From c3bb30a3455bcc00fae0eb1c3cbd8ab033b5da59 Mon Sep 17 00:00:00 2001 From: Ryan Henning Date: Mon, 27 Jan 2025 23:53:10 +0000 Subject: [PATCH 1/4] Start of ollama support --- README.md | 3 +- pyproject.toml | 1 + src/lasagna/known_models.py | 24 +++ src/lasagna/known_providers.py | 11 ++ src/lasagna/lasagna_ollama.py | 351 +++++++++++++++++++++++++++++++++ 5 files changed, 389 insertions(+), 1 deletion(-) create mode 100644 src/lasagna/lasagna_ollama.py diff --git a/README.md b/README.md index c274aa9..7883002 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,8 @@ - Yes, you _can_ have _both_ streaming and easy database storage. - ↔️ **Provider/model agnostic and interoperable!** - - Native support for [OpenAI](https://platform.openai.com/docs/models), [Anthropic](https://docs.anthropic.com/en/docs/welcome), [NVIDIA NIM/NGC](https://build.nvidia.com/explore/reasoning) (+ more to come). + - Core support for [OpenAI](https://platform.openai.com/docs/models) and [Anthropic](https://docs.anthropic.com/en/docs/welcome). + - Experimental support for [Ollama](https://ollama.com/search) and [NVIDIA NIM/NGC](https://build.nvidia.com/explore/reasoning). - Message representations are canonized. 😇 - Supports vision! - Easily build committees! diff --git a/pyproject.toml b/pyproject.toml index fc9b2ce..993200b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ classifiers = [ dependencies = [ "svs", "aiohttp[speedups]", + "httpx", "python-dotenv", "colorama", "pydantic >= 2.7, < 3", diff --git a/src/lasagna/known_models.py b/src/lasagna/known_models.py index f0806d2..d9c8d8a 100644 --- a/src/lasagna/known_models.py +++ b/src/lasagna/known_models.py @@ -63,6 +63,30 @@ BIND_ANTHROPIC_claude_3_haiku = partial_bind_model('anthropic', 'claude-3-haiku-20240307') +OLLAMA_KNOWN_MODELS: List[ModelRecord] = [ + # We'll only list models here that (1) support tool-calling because + # that's kind of the whole point of lasagna, and (2) we've tested + # ourselves and pass our "vibe" test. Users are of course more than + # welcome to use *other* Ollama models as they see fit. + { + 'formal_name': 'llama3.2', + 'display_name': 'Meta Llama 3.2', + }, + { + 'formal_name': 'mistral-small', + 'display_name': 'Mistral Small', + }, + { + 'formal_name': 'mistral-large', + 'display_name': 'Mistral Large', + }, +] + +BIND_OLLAMA_llama3_2 = partial_bind_model('ollama', 'llama3.2') +BIND_OLLAMA_mistral_small = partial_bind_model('ollama', 'mistral-small') +BIND_OLLAMA_mistral_large = partial_bind_model('ollama', 'mistral-large') + + NVIDIA_KNOWN_MODELS: List[ModelRecord] = [ { 'formal_name': 'meta/llama3-70b-instruct', diff --git a/src/lasagna/known_providers.py b/src/lasagna/known_providers.py index 7ef116c..9d67a6e 100644 --- a/src/lasagna/known_providers.py +++ b/src/lasagna/known_providers.py @@ -14,6 +14,7 @@ def attempt_load_all_known_providers() -> None: providers_to_try_to_load = [ 'openai', 'anthropic', + 'ollama', 'nvidia', ] for provider in providers_to_try_to_load: @@ -63,6 +64,16 @@ def attempt_load_known_providers(provider: str) -> None: models = ANTHROPIC_KNOWN_MODELS, ) + elif provider == 'ollama': + from .known_models import OLLAMA_KNOWN_MODELS + from .lasagna_ollama import LasagnaOllama + register_provider( + key = 'ollama', + name = 'Ollama', + factory = LasagnaOllama, + models = OLLAMA_KNOWN_MODELS, + ) + elif provider == 'nvidia': from .known_models import NVIDIA_KNOWN_MODELS from .lasagna_nvidia import LasagnaNVIDIA diff --git a/src/lasagna/lasagna_ollama.py b/src/lasagna/lasagna_ollama.py new file mode 100644 index 0000000..11a85b0 --- /dev/null +++ b/src/lasagna/lasagna_ollama.py @@ -0,0 +1,351 @@ +""" +This module is the Lasagna adapter for the Ollama server. + +For more information about Ollama, see: + - https://ollama.com/ +""" + +from .types import ( + ModelSpec, + Message, + Media, + EventCallback, + ToolCall, + Model, + ExtractionType, +) + +from .util import ( + convert_to_image_base64, + exponential_backoff_retry_delays, + get_name, + recursive_hash, +) + +from .tools_util import ( + convert_to_json_schema, + extract_tool_result_as_sting, + get_tool_params, + handle_tools, + build_tool_response_message, +) + +from .pydantic_util import ensure_pydantic_model, build_and_validate + +from .known_models import OLLAMA_KNOWN_MODELS + +from openai.lib._pydantic import to_strict_json_schema + +from typing import ( + List, Callable, AsyncIterator, Any, Type, + Tuple, Dict, Union, + cast, +) + +import os +import asyncio +import httpx +import copy +import json + +import logging + +_LOG = logging.getLogger(__name__) + + +def _convert_to_ollama_tool(tool: Callable) -> Dict: + description, params = get_tool_params(tool) + return { + 'name': get_name(tool), + 'description': description, + 'input_schema': convert_to_json_schema(params), + } + + +def _convert_to_ollama_tools(tools: List[Callable]) -> Union[None, List[Dict]]: + if len(tools) == 0: + return None + specs = [_convert_to_ollama_tool(tool) for tool in tools] + return specs + + +def _log_dumps(val: Any) -> str: + if isinstance(val, dict): + return json.dumps(val) + else: + return str(val) + + +async def _convert_to_ollama_media(media: List[Media]) -> Dict: + # Ollama only supports *images* as media. But, so does lasagna, so all good. + res: Dict = { + 'images': [], + } + for m in media: + assert m['type'] == 'image' + mimetype, data = await convert_to_image_base64(m['image']) + assert mimetype + res['images'].append(data) + return res + + +def _convert_to_ollama_tool_calls(tools: List[ToolCall]) -> List: + res: List = [] + for t in tools: + assert t['call_type'] == 'function' + res.append({ + # t['call_id'] NOT USED! Ollama doesn't do that sort of thing. + 'function': { + 'name': t['function']['name'], + 'arguments': json.loads(t['function']['arguments']), + }, + }) + return res + + +async def _convert_to_ollama_messages(messages: List[Message]) -> List[Dict]: + res: List[Dict] = [] + map = { + 'system': 'system', + 'human': 'user', + 'ai': 'assistant', + } + for m in messages: + if m['role'] == 'system' or m['role'] == 'human' or m['role'] == 'ai': # <-- not using boolean 'in' to make mypy happy + media = {} + if 'media' in m and m['media']: + media = await _convert_to_ollama_media(m['media']) + res.append({ + 'role': map[m['role']], + 'content': m.get('text', ''), + **media, + }) + elif m['role'] == 'tool_call': + res.append({ + 'role': 'assistant', + 'content': '', + 'tool_calls': _convert_to_ollama_tool_calls(m['tools']), + }) + elif m['role'] == 'tool_res': + for t in m['tools']: + # t['is_error'] IS NOT USED! + res.append({ + 'role': 'tool', + 'content': extract_tool_result_as_sting(t), + 'name': t['call_id'], # <-- weird, I know, but Ollama sort of uses the name as the call id, so we do that too + }) + else: + raise RuntimeError(f"unreachable: {m['role']}") + return res + + +async def _event_stream(url: str, payload: Dict) -> AsyncIterator[Dict]: + async with httpx.AsyncClient() as client: + async with client.stream('POST', url, json=payload) as r: + if r.status_code != 200: + error_text = json.loads(await r.aread())['error'] + raise RuntimeError(f'Ollama error: {error_text}') + async for line in r.aiter_lines(): + rec = json.loads(line) + if rec.get('error'): + error_text = rec['error'] + raise RuntimeError(f'Ollama error: {error_text}') + yield rec + + +async def _process_stream( + stream: AsyncIterator[Dict], + event_callback: EventCallback, +) -> List[Message]: + async for event in stream: + if 'message' in event: + m = event['message'] + if 'content' in m: + c = m['content'] + assert isinstance(c, str) + await event_callback(('ai', 'text_event', c)) # TODO + return [] # TODO + + +class LasagnaOllama(Model): + def __init__(self, model: str, **model_kwargs: Dict[str, Any]): + known_model_names = [m['formal_name'] for m in OLLAMA_KNOWN_MODELS] + if model not in known_model_names: + _LOG.warning(f'untested model: {model} (may or may not work)') + self.model = model + self.model_kwargs = copy.deepcopy(model_kwargs or {}) + self.n_retries: int = cast(int, self.model_kwargs['retries']) if 'retries' in self.model_kwargs else 3 + if not isinstance(self.n_retries, int) or self.n_retries < 0: + raise ValueError(f"model_kwargs['retries'] must be a non-negative integer (got {self.model_kwargs['retries']})") + self.model_spec: ModelSpec = { + 'provider': 'ollama', + 'model': self.model, + 'model_kwargs': self.model_kwargs, + } + self.base_url = model_kwargs.get('base_url', os.environ.get('OLLAMA_BASE_URL', 'http://127.0.0.1:11434')) + + def config_hash(self) -> str: + return recursive_hash(None, { + 'provider': 'ollama', + 'model': self.model, + 'model_kwargs': self.model_kwargs, + }) + + async def _run_once( + self, + event_callback: EventCallback, + messages: List[Message], + tools_spec: Union[None, List[Dict]], + force_tool: bool, + ) -> List[Message]: + stream = True + tools = None + + if tools_spec and len(tools_spec) > 0: + stream = False # Ollama does not yet support streaming tool responses. + tools = tools_spec + if not force_tool: + raise ValueError("Oops! Ollama currently does not support *optional* tool use. Thus, if you pass tools, you must also pass `force_tool=True` to show that your intended use matches Ollama's behavior.") + + ollama_messages = await _convert_to_ollama_messages(messages) + + _LOG.info(f"Invoking {self.model} with:\n messages: {_log_dumps(ollama_messages)}\n tools: {_log_dumps(tools)}") + + url = f'{self.base_url}/api/chat' + + payload = { + 'model': self.model, + 'messages': ollama_messages, + 'stream': stream, + **({'tools': tools} if tools is not None else {}), + } + + event_stream = _event_stream(url, payload) + new_messages = await _process_stream(event_stream, event_callback) + + _LOG.info(f"Finished {self.model}") + + return new_messages + + async def _retrying_run_once( + self, + event_callback: EventCallback, + messages: List[Message], + tools_spec: Union[None, List[Dict]], + force_tool: bool, + ) -> List[Message]: + last_error: Union[Exception, None] = None + assert self.n_retries + 1 > 0 # <-- we know this is true from the check in __init__ + for delay_on_error in exponential_backoff_retry_delays(self.n_retries + 1): + try: + await event_callback(('transaction', 'start', ('ollama', self.model))) + try: + new_messages = await self._run_once( + event_callback = event_callback, + messages = messages, + tools_spec = tools_spec, + force_tool = force_tool, + ) + except: + await event_callback(('transaction', 'rollback', None)) + raise + await event_callback(('transaction', 'commit', None)) + return new_messages + except Exception as e: + # Some errors should be retried, some should not. Below + # is the logic to decide when to retry vs when to not. + # It's likely this will change as we get more usage and see + # where Ollama tends to fail, and when we know more what is + # recoverable vs not. + last_error = e + if isinstance(e, httpx.HTTPError): + # Network connection error. We can retry this. + pass + else: + # Some other error that we don't know about. Let's bail. + raise + if delay_on_error > 0.0: + _LOG.warning(f"Got a maybe-recoverable error (will retry in {delay_on_error:.2f} seconds): {e}") + await asyncio.sleep(delay_on_error) + assert last_error is not None # <-- we know this is true because `n_retries + 1 > 0` + raise last_error + + async def run( + self, + event_callback: EventCallback, + messages: List[Message], + tools: List[Callable], + force_tool: bool = False, + max_tool_iters: int = 5, + ) -> List[Message]: + messages = [*messages] # shallow copy + new_messages: List[Message] = [] + tools_spec = _convert_to_ollama_tools(tools) + tools_map = {get_name(tool): tool for tool in tools} + for _ in range(max_tool_iters): + new_messages_here = await self._retrying_run_once( + event_callback = event_callback, + messages = messages, + tools_spec = tools_spec, + force_tool = force_tool, + ) + tools_results = await handle_tools( + prev_messages = messages, + new_messages = new_messages_here, + tools_map = tools_map, + event_callback = event_callback, + model_spec = self.model_spec, + ) + new_messages.extend(new_messages_here) + messages.extend(new_messages_here) + if tools_results is None: + break + for tool_result in tools_results: + await event_callback(('tool_res', 'tool_res_event', tool_result)) + tool_response_message = build_tool_response_message(tools_results) + new_messages.append(tool_response_message) + messages.append(tool_response_message) + return new_messages + + async def extract( + self, + event_callback: EventCallback, + messages: List[Message], + extraction_type: Type[ExtractionType], + ) -> Tuple[Message, ExtractionType]: + tools_spec: List[Dict] = [ + { + 'name': get_name(extraction_type), + 'input_schema': to_strict_json_schema(ensure_pydantic_model(extraction_type)), + }, + ] + + # TODO: use the `format` feature of Ollama + + docstr = getattr(extraction_type, '__doc__', None) + if docstr: + tools_spec[0]['description'] = docstr + + new_messages = await self._retrying_run_once( + event_callback = event_callback, + messages = messages, + tools_spec = tools_spec, + force_tool = True, + ) + + assert len(new_messages) == 1 + new_message = new_messages[0] + + if new_message['role'] == 'tool_call': + tools = new_message['tools'] + + assert len(tools) == 1 + parsed = json.loads(tools[0]['function']['arguments']) + result = build_and_validate(extraction_type, parsed) + + return new_message, result + + else: + assert new_message['role'] == 'ai' + text = new_message['text'] + raise RuntimeError(f"Model failed to generate structured output; instead, it output: {text}") From 13960d7ffe085739fe3ac618c5dece6c12c84888 Mon Sep 17 00:00:00 2001 From: Ryan Henning Date: Tue, 28 Jan 2025 17:12:51 -0600 Subject: [PATCH 2/4] Cont' --- src/lasagna/lasagna_ollama.py | 187 ++++++++--- tests/test_lasagna_ollama.py | 565 ++++++++++++++++++++++++++++++++++ 2 files changed, 707 insertions(+), 45 deletions(-) create mode 100644 tests/test_lasagna_ollama.py diff --git a/src/lasagna/lasagna_ollama.py b/src/lasagna/lasagna_ollama.py index 11a85b0..de9f051 100644 --- a/src/lasagna/lasagna_ollama.py +++ b/src/lasagna/lasagna_ollama.py @@ -6,6 +6,9 @@ """ from .types import ( + Cost, + EventPayload, + MessageToolCall, ModelSpec, Message, Media, @@ -32,8 +35,6 @@ from .pydantic_util import ensure_pydantic_model, build_and_validate -from .known_models import OLLAMA_KNOWN_MODELS - from openai.lib._pydantic import to_strict_json_schema from typing import ( @@ -56,9 +57,12 @@ def _convert_to_ollama_tool(tool: Callable) -> Dict: description, params = get_tool_params(tool) return { - 'name': get_name(tool), - 'description': description, - 'input_schema': convert_to_json_schema(params), + 'type': 'function', + 'function': { + 'name': get_name(tool), + 'description': description, + 'parameters': convert_to_json_schema(params), + }, } @@ -77,7 +81,6 @@ def _log_dumps(val: Any) -> str: async def _convert_to_ollama_media(media: List[Media]) -> Dict: - # Ollama only supports *images* as media. But, so does lasagna, so all good. res: Dict = { 'images': [], } @@ -110,6 +113,7 @@ async def _convert_to_ollama_messages(messages: List[Message]) -> List[Dict]: 'human': 'user', 'ai': 'assistant', } + prev_tool_call_map = {} for m in messages: if m['role'] == 'system' or m['role'] == 'human' or m['role'] == 'ai': # <-- not using boolean 'in' to make mypy happy media = {} @@ -126,13 +130,17 @@ async def _convert_to_ollama_messages(messages: List[Message]) -> List[Dict]: 'content': '', 'tool_calls': _convert_to_ollama_tool_calls(m['tools']), }) + prev_tool_call_map = { + t['call_id']: t['function']['name'] + for t in m['tools'] + } elif m['role'] == 'tool_res': for t in m['tools']: # t['is_error'] IS NOT USED! res.append({ 'role': 'tool', 'content': extract_tool_result_as_sting(t), - 'name': t['call_id'], # <-- weird, I know, but Ollama sort of uses the name as the call id, so we do that too + 'name': prev_tool_call_map.get(t['call_id'], 'unknown'), }) else: raise RuntimeError(f"unreachable: {m['role']}") @@ -144,34 +152,111 @@ async def _event_stream(url: str, payload: Dict) -> AsyncIterator[Dict]: async with client.stream('POST', url, json=payload) as r: if r.status_code != 200: error_text = json.loads(await r.aread())['error'] - raise RuntimeError(f'Ollama error: {error_text}') + raise RuntimeError(f'Ollama error: {error_text}') # TODO TEST ME async for line in r.aiter_lines(): rec = json.loads(line) if rec.get('error'): error_text = rec['error'] - raise RuntimeError(f'Ollama error: {error_text}') + raise RuntimeError(f'Ollama error: {error_text}') # TODO TEST ME yield rec +def _set_cost_raw(message: Message, raw: List[Dict]) -> Message: + cost: Cost = { + 'input_tokens': None, + 'output_tokens': None, + 'total_tokens': None, + } + for event in raw: + if 'prompt_eval_count' in event: + cost['input_tokens'] = (cost['input_tokens'] or 0) + event['prompt_eval_count'] + if 'eval_count' in event: + cost['output_tokens'] = (cost['output_tokens'] or 0) + event['eval_count'] + cost['total_tokens'] = (cost['input_tokens'] or 0) + (cost['output_tokens'] or 0) + new_message: Message = copy.copy(message) + if (cost['total_tokens'] or 0) > 0: + new_message['cost'] = cost + new_message['raw'] = raw + return new_message + + async def _process_stream( stream: AsyncIterator[Dict], event_callback: EventCallback, ) -> List[Message]: + raw: List[Dict] = [] + content: List[str] = [] + tools: List[ToolCall] = [] async for event in stream: + raw.append(event) if 'message' in event: m = event['message'] if 'content' in m: c = m['content'] - assert isinstance(c, str) - await event_callback(('ai', 'text_event', c)) # TODO - return [] # TODO + if c: + assert isinstance(c, str) + await event_callback(('ai', 'text_event', c)) + content.append(c) + if 'tool_calls' in m: + for i, t in enumerate(m['tool_calls']): + assert 'function' in t + f = t['function'] + name = f['name'] + args = f['arguments'] # is a dict + args_str = json.dumps(args) + tool_call: ToolCall = { + 'call_id': f'call_{i}', # <-- Ollama doesn't do the `call_id` thing, so we invent a call_id + 'call_type': 'function', + 'function': { + 'name': name, + 'arguments': args_str, + }, + } + await event_callback(('tool_call', 'text_event', f'{name}({args_str})\n')) + await event_callback(('tool_call', 'tool_call_event', tool_call)) + tools.append(tool_call) + messages: List[Message] = [] + if content: + messages.append({ + 'role': 'ai', + 'text': ''.join(content), + }) + if tools: + messages.append({ + 'role': 'tool_call', + 'tools': tools, + }) + if len(messages) > 0: + messages[-1] = _set_cost_raw(messages[-1], raw) + return messages + + +def _wrap_event_callback_convert_ai_text_to_tool_call_text( + wrapped: EventCallback, +) -> EventCallback: + async def wrapper(event: EventPayload) -> None: + if event[0] == 'ai' and event[1] == 'text_event': + await wrapped(('tool_call', 'text_event', event[2])) + else: + await wrapped(event) + + return wrapper + + +def _get_ollama_format_for_structured_output( + extraction_type: Type[ExtractionType], +) -> Dict: + format: Dict = to_strict_json_schema(ensure_pydantic_model(extraction_type)) + + docstr = getattr(extraction_type, '__doc__', None) + if docstr: + format['description'] = docstr + + return format class LasagnaOllama(Model): def __init__(self, model: str, **model_kwargs: Dict[str, Any]): - known_model_names = [m['formal_name'] for m in OLLAMA_KNOWN_MODELS] - if model not in known_model_names: - _LOG.warning(f'untested model: {model} (may or may not work)') self.model = model self.model_kwargs = copy.deepcopy(model_kwargs or {}) self.n_retries: int = cast(int, self.model_kwargs['retries']) if 'retries' in self.model_kwargs else 3 @@ -197,19 +282,24 @@ async def _run_once( messages: List[Message], tools_spec: Union[None, List[Dict]], force_tool: bool, + format: Union[None, Dict], ) -> List[Message]: stream = True - tools = None - if tools_spec and len(tools_spec) > 0: + if tools_spec and format: + raise ValueError("Oops! You cannot do both tool-use and structured output at the same time!") + + if tools_spec: stream = False # Ollama does not yet support streaming tool responses. - tools = tools_spec if not force_tool: raise ValueError("Oops! Ollama currently does not support *optional* tool use. Thus, if you pass tools, you must also pass `force_tool=True` to show that your intended use matches Ollama's behavior.") + else: + if force_tool: + raise ValueError("Oops! You cannot force tools that are not specified!") ollama_messages = await _convert_to_ollama_messages(messages) - _LOG.info(f"Invoking {self.model} with:\n messages: {_log_dumps(ollama_messages)}\n tools: {_log_dumps(tools)}") + _LOG.info(f"Invoking {self.model} with:\n messages: {_log_dumps(ollama_messages)}\n tools: {_log_dumps(tools_spec)}") url = f'{self.base_url}/api/chat' @@ -217,7 +307,8 @@ async def _run_once( 'model': self.model, 'messages': ollama_messages, 'stream': stream, - **({'tools': tools} if tools is not None else {}), + **({'tools': tools_spec} if tools_spec else {}), + **({'format': format} if format else {}), } event_stream = _event_stream(url, payload) @@ -233,6 +324,7 @@ async def _retrying_run_once( messages: List[Message], tools_spec: Union[None, List[Dict]], force_tool: bool, + format: Union[None, Dict], ) -> List[Message]: last_error: Union[Exception, None] = None assert self.n_retries + 1 > 0 # <-- we know this is true from the check in __init__ @@ -245,6 +337,7 @@ async def _retrying_run_once( messages = messages, tools_spec = tools_spec, force_tool = force_tool, + format = format, ) except: await event_callback(('transaction', 'rollback', None)) @@ -288,6 +381,7 @@ async def run( messages = messages, tools_spec = tools_spec, force_tool = force_tool, + format = None, ) tools_results = await handle_tools( prev_messages = messages, @@ -313,39 +407,42 @@ async def extract( messages: List[Message], extraction_type: Type[ExtractionType], ) -> Tuple[Message, ExtractionType]: - tools_spec: List[Dict] = [ - { - 'name': get_name(extraction_type), - 'input_schema': to_strict_json_schema(ensure_pydantic_model(extraction_type)), - }, - ] - - # TODO: use the `format` feature of Ollama - - docstr = getattr(extraction_type, '__doc__', None) - if docstr: - tools_spec[0]['description'] = docstr + format = _get_ollama_format_for_structured_output(extraction_type) new_messages = await self._retrying_run_once( - event_callback = event_callback, + event_callback = _wrap_event_callback_convert_ai_text_to_tool_call_text(event_callback), messages = messages, - tools_spec = tools_spec, - force_tool = True, + tools_spec = None, + force_tool = False, + format = format, ) assert len(new_messages) == 1 new_message = new_messages[0] - if new_message['role'] == 'tool_call': - tools = new_message['tools'] + assert new_message['role'] == 'ai' # Ollama generates structured output just like it does normal text, so at this point it is just text. + + arguments = new_message['text'] or '{}' + + tool_message: MessageToolCall = { + 'role': 'tool_call', + 'tools': [ + { + 'call_id': 'extraction_tool_call', + 'call_type': 'function', + 'function': { + 'name': '...', # TODO + 'arguments': arguments, + }, + }, + ], + #'cost': new_message.get('cost', None), # TODO + #'raw': new_message.get('raw', None), # TODO + } - assert len(tools) == 1 - parsed = json.loads(tools[0]['function']['arguments']) - result = build_and_validate(extraction_type, parsed) + await event_callback(('tool_call', 'tool_call_event', tool_message['tools'][0])) - return new_message, result + parsed = json.loads(arguments) + result = build_and_validate(extraction_type, parsed) - else: - assert new_message['role'] == 'ai' - text = new_message['text'] - raise RuntimeError(f"Model failed to generate structured output; instead, it output: {text}") + return tool_message, result diff --git a/tests/test_lasagna_ollama.py b/tests/test_lasagna_ollama.py new file mode 100644 index 0000000..5765a19 --- /dev/null +++ b/tests/test_lasagna_ollama.py @@ -0,0 +1,565 @@ +import pytest + +import os +import tempfile + +from typing import List, Dict, Literal +from typing_extensions import TypedDict + +from pydantic import BaseModel + +from lasagna.types import EventPayload, Message, ToolCall + +from lasagna.stream import fake_async + +from lasagna.lasagna_ollama import ( + _convert_to_ollama_tool, + _convert_to_ollama_tools, + _convert_to_ollama_media, + _convert_to_ollama_tool_calls, + _convert_to_ollama_messages, + _set_cost_raw, + _process_stream, + _wrap_event_callback_convert_ai_text_to_tool_call_text, + _get_ollama_format_for_structured_output, + LasagnaOllama, +) + + +def sample_tool(a: int, b: str = 'hi') -> str: + """ + A sample tool. + :param: a: int: the value of a + :param: b: str: (optional) the value of b + """ + return b * a + + +_sample_tool_correct_schema = { + 'type': 'function', + 'function': { + 'name': 'sample_tool', + 'description': 'A sample tool.', + 'parameters': { + 'type': 'object', + 'properties': { + 'a': { + 'type': 'integer', + 'description': 'the value of a', + }, + 'b': { + 'type': 'string', + 'description': '(optional) the value of b', + }, + }, + 'required': ['a'], + 'additionalProperties': False, + }, + }, +} + + +_sample_events_streaming_text: List[Dict] = [ + {'message': {'role': 'assistant', 'content': 'Hello'}, 'done': False}, + {'message': {'role': 'assistant', 'content': '!'}, 'done': False}, + {'message': {'role': 'assistant', 'content': ' How'}, 'done': False}, + {'message': {'role': 'assistant', 'content': ' can'}, 'done': False}, + {'message': {'role': 'assistant', 'content': ' I'}, 'done': False}, + {'message': {'role': 'assistant', 'content': ' assist'}, 'done': False}, + {'message': {'role': 'assistant', 'content': ' you'}, 'done': False}, + {'message': {'role': 'assistant', 'content': ' today'}, 'done': False}, + {'message': {'role': 'assistant', 'content': '?'}, 'done': False}, + {'message': {'role': 'assistant', 'content': ''}, 'done_reason': 'stop', 'done': True, 'total_duration': 290565553, 'load_duration': 4737247, 'prompt_eval_count': 20, 'prompt_eval_duration': 27000000, 'eval_count': 10, 'eval_duration': 257000000}, +] + + +_sample_events_tool_call: List[Dict] = [ + # Ollama cannot stream in this case, so we get a single message here: + {'message': {'role': 'assistant', 'content': '', 'tool_calls': [{'function': {'name': 'evaluate_math_expression', 'arguments': {'expression': '2.5 * sin(4.5)'}}}]}, 'done_reason': 'stop', 'done': True, 'total_duration': 799653463, 'load_duration': 4977833, 'prompt_eval_count': 100, 'prompt_eval_duration': 27000000, 'eval_count': 29, 'eval_duration': 765000000}, +] + + +_sample_events_structued_output: List[Dict] = [ + # Ollama cannot stream in this case, so we get a single message here: + {'message': {'role': 'assistant', 'content': '{"base":2,"exponent":101}'}, 'done_reason': 'stop', 'done': True, 'total_duration': 426267794, 'load_duration': 4418431, 'prompt_eval_count': 28, 'prompt_eval_duration': 21000000, 'eval_count': 15, 'eval_duration': 399000000}, +] + + +def test_convert_to_ollama_tool(): + assert _convert_to_ollama_tool(sample_tool) == _sample_tool_correct_schema + + +def test_convert_to_ollama_tools(): + assert _convert_to_ollama_tools([]) is None + assert _convert_to_ollama_tools([sample_tool]) == [ + _sample_tool_correct_schema, + ] + + +@pytest.mark.asyncio +async def test_convert_to_ollama_media(): + with tempfile.TemporaryDirectory() as tmp: + fn1 = os.path.join(tmp, 'a.png') + fn2 = os.path.join(tmp, 'b.png') + with open(fn1, 'wb') as f: + f.write(b'1234') + with open(fn2, 'wb') as f: + f.write(b'1235') + assert await _convert_to_ollama_media([ + { + 'type': 'image', + 'image': fn1, + }, + { + 'type': 'image', + 'image': fn2, + }, + ]) == { + 'images': ['MTIzNA==', 'MTIzNQ=='], + } + + +def test_convert_to_ollama_tool_calls(): + assert _convert_to_ollama_tool_calls([ + { + 'call_id': 'abc', + 'call_type': 'function', + 'function': { + 'name': 'foo', + 'arguments': '{"a": 7}', + }, + }, + { + 'call_id': 'xyz', + 'call_type': 'function', + 'function': { + 'name': 'bar', + 'arguments': '{"b": 42}', + }, + }, + ]) == [ + { + 'function': { + 'name': 'foo', + 'arguments': { + 'a': 7, + }, + }, + }, + { + 'function': { + 'name': 'bar', + 'arguments': { + 'b': 42, + }, + }, + }, + ] + + +@pytest.mark.asyncio +async def test_convert_to_ollama_messages(): + with tempfile.TemporaryDirectory() as tmp: + fn = os.path.join(tmp, 'a.png') + with open(fn, 'wb') as f: + f.write(b'1234') + assert await _convert_to_ollama_messages([ + { + 'role': 'system', + 'text': 'You are a robot.', + }, + { + 'role': 'human', + 'text': 'What is in this image?', + 'media': [ + { + 'type': 'image', + 'image': fn, + }, + ], + }, + { + 'role': 'ai', + 'text': 'Nothing!', + }, + { + 'role': 'tool_call', + 'tools': [ + { + 'call_id': 'abc', + 'call_type': 'function', + 'function': { + 'name': 'foo', + 'arguments': '{"a": 7}', + }, + }, + { + 'call_id': 'xyz', + 'call_type': 'function', + 'function': { + 'name': 'bar', + 'arguments': '{"b": 42}', + }, + }, + ], + }, + { + 'role': 'tool_res', + 'tools': [ + { + 'call_id': 'xyz', + 'type': 'any', + 'result': True, + }, + { + 'call_id': 'abc', + 'type': 'any', + 'result': 'it worked', + }, + { + 'call_id': 'dne', + 'type': 'any', + 'result': 'ghost result', + }, + ], + }, + ]) == [ + { + 'role': 'system', + 'content': 'You are a robot.', + }, + { + 'role': 'user', + 'content': 'What is in this image?', + 'images': ['MTIzNA=='], + }, + { + 'role': 'assistant', + 'content': 'Nothing!', + }, + { + 'role': 'assistant', + 'content': '', + 'tool_calls': [ + { + 'function': { + 'name': 'foo', + 'arguments': { + 'a': 7, + }, + }, + }, + { + 'function': { + 'name': 'bar', + 'arguments': { + 'b': 42, + }, + }, + }, + ], + }, + { + 'role': 'tool', + 'content': 'True', + 'name': 'bar', + }, + { + 'role': 'tool', + 'content': 'it worked', + 'name': 'foo', + }, + { + 'role': 'tool', + 'content': 'ghost result', + 'name': 'unknown', + }, + ] + + +def test_set_cost_raw(): + orig_message: Message = { + 'role': 'ai', + 'text': 'Hi', + } + assert _set_cost_raw(orig_message, _sample_events_streaming_text) == { + 'role': 'ai', + 'text': 'Hi', + 'cost': { + 'input_tokens': 20, + 'output_tokens': 10, + 'total_tokens': 30, + }, + 'raw': _sample_events_streaming_text, + } + assert orig_message == { # <-- orig message should be unmodified! + 'role': 'ai', + 'text': 'Hi', + } + + +@pytest.mark.asyncio +async def test_process_stream_streaming_text(): + events = [] + async def callback(event: EventPayload) -> None: + events.append(event) + messages = await _process_stream( + fake_async(_sample_events_streaming_text), + callback, + ) + assert events == [ + ('ai', 'text_event', 'Hello'), + ('ai', 'text_event', '!'), + ('ai', 'text_event', ' How'), + ('ai', 'text_event', ' can'), + ('ai', 'text_event', ' I'), + ('ai', 'text_event', ' assist'), + ('ai', 'text_event', ' you'), + ('ai', 'text_event', ' today'), + ('ai', 'text_event', '?'), + ] + assert messages == [ + { + 'role': 'ai', + 'text': 'Hello! How can I assist you today?', + 'cost': { + 'input_tokens': 20, + 'output_tokens': 10, + 'total_tokens': 30, + }, + 'raw': _sample_events_streaming_text, + }, + ] + + +@pytest.mark.asyncio +async def test_process_stream_tool_call(): + events = [] + async def callback(event: EventPayload) -> None: + events.append(event) + messages = await _process_stream( + fake_async(_sample_events_tool_call), + callback, + ) + tool_call: ToolCall = { + 'call_id': 'call_0', + 'call_type': 'function', + 'function': { + 'name': 'evaluate_math_expression', + 'arguments': '{"expression": "2.5 * sin(4.5)"}', + }, + } + assert events == [ + ('tool_call', 'text_event', 'evaluate_math_expression({"expression": "2.5 * sin(4.5)"})\n'), + ('tool_call', 'tool_call_event', tool_call), + ] + assert messages == [ + { + 'role': 'tool_call', + 'tools': [tool_call], + 'cost': { + 'input_tokens': 100, + 'output_tokens': 29, + 'total_tokens': 129, + }, + 'raw': _sample_events_tool_call, + }, + ] + + +@pytest.mark.asyncio +async def test_process_stream_text_and_tool_call(): + _sample_events = [ + *_sample_events_streaming_text, + *_sample_events_tool_call, + ] + events = [] + async def callback(event: EventPayload) -> None: + events.append(event) + messages = await _process_stream( + fake_async(_sample_events), + callback, + ) + tool_call: ToolCall = { + 'call_id': 'call_0', + 'call_type': 'function', + 'function': { + 'name': 'evaluate_math_expression', + 'arguments': '{"expression": "2.5 * sin(4.5)"}', + }, + } + assert events == [ + ('ai', 'text_event', 'Hello'), + ('ai', 'text_event', '!'), + ('ai', 'text_event', ' How'), + ('ai', 'text_event', ' can'), + ('ai', 'text_event', ' I'), + ('ai', 'text_event', ' assist'), + ('ai', 'text_event', ' you'), + ('ai', 'text_event', ' today'), + ('ai', 'text_event', '?'), + ('tool_call', 'text_event', 'evaluate_math_expression({"expression": "2.5 * sin(4.5)"})\n'), + ('tool_call', 'tool_call_event', tool_call), + ] + assert messages == [ + { + 'role': 'ai', + 'text': 'Hello! How can I assist you today?', + }, + { + 'role': 'tool_call', + 'tools': [tool_call], + 'cost': { + 'input_tokens': 20 + 100, + 'output_tokens': 10 + 29, + 'total_tokens': 30 + 129, + }, + 'raw': _sample_events, + }, + ] + + +@pytest.mark.asyncio +async def test_process_stream_structured_output(): + events = [] + async def callback(event: EventPayload) -> None: + events.append(event) + messages = await _process_stream( + fake_async(_sample_events_structued_output), + _wrap_event_callback_convert_ai_text_to_tool_call_text(callback), + ) + assert events == [ + ('tool_call', 'text_event', '{"base":2,"exponent":101}'), + ] + assert messages == [ + { + 'role': 'ai', + 'text': '{"base":2,"exponent":101}', + 'cost': { + 'input_tokens': 28, + 'output_tokens': 15, + 'total_tokens': 43, + }, + 'raw': _sample_events_structued_output, + }, + ] + + +class MyTypedDict(TypedDict): + """My special type which is a TypedDict""" + a: str + b: int + c: Literal['one', 'two', 'three'] + +class MyPydanticType(BaseModel): + """My special type which is a Pydantic thing""" + a: str + b: int + c: Literal['one', 'two', 'three'] + d: MyTypedDict + +def test_get_ollama_format_for_structured_output(): + assert _get_ollama_format_for_structured_output(MyTypedDict) == { + 'type': 'object', + 'title': 'MyTypedDict', + 'description': 'My special type which is a TypedDict', + 'properties': { + 'a': { + 'type': 'string', + 'title': 'A', + }, + 'b': { + 'type': 'integer', + 'title': 'B', + }, + 'c': { + 'type': 'string', + 'title': 'C', + 'enum': ['one', 'two', 'three'], + }, + }, + 'required': ['a', 'b', 'c'], + 'additionalProperties': False, + } + assert _get_ollama_format_for_structured_output(MyPydanticType) == { + 'type': 'object', + 'title': 'MyPydanticType', + 'description': 'My special type which is a Pydantic thing', + 'properties': { + 'a': { + 'type': 'string', + 'title': 'A', + }, + 'b': { + 'type': 'integer', + 'title': 'B', + }, + 'c': { + 'type': 'string', + 'title': 'C', + 'enum': ['one', 'two', 'three'], + }, + 'd': { + '$ref': '#/$defs/MyTypedDict', + }, + }, + 'required': ['a', 'b', 'c', 'd'], + 'additionalProperties': False, + '$defs': { + 'MyTypedDict': { + 'type': 'object', + 'title': 'MyTypedDict', + 'description': 'My special type which is a TypedDict', + 'properties': { + 'a': { + 'type': 'string', + 'title': 'A', + }, + 'b': { + 'type': 'integer', + 'title': 'B', + }, + 'c': { + 'type': 'string', + 'title': 'C', + 'enum': ['one', 'two', 'three'], + }, + }, + 'required': ['a', 'b', 'c'], + 'additionalProperties': False, + }, + }, + } + + +@pytest.mark.asyncio +async def test_LasagnaOllama_run_once_invalid_arguments(): + async def callback(event: EventPayload) -> None: + assert event + + with pytest.raises(ValueError, match=r"Oops! You cannot do both tool-use and structured output at the same time!"): + await LasagnaOllama(model='mistral-small')._run_once( + event_callback = callback, + messages = [], + tools_spec = _convert_to_ollama_tools([sample_tool]), + force_tool = True, + format = {'type': 'object'}, + ) + + with pytest.raises(ValueError, match=r"Oops! Ollama currently does not support \*optional\* tool use."): + await LasagnaOllama(model='mistral-small')._run_once( + event_callback = callback, + messages = [], + tools_spec = _convert_to_ollama_tools([sample_tool]), + force_tool = False, + format = None, + ) + + with pytest.raises(ValueError, match=r"Oops! You cannot force tools that are not specified!"): + await LasagnaOllama(model='mistral-small')._run_once( + event_callback = callback, + messages = [], + tools_spec = None, + force_tool = True, + format = None, + ) From a5873d89fefb2a7aabfb53c14c20c00b86e13c7d Mon Sep 17 00:00:00 2001 From: Ryan Henning Date: Tue, 28 Jan 2025 17:26:11 -0600 Subject: [PATCH 3/4] Better --- src/lasagna/lasagna_ollama.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/src/lasagna/lasagna_ollama.py b/src/lasagna/lasagna_ollama.py index de9f051..7d3ccdd 100644 --- a/src/lasagna/lasagna_ollama.py +++ b/src/lasagna/lasagna_ollama.py @@ -152,12 +152,12 @@ async def _event_stream(url: str, payload: Dict) -> AsyncIterator[Dict]: async with client.stream('POST', url, json=payload) as r: if r.status_code != 200: error_text = json.loads(await r.aread())['error'] - raise RuntimeError(f'Ollama error: {error_text}') # TODO TEST ME + raise RuntimeError(f'Ollama error: {error_text}') async for line in r.aiter_lines(): rec = json.loads(line) if rec.get('error'): error_text = rec['error'] - raise RuntimeError(f'Ollama error: {error_text}') # TODO TEST ME + raise RuntimeError(f'Ollama error: {error_text}') yield rec @@ -424,23 +424,25 @@ async def extract( arguments = new_message['text'] or '{}' + tool_call: ToolCall = { + 'call_id': 'extraction_tool_call', + 'call_type': 'function', + 'function': { + 'name': get_name(extraction_type), + 'arguments': arguments, + }, + } + tool_message: MessageToolCall = { 'role': 'tool_call', - 'tools': [ - { - 'call_id': 'extraction_tool_call', - 'call_type': 'function', - 'function': { - 'name': '...', # TODO - 'arguments': arguments, - }, - }, - ], - #'cost': new_message.get('cost', None), # TODO - #'raw': new_message.get('raw', None), # TODO + 'tools': [tool_call], } + if 'cost' in new_message: + tool_message['cost'] = new_message['cost'] + if 'raw' in new_message: + tool_message['raw'] = new_message['raw'] - await event_callback(('tool_call', 'tool_call_event', tool_message['tools'][0])) + await event_callback(('tool_call', 'tool_call_event', tool_call)) parsed = json.loads(arguments) result = build_and_validate(extraction_type, parsed) From e48cc98ebf4f0eb13c0a4658cb48447879003021 Mon Sep 17 00:00:00 2001 From: Ryan Henning Date: Tue, 28 Jan 2025 17:36:29 -0600 Subject: [PATCH 4/4] Final touches --- examples/demo_structured_output.py | 2 +- src/lasagna/__init__.py | 2 +- tests/test_known_providers.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/demo_structured_output.py b/examples/demo_structured_output.py index c13f8f7..3e57847 100644 --- a/examples/demo_structured_output.py +++ b/examples/demo_structured_output.py @@ -15,7 +15,7 @@ from dotenv import load_dotenv; load_dotenv() -MODEL_BINDER = known_models.BIND_OPENAI_gpt_4o_mini() +MODEL_BINDER = known_models.BIND_OLLAMA_mistral_small() PROMPT = """ diff --git a/src/lasagna/__init__.py b/src/lasagna/__init__.py index 0a4ad5f..32aed44 100644 --- a/src/lasagna/__init__.py +++ b/src/lasagna/__init__.py @@ -43,4 +43,4 @@ 'build_static_output_agent', ] -__version__ = "0.10.4" +__version__ = "0.11.0" diff --git a/tests/test_known_providers.py b/tests/test_known_providers.py index c1ca0e1..49c317c 100644 --- a/tests/test_known_providers.py +++ b/tests/test_known_providers.py @@ -10,5 +10,5 @@ def test_attempt_load_known_providers(): AGENTS.clear() PROVIDERS.clear() attempt_load_all_known_providers() - assert len(PROVIDERS) == 3 + assert len(PROVIDERS) == 4 assert len(AGENTS) == 0