From 76c500b2c92c958187e6487d0d68bd687947e943 Mon Sep 17 00:00:00 2001 From: Tomas Pilar Date: Thu, 2 Jan 2025 14:50:19 +0100 Subject: [PATCH] feat(tools): add Model Context Protocol tool Signed-off-by: Tomas Pilar --- examples/tools/mcp.ts | 60 ++++++++++++++ package.json | 5 ++ src/internals/helpers/paginate.test.ts | 79 +++++++++++++++---- src/internals/helpers/paginate.ts | 32 ++++++++ src/tools/mcp.test.ts | 103 +++++++++++++++++++++++++ src/tools/mcp.ts | 70 +++++++++++++++++ yarn.lock | 61 ++++++++++++--- 7 files changed, 383 insertions(+), 27 deletions(-) create mode 100644 examples/tools/mcp.ts create mode 100644 src/tools/mcp.test.ts create mode 100644 src/tools/mcp.ts diff --git a/examples/tools/mcp.ts b/examples/tools/mcp.ts new file mode 100644 index 00000000..4eb2c841 --- /dev/null +++ b/examples/tools/mcp.ts @@ -0,0 +1,60 @@ +/** + * Copyright 2025 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; +import { MCPTool } from "bee-agent-framework/tools/mcp"; +import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; +import { BeeAgent } from "bee-agent-framework/agents/bee/agent"; +import { UnconstrainedMemory } from "bee-agent-framework/memory/unconstrainedMemory"; +import { OllamaChatLLM } from "bee-agent-framework/adapters/ollama/chat"; + +// Create MCP Client +const client = new Client( + { + name: "test-client", + version: "1.0.0", + }, + { + capabilities: {}, + }, +); + +// Connect the client to any MCP server with tools capablity +await client.connect( + new StdioClientTransport({ + command: "npx", + args: ["-y", "@modelcontextprotocol/server-everything"], + }), +); + +try { + // Server usually supports several tools, use the factory for automatic discovery + const tools = await MCPTool.createTools(client); + const agent = new BeeAgent({ + llm: new OllamaChatLLM(), + memory: new UnconstrainedMemory(), + tools, + }); + // @modelcontextprotocol/server-everything contains "add" tool + await agent.run({ prompt: "Find out how much is 4 + 7" }).observe((emitter) => { + emitter.on("update", async ({ data, update, meta }) => { + console.log(`Agent (${update.key}) 🤖 : `, update.value); + }); + }); +} finally { + // Close the MCP connection + await client.close(); +} diff --git a/package.json b/package.json index 5a77eae7..462d3033 100644 --- a/package.json +++ b/package.json @@ -209,6 +209,7 @@ "@ibm-generative-ai/node-sdk": "~3.2.4", "@langchain/community": ">=0.2.28", "@langchain/core": ">=0.2.27", + "@modelcontextprotocol/sdk": "^1.0.4", "@zilliz/milvus2-sdk-node": "^2.4.9", "google-auth-library": "*", "groq-sdk": "^0.7.0", @@ -246,6 +247,9 @@ "@langchain/core": { "optional": true }, + "@modelcontextprotocol/sdk": { + "optional": true + }, "@zilliz/milvus2-sdk-node": { "optional": true }, @@ -285,6 +289,7 @@ "@ibm-generative-ai/node-sdk": "~3.2.4", "@langchain/community": "~0.3.17", "@langchain/core": "~0.3.22", + "@modelcontextprotocol/sdk": "^1.0.4", "@opentelemetry/instrumentation": "^0.56.0", "@opentelemetry/resources": "^1.29.0", "@opentelemetry/sdk-node": "^0.56.0", diff --git a/src/internals/helpers/paginate.test.ts b/src/internals/helpers/paginate.test.ts index aa359792..ff37cb36 100644 --- a/src/internals/helpers/paginate.test.ts +++ b/src/internals/helpers/paginate.test.ts @@ -14,10 +14,15 @@ * limitations under the License. */ -import { paginate, PaginateInput } from "@/internals/helpers/paginate.js"; +import { + paginate, + PaginateInput, + paginateWithCursor, + PaginateWithCursorInput, +} from "@/internals/helpers/paginate.js"; describe("paginate", () => { - it.each([ + const mockSetup = [ { size: 1, chunkSize: 1, @@ -38,23 +43,63 @@ describe("paginate", () => { chunkSize: 1, items: Array(20).fill(1), }, - ])("Works %#", async ({ size, items, chunkSize }) => { - const fn: PaginateInput["handler"] = vi.fn().mockImplementation(async ({ offset }) => { - const chunk = items.slice(offset, offset + chunkSize); - return { done: offset + chunk.length >= items.length, data: chunk }; - }); + ] as const; + + describe("paginate", () => { + it.each(mockSetup)("Works %#", async ({ size, items, chunkSize }) => { + const fn: PaginateInput["handler"] = vi + .fn() + .mockImplementation(async ({ offset }) => { + const chunk = items.slice(offset, offset + chunkSize); + return { done: offset + chunk.length >= items.length, data: chunk }; + }); - const results = await paginate({ - size, - handler: fn, + const results = await paginate({ + size, + handler: fn, + }); + + const maxItemsToBeRetrieved = Math.min(size, items.length); + let expectedCalls = Math.ceil(maxItemsToBeRetrieved / chunkSize); + if (expectedCalls === 0 && size > 0) { + expectedCalls = 1; + } + expect(fn).toBeCalledTimes(expectedCalls); + expect(results).toHaveLength(maxItemsToBeRetrieved); }); + }); - const maxItemsToBeRetrieved = Math.min(size, items.length); - let expectedCalls = Math.ceil(maxItemsToBeRetrieved / chunkSize); - if (expectedCalls === 0 && size > 0) { - expectedCalls = 1; - } - expect(fn).toBeCalledTimes(expectedCalls); - expect(results).toHaveLength(maxItemsToBeRetrieved); + describe("paginateWithCursor", () => { + it.each(mockSetup)("Works %#", async ({ size, items, chunkSize }) => { + const fn = vi + .fn["handler"]>() + .mockImplementation(async ({ cursor = 0 }) => { + const chunk = items.slice(cursor, cursor + chunkSize); + const isDone = cursor + chunk.length >= items.length; + return isDone + ? ({ + done: true, + data: chunk, + } as const) + : ({ + done: false, + data: chunk, + nextCursor: cursor + chunk.length, + } as const); + }); + + const results = await paginateWithCursor({ + size, + handler: fn, + }); + + const maxItemsToBeRetrieved = Math.min(size, items.length); + let expectedCalls = Math.ceil(maxItemsToBeRetrieved / chunkSize); + if (expectedCalls === 0 && size > 0) { + expectedCalls = 1; + } + expect(fn).toBeCalledTimes(expectedCalls); + expect(results).toHaveLength(maxItemsToBeRetrieved); + }); }); }); diff --git a/src/internals/helpers/paginate.ts b/src/internals/helpers/paginate.ts index dcd0b31c..954cfd8f 100644 --- a/src/internals/helpers/paginate.ts +++ b/src/internals/helpers/paginate.ts @@ -40,3 +40,35 @@ export async function paginate(input: PaginateInput): Promise { return acc; } + +export interface PaginateWithCursorInput { + size: number; + handler: (data: { + cursor: C | undefined; + limit: number; + }) => Promise<{ data: T[]; done: true } | { data: T[]; done: false; nextCursor: C }>; +} + +export async function paginateWithCursor(input: PaginateWithCursorInput): Promise { + const acc: T[] = []; + let cursor: C | undefined; + while (acc.length < input.size) { + const result = await input.handler({ + cursor, + limit: input.size - acc.length, + }); + acc.push(...result.data); + + if (result.done || result.data.length === 0) { + break; + } else { + cursor = result.nextCursor; + } + } + + if (acc.length > input.size) { + acc.length = input.size; + } + + return acc; +} diff --git a/src/tools/mcp.test.ts b/src/tools/mcp.test.ts new file mode 100644 index 00000000..642c9d46 --- /dev/null +++ b/src/tools/mcp.test.ts @@ -0,0 +1,103 @@ +/** + * Copyright 2025 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Server } from "@modelcontextprotocol/sdk/server/index.js"; +import { CallToolRequestSchema, ListToolsRequestSchema } from "@modelcontextprotocol/sdk/types.js"; +import { InMemoryTransport } from "@modelcontextprotocol/sdk/inMemory.js"; +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; +import { MCPTool } from "./mcp.js"; +import { entries } from "remeda"; +import { zodToJsonSchema } from "zod-to-json-schema"; +import { z } from "zod"; + +const abInputSchema = z.object({ a: z.number(), b: z.number() }); +const toolDescriptions = { + add: { + description: "Adds two numbers", + inputSchema: zodToJsonSchema(abInputSchema), + handler: ({ a, b }: z.input) => a + b, + }, + multiply: { + description: "Multiplies two numbers", + inputSchema: zodToJsonSchema(abInputSchema), + handler: ({ a, b }: z.input) => a * b, + }, +} as const; + +describe("MCPTool", () => { + let server: Server; + let client: Client; + let tools: MCPTool[]; + + beforeEach(async () => { + server = new Server( + { + name: "test-server", + version: "1.0.0", + }, + { + capabilities: { + tools: {}, + }, + }, + ); + server.setRequestHandler(ListToolsRequestSchema, async () => { + return { + tools: entries(toolDescriptions).map(([name, { description, inputSchema }]) => ({ + name, + description, + inputSchema, + })), + }; + }); + server.setRequestHandler(CallToolRequestSchema, async (request) => { + const tool = toolDescriptions[request.params.name as keyof typeof toolDescriptions]; + if (!tool) { + throw new Error("Tool not found"); + } + // Arguments are assumed to be valid in this mock + return { + contents: [tool.handler(request.params.arguments as any)], + }; + }); + + client = new Client( + { + name: "test-client", + version: "1.0.0", + }, + { + capabilities: {}, + }, + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await server.connect(serverTransport); + await client.connect(clientTransport); + + tools = await MCPTool.createTools(client); + }); + + it("should run the tools", async () => { + const tool = tools.at(0); + expect(tool).toBeDefined(); + }); + + afterEach(async () => { + await client.close(); + await server.close(); + }); +}); diff --git a/src/tools/mcp.ts b/src/tools/mcp.ts new file mode 100644 index 00000000..9d4123ce --- /dev/null +++ b/src/tools/mcp.ts @@ -0,0 +1,70 @@ +/** + * Copyright 2025 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { BaseToolRunOptions, ToolEmitter, ToolInput, JSONToolOutput, Tool } from "@/tools/base.js"; +import { Emitter } from "@/emitter/emitter.js"; +import { GetRunContext } from "@/context.js"; +import { Client as MCPClient } from "@modelcontextprotocol/sdk/client/index.js"; +import { ListToolsResult } from "@modelcontextprotocol/sdk/types.js"; +import { SchemaObject } from "ajv"; + +export interface MCPToolInput { + client: MCPClient; + tool: ListToolsResult["tools"][number]; +} + +export class MCPToolOutput extends JSONToolOutput {} + +export class MCPTool extends Tool { + public readonly name: string; + public readonly description: string; + + public readonly client: MCPClient; + private readonly tool: ListToolsResult["tools"][number]; + + constructor({ client, tool, ...options }: MCPToolInput) { + super(options); + this.client = client; + this.tool = tool; + this.name = tool.name; + this.description = tool.description ?? "No description, use based on name."; + } + + public readonly emitter: ToolEmitter, MCPToolOutput> = Emitter.root.child({ + namespace: ["tool", "mcp", "tool"], + creator: this, + }); + + inputSchema() { + return this.tool.inputSchema as SchemaObject; + } + + protected async _run( + input: ToolInput, + _options: BaseToolRunOptions, + run: GetRunContext, + ) { + const result = await this.client.callTool({ name: this.name, arguments: input }, undefined, { + signal: run.signal, + }); + return new MCPToolOutput(result); + } + + static async createTools(client: MCPClient): Promise { + const { tools } = await client.listTools(); + return tools.map((tool) => new MCPTool({ client, tool })); + } +} diff --git a/yarn.lock b/yarn.lock index 1d7be431..5963bc99 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2173,6 +2173,17 @@ __metadata: languageName: node linkType: hard +"@modelcontextprotocol/sdk@npm:^1.0.4": + version: 1.1.0 + resolution: "@modelcontextprotocol/sdk@npm:1.1.0" + dependencies: + content-type: "npm:^1.0.5" + raw-body: "npm:^3.0.0" + zod: "npm:^3.23.8" + checksum: 10c0/1f80a2139a09cb0e7c6af75d7fe57373a554a9d30be54ee093d49dce5ea1266cc0b624a687017b99248c851f125845387897811360abed22d0d57998935ce04a + languageName: node + linkType: hard + "@nodelib/fs.scandir@npm:2.1.5": version: 2.1.5 resolution: "@nodelib/fs.scandir@npm:2.1.5" @@ -4827,6 +4838,7 @@ __metadata: "@ibm-generative-ai/node-sdk": "npm:~3.2.4" "@langchain/community": "npm:~0.3.17" "@langchain/core": "npm:~0.3.22" + "@modelcontextprotocol/sdk": "npm:^1.0.4" "@opentelemetry/api": "npm:^1.9.0" "@opentelemetry/instrumentation": "npm:^0.56.0" "@opentelemetry/resources": "npm:^1.29.0" @@ -4919,6 +4931,7 @@ __metadata: "@ibm-generative-ai/node-sdk": ~3.2.4 "@langchain/community": ">=0.2.28" "@langchain/core": ">=0.2.27" + "@modelcontextprotocol/sdk": ^1.0.4 "@zilliz/milvus2-sdk-node": ^2.4.9 google-auth-library: "*" groq-sdk: ^0.7.0 @@ -4946,6 +4959,8 @@ __metadata: optional: true "@langchain/core": optional: true + "@modelcontextprotocol/sdk": + optional: true "@zilliz/milvus2-sdk-node": optional: true google-auth-library: @@ -5141,6 +5156,13 @@ __metadata: languageName: node linkType: hard +"bytes@npm:3.1.2": + version: 3.1.2 + resolution: "bytes@npm:3.1.2" + checksum: 10c0/76d1c43cbd602794ad8ad2ae94095cddeb1de78c5dddaa7005c51af10b0176c69971a6d88e805a90c2b6550d76636e43c40d8427a808b8645ede885de4a0358e + languageName: node + linkType: hard + "cac@npm:^6.7.14": version: 6.7.14 resolution: "cac@npm:6.7.14" @@ -5810,6 +5832,13 @@ __metadata: languageName: node linkType: hard +"content-type@npm:^1.0.5": + version: 1.0.5 + resolution: "content-type@npm:1.0.5" + checksum: 10c0/b76ebed15c000aee4678c3707e0860cb6abd4e680a598c0a26e17f0bfae723ec9cc2802f0ff1bc6e4d80603719010431d2231018373d4dde10f9ccff9dadf5af + languageName: node + linkType: hard + "conventional-changelog-angular@npm:^7.0.0": version: 7.0.0 resolution: "conventional-changelog-angular@npm:7.0.0" @@ -8365,6 +8394,15 @@ __metadata: languageName: node linkType: hard +"iconv-lite@npm:0.6.3, iconv-lite@npm:^0.6.2, iconv-lite@npm:^0.6.3": + version: 0.6.3 + resolution: "iconv-lite@npm:0.6.3" + dependencies: + safer-buffer: "npm:>= 2.1.2 < 3.0.0" + checksum: 10c0/98102bc66b33fcf5ac044099d1257ba0b7ad5e3ccd3221f34dd508ab4070edff183276221684e1e0555b145fce0850c9f7d2b60a9fcac50fbb4ea0d6e845a3b1 + languageName: node + linkType: hard + "iconv-lite@npm:^0.4.24": version: 0.4.24 resolution: "iconv-lite@npm:0.4.24" @@ -8374,15 +8412,6 @@ __metadata: languageName: node linkType: hard -"iconv-lite@npm:^0.6.2, iconv-lite@npm:^0.6.3": - version: 0.6.3 - resolution: "iconv-lite@npm:0.6.3" - dependencies: - safer-buffer: "npm:>= 2.1.2 < 3.0.0" - checksum: 10c0/98102bc66b33fcf5ac044099d1257ba0b7ad5e3ccd3221f34dd508ab4070edff183276221684e1e0555b145fce0850c9f7d2b60a9fcac50fbb4ea0d6e845a3b1 - languageName: node - linkType: hard - "ieee754@npm:^1.1.13": version: 1.2.1 resolution: "ieee754@npm:1.2.1" @@ -12045,6 +12074,18 @@ __metadata: languageName: node linkType: hard +"raw-body@npm:^3.0.0": + version: 3.0.0 + resolution: "raw-body@npm:3.0.0" + dependencies: + bytes: "npm:3.1.2" + http-errors: "npm:2.0.0" + iconv-lite: "npm:0.6.3" + unpipe: "npm:1.0.0" + checksum: 10c0/f8daf4b724064a4811d118745a781ca0fb4676298b8adadfd6591155549cfea0a067523cf7dd3baeb1265fecc9ce5dfb2fc788c12c66b85202a336593ece0f87 + languageName: node + linkType: hard + "rc@npm:1.2.8, rc@npm:^1.2.7, rc@npm:^1.2.8": version: 1.2.8 resolution: "rc@npm:1.2.8" @@ -13979,7 +14020,7 @@ __metadata: languageName: node linkType: hard -"unpipe@npm:~1.0.0": +"unpipe@npm:1.0.0, unpipe@npm:~1.0.0": version: 1.0.0 resolution: "unpipe@npm:1.0.0" checksum: 10c0/193400255bd48968e5c5383730344fbb4fa114cdedfab26e329e50dd2d81b134244bb8a72c6ac1b10ab0281a58b363d06405632c9d49ca9dfd5e90cbd7d0f32c