diff --git a/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/Chat.tsx b/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/Chat.tsx index e4f5b26c8..89e81d95e 100644 --- a/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/Chat.tsx +++ b/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/Chat.tsx @@ -136,14 +136,11 @@ export default function Chat({ messagesCount += data.messages!.length } - const isComplete = - data.type === ChainEventTypes.Complete || - data.type === ChainEventTypes.ToolsCalled switch (event) { case StreamEventTypes.Latitude: { if (data.type === ChainEventTypes.StepComplete) { response = '' - } else if (isComplete) { + } else if (data.type === ChainEventTypes.Complete) { setUsage(data.response.usage) setChainLength(messagesCount) setTime(performance.now() - start) diff --git a/packages/constants/src/ai.ts b/packages/constants/src/ai.ts index e98858b51..8a1e5b95c 100644 --- a/packages/constants/src/ai.ts +++ b/packages/constants/src/ai.ts @@ -1,6 +1,7 @@ import { Message, ToolCall } from '@latitude-data/compiler' import { CoreTool, + FinishReason, LanguageModelUsage, ObjectStreamPart, TextStreamPart, @@ -93,14 +94,7 @@ export type ChainEventDto = | { type: ChainEventTypes.Complete config: Config - messages?: Message[] - object?: any - response: ChainEventDtoResponse - uuid?: string - } - | { - type: ChainEventTypes.ToolsCalled - config: Config + finishReason?: FinishReason messages?: Message[] object?: any response: ChainEventDtoResponse @@ -181,14 +175,7 @@ export type LatitudeEventData = messages?: Message[] object?: any response: ChainStepResponse - documentLogUuid?: string - } - | { - type: ChainEventTypes.ToolsCalled - config: Config - messages?: Message[] - object?: any - response: ChainStepResponse + finishReason: FinishReason documentLogUuid?: string } | { diff --git a/packages/core/src/helpers.ts b/packages/core/src/helpers.ts index 97e8d203b..d63ad1f99 100644 --- a/packages/core/src/helpers.ts +++ b/packages/core/src/helpers.ts @@ -7,8 +7,10 @@ import { Message, MessageContent, MessageRole, + ToolCall, ToolRequestContent, } from '@latitude-data/compiler' +import { StreamType } from '@latitude-data/constants' const DEFAULT_OBJECT_TO_STRING_MESSAGE = 'Error: Provider returned an object that could not be stringified' @@ -61,25 +63,53 @@ export function buildCsvFile(csvData: CsvData, name: string): File { return new File([csv], `${name}.csv`, { type: 'text/csv' }) } -export function buildConversation(providerLog: ProviderLogDto) { - let messages: Message[] = [...providerLog.messages] - let message: Message | undefined = undefined - - if (providerLog.response && providerLog.response.length > 0) { - message = { - role: MessageRole.assistant, - content: [ - { - type: ContentType.text, - text: providerLog.response, - }, - ], - toolCalls: [], +type BuildMessageParams = T extends 'object' + ? { + type: 'object' + data?: { + object: any | undefined + text: string | undefined + } } + : { + type: 'text' + data?: { + text: string + toolCalls: ToolCall[] + } + } +export function buildResponseMessage({ + type, + data, +}: BuildMessageParams) { + let message: Message = { + role: MessageRole.assistant, + content: [] as MessageContent[], + toolCalls: [], + } + if (!data) return undefined + + const text = data.text + const object = type === 'object' ? data.object : undefined + const toolCalls = type === 'text' ? (data.toolCalls ?? []) : [] + let content: MessageContent[] = [] + + if (text && text.length > 0) { + content.push({ + type: ContentType.text, + text: text, + }) + } + + if (object) { + content.push({ + type: ContentType.text, + text: objectToString(object), + }) } - if (providerLog.toolCalls.length > 0) { - const content = providerLog.toolCalls.map((toolCall) => { + if (toolCalls.length > 0) { + const toolContents = toolCalls.map((toolCall) => { return { type: ContentType.toolCall, toolCallId: toolCall.id, @@ -88,18 +118,26 @@ export function buildConversation(providerLog: ProviderLogDto) { } as ToolRequestContent }) - if (message) { - message.content = (message.content as MessageContent[]).concat(content) - message.toolCalls = providerLog.toolCalls - } else { - message = { - role: MessageRole.assistant, - content: content, - toolCalls: providerLog.toolCalls, - } - } + message.toolCalls = toolCalls + content = content.concat(toolContents) } + message.content = content + + return content.length > 0 ? message : undefined +} + +export function buildConversation(providerLog: ProviderLogDto) { + let messages: Message[] = [...providerLog.messages] + + const message = buildResponseMessage({ + type: 'text', + data: { + text: providerLog.response, + toolCalls: providerLog.toolCalls, + }, + }) + if (message) { messages.push(message) } diff --git a/packages/core/src/services/chains/ChainStreamConsumer/index.ts b/packages/core/src/services/chains/ChainStreamConsumer/index.ts index 17329fba9..c3ea332de 100644 --- a/packages/core/src/services/chains/ChainStreamConsumer/index.ts +++ b/packages/core/src/services/chains/ChainStreamConsumer/index.ts @@ -14,10 +14,11 @@ import { StreamEventTypes, StreamType, } from '../../../constants' -import { objectToString } from '../../../helpers' +import { buildResponseMessage, objectToString } from '../../../helpers' import { Config } from '../../ai' import { ChainError } from '../ChainErrors' import { ValidatedStep } from '../ChainValidator' +import { FinishReason } from 'ai' export function enqueueChainEvent( controller: ReadableStreamDefaultController, @@ -51,65 +52,29 @@ export class ChainStreamConsumer { response, config, controller, + finishReason, }: { controller: ReadableStreamDefaultController response: ChainStepResponse config: Config + finishReason: FinishReason }) { - let messages: Message[] = [] - let message: Message | undefined = undefined - - if (response.text.length > 0) { - message = { - role: MessageRole.assistant, - content: [ - { - type: ContentType.text, - text: response.text, - }, - ], - toolCalls: [], - } - } - - if (response.streamType === 'object' && response.object) { - message = { - role: MessageRole.assistant, - content: [ - { - type: ContentType.text, - text: objectToString(response.object), - }, - ], - toolCalls: [], - } - } - - if (response.streamType === 'text' && response.toolCalls.length > 0) { - const content = response.toolCalls.map((toolCall) => { - return { - type: ContentType.toolCall, - toolCallId: toolCall.id, - toolName: toolCall.name, - args: toolCall.arguments, - } as ToolRequestContent - }) - - if (message) { - message.content = (message.content as MessageContent[]).concat(content) - message.toolCalls = response.toolCalls - } else { - message = { - role: MessageRole.assistant, - content: content, - toolCalls: response.toolCalls, - } - } - } - - if (message) { - messages.push(message) - } + const type = response.streamType + const message = + type === 'object' + ? buildResponseMessage<'object'>({ + type: 'object', + data: { + object: response.object, + text: response.text, + }, + }) + : type === 'text' + ? buildResponseMessage<'text'>({ + type: 'text', + data: { text: response.text, toolCalls: response.toolCalls }, + }) + : undefined enqueueChainEvent(controller, { event: StreamEventTypes.Latitude, @@ -118,7 +83,8 @@ export class ChainStreamConsumer { config, documentLogUuid: response.documentLogUuid, response, - messages, + messages: message ? [message] : undefined, + finishReason, }, }) @@ -196,14 +162,17 @@ export class ChainStreamConsumer { chainCompleted({ step, response, + finishReason, }: { step: ValidatedStep response: ChainStepResponse + finishReason: FinishReason }) { ChainStreamConsumer.chainCompleted({ controller: this.controller, response, config: step.conversation.config as Config, + finishReason, }) } diff --git a/packages/core/src/services/chains/buildStep/index.ts b/packages/core/src/services/chains/buildStep/index.ts index 6a71adf7f..b55c3e876 100644 --- a/packages/core/src/services/chains/buildStep/index.ts +++ b/packages/core/src/services/chains/buildStep/index.ts @@ -79,8 +79,7 @@ export async function buildStepExecution({ streamConsumer.chainCompleted({ step, response: finalResponse, - // TODO: Add this - finishReason + finishReason, }) return finalResponse