Skip to content

Commit

Permalink
Merge pull request #1 from xrsrke/fix-loss
Browse files Browse the repository at this point in the history
refactor filtering api calls
  • Loading branch information
XλRI-U5 authored Mar 14, 2023
2 parents 303d209 + 1063f37 commit 2d278c3
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 58 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ ToolFormer (Pytorch) - 🚧 WORK IN PROGRESS 🚧

<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

![image.png](index_files/figure-commonmark/369b5b37-1-image.png)
![image.png](index_files/figure-commonmark/ec9347e6-1-image.png)

Paper: [Toolformer: Language Models Can Teach Themselves to Use
Tools](https://arxiv.org/abs/2302.04761)
Expand Down
2 changes: 1 addition & 1 deletion configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ data_generator:

top_k_sampling: 3
sampling_threshold: 0.1
filtering_threshold: 0.3
filtering_threshold: 0.2

max_new_tokens: 100

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
54 changes: 26 additions & 28 deletions nbs/04_data_generator.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@
" \n",
" def sample_api_position(\n",
" self,\n",
" prompt_ids: TensorType[\"batch_size\", \"seq_len\"], # the ids of the prompt\n",
" prompt_ids: TensorType[\"seq_len\"], # the ids of the prompt\n",
" ) -> Tuple[\n",
" TensorType[\"batch_size\", \"n_positions\"], # The positions of api call\n",
" TensorType[\"batch_size\", \"seq_len\"] # The generated text\n",
" TensorType[\"n_positions\"], # The positions of api call\n",
" TensorType[\"seq_len\"] # The generated text\n",
" ]:\n",
" \"\"\"Sampling API positions.\"\"\"\n",
" # TODO: add support batch\n",
Expand Down Expand Up @@ -175,10 +175,10 @@
"\n",
" def obtain_api_response(\n",
" self,\n",
" prompt_ids: TensorType[\"batch_size\", \"seq_len\"],\n",
" positions: TensorType[\"batch_size\", \"n_positions\"],\n",
" generated_ids: TensorType[\"batch_size\", \"seq_len\"]\n",
" ) -> TensorType[\"batch_size\", \"n_positions\", \"seq_len\"]:\n",
" prompt_ids: TensorType[\"seq_len\"],\n",
" positions: TensorType[\"n_positions\"],\n",
" generated_ids: TensorType[\"seq_len\"]\n",
" ) -> TensorType[\"n_positions\", \"seq_len\"]:\n",
" \n",
" MAX_PAD = 50\n",
" \n",
Expand Down Expand Up @@ -219,7 +219,7 @@
" \n",
" def _generate_conditioning_prompts(\n",
" self,\n",
" candidate_ids: TensorType[\"batch_size\", \"n_candidates\", \"seq_len\"],\n",
" candidate_ids: TensorType[\"n_candidates\", \"seq_len\"],\n",
" ):\n",
" calculator_api = CalculatorAPI()\n",
" conditioning_api_ids = torch.tensor([])\n",
Expand Down Expand Up @@ -300,7 +300,7 @@
" def _filter_candidate_by_threshold(\n",
" self,\n",
" losses,\n",
" candidates: TensorType[\"batch_size\", \"seq_len\"]\n",
" candidates: TensorType[\"seq_len\"]\n",
" ):\n",
" filtered_augmented_text_ids = []\n",
" for i, position in enumerate(losses):\n",
Expand All @@ -314,9 +314,9 @@
"\n",
" def filter_api( \n",
" self,\n",
" text_ids: TensorType[\"batch_size\", \"seq_len\"],\n",
" api_start_idxs: TensorType[\"batch_size\", \"n_positions\"],\n",
" candidate_ids: TensorType[\"batch_size\", \"n_positions\", \"seq_len\"]\n",
" text_ids: TensorType[\"seq_len\"],\n",
" api_start_idxs: TensorType[\"n_positions\"],\n",
" candidate_ids: TensorType[\"n_positions\", \"seq_len\"]\n",
" ):\n",
" conditioning_api_ids = self._generate_conditioning_prompts(candidate_ids)\n",
" \n",
Expand Down Expand Up @@ -347,53 +347,49 @@
" torch.cat([api_ids[1], SPACE_TOKEN, conditioning_text_ids], dim=0), # [api->result, text_ids]\n",
" ], dim=0)\n",
" \n",
" # api_and_text_ids = conditioning_api_ids[]\n",
" # the next token after x_j\n",
" next_token_ids = text_ids[j]\n",
" augmented_text_ids[\"api_start_positions\"][idx][\"seq_positions\"][j] = {\n",
" \"prompt_ids\": api_and_text_ids,\n",
" \"unnormalized_weight\": self._compute_weight(t=j-idx),\n",
" \"losses\": [],\n",
" \"target_ids\": next_token_ids\n",
" \"target_ids\": torch.tensor([next_token_ids, next_token_ids, next_token_ids])\n",
" }\n",
" j += 1\n",
" \n",
" augmented_text_ids = self._normalize_weights(augmented_text_ids)\n",
" \n",
" \n",
" def extract_conditioning_ids_and_target_ids(augmented_text_ids):\n",
" conditioning_text_ids = torch.tensor([])\n",
" target_ids = []\n",
" target_ids = torch.tensor([])\n",
" \n",
" for _, api_start_position_dict in augmented_text_ids[\"api_start_positions\"].items():\n",
" for _, seq_position_dict in api_start_position_dict[\"seq_positions\"].items():\n",
" target_ids.append(seq_position_dict[\"target_ids\"])\n",
" target_ids = torch.concat([target_ids, seq_position_dict[\"target_ids\"]], dim=0)\n",
" for prompt_id in seq_position_dict[\"prompt_ids\"]:\n",
" conditioning_text_ids = torch.cat([\n",
" conditioning_text_ids,\n",
" F.pad(prompt_id.long(), pad=(50-prompt_id.shape[-1], 0), value=self.pad_token_id).unsqueeze(0)\n",
" ], dim=0)\n",
" \n",
" return conditioning_text_ids, target_ids\n",
" return conditioning_text_ids.long(), target_ids.long()\n",
"\n",
" conditioning_text_ids, target_ids = extract_conditioning_ids_and_target_ids(augmented_text_ids)\n",
" \n",
" output = self.model(input_ids=conditioning_text_ids.long())\n",
" logits = output.logits[:, -1, :]\n",
" \n",
" def extract_target_logprob_from_logits(logits, target_ids):\n",
" probs = F.softmax(logits, dim=-1)\n",
" i = 0\n",
" log_probs = torch.tensor([])\n",
" for x in target_ids:\n",
" log_probs = torch.cat([log_probs, probs[i:i+3][:, x].log().unsqueeze(0)], dim=0)\n",
" i += 3\n",
" return log_probs\n",
" log_probs = F.log_softmax(logits, dim=-1)\n",
" target_log_probs = log_probs[range(target_ids.shape[-1]), target_ids]\n",
" return target_log_probs\n",
"\n",
" log_probs = extract_target_logprob_from_logits(logits, target_ids)\n",
" \n",
" for _, api_start_position_dict in augmented_text_ids[\"api_start_positions\"].items():\n",
" for _, seq_position_dict in api_start_position_dict[\"seq_positions\"].items():\n",
" seq_position_dict[\"losses\"] = log_probs[:1].squeeze(0)\n",
" log_probs = log_probs[1:]\n",
" seq_position_dict[\"losses\"] = log_probs[:3].squeeze(0)\n",
" log_probs = log_probs[3:]\n",
" \n",
" augmented_text_ids = self._calculate_weighted_loss(augmented_text_ids)\n",
" losses = self._calculate_loss(augmented_text_ids)\n",
Expand All @@ -404,7 +400,7 @@
" self,\n",
" prompt_tempalte: PromptTemplate,\n",
" text: str,\n",
" ) -> TensorType[\"batch_size\", \"n_candidates\", \"seq_len\"]:\n",
" ) -> TensorType[\"n_candidates\", \"seq_len\"]:\n",
" # TODO: add support batch\n",
" prompt = prompt_tempalte.format(input=text)\n",
" prompt_ids = self.tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"][0]\n",
Expand All @@ -417,6 +413,8 @@
"\n",
" # filtering\n",
" text_ids = self.tokenizer(text, return_tensors=\"pt\")[\"input_ids\"][0]\n",
" \n",
" # return prompt_ids, api_start_idxs, generated_ids, candidate_ids, text_ids\n",
" filtered_candidate_ids = self.filter_api(text_ids, api_start_idxs, candidate_ids)\n",
" \n",
" return filtered_candidate_ids"
Expand Down
54 changes: 26 additions & 28 deletions toolformer/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def extract_api_syntax(self, text: str, api_name: str) -> str:

def sample_api_position(
self,
prompt_ids: TensorType["batch_size", "seq_len"], # the ids of the prompt
prompt_ids: TensorType["seq_len"], # the ids of the prompt
) -> Tuple[
TensorType["batch_size", "n_positions"], # The positions of api call
TensorType["batch_size", "seq_len"] # The generated text
TensorType["n_positions"], # The positions of api call
TensorType["seq_len"] # The generated text
]:
"""Sampling API positions."""
# TODO: add support batch
Expand Down Expand Up @@ -119,10 +119,10 @@ def sample_api_position(

def obtain_api_response(
self,
prompt_ids: TensorType["batch_size", "seq_len"],
positions: TensorType["batch_size", "n_positions"],
generated_ids: TensorType["batch_size", "seq_len"]
) -> TensorType["batch_size", "n_positions", "seq_len"]:
prompt_ids: TensorType["seq_len"],
positions: TensorType["n_positions"],
generated_ids: TensorType["seq_len"]
) -> TensorType["n_positions", "seq_len"]:

MAX_PAD = 50

Expand Down Expand Up @@ -163,7 +163,7 @@ def obtain_api_response(

def _generate_conditioning_prompts(
self,
candidate_ids: TensorType["batch_size", "n_candidates", "seq_len"],
candidate_ids: TensorType["n_candidates", "seq_len"],
):
calculator_api = CalculatorAPI()
conditioning_api_ids = torch.tensor([])
Expand Down Expand Up @@ -244,7 +244,7 @@ def _calculate_loss(self, augmented_text_ids):
def _filter_candidate_by_threshold(
self,
losses,
candidates: TensorType["batch_size", "seq_len"]
candidates: TensorType["seq_len"]
):
filtered_augmented_text_ids = []
for i, position in enumerate(losses):
Expand All @@ -258,9 +258,9 @@ def _filter_candidate_by_threshold(

def filter_api(
self,
text_ids: TensorType["batch_size", "seq_len"],
api_start_idxs: TensorType["batch_size", "n_positions"],
candidate_ids: TensorType["batch_size", "n_positions", "seq_len"]
text_ids: TensorType["seq_len"],
api_start_idxs: TensorType["n_positions"],
candidate_ids: TensorType["n_positions", "seq_len"]
):
conditioning_api_ids = self._generate_conditioning_prompts(candidate_ids)

Expand Down Expand Up @@ -291,53 +291,49 @@ def filter_api(
torch.cat([api_ids[1], SPACE_TOKEN, conditioning_text_ids], dim=0), # [api->result, text_ids]
], dim=0)

# api_and_text_ids = conditioning_api_ids[]
# the next token after x_j
next_token_ids = text_ids[j]
augmented_text_ids["api_start_positions"][idx]["seq_positions"][j] = {
"prompt_ids": api_and_text_ids,
"unnormalized_weight": self._compute_weight(t=j-idx),
"losses": [],
"target_ids": next_token_ids
"target_ids": torch.tensor([next_token_ids, next_token_ids, next_token_ids])
}
j += 1

augmented_text_ids = self._normalize_weights(augmented_text_ids)

def extract_conditioning_ids_and_target_ids(augmented_text_ids):
conditioning_text_ids = torch.tensor([])
target_ids = []
target_ids = torch.tensor([])

for _, api_start_position_dict in augmented_text_ids["api_start_positions"].items():
for _, seq_position_dict in api_start_position_dict["seq_positions"].items():
target_ids.append(seq_position_dict["target_ids"])
target_ids = torch.concat([target_ids, seq_position_dict["target_ids"]], dim=0)
for prompt_id in seq_position_dict["prompt_ids"]:
conditioning_text_ids = torch.cat([
conditioning_text_ids,
F.pad(prompt_id.long(), pad=(50-prompt_id.shape[-1], 0), value=self.pad_token_id).unsqueeze(0)
], dim=0)

return conditioning_text_ids, target_ids
return conditioning_text_ids.long(), target_ids.long()

conditioning_text_ids, target_ids = extract_conditioning_ids_and_target_ids(augmented_text_ids)

output = self.model(input_ids=conditioning_text_ids.long())
logits = output.logits[:, -1, :]

def extract_target_logprob_from_logits(logits, target_ids):
probs = F.softmax(logits, dim=-1)
i = 0
log_probs = torch.tensor([])
for x in target_ids:
log_probs = torch.cat([log_probs, probs[i:i+3][:, x].log().unsqueeze(0)], dim=0)
i += 3
return log_probs
log_probs = F.log_softmax(logits, dim=-1)
target_log_probs = log_probs[range(target_ids.shape[-1]), target_ids]
return target_log_probs

log_probs = extract_target_logprob_from_logits(logits, target_ids)

for _, api_start_position_dict in augmented_text_ids["api_start_positions"].items():
for _, seq_position_dict in api_start_position_dict["seq_positions"].items():
seq_position_dict["losses"] = log_probs[:1].squeeze(0)
log_probs = log_probs[1:]
seq_position_dict["losses"] = log_probs[:3].squeeze(0)
log_probs = log_probs[3:]

augmented_text_ids = self._calculate_weighted_loss(augmented_text_ids)
losses = self._calculate_loss(augmented_text_ids)
Expand All @@ -348,7 +344,7 @@ def generate(
self,
prompt_tempalte: PromptTemplate,
text: str,
) -> TensorType["batch_size", "n_candidates", "seq_len"]:
) -> TensorType["n_candidates", "seq_len"]:
# TODO: add support batch
prompt = prompt_tempalte.format(input=text)
prompt_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0]
Expand All @@ -361,6 +357,8 @@ def generate(

# filtering
text_ids = self.tokenizer(text, return_tensors="pt")["input_ids"][0]

# return prompt_ids, api_start_idxs, generated_ids, candidate_ids, text_ids
filtered_candidate_ids = self.filter_api(text_ids, api_start_idxs, candidate_ids)

return filtered_candidate_ids

0 comments on commit 2d278c3

Please sign in to comment.