Skip to content

Commit

Permalink
feat: added functionCallingConfig config option for gemini models (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelgj authored Nov 11, 2024
1 parent ea03813 commit c77d14e
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 5 deletions.
55 changes: 50 additions & 5 deletions js/plugins/googleai/src/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import {
FileDataPart,
FunctionCallingMode,
FunctionCallPart,
FunctionDeclaration,
FunctionResponsePart,
Expand All @@ -30,29 +31,30 @@ import {
SchemaType,
StartChatParams,
Tool,
ToolConfig,
} from '@google/generative-ai';
import {
GENKIT_CLIENT_HEADER,
Genkit,
GENKIT_CLIENT_HEADER,
GenkitError,
JSONSchema,
z,
} from 'genkit';
import {
CandidateData,
GenerationCommonConfigSchema,
getBasicUsageStats,
MediaPart,
MessageData,
ModelAction,
ModelInfo,
ModelMiddleware,
modelRef,
ModelReference,
Part,
ToolDefinitionSchema,
ToolRequestPart,
ToolResponsePart,
getBasicUsageStats,
modelRef,
} from 'genkit/model';
import {
downloadRequestMedia,
Expand All @@ -79,6 +81,12 @@ const SafetySettingsSchema = z.object({
export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({
safetySettings: z.array(SafetySettingsSchema).optional(),
codeExecution: z.union([z.boolean(), z.object({}).strict()]).optional(),
functionCallingConfig: z
.object({
mode: z.enum(['MODE_UNSPECIFIED', 'AUTO', 'ANY', 'NONE']).optional(),
allowedFunctionNames: z.array(z.string()).optional(),
})
.optional(),
});

export const gemini10Pro = modelRef({
Expand Down Expand Up @@ -531,7 +539,7 @@ export function defineGoogleAIModel(
if (apiVersion) {
options.baseUrl = baseUrl;
}
const requestConfig = {
const requestConfig: z.infer<typeof GeminiConfigSchema> = {
...defaultConfig,
...request.config,
};
Expand Down Expand Up @@ -565,7 +573,6 @@ export function defineGoogleAIModel(
}

const tools: Tool[] = [];

if (request.tools?.length) {
tools.push({
functionDeclarations: request.tools.map(toGeminiTool),
Expand All @@ -581,6 +588,19 @@ export function defineGoogleAIModel(
});
}

let toolConfig: ToolConfig | undefined;
if (requestConfig.functionCallingConfig) {
toolConfig = {
functionCallingConfig: {
allowedFunctionNames:
requestConfig.functionCallingConfig.allowedFunctionNames,
mode: toGeminiFunctionMode(
requestConfig.functionCallingConfig.mode
),
},
};
}

// cannot use tools with json mode
const jsonMode =
(request.output?.format === 'json' ||
Expand All @@ -605,6 +625,7 @@ export function defineGoogleAIModel(
systemInstruction,
generationConfig,
tools,
toolConfig,
history: messages
.slice(0, -1)
.map((message) => toGeminiMessage(message, model)),
Expand Down Expand Up @@ -668,3 +689,27 @@ export function defineGoogleAIModel(
}
);
}

function toGeminiFunctionMode(
genkitMode: string | undefined
): FunctionCallingMode | undefined {
if (genkitMode === undefined) {
return undefined;
}
switch (genkitMode) {
case 'MODE_UNSPECIFIED': {
return FunctionCallingMode.MODE_UNSPECIFIED;
}
case 'ANY': {
return FunctionCallingMode.ANY;
}
case 'AUTO': {
return FunctionCallingMode.AUTO;
}
case 'NONE': {
return FunctionCallingMode.NONE;
}
default:
throw new Error(`unsupported function calling mode: ${genkitMode}`);
}
}
45 changes: 45 additions & 0 deletions js/plugins/vertexai/src/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import {
Content,
FunctionCallingMode,
FunctionDeclaration,
FunctionDeclarationSchemaType,
Part as GeminiPart,
Expand All @@ -25,6 +26,7 @@ import {
HarmBlockThreshold,
HarmCategory,
StartChatParams,
ToolConfig,
VertexAI,
} from '@google-cloud/vertexai';
import { GENKIT_CLIENT_HEADER, Genkit, JSONSchema, z } from 'genkit';
Expand Down Expand Up @@ -71,6 +73,12 @@ export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({
location: z.string().optional(),
vertexRetrieval: VertexRetrievalSchema.optional(),
googleSearchRetrieval: GoogleSearchRetrievalSchema.optional(),
functionCallingConfig: z
.object({
mode: z.enum(['MODE_UNSPECIFIED', 'AUTO', 'ANY', 'NONE']).optional(),
allowedFunctionNames: z.array(z.string()).optional(),
})
.optional(),
});

export const gemini10Pro = modelRef({
Expand Down Expand Up @@ -485,6 +493,18 @@ export function defineGeminiModel(
? [{ functionDeclarations: request.tools?.map(toGeminiTool) }]
: [];

let toolConfig: ToolConfig | undefined;
if (request?.config?.functionCallingConfig) {
toolConfig = {
functionCallingConfig: {
allowedFunctionNames:
request.config.functionCallingConfig.allowedFunctionNames,
mode: toGeminiFunctionMode(
request.config.functionCallingConfig.mode
),
},
};
}
// Cannot use tools and function calling at the same time
const jsonMode =
(request.output?.format === 'json' || !!request.output?.schema) &&
Expand All @@ -493,6 +513,7 @@ export function defineGeminiModel(
const chatRequest: StartChatParams = {
systemInstruction,
tools,
toolConfig,
history: messages
.slice(0, -1)
.map((message) => toGeminiMessage(message, model)),
Expand Down Expand Up @@ -589,3 +610,27 @@ export function defineGeminiModel(
}
);
}

function toGeminiFunctionMode(
genkitMode: string | undefined
): FunctionCallingMode | undefined {
if (genkitMode === undefined) {
return undefined;
}
switch (genkitMode) {
case 'MODE_UNSPECIFIED': {
return FunctionCallingMode.MODE_UNSPECIFIED;
}
case 'ANY': {
return FunctionCallingMode.ANY;
}
case 'AUTO': {
return FunctionCallingMode.AUTO;
}
case 'NONE': {
return FunctionCallingMode.NONE;
}
default:
throw new Error(`unsupported function calling mode: ${genkitMode}`);
}
}

0 comments on commit c77d14e

Please sign in to comment.