Skip to content

Commit

Permalink
fixed bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Mar 15, 2023
1 parent 2d278c3 commit 7c1d71d
Show file tree
Hide file tree
Showing 15 changed files with 160 additions and 152 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ ToolFormer (Pytorch) - 🚧 WORK IN PROGRESS 🚧

<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

![image.png](index_files/figure-commonmark/ec9347e6-1-image.png)
![image.png](index_files/figure-commonmark/7a1d6aa5-1-image.png)

Paper: [Toolformer: Language Models Can Teach Themselves to Use
Tools](https://arxiv.org/abs/2302.04761)
Expand All @@ -13,11 +13,12 @@ Tools](https://arxiv.org/abs/2302.04761)

### TODO

- Support batch
- Support augment a batch of text
- Executing API calls in parallel

**API**

-
- Calendar API
- WolframeAlpha API

Expand Down
2 changes: 1 addition & 1 deletion configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ data_generator:

top_k_sampling: 3
sampling_threshold: 0.1
filtering_threshold: 0.2
filtering_threshold: 0.05

max_new_tokens: 100

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 4 additions & 4 deletions nbs/02_prompt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -124,18 +124,18 @@
"source": [
"#| export\n",
"wolframe_prompt = \"\"\"\n",
"Your task is to add calls to a Scientific API to a piece of text. The questions should help you get information required to complete the text.\n",
"Your task is to add calls to a Scientific API to a piece of text that related to chemistry, math, physics. The questions should help you get information required to complete the text.\n",
"You can call the API by writing \"[Wolframe(question)]\" where \"question\" is the question you want to ask. Here are some examples of API calls:\n",
"\n",
"Input: Joe Biden was born in Scranton, Pennsylvania\n",
"Output: Joe Biden was born in [Wolframe(\"Where was Joe Biden born?\")] Scranton, [Wolframe(\"In which state is Scranton?\")] Pennsylvania.\n",
"\n",
"Input: The complex conjugate of 2 + 3i is 2 - 3i.\n",
"Output: The complex conjugate of 2 + 3i is [Wolframe(\"What is the complex conjugate of 2 - 3i\")] 2 - 3i.\n",
"\n",
"Input: Solve x^2 + 4x + 6 = 0. The answer is x = -2 - i sqrt(2)\n",
"Output: Solve x^2 + 4x + 6 = 0. The answer is [Wolframe(\"Solve x^2 + 4x + 6 = 0\")] x = -2 - i sqrt(2)\n",
"\n",
"Input: Given a sequence of numbers: 21.3, 38.4, 12.7, 41.6. The mean is 28.5\n",
"Output: Given a sequence of numbers: 21.3, 38.4, 12.7, 41.6. The mean is [Wolframe(\"What is the mean of 21.3, 38.4, 12.7, 41.6\")] 28.5\n",
"\n",
"Input: {input}\n",
"Output:\n",
"\"\"\""
Expand Down
54 changes: 11 additions & 43 deletions nbs/03_api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,17 @@
"source": [
"#| export\n",
"class BaseAPI:\n",
" # def __init__(\n",
" # self,\n",
" # name: str,\n",
" # prompt_template: PromptTemplate\n",
" # ):\n",
" # self.name = name\n",
" # self.prompt_template = prompt_template\n",
" def __init__(\n",
" self,\n",
" name: str, # the name of the API call\n",
" prompt_template: PromptTemplate,\n",
" sampling_threshold: float = 0.2,\n",
" filtering_threshold: float = 0.2,\n",
" ):\n",
" self.name = name\n",
" self.prompt_template = prompt_template\n",
" self.sampling_threshold = sampling_threshold\n",
" self.filtering_threshold = filtering_threshold\n",
"\n",
" @abstractclassmethod\n",
" def execute(self):\n",
Expand Down Expand Up @@ -129,42 +133,6 @@
" res = client.query(input=input)\n",
" return next(res.results).text"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"output = WolframeAPI()(\"solve x^2 + 4x + 6 = 0\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'x = -2 - i sqrt(2)'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
64 changes: 38 additions & 26 deletions nbs/04_data_generator.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,12 @@
" else:\n",
" i += 1\n",
" \n",
" _, indices = torch.sort(api_pos_probs[:, 0], descending=True)\n",
" top_k_sampling = self.top_k_sampling\n",
" api_positions = api_pos_probs[indices[:top_k_sampling], 1]\n",
" if api_pos_probs.numel() == 0:\n",
" api_positions = torch.tensor([])\n",
" else:\n",
" _, indices = torch.sort(api_pos_probs[:, 0], descending=True)\n",
" top_k_sampling = self.top_k_sampling\n",
" api_positions = api_pos_probs[indices[:top_k_sampling], 1]\n",
" \n",
" return api_positions.long(), generated_ids.long()\n",
"\n",
Expand Down Expand Up @@ -219,20 +222,20 @@
" \n",
" def _generate_conditioning_prompts(\n",
" self,\n",
" api: BaseAPI,\n",
" candidate_ids: TensorType[\"n_candidates\", \"seq_len\"],\n",
" ):\n",
" calculator_api = CalculatorAPI()\n",
" conditioning_api_ids = torch.tensor([])\n",
"\n",
" API_NAME = \"Calculator\"\n",
" API_NAME = api.name\n",
" MAX_PAD = 100\n",
"\n",
" for text_ids in candidate_ids:\n",
" # the ids of the prediction\n",
" text = self.tokenizer.decode(text_ids, skip_special_tokens=True)\n",
" \n",
" api_request_content = self.extract_api_request_content(text, api_name=API_NAME)\n",
" api_response = calculator_api(api_request_content)\n",
" api_response = api(api_request_content)\n",
" api_response_ids = self.tokenizer(api_response, return_tensors=\"pt\")[\"input_ids\"][0]\n",
" # Format: \"-> [api_response]\"\n",
" api_response_with_arrow_ids = torch.cat([self.api_output_token_id, api_response_ids], dim=0)\n",
Expand Down Expand Up @@ -302,23 +305,28 @@
" losses,\n",
" candidates: TensorType[\"seq_len\"]\n",
" ):\n",
" filtered_augmented_text_ids = []\n",
" filtered_augmented_text_ids = torch.tensor([])\n",
" for i, position in enumerate(losses):\n",
" negative_loss = min(losses[position][0], losses[position][1])\n",
" positive_loss = losses[position][2]\n",
" \n",
" if negative_loss - positive_loss >= self.filtering_threshold:\n",
" filtered_augmented_text_ids.append(candidates[i])\n",
" # filtered_augmented_text_ids.append(candidates[i])\n",
" filtered_augmented_text_ids = torch.cat([\n",
" filtered_augmented_text_ids,\n",
" candidates[i].unsqueeze(0)\n",
" ], dim=0)\n",
" \n",
" return filtered_augmented_text_ids\n",
" return filtered_augmented_text_ids.long()\n",
"\n",
" def filter_api( \n",
" self,\n",
" api: BaseAPI,\n",
" text_ids: TensorType[\"seq_len\"],\n",
" api_start_idxs: TensorType[\"n_positions\"],\n",
" candidate_ids: TensorType[\"n_positions\", \"seq_len\"]\n",
" ):\n",
" conditioning_api_ids = self._generate_conditioning_prompts(candidate_ids)\n",
" conditioning_api_ids = self._generate_conditioning_prompts(api, candidate_ids)\n",
" \n",
" SPACE_TOKEN = self.tokenizer(\". \", return_tensors=\"pt\")[\"input_ids\"][0]\n",
" API_LENGTH = 100\n",
Expand Down Expand Up @@ -398,26 +406,30 @@
" \n",
" def generate(\n",
" self,\n",
" prompt_tempalte: PromptTemplate,\n",
" text: str,\n",
" ) -> TensorType[\"n_candidates\", \"seq_len\"]:\n",
" # TODO: add support batch\n",
" prompt = prompt_tempalte.format(input=text)\n",
" prompt_ids = self.tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"][0]\n",
" \n",
" # sampling positions\n",
" api_start_idxs, generated_ids = self.sample_api_position(prompt_ids)\n",
" ) -> TensorType[\"n_apis\", \"n_candidates\", \"seq_len\"]:\n",
" filtered_apis = torch.tensor([])\n",
" \n",
" for api in self.apis:\n",
" # TODO: add support batch\n",
" prompt = api.prompt_template.format(input=text)\n",
" prompt_ids = self.tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"][0]\n",
" \n",
" # obtaining api responses\n",
" candidate_ids = self.obtain_api_response(prompt_ids, api_start_idxs, generated_ids)\n",
" # sampling positions\n",
" api_start_idxs, generated_ids = self.sample_api_position(prompt_ids)\n",
" \n",
" # obtaining api responses\n",
" candidate_ids = self.obtain_api_response(prompt_ids, api_start_idxs, generated_ids)\n",
"\n",
" # filtering\n",
" text_ids = self.tokenizer(text, return_tensors=\"pt\")[\"input_ids\"][0]\n",
" # filtering\n",
" text_ids = self.tokenizer(text, return_tensors=\"pt\")[\"input_ids\"][0]\n",
" \n",
" # return prompt_ids, api_start_idxs, generated_ids, candidate_ids, text_ids\n",
" filtered_candidate_ids = self.filter_api(api, text_ids, api_start_idxs, candidate_ids)\n",
" \n",
" filtered_apis = torch.cat([filtered_apis, filtered_candidate_ids.unsqueeze(0)], dim=0)\n",
" \n",
" # return prompt_ids, api_start_idxs, generated_ids, candidate_ids, text_ids\n",
" filtered_candidate_ids = self.filter_api(text_ids, api_start_idxs, candidate_ids)\n",
" \n",
" return filtered_candidate_ids"
" return filtered_apis.long()"
]
}
],
Expand Down
3 changes: 2 additions & 1 deletion nbs/index.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@
"metadata": {},
"source": [
"### TODO\n",
"- Support batch\n",
"- Support augment a batch of text\n",
"- Executing API calls in parallel\n",
"\n",
"\n",
"**API**\n",
"\n",
"- \n",
"- Calendar API\n",
"- WolframeAlpha API"
]
Expand Down
34 changes: 14 additions & 20 deletions settings.ini
Original file line number Diff line number Diff line change
@@ -1,43 +1,37 @@
[DEFAULT]
# All sections below are required unless otherwise specified.
# See https://github.com/fastai/nbdev/blob/master/settings.ini for examples.

### Python library ###
repo = toolformer
lib_name = %(repo)s
version = 0.0.1
lib_name = toolformer
version = 0.0.2
min_python = 3.7
license = apache2
black_formatting = False

### nbdev ###
doc_path = _docs
lib_path = toolformer
nbs_path = nbs
recursive = True
tst_flags = notest
put_version_in_init = True

### Docs ###
branch = main
custom_sidebar = False
doc_host = https://%(user)s.github.io
doc_baseurl = /%(repo)s
git_url = https://github.com/%(user)s/%(repo)s
title = %(lib_name)s

### PyPI ###
doc_host = https://xrsrke.github.io
doc_baseurl = /toolformer
git_url = https://github.com/xrsrke/toolformer
title = toolformer
audience = Developers
author = xrsrke
author_email = [email protected]
copyright = 2023 onwards, %(author)s
copyright = 2023 onwards, xrsrke
description = Implementation of Toolformer
keywords = nbdev jupyter notebook python
language = English
status = 3
user = xrsrke

### Optional ###
requirements = torch einops torchtyping langchain transformers datasets wolframalpha
dev_requirements = pytest
# console_scripts =
readme_nb = index.ipynb
allowed_metadata_keys =
allowed_cell_metadata_keys =
jupyter_hooks = True
clean_ids = True
clear_all = False

16 changes: 13 additions & 3 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from toolformer.api import CalculatorAPI
from toolformer.api import CalculatorAPI, WolframeAPI
from toolformer.prompt import calculator_prompt, wolframe_prompt


# generate test for execute_calculator
Expand All @@ -13,9 +14,18 @@
)
)
def test_execute_calculator_api(input, expected):
calculator_api = CalculatorAPI()
calculator_api = CalculatorAPI("Calculator", calculator_prompt)

output = calculator_api(input)

assert output == expected
assert isinstance(output, str)
assert isinstance(output, str)

def test_execute_wolframe_api():
wolframe_api = WolframeAPI("Wolframe", wolframe_prompt)

input = "integrate x^2 sin^3 x dx"
output = wolframe_api(input)

assert isinstance(output, str)
assert len(output) > 0
33 changes: 19 additions & 14 deletions tests/test_data_generator.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import pytest

import torch
import torch.nn.functional as F
from langchain import PromptTemplate

from toolformer.data_generator import DataGenerator
from toolformer.prompt import calculator_prompt

def test_create_data_generator(default_config):
# model = AutoModelForCausalLM.from_pretrained(default_config['model']['path'])
# tokenizer = AutoTokenizer.from_pretrained(default_config['tokenizer']['path'])

# generator = DataGenerator(default_config, model, tokenizer, apis=[])
pass
from toolformer.api import CalculatorAPI, WolframeAPI
from toolformer.prompt import calculator_prompt, wolframe_prompt

def test_sampling_apis_call(
data_generator, prompt_tempalte,
Expand Down Expand Up @@ -76,12 +72,21 @@ def test_filtering_api_call(default_config, model, tokenizer):

assert isinstance(filtered_candidate_ids, list)

def test_generate_data_generator(default_config, model, tokenizer):
text = "From this, we have 10 - 5 minutes = 5 minutes."
prompt_tempalte = PromptTemplate(template=calculator_prompt, input_variables=["input"])
calculator_api = CalculatorAPI("Calculator", calculator_prompt)
wolframe_api = WolframeAPI("Wolframe", wolframe_prompt)

generator = DataGenerator(default_config, model, tokenizer, apis=[])
@pytest.mark.parametrize("apis", [
[calculator_api],
[wolframe_api],
[calculator_api, wolframe_api],
])
def test_generate_data_generator(default_config, model, tokenizer, apis):
text = "From this, we have 10 - 5 minutes [Calculator(10 - 5)] 5 minutes."

generator = DataGenerator(default_config, model, tokenizer, apis=apis)

filtered_candidate_ids = generator.generate(prompt_tempalte, text)
filtered_candidate_ids = generator.generate(text)

assert isinstance(filtered_candidate_ids, list)
assert filtered_candidate_ids.shape[0] == len(apis)
assert filtered_candidate_ids.ndim == 3
assert isinstance(filtered_candidate_ids, torch.Tensor)
Loading

0 comments on commit 7c1d71d

Please sign in to comment.