diff --git a/src/node.ts b/src/node.ts index 34485ce2b..a6bc45bac 100644 --- a/src/node.ts +++ b/src/node.ts @@ -11,6 +11,7 @@ import { AgentFunctionContext, AgentFunction, AgentFilterInfo, + AgentFilterParams, DefaultParamsType, DefaultInputData, } from "@/type"; @@ -53,6 +54,7 @@ export class ComputedNode extends Node { public readonly graphId: string; public readonly isResult: boolean; public readonly params: NodeDataParams; // Agent-specific parameters + private readonly filterParams: AgentFilterParams; private readonly dynamicParams: Record; public readonly nestedGraph?: GraphData; public readonly retryLimit: number; @@ -76,6 +78,7 @@ export class ComputedNode extends Node { super(nodeId, graph); this.graphId = graphId; this.params = data.params ?? {}; + this.filterParams = data.filterParams ?? {}; this.nestedGraph = data.graph; if (typeof data.agent === "string") { this.agentId = data.agent; @@ -221,6 +224,9 @@ export class ComputedNode extends Node { const agentFilter = this.graph.agentFilters[index++]; if (agentFilter) { if (this.shouldApplyAgentFilter(agentFilter)) { + if (agentFilter.filterParams) { + context.filterParams = { ...agentFilter.filterParams, ...context.filterParams }; + } return agentFilter.agent(context, next); } return next(context); @@ -269,7 +275,7 @@ export class ComputedNode extends Node { retry: this.retryCount, verbose: this.graph.verbose, }, - filterParams: {}, + filterParams: this.filterParams, log: localLog, }; diff --git a/src/type.ts b/src/type.ts index 66d1b4f96..5010b0d5a 100644 --- a/src/type.ts +++ b/src/type.ts @@ -33,11 +33,14 @@ export type StaticNodeData = { }; export type AgentNamelessFunction = (...param: any[]) => unknown; +export type AgentFilterParams = Record; + export type ComputedNodeData = { agent: string | AgentNamelessFunction; inputs?: Array; anyInput?: boolean; // any input makes this node ready params?: NodeDataParams; + filterParams?: AgentFilterParams; // agent filter retry?: number; timeout?: number; // msec if?: string; // conditional execution @@ -74,7 +77,7 @@ export type AgentFunctionContext; // agent filter + filterParams: AgentFilterParams; // agent filter log?: TransactionLog[]; }; @@ -92,6 +95,7 @@ export type AgentFilterInfo = { agent: AgentFilterFunction; agentIds?: string[]; nodeIds?: string[]; + filterParams?: AgentFilterParams; }; export type AgentFunctionDictonary = Record>; diff --git a/src/validators/common.ts b/src/validators/common.ts index eaa16a6d4..7a094f1c6 100644 --- a/src/validators/common.ts +++ b/src/validators/common.ts @@ -1,4 +1,4 @@ export const graphDataAttributeKeys = ["nodes", "concurrency", "agentId", "loop", "verbose", "version"]; -export const computedNodeAttributeKeys = ["inputs", "anyInput", "params", "retry", "timeout", "agent", "graph", "isResult", "priority", "if"]; +export const computedNodeAttributeKeys = ["inputs", "anyInput", "params", "retry", "timeout", "agent", "graph", "isResult", "priority", "if", "filterParams"]; export const staticNodeAttributeKeys = ["value", "update", "isResult"]; diff --git a/tests/agentFilters/test_filter_params.ts b/tests/agentFilters/test_filter_params.ts new file mode 100644 index 000000000..de2e81d01 --- /dev/null +++ b/tests/agentFilters/test_filter_params.ts @@ -0,0 +1,231 @@ +import { GraphAI } from "@/graphai"; +import { AgentFilterFunction } from "@/type"; + +import { defaultTestAgents } from "@/utils/test_agents"; + +import test from "node:test"; +import assert from "node:assert"; + +const httpAgentFilter: AgentFilterFunction = async (context, next) => { + return next(context); +}; + +const callbackDictonary = {}; + +test("test filterParams on agent filter", async () => { + const graph_data = { + version: 0.3, + nodes: { + echo: { + agent: "echoAgent", + params: { + filterParams: true, + }, + }, + bypassAgent: { + agent: "bypassAgent", + inputs: [":echo"], + isResult: true, + }, + }, + }; + const agentFilters = [ + { + name: "httpAgentFilter", + agent: httpAgentFilter, + filterParams: { + agentServer: { + baseUrl: "http://localhost:8085/agentFilters/", + stream: true, + }, + }, + agentIds: ["echoAgent"], + }, + ]; + + const graph = new GraphAI({ ...graph_data }, { ...defaultTestAgents, ...callbackDictonary }, { agentFilters }); + + const result = await graph.run(); + // console.log(JSON.stringify(result)); + assert.deepStrictEqual(result, { bypassAgent: [{ agentServer: { baseUrl: "http://localhost:8085/agentFilters/", stream: true } }] }); +}); + +test("test filterParams on node", async () => { + const graph_data = { + version: 0.3, + nodes: { + echo: { + agent: "echoAgent", + params: { + filterParams: true, + }, + filterParams: { + agentServer: { + baseUrl: "http://localhost:8081/nodeParameter/", + }, + }, + }, + bypassAgent: { + agent: "bypassAgent", + inputs: [":echo"], + isResult: true, + }, + }, + }; + const agentFilters = [ + { + name: "httpAgentFilter", + agent: httpAgentFilter, + agentIds: ["echoAgent"], + }, + ]; + + const graph = new GraphAI({ ...graph_data }, { ...defaultTestAgents, ...callbackDictonary }, { agentFilters }); + + const result = await graph.run(); + // console.log(JSON.stringify(result)); + assert.deepStrictEqual(result, { bypassAgent: [{ agentServer: { baseUrl: "http://localhost:8081/nodeParameter/" } }] }); +}); + +test("test filterParams on agent filter and node. Then node.ts use filterParams on node", async () => { + const graph_data = { + version: 0.3, + nodes: { + echo: { + agent: "echoAgent", + params: { + filterParams: true, + }, + filterParams: { + agentServer: { + baseUrl: "http://localhost:8081/nodeParameter/", + }, + }, + }, + bypassAgent: { + agent: "bypassAgent", + inputs: [":echo"], + isResult: true, + }, + }, + }; + const agentFilters = [ + { + name: "httpAgentFilter", + agent: httpAgentFilter, + filterParams: { + agentServer: { + baseUrl: "http://localhost:8085/agentFilters/", + stream: true, + }, + }, + agentIds: ["echoAgent"], + }, + ]; + + const graph = new GraphAI({ ...graph_data }, { ...defaultTestAgents, ...callbackDictonary }, { agentFilters }); + + const result = await graph.run(); + console.log(JSON.stringify(result)); + assert.deepStrictEqual(result, { bypassAgent: [{ agentServer: { baseUrl: "http://localhost:8081/nodeParameter/" } }] }); +}); + +test("test filterParams on each agent", async () => { + const graph_data = { + version: 0.3, + nodes: { + echo: { + agent: "echoAgent", + params: { + filterParams: true, + }, + filterParams: { + agentServer: { + baseUrl: "http://localhost:8081/nodeParameter/", + }, + }, + }, + echo2: { + agent: "echoAgent", + params: { + filterParams: true, + }, + filterParams: { + agentServer: { + baseUrl: "http://localhost:8081/nodeParameter2/", + }, + }, + }, + bypassAgent: { + agent: "bypassAgent", + inputs: [":echo", ":echo2"], + isResult: true, + }, + }, + }; + const agentFilters = [ + { + name: "httpAgentFilter", + agent: httpAgentFilter, + agentIds: ["echoAgent"], + }, + ]; + + const graph = new GraphAI({ ...graph_data }, { ...defaultTestAgents, ...callbackDictonary }, { agentFilters }); + + const result = await graph.run(); + // console.log(JSON.stringify(result)); + assert.deepStrictEqual(result, { + bypassAgent: [{ agentServer: { baseUrl: "http://localhost:8081/nodeParameter/" } }, { agentServer: { baseUrl: "http://localhost:8081/nodeParameter2/" } }], + }); +}); + + +test("test filterParams on agent filter", async () => { + const graph_data = { + version: 0.3, + nodes: { + echo: { + agent: "echoAgent", + params: { + filterParams: true, + }, + }, + bypassAgent: { + agent: "bypassAgent", + inputs: [":echo"], + isResult: true, + }, + }, + }; + const agentFilters = [ + { + name: "httpAgentFilter", + agent: httpAgentFilter, + filterParams: { + agentServer: { + baseUrl: "http://localhost:8085/agentFilters/", + stream: true, + }, + }, + agentIds: ["echoAgent"], + }, + { + name: "httpAgentFilter", + agent: httpAgentFilter, + filterParams: { + agentServer: { + baseUrl: "http://localhost:8085/agentFilters2/", + stream: true, + }, + }, + agentIds: ["echoAgent"], + }, + ]; + + const graph = new GraphAI({ ...graph_data }, { ...defaultTestAgents, ...callbackDictonary }, { agentFilters }); + + const result = await graph.run(); + // console.log(JSON.stringify(result)); + assert.deepStrictEqual(result, { bypassAgent: [{ agentServer: { baseUrl: "http://localhost:8085/agentFilters/", stream: true } }] }); +});