Skip to content

Commit

Permalink
Merge pull request #867 from isamu/gemini_agent_stream
Browse files Browse the repository at this point in the history
Gemini agent stream
  • Loading branch information
isamu authored Dec 30, 2024
2 parents e02347f + 028e044 commit c941949
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 8 deletions.
1 change: 1 addition & 0 deletions llm_agents/gemini_agent/lib/gemini_agent.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ type GeminiInputs = {
} & GraphAILLMInputBase;
type GeminiConfig = {
apiKey?: string;
stream?: boolean;
};
type GeminiParams = GeminiInputs & GeminiConfig;
export declare const geminiAgent: AgentFunction<GeminiParams, Record<string, any> | string, GeminiInputs, GeminiConfig>;
Expand Down
20 changes: 17 additions & 3 deletions llm_agents/gemini_agent/lib/gemini_agent.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ exports.geminiAgent = void 0;
const graphai_1 = require("graphai");
const generative_ai_1 = require("@google/generative-ai");
const llm_utils_1 = require("@graphai/llm_utils");
const geminiAgent = async ({ params, namedInputs, config }) => {
const geminiAgent = async ({ params, namedInputs, config, filterParams, }) => {
const { model, system, temperature, max_tokens, tools, prompt, messages } = { ...params, ...namedInputs };
const { apiKey /* stream */ } = {
const { apiKey, stream } = {
...params,
...(config || {}),
};
Expand Down Expand Up @@ -60,6 +60,20 @@ const geminiAgent = async ({ params, namedInputs, config }) => {
}),
generationConfig,
});
if (stream) {
const result = await chat.sendMessageStream(lastMessage.content);
const chunks = [];
for await (const chunk of result.stream) {
const chunkText = chunk.text();
if (filterParams && filterParams.streamTokenCallback && chunkText) {
filterParams.streamTokenCallback(chunkText);
}
chunks.push(chunkText);
}
const text = chunks.join("");
const message = { role: "assistant", content: text };
return { choices: [{ message }], text, message };
}
const result = await chat.sendMessage(lastMessage.content);
const response = result.response;
const text = response.text();
Expand Down Expand Up @@ -111,7 +125,7 @@ const geminiAgentInfo = {
author: "Receptron team",
repository: "https://github.com/receptron/graphai",
license: "MIT",
// stream: true,
stream: true,
npms: ["@anthropic-ai/sdk"],
environmentVariables: ["GOOGLE_GENAI_API_KEY"],
};
Expand Down
2 changes: 1 addition & 1 deletion llm_agents/gemini_agent/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@graphai/gemini_agent",
"version": "0.2.1",
"version": "0.2.2",
"description": "Gemini agents for GraphAI.",
"main": "lib/index.js",
"files": [
Expand Down
28 changes: 24 additions & 4 deletions llm_agents/gemini_agent/src/gemini_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,20 @@ type GeminiInputs = {

type GeminiConfig = {
apiKey?: string;
// stream?: boolean;
stream?: boolean;
};

type GeminiParams = GeminiInputs & GeminiConfig;

export const geminiAgent: AgentFunction<GeminiParams, Record<string, any> | string, GeminiInputs, GeminiConfig> = async ({ params, namedInputs, config }) => {
export const geminiAgent: AgentFunction<GeminiParams, Record<string, any> | string, GeminiInputs, GeminiConfig> = async ({
params,
namedInputs,
config,
filterParams,
}) => {
const { model, system, temperature, max_tokens, tools, prompt, messages } = { ...params, ...namedInputs };

const { apiKey /* stream */ } = {
const { apiKey, stream } = {
...params,
...(config || {}),
};
Expand Down Expand Up @@ -83,6 +88,21 @@ export const geminiAgent: AgentFunction<GeminiParams, Record<string, any> | stri
generationConfig,
});

if (stream) {
const result = await chat.sendMessageStream(lastMessage.content);
const chunks = [];
for await (const chunk of result.stream) {
const chunkText = chunk.text();
if (filterParams && filterParams.streamTokenCallback && chunkText) {
filterParams.streamTokenCallback(chunkText);
}
chunks.push(chunkText);
}
const text = chunks.join("");
const message: any = { role: "assistant", content: text };
return { choices: [{ message }], text, message };
}

const result = await chat.sendMessage(lastMessage.content);
const response = result.response;
const text = response.text();
Expand Down Expand Up @@ -136,7 +156,7 @@ const geminiAgentInfo: AgentFunctionInfo = {
author: "Receptron team",
repository: "https://github.com/receptron/graphai",
license: "MIT",
// stream: true,
stream: true,
npms: ["@anthropic-ai/sdk"],
environmentVariables: ["GOOGLE_GENAI_API_KEY"],
};
Expand Down
21 changes: 21 additions & 0 deletions llm_agents/gemini_agent/tests/run_gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,24 @@ test("test gemini", async () => {
}
assert.deepStrictEqual(true, true);
});

test("test gemini stream", async () => {
const namedInputs = { prompt: ["tell me world history"] };
const params = { stream: true };
const res = (await geminiAgent({
namedInputs,
params,
filterParams: {
streamTokenCallback: (token: string) => {
console.log(token);
},
},
debugInfo: { verbose: false, nodeId: "test", retry: 5 },
})) as any;

if (res) {
console.log(res.choices[0].message["content"]);
console.log(res.text);
}
assert.deepStrictEqual(true, true);
});

0 comments on commit c941949

Please sign in to comment.