diff --git a/src/graphai.ts b/src/graphai.ts index 03d6db690..314c52e16 100644 --- a/src/graphai.ts +++ b/src/graphai.ts @@ -19,6 +19,7 @@ export class GraphAI { private readonly data: GraphData; private readonly loop?: LoopData; private readonly logs: Array = []; + private readonly hooks: Record) => void>; public readonly callbackDictonary: AgentFunctionDictonary; public readonly taskManager: TaskManager; public readonly agentFilters: AgentFilterInfo[]; @@ -90,7 +91,7 @@ export class GraphAI { constructor( data: GraphData, callbackDictonary: AgentFunctionDictonary, - options: { agentFilters?: AgentFilterInfo[] | undefined; taskManager?: TaskManager | undefined } = { taskManager: undefined, agentFilters: [] }, + options: { agentFilters?: AgentFilterInfo[] | undefined; taskManager?: TaskManager | undefined, hooks?: Record) => void> } = { taskManager: undefined, agentFilters: [] }, ) { if (!data.version && !options.taskManager) { console.log("------------ missing version number"); @@ -105,6 +106,7 @@ export class GraphAI { this.callbackDictonary = callbackDictonary; this.taskManager = options.taskManager ?? new TaskManager(data.concurrency ?? defaultConcurrency); this.agentFilters = options.agentFilters ?? []; + this.hooks = options.hooks ?? {}; this.loop = data.loop; this.verbose = data.verbose === true; this.onComplete = () => { @@ -291,4 +293,8 @@ export class GraphAI { return getDataFromSource(result, source); }); } + + public getHook(hookId: string) { + return this.hooks[hookId]; + } } diff --git a/src/node.ts b/src/node.ts index a6bc45bac..ebf9c96fe 100644 --- a/src/node.ts +++ b/src/node.ts @@ -108,6 +108,10 @@ export class ComputedNode extends Node { assert(!this.anyInput, "Dynamic params are not supported with anyInput"); tmp[key] = dataSource; this.pendings.add(dataSource.nodeId); + } else if (dataSource.isHook) { + const hook = graph.getHook(dataSource.value); + assert(hook !== undefined, `Specified hook does not exist: ${dataSource.value}`); + this.params[key] = hook; } return tmp; }, {}); diff --git a/src/type.ts b/src/type.ts index 5010b0d5a..b000f4bc2 100644 --- a/src/type.ts +++ b/src/type.ts @@ -24,6 +24,7 @@ export type DataSource = { nodeId?: string; value?: any; propIds?: string[]; + isHook?: boolean; }; export type StaticNodeData = { diff --git a/src/utils/utils.ts b/src/utils/utils.ts index 3306154c5..546797822 100644 --- a/src/utils/utils.ts +++ b/src/utils/utils.ts @@ -25,6 +25,12 @@ export const parseNodeName = (inputNodeId: any, version: number): DataSource => return parseNodeName_02(inputNodeId); } if (typeof inputNodeId === "string") { + const regexHook = /^~(.*)$/; + const matchHook = inputNodeId.match(regexHook); + if (matchHook) { + console.log("****", inputNodeId, matchHook[1]); + return { value: matchHook[1], isHook: true }; + } const regex = /^:(.*)$/; const match = inputNodeId.match(regex); if (!match) { diff --git a/tests/graphai/test_hooks.ts b/tests/graphai/test_hooks.ts new file mode 100644 index 000000000..9b0a03f88 --- /dev/null +++ b/tests/graphai/test_hooks.ts @@ -0,0 +1,58 @@ +import { GraphAI } from "@/graphai"; +import { defaultTestAgents } from "@/utils/test_agents"; +import { AgentFunction } from "@/graphai"; +import { sleep } from "@/utils/utils"; + +import test from "node:test"; +// import assert from "node:assert"; + +const graphdata_hook = { + version: 0.3, + nodes: { + source: { + value: "May the force be with you." + }, + streamNode: { + agent: "streamAgent", + params: { + stream: "~test_hook", + }, + isResult: true, + inputs: [":source"], + }, + }, +}; + + +const streamAgent: AgentFunction<{ stream?: (data:Record)=> void }, string, string> = async ({ + params, inputs +}) => { + const [message] = inputs; + const {stream} = params; + if (stream) { + for await (const word of message.split(' ')) { + await sleep(10); + stream({word}); + }; + } + return message; +}; + +const test_hook = (data: Record) => { + console.log(data.word); +} + +test("test hook", async () => { + const graph = new GraphAI(graphdata_hook, { ...defaultTestAgents, streamAgent }, { hooks: { test_hook }}); + const result = await graph.run(false); + console.log(result); + /* + assert.deepStrictEqual(result, { + node1: { node1: "output" }, + node2: { port1: { node2: "dispatch" } }, + node3: { node3: "output", node1: "output", node2: "dispatch" }, + node4: { node4: "output", node3: "output", node1: "output", node2: "dispatch" }, + node5: { node5: "output", node4: "output", node3: "output", node1: "output", node2: "dispatch" }, + }); + */ +});