Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update LLM.generate output to include statistics #1034

Open
wants to merge 34 commits into
base: develop
Choose a base branch
from

Conversation

plaguss
Copy link
Contributor

@plaguss plaguss commented Oct 11, 2024

Description

This PR updates the output from llm.generate to make it more feature rich.

Previously we only returned the generated text:

GenerateOutput = List[Union[str, None]]

Now it is updated to allow for statistics related to the generation:

LLMOutput = List[Union[str, None]]

class TokenCount(TypedDict):
    input_tokens: List[int]
    output_tokens: List[int]

LLMStatistics = Union[TokenCount, Dict[str, Any]]
"""Initially the LLMStatistics will contain the token count, but can have more variables.
They can be added once we have them defined for every LLM.
"""

class GenerateOutput(TypedDict):
    generations: LLMOutput
    statistics: LLMStatistics

This PR only includes input_tokens and output_tokens as statistics, but we can add as much as needed in the future.

This information is moved to distilabel_metadata in the following way, to avoid collisions between statistics of different steps:

{
    "generations": ["Hello Magpie"],
    f"statistics_{step_name}": {
        "input_tokens": [12],
        "output_tokens": [12],
    },
}

NOTE:
Most Task reuse the same Task.process method to process the generations, and nothing else has to be done, but for tasks like Magpie where the process method is overwritten, this has to be updated.

Closes #738

@plaguss plaguss added this to the 1.5.0 milestone Oct 11, 2024
@plaguss plaguss self-assigned this Oct 11, 2024
Copy link

Documentation for this PR has been built. You can view it at: https://distilabel.argilla.io/pr-1034/

Copy link

codspeed-hq bot commented Oct 11, 2024

CodSpeed Performance Report

Merging #1034 will not alter performance

Comparing llm-generate-upgrade (e97f901) with develop (7c8976b)

Summary

✅ 1 untouched benchmarks

@plaguss plaguss added the enhancement New feature or request label Oct 14, 2024
@plaguss plaguss marked this pull request as ready for review October 25, 2024 07:17
"""
if isinstance(text_or_messages, list):
# If it's a list of messages, concatenate the content of each message
text = " ".join([message["content"] for message in text_or_messages])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to tokenize each message individually and then sum the results to be 100% precise.

Comment on lines +53 to +54
input_tokens: The number of tokens of the inputs. Defaults to [0].
output_tokens: The number of tokens of the LLM response. Defaults to [0].
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
input_tokens: The number of tokens of the inputs. Defaults to [0].
output_tokens: The number of tokens of the LLM response. Defaults to [0].
input_tokens: The number of tokens of the inputs. Defaults to `None`.
output_tokens: The number of tokens of the LLM response. Defaults to `None`.

}

@staticmethod
def _prepare_sorted_resuts(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _prepare_sorted_resuts(
def _prepare_sorted_results(

batched_outputs = _sort_batches(
batched_outputs, sorted_indices, num_generations=num_generations
# Sort the batched outputs together with the statistics
generations = self._prepare_sorted_resuts(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
generations = self._prepare_sorted_resuts(
generations = self._prepare_sorted_results(

)
statistics[field] = batched_field

# Regenerates the outputs as they are returned buy `preare_output`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Regenerates the outputs as they are returned buy `preare_output`
# Regenerates the outputs as they are returned buy `prepare_output`

@@ -312,6 +312,7 @@ def process(self, inputs: StepInput) -> "StepOutput":
self._logger.info(f"📦 Processing internal batch of inputs {i}...")
results = super().process(batched_inputs)
for result in next(results): # Extract the elements from the generator
print("INTERMEDIATE RESULTS", result)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("INTERMEDIATE RESULTS", result)

return output


def iterate_generations_with_stats(output: "GenerateOutput") -> "GenerateOutput":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the return type hint should be Generator[Tuple[...], None, None]

Comment on lines +269 to +270
messages = [output["generations"][0] for output in outputs]
statistics = [output["statistics"] for output in outputs]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can combine this in the same for loop

Comment on lines +22 to +36
# """`StepOutput` is an alias of the typing.
# A step output is a dict of the form:
# {
# "outputs": [
# {"col1": "val1", "col2": "val2"},
# {"col1": "val1", "col2": "val2"},
# {"col1": "val1", "col2": "val2"},
# ],
# "statistics": {
# "llm": {},
# "time": 12341234,
# ...
# }
# }
# """
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be uncommented, right?



def flatten_dict(x: Dict[Any, Any]) -> Dict[Any, Any]:
return {k: json.dumps(v) if isinstance(v, dict) else v for k, v in x.items()}


def merge_dicts(*dict_lists):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def merge_dicts(*dict_lists):
def merge_dicts(*dict_lists: dict) -> list[dict]:

@gabrielmbmb gabrielmbmb changed the title Llm generate upgrade Update LLM.generate output to include statistics Nov 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FEATURE] Update LLM.generate interface to allow returning arbitrary/extra stuff related to the generation
2 participants