From 99dd207787b8f3e60101647fa52a94e82b73ea76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Pokorn=C3=BD?= Date: Thu, 2 Jan 2025 17:13:34 +0100 Subject: [PATCH] fix: switch to Pydantic JSON schema (#13) --- executor/Dockerfile | 2 + executor/requirements-skip.txt | 1 + pyproject.toml | 2 +- src/code_interpreter/health_check.py | 2 +- .../services/custom_tool_executor.py | 161 +++++++++++------- .../code_interpreter_servicer.py | 2 +- src/code_interpreter/services/http_server.py | 2 +- test/e2e/test_grpc.py | 28 ++- test/e2e/test_http.py | 30 +++- 9 files changed, 155 insertions(+), 75 deletions(-) diff --git a/executor/Dockerfile b/executor/Dockerfile index f424247..c9fffaa 100644 --- a/executor/Dockerfile +++ b/executor/Dockerfile @@ -73,6 +73,8 @@ RUN apk add --no-cache --repository=https://dl-cdn.alpinelinux.org/alpine/edge/t py3-pillow-pyc \ py3-pip \ py3-pip-pyc \ + py3-pydantic \ + py3-pydantic-pyc \ py3-pypandoc \ py3-pypandoc-pyc \ py3-scipy \ diff --git a/executor/requirements-skip.txt b/executor/requirements-skip.txt index 5a6d099..3fb2f1c 100644 --- a/executor/requirements-skip.txt +++ b/executor/requirements-skip.txt @@ -9,6 +9,7 @@ pandas pdf2image pikepdf pillow +pydantic pypandoc scipy sympy diff --git a/pyproject.toml b/pyproject.toml index 730374f..b707f8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "bee-code-interpreter" -version = "0.0.27" +version = "0.0.28" license = "Apache-2.0" description = "A gRPC service intended as a backend for an LLM that can run arbitrary pieces of Python code." authors = [ "Jan Pokorný ", "Tomáš Dvořák " ] diff --git a/src/code_interpreter/health_check.py b/src/code_interpreter/health_check.py index 4458278..fe8dfa8 100644 --- a/src/code_interpreter/health_check.py +++ b/src/code_interpreter/health_check.py @@ -45,7 +45,7 @@ def health_check(): assert ( CodeInterpreterServiceStub(channel) .Execute( - ExecuteRequest(executor_id="health-check", source_code="print(21 * 2)"), + ExecuteRequest(source_code="print(21 * 2)"), timeout=9999, # no need to timeout here -- k8s health checks have their own timeouts ) .stdout diff --git a/src/code_interpreter/services/custom_tool_executor.py b/src/code_interpreter/services/custom_tool_executor.py index 4440e74..272ed40 100644 --- a/src/code_interpreter/services/custom_tool_executor.py +++ b/src/code_interpreter/services/custom_tool_executor.py @@ -14,14 +14,13 @@ import ast from dataclasses import dataclass -import json import typing import inspect import re +import json import textwrap - -from pydantic import validate_call - +import pydantic +import pydantic.json_schema from code_interpreter.services.kubernetes_code_executor import KubernetesCodeExecutor @@ -29,7 +28,7 @@ class CustomTool: name: str description: str - input_schema: dict + input_schema: dict[str, typing.Any] @dataclass @@ -53,8 +52,8 @@ def parse(self, tool_source_code: str) -> CustomTool: The source code must contain a single function definition, optionally preceded by imports. The function must not have positional-only arguments, *args or **kwargs. The function arguments must have type annotations. The docstring must follow the ReST format -- :param something: and :return: directives are supported. - Supported types for input arguments: int, float, str, bool, typing.Any, list[...], dict[str, ...], typing.Tuple[...], typing.Optional[...], typing.Union[...], where ... is any of the supported types. - Supported types for return value: anything that can be JSON-serialized. + Function arguments will be converted to JSONSchema by Pydantic, so everything that can be (de)serialized through Pydantic can be used. + However, the imports that can be used in types are currently limited to `typing`, `pathlib` and `datetime` for safety reasons. """ try: *imports, function_def = ast.parse(textwrap.dedent(tool_source_code)).body @@ -99,12 +98,14 @@ def parse(self, tool_source_code: str) -> CustomTool: ast.get_docstring(function_def) or "" ) + namespace = _build_namespace(imports) + json_schema = { "$schema": "http://json-schema.org/draft-07/schema#", "type": "object", "title": function_def.name, "properties": { - arg.arg: _type_to_json_schema(arg.annotation) + arg.arg: _type_to_json_schema(arg.annotation, namespace) | ( {"description": param_description} if (param_description := param_descriptions.get(arg.arg)) @@ -153,11 +154,11 @@ def parse(self, tool_source_code: str) -> CustomTool: input_schema=json_schema, ) - @validate_call + @pydantic.validate_call async def execute( self, tool_source_code: str, - tool_input: dict[str, typing.Any], + tool_input_json: str, ) -> typing.Any: """ Execute the given custom tool with the given input. @@ -166,24 +167,22 @@ async def execute( The input is expected to be valid according to the input schema produced by the parse method. """ + clean_tool_source_code = textwrap.dedent(tool_source_code) + *imports, function_def = ast.parse(clean_tool_source_code).body + result = await self.code_executor.execute( source_code=f""" +# Import all tool dependencies here -- to aid the dependency detection +{"\n".join(ast.unparse(node) for node in imports if isinstance(node, (ast.Import, ast.ImportFrom)))} + +import pydantic import contextlib import json -# Import all tool dependencies here -- to aid the dependency detection -{ - "\n".join( - ast.unparse(node) - for node in ast.parse(textwrap.dedent(tool_source_code)).body - if isinstance(node, (ast.Import, ast.ImportFrom)) - ) -} - with contextlib.redirect_stdout(None): inner_globals = {{}} - exec(compile({repr(textwrap.dedent(tool_source_code))}, "", "exec"), inner_globals) - result = next(x for x in inner_globals.values() if getattr(x, '__module__', ...) is None)(**{repr(tool_input)}) + exec(compile({repr(clean_tool_source_code)}, "", "exec"), inner_globals) + result = pydantic.TypeAdapter(inner_globals[{repr(function_def.name)}]).validate_json({repr(tool_input_json)}) print(json.dumps(result)) """, @@ -195,50 +194,6 @@ async def execute( return json.loads(result.stdout) -def _type_to_json_schema(type_node: ast.AST) -> dict: - if isinstance(type_node, ast.Subscript): - type_node_name = ast.unparse(type_node.value) - if type_node_name == "list": - return {"type": "array", "items": _type_to_json_schema(type_node.slice)} - elif type_node_name == "dict" and isinstance(type_node.slice, ast.Tuple): - key_type_node, value_type_node = type_node.slice.elts - if ast.unparse(key_type_node) != "str": - raise ValueError(f"Unsupported type: {type_node}") - return { - "type": "object", - "additionalProperties": _type_to_json_schema(value_type_node), - } - elif type_node_name == "Optional" or type_node_name == "typing.Optional": - return {"anyOf": [{"type": "null"}, _type_to_json_schema(type_node.slice)]} - elif ( - type_node_name == "Union" or type_node_name == "typing.Union" - ) and isinstance(type_node.slice, ast.Tuple): - return {"anyOf": [_type_to_json_schema(el) for el in type_node.slice.elts]} - elif ( - type_node_name == "Tuple" or type_node_name == "typing.Tuple" - ) and isinstance(type_node.slice, ast.Tuple): - return { - "type": "array", - "minItems": len(type_node.slice.elts), - "items": [_type_to_json_schema(el) for el in type_node.slice.elts], - "additionalItems": False, - } - - type_node_name = ast.unparse(type_node) - if type_node_name == "int": - return {"type": "integer"} - elif type_node_name == "float": - return {"type": "number"} - elif type_node_name == "str": - return {"type": "string"} - elif type_node_name == "bool": - return {"type": "boolean"} - elif type_node_name == "Any" or type_node_name == "typing.Any": - return {"type": "array"} - else: - raise ValueError(f"Unsupported type: {type_node_name}") - - def _parse_docstring(docstring: str) -> typing.Tuple[str, str, dict[str, str]]: """ Parse a docstring in the ReST format and return the function description, return description and a dictionary of parameter descriptions. @@ -262,3 +217,79 @@ def _parse_docstring(docstring: str) -> typing.Tuple[str, str, dict[str, str]]: elif match := re.match(r"return: ((?:.|\n)+)", chunk, flags=re.MULTILINE): return_description = match.group(1) return fn_description, return_description, param_descriptions + + +def _build_namespace( + imports: list[ast.AST], + allowed_modules: set[str] = {"typing", "pathlib", "datetime"}, +) -> dict[str, typing.Any]: + namespace = { + "str": str, + "int": int, + "float": float, + "bool": bool, + "list": list, + "dict": dict, + "set": set, + "tuple": tuple, + } + + for node in imports: + if isinstance(node, ast.Import): + for name in node.names: + if name.name in allowed_modules: + namespace[name.asname or name.name] = __import__(name.name) + elif isinstance(node, ast.ImportFrom): + if node.module in allowed_modules: + module = __import__(node.module, fromlist=[n.name for n in node.names]) + for name in node.names: + namespace[name.asname or name.name] = getattr(module, name.name) + + return namespace + + +def _type_to_json_schema(type_ast: ast.AST, namespace: dict) -> dict: + type_str = ast.unparse(type_ast) + if not _is_safe_type_ast(type_ast): + raise CustomToolParseError([f"Invalid type annotation `{type_str}`"]) + try: + return pydantic.TypeAdapter(eval(type_str, namespace)).json_schema( + schema_generator=_GenerateJsonSchema + ) + except Exception as e: + raise CustomToolParseError([f"Error when parsing type `{type_str}`: {e}"]) + + +class _GenerateJsonSchema(pydantic.json_schema.GenerateJsonSchema): + schema_dialect = "http://json-schema.org/draft-07/schema#" + + def tuple_schema(self, schema): + # Use draft-07 syntax for tuples + schema = super().tuple_schema(schema) + if "prefixItems" in schema: + schema["items"] = schema.pop("prefixItems") + schema.pop("maxItems") + schema["additionalItems"] = False + return schema + + +def _is_safe_type_ast(node: ast.AST) -> bool: + match node: + case ast.Name(): + return True + case ast.Attribute(): + return _is_safe_type_ast(node.value) + case ast.Subscript(): + return _is_safe_type_ast(node.value) and _is_safe_type_ast(node.slice) + case ast.Tuple() | ast.List(): + return all(_is_safe_type_ast(elt) for elt in node.elts) + case ast.Constant(): + return isinstance(node.value, (str, int, float, bool, type(None))) + case ast.BinOp(): + return ( + isinstance(node.op, ast.BitOr) + and _is_safe_type_ast(node.left) + and _is_safe_type_ast(node.right) + ) + case _: + return False diff --git a/src/code_interpreter/services/grpc_servicers/code_interpreter_servicer.py b/src/code_interpreter/services/grpc_servicers/code_interpreter_servicer.py index a55399d..a07ae1c 100644 --- a/src/code_interpreter/services/grpc_servicers/code_interpreter_servicer.py +++ b/src/code_interpreter/services/grpc_servicers/code_interpreter_servicer.py @@ -120,7 +120,7 @@ async def ExecuteCustomTool( try: result = await self.custom_tool_executor.execute( - tool_input=json.loads(request.tool_input_json), + tool_input_json=request.tool_input_json, tool_source_code=request.tool_source_code, ) except CustomToolExecuteError as e: diff --git a/src/code_interpreter/services/http_server.py b/src/code_interpreter/services/http_server.py index c2254e1..d9732df 100644 --- a/src/code_interpreter/services/http_server.py +++ b/src/code_interpreter/services/http_server.py @@ -141,7 +141,7 @@ async def execute_custom_tool( "Executing custom tool with source code %s", request.tool_source_code ) result = await custom_tool_executor.execute( - tool_input=json.loads(request.tool_input_json), + tool_input_json=request.tool_input_json, tool_source_code=request.tool_source_code, ) logger.info("Executed custom tool with result %s", result) diff --git a/test/e2e/test_grpc.py b/test/e2e/test_grpc.py index 7da98e0..97aeb6b 100755 --- a/test/e2e/test_grpc.py +++ b/test/e2e/test_grpc.py @@ -110,7 +110,12 @@ def test_parse_custom_tool_success(grpc_stub: CodeInterpreterServiceStub): response: ParseCustomToolResponse = grpc_stub.ParseCustomTool( ParseCustomToolRequest( tool_source_code=''' -def my_tool(a: int, b: typing.Tuple[Optional[str], str] = ("hello", "world"), *, c: typing.Union[list[str], dict[str, typing.Optional[float]]]) -> int: +import typing +import typing as banana +from typing import Optional +from typing import Union as Onion + +def my_tool(a: int, b: typing.Tuple[Optional[str], str] = ("hello", "world"), *, c: Onion[list[str], dict[str, banana.Optional[float]]]) -> int: """ This tool is really really cool. Very toolish experience: @@ -149,7 +154,7 @@ def my_tool(a: int, b: typing.Tuple[Optional[str], str] = ("hello", "world"), *, "type": "array", "minItems": 2, "items": [ - {"anyOf": [{"type": "null"}, {"type": "string"}]}, + {"anyOf": [{"type": "string"}, {"type": "null"}]}, {"type": "string"}, ], "additionalItems": False, @@ -161,7 +166,7 @@ def my_tool(a: int, b: typing.Tuple[Optional[str], str] = ("hello", "world"), *, { "type": "object", "additionalProperties": { - "anyOf": [{"type": "null"}, {"type": "number"}] + "anyOf": [{"type": "number"}, {"type": "null"}] }, }, ], @@ -249,6 +254,23 @@ def test_execute_custom_tool_success(grpc_stub: CodeInterpreterServiceStub): assert result.success.tool_output_json == "3" +def test_execute_custom_tool_advanced_success(grpc_stub: CodeInterpreterServiceStub): + result = grpc_stub.ExecuteCustomTool( + ExecuteCustomToolRequest( + tool_source_code=""" +import datetime + +def date_tool(a: datetime.datetime) -> str: + return f"The year is {a.year}" +""", + tool_input_json='{"a": "2000-01-01T00:00:00"}', + ) + ) + + assert result.WhichOneof("response") == "success" + assert result.success.tool_output_json == "\"The year is 2000\"" + + def test_execute_custom_tool_error(grpc_stub: CodeInterpreterServiceStub): result = grpc_stub.ExecuteCustomTool( ExecuteCustomToolRequest( diff --git a/test/e2e/test_http.py b/test/e2e/test_http.py index f673b33..f8e0bdb 100644 --- a/test/e2e/test_http.py +++ b/test/e2e/test_http.py @@ -90,7 +90,12 @@ def test_parse_custom_tool_success(http_client: httpx.Client): "/v1/parse-custom-tool", json={ "tool_source_code": ''' -def my_tool(a: int, b: typing.Tuple[Optional[str], str] = ("hello", "world"), *, c: typing.Union[list[str], dict[str, typing.Optional[float]]]) -> int: +import typing +import typing as banana +from typing import Optional +from typing import Union as Onion + +def my_tool(a: int, b: typing.Tuple[Optional[str], str] = ("hello", "world"), *, c: Onion[list[str], dict[str, banana.Optional[float]]]) -> int: """ This tool is really really cool. Very toolish experience: @@ -128,7 +133,7 @@ def my_tool(a: int, b: typing.Tuple[Optional[str], str] = ("hello", "world"), *, "type": "array", "minItems": 2, "items": [ - {"anyOf": [{"type": "null"}, {"type": "string"}]}, + {"anyOf": [{"type": "string"}, {"type": "null"}]}, {"type": "string"}, ], "additionalItems": False, @@ -140,7 +145,7 @@ def my_tool(a: int, b: typing.Tuple[Optional[str], str] = ("hello", "world"), *, { "type": "object", "additionalProperties": { - "anyOf": [{"type": "null"}, {"type": "number"}] + "anyOf": [{"type": "number"}, {"type": "null"}] }, }, ], @@ -215,6 +220,25 @@ def test_execute_custom_tool_success(http_client: httpx.Client): assert json.loads(response_json["tool_output_json"]) == 3 +def test_execute_custom_tool_advanced_success(http_client: httpx.Client): + response = http_client.post( + "/v1/execute-custom-tool", + json={ + "tool_source_code": """ +import datetime + +def date_tool(a: datetime.datetime) -> str: + return f"The year is {a.year}" +""", + "tool_input_json": '{"a": "2000-01-01T00:00:00"}', + }, + ) + + assert response.status_code == 200 + response_json = response.json() + assert json.loads(response_json["tool_output_json"]) == "The year is 2000" + + def test_parse_custom_tool_error(http_client: httpx.Client): response = http_client.post( "/v1/parse-custom-tool",