diff --git a/.changeset/purple-beds-clean.md b/.changeset/purple-beds-clean.md new file mode 100644 index 00000000..e746bf37 --- /dev/null +++ b/.changeset/purple-beds-clean.md @@ -0,0 +1,7 @@ +--- +"@livekit/agents": patch +"@livekit/agents-plugin-openai": patch +"livekit-agents-examples": patch +--- + +add ChatContext diff --git a/agents/src/llm/chat_context.ts b/agents/src/llm/chat_context.ts new file mode 100644 index 00000000..0043e2e6 --- /dev/null +++ b/agents/src/llm/chat_context.ts @@ -0,0 +1,136 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import type { AudioFrame } from '@livekit/rtc-node'; +import type { CallableFunctionResult, FunctionContext } from './function_context.js'; + +export enum ChatRole { + SYSTEM, + USER, + ASSISTANT, + TOOL, +} + +export interface ChatImage { + image: string | AudioFrame; + inferenceWidth?: number; + inferenceHeight?: number; + /** + * @internal + * Used by LLM implementations to store a processed version of the image for later use. + */ + cache: { [id: string | number | symbol]: any }; +} + +export interface ChatAudio { + frame: AudioFrame | AudioFrame[]; +} + +export type ChatContent = string | ChatImage | ChatAudio; + +const defaultCreateChatMessage = { + text: '', + images: [], + role: ChatRole.SYSTEM, +}; + +export class ChatMessage { + readonly role: ChatRole; + readonly id?: string; + readonly name?: string; + readonly content?: ChatContent | ChatContent[]; + readonly toolCalls?: FunctionContext; + readonly toolCallId?: string; + readonly toolException?: Error; + + /** @internal */ + constructor({ + role, + id, + name, + content, + toolCalls, + toolCallId, + toolException, + }: { + role: ChatRole; + id?: string; + name?: string; + content?: ChatContent | ChatContent[]; + toolCalls?: FunctionContext; + toolCallId?: string; + toolException?: Error; + }) { + this.role = role; + this.id = id; + this.name = name; + this.content = content; + this.toolCalls = toolCalls; + this.toolCallId = toolCallId; + this.toolException = toolException; + } + + static createToolFromFunctionResult(func: CallableFunctionResult): ChatMessage { + if (!func.result && !func.error) { + throw new TypeError('CallableFunctionResult must include result or error'); + } + + return new ChatMessage({ + role: ChatRole.TOOL, + name: func.name, + content: func.result || `Error: ${func.error}`, + toolCallId: func.toolCallId, + toolException: func.error, + }); + } + + static createToolCalls(toolCalls: FunctionContext, text = '') { + return new ChatMessage({ + role: ChatRole.ASSISTANT, + toolCalls, + content: text, + }); + } + + static create( + options: Partial<{ + text?: string; + images: ChatImage[]; + role: ChatRole; + }>, + ): ChatMessage { + const { text, images, role } = { ...defaultCreateChatMessage, ...options }; + + if (!images.length) { + return new ChatMessage({ + role: ChatRole.ASSISTANT, + content: text, + }); + } else { + return new ChatMessage({ + role, + content: [...(text ? [text] : []), ...images], + }); + } + } + + /** Returns a structured clone of this message. */ + copy(): ChatMessage { + return structuredClone(this); + } +} + +export class ChatContext { + messages: ChatMessage[] = []; + metadata: { [id: string]: any } = {}; + + append(msg: { text?: string; images: ChatImage[]; role: ChatRole }): ChatContext { + this.messages.push(ChatMessage.create(msg)); + return this; + } + + /** Returns a structured clone of this context. */ + copy(): ChatContext { + return structuredClone(this); + } +} diff --git a/agents/src/llm/function_context.ts b/agents/src/llm/function_context.ts index e6673abb..af193b78 100644 --- a/agents/src/llm/function_context.ts +++ b/agents/src/llm/function_context.ts @@ -18,6 +18,14 @@ export interface CallableFunction

{ execute: (args: inferParameters

) => PromiseLike; } +/** A currently-running function call, called by the LLM. */ +export interface CallableFunctionResult { + name: string; + toolCallId: string; + result?: any; + error?: any; +} + /** An object containing callable functions and their names */ export type FunctionContext = { [name: string]: CallableFunction; diff --git a/agents/src/llm/index.ts b/agents/src/llm/index.ts index 80336ecb..a2672fde 100644 --- a/agents/src/llm/index.ts +++ b/agents/src/llm/index.ts @@ -1,11 +1,19 @@ // SPDX-FileCopyrightText: 2024 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -import { +export { type CallableFunction, + type CallableFunctionResult, type FunctionContext, type inferParameters, oaiParams, } from './function_context.js'; -export { CallableFunction, FunctionContext, inferParameters, oaiParams }; +export { + type ChatImage, + type ChatAudio, + type ChatContent, + ChatRole, + ChatMessage, + ChatContext, +} from './chat_context.js'; diff --git a/agents/src/multimodal/multimodal_agent.ts b/agents/src/multimodal/multimodal_agent.ts index 045ef868..d7cf1395 100644 --- a/agents/src/multimodal/multimodal_agent.ts +++ b/agents/src/multimodal/multimodal_agent.ts @@ -64,13 +64,16 @@ export class MultimodalAgent extends EventEmitter { constructor({ model, + chatCtx, fncCtx, }: { model: RealtimeModel; - fncCtx?: llm.FunctionContext | undefined; + chatCtx?: llm.ChatContext; + fncCtx?: llm.FunctionContext; }) { super(); this.model = model; + this.#chatCtx = chatCtx; this.#fncCtx = fncCtx; } @@ -83,6 +86,7 @@ export class MultimodalAgent extends EventEmitter { #logger = log(); #session: RealtimeSession | null = null; #fncCtx: llm.FunctionContext | undefined = undefined; + #chatCtx: llm.ChatContext | undefined = undefined; #_started: boolean = false; #_pendingFunctionCalls: Set = new Set(); @@ -200,7 +204,7 @@ export class MultimodalAgent extends EventEmitter { } } - this.#session = this.model.session({ fncCtx: this.#fncCtx }); + this.#session = this.model.session({ fncCtx: this.#fncCtx, chatCtx: this.#chatCtx }); this.#started = true; // eslint-disable-next-line @typescript-eslint/no-explicit-any diff --git a/examples/src/minimal_assistant.ts b/examples/src/minimal_assistant.ts index 765f118a..fcfca73d 100644 --- a/examples/src/minimal_assistant.ts +++ b/examples/src/minimal_assistant.ts @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: 2024 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -import { type JobContext, WorkerOptions, cli, defineAgent, multimodal } from '@livekit/agents'; +import { type JobContext, WorkerOptions, cli, defineAgent, llm, multimodal } from '@livekit/agents'; import * as openai from '@livekit/agents-plugin-openai'; import { fileURLToPath } from 'node:url'; import { z } from 'zod'; @@ -52,11 +52,12 @@ export default defineAgent({ .start(ctx.room, participant) .then((session) => session as openai.realtime.RealtimeSession); - session.conversation.item.create({ - type: 'message', - role: 'user', - content: [{ type: 'input_text', text: 'Say "How can I help you today?"' }], - }); + session.conversation.item.create( + llm.ChatMessage.create({ + role: llm.ChatRole.USER, + text: 'Say "How can I help you today?"', + }), + ); session.response.create(); }, }); diff --git a/plugins/openai/src/realtime/realtime_model.ts b/plugins/openai/src/realtime/realtime_model.ts index 23b16e5a..af33a185 100644 --- a/plugins/openai/src/realtime/realtime_model.ts +++ b/plugins/openai/src/realtime/realtime_model.ts @@ -1,7 +1,15 @@ // SPDX-FileCopyrightText: 2024 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -import { AsyncIterableQueue, Future, Queue, llm, log, multimodal } from '@livekit/agents'; +import { + AsyncIterableQueue, + Future, + Queue, + llm, + log, + mergeFrames, + multimodal, +} from '@livekit/agents'; import { AudioFrame } from '@livekit/rtc-node'; import { once } from 'node:events'; import { WebSocket } from 'ws'; @@ -108,6 +116,7 @@ class InputAudioBuffer { class ConversationItem { #session: RealtimeSession; + #logger = log(); constructor(session: RealtimeSession) { this.#session = session; @@ -129,12 +138,127 @@ class ConversationItem { }); } - create(item: api_proto.ConversationItemCreateContent, previousItemId?: string): void { - this.#session.queueMsg({ - type: 'conversation.item.create', - item, - previous_item_id: previousItemId, - }); + // create(item: api_proto.ConversationItemCreateContent, previousItemId?: string): void { + create(message: llm.ChatMessage, previousItemId?: string): void { + if (!message.content) { + return; + } + + let event: api_proto.ConversationItemCreateEvent; + + if (message.toolCallId) { + if (typeof message.content !== 'string') { + throw new TypeError('message.content must be a string'); + } + + event = { + type: 'conversation.item.create', + previous_item_id: previousItemId, + item: { + type: 'function_call_output', + call_id: message.toolCallId, + output: message.content, + }, + }; + } else { + let content = message.content; + if (!Array.isArray(content)) { + content = [content]; + } + + if (message.role === llm.ChatRole.USER) { + const contents: (api_proto.InputTextContent | api_proto.InputAudioContent)[] = []; + for (const c of content) { + if (typeof c === 'string') { + contents.push({ + type: 'input_text', + text: c, + }); + } else if ( + // typescript type guard for determining ChatAudio vs ChatImage + ((c: llm.ChatAudio | llm.ChatImage): c is llm.ChatAudio => { + return (c as llm.ChatAudio).frame !== undefined; + })(c) + ) { + contents.push({ + type: 'input_audio', + audio: Buffer.from(mergeFrames(c.frame).data.buffer).toString('base64'), + }); + } + } + + event = { + type: 'conversation.item.create', + previous_item_id: previousItemId, + item: { + type: 'message', + role: 'user', + content: contents, + }, + }; + } else if (message.role === llm.ChatRole.ASSISTANT) { + const contents: api_proto.TextContent[] = []; + for (const c of content) { + if (typeof c === 'string') { + contents.push({ + type: 'text', + text: c, + }); + } else if ( + // typescript type guard for determining ChatAudio vs ChatImage + ((c: llm.ChatAudio | llm.ChatImage): c is llm.ChatAudio => { + return (c as llm.ChatAudio).frame !== undefined; + })(c) + ) { + this.#logger.warn('audio content in assistant message is not supported'); + } + } + + event = { + type: 'conversation.item.create', + previous_item_id: previousItemId, + item: { + type: 'message', + role: 'assistant', + content: contents, + }, + }; + } else if (message.role === llm.ChatRole.SYSTEM) { + const contents: api_proto.InputTextContent[] = []; + for (const c of content) { + if (typeof c === 'string') { + contents.push({ + type: 'input_text', + text: c, + }); + } else if ( + // typescript type guard for determining ChatAudio vs ChatImage + ((c: llm.ChatAudio | llm.ChatImage): c is llm.ChatAudio => { + return (c as llm.ChatAudio).frame !== undefined; + })(c) + ) { + this.#logger.warn('audio content in system message is not supported'); + } + } + + event = { + type: 'conversation.item.create', + previous_item_id: previousItemId, + item: { + type: 'message', + role: 'system', + content: contents, + }, + }; + } else { + this.#logger + .child({ message }) + .warn('chat message is not supported inside the realtime API'); + return; + } + } + + this.#session.queueMsg(event); } } @@ -302,6 +426,7 @@ export class RealtimeModel extends multimodal.RealtimeModel { session({ fncCtx, + chatCtx, modalities = this.#defaultOpts.modalities, instructions = this.#defaultOpts.instructions, voice = this.#defaultOpts.voice, @@ -313,6 +438,7 @@ export class RealtimeModel extends multimodal.RealtimeModel { maxResponseOutputTokens = this.#defaultOpts.maxResponseOutputTokens, }: { fncCtx?: llm.FunctionContext; + chatCtx?: llm.ChatContext; modalities?: ['text', 'audio'] | ['text']; instructions?: string; voice?: api_proto.Voice; @@ -341,7 +467,10 @@ export class RealtimeModel extends multimodal.RealtimeModel { entraToken: this.#defaultOpts.entraToken, }; - const newSession = new RealtimeSession(opts, fncCtx); + const newSession = new RealtimeSession(opts, { + chatCtx: chatCtx || new llm.ChatContext(), + fncCtx, + }); this.#sessions.push(newSession); return newSession; } @@ -352,6 +481,7 @@ export class RealtimeModel extends multimodal.RealtimeModel { } export class RealtimeSession extends multimodal.RealtimeSession { + #chatCtx: llm.ChatContext | undefined = undefined; #fncCtx: llm.FunctionContext | undefined = undefined; #opts: ModelOptions; #pendingResponses: { [id: string]: RealtimeResponse } = {}; @@ -363,10 +493,14 @@ export class RealtimeSession extends multimodal.RealtimeSession { #closing = true; #sendQueue = new Queue(); - constructor(opts: ModelOptions, fncCtx?: llm.FunctionContext | undefined) { + constructor( + opts: ModelOptions, + { fncCtx, chatCtx }: { fncCtx?: llm.FunctionContext; chatCtx?: llm.ChatContext }, + ) { super(); this.#opts = opts; + this.#chatCtx = chatCtx; this.#fncCtx = fncCtx; this.#task = this.#start(); @@ -385,6 +519,10 @@ export class RealtimeSession extends multimodal.RealtimeSession { }); } + get chatCtx(): llm.ChatContext | undefined { + return this.#chatCtx; + } + get fncCtx(): llm.FunctionContext | undefined { return this.#fncCtx; } @@ -869,11 +1007,11 @@ export class RealtimeSession extends multimodal.RealtimeSession { callId: item.call_id, }); this.conversation.item.create( - { - type: 'function_call_output', - call_id: item.call_id, - output: content, - }, + llm.ChatMessage.createToolFromFunctionResult({ + name: item.name, + toolCallId: item.call_id, + result: content, + }), output.itemId, ); this.response.create();