Skip to content

Commit

Permalink
support arrays and optional fields in function call schema (#133)
Browse files Browse the repository at this point in the history
Co-authored-by: aoife cassidy <[email protected]>
  • Loading branch information
jb17q and nbsp authored Nov 4, 2024
1 parent 599195d commit 3349462
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 19 deletions.
6 changes: 6 additions & 0 deletions .changeset/forty-bulldogs-shop.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@livekit/agents": minor
"@livekit/agents-plugin-openai": minor
---

OpenAI function calling: support arrays and optional fields in function call schema
247 changes: 247 additions & 0 deletions agents/src/llm/function_context.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
// SPDX-FileCopyrightText: 2024 LiveKit, Inc.
//
// SPDX-License-Identifier: Apache-2.0
import { describe, expect, it } from 'vitest';
import { z } from 'zod';
import { CallableFunction, oaiParams } from './function_context.js';

describe('function_context', () => {
describe('oaiParams', () => {
it('should handle basic object schema', () => {
const schema = z.object({
name: z.string().describe('The user name'),
age: z.number().describe('The user age'),
});

const result = oaiParams(schema);

expect(result).toEqual({
type: 'object',
properties: {
name: {
type: 'string',
description: 'The user name',
},
age: {
type: 'number',
description: 'The user age',
},
},
required: ['name', 'age'],
});
});

it('should handle enum fields', () => {
const schema = z.object({
color: z.enum(['red', 'blue', 'green']).describe('Choose a color'),
});

const result = oaiParams(schema);

expect(result).toEqual({
type: 'object',
properties: {
color: {
type: 'string',
description: 'Choose a color',
enum: ['red', 'blue', 'green'],
},
},
required: ['color'],
});
});

it('should handle array fields', () => {
const schema = z.object({
tags: z.array(z.string()).describe('List of tags'),
});

const result = oaiParams(schema);

expect(result).toEqual({
type: 'object',
properties: {
tags: {
type: 'array',
description: 'List of tags',
items: {
type: 'string',
},
},
},
required: ['tags'],
});
});

it('should handle array of enums', () => {
const schema = z.object({
colors: z.array(z.enum(['red', 'blue', 'green'])).describe('List of colors'),
});

const result = oaiParams(schema);

expect(result).toEqual({
type: 'object',
properties: {
colors: {
type: 'array',
description: 'List of colors',
items: {
type: 'string',
enum: ['red', 'blue', 'green'],
},
},
},
required: ['colors'],
});
});

it('should handle optional fields', () => {
const schema = z.object({
name: z.string().describe('The user name'),
age: z.number().optional().describe('The user age'),
});

const result = oaiParams(schema);

expect(result).toEqual({
type: 'object',
properties: {
name: {
type: 'string',
description: 'The user name',
},
age: {
type: 'number',
description: 'The user age',
},
},
required: ['name'], // age should not be required
});
});

it('should handle fields without descriptions', () => {
const schema = z.object({
name: z.string(),
age: z.number(),
});

const result = oaiParams(schema);

expect(result).toEqual({
type: 'object',
properties: {
name: {
type: 'string',
description: undefined,
},
age: {
type: 'number',
description: undefined,
},
},
required: ['name', 'age'],
});
});
});

describe('CallableFunction type', () => {
it('should properly type a callable function', async () => {
const schema = z.object({
name: z.string().describe('The user name'),
age: z.number().describe('The user age'),
});

const testFunction: CallableFunction<typeof schema, string> = {
description: 'Test function',
parameters: schema,
execute: async (args: z.infer<typeof schema>) => {
// TypeScript should recognize args.name and args.age
return `${args.name} is ${args.age} years old`;
},
};

const result = await testFunction.execute({ name: 'John', age: 30 });
expect(result).toBe('John is 30 years old');
});

it('should handle async execution', async () => {
const schema = z.object({
delay: z.number().describe('Delay in milliseconds'),
});

const testFunction: CallableFunction<typeof schema, number> = {
description: 'Async test function',
parameters: schema,
execute: async (args: z.infer<typeof schema>) => {
await new Promise((resolve) => setTimeout(resolve, args.delay));
return args.delay;
},
};

const start = Date.now();
const result = await testFunction.execute({ delay: 100 });
const duration = Date.now() - start;

expect(result).toBe(100);
expect(duration).toBeGreaterThanOrEqual(95); // Allow for small timing variations
});

describe('nested array support', () => {
it('should handle nested array fields', () => {
const schema = z.object({
items: z.array(
z.object({
name: z.string().describe('the item name'),
modifiers: z
.array(
z.object({
modifier_name: z.string(),
modifier_value: z.string(),
}),
)
.describe('list of the modifiers applied on this item, such as size'),
}),
),
});
const result = oaiParams(schema);
expect(result).toEqual({
type: 'object',
properties: {
items: {
type: 'array',
description: undefined,
items: {
type: 'object',
properties: {
name: {
type: 'string',
description: 'the item name',
},
modifiers: {
type: 'array',
description: 'list of the modifiers applied on this item, such as size',
items: {
type: 'object',
properties: {
modifier_name: {
type: 'string',
},
modifier_value: {
type: 'string',
},
},
required: ['modifier_name', 'modifier_value'],
},
},
},
required: ['name', 'modifiers'],
},
},
},
required: ['items'],
});
});
});
});
});
57 changes: 39 additions & 18 deletions agents/src/llm/function_context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,36 +34,57 @@ export type FunctionContext = {
/** @internal */
export const oaiParams = (p: z.AnyZodObject) => {
const properties: Record<string, any> = {};
const required_properties: string[] = [];
const requiredProperties: string[] = [];

for (const key in p.shape) {
const field = p.shape[key];
const description = field._def.description || undefined;
let type: string;
let enumValues: any[] | undefined;
const processZodType = (field: z.ZodTypeAny): any => {
const isOptional = field instanceof z.ZodOptional;
const nestedField = isOptional ? field._def.innerType : field;
const description = field._def.description;

if (field instanceof z.ZodEnum) {
enumValues = field._def.values;
type = typeof enumValues![0];
if (nestedField instanceof z.ZodEnum) {
return {
type: typeof nestedField._def.values[0],
...(description && { description }),
enum: nestedField._def.values,
};
} else if (nestedField instanceof z.ZodArray) {
const elementType = nestedField._def.type;
return {
type: 'array',
...(description && { description }),
items: processZodType(elementType),
};
} else if (nestedField instanceof z.ZodObject) {
const { properties, required } = oaiParams(nestedField);
return {
type: 'object',
...(description && { description }),
properties,
required,
};
} else {
type = field._def.typeName.toLowerCase();
let type = nestedField._def.typeName.toLowerCase();
type = type.includes('zod') ? type.substring(3) : type;
return {
type,
...(description && { description }),
};
}
};

properties[key] = {
type: type.includes('zod') ? type.substring(3) : type,
description,
enum: enumValues,
};
for (const key in p.shape) {
const field = p.shape[key];
properties[key] = processZodType(field);

if (!field._def.defaultValue) {
required_properties.push(key);
if (!(field instanceof z.ZodOptional)) {
requiredProperties.push(key);
}
}

const type = 'object' as const;
return {
type,
properties,
required_properties,
required: requiredProperties,
};
};
2 changes: 1 addition & 1 deletion plugins/openai/src/realtime/api_proto.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ export interface Tool {
[prop: string]: any;
};
};
required_properties: string[];
required: string[];
};
}

Expand Down

0 comments on commit 3349462

Please sign in to comment.