diff --git a/src/tools/custom.test.ts b/src/tools/custom.test.ts deleted file mode 100644 index 570430e4..00000000 --- a/src/tools/custom.test.ts +++ /dev/null @@ -1,177 +0,0 @@ -/** - * Copyright 2025 IBM Corp. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { describe, it, expect, vi } from "vitest"; -import { CustomTool } from "./custom.js"; -import { StringToolOutput } from "./base.js"; - -const mocks = vi.hoisted(() => { - return { - parseCustomTool: vi.fn(), - executeCustomTool: vi.fn(), - }; -}); - -vi.mock("@connectrpc/connect", () => ({ - createPromiseClient: vi.fn().mockReturnValue({ - parseCustomTool: mocks.parseCustomTool, - executeCustomTool: mocks.executeCustomTool, - }), -})); - -describe("CustomTool", () => { - it("should instantiate correctly", async () => { - mocks.parseCustomTool.mockResolvedValue({ - response: { - case: "success", - value: { - toolName: "test", - toolDescription: "A test tool", - toolInputSchemaJson: `{ - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": { - "a": { "type": "integer" }, - "b": { "type": "string" } - } - }`, - }, - }, - }); - - const customTool = await CustomTool.fromSourceCode({ url: "http://localhost" }, "source code"); - - expect(customTool.name).toBe("test"); - expect(customTool.description).toBe("A test tool"); - expect(await customTool.inputSchema()).toEqual({ - $schema: "http://json-schema.org/draft-07/schema#", - type: "object", - properties: { - a: { type: "integer" }, - b: { type: "string" }, - }, - }); - }); - - it("should throw InvalidCustomToolError on parse error", async () => { - mocks.parseCustomTool.mockResolvedValue({ - response: { - case: "error", - value: { - errorMessages: ["Error parsing tool"], - }, - }, - }); - - await expect( - CustomTool.fromSourceCode({ url: "http://localhost" }, "source code"), - ).rejects.toThrow("Error parsing tool"); - }); - - it("should run the custom tool", async () => { - mocks.parseCustomTool.mockResolvedValue({ - response: { - case: "success", - value: { - toolName: "test", - toolDescription: "A test tool", - toolInputSchemaJson: `{ - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": { - "a": { "type": "integer" }, - "b": { "type": "string" } - } - }`, - }, - }, - }); - - const customTool = await CustomTool.fromSourceCode( - { url: "http://localhost" }, - "source code", - "executor-id", - ); - - mocks.executeCustomTool.mockResolvedValue({ - response: { - case: "success", - value: { - toolOutputJson: '{"something": "42"}', - }, - }, - }); - - const result = await customTool.run( - { - a: 42, - b: "test", - }, - { - signal: new AbortController().signal, - }, - ); - expect(result).toBeInstanceOf(StringToolOutput); - expect(result.getTextContent()).toEqual('{"something": "42"}'); - }); - - it("should throw CustomToolExecutionError on execution error", async () => { - mocks.parseCustomTool.mockResolvedValue({ - response: { - case: "success", - value: { - toolName: "test", - toolDescription: "A test tool", - toolInputSchemaJson: `{ - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": { - "a": { "type": "integer" }, - "b": { "type": "string" } - } - }`, - }, - }, - }); - - const customTool = await CustomTool.fromSourceCode( - { url: "http://localhost" }, - "source code", - "executor-id", - ); - - mocks.executeCustomTool.mockResolvedValue({ - response: { - case: "error", - value: { - stderr: "Error executing tool", - }, - }, - }); - - await expect( - customTool.run( - { - a: 42, - b: "test", - }, - { - signal: new AbortController().signal, - }, - ), - ).rejects.toThrow('Tool "test" has occurred an error!'); - }); -}); diff --git a/src/tools/custom.ts b/src/tools/custom.ts index 7e9b45f5..0eefb84b 100644 --- a/src/tools/custom.ts +++ b/src/tools/custom.ts @@ -22,12 +22,9 @@ import { Tool, ToolInput, } from "@/tools/base.js"; -import { createGrpcTransport } from "@connectrpc/connect-node"; -import { PromiseClient, createPromiseClient } from "@connectrpc/connect"; import { FrameworkError } from "@/errors.js"; import { z } from "zod"; import { validate } from "@/internals/helpers/general.js"; -import { CodeInterpreterService } from "bee-proto/code_interpreter/v1/code_interpreter_service_connect"; import { CodeInterpreterOptions } from "./python/python.js"; import { RunContext } from "@/context.js"; import { Emitter } from "@/emitter/emitter.js"; @@ -51,17 +48,6 @@ const toolOptionsSchema = z export type CustomToolOptions = z.output & BaseToolOptions; -function createCodeInterpreterClient(codeInterpreter: CodeInterpreterOptions) { - return createPromiseClient( - CodeInterpreterService, - createGrpcTransport({ - baseUrl: codeInterpreter.url, - httpVersion: "2", - nodeOptions: codeInterpreter.connectionOptions, - }), - ); -} - export class CustomTool extends Tool { name: string; description: string; @@ -75,19 +61,13 @@ export class CustomTool extends Tool { return this.options.inputSchema; } - protected client: PromiseClient; - static { this.register(); } - public constructor( - options: CustomToolOptions, - client?: PromiseClient, - ) { + public constructor(options: CustomToolOptions) { validate(options, toolOptionsSchema); super(options); - this.client = client || createCodeInterpreterClient(options.codeInterpreter); this.name = options.name; this.description = options.description; } @@ -97,25 +77,52 @@ export class CustomTool extends Tool { _options: Partial, run: RunContext, ) { - const { response } = await this.client.executeCustomTool( - { - executorId: this.options.executorId || "default", - toolSourceCode: this.options.sourceCode, - toolInputJson: JSON.stringify(input), - }, - { signal: run.signal }, - ); - - if (response.case === "error") { - throw new CustomToolExecuteError(response.value.stderr); + // Execute custom tool + const httpUrl = this.options.codeInterpreter.url + "/v1/execute-custom-tool"; + const response = await this.customFetch(httpUrl, input, run); + + if (!response?.ok) { + throw new CustomToolExecuteError("HTTP request failed!", [new Error(await response.text())]); } - return new StringToolOutput(response.value!.toolOutputJson); + const result = await response.json(); + + if (result?.exit_code) { + throw new CustomToolExecuteError(`Custom tool {tool_name} execution error!`); + } + + return new StringToolOutput(result.tool_output_json); + } + + private async customFetch(httpUrl: string, input: any, run: RunContext) { + try { + return await fetch(httpUrl, { + method: "POST", + headers: { + "Accept": "application/json", + "Content-Type": "application/json", + }, + body: JSON.stringify({ + tool_source_code: this.options.sourceCode, + executorId: this.options.executorId ?? "default", + tool_input_json: JSON.stringify(input), + }), + signal: run.signal, + }); + } catch (error) { + if (error.cause.name == "HTTPParserError") { + throw new CustomToolExecuteError( + "Custom tool over HTTP failed -- not using HTTP endpoint!", + [error], + ); + } else { + throw new CustomToolExecuteError("Custom tool over HTTP failed!", [error]); + } + } } loadSnapshot(snapshot: ReturnType): void { super.loadSnapshot(snapshot); - this.client = createCodeInterpreterClient(this.options.codeInterpreter); } static async fromSourceCode( @@ -123,25 +130,48 @@ export class CustomTool extends Tool { sourceCode: string, executorId?: string, ) { - const client = createCodeInterpreterClient(codeInterpreter); - const response = await client.parseCustomTool({ toolSourceCode: sourceCode }); + // Parse custom tool + let response; + const httpUrl = codeInterpreter.url + "/v1/parse-custom-tool"; + try { + response = await fetch(httpUrl, { + method: "POST", + headers: { + "Accept": "application/json", + "Content-Type": "application/json", + }, + body: JSON.stringify({ + tool_source_code: sourceCode, + executorId: executorId ?? "default", + }), + }); + } catch (error) { + if (error.cause.name == "HTTPParserError") { + throw new CustomToolCreateError("Custom tool parse error -- not using HTTP endpoint!", [ + error, + ]); + } else { + throw new CustomToolCreateError("Custom tool parse error!", [error]); + } + } - if (response.response.case === "error") { - throw new CustomToolCreateError(response.response.value.errorMessages.join("\n")); + if (!response?.ok) { + throw new CustomToolCreateError("Error parsing custom tool!", [ + new Error(await response.text()), + ]); } - const { toolName, toolDescription, toolInputSchemaJson } = response.response.value!; - - return new CustomTool( - { - codeInterpreter, - sourceCode, - name: toolName, - description: toolDescription, - inputSchema: JSON.parse(toolInputSchemaJson), - executorId, - }, - client, - ); + const result = await response.json(); + + const { tool_name, tool_description, tool_input_schema_json } = result; + + return new CustomTool({ + codeInterpreter, + sourceCode, + name: tool_name, + description: tool_description, + inputSchema: JSON.parse(tool_input_schema_json), + executorId, + }); } } diff --git a/src/tools/python/python.ts b/src/tools/python/python.ts index 85d25f4a..9cf35fb9 100644 --- a/src/tools/python/python.ts +++ b/src/tools/python/python.ts @@ -163,6 +163,7 @@ export class PythonTool extends Tool { inputFiles.map((file) => [`${prefix}${file.filename}`, file.pythonId]), ), }), + signal: run.signal, }); } catch (error) { if (error.cause.name == "HTTPParserError") { diff --git a/tests/e2e/tools/custom.test.ts b/tests/e2e/tools/custom.test.ts new file mode 100644 index 00000000..a98c6b30 --- /dev/null +++ b/tests/e2e/tools/custom.test.ts @@ -0,0 +1,111 @@ +/** + * Copyright 2024 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { describe, it, expect } from "vitest"; +import { CustomTool } from "@/tools/custom.js"; +import { StringToolOutput } from "@/tools/base.js"; + +describe("CustomTool", () => { + it("should instantiate correctly", async () => { + const customTool = await CustomTool.fromSourceCode( + { url: process.env.CODE_INTERPRETER_URL! }, + ` + def test_func(a: int, b: str=None): + """A test tool""" + print(a) + print(b) + `, + ); + + expect(customTool.name).toBe("test_func"); + expect(customTool.description).toBe("A test tool"); + expect(await customTool.inputSchema()).toEqual({ + $schema: "http://json-schema.org/draft-07/schema#", + additionalProperties: false, + properties: { + a: { type: "integer" }, + b: { type: "string" }, + }, + required: ["a"], + title: "test_func", + type: "object", + }); + }); + + it("should throw InvalidCustomToolError on parse error", async () => { + await expect( + CustomTool.fromSourceCode({ url: process.env.CODE_INTERPRETER_URL! }, "source code"), + ).rejects.toThrow("Error parsing custom tool!"); + }); + + it("should run the custom tool", async () => { + const customTool = await CustomTool.fromSourceCode( + { url: process.env.CODE_INTERPRETER_URL! }, + `import requests + +def get_riddle(a: int, b: str) -> dict[str, str] | None: + """ + Fetches a random riddle from the Riddles API. + + This function retrieves a random riddle and its answer. Testing with input params. + + Returns: + dict[str,str] | None: A dictionary containing: + - 'riddle' (str): Passed in a int + - 'answer' (str): Passed in b string + Returns None if the request fails. + """ + return { "riddle": str(a), "answer": b}`, + ); + + const result = await customTool.run( + { + a: 42, + b: "something", + }, + { + signal: new AbortController().signal, + }, + ); + expect(result).toBeInstanceOf(StringToolOutput); + expect(result.getTextContent()).toEqual('{"riddle": "42", "answer": "something"}'); + }); + + it("should throw CustomToolExecutionError on execution error", async () => { + const customTool = await CustomTool.fromSourceCode( + { url: process.env.CODE_INTERPRETER_URL! }, + ` +def test(a: int, b: str=None): + """A test tool""" + print("hello") + div_by_zero = 123 / 0 + return "foo" +`, + ); + + await expect( + customTool.run( + { + a: 42, + b: "test", + }, + { + signal: new AbortController().signal, + }, + ), + ).rejects.toThrow('Tool "test" has occurred an error!'); + }); +});