-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update LLM.generate
output to include statistics
#1034
base: develop
Are you sure you want to change the base?
Conversation
Documentation for this PR has been built. You can view it at: https://distilabel.argilla.io/pr-1034/ |
CodSpeed Performance ReportMerging #1034 will not alter performanceComparing Summary
|
…dd new merge_dicts to help merging user-assistant messages in magpie
…to llm-generate-upgrade
""" | ||
if isinstance(text_or_messages, list): | ||
# If it's a list of messages, concatenate the content of each message | ||
text = " ".join([message["content"] for message in text_or_messages]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be better to tokenize each message individually and then sum the results to be 100% precise.
input_tokens: The number of tokens of the inputs. Defaults to [0]. | ||
output_tokens: The number of tokens of the LLM response. Defaults to [0]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_tokens: The number of tokens of the inputs. Defaults to [0]. | |
output_tokens: The number of tokens of the LLM response. Defaults to [0]. | |
input_tokens: The number of tokens of the inputs. Defaults to `None`. | |
output_tokens: The number of tokens of the LLM response. Defaults to `None`. |
} | ||
|
||
@staticmethod | ||
def _prepare_sorted_resuts( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _prepare_sorted_resuts( | |
def _prepare_sorted_results( |
batched_outputs = _sort_batches( | ||
batched_outputs, sorted_indices, num_generations=num_generations | ||
# Sort the batched outputs together with the statistics | ||
generations = self._prepare_sorted_resuts( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generations = self._prepare_sorted_resuts( | |
generations = self._prepare_sorted_results( |
) | ||
statistics[field] = batched_field | ||
|
||
# Regenerates the outputs as they are returned buy `preare_output` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Regenerates the outputs as they are returned buy `preare_output` | |
# Regenerates the outputs as they are returned buy `prepare_output` |
@@ -312,6 +312,7 @@ def process(self, inputs: StepInput) -> "StepOutput": | |||
self._logger.info(f"📦 Processing internal batch of inputs {i}...") | |||
results = super().process(batched_inputs) | |||
for result in next(results): # Extract the elements from the generator | |||
print("INTERMEDIATE RESULTS", result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print("INTERMEDIATE RESULTS", result) |
return output | ||
|
||
|
||
def iterate_generations_with_stats(output: "GenerateOutput") -> "GenerateOutput": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the return type hint should be Generator[Tuple[...], None, None]
messages = [output["generations"][0] for output in outputs] | ||
statistics = [output["statistics"] for output in outputs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can combine this in the same for loop
# """`StepOutput` is an alias of the typing. | ||
# A step output is a dict of the form: | ||
# { | ||
# "outputs": [ | ||
# {"col1": "val1", "col2": "val2"}, | ||
# {"col1": "val1", "col2": "val2"}, | ||
# {"col1": "val1", "col2": "val2"}, | ||
# ], | ||
# "statistics": { | ||
# "llm": {}, | ||
# "time": 12341234, | ||
# ... | ||
# } | ||
# } | ||
# """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be uncommented, right?
|
||
|
||
def flatten_dict(x: Dict[Any, Any]) -> Dict[Any, Any]: | ||
return {k: json.dumps(v) if isinstance(v, dict) else v for k, v in x.items()} | ||
|
||
|
||
def merge_dicts(*dict_lists): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def merge_dicts(*dict_lists): | |
def merge_dicts(*dict_lists: dict) -> list[dict]: |
LLM.generate
output to include statistics
Description
This PR updates the output from
llm.generate
to make it more feature rich.Previously we only returned the generated text:
Now it is updated to allow for statistics related to the generation:
This PR only includes
input_tokens
andoutput_tokens
as statistics, but we can add as much as needed in the future.This information is moved to
distilabel_metadata
in the following way, to avoid collisions between statistics of different steps:NOTE:
Most
Task
reuse the sameTask.process
method to process the generations, and nothing else has to be done, but for tasks likeMagpie
where theprocess
method is overwritten, this has to be updated.Closes #738