From 769e904b99fd2b0b5c6fac38d2143026809dadd3 Mon Sep 17 00:00:00 2001 From: Ryan Henning Date: Fri, 8 Nov 2024 20:01:34 -0600 Subject: [PATCH 01/11] More flexible --- src/lasagna/lasagna_openai.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/lasagna/lasagna_openai.py b/src/lasagna/lasagna_openai.py index 93f7a77..a984535 100644 --- a/src/lasagna/lasagna_openai.py +++ b/src/lasagna/lasagna_openai.py @@ -272,14 +272,18 @@ def should_combine( # This is the case where the model started with text and switched # to tool-calling part-way-through. We need to combine these # messages. - assert ('content' in m1 and m1['content']) and ('tool_calls' not in m1 or not m1['tool_calls']) - assert ('content' not in m2 or not m2['content']) and ('tool_calls' in m2 and m2['tool_calls']) - m_combined: ChatCompletionMessageParam = { - 'role': 'assistant', - 'content': m1['content'], - 'tool_calls': m2['tool_calls'], - } - return True, m_combined + is_first_just_content = ('content' in m1 and m1['content']) and ('tool_calls' not in m1 or not m1['tool_calls']) + is_second_just_tools = ('content' not in m2 or not m2['content']) and ('tool_calls' in m2 and m2['tool_calls']) + if is_first_just_content and is_second_just_tools: + assert 'content' in m1 and 'tool_calls' in m2 + m_combined: ChatCompletionMessageParam = { + 'role': 'assistant', + 'content': m1['content'], + 'tool_calls': m2['tool_calls'], + } + return True, m_combined + else: + return False return False return combine_pairs(ms, should_combine) From 6f9586caaf00f808220a73b892a2e542e12da573 Mon Sep 17 00:00:00 2001 From: Ryan Henning Date: Fri, 8 Nov 2024 20:06:58 -0600 Subject: [PATCH 02/11] Reorder functions and imports --- src/lasagna/agent_util.py | 163 +++++++++++++++++++------------------- 1 file changed, 83 insertions(+), 80 deletions(-) diff --git a/src/lasagna/agent_util.py b/src/lasagna/agent_util.py index 248f293..42100d1 100644 --- a/src/lasagna/agent_util.py +++ b/src/lasagna/agent_util.py @@ -1,4 +1,4 @@ -from typing import Union, Dict, Any, List, Callable, Protocol, Type +from .agent_runner import run from .util import get_name @@ -17,7 +17,10 @@ ToolResult, ) -from .agent_runner import run +from typing import ( + Union, Dict, List, Any, Type, + Callable, Protocol, +) def bind_model( @@ -73,6 +76,79 @@ def __str__(self) -> str: return PartialModelBinder() +def _extract_messages_from_tool_result( + tool_result: ToolResult, +) -> List[Message]: + if tool_result['type'] == 'layered_agent': + run = tool_result['run'] + return recursive_extract_messages([run], from_layered_agents=True) + return [] + + +def _recursive_extract_messages_from_tool_res( + messages: List[Message], +) -> List[Message]: + ms: List[Message] = [] + for m in messages: + ms.append(m) + if m['role'] == 'tool_res': + for t in m['tools']: + ms.extend(_extract_messages_from_tool_result(t)) + return ms + + +def recursive_extract_messages( + agent_runs: List[AgentRun], + from_layered_agents: bool, +) -> List[Message]: + """DFS retrieve all messages within a list of `AgentRuns`.""" + messages: List[Message] = [] + for run in agent_runs: + if run['type'] == 'messages': + messages.extend( + ( + _recursive_extract_messages_from_tool_res(run['messages']) + if from_layered_agents else + run['messages'] + ), + ) + elif run['type'] == 'chain' or run['type'] == 'parallel': + messages.extend( + recursive_extract_messages(run['runs'], from_layered_agents=from_layered_agents), + ) + elif run['type'] == 'extraction': + messages.extend( + ( + _recursive_extract_messages_from_tool_res([run['message']]) + if from_layered_agents else + [run['message']] + ), + ) + else: + raise RuntimeError('unreachable') + return messages + + +def extract_last_message( + agent_run_or_runs: Union[AgentRun, List[AgentRun]], + from_layered_agents: bool, +) -> Message: + if isinstance(agent_run_or_runs, list): + messages = recursive_extract_messages(agent_run_or_runs, from_layered_agents=from_layered_agents) + else: + messages = recursive_extract_messages([agent_run_or_runs], from_layered_agents=from_layered_agents) + if len(messages) == 0: + raise ValueError('no messages found') + return messages[-1] + + +def flat_messages(messages: List[Message]) -> AgentRun: + return { + 'type': 'messages', + 'messages': messages, + } + + def override_system_prompt( messages: List[Message], system_prompt: str, @@ -97,6 +173,11 @@ def strip_tool_calls_and_results( ] +async def noop_callback(event: EventPayload) -> None: + # "noop" mean "no operation" means DON'T DO ANYTHING! + assert event + + def build_simple_agent( name: str, tools: List[Callable] = [], @@ -163,81 +244,3 @@ def __str__(self) -> str: if doc: a.__doc__ = doc return a - - -def _extract_messages_from_tool_result( - tool_result: ToolResult, -) -> List[Message]: - if tool_result['type'] == 'layered_agent': - run = tool_result['run'] - return recursive_extract_messages([run], from_layered_agents=True) - return [] - - -def _recursive_extract_messages_from_tool_res( - messages: List[Message], -) -> List[Message]: - ms: List[Message] = [] - for m in messages: - ms.append(m) - if m['role'] == 'tool_res': - for t in m['tools']: - ms.extend(_extract_messages_from_tool_result(t)) - return ms - - -def recursive_extract_messages( - agent_runs: List[AgentRun], - from_layered_agents: bool, -) -> List[Message]: - """DFS retrieve all messages within a list of `AgentRuns`.""" - messages: List[Message] = [] - for run in agent_runs: - if run['type'] == 'messages': - messages.extend( - ( - _recursive_extract_messages_from_tool_res(run['messages']) - if from_layered_agents else - run['messages'] - ), - ) - elif run['type'] == 'chain' or run['type'] == 'parallel': - messages.extend( - recursive_extract_messages(run['runs'], from_layered_agents=from_layered_agents), - ) - elif run['type'] == 'extraction': - messages.extend( - ( - _recursive_extract_messages_from_tool_res([run['message']]) - if from_layered_agents else - [run['message']] - ), - ) - else: - raise RuntimeError('unreachable') - return messages - - -def flat_messages(messages: List[Message]) -> AgentRun: - return { - 'type': 'messages', - 'messages': messages, - } - - -async def noop_callback(event: EventPayload) -> None: - # "noop" mean "no operation" means DON'T DO ANYTHING! - pass - - -def extract_last_message( - agent_run_or_runs: Union[AgentRun, List[AgentRun]], - from_layered_agents: bool, -) -> Message: - if isinstance(agent_run_or_runs, list): - messages = recursive_extract_messages(agent_run_or_runs, from_layered_agents=from_layered_agents) - else: - messages = recursive_extract_messages([agent_run_or_runs], from_layered_agents=from_layered_agents) - if len(messages) == 0: - raise ValueError('no messages found') - return messages[-1] From 5698055360f1558585ecbcdd5ed037281c752db0 Mon Sep 17 00:00:00 2001 From: Ryan Henning Date: Fri, 8 Nov 2024 20:18:07 -0600 Subject: [PATCH 03/11] Changing the interface! --- src/lasagna/agent_util.py | 39 ++++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/src/lasagna/agent_util.py b/src/lasagna/agent_util.py index 42100d1..7ef1152 100644 --- a/src/lasagna/agent_util.py +++ b/src/lasagna/agent_util.py @@ -178,14 +178,34 @@ async def noop_callback(event: EventPayload) -> None: assert event +MessageExtractor = Callable[[List[AgentRun]], List[Message]] + + +def build_standard_message_extractor( + system_prompt_override: Union[str, None] = None, + strip_old_tool_use_messages: bool = False, +) -> MessageExtractor: + def extractor(prev_runs: List[AgentRun]) -> List[Message]: + messages = recursive_extract_messages(prev_runs, from_layered_agents=False) + if system_prompt_override: + messages = override_system_prompt(messages, system_prompt_override) + if strip_old_tool_use_messages: + messages = strip_tool_calls_and_results(messages) + return messages + + return extractor + + +default_message_extractor = build_standard_message_extractor() + + def build_simple_agent( name: str, tools: List[Callable] = [], - doc: Union[str, None] = None, - system_prompt_override: Union[str, None] = None, - strip_old_tool_use_messages: bool = False, force_tool: bool = False, max_tool_iters: int = 5, + message_extractor: MessageExtractor = default_message_extractor, + doc: Union[str, None] = None, ) -> AgentCallable: class SimpleAgent(): async def __call__( @@ -194,11 +214,7 @@ async def __call__( event_callback: EventCallback, prev_runs: List[AgentRun], ) -> AgentRun: - messages = recursive_extract_messages(prev_runs, from_layered_agents=False) - if system_prompt_override: - messages = override_system_prompt(messages, system_prompt_override) - if strip_old_tool_use_messages: - messages = strip_tool_calls_and_results(messages) + messages = message_extractor(prev_runs) new_messages = await model.run( event_callback = event_callback, messages = messages, @@ -218,8 +234,9 @@ def __str__(self) -> str: def build_extraction_agent( + name: str, extraction_type: Type[ExtractionType], - name: Union[str, None] = None, + message_extractor: MessageExtractor = default_message_extractor, doc: Union[str, None] = None, ) -> AgentCallable: class ExtractionAgent(): @@ -229,7 +246,7 @@ async def __call__( event_callback: EventCallback, prev_runs: List[AgentRun], ) -> AgentRun: - messages = recursive_extract_messages(prev_runs, from_layered_agents=False) + messages = message_extractor(prev_runs) message, result = await model.extract(event_callback, messages, extraction_type) return { 'type': 'extraction', @@ -238,7 +255,7 @@ async def __call__( } def __str__(self) -> str: - return name or 'extraction_agent' + return name a = ExtractionAgent() if doc: From 22353de7865262b5185836cdcca75c9c37ddc992 Mon Sep 17 00:00:00 2001 From: Ryan Henning Date: Fri, 8 Nov 2024 20:26:33 -0600 Subject: [PATCH 04/11] Chaining agent --- src/lasagna/agent_util.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/lasagna/agent_util.py b/src/lasagna/agent_util.py index 7ef1152..7ece734 100644 --- a/src/lasagna/agent_util.py +++ b/src/lasagna/agent_util.py @@ -261,3 +261,38 @@ def __str__(self) -> str: if doc: a.__doc__ = doc return a + + +def build_agent_chainer( + name: str, + agents: List[BoundAgentCallable], + message_extractor: Union[MessageExtractor, None] = None, + doc: Union[str, None] = None, +) -> BoundAgentCallable: + class ChainedAgents(): + async def __call__( + self, + event_callback: EventCallback, + prev_runs: List[AgentRun], + ) -> AgentRun: + if message_extractor is not None: + prev_runs = [flat_messages(message_extractor(prev_runs))] + else: + prev_runs = [*prev_runs] # shallow copy + new_runs: List[AgentRun] = [] + for agent in agents: + this_run = await agent(event_callback, prev_runs) + prev_runs.append(this_run) + new_runs.append(this_run) + return { + 'type': 'chain', + 'runs': new_runs, + } + + def __str__(self) -> str: + return name + + a = ChainedAgents() + if doc: + a.__doc__ = doc + return a From e54c4ce81a6a642389aeb00ac9fd404e24f7ee46 Mon Sep 17 00:00:00 2001 From: Ryan Henning Date: Sat, 9 Nov 2024 12:16:27 -0600 Subject: [PATCH 05/11] Agent router --- src/lasagna/agent_util.py | 40 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/src/lasagna/agent_util.py b/src/lasagna/agent_util.py index 7ece734..f67cfcd 100644 --- a/src/lasagna/agent_util.py +++ b/src/lasagna/agent_util.py @@ -296,3 +296,43 @@ def __str__(self) -> str: if doc: a.__doc__ = doc return a + + +def build_agent_router( + name: str, + extraction_type: Type[ExtractionType], + pick_agent_func: Callable[[ExtractionType], BoundAgentCallable], + message_extractor: MessageExtractor = default_message_extractor, + doc: Union[str, None] = None, +) -> AgentCallable: + class AgentRouter(): + async def __call__( + self, + model: Model, + event_callback: EventCallback, + prev_runs: List[AgentRun], + ) -> AgentRun: + messages = message_extractor(prev_runs) + message, result = await model.extract(event_callback, messages, extraction_type) + extraction: AgentRun = { + 'type': 'extraction', + 'message': message, + 'result': result, + } + agent = pick_agent_func(result) + run = await agent(event_callback, prev_runs) + return { + 'type': 'chain', + 'runs': [ + extraction, + run, + ], + } + + def __str__(self) -> str: + return name + + a = AgentRouter() + if doc: + a.__doc__ = doc + return a From c0aee74939d64661d7167b1b863fdefefd113bf4 Mon Sep 17 00:00:00 2001 From: Ryan Henning Date: Sat, 9 Nov 2024 12:21:18 -0600 Subject: [PATCH 06/11] Export and bump version --- src/lasagna/__init__.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/src/lasagna/__init__.py b/src/lasagna/__init__.py index 18849d0..f40525a 100644 --- a/src/lasagna/__init__.py +++ b/src/lasagna/__init__.py @@ -4,23 +4,39 @@ from .agent_util import ( bind_model, + partial_bind_model, recursive_extract_messages, + extract_last_message, flat_messages, + override_system_prompt, + strip_tool_calls_and_results, + noop_callback, + MessageExtractor, + build_standard_message_extractor, + default_message_extractor, build_simple_agent, build_extraction_agent, - noop_callback, - extract_last_message, + build_agent_chainer, + build_agent_router, ) __all__ = [ 'run', 'bind_model', + 'partial_bind_model', 'recursive_extract_messages', + 'extract_last_message', 'flat_messages', + 'override_system_prompt', + 'strip_tool_calls_and_results', + 'noop_callback', + 'MessageExtractor', + 'build_standard_message_extractor', + 'default_message_extractor', 'build_simple_agent', 'build_extraction_agent', - 'noop_callback', - 'extract_last_message', + 'build_agent_chainer', + 'build_agent_router', ] -__version__ = "0.9.1" +__version__ = "0.10.0" From 7f64840d46e14683d7d18c173c77e57f4297b456 Mon Sep 17 00:00:00 2001 From: Ryan Henning Date: Sat, 9 Nov 2024 12:25:51 -0600 Subject: [PATCH 07/11] Require agent's name --- src/lasagna/agent_util.py | 19 ++++++++++++++++--- src/lasagna/types.py | 2 +- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/lasagna/agent_util.py b/src/lasagna/agent_util.py index f67cfcd..3fe1523 100644 --- a/src/lasagna/agent_util.py +++ b/src/lasagna/agent_util.py @@ -142,8 +142,12 @@ def extract_last_message( return messages[-1] -def flat_messages(messages: List[Message]) -> AgentRun: +def flat_messages( + agent_name: str, + messages: List[Message], +) -> AgentRun: return { + 'agent': agent_name, 'type': 'messages', 'messages': messages, } @@ -222,7 +226,7 @@ async def __call__( force_tool = force_tool, max_tool_iters = max_tool_iters, ) - return flat_messages(new_messages) + return flat_messages(name, new_messages) def __str__(self) -> str: return name @@ -249,6 +253,7 @@ async def __call__( messages = message_extractor(prev_runs) message, result = await model.extract(event_callback, messages, extraction_type) return { + 'agent': name, 'type': 'extraction', 'message': message, 'result': result, @@ -276,7 +281,12 @@ async def __call__( prev_runs: List[AgentRun], ) -> AgentRun: if message_extractor is not None: - prev_runs = [flat_messages(message_extractor(prev_runs))] + prev_runs = [ + flat_messages( + agent_name = name, + messages = message_extractor(prev_runs), + ), + ] else: prev_runs = [*prev_runs] # shallow copy new_runs: List[AgentRun] = [] @@ -285,6 +295,7 @@ async def __call__( prev_runs.append(this_run) new_runs.append(this_run) return { + 'agent': name, 'type': 'chain', 'runs': new_runs, } @@ -315,6 +326,7 @@ async def __call__( messages = message_extractor(prev_runs) message, result = await model.extract(event_callback, messages, extraction_type) extraction: AgentRun = { + 'agent': name, 'type': 'extraction', 'message': message, 'result': result, @@ -322,6 +334,7 @@ async def __call__( agent = pick_agent_func(result) run = await agent(event_callback, prev_runs) return { + 'agent': name, 'type': 'chain', 'runs': [ extraction, diff --git a/src/lasagna/types.py b/src/lasagna/types.py index b2f79bc..5b48e76 100644 --- a/src/lasagna/types.py +++ b/src/lasagna/types.py @@ -85,7 +85,7 @@ class MessageToolResult(MessageBase): class AgentRunBase(TypedDict): - agent: NotRequired[str] + agent: str provider: NotRequired[str] model: NotRequired[str] model_kwargs: NotRequired[Dict[str, Any]] From 843f561a25081aeda775f50f95d370e0b06cc5a4 Mon Sep 17 00:00:00 2001 From: Ryan Henning Date: Sat, 9 Nov 2024 12:46:19 -0600 Subject: [PATCH 08/11] Better --- src/lasagna/__init__.py | 2 ++ src/lasagna/agent_util.py | 24 ++++++++++++++++++++---- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/lasagna/__init__.py b/src/lasagna/__init__.py index f40525a..fa44666 100644 --- a/src/lasagna/__init__.py +++ b/src/lasagna/__init__.py @@ -10,6 +10,7 @@ flat_messages, override_system_prompt, strip_tool_calls_and_results, + strip_all_but_last_human_message, noop_callback, MessageExtractor, build_standard_message_extractor, @@ -29,6 +30,7 @@ 'flat_messages', 'override_system_prompt', 'strip_tool_calls_and_results', + 'strip_all_but_last_human_message', 'noop_callback', 'MessageExtractor', 'build_standard_message_extractor', diff --git a/src/lasagna/agent_util.py b/src/lasagna/agent_util.py index 3fe1523..e108f5a 100644 --- a/src/lasagna/agent_util.py +++ b/src/lasagna/agent_util.py @@ -177,6 +177,15 @@ def strip_tool_calls_and_results( ] +def strip_all_but_last_human_message( + messages: List[Message], +) -> List[Message]: + for m in reversed(messages): + if m['role'] == 'human': + return [m] + return [] + + async def noop_callback(event: EventPayload) -> None: # "noop" mean "no operation" means DON'T DO ANYTHING! assert event @@ -186,15 +195,22 @@ async def noop_callback(event: EventPayload) -> None: def build_standard_message_extractor( + extract_from_layered_agents: bool = False, + keep_only_last_human_message: bool = False, + strip_tool_messages: bool = False, system_prompt_override: Union[str, None] = None, - strip_old_tool_use_messages: bool = False, ) -> MessageExtractor: def extractor(prev_runs: List[AgentRun]) -> List[Message]: - messages = recursive_extract_messages(prev_runs, from_layered_agents=False) + messages = recursive_extract_messages( + agent_runs = prev_runs, + from_layered_agents = extract_from_layered_agents, + ) + if keep_only_last_human_message: + messages = strip_all_but_last_human_message(messages) + if strip_tool_messages: + messages = strip_tool_calls_and_results(messages) if system_prompt_override: messages = override_system_prompt(messages, system_prompt_override) - if strip_old_tool_use_messages: - messages = strip_tool_calls_and_results(messages) return messages return extractor From fb6055bdeea9a73b8ad4efb16de20679c1417bf9 Mon Sep 17 00:00:00 2001 From: Ryan Henning Date: Sat, 9 Nov 2024 13:19:51 -0600 Subject: [PATCH 09/11] Update --- src/lasagna/agent_runner.py | 2 -- src/lasagna/tools_util.py | 7 ++++++- src/lasagna/tui.py | 26 ++++++++++++++++++-------- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/src/lasagna/agent_runner.py b/src/lasagna/agent_runner.py index 7a58318..a8646d1 100644 --- a/src/lasagna/agent_runner.py +++ b/src/lasagna/agent_runner.py @@ -58,8 +58,6 @@ async def run( prev_runs, ) - if 'agent' not in agent_run: - agent_run['agent'] = agent_name if 'provider' not in agent_run: agent_run['provider'] = provider_str if 'model' not in agent_run: diff --git a/src/lasagna/tools_util.py b/src/lasagna/tools_util.py index 7154943..cee47cb 100644 --- a/src/lasagna/tools_util.py +++ b/src/lasagna/tools_util.py @@ -281,7 +281,12 @@ async def _run_single_tool( 'role': 'tool_call', 'tools': [tool_call], }) - prev_runs: List[AgentRun] = [flat_messages(messages)] + prev_runs: List[AgentRun] = [ + flat_messages( + agent_name = 'upstream_agent', + messages = messages, + ), + ] if is_agent_callable: agent = cast(AgentCallable, func) diff --git a/src/lasagna/tui.py b/src/lasagna/tui.py index 19edef4..52eb104 100644 --- a/src/lasagna/tui.py +++ b/src/lasagna/tui.py @@ -35,18 +35,28 @@ async def tui_input_loop( just_fix_windows_console() prev_runs: List[AgentRun] = [] if system_prompt is not None: - prev_runs.append(flat_messages([{ - 'role': 'system', - 'text': system_prompt, - }])) + prev_runs.append(flat_messages( + 'tui_input_loop', + [ + { + 'role': 'system', + 'text': system_prompt, + }, + ], + )) try: while True: human_input = input(Fore.GREEN + Style.BRIGHT + '> ') print(Style.RESET_ALL, end='', flush=True) - prev_runs.append(flat_messages([{ - 'role': 'human', - 'text': human_input, - }])) + prev_runs.append(flat_messages( + 'tui_input_loop', + [ + { + 'role': 'human', + 'text': human_input, + }, + ], + )) this_run = await bound_agent(tui_event_callback, prev_runs) prev_runs.append(this_run) print(Style.RESET_ALL) From 9e2620de3cdd1c01936481eed4bd17642553db29 Mon Sep 17 00:00:00 2001 From: Ryan Henning Date: Sat, 9 Nov 2024 13:37:20 -0600 Subject: [PATCH 10/11] Update tests --- tests/test_agent_runner.py | 1 + tests/test_agent_util.py | 52 +++++++++++++++++++++++++++++--------- tests/test_caching.py | 6 +++++ tests/test_tools_util.py | 3 +++ 4 files changed, 50 insertions(+), 12 deletions(-) diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index f822ac1..6284858 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -32,6 +32,7 @@ async def agent_1( messages: List[Message] = [] new_messages = await model.run(event_callback, messages, []) return { + 'agent': 'agent_1', 'type': 'messages', 'messages': new_messages, } diff --git a/tests/test_agent_util.py b/tests/test_agent_util.py index 39baa03..30b553d 100644 --- a/tests/test_agent_util.py +++ b/tests/test_agent_util.py @@ -5,6 +5,7 @@ from lasagna.agent_util import ( bind_model, build_simple_agent, + build_standard_message_extractor, extract_last_message, noop_callback, partial_bind_model, @@ -96,8 +97,10 @@ async def test_build_layered_agent(): my_agent = build_simple_agent( name = 'a_layered_agent', tools = [], + message_extractor = build_standard_message_extractor( + system_prompt_override = 'system test', + ), doc = 'doc test', - system_prompt_override = 'system test', ) assert my_agent.__doc__ == 'doc test' assert str(my_agent) == 'a_layered_agent' @@ -106,12 +109,17 @@ async def test_build_layered_agent(): assert my_bound_agent.__doc__ == 'doc test' assert my_bound_agent.__name__ == 'a_layered_agent' assert get_name(my_bound_agent) == 'a_layered_agent' - prev_runs: List[AgentRun] = [flat_messages([ - { - 'role': 'human', - 'text': 'layered agent test', - }, - ])] + prev_runs: List[AgentRun] = [ + flat_messages( + 'some_agent', + [ + { + 'role': 'human', + 'text': 'layered agent test', + }, + ], + ), + ] new_run = await my_bound_agent(noop_callback, prev_runs) assert new_run == { 'agent': 'a_layered_agent', @@ -143,12 +151,17 @@ async def test_model_extract(): async def event_callback(event: EventPayload) -> None: events.append(event) prev_runs: List[AgentRun] = [] - run = await my_binder(build_extraction_agent(MyTestType))(event_callback, prev_runs) + run = await my_binder( + build_extraction_agent( + name = 'my_extraction_agent', + extraction_type = MyTestType, + ), + )(event_callback, prev_runs) assert events == [ ( 'agent', 'start', - 'extraction_agent', + 'my_extraction_agent', ), ( 'tool_call', @@ -173,7 +186,7 @@ async def event_callback(event: EventPayload) -> None: }, 'result': MyTestType(a='yes', b=6), 'provider': 'MockProvider', - 'agent': 'extraction_agent', + 'agent': 'my_extraction_agent', 'model': 'some_model', 'model_kwargs': { 'a': 'yes', @@ -198,17 +211,25 @@ async def test_model_extract_type_mismatch(): my_binder = bind_model(MockProvider, 'some_model', {'a': 'yes', 'b': 'BAD VALUE'}) prev_runs: List[AgentRun] = [] with pytest.raises(ValidationError): - await my_binder(build_extraction_agent(MyTestType))(noop_callback, prev_runs) + await my_binder( + build_extraction_agent( + name = 'my_extraction_agent', + extraction_type = MyTestType, + ), + )(noop_callback, prev_runs) def test_recursive_extract_messages(): agent_run: AgentRun = { + 'agent': 'outer_agent', 'type': 'chain', 'runs': [ { + 'agent': 'inner_agent_1', 'type': 'parallel', 'runs': [ { + 'agent': 'inner_agent_2', 'type': 'messages', 'messages': [ { @@ -222,6 +243,7 @@ def test_recursive_extract_messages(): ], }, { + 'agent': 'inner_agent_3', 'type': 'messages', 'messages': [ { @@ -244,6 +266,7 @@ def test_recursive_extract_messages(): 'type': 'layered_agent', 'call_id': 'call001', 'run': { + 'agent': 'inner_agent_4', 'type': 'messages', 'messages': [ { @@ -260,6 +283,7 @@ def test_recursive_extract_messages(): ], }, { + 'agent': 'inner_agent_5', 'type': 'extraction', 'message': { 'role': 'tool_call', @@ -277,6 +301,7 @@ def test_recursive_extract_messages(): 'result': {'value': 7}, }, { + 'agent': 'inner_agent_6', 'type': 'messages', 'messages': [ { @@ -320,6 +345,7 @@ def test_recursive_extract_messages(): 'type': 'layered_agent', 'call_id': 'call001', 'run': { + 'agent': 'inner_agent_4', 'type': 'messages', 'messages': [ { @@ -386,6 +412,7 @@ def test_recursive_extract_messages(): 'type': 'layered_agent', 'call_id': 'call001', 'run': { + 'agent': 'inner_agent_4', 'type': 'messages', 'messages': [ { @@ -436,7 +463,8 @@ def test_flat_messages(): 'text': 'Here kitty kitty!', }, ] - assert flat_messages(messages) == { + assert flat_messages('an_agent', messages) == { + 'agent': 'an_agent', 'type': 'messages', 'messages': [ { diff --git a/tests/test_caching.py b/tests/test_caching.py index b6c58d4..c62c693 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -21,6 +21,7 @@ _PREV_RUNS: List[AgentRun] = [ { + 'agent': 'some_agent', 'type': 'messages', 'messages': [ { @@ -45,6 +46,7 @@ async def my_agent( messages: List[Message] = [] new_messages = await model.run(event_callback, messages, []) return { + 'agent': 'my_agent', 'type': 'messages', 'messages': [ *new_messages, @@ -118,6 +120,10 @@ async def test_hash_agent_runs(): __open_list__ __open_dict__ __open_list__ + __open_list__ + __str__agent + __str__some_agent + __close_list__ __open_list__ __str__messages __open_list__ diff --git a/tests/test_tools_util.py b/tests/test_tools_util.py index 0b3dee1..045c96c 100644 --- a/tests/test_tools_util.py +++ b/tests/test_tools_util.py @@ -983,6 +983,7 @@ def test_extract_tool_result_as_sting(): 'call_id': '1001', 'type': 'layered_agent', 'run': { + 'agent': 'some_downstream_agent', 'type': 'messages', 'messages': [ { @@ -998,6 +999,7 @@ def test_extract_tool_result_as_sting(): 'call_id': '1001', 'type': 'layered_agent', 'run': { + 'agent': 'some_downstream_agent', 'type': 'messages', 'messages': [ { @@ -1022,6 +1024,7 @@ def test_extract_tool_result_as_sting(): 'call_id': '1001', 'type': 'layered_agent', 'run': { + 'agent': 'some_downstream_agent', 'type': 'messages', 'messages': [ { From 3bc6b46e0c3cc68668675100a1d30b38365da135 Mon Sep 17 00:00:00 2001 From: Ryan Henning Date: Sat, 9 Nov 2024 13:44:29 -0600 Subject: [PATCH 11/11] Update examples --- examples/demo_committee.py | 2 +- examples/demo_layered_agents.py | 13 +++++++++---- examples/demo_structured_output.py | 22 +++++++++++++++------- examples/demo_vision.py | 9 ++++++++- 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/examples/demo_committee.py b/examples/demo_committee.py index e4d2882..13bf52a 100644 --- a/examples/demo_committee.py +++ b/examples/demo_committee.py @@ -53,7 +53,7 @@ async def vote_on_jokes(joke_a: str, joke_b: str) -> Dict[str, int]: }, ] tasks = [ - model(noop_callback, [flat_messages(messages)]) + model(noop_callback, [flat_messages('input', messages)]) for model in COMMITTEE_MODELS ] outputs = await asyncio.gather(*tasks) diff --git a/examples/demo_layered_agents.py b/examples/demo_layered_agents.py index ab0a72e..72a92ce 100644 --- a/examples/demo_layered_agents.py +++ b/examples/demo_layered_agents.py @@ -1,6 +1,7 @@ from lasagna import ( known_models, build_simple_agent, + build_standard_message_extractor, ) from lasagna.tui import ( @@ -35,17 +36,21 @@ async def main() -> None: tools = [ evaluate_math_expression, ], + message_extractor = build_standard_message_extractor( + strip_tool_messages = True, + system_prompt_override = "You are a math assistant.", + ), doc = "Use this tool if the user asks a math question.", - system_prompt_override = "You are a math assistant.", - strip_old_tool_use_messages = True, ) health_agent = known_models.BIND_ANTHROPIC_claude_35_sonnet()( build_simple_agent( name = 'health_agent', tools = [], + message_extractor = build_standard_message_extractor( + strip_tool_messages = True, + system_prompt_override = "You are a health coach who motivates through fear.", + ), doc = "Use this tool if the user asks a health question.", - system_prompt_override = "You are a health coach who motivates through fear.", - strip_old_tool_use_messages = True, ), ) my_bound_agent = known_models.BIND_OPENAI_gpt_4o_mini()( diff --git a/examples/demo_structured_output.py b/examples/demo_structured_output.py index 135019e..c13f8f7 100644 --- a/examples/demo_structured_output.py +++ b/examples/demo_structured_output.py @@ -41,14 +41,22 @@ class ExtractionModel(BaseModel): async def main() -> None: my_bound_agent = MODEL_BINDER( - build_extraction_agent(ExtractionModel), + build_extraction_agent( + name = 'extraction_agent', + extraction_type = ExtractionModel, + ), ) - prev_runs = [flat_messages([ - { - 'role': 'human', - 'text': PROMPT, - }, - ])] + prev_runs = [ + flat_messages( + 'input', + [ + { + 'role': 'human', + 'text': PROMPT, + }, + ], + ), + ] agent_run = await my_bound_agent(tui_event_callback, prev_runs) print('DONE!') diff --git a/examples/demo_vision.py b/examples/demo_vision.py index 91822e1..2008ce9 100644 --- a/examples/demo_vision.py +++ b/examples/demo_vision.py @@ -2,6 +2,7 @@ known_models, build_simple_agent, flat_messages, + AgentRun, Message, ) @@ -35,7 +36,13 @@ async def main() -> None: bound_agent = MODEL_BINDER( build_simple_agent(name='agent'), ) - await bound_agent(tui_event_callback, [flat_messages(messages)]) + prev_runs: List[AgentRun] = [ + flat_messages( + 'input', + messages, + ), + ] + await bound_agent(tui_event_callback, prev_runs) print()