Skip to content

Commit

Permalink
Merge pull request #303 from isamu/filterParams
Browse files Browse the repository at this point in the history
Filter params
  • Loading branch information
snakajima authored May 10, 2024
2 parents 68483ce + 97e98f1 commit 80837ce
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 3 deletions.
8 changes: 7 additions & 1 deletion src/node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
AgentFunctionContext,
AgentFunction,
AgentFilterInfo,
AgentFilterParams,
DefaultParamsType,
DefaultInputData,
} from "@/type";
Expand Down Expand Up @@ -53,6 +54,7 @@ export class ComputedNode extends Node {
public readonly graphId: string;
public readonly isResult: boolean;
public readonly params: NodeDataParams; // Agent-specific parameters
private readonly filterParams: AgentFilterParams;
private readonly dynamicParams: Record<string, DataSource>;
public readonly nestedGraph?: GraphData;
public readonly retryLimit: number;
Expand All @@ -76,6 +78,7 @@ export class ComputedNode extends Node {
super(nodeId, graph);
this.graphId = graphId;
this.params = data.params ?? {};
this.filterParams = data.filterParams ?? {};
this.nestedGraph = data.graph;
if (typeof data.agent === "string") {
this.agentId = data.agent;
Expand Down Expand Up @@ -221,6 +224,9 @@ export class ComputedNode extends Node {
const agentFilter = this.graph.agentFilters[index++];
if (agentFilter) {
if (this.shouldApplyAgentFilter(agentFilter)) {
if (agentFilter.filterParams) {
context.filterParams = { ...agentFilter.filterParams, ...context.filterParams };
}
return agentFilter.agent(context, next);
}
return next(context);
Expand Down Expand Up @@ -269,7 +275,7 @@ export class ComputedNode extends Node {
retry: this.retryCount,
verbose: this.graph.verbose,
},
filterParams: {},
filterParams: this.filterParams,
log: localLog,
};

Expand Down
6 changes: 5 additions & 1 deletion src/type.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@ export type StaticNodeData = {
};
export type AgentNamelessFunction = (...param: any[]) => unknown;

export type AgentFilterParams = Record<string, any>;

export type ComputedNodeData = {
agent: string | AgentNamelessFunction;
inputs?: Array<any>;
anyInput?: boolean; // any input makes this node ready
params?: NodeDataParams;
filterParams?: AgentFilterParams; // agent filter
retry?: number;
timeout?: number; // msec
if?: string; // conditional execution
Expand Down Expand Up @@ -74,7 +77,7 @@ export type AgentFunctionContext<ParamsType = DefaultParamsType, InputDataType =
graphData?: GraphData | string; // nested graph
agents?: AgentFunctionDictonary; // for nested graph
taskManager?: TaskManager; // for nested graph
filterParams: Record<string, any>; // agent filter
filterParams: AgentFilterParams; // agent filter
log?: TransactionLog[];
};

Expand All @@ -92,6 +95,7 @@ export type AgentFilterInfo = {
agent: AgentFilterFunction;
agentIds?: string[];
nodeIds?: string[];
filterParams?: AgentFilterParams;
};

export type AgentFunctionDictonary = Record<string, AgentFunction<any, any, any>>;
Expand Down
2 changes: 1 addition & 1 deletion src/validators/common.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export const graphDataAttributeKeys = ["nodes", "concurrency", "agentId", "loop", "verbose", "version"];

export const computedNodeAttributeKeys = ["inputs", "anyInput", "params", "retry", "timeout", "agent", "graph", "isResult", "priority", "if"];
export const computedNodeAttributeKeys = ["inputs", "anyInput", "params", "retry", "timeout", "agent", "graph", "isResult", "priority", "if", "filterParams"];
export const staticNodeAttributeKeys = ["value", "update", "isResult"];
231 changes: 231 additions & 0 deletions tests/agentFilters/test_filter_params.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
import { GraphAI } from "@/graphai";
import { AgentFilterFunction } from "@/type";

import { defaultTestAgents } from "@/utils/test_agents";

import test from "node:test";
import assert from "node:assert";

const httpAgentFilter: AgentFilterFunction = async (context, next) => {
return next(context);
};

const callbackDictonary = {};

test("test filterParams on agent filter", async () => {
const graph_data = {
version: 0.3,
nodes: {
echo: {
agent: "echoAgent",
params: {
filterParams: true,
},
},
bypassAgent: {
agent: "bypassAgent",
inputs: [":echo"],
isResult: true,
},
},
};
const agentFilters = [
{
name: "httpAgentFilter",
agent: httpAgentFilter,
filterParams: {
agentServer: {
baseUrl: "http://localhost:8085/agentFilters/",
stream: true,
},
},
agentIds: ["echoAgent"],
},
];

const graph = new GraphAI({ ...graph_data }, { ...defaultTestAgents, ...callbackDictonary }, { agentFilters });

const result = await graph.run();
// console.log(JSON.stringify(result));
assert.deepStrictEqual(result, { bypassAgent: [{ agentServer: { baseUrl: "http://localhost:8085/agentFilters/", stream: true } }] });
});

test("test filterParams on node", async () => {
const graph_data = {
version: 0.3,
nodes: {
echo: {
agent: "echoAgent",
params: {
filterParams: true,
},
filterParams: {
agentServer: {
baseUrl: "http://localhost:8081/nodeParameter/",
},
},
},
bypassAgent: {
agent: "bypassAgent",
inputs: [":echo"],
isResult: true,
},
},
};
const agentFilters = [
{
name: "httpAgentFilter",
agent: httpAgentFilter,
agentIds: ["echoAgent"],
},
];

const graph = new GraphAI({ ...graph_data }, { ...defaultTestAgents, ...callbackDictonary }, { agentFilters });

const result = await graph.run();
// console.log(JSON.stringify(result));
assert.deepStrictEqual(result, { bypassAgent: [{ agentServer: { baseUrl: "http://localhost:8081/nodeParameter/" } }] });
});

test("test filterParams on agent filter and node. Then node.ts use filterParams on node", async () => {
const graph_data = {
version: 0.3,
nodes: {
echo: {
agent: "echoAgent",
params: {
filterParams: true,
},
filterParams: {
agentServer: {
baseUrl: "http://localhost:8081/nodeParameter/",
},
},
},
bypassAgent: {
agent: "bypassAgent",
inputs: [":echo"],
isResult: true,
},
},
};
const agentFilters = [
{
name: "httpAgentFilter",
agent: httpAgentFilter,
filterParams: {
agentServer: {
baseUrl: "http://localhost:8085/agentFilters/",
stream: true,
},
},
agentIds: ["echoAgent"],
},
];

const graph = new GraphAI({ ...graph_data }, { ...defaultTestAgents, ...callbackDictonary }, { agentFilters });

const result = await graph.run();
console.log(JSON.stringify(result));
assert.deepStrictEqual(result, { bypassAgent: [{ agentServer: { baseUrl: "http://localhost:8081/nodeParameter/" } }] });
});

test("test filterParams on each agent", async () => {
const graph_data = {
version: 0.3,
nodes: {
echo: {
agent: "echoAgent",
params: {
filterParams: true,
},
filterParams: {
agentServer: {
baseUrl: "http://localhost:8081/nodeParameter/",
},
},
},
echo2: {
agent: "echoAgent",
params: {
filterParams: true,
},
filterParams: {
agentServer: {
baseUrl: "http://localhost:8081/nodeParameter2/",
},
},
},
bypassAgent: {
agent: "bypassAgent",
inputs: [":echo", ":echo2"],
isResult: true,
},
},
};
const agentFilters = [
{
name: "httpAgentFilter",
agent: httpAgentFilter,
agentIds: ["echoAgent"],
},
];

const graph = new GraphAI({ ...graph_data }, { ...defaultTestAgents, ...callbackDictonary }, { agentFilters });

const result = await graph.run();
// console.log(JSON.stringify(result));
assert.deepStrictEqual(result, {
bypassAgent: [{ agentServer: { baseUrl: "http://localhost:8081/nodeParameter/" } }, { agentServer: { baseUrl: "http://localhost:8081/nodeParameter2/" } }],
});
});


test("test filterParams on agent filter", async () => {
const graph_data = {
version: 0.3,
nodes: {
echo: {
agent: "echoAgent",
params: {
filterParams: true,
},
},
bypassAgent: {
agent: "bypassAgent",
inputs: [":echo"],
isResult: true,
},
},
};
const agentFilters = [
{
name: "httpAgentFilter",
agent: httpAgentFilter,
filterParams: {
agentServer: {
baseUrl: "http://localhost:8085/agentFilters/",
stream: true,
},
},
agentIds: ["echoAgent"],
},
{
name: "httpAgentFilter",
agent: httpAgentFilter,
filterParams: {
agentServer: {
baseUrl: "http://localhost:8085/agentFilters2/",
stream: true,
},
},
agentIds: ["echoAgent"],
},
];

const graph = new GraphAI({ ...graph_data }, { ...defaultTestAgents, ...callbackDictonary }, { agentFilters });

const result = await graph.run();
// console.log(JSON.stringify(result));
assert.deepStrictEqual(result, { bypassAgent: [{ agentServer: { baseUrl: "http://localhost:8085/agentFilters/", stream: true } }] });
});

0 comments on commit 80837ce

Please sign in to comment.