diff --git a/.changeset/forty-bulldogs-shop.md b/.changeset/forty-bulldogs-shop.md new file mode 100644 index 0000000..0f8729c --- /dev/null +++ b/.changeset/forty-bulldogs-shop.md @@ -0,0 +1,6 @@ +--- +"@livekit/agents": minor +"@livekit/agents-plugin-openai": minor +--- + +OpenAI function calling: support arrays and optional fields in function call schema diff --git a/agents/src/llm/function_context.test.ts b/agents/src/llm/function_context.test.ts new file mode 100644 index 0000000..52a090f --- /dev/null +++ b/agents/src/llm/function_context.test.ts @@ -0,0 +1,247 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { describe, expect, it } from 'vitest'; +import { z } from 'zod'; +import { CallableFunction, oaiParams } from './function_context.js'; + +describe('function_context', () => { + describe('oaiParams', () => { + it('should handle basic object schema', () => { + const schema = z.object({ + name: z.string().describe('The user name'), + age: z.number().describe('The user age'), + }); + + const result = oaiParams(schema); + + expect(result).toEqual({ + type: 'object', + properties: { + name: { + type: 'string', + description: 'The user name', + }, + age: { + type: 'number', + description: 'The user age', + }, + }, + required: ['name', 'age'], + }); + }); + + it('should handle enum fields', () => { + const schema = z.object({ + color: z.enum(['red', 'blue', 'green']).describe('Choose a color'), + }); + + const result = oaiParams(schema); + + expect(result).toEqual({ + type: 'object', + properties: { + color: { + type: 'string', + description: 'Choose a color', + enum: ['red', 'blue', 'green'], + }, + }, + required: ['color'], + }); + }); + + it('should handle array fields', () => { + const schema = z.object({ + tags: z.array(z.string()).describe('List of tags'), + }); + + const result = oaiParams(schema); + + expect(result).toEqual({ + type: 'object', + properties: { + tags: { + type: 'array', + description: 'List of tags', + items: { + type: 'string', + }, + }, + }, + required: ['tags'], + }); + }); + + it('should handle array of enums', () => { + const schema = z.object({ + colors: z.array(z.enum(['red', 'blue', 'green'])).describe('List of colors'), + }); + + const result = oaiParams(schema); + + expect(result).toEqual({ + type: 'object', + properties: { + colors: { + type: 'array', + description: 'List of colors', + items: { + type: 'string', + enum: ['red', 'blue', 'green'], + }, + }, + }, + required: ['colors'], + }); + }); + + it('should handle optional fields', () => { + const schema = z.object({ + name: z.string().describe('The user name'), + age: z.number().optional().describe('The user age'), + }); + + const result = oaiParams(schema); + + expect(result).toEqual({ + type: 'object', + properties: { + name: { + type: 'string', + description: 'The user name', + }, + age: { + type: 'number', + description: 'The user age', + }, + }, + required: ['name'], // age should not be required + }); + }); + + it('should handle fields without descriptions', () => { + const schema = z.object({ + name: z.string(), + age: z.number(), + }); + + const result = oaiParams(schema); + + expect(result).toEqual({ + type: 'object', + properties: { + name: { + type: 'string', + description: undefined, + }, + age: { + type: 'number', + description: undefined, + }, + }, + required: ['name', 'age'], + }); + }); + }); + + describe('CallableFunction type', () => { + it('should properly type a callable function', async () => { + const schema = z.object({ + name: z.string().describe('The user name'), + age: z.number().describe('The user age'), + }); + + const testFunction: CallableFunction = { + description: 'Test function', + parameters: schema, + execute: async (args: z.infer) => { + // TypeScript should recognize args.name and args.age + return `${args.name} is ${args.age} years old`; + }, + }; + + const result = await testFunction.execute({ name: 'John', age: 30 }); + expect(result).toBe('John is 30 years old'); + }); + + it('should handle async execution', async () => { + const schema = z.object({ + delay: z.number().describe('Delay in milliseconds'), + }); + + const testFunction: CallableFunction = { + description: 'Async test function', + parameters: schema, + execute: async (args: z.infer) => { + await new Promise((resolve) => setTimeout(resolve, args.delay)); + return args.delay; + }, + }; + + const start = Date.now(); + const result = await testFunction.execute({ delay: 100 }); + const duration = Date.now() - start; + + expect(result).toBe(100); + expect(duration).toBeGreaterThanOrEqual(95); // Allow for small timing variations + }); + + describe('nested array support', () => { + it('should handle nested array fields', () => { + const schema = z.object({ + items: z.array( + z.object({ + name: z.string().describe('the item name'), + modifiers: z + .array( + z.object({ + modifier_name: z.string(), + modifier_value: z.string(), + }), + ) + .describe('list of the modifiers applied on this item, such as size'), + }), + ), + }); + const result = oaiParams(schema); + expect(result).toEqual({ + type: 'object', + properties: { + items: { + type: 'array', + description: undefined, + items: { + type: 'object', + properties: { + name: { + type: 'string', + description: 'the item name', + }, + modifiers: { + type: 'array', + description: 'list of the modifiers applied on this item, such as size', + items: { + type: 'object', + properties: { + modifier_name: { + type: 'string', + }, + modifier_value: { + type: 'string', + }, + }, + required: ['modifier_name', 'modifier_value'], + }, + }, + }, + required: ['name', 'modifiers'], + }, + }, + }, + required: ['items'], + }); + }); + }); + }); +}); diff --git a/agents/src/llm/function_context.ts b/agents/src/llm/function_context.ts index af193b7..864242a 100644 --- a/agents/src/llm/function_context.ts +++ b/agents/src/llm/function_context.ts @@ -34,29 +34,50 @@ export type FunctionContext = { /** @internal */ export const oaiParams = (p: z.AnyZodObject) => { const properties: Record = {}; - const required_properties: string[] = []; + const requiredProperties: string[] = []; - for (const key in p.shape) { - const field = p.shape[key]; - const description = field._def.description || undefined; - let type: string; - let enumValues: any[] | undefined; + const processZodType = (field: z.ZodTypeAny): any => { + const isOptional = field instanceof z.ZodOptional; + const nestedField = isOptional ? field._def.innerType : field; + const description = field._def.description; - if (field instanceof z.ZodEnum) { - enumValues = field._def.values; - type = typeof enumValues![0]; + if (nestedField instanceof z.ZodEnum) { + return { + type: typeof nestedField._def.values[0], + ...(description && { description }), + enum: nestedField._def.values, + }; + } else if (nestedField instanceof z.ZodArray) { + const elementType = nestedField._def.type; + return { + type: 'array', + ...(description && { description }), + items: processZodType(elementType), + }; + } else if (nestedField instanceof z.ZodObject) { + const { properties, required } = oaiParams(nestedField); + return { + type: 'object', + ...(description && { description }), + properties, + required, + }; } else { - type = field._def.typeName.toLowerCase(); + let type = nestedField._def.typeName.toLowerCase(); + type = type.includes('zod') ? type.substring(3) : type; + return { + type, + ...(description && { description }), + }; } + }; - properties[key] = { - type: type.includes('zod') ? type.substring(3) : type, - description, - enum: enumValues, - }; + for (const key in p.shape) { + const field = p.shape[key]; + properties[key] = processZodType(field); - if (!field._def.defaultValue) { - required_properties.push(key); + if (!(field instanceof z.ZodOptional)) { + requiredProperties.push(key); } } @@ -64,6 +85,6 @@ export const oaiParams = (p: z.AnyZodObject) => { return { type, properties, - required_properties, + required: requiredProperties, }; }; diff --git a/plugins/openai/src/realtime/api_proto.ts b/plugins/openai/src/realtime/api_proto.ts index f23e93e..9596f80 100644 --- a/plugins/openai/src/realtime/api_proto.ts +++ b/plugins/openai/src/realtime/api_proto.ts @@ -79,7 +79,7 @@ export interface Tool { [prop: string]: any; }; }; - required_properties: string[]; + required: string[]; }; }