Skip to content

Commit

Permalink
Fix missing system_prompt_key column in Magpie tasks (#983)
Browse files Browse the repository at this point in the history
* Fix missing `system_prompt_key` column

* Fix wrong `system_prompt_key` associated to conversation

* Update unit tests
  • Loading branch information
gabrielmbmb authored Sep 17, 2024
1 parent b2d8eb5 commit e67864e
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 23 deletions.
31 changes: 19 additions & 12 deletions src/distilabel/steps/tasks/magpie/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import random
from itertools import zip_longest
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from pydantic import Field, PositiveInt, field_validator
Expand Down Expand Up @@ -115,7 +116,7 @@ def system_prompts_weights_validator(

def _prepare_inputs_for_instruction_generation(
self, inputs: List[Dict[str, Any]]
) -> Tuple[List["ChatType"], Union[str, None]]:
) -> Tuple[List["ChatType"], List[str]]:
"""Prepares the inputs adding the system (if required) prompt provided in each row,
or if the conversations to generate have more than one turn, then adding the system
prompt for multi-turn conversation from the paper.
Expand All @@ -124,10 +125,10 @@ def _prepare_inputs_for_instruction_generation(
inputs: the inputs to prepare.
Returns:
The prepared inputs.
The prepared inputs and the system prompt keys used for each input.
"""
prepared_inputs = []
system_prompt_key = None
system_prompt_keys = []
for input in inputs:
conversation = []
if "system_prompt" in input:
Expand All @@ -146,6 +147,7 @@ def _prepare_inputs_for_instruction_generation(
system_prompt_key = random.choices(
system_prompts_keys, weights, k=1
)[0]
system_prompt_keys.append(system_prompt_key)
system_prompt = self.system_prompt[system_prompt_key]
if isinstance(system_prompt, tuple):
system_prompt = system_prompt[0]
Expand All @@ -159,7 +161,7 @@ def _prepare_inputs_for_instruction_generation(

prepared_inputs.append(conversation)

return prepared_inputs, system_prompt_key
return prepared_inputs, system_prompt_keys

def _append_messages_to_conversations(
self, role: str, messages: List[str], conversations: List["ChatType"]
Expand All @@ -182,7 +184,7 @@ def _append_messages_to_conversations(
def _generate_instruction(
self, inputs: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
prepared_inputs, system_prompt_key = (
prepared_inputs, system_prompt_keys = (
self._prepare_inputs_for_instruction_generation(inputs)
)
outputs = self.llm.generate(
Expand All @@ -191,30 +193,35 @@ def _generate_instruction(
**self.llm.generation_kwargs, # type: ignore
)
rows = []
for output in outputs:
row = {"instruction": output[0]}
for output, system_prompt_key in zip_longest(
outputs, system_prompt_keys, fillvalue=None
):
row = {"instruction": output[0]} # type: ignore
if system_prompt_key is not None:
row["system_prompt_key"] = system_prompt_key
rows.append(row)
return rows

def _prepare_conversation_outputs(
self, conversations: List["ChatType"], system_prompt_key: Optional[str] = None
self, conversations: List["ChatType"], system_prompt_keys: List[str]
) -> List[Dict[str, Any]]:
"""Prepare the output conversation removing the system prompt if necessary. If
`n_turns==1`, then it will return a dictionary with "instruction" and "response"
keys. Otherwise, it will return a dictionary with a "conversation" key.
Args:
conversations: the list of generated conversations.
system_prompt_key: the key of the system prompt used to generate the conversation.
system_prompt_keys: the list of system prompt keys used to generate the conversations.
Returns:
A list of dictionaries containing a "conversation" key or "instruction" and
"responses" key.
"""
outputs = []
for conversation in conversations:
for conversation, system_prompt_key in zip_longest(
conversations, system_prompt_keys, fillvalue=None
):
assert conversation is not None
# Something went wrong with the `LLM` and it didn't generate any message
if len(conversation) == 0:
if self.n_turns == 1:
Expand Down Expand Up @@ -265,7 +272,7 @@ def _generate_conversation_turn(
def _generate_multi_turn_conversation(
self, inputs: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
conversations, system_prompt_key = (
conversations, system_prompt_keys = (
self._prepare_inputs_for_instruction_generation(inputs)
)
# Keep track of the active conversations, as it could happen that for some conversation
Expand Down Expand Up @@ -294,7 +301,7 @@ def _generate_multi_turn_conversation(
active_indices=active_indices,
)

return self._prepare_conversation_outputs(conversations)
return self._prepare_conversation_outputs(conversations, system_prompt_keys)

def _generate_with_pre_query_template(
self, inputs: List[Dict[str, Any]]
Expand Down
62 changes: 51 additions & 11 deletions tests/unit/steps/tasks/magpie/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,46 @@ def test_process_with_system_prompt_per_row(self) -> None:
},
]

def test_process_with_system_prompt_and_probabilities(self) -> None:
with mock.patch(
"random.choices",
side_effect=[
["system_prompt_1"],
["system_prompt_2"],
["system_prompt_1"],
],
):
task = Magpie(
llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
system_prompt={
"system_prompt_1": ("system_prompt", 0.6),
"system_prompt_2": ("system_prompt", 0.4),
},
)

task.load()

assert next(task.process(inputs=[{}, {}, {}])) == [
{
"instruction": "Hello Magpie",
"response": "Hello Magpie",
"system_prompt_key": "system_prompt_1",
"model_name": "test",
},
{
"instruction": "Hello Magpie",
"response": "Hello Magpie",
"system_prompt_key": "system_prompt_2",
"model_name": "test",
},
{
"instruction": "Hello Magpie",
"response": "Hello Magpie",
"system_prompt_key": "system_prompt_1",
"model_name": "test",
},
]

def test_process_only_instruction(self) -> None:
task = Magpie(
llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
Expand Down Expand Up @@ -504,58 +544,58 @@ def test_prepare_conversation_outputs(
n_turns=n_turns,
include_system_prompt=include_system_prompt,
)
assert task._prepare_conversation_outputs([conversation]) == [expected]
assert task._prepare_conversation_outputs([conversation], []) == [expected]

@pytest.mark.parametrize(
"system_prompt, n_turns, inputs, random_choices_return, expected_prepared_inputs, expected_system_prompt_key",
"system_prompt, n_turns, inputs, random_choices_return, expected_prepared_inputs, expected_system_prompt_keys",
[
(
None,
1,
[{"system_prompt": "Custom system prompt."}],
None,
[[{"role": "system", "content": "Custom system prompt."}]],
None,
[],
),
(
["Prompt A", "Prompt B"],
1,
[{}],
["Prompt A"],
[[{"role": "system", "content": "Prompt A"}]],
None,
[],
),
(
{"Key1": "Prompt 1", "Key2": "Prompt 2"},
1,
[{}],
["Key1"],
[[{"role": "system", "content": "Prompt 1"}]],
"Key1",
["Key1"],
),
(
{"Key1": ("Prompt 1", 0.7), "Key2": ("Prompt 2", 0.3)},
1,
[{}],
["Key1"],
[[{"role": "system", "content": "Prompt 1"}]],
"Key1",
["Key1"],
),
(
None,
2,
[{}],
None,
[[{"role": "system", "content": MAGPIE_MULTI_TURN_SYSTEM_PROMPT}]],
None,
[],
),
(
None,
1,
[{}],
None,
[[]],
None,
[],
),
],
)
Expand All @@ -566,7 +606,7 @@ def test_prepare_inputs_for_instruction_generation(
inputs,
random_choices_return,
expected_prepared_inputs,
expected_system_prompt_key,
expected_system_prompt_keys,
):
task = Magpie(
llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
Expand All @@ -578,12 +618,12 @@ def test_prepare_inputs_for_instruction_generation(
if random_choices_return is not None:
mock_choices.return_value = random_choices_return

prepared_inputs, system_prompt_key = (
prepared_inputs, system_prompt_keys = (
task._prepare_inputs_for_instruction_generation(inputs)
)

assert prepared_inputs == expected_prepared_inputs
assert system_prompt_key == expected_system_prompt_key
assert system_prompt_keys == expected_system_prompt_keys

def test_serialization(self) -> None:
task = Magpie(
Expand Down

0 comments on commit e67864e

Please sign in to comment.