diff --git a/packages/core/src/services/chains/ChainStreamConsumer/index.test.ts b/packages/core/src/services/chains/ChainStreamConsumer/index.test.ts index 8914f97ea..678a290c4 100644 --- a/packages/core/src/services/chains/ChainStreamConsumer/index.test.ts +++ b/packages/core/src/services/chains/ChainStreamConsumer/index.test.ts @@ -134,7 +134,12 @@ describe('ChainStreamConsumer', () => { messages: [ { role: MessageRole.assistant, - content: 'text response', + content: [ + { + type: ContentType.text, + text: 'text response', + }, + ], toolCalls: [], }, ], @@ -165,7 +170,12 @@ describe('ChainStreamConsumer', () => { messages: [ { role: MessageRole.assistant, - content: '{\n "object": "response"\n}', + content: [ + { + type: ContentType.text, + text: '{\n "object": "response"\n}', + }, + ], toolCalls: [], }, ], @@ -250,14 +260,13 @@ describe('ChainStreamConsumer', () => { config: step.conversation.config, response: response, messages: [ - { - role: MessageRole.assistant, - content: 'text response', - toolCalls: [], - }, { role: MessageRole.assistant, content: [ + { + type: ContentType.text, + text: 'text response', + }, { type: ContentType.toolCall, toolCallId: 'tool-call-id', diff --git a/packages/core/src/services/chains/ChainStreamConsumer/index.ts b/packages/core/src/services/chains/ChainStreamConsumer/index.ts index 471d73d31..17329fba9 100644 --- a/packages/core/src/services/chains/ChainStreamConsumer/index.ts +++ b/packages/core/src/services/chains/ChainStreamConsumer/index.ts @@ -1,5 +1,6 @@ import { ContentType, + MessageContent, MessageRole, ToolRequestContent, } from '@latitude-data/compiler' @@ -56,36 +57,58 @@ export class ChainStreamConsumer { config: Config }) { let messages: Message[] = [] + let message: Message | undefined = undefined if (response.text.length > 0) { - messages.push({ + message = { role: MessageRole.assistant, - content: response.text, + content: [ + { + type: ContentType.text, + text: response.text, + }, + ], toolCalls: [], - }) + } } if (response.streamType === 'object' && response.object) { - messages.push({ + message = { role: MessageRole.assistant, - content: objectToString(response.object), + content: [ + { + type: ContentType.text, + text: objectToString(response.object), + }, + ], toolCalls: [], - }) + } } if (response.streamType === 'text' && response.toolCalls.length > 0) { - messages.push({ - role: MessageRole.assistant, - content: response.toolCalls.map((toolCall) => { - return { - type: ContentType.toolCall, - toolCallId: toolCall.id, - toolName: toolCall.name, - args: toolCall.arguments, - } as ToolRequestContent - }), - toolCalls: response.toolCalls, + const content = response.toolCalls.map((toolCall) => { + return { + type: ContentType.toolCall, + toolCallId: toolCall.id, + toolName: toolCall.name, + args: toolCall.arguments, + } as ToolRequestContent }) + + if (message) { + message.content = (message.content as MessageContent[]).concat(content) + message.toolCalls = response.toolCalls + } else { + message = { + role: MessageRole.assistant, + content: content, + toolCalls: response.toolCalls, + } + } + } + + if (message) { + messages.push(message) } enqueueChainEvent(controller, { diff --git a/packages/core/src/services/chains/run.test.ts b/packages/core/src/services/chains/run.test.ts index 1151c838b..d9bf58e9f 100644 --- a/packages/core/src/services/chains/run.test.ts +++ b/packages/core/src/services/chains/run.test.ts @@ -662,14 +662,13 @@ describe('runChain', () => { data: expect.objectContaining({ type: ChainEventTypes.Complete, messages: [ - { - role: MessageRole.assistant, - content: 'assistant message', - toolCalls: [], - }, { role: MessageRole.assistant, content: [ + { + type: ContentType.text, + text: 'assistant message', + }, { type: ContentType.toolCall, toolCallId: 'tool-call-id', diff --git a/packages/core/src/services/commits/runDocumentAtCommit.test.ts b/packages/core/src/services/commits/runDocumentAtCommit.test.ts index c204195d7..e09f7b797 100644 --- a/packages/core/src/services/commits/runDocumentAtCommit.test.ts +++ b/packages/core/src/services/commits/runDocumentAtCommit.test.ts @@ -263,7 +263,12 @@ model: gpt-4o messages: [ { role: 'assistant', - content: [{ type: 'text', text: 'Fake AI generated text' }], + content: [ + { + type: 'text', + text: 'Fake AI generated text', + }, + ], }, { role: 'system', @@ -306,7 +311,12 @@ model: gpt-4o messages: [ { role: 'assistant', - content: 'Fake AI generated text', + content: [ + { + type: 'text', + text: 'Fake AI generated text', + }, + ], toolCalls: [], }, ], diff --git a/packages/core/src/services/documentLogs/addMessages/index.test.ts b/packages/core/src/services/documentLogs/addMessages/index.test.ts index 46d1d2a4a..657b93be7 100644 --- a/packages/core/src/services/documentLogs/addMessages/index.test.ts +++ b/packages/core/src/services/documentLogs/addMessages/index.test.ts @@ -6,6 +6,7 @@ import { LatitudeErrorCodes, RunErrorCodes, } from '@latitude-data/constants/errors' +import { eq } from 'drizzle-orm' import { ChainEventTypes, Commit, @@ -18,8 +19,10 @@ import { User, Workspace, } from '../../../browser' +import { database } from '../../../client' import { Result, TypedResult } from '../../../lib' import { ProviderLogsRepository } from '../../../repositories' +import { providerLogs } from '../../../schema' import { createDocumentLog, createProject } from '../../../tests/factories' import { testConsumeStream } from '../../../tests/helpers' import * as aiModule from '../../ai' @@ -143,7 +146,18 @@ describe('addMessages', () => { expect(result.error).toBeDefined() }) - it('pass arguments to AI service', async () => { + it('pass arguments to AI service with text response', async () => { + providerLog = await database + .update(providerLogs) + .set({ + responseText: 'assistant message', + responseObject: null, + toolCalls: [], + }) + .where(eq(providerLogs.id, providerLog.id)) + .returning() + .then((p) => p[0]!) + const { stream } = await addMessages({ workspace, documentLogUuid: providerLog.documentLogUuid!, @@ -153,7 +167,7 @@ describe('addMessages', () => { content: [ { type: ContentType.text, - text: 'This is a user message', + text: 'user message', }, ], }, @@ -169,7 +183,12 @@ describe('addMessages', () => { ...providerLog.messages, { role: MessageRole.assistant, - content: providerLog.responseText, + content: [ + { + type: ContentType.text, + text: providerLog.responseText, + }, + ], toolCalls: [], }, { @@ -177,7 +196,219 @@ describe('addMessages', () => { content: [ { type: ContentType.text, - text: 'This is a user message', + text: 'user message', + }, + ], + }, + ]), + config: expect.objectContaining({ + model: 'gpt-4o', + provider: 'openai', + }), + provider: expect.any(Object), + }), + ) + }) + + it('pass arguments to AI service with object response', async () => { + providerLog = await database + .update(providerLogs) + .set({ + responseText: null, + responseObject: { object: 'response' }, + toolCalls: [], + }) + .where(eq(providerLogs.id, providerLog.id)) + .returning() + .then((p) => p[0]!) + + const { stream } = await addMessages({ + workspace, + documentLogUuid: providerLog.documentLogUuid!, + messages: [ + { + role: MessageRole.user, + content: [ + { + type: ContentType.text, + text: 'user message', + }, + ], + }, + ], + source: LogSources.API, + }).then((r) => r.unwrap()) + + await testConsumeStream(stream) + + expect(mocks.ai).toHaveBeenCalledWith( + expect.objectContaining({ + messages: expect.arrayContaining([ + ...providerLog.messages, + { + role: MessageRole.assistant, + content: [ + { + type: ContentType.text, + text: '{\n "object": "response"\n}', + }, + ], + toolCalls: [], + }, + { + role: MessageRole.user, + content: [ + { + type: ContentType.text, + text: 'user message', + }, + ], + }, + ]), + config: expect.objectContaining({ + model: 'gpt-4o', + provider: 'openai', + }), + provider: expect.any(Object), + }), + ) + }) + + it('pass arguments to AI service with tool calls response', async () => { + providerLog = await database + .update(providerLogs) + .set({ + responseText: null, + responseObject: null, + toolCalls: [ + { + id: 'tool-call-id', + name: 'tool-call-name', + arguments: { arg1: 'value1', arg2: 'value2' }, + }, + ], + }) + .where(eq(providerLogs.id, providerLog.id)) + .returning() + .then((p) => p[0]!) + + const { stream } = await addMessages({ + workspace, + documentLogUuid: providerLog.documentLogUuid!, + messages: [ + { + role: MessageRole.user, + content: [ + { + type: ContentType.text, + text: 'user message', + }, + ], + }, + ], + source: LogSources.API, + }).then((r) => r.unwrap()) + + await testConsumeStream(stream) + + expect(mocks.ai).toHaveBeenCalledWith( + expect.objectContaining({ + messages: expect.arrayContaining([ + ...providerLog.messages, + { + role: MessageRole.assistant, + content: [ + { + type: ContentType.toolCall, + toolCallId: 'tool-call-id', + toolName: 'tool-call-name', + args: { arg1: 'value1', arg2: 'value2' }, + }, + ], + toolCalls: providerLog.toolCalls, + }, + { + role: MessageRole.user, + content: [ + { + type: ContentType.text, + text: 'user message', + }, + ], + }, + ]), + config: expect.objectContaining({ + model: 'gpt-4o', + provider: 'openai', + }), + provider: expect.any(Object), + }), + ) + }) + + it('pass arguments to AI service with tool calls and text response', async () => { + providerLog = await database + .update(providerLogs) + .set({ + responseText: 'assistant message', + responseObject: null, + toolCalls: [ + { + id: 'tool-call-id', + name: 'tool-call-name', + arguments: { arg1: 'value1', arg2: 'value2' }, + }, + ], + }) + .where(eq(providerLogs.id, providerLog.id)) + .returning() + .then((p) => p[0]!) + + const { stream } = await addMessages({ + workspace, + documentLogUuid: providerLog.documentLogUuid!, + messages: [ + { + role: MessageRole.user, + content: [ + { + type: ContentType.text, + text: 'user message', + }, + ], + }, + ], + source: LogSources.API, + }).then((r) => r.unwrap()) + + await testConsumeStream(stream) + + expect(mocks.ai).toHaveBeenCalledWith( + expect.objectContaining({ + messages: expect.arrayContaining([ + ...providerLog.messages, + { + role: MessageRole.assistant, + content: [ + { + type: ContentType.text, + text: providerLog.responseText, + }, + { + type: ContentType.toolCall, + toolCallId: 'tool-call-id', + toolName: 'tool-call-name', + args: { arg1: 'value1', arg2: 'value2' }, + }, + ], + toolCalls: providerLog.toolCalls, + }, + { + role: MessageRole.user, + content: [ + { + type: ContentType.text, + text: 'user message', }, ], }, @@ -234,7 +465,12 @@ describe('addMessages', () => { messages: [ { role: 'assistant', - content: 'Fake AI generated text', + content: [ + { + type: ContentType.text, + text: 'Fake AI generated text', + }, + ], toolCalls: [], }, ], @@ -405,14 +641,13 @@ describe('addMessages', () => { data: expect.objectContaining({ type: ChainEventTypes.Complete, messages: [ - { - role: MessageRole.assistant, - content: 'assistant message', - toolCalls: [], - }, { role: MessageRole.assistant, content: [ + { + type: ContentType.text, + text: 'assistant message', + }, { type: ContentType.toolCall, toolCallId: 'tool-call-id', diff --git a/packages/core/src/services/providerLogs/addMessages.ts b/packages/core/src/services/providerLogs/addMessages.ts index 11f64bb6d..cee1d7d0f 100644 --- a/packages/core/src/services/providerLogs/addMessages.ts +++ b/packages/core/src/services/providerLogs/addMessages.ts @@ -1,6 +1,7 @@ import { ContentType, Message, + MessageContent, MessageRole, ToolRequestContent, } from '@latitude-data/compiler' @@ -31,37 +32,59 @@ import { } from '../chains/ProviderProcessor/saveOrPublishProviderLogs' function rebuildConversation(providerLog: ProviderLog) { - let messages = providerLog.messages + let messages: Message[] = providerLog.messages + let message: Message | undefined = undefined if (providerLog.responseText && providerLog.responseText.length > 0) { - messages.push({ + message = { role: MessageRole.assistant, - content: providerLog.responseText, + content: [ + { + type: ContentType.text, + text: providerLog.responseText, + }, + ], toolCalls: [], - }) + } } if (providerLog.responseObject) { - messages.push({ + message = { role: MessageRole.assistant, - content: objectToString(providerLog.responseObject), + content: [ + { + type: ContentType.text, + text: objectToString(providerLog.responseObject), + }, + ], toolCalls: [], - }) + } } if (providerLog.toolCalls.length > 0) { - messages.push({ - role: MessageRole.assistant, - content: providerLog.toolCalls.map((toolCall) => { - return { - type: ContentType.toolCall, - toolCallId: toolCall.id, - toolName: toolCall.name, - args: toolCall.arguments, - } as ToolRequestContent - }), - toolCalls: providerLog.toolCalls, + const content = providerLog.toolCalls.map((toolCall) => { + return { + type: ContentType.toolCall, + toolCallId: toolCall.id, + toolName: toolCall.name, + args: toolCall.arguments, + } as ToolRequestContent }) + + if (message) { + message.content = (message.content as MessageContent[]).concat(content) + message.toolCalls = providerLog.toolCalls + } else { + message = { + role: MessageRole.assistant, + content: content, + toolCalls: providerLog.toolCalls, + } + } + } + + if (message) { + messages.push(message) } return messages