Skip to content

Commit

Permalink
feat!: use custom code tool over http
Browse files Browse the repository at this point in the history
* Modify the existing CustomTool over gRPC to go over HTTP.
* Add e2e tests.
* Requires CODE_INTERPRETER_URL env var to point to exposed HTTP port (50081).

BREAKING CHANGE: Requires exposed port and updated CODE_INTERPRETER_URL.

Signed-off-by: Mark Sturdevant <[email protected]>
  • Loading branch information
markstur committed Jan 8, 2025
1 parent 12703a4 commit da928fc
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 228 deletions.
177 changes: 0 additions & 177 deletions src/tools/custom.test.ts

This file was deleted.

132 changes: 81 additions & 51 deletions src/tools/custom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,9 @@ import {
Tool,
ToolInput,
} from "@/tools/base.js";
import { createGrpcTransport } from "@connectrpc/connect-node";
import { PromiseClient, createPromiseClient } from "@connectrpc/connect";
import { FrameworkError } from "@/errors.js";
import { z } from "zod";
import { validate } from "@/internals/helpers/general.js";
import { CodeInterpreterService } from "bee-proto/code_interpreter/v1/code_interpreter_service_connect";
import { CodeInterpreterOptions } from "./python/python.js";
import { RunContext } from "@/context.js";
import { Emitter } from "@/emitter/emitter.js";
Expand All @@ -51,17 +48,6 @@ const toolOptionsSchema = z

export type CustomToolOptions = z.output<typeof toolOptionsSchema> & BaseToolOptions;

function createCodeInterpreterClient(codeInterpreter: CodeInterpreterOptions) {
return createPromiseClient(
CodeInterpreterService,
createGrpcTransport({
baseUrl: codeInterpreter.url,
httpVersion: "2",
nodeOptions: codeInterpreter.connectionOptions,
}),
);
}

export class CustomTool extends Tool<StringToolOutput, CustomToolOptions> {
name: string;
description: string;
Expand All @@ -75,19 +61,13 @@ export class CustomTool extends Tool<StringToolOutput, CustomToolOptions> {
return this.options.inputSchema;
}

protected client: PromiseClient<typeof CodeInterpreterService>;

static {
this.register();
}

public constructor(
options: CustomToolOptions,
client?: PromiseClient<typeof CodeInterpreterService>,
) {
public constructor(options: CustomToolOptions) {
validate(options, toolOptionsSchema);
super(options);
this.client = client || createCodeInterpreterClient(options.codeInterpreter);
this.name = options.name;
this.description = options.description;
}
Expand All @@ -97,51 +77,101 @@ export class CustomTool extends Tool<StringToolOutput, CustomToolOptions> {
_options: Partial<BaseToolRunOptions>,
run: RunContext<typeof this>,
) {
const { response } = await this.client.executeCustomTool(
{
executorId: this.options.executorId || "default",
toolSourceCode: this.options.sourceCode,
toolInputJson: JSON.stringify(input),
},
{ signal: run.signal },
);

if (response.case === "error") {
throw new CustomToolExecuteError(response.value.stderr);
// Execute custom tool
const httpUrl = this.options.codeInterpreter.url + "/v1/execute-custom-tool";
const response = await this.customFetch(httpUrl, input, run);

if (!response?.ok) {
throw new CustomToolExecuteError("HTTP request failed!", [new Error(await response.text())]);
}

return new StringToolOutput(response.value!.toolOutputJson);
const result = await response.json();

if (result?.exit_code) {
throw new CustomToolExecuteError(`Custom tool {tool_name} execution error!`);
}

return new StringToolOutput(result.tool_output_json);
}

private async customFetch(httpUrl: string, input: any, run: RunContext<typeof this>) {
try {
return await fetch(httpUrl, {
method: "POST",
headers: {
"Accept": "application/json",
"Content-Type": "application/json",
},
body: JSON.stringify({
tool_source_code: this.options.sourceCode,
executorId: this.options.executorId ?? "default",
tool_input_json: JSON.stringify(input),
}),
signal: run.signal,
});
} catch (error) {
if (error.cause.name == "HTTPParserError") {
throw new CustomToolExecuteError(
"Custom tool over HTTP failed -- not using HTTP endpoint!",
[error],
);
} else {
throw new CustomToolExecuteError("Custom tool over HTTP failed!", [error]);
}
}
}

loadSnapshot(snapshot: ReturnType<typeof this.createSnapshot>): void {
super.loadSnapshot(snapshot);
this.client = createCodeInterpreterClient(this.options.codeInterpreter);
}

static async fromSourceCode(
codeInterpreter: CodeInterpreterOptions,
sourceCode: string,
executorId?: string,
) {
const client = createCodeInterpreterClient(codeInterpreter);
const response = await client.parseCustomTool({ toolSourceCode: sourceCode });
// Parse custom tool
let response;
const httpUrl = codeInterpreter.url + "/v1/parse-custom-tool";
try {
response = await fetch(httpUrl, {
method: "POST",
headers: {
"Accept": "application/json",
"Content-Type": "application/json",
},
body: JSON.stringify({
tool_source_code: sourceCode,
executorId: executorId ?? "default",
}),
});
} catch (error) {
if (error.cause.name == "HTTPParserError") {
throw new CustomToolCreateError("Custom tool parse error -- not using HTTP endpoint!", [
error,
]);
} else {
throw new CustomToolCreateError("Custom tool parse error!", [error]);
}
}

if (response.response.case === "error") {
throw new CustomToolCreateError(response.response.value.errorMessages.join("\n"));
if (!response?.ok) {
throw new CustomToolCreateError("Error parsing custom tool!", [
new Error(await response.text()),
]);
}

const { toolName, toolDescription, toolInputSchemaJson } = response.response.value!;

return new CustomTool(
{
codeInterpreter,
sourceCode,
name: toolName,
description: toolDescription,
inputSchema: JSON.parse(toolInputSchemaJson),
executorId,
},
client,
);
const result = await response.json();

const { tool_name, tool_description, tool_input_schema_json } = result;

return new CustomTool({
codeInterpreter,
sourceCode,
name: tool_name,
description: tool_description,
inputSchema: JSON.parse(tool_input_schema_json),
executorId,
});
}
}
1 change: 1 addition & 0 deletions src/tools/python/python.ts
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ export class PythonTool extends Tool<PythonToolOutput, PythonToolOptions> {
inputFiles.map((file) => [`${prefix}${file.filename}`, file.pythonId]),
),
}),
signal: run.signal,
});
} catch (error) {
if (error.cause.name == "HTTPParserError") {
Expand Down
Loading

0 comments on commit da928fc

Please sign in to comment.