diff --git a/README.md b/README.md index 9612b7f..4388a20 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ ToolFormer (Pytorch) - 🚧 WORK IN PROGRESS 🚧 -![image.png](index_files/figure-commonmark/ec9347e6-1-image.png) +![image.png](index_files/figure-commonmark/7a1d6aa5-1-image.png) Paper: [Toolformer: Language Models Can Teach Themselves to Use Tools](https://arxiv.org/abs/2302.04761) @@ -13,11 +13,12 @@ Tools](https://arxiv.org/abs/2302.04761) ### TODO -- Support batch +- Support augment a batch of text - Executing API calls in parallel **API** +- - Calendar API - WolframeAlpha API diff --git a/configs/default.yaml b/configs/default.yaml index 53c57d4..a173116 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -15,7 +15,7 @@ data_generator: top_k_sampling: 3 sampling_threshold: 0.1 - filtering_threshold: 0.2 + filtering_threshold: 0.05 max_new_tokens: 100 diff --git a/index_files/figure-commonmark/7a1d6aa5-1-image.png b/index_files/figure-commonmark/7a1d6aa5-1-image.png new file mode 100644 index 0000000..02946d9 Binary files /dev/null and b/index_files/figure-commonmark/7a1d6aa5-1-image.png differ diff --git a/nbs/02_prompt.ipynb b/nbs/02_prompt.ipynb index 2e8568f..ce09116 100644 --- a/nbs/02_prompt.ipynb +++ b/nbs/02_prompt.ipynb @@ -124,18 +124,18 @@ "source": [ "#| export\n", "wolframe_prompt = \"\"\"\n", - "Your task is to add calls to a Scientific API to a piece of text. The questions should help you get information required to complete the text.\n", + "Your task is to add calls to a Scientific API to a piece of text that related to chemistry, math, physics. The questions should help you get information required to complete the text.\n", "You can call the API by writing \"[Wolframe(question)]\" where \"question\" is the question you want to ask. Here are some examples of API calls:\n", "\n", - "Input: Joe Biden was born in Scranton, Pennsylvania\n", - "Output: Joe Biden was born in [Wolframe(\"Where was Joe Biden born?\")] Scranton, [Wolframe(\"In which state is Scranton?\")] Pennsylvania.\n", - "\n", "Input: The complex conjugate of 2 + 3i is 2 - 3i.\n", "Output: The complex conjugate of 2 + 3i is [Wolframe(\"What is the complex conjugate of 2 - 3i\")] 2 - 3i.\n", "\n", "Input: Solve x^2 + 4x + 6 = 0. The answer is x = -2 - i sqrt(2)\n", "Output: Solve x^2 + 4x + 6 = 0. The answer is [Wolframe(\"Solve x^2 + 4x + 6 = 0\")] x = -2 - i sqrt(2)\n", "\n", + "Input: Given a sequence of numbers: 21.3, 38.4, 12.7, 41.6. The mean is 28.5\n", + "Output: Given a sequence of numbers: 21.3, 38.4, 12.7, 41.6. The mean is [Wolframe(\"What is the mean of 21.3, 38.4, 12.7, 41.6\")] 28.5\n", + "\n", "Input: {input}\n", "Output:\n", "\"\"\"" diff --git a/nbs/03_api.ipynb b/nbs/03_api.ipynb index 1d1e92f..03eaefb 100644 --- a/nbs/03_api.ipynb +++ b/nbs/03_api.ipynb @@ -68,13 +68,17 @@ "source": [ "#| export\n", "class BaseAPI:\n", - " # def __init__(\n", - " # self,\n", - " # name: str,\n", - " # prompt_template: PromptTemplate\n", - " # ):\n", - " # self.name = name\n", - " # self.prompt_template = prompt_template\n", + " def __init__(\n", + " self,\n", + " name: str, # the name of the API call\n", + " prompt_template: PromptTemplate,\n", + " sampling_threshold: float = 0.2,\n", + " filtering_threshold: float = 0.2,\n", + " ):\n", + " self.name = name\n", + " self.prompt_template = prompt_template\n", + " self.sampling_threshold = sampling_threshold\n", + " self.filtering_threshold = filtering_threshold\n", "\n", " @abstractclassmethod\n", " def execute(self):\n", @@ -129,42 +133,6 @@ " res = client.query(input=input)\n", " return next(res.results).text" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "output = WolframeAPI()(\"solve x^2 + 4x + 6 = 0\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'x = -2 - i sqrt(2)'" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "output" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/nbs/04_data_generator.ipynb b/nbs/04_data_generator.ipynb index 796cc2a..9d4f982 100644 --- a/nbs/04_data_generator.ipynb +++ b/nbs/04_data_generator.ipynb @@ -167,9 +167,12 @@ " else:\n", " i += 1\n", " \n", - " _, indices = torch.sort(api_pos_probs[:, 0], descending=True)\n", - " top_k_sampling = self.top_k_sampling\n", - " api_positions = api_pos_probs[indices[:top_k_sampling], 1]\n", + " if api_pos_probs.numel() == 0:\n", + " api_positions = torch.tensor([])\n", + " else:\n", + " _, indices = torch.sort(api_pos_probs[:, 0], descending=True)\n", + " top_k_sampling = self.top_k_sampling\n", + " api_positions = api_pos_probs[indices[:top_k_sampling], 1]\n", " \n", " return api_positions.long(), generated_ids.long()\n", "\n", @@ -219,12 +222,12 @@ " \n", " def _generate_conditioning_prompts(\n", " self,\n", + " api: BaseAPI,\n", " candidate_ids: TensorType[\"n_candidates\", \"seq_len\"],\n", " ):\n", - " calculator_api = CalculatorAPI()\n", " conditioning_api_ids = torch.tensor([])\n", "\n", - " API_NAME = \"Calculator\"\n", + " API_NAME = api.name\n", " MAX_PAD = 100\n", "\n", " for text_ids in candidate_ids:\n", @@ -232,7 +235,7 @@ " text = self.tokenizer.decode(text_ids, skip_special_tokens=True)\n", " \n", " api_request_content = self.extract_api_request_content(text, api_name=API_NAME)\n", - " api_response = calculator_api(api_request_content)\n", + " api_response = api(api_request_content)\n", " api_response_ids = self.tokenizer(api_response, return_tensors=\"pt\")[\"input_ids\"][0]\n", " # Format: \"-> [api_response]\"\n", " api_response_with_arrow_ids = torch.cat([self.api_output_token_id, api_response_ids], dim=0)\n", @@ -302,23 +305,28 @@ " losses,\n", " candidates: TensorType[\"seq_len\"]\n", " ):\n", - " filtered_augmented_text_ids = []\n", + " filtered_augmented_text_ids = torch.tensor([])\n", " for i, position in enumerate(losses):\n", " negative_loss = min(losses[position][0], losses[position][1])\n", " positive_loss = losses[position][2]\n", " \n", " if negative_loss - positive_loss >= self.filtering_threshold:\n", - " filtered_augmented_text_ids.append(candidates[i])\n", + " # filtered_augmented_text_ids.append(candidates[i])\n", + " filtered_augmented_text_ids = torch.cat([\n", + " filtered_augmented_text_ids,\n", + " candidates[i].unsqueeze(0)\n", + " ], dim=0)\n", " \n", - " return filtered_augmented_text_ids\n", + " return filtered_augmented_text_ids.long()\n", "\n", " def filter_api( \n", " self,\n", + " api: BaseAPI,\n", " text_ids: TensorType[\"seq_len\"],\n", " api_start_idxs: TensorType[\"n_positions\"],\n", " candidate_ids: TensorType[\"n_positions\", \"seq_len\"]\n", " ):\n", - " conditioning_api_ids = self._generate_conditioning_prompts(candidate_ids)\n", + " conditioning_api_ids = self._generate_conditioning_prompts(api, candidate_ids)\n", " \n", " SPACE_TOKEN = self.tokenizer(\". \", return_tensors=\"pt\")[\"input_ids\"][0]\n", " API_LENGTH = 100\n", @@ -398,26 +406,30 @@ " \n", " def generate(\n", " self,\n", - " prompt_tempalte: PromptTemplate,\n", " text: str,\n", - " ) -> TensorType[\"n_candidates\", \"seq_len\"]:\n", - " # TODO: add support batch\n", - " prompt = prompt_tempalte.format(input=text)\n", - " prompt_ids = self.tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"][0]\n", - " \n", - " # sampling positions\n", - " api_start_idxs, generated_ids = self.sample_api_position(prompt_ids)\n", + " ) -> TensorType[\"n_apis\", \"n_candidates\", \"seq_len\"]:\n", + " filtered_apis = torch.tensor([])\n", + " \n", + " for api in self.apis:\n", + " # TODO: add support batch\n", + " prompt = api.prompt_template.format(input=text)\n", + " prompt_ids = self.tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"][0]\n", " \n", - " # obtaining api responses\n", - " candidate_ids = self.obtain_api_response(prompt_ids, api_start_idxs, generated_ids)\n", + " # sampling positions\n", + " api_start_idxs, generated_ids = self.sample_api_position(prompt_ids)\n", + " \n", + " # obtaining api responses\n", + " candidate_ids = self.obtain_api_response(prompt_ids, api_start_idxs, generated_ids)\n", "\n", - " # filtering\n", - " text_ids = self.tokenizer(text, return_tensors=\"pt\")[\"input_ids\"][0]\n", + " # filtering\n", + " text_ids = self.tokenizer(text, return_tensors=\"pt\")[\"input_ids\"][0]\n", + " \n", + " # return prompt_ids, api_start_idxs, generated_ids, candidate_ids, text_ids\n", + " filtered_candidate_ids = self.filter_api(api, text_ids, api_start_idxs, candidate_ids)\n", + " \n", + " filtered_apis = torch.cat([filtered_apis, filtered_candidate_ids.unsqueeze(0)], dim=0)\n", " \n", - " # return prompt_ids, api_start_idxs, generated_ids, candidate_ids, text_ids\n", - " filtered_candidate_ids = self.filter_api(text_ids, api_start_idxs, candidate_ids)\n", - " \n", - " return filtered_candidate_ids" + " return filtered_apis.long()" ] } ], diff --git a/nbs/index.ipynb b/nbs/index.ipynb index 5955855..73bb356 100644 --- a/nbs/index.ipynb +++ b/nbs/index.ipynb @@ -44,12 +44,13 @@ "metadata": {}, "source": [ "### TODO\n", - "- Support batch\n", + "- Support augment a batch of text\n", "- Executing API calls in parallel\n", "\n", "\n", "**API**\n", "\n", + "- \n", "- Calendar API\n", "- WolframeAlpha API" ] diff --git a/settings.ini b/settings.ini index 0558518..ed31971 100644 --- a/settings.ini +++ b/settings.ini @@ -1,43 +1,37 @@ [DEFAULT] -# All sections below are required unless otherwise specified. -# See https://github.com/fastai/nbdev/blob/master/settings.ini for examples. - -### Python library ### repo = toolformer -lib_name = %(repo)s -version = 0.0.1 +lib_name = toolformer +version = 0.0.2 min_python = 3.7 license = apache2 black_formatting = False - -### nbdev ### doc_path = _docs lib_path = toolformer nbs_path = nbs recursive = True tst_flags = notest put_version_in_init = True - -### Docs ### branch = main custom_sidebar = False -doc_host = https://%(user)s.github.io -doc_baseurl = /%(repo)s -git_url = https://github.com/%(user)s/%(repo)s -title = %(lib_name)s - -### PyPI ### +doc_host = https://xrsrke.github.io +doc_baseurl = /toolformer +git_url = https://github.com/xrsrke/toolformer +title = toolformer audience = Developers author = xrsrke author_email = xariusdrake@hotmail.com -copyright = 2023 onwards, %(author)s +copyright = 2023 onwards, xrsrke description = Implementation of Toolformer keywords = nbdev jupyter notebook python language = English status = 3 user = xrsrke - -### Optional ### requirements = torch einops torchtyping langchain transformers datasets wolframalpha dev_requirements = pytest -# console_scripts = \ No newline at end of file +readme_nb = index.ipynb +allowed_metadata_keys = +allowed_cell_metadata_keys = +jupyter_hooks = True +clean_ids = True +clear_all = False + diff --git a/tests/test_api.py b/tests/test_api.py index ea89b14..7e7d3d7 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,6 +1,7 @@ import pytest -from toolformer.api import CalculatorAPI +from toolformer.api import CalculatorAPI, WolframeAPI +from toolformer.prompt import calculator_prompt, wolframe_prompt # generate test for execute_calculator @@ -13,9 +14,18 @@ ) ) def test_execute_calculator_api(input, expected): - calculator_api = CalculatorAPI() + calculator_api = CalculatorAPI("Calculator", calculator_prompt) output = calculator_api(input) assert output == expected - assert isinstance(output, str) \ No newline at end of file + assert isinstance(output, str) + +def test_execute_wolframe_api(): + wolframe_api = WolframeAPI("Wolframe", wolframe_prompt) + + input = "integrate x^2 sin^3 x dx" + output = wolframe_api(input) + + assert isinstance(output, str) + assert len(output) > 0 \ No newline at end of file diff --git a/tests/test_data_generator.py b/tests/test_data_generator.py index e1e027d..275fe11 100644 --- a/tests/test_data_generator.py +++ b/tests/test_data_generator.py @@ -1,16 +1,12 @@ +import pytest + import torch import torch.nn.functional as F from langchain import PromptTemplate from toolformer.data_generator import DataGenerator -from toolformer.prompt import calculator_prompt - -def test_create_data_generator(default_config): - # model = AutoModelForCausalLM.from_pretrained(default_config['model']['path']) - # tokenizer = AutoTokenizer.from_pretrained(default_config['tokenizer']['path']) - - # generator = DataGenerator(default_config, model, tokenizer, apis=[]) - pass +from toolformer.api import CalculatorAPI, WolframeAPI +from toolformer.prompt import calculator_prompt, wolframe_prompt def test_sampling_apis_call( data_generator, prompt_tempalte, @@ -76,12 +72,21 @@ def test_filtering_api_call(default_config, model, tokenizer): assert isinstance(filtered_candidate_ids, list) -def test_generate_data_generator(default_config, model, tokenizer): - text = "From this, we have 10 - 5 minutes = 5 minutes." - prompt_tempalte = PromptTemplate(template=calculator_prompt, input_variables=["input"]) +calculator_api = CalculatorAPI("Calculator", calculator_prompt) +wolframe_api = WolframeAPI("Wolframe", wolframe_prompt) - generator = DataGenerator(default_config, model, tokenizer, apis=[]) +@pytest.mark.parametrize("apis", [ + [calculator_api], + [wolframe_api], + [calculator_api, wolframe_api], +]) +def test_generate_data_generator(default_config, model, tokenizer, apis): + text = "From this, we have 10 - 5 minutes [Calculator(10 - 5)] 5 minutes." + + generator = DataGenerator(default_config, model, tokenizer, apis=apis) - filtered_candidate_ids = generator.generate(prompt_tempalte, text) + filtered_candidate_ids = generator.generate(text) - assert isinstance(filtered_candidate_ids, list) \ No newline at end of file + assert filtered_candidate_ids.shape[0] == len(apis) + assert filtered_candidate_ids.ndim == 3 + assert isinstance(filtered_candidate_ids, torch.Tensor) \ No newline at end of file diff --git a/toolformer/__init__.py b/toolformer/__init__.py index f102a9c..3b93d0b 100644 --- a/toolformer/__init__.py +++ b/toolformer/__init__.py @@ -1 +1 @@ -__version__ = "0.0.1" +__version__ = "0.0.2" diff --git a/toolformer/_modidx.py b/toolformer/_modidx.py index 85ffb6e..a650448 100644 --- a/toolformer/_modidx.py +++ b/toolformer/_modidx.py @@ -7,6 +7,7 @@ 'lib_path': 'toolformer'}, 'syms': { 'toolformer.api': { 'toolformer.api.BaseAPI': ('api.html#baseapi', 'toolformer/api.py'), 'toolformer.api.BaseAPI.__call__': ('api.html#baseapi.__call__', 'toolformer/api.py'), + 'toolformer.api.BaseAPI.__init__': ('api.html#baseapi.__init__', 'toolformer/api.py'), 'toolformer.api.BaseAPI.execute': ('api.html#baseapi.execute', 'toolformer/api.py'), 'toolformer.api.CalculatorAPI': ('api.html#calculatorapi', 'toolformer/api.py'), 'toolformer.api.CalculatorAPI.execute': ('api.html#calculatorapi.execute', 'toolformer/api.py'), diff --git a/toolformer/api.py b/toolformer/api.py index e724eff..668769d 100644 --- a/toolformer/api.py +++ b/toolformer/api.py @@ -11,13 +11,17 @@ # %% ../nbs/03_api.ipynb 6 class BaseAPI: - # def __init__( - # self, - # name: str, - # prompt_template: PromptTemplate - # ): - # self.name = name - # self.prompt_template = prompt_template + def __init__( + self, + name: str, # the name of the API call + prompt_template: PromptTemplate, + sampling_threshold: float = 0.2, + filtering_threshold: float = 0.2, + ): + self.name = name + self.prompt_template = prompt_template + self.sampling_threshold = sampling_threshold + self.filtering_threshold = filtering_threshold @abstractclassmethod def execute(self): diff --git a/toolformer/data_generator.py b/toolformer/data_generator.py index aff8ee0..4fd65aa 100644 --- a/toolformer/data_generator.py +++ b/toolformer/data_generator.py @@ -111,9 +111,12 @@ def sample_api_position( else: i += 1 - _, indices = torch.sort(api_pos_probs[:, 0], descending=True) - top_k_sampling = self.top_k_sampling - api_positions = api_pos_probs[indices[:top_k_sampling], 1] + if api_pos_probs.numel() == 0: + api_positions = torch.tensor([]) + else: + _, indices = torch.sort(api_pos_probs[:, 0], descending=True) + top_k_sampling = self.top_k_sampling + api_positions = api_pos_probs[indices[:top_k_sampling], 1] return api_positions.long(), generated_ids.long() @@ -163,12 +166,12 @@ def obtain_api_response( def _generate_conditioning_prompts( self, + api: BaseAPI, candidate_ids: TensorType["n_candidates", "seq_len"], ): - calculator_api = CalculatorAPI() conditioning_api_ids = torch.tensor([]) - API_NAME = "Calculator" + API_NAME = api.name MAX_PAD = 100 for text_ids in candidate_ids: @@ -176,7 +179,7 @@ def _generate_conditioning_prompts( text = self.tokenizer.decode(text_ids, skip_special_tokens=True) api_request_content = self.extract_api_request_content(text, api_name=API_NAME) - api_response = calculator_api(api_request_content) + api_response = api(api_request_content) api_response_ids = self.tokenizer(api_response, return_tensors="pt")["input_ids"][0] # Format: "-> [api_response]" api_response_with_arrow_ids = torch.cat([self.api_output_token_id, api_response_ids], dim=0) @@ -246,23 +249,28 @@ def _filter_candidate_by_threshold( losses, candidates: TensorType["seq_len"] ): - filtered_augmented_text_ids = [] + filtered_augmented_text_ids = torch.tensor([]) for i, position in enumerate(losses): negative_loss = min(losses[position][0], losses[position][1]) positive_loss = losses[position][2] if negative_loss - positive_loss >= self.filtering_threshold: - filtered_augmented_text_ids.append(candidates[i]) + # filtered_augmented_text_ids.append(candidates[i]) + filtered_augmented_text_ids = torch.cat([ + filtered_augmented_text_ids, + candidates[i].unsqueeze(0) + ], dim=0) - return filtered_augmented_text_ids + return filtered_augmented_text_ids.long() def filter_api( self, + api: BaseAPI, text_ids: TensorType["seq_len"], api_start_idxs: TensorType["n_positions"], candidate_ids: TensorType["n_positions", "seq_len"] ): - conditioning_api_ids = self._generate_conditioning_prompts(candidate_ids) + conditioning_api_ids = self._generate_conditioning_prompts(api, candidate_ids) SPACE_TOKEN = self.tokenizer(". ", return_tensors="pt")["input_ids"][0] API_LENGTH = 100 @@ -342,23 +350,27 @@ def extract_target_logprob_from_logits(logits, target_ids): def generate( self, - prompt_tempalte: PromptTemplate, text: str, - ) -> TensorType["n_candidates", "seq_len"]: - # TODO: add support batch - prompt = prompt_tempalte.format(input=text) - prompt_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0] - - # sampling positions - api_start_idxs, generated_ids = self.sample_api_position(prompt_ids) + ) -> TensorType["n_apis", "n_candidates", "seq_len"]: + filtered_apis = torch.tensor([]) - # obtaining api responses - candidate_ids = self.obtain_api_response(prompt_ids, api_start_idxs, generated_ids) + for api in self.apis: + # TODO: add support batch + prompt = api.prompt_template.format(input=text) + prompt_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0] + + # sampling positions + api_start_idxs, generated_ids = self.sample_api_position(prompt_ids) + + # obtaining api responses + candidate_ids = self.obtain_api_response(prompt_ids, api_start_idxs, generated_ids) - # filtering - text_ids = self.tokenizer(text, return_tensors="pt")["input_ids"][0] + # filtering + text_ids = self.tokenizer(text, return_tensors="pt")["input_ids"][0] + + # return prompt_ids, api_start_idxs, generated_ids, candidate_ids, text_ids + filtered_candidate_ids = self.filter_api(api, text_ids, api_start_idxs, candidate_ids) + + filtered_apis = torch.cat([filtered_apis, filtered_candidate_ids.unsqueeze(0)], dim=0) - # return prompt_ids, api_start_idxs, generated_ids, candidate_ids, text_ids - filtered_candidate_ids = self.filter_api(text_ids, api_start_idxs, candidate_ids) - - return filtered_candidate_ids + return filtered_apis.long() diff --git a/toolformer/prompt.py b/toolformer/prompt.py index 28a5ae5..8e6ae7f 100644 --- a/toolformer/prompt.py +++ b/toolformer/prompt.py @@ -43,18 +43,18 @@ # %% ../nbs/02_prompt.ipynb 9 wolframe_prompt = """ -Your task is to add calls to a Scientific API to a piece of text. The questions should help you get information required to complete the text. +Your task is to add calls to a Scientific API to a piece of text that related to chemistry, math, physics. The questions should help you get information required to complete the text. You can call the API by writing "[Wolframe(question)]" where "question" is the question you want to ask. Here are some examples of API calls: -Input: Joe Biden was born in Scranton, Pennsylvania -Output: Joe Biden was born in [Wolframe("Where was Joe Biden born?")] Scranton, [Wolframe("In which state is Scranton?")] Pennsylvania. - Input: The complex conjugate of 2 + 3i is 2 - 3i. Output: The complex conjugate of 2 + 3i is [Wolframe("What is the complex conjugate of 2 - 3i")] 2 - 3i. Input: Solve x^2 + 4x + 6 = 0. The answer is x = -2 - i sqrt(2) Output: Solve x^2 + 4x + 6 = 0. The answer is [Wolframe("Solve x^2 + 4x + 6 = 0")] x = -2 - i sqrt(2) +Input: Given a sequence of numbers: 21.3, 38.4, 12.7, 41.6. The mean is 28.5 +Output: Given a sequence of numbers: 21.3, 38.4, 12.7, 41.6. The mean is [Wolframe("What is the mean of 21.3, 38.4, 12.7, 41.6")] 28.5 + Input: {input} Output: """