diff --git a/apps/gateway/package.json b/apps/gateway/package.json index 3f3a53963..4fe574874 100644 --- a/apps/gateway/package.json +++ b/apps/gateway/package.json @@ -30,7 +30,7 @@ "drizzle-orm": "^0.33.0", "hono": "^4.6.6", "lodash-es": "^4.17.21", - "promptl-ai": "^0.3.5", + "promptl-ai": "^0.4.5", "rate-limiter-flexible": "^5.0.3", "zod": "^3.23.8" }, diff --git a/apps/web/package.json b/apps/web/package.json index 813debe44..a5990de61 100644 --- a/apps/web/package.json +++ b/apps/web/package.json @@ -51,7 +51,7 @@ "oslo": "1.2.0", "pdfjs-dist": "^4.9.155", "posthog-js": "^1.161.6", - "promptl-ai": "^0.3.5", + "promptl-ai": "^0.4.5", "rate-limiter-flexible": "^5.0.3", "react": "19.0.0-rc-5d19e1c8-20240923", "react-dom": "19.0.0-rc-5d19e1c8-20240923", diff --git a/apps/web/src/app/(private)/evaluations/(evaluation)/[evaluationUuid]/editor/_components/Playground/Chat/index.tsx b/apps/web/src/app/(private)/evaluations/(evaluation)/[evaluationUuid]/editor/_components/Playground/Chat/index.tsx index 828c694a3..c58c60366 100644 --- a/apps/web/src/app/(private)/evaluations/(evaluation)/[evaluationUuid]/editor/_components/Playground/Chat/index.tsx +++ b/apps/web/src/app/(private)/evaluations/(evaluation)/[evaluationUuid]/editor/_components/Playground/Chat/index.tsx @@ -176,7 +176,7 @@ export default function Chat({ ) : ( )} diff --git a/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/Chat.tsx b/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/Chat.tsx index 8e14284b3..4d0d55dd6 100644 --- a/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/Chat.tsx +++ b/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/Chat.tsx @@ -2,9 +2,9 @@ import { useCallback, useContext, useEffect, useRef, useState } from 'react' import { ContentType, - Conversation, Message as ConversationMessage, MessageRole, + ToolMessage, } from '@latitude-data/compiler' import { ChainEventTypes, @@ -31,15 +31,31 @@ import { readStreamableValue } from 'ai/rsc' import { DocumentEditorContext } from '..' import Actions, { ActionsState } from './Actions' +import { useMessages } from './useMessages' +import { type PromptlVersion } from '@latitude-data/web-ui' -export default function Chat({ +function buildMessage({ input }: { input: string | ToolMessage[] }) { + if (typeof input === 'string') { + return [ + { + role: MessageRole.user, + content: [{ type: ContentType.text, text: input }], + } as ConversationMessage, + ] + } + return input +} + +export default function Chat({ document, + promptlVersion, parameters, clearChat, expandParameters, setExpandParameters, }: { document: DocumentVersion + promptlVersion: V parameters: Record clearChat: () => void } & ActionsState) { @@ -62,25 +78,11 @@ export default function Chat({ const runChainOnce = useRef(false) // Index where the chain ends and the chat begins const [chainLength, setChainLength] = useState(Infinity) - const [conversation, setConversation] = useState() const [responseStream, setResponseStream] = useState() const [isStreaming, setIsStreaming] = useState(false) - - const addMessageToConversation = useCallback( - (message: ConversationMessage) => { - let newConversation: Conversation - setConversation((prevConversation) => { - newConversation = { - ...prevConversation, - messages: [...(prevConversation?.messages ?? []), message], - } as Conversation - return newConversation - }) - return newConversation! - }, - [], - ) - + const { messages, addMessages, unresponedToolCalls } = useMessages({ + version: promptlVersion, + }) const startStreaming = useCallback(() => { setError(undefined) setUsage({ @@ -122,7 +124,7 @@ export default function Chat({ if ('messages' in data) { setResponseStream(undefined) - data.messages!.forEach(addMessageToConversation) + addMessages(data.messages ?? []) messagesCount += data.messages!.length } @@ -148,6 +150,7 @@ export default function Chat({ } break } + default: break } @@ -163,7 +166,7 @@ export default function Chat({ commit.uuid, parameters, runDocumentAction, - addMessageToConversation, + addMessages, startStreaming, stopStreaming, ]) @@ -176,14 +179,15 @@ export default function Chat({ }, [runDocument]) const submitUserMessage = useCallback( - async (input: string) => { + async (input: string | ToolMessage[]) => { if (!documentLogUuid) return // This should not happen - const message: ConversationMessage = { - role: MessageRole.user, - content: [{ type: ContentType.text, text: input }], + const newMessages = buildMessage({ input }) + + // Only in Chat mode we add optimistically the message + if (typeof input === 'string') { + addMessages(newMessages) } - addMessageToConversation(message) let response = '' startStreaming() @@ -191,7 +195,7 @@ export default function Chat({ try { const { output } = await addMessagesAction({ documentLogUuid, - messages: [message], + messages: newMessages, }) for await (const serverEvent of readStreamableValue(output)) { @@ -201,7 +205,7 @@ export default function Chat({ if ('messages' in data) { setResponseStream(undefined) - data.messages!.forEach(addMessageToConversation) + addMessages(data.messages ?? []) } switch (event) { @@ -222,6 +226,7 @@ export default function Chat({ } break } + default: break } @@ -235,7 +240,7 @@ export default function Chat({ [ documentLogUuid, addMessagesAction, - addMessageToConversation, + addMessages, startStreaming, stopStreaming, ], @@ -255,24 +260,22 @@ export default function Chat({ className='flex flex-col gap-3 flex-grow flex-shrink min-h-0 custom-scrollbar scrollable-indicator pb-12' > - {(conversation?.messages.length ?? 0) >= chainLength && ( + {(messages.length ?? 0) >= chainLength && ( <> {time && } )} - {(conversation?.messages.length ?? 0) > chainLength && ( + {(messages.length ?? 0) > chainLength && ( <> Chat - + )} {error ? ( @@ -280,7 +283,7 @@ export default function Chat({ ) : ( )} @@ -296,6 +299,8 @@ export default function Chat({ placeholder='Enter followup message...' disabled={isStreaming} onSubmit={submitUserMessage} + toolRequests={unresponedToolCalls} + addMessages={addMessages} /> @@ -342,12 +347,8 @@ export function TokenUsage({ } >
- - {usage?.promptTokens || 0} prompt tokens - - - {usage?.completionTokens || 0} completion tokens - + {usage?.promptTokens || 0} prompt tokens + {usage?.completionTokens || 0} completion tokens
) : ( @@ -359,16 +360,15 @@ export function TokenUsage({ export function StreamMessage({ responseStream, - conversation, + messages, chainLength, }: { responseStream: string | undefined - conversation: Conversation | undefined + messages: ConversationMessage[] chainLength: number }) { if (responseStream === undefined) return null - if (conversation === undefined) return null - if (conversation.messages.length < chainLength - 1) { + if (messages.length < chainLength - 1) { return ( void metadata: ConversationMetadata }) { + const promptlVersion = document.promptlVersion === 1 ? 1 : 0 const [mode, setMode] = useState<'preview' | 'chat'>('preview') const { commit } = useCurrentCommit() const [expanded, setExpanded] = useState(true) @@ -77,6 +76,7 @@ export default function Playground({ ) : ( setMode('preview')} expandParameters={expandParameters} diff --git a/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/useMessages.ts b/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/useMessages.ts new file mode 100644 index 000000000..d35e5b6c2 --- /dev/null +++ b/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/useMessages.ts @@ -0,0 +1,75 @@ +import { + ToolPart, + ToolRequest, + PromptlVersion, + VersionedMessage, + extractToolContents, +} from '@latitude-data/web-ui' +import { Message as CompilerMessage } from '@latitude-data/compiler' +import { useCallback, useState } from 'react' + +function isToolRequest(part: ToolPart): part is ToolRequest { + return 'toolCallId' in part +} + +function getUnrespondedToolRequests({ + version: _, + messages, +}: { + version: V + messages: VersionedMessage[] +}) { + // FIXME: Kill compiler please. I made this module compatible with + // both old compiler and promptl. But because everything is typed with + // old compiler prompts in promptl version are also formatted as old compiler + const parts = extractToolContents({ + version: 0, + messages: messages as VersionedMessage<0>[], + }) + const toolRequestIds = new Set() + const toolResponses = new Set() + + parts.forEach((part) => { + if (isToolRequest(part)) { + toolRequestIds.add(part.toolCallId) + } else { + toolResponses.add(part.id) + } + }) + + return parts.filter( + (part): part is ToolRequest => + isToolRequest(part) && !toolResponses.has(part.toolCallId), + ) +} + +type Props = { version: V } + +export function useMessages({ version }: Props) { + const [messages, setMessages] = useState[]>([]) + const [unresponedToolCalls, setUnresponedToolCalls] = useState( + [], + ) + // FIXME: Kill compiler please + // every where we have old compiler types. To avoid Typescript crying we + // allow only compiler messages and then transform them to versioned messages + const addMessages = useCallback( + (m: CompilerMessage[]) => { + const msg = m as VersionedMessage[] + setMessages((prevMessages) => { + const newMessages = prevMessages.concat(msg) + setUnresponedToolCalls( + getUnrespondedToolRequests({ version, messages: newMessages }), + ) + return newMessages + }) + }, + [version], + ) + + return { + addMessages, + messages: messages as CompilerMessage[], + unresponedToolCalls, + } +} diff --git a/apps/web/src/app/(public)/share/d/[publishedDocumentUuid]/_components/Messages/AllMessages/index.tsx b/apps/web/src/app/(public)/share/d/[publishedDocumentUuid]/_components/Messages/AllMessages/index.tsx index a95ec0933..98eff6eb8 100644 --- a/apps/web/src/app/(public)/share/d/[publishedDocumentUuid]/_components/Messages/AllMessages/index.tsx +++ b/apps/web/src/app/(public)/share/d/[publishedDocumentUuid]/_components/Messages/AllMessages/index.tsx @@ -38,7 +38,7 @@ export function AllMessages({ ) : ( )} diff --git a/apps/web/src/app/(public)/share/d/[publishedDocumentUuid]/_components/Messages/ChatMessages/index.tsx b/apps/web/src/app/(public)/share/d/[publishedDocumentUuid]/_components/Messages/ChatMessages/index.tsx index bcdb46150..314f825e5 100644 --- a/apps/web/src/app/(public)/share/d/[publishedDocumentUuid]/_components/Messages/ChatMessages/index.tsx +++ b/apps/web/src/app/(public)/share/d/[publishedDocumentUuid]/_components/Messages/ChatMessages/index.tsx @@ -30,7 +30,7 @@ export function ChatMessages({ ) : ( )} diff --git a/packages/compiler/src/types/message.ts b/packages/compiler/src/types/message.ts index 66fd6628a..2a372e809 100644 --- a/packages/compiler/src/types/message.ts +++ b/packages/compiler/src/types/message.ts @@ -93,7 +93,7 @@ export type AssistantMessage = { export type ToolMessage = { role: MessageRole.tool - content: ToolContent[] + content: (TextContent | ToolContent)[] [key: string]: unknown } diff --git a/packages/constants/src/ai.ts b/packages/constants/src/ai.ts index 68d3e6e86..f26a1c331 100644 --- a/packages/constants/src/ai.ts +++ b/packages/constants/src/ai.ts @@ -1,6 +1,7 @@ import { Message, ToolCall } from '@latitude-data/compiler' import { CoreTool, + FinishReason, LanguageModelUsage, ObjectStreamPart, TextStreamPart, @@ -93,6 +94,7 @@ export type ChainEventDto = | { type: ChainEventTypes.Complete config: Config + finishReason?: FinishReason messages?: Message[] object?: any response: ChainEventDtoResponse @@ -174,6 +176,7 @@ export type LatitudeChainCompleteEventData = { messages?: Message[] object?: any response: ChainStepResponse + finishReason: FinishReason documentLogUuid?: string } @@ -188,14 +191,12 @@ export type LatitudeEventData = | LatitudeChainCompleteEventData | LatitudeChainErrorEventData -// FIXME: Move to @latitude-data/constants export type RunSyncAPIResponse = { uuid: string conversation: Message[] response: ChainCallResponseDto } -// FIXME: Move to @latitude-data/constants export type ChatSyncAPIResponse = RunSyncAPIResponse export const toolCallResponseSchema = z.object({ diff --git a/packages/core/package.json b/packages/core/package.json index 80e3ae3b7..94c3b9725 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -151,6 +151,6 @@ "dependencies": { "date-fns": "^3.6.0", "diff-match-patch": "^1.0.5", - "promptl-ai": "^0.3.5" + "promptl-ai": "^0.4.5" } } diff --git a/packages/core/src/constants.ts b/packages/core/src/constants.ts index b42b6f8b9..289a23cac 100644 --- a/packages/core/src/constants.ts +++ b/packages/core/src/constants.ts @@ -10,7 +10,7 @@ import { LogSources, ProviderData, } from '@latitude-data/constants' -import { LanguageModelUsage } from 'ai' +import { FinishReason, LanguageModelUsage } from 'ai' import { z } from 'zod' import { DocumentVersion, ProviderLog, Span, Trace } from './browser' @@ -62,6 +62,8 @@ export type StreamType = 'object' | 'text' type BaseResponse = { text: string usage: LanguageModelUsage + finishReason?: FinishReason + chainCompleted?: boolean documentLogUuid?: string providerLog?: ProviderLog } @@ -264,7 +266,7 @@ const toolResultContentSchema = z.object({ type: z.literal('tool-result'), toolCallId: z.string(), toolName: z.string(), - result: z.string(), + result: z.any(), isError: z.boolean().optional(), }) diff --git a/packages/core/src/events/events.d.ts b/packages/core/src/events/events.d.ts index 0f85744ff..ab1e1ef92 100644 --- a/packages/core/src/events/events.d.ts +++ b/packages/core/src/events/events.d.ts @@ -1,4 +1,4 @@ -import { LanguageModelUsage } from 'ai' +import { FinishReason, LanguageModelUsage } from 'ai' import type { ChainStepResponse, @@ -134,6 +134,8 @@ export type StreamCommonData = { messages: Message[] usage: LanguageModelUsage duration: number + chainCompleted: boolean + finishReason: FinishReason } type StreamTextData = { diff --git a/packages/core/src/helpers.ts b/packages/core/src/helpers.ts index ba86c76a3..0140531e5 100644 --- a/packages/core/src/helpers.ts +++ b/packages/core/src/helpers.ts @@ -85,6 +85,7 @@ type BuildMessageParams = T extends 'object' toolCallResponses?: ToolCallResponse[] } } + export function buildResponseMessage({ type, data, @@ -183,50 +184,30 @@ export function buildMessagesFromResponse({ }) : undefined - const previousMessages = response.providerLog - ? response.providerLog.messages - : [] - return [...previousMessages, ...(message ? [message] : [])] + return message ? ([message] as Message[]) : [] +} + +export function buildAllMessagesFromResponse({ + response, +}: { + response: ChainStepResponse +}) { + const previousMessages = response.providerLog?.messages ?? [] + const messages = buildMessagesFromResponse({ response }) + + return [...previousMessages, ...messages] } export function buildConversation(providerLog: ProviderLogDto) { let messages: Message[] = [...providerLog.messages] - let message: Message | undefined = undefined - - if (providerLog.response && providerLog.response.length > 0) { - message = { - role: 'assistant', - content: [ - { - type: 'text', - text: providerLog.response, - }, - ], - toolCalls: [], - } as Message - } - - if (providerLog.toolCalls.length > 0) { - const content = providerLog.toolCalls.map((toolCall) => { - return { - type: 'tool-call', - toolCallId: toolCall.id, - toolName: toolCall.name, - args: toolCall.arguments, - } as ToolRequestContent - }) - if (message) { - message.content = (message.content as MessageContent[]).concat(content) - message.toolCalls = providerLog.toolCalls - } else { - message = { - role: 'assistant', - content: content, - toolCalls: providerLog.toolCalls, - } as Message - } - } + const message = buildResponseMessage({ + type: 'text', + data: { + text: providerLog.response, + toolCalls: providerLog.toolCalls, + }, + }) if (message) { messages.push(message) diff --git a/packages/core/src/jobs/job-definitions/batchEvaluations/runEvaluationJob.test.ts b/packages/core/src/jobs/job-definitions/batchEvaluations/runEvaluationJob.test.ts index 3f698c080..6d25eb5b0 100644 --- a/packages/core/src/jobs/job-definitions/batchEvaluations/runEvaluationJob.test.ts +++ b/packages/core/src/jobs/job-definitions/batchEvaluations/runEvaluationJob.test.ts @@ -125,6 +125,8 @@ describe('runEvaluationJob', () => { usage: { promptTokens: 8, completionTokens: 2, totalTokens: 10 }, documentLogUuid: documentLog.uuid, providerLog: undefined, + finishReason: 'stop', + chainCompleted: true, }) runEvaluationSpy.mockResolvedValueOnce( Result.ok({ @@ -151,6 +153,8 @@ describe('runEvaluationJob', () => { usage: { promptTokens: 8, completionTokens: 2, totalTokens: 10 }, documentLogUuid: documentLog.uuid, providerLog: undefined, + finishReason: 'stop', + chainCompleted: true, }) runEvaluationSpy.mockResolvedValueOnce( Result.ok({ diff --git a/packages/core/src/repositories/documentLogsRepository/index.ts b/packages/core/src/repositories/documentLogsRepository/index.ts index 968c7417b..4825de993 100644 --- a/packages/core/src/repositories/documentLogsRepository/index.ts +++ b/packages/core/src/repositories/documentLogsRepository/index.ts @@ -42,7 +42,11 @@ export class DocumentLogsRepository extends Repository { .$dynamic() } - async findByUuid(uuid: string) { + async findByUuid(uuid: string | undefined) { + if (!uuid) { + return Result.error(new NotFoundError('DocumentLog not found')) + } + const result = await this.scope.where(eq(documentLogs.uuid, uuid)) if (!result.length) { diff --git a/packages/core/src/services/ai/providers/rules/providerMetadata/index.ts b/packages/core/src/services/ai/providers/rules/providerMetadata/index.ts index d9a6bd006..5e34e222e 100644 --- a/packages/core/src/services/ai/providers/rules/providerMetadata/index.ts +++ b/packages/core/src/services/ai/providers/rules/providerMetadata/index.ts @@ -21,6 +21,8 @@ const CONTENT_DEFINED_ATTRIBUTES = [ 'toolCallId', 'toolName', 'args', + // TODO: Add a test for this + 'result', ] as const type AttrArgs = { attributes: string[]; content: Record } diff --git a/packages/core/src/services/chains/ChainStreamConsumer/index.test.ts b/packages/core/src/services/chains/ChainStreamConsumer/index.test.ts index e97ff00fa..d39019499 100644 --- a/packages/core/src/services/chains/ChainStreamConsumer/index.test.ts +++ b/packages/core/src/services/chains/ChainStreamConsumer/index.test.ts @@ -99,6 +99,8 @@ describe('ChainStreamConsumer', () => { toolCalls: [], usage: { promptTokens: 0, completionTokens: 0, totalTokens: 0 }, documentLogUuid: 'errorable-uuid', + finishReason: 'stop', + chainCompleted: false, } consumer.stepCompleted(response) @@ -121,9 +123,27 @@ describe('ChainStreamConsumer', () => { toolCalls: [], usage: { promptTokens: 0, completionTokens: 0, totalTokens: 0 }, documentLogUuid: 'errorable-uuid', + finishReason: 'stop', + chainCompleted: true, } - consumer.chainCompleted({ step, response }) + consumer.chainCompleted({ + step, + response, + finishReason: 'stop', + responseMessages: [ + { + role: MessageRole.assistant, + content: [ + { + type: ContentType.text, + text: 'text response', + }, + ], + toolCalls: [], + }, + ], + }) expect(controller.enqueue).toHaveBeenCalledWith({ event: StreamEventTypes.Latitude, @@ -131,6 +151,7 @@ describe('ChainStreamConsumer', () => { type: ChainEventTypes.Complete, config: step.conversation.config, response: response, + finishReason: 'stop', messages: [ { role: MessageRole.assistant, @@ -157,9 +178,27 @@ describe('ChainStreamConsumer', () => { object: { object: 'response' }, usage: { promptTokens: 0, completionTokens: 0, totalTokens: 0 }, documentLogUuid: 'errorable-uuid', + finishReason: 'stop', + chainCompleted: true, } - consumer.chainCompleted({ step, response }) + consumer.chainCompleted({ + step, + response, + finishReason: 'stop', + responseMessages: [ + { + role: MessageRole.assistant, + content: [ + { + type: ContentType.text, + text: '{\n "object": "response"\n}', + }, + ], + toolCalls: [], + }, + ], + }) expect(controller.enqueue).toHaveBeenCalledWith({ event: StreamEventTypes.Latitude, @@ -167,6 +206,7 @@ describe('ChainStreamConsumer', () => { type: ChainEventTypes.Complete, config: step.conversation.config, response: response, + finishReason: 'stop', messages: [ { role: MessageRole.assistant, @@ -199,9 +239,35 @@ describe('ChainStreamConsumer', () => { ], usage: { promptTokens: 0, completionTokens: 0, totalTokens: 0 }, documentLogUuid: 'errorable-uuid', + finishReason: 'stop', + chainCompleted: true, } - consumer.chainCompleted({ step, response }) + consumer.chainCompleted({ + step, + response, + finishReason: 'stop', + responseMessages: [ + { + role: MessageRole.assistant, + content: [ + { + type: ContentType.toolCall, + toolCallId: 'tool-call-id', + toolName: 'tool-call-name', + args: { arg1: 'value1', arg2: 'value2' }, + }, + ], + toolCalls: [ + { + id: 'tool-call-id', + name: 'tool-call-name', + arguments: { arg1: 'value1', arg2: 'value2' }, + }, + ], + }, + ], + }) expect(controller.enqueue).toHaveBeenCalledWith({ event: StreamEventTypes.Latitude, @@ -209,6 +275,7 @@ describe('ChainStreamConsumer', () => { type: ChainEventTypes.Complete, config: step.conversation.config, response: response, + finishReason: 'stop', messages: [ { role: MessageRole.assistant, @@ -249,15 +316,46 @@ describe('ChainStreamConsumer', () => { ], usage: { promptTokens: 0, completionTokens: 0, totalTokens: 0 }, documentLogUuid: 'errorable-uuid', + finishReason: 'stop', + chainCompleted: true, } - consumer.chainCompleted({ step, response }) + consumer.chainCompleted({ + step, + response, + finishReason: 'stop', + responseMessages: [ + { + role: MessageRole.assistant, + content: [ + { + type: ContentType.text, + text: 'text response', + }, + { + type: ContentType.toolCall, + toolCallId: 'tool-call-id', + toolName: 'tool-call-name', + args: { arg1: 'value1', arg2: 'value2' }, + }, + ], + toolCalls: [ + { + id: 'tool-call-id', + name: 'tool-call-name', + arguments: { arg1: 'value1', arg2: 'value2' }, + }, + ], + }, + ], + }) expect(controller.enqueue).toHaveBeenCalledWith({ event: StreamEventTypes.Latitude, data: { type: ChainEventTypes.Complete, config: step.conversation.config, + finishReason: 'stop', response: response, messages: [ { diff --git a/packages/core/src/services/chains/ChainStreamConsumer/index.ts b/packages/core/src/services/chains/ChainStreamConsumer/index.ts index 2a7f51fb6..9da671b14 100644 --- a/packages/core/src/services/chains/ChainStreamConsumer/index.ts +++ b/packages/core/src/services/chains/ChainStreamConsumer/index.ts @@ -1,9 +1,3 @@ -import { - ContentType, - MessageContent, - MessageRole, - ToolRequestContent, -} from '@latitude-data/compiler' import { RunErrorCodes } from '@latitude-data/constants/errors' import { @@ -14,13 +8,14 @@ import { StreamEventTypes, StreamType, } from '../../../constants' -import { objectToString } from '../../../helpers' import { Config } from '../../ai' import { ChainError } from '../ChainErrors' import { ValidatedChainStep } from '../ChainValidator' import { ValidatedAgentStep } from '../agents/AgentStepValidator' type ValidatedStep = ValidatedChainStep | ValidatedAgentStep +import { FinishReason } from 'ai' +import { buildMessagesFromResponse } from '../../../helpers' export function enqueueChainEvent( controller: ReadableStreamDefaultController, @@ -54,66 +49,15 @@ export class ChainStreamConsumer { response, config, controller, + finishReason, + responseMessages, }: { controller: ReadableStreamDefaultController response: ChainStepResponse config: Config + finishReason: FinishReason + responseMessages: Message[] }) { - let messages: Message[] = [] - let message: Message | undefined = undefined - - if (response.text.length > 0) { - message = { - role: MessageRole.assistant, - content: [ - { - type: ContentType.text, - text: response.text, - }, - ], - toolCalls: [], - } - } - - if (response.streamType === 'object' && response.object) { - message = { - role: MessageRole.assistant, - content: [ - { - type: ContentType.text, - text: objectToString(response.object), - }, - ], - toolCalls: [], - } - } - - if (response.streamType === 'text' && response.toolCalls.length > 0) { - const content = response.toolCalls.map((toolCall) => { - return { - type: ContentType.toolCall, - toolCallId: toolCall.id, - toolName: toolCall.name, - args: toolCall.arguments, - } as ToolRequestContent - }) - - if (message) { - message.content = (message.content as MessageContent[]).concat(content) - message.toolCalls = response.toolCalls - } else { - message = { - role: MessageRole.assistant, - content: content, - toolCalls: response.toolCalls, - } - } - } - - if (message) { - messages.push(message) - } - enqueueChainEvent(controller, { event: StreamEventTypes.Latitude, data: { @@ -121,7 +65,8 @@ export class ChainStreamConsumer { config, documentLogUuid: response.documentLogUuid, response, - messages, + messages: responseMessages, + finishReason, }, }) @@ -199,14 +144,22 @@ export class ChainStreamConsumer { chainCompleted({ step, response, + finishReason, + responseMessages: defaultResponseMessages, }: { step: ValidatedStep response: ChainStepResponse + finishReason: FinishReason + responseMessages?: Message[] }) { - ChainStreamConsumer.chainCompleted({ + const responseMessages = + defaultResponseMessages ?? buildMessagesFromResponse({ response }) + return ChainStreamConsumer.chainCompleted({ controller: this.controller, response, config: step.conversation.config as Config, + finishReason, + responseMessages, }) } diff --git a/packages/core/src/services/chains/ChainValidator/index.ts b/packages/core/src/services/chains/ChainValidator/index.ts index 91cf65c1d..81218149c 100644 --- a/packages/core/src/services/chains/ChainValidator/index.ts +++ b/packages/core/src/services/chains/ChainValidator/index.ts @@ -6,7 +6,7 @@ import { } from '@latitude-data/compiler' import { RunErrorCodes } from '@latitude-data/constants/errors' import { JSONSchema7 } from 'json-schema' -import { Chain as PromptlChain } from 'promptl-ai' +import { Chain as PromptlChain, Message as PromptlMessage } from 'promptl-ai' import { z } from 'zod' import { applyProviderRules, ProviderApiKey, Workspace } from '../../../browser' @@ -33,7 +33,7 @@ export type ConfigOverrides = JSONOverride | { output: 'no-schema' } type ValidatorContext = { workspace: Workspace - prevText: string | undefined + prevContent: Message[] | string | undefined chain: SomeChain promptlVersion: number providersMap: CachedApiKeys @@ -77,24 +77,48 @@ const getOutputType = ({ return configSchema.type === 'array' ? 'array' : 'object' } +/* + * Legacy compiler wants a string as response for the next step + * But new Promptl can handle an array of messages + */ +function getTextFromMessages( + prevContent: Message[] | string | undefined, +): string | undefined { + if (!prevContent) return undefined + if (typeof prevContent === 'string') return prevContent + + try { + return prevContent + .flatMap((message) => { + if (typeof message.content === 'string') return message.content + return JSON.stringify(message.content) + }) + .join('\n') + } catch { + return '' + } +} + const safeChain = async ({ promptlVersion, chain, - prevText, + prevContent, }: { promptlVersion: number chain: SomeChain - prevText: string | undefined + prevContent: Message[] | string | undefined }) => { try { if (promptlVersion === 0) { + let prevText = getTextFromMessages(prevContent) const { completed, conversation } = await (chain as LegacyChain).step( prevText, ) return Result.ok({ chainCompleted: completed, conversation }) } + const { completed, messages, config } = await (chain as PromptlChain).step( - prevText, + prevContent as PromptlMessage[] | undefined | string, ) return Result.ok({ chainCompleted: completed, @@ -170,13 +194,13 @@ export const validateChain = async ( const { workspace, promptlVersion, - prevText, + prevContent, chain, providersMap, configOverrides, removeSchema, } = context - const chainResult = await safeChain({ promptlVersion, chain, prevText }) + const chainResult = await safeChain({ promptlVersion, chain, prevContent }) if (chainResult.error) return chainResult const { chainCompleted, conversation } = chainResult.value diff --git a/packages/core/src/services/chains/ProviderProcessor/index.test.ts b/packages/core/src/services/chains/ProviderProcessor/index.test.ts index 588e69ee2..6f2b728ed 100644 --- a/packages/core/src/services/chains/ProviderProcessor/index.test.ts +++ b/packages/core/src/services/chains/ProviderProcessor/index.test.ts @@ -93,6 +93,8 @@ describe('ProviderProcessor', () => { }, }, startTime: Date.now(), + chainCompleted: true, + finishReason: 'stop', }) expect(result).toEqual({ diff --git a/packages/core/src/services/chains/ProviderProcessor/index.ts b/packages/core/src/services/chains/ProviderProcessor/index.ts index 9e99213f9..97311bb20 100644 --- a/packages/core/src/services/chains/ProviderProcessor/index.ts +++ b/packages/core/src/services/chains/ProviderProcessor/index.ts @@ -12,6 +12,7 @@ import { generateUUIDIdentifier } from '../../../lib' import { AIReturn, PartialConfig } from '../../ai' import { processStreamObject } from './processStreamObject' import { processStreamText } from './processStreamText' +import { FinishReason } from 'ai' async function buildCommonData({ aiResult, @@ -21,6 +22,8 @@ async function buildCommonData({ apiProvider, config, messages, + chainCompleted, + finishReason, errorableUuid, }: { aiResult: Awaited> @@ -30,6 +33,8 @@ async function buildCommonData({ apiProvider: ProviderApiKey config: PartialConfig messages: Message[] + chainCompleted: boolean + finishReason: FinishReason errorableUuid?: string }): Promise { const endTime = Date.now() @@ -53,6 +58,8 @@ async function buildCommonData({ config: config, messages: messages, usage: await aiResult.data.usage, + finishReason, + chainCompleted, } } @@ -67,6 +74,8 @@ export async function processResponse({ apiProvider, config, messages, + finishReason, + chainCompleted, errorableUuid, }: { aiResult: Awaited> @@ -76,6 +85,8 @@ export async function processResponse({ apiProvider: ProviderApiKey config: PartialConfig messages: Message[] + finishReason: FinishReason + chainCompleted: boolean errorableUuid?: string }): Promise> { const commonData = await buildCommonData({ @@ -87,6 +98,8 @@ export async function processResponse({ config, messages, errorableUuid, + finishReason, + chainCompleted, }) if (aiResult.type === 'text') { diff --git a/packages/core/src/services/chains/ProviderProcessor/saveOrPublishProviderLogs.test.ts b/packages/core/src/services/chains/ProviderProcessor/saveOrPublishProviderLogs.test.ts index e838598c1..214f30211 100644 --- a/packages/core/src/services/chains/ProviderProcessor/saveOrPublishProviderLogs.test.ts +++ b/packages/core/src/services/chains/ProviderProcessor/saveOrPublishProviderLogs.test.ts @@ -84,11 +84,17 @@ describe('saveOrPublishProviderLogs', () => { streamType: 'text', saveSyncProviderLogs: true, finishReason: 'stop', + chainCompleted: true, }) expect(publisherSpy).toHaveBeenCalledWith({ type: 'aiProviderCallCompleted', - data: { ...data, streamType: 'text' }, + data: { + ...data, + streamType: 'text', + finishReason: 'stop', + chainCompleted: true, + }, }) }) @@ -99,6 +105,7 @@ describe('saveOrPublishProviderLogs', () => { workspace, saveSyncProviderLogs: true, finishReason: 'stop', + chainCompleted: true, }) expect(createProviderLogSpy).toHaveBeenCalledWith({ @@ -114,6 +121,7 @@ describe('saveOrPublishProviderLogs', () => { streamType: 'text', saveSyncProviderLogs: false, finishReason: 'stop', + chainCompleted: true, workspace, }) diff --git a/packages/core/src/services/chains/ProviderProcessor/saveOrPublishProviderLogs.ts b/packages/core/src/services/chains/ProviderProcessor/saveOrPublishProviderLogs.ts index 44dededde..9906dcf09 100644 --- a/packages/core/src/services/chains/ProviderProcessor/saveOrPublishProviderLogs.ts +++ b/packages/core/src/services/chains/ProviderProcessor/saveOrPublishProviderLogs.ts @@ -16,18 +16,23 @@ export async function saveOrPublishProviderLogs({ streamType, saveSyncProviderLogs, finishReason, + chainCompleted, }: { workspace: Workspace streamType: StreamType data: ReturnType saveSyncProviderLogs: boolean finishReason: FinishReason + chainCompleted: boolean }) { publisher.publishLater({ type: 'aiProviderCallCompleted', - data: { ...data, streamType } as AIProviderCallCompletedData< - typeof streamType - >, + data: { + ...data, + finishReason, + chainCompleted, + streamType, + } as AIProviderCallCompletedData, }) const providerLogsData = { diff --git a/packages/core/src/services/chains/agents/run.ts b/packages/core/src/services/chains/agents/run.ts index b663d80f2..09fc3dd3d 100644 --- a/packages/core/src/services/chains/agents/run.ts +++ b/packages/core/src/services/chains/agents/run.ts @@ -26,7 +26,7 @@ import { StreamType, } from '../../../constants' import { Conversation } from '@latitude-data/compiler' -import { buildMessagesFromResponse, Workspace } from '../../../browser' +import { buildAllMessagesFromResponse, Workspace } from '../../../browser' import { ChainStreamConsumer } from '../ChainStreamConsumer' import { getCachedResponse, setCachedResponse } from '../../commits/promptCache' import { validateAgentStep } from './AgentStepValidator' @@ -97,7 +97,7 @@ export async function runAgent({ chainEventsReader.read().then(({ done, value }) => { if (value) { if (isChainComplete(value)) { - const messages = buildMessagesFromResponse({ + const messages = buildAllMessagesFromResponse({ response: value.data.response, }) conversation = { @@ -206,6 +206,7 @@ async function runAgentStep({ streamConsumer.chainCompleted({ step, response: previousResponse, + finishReason: 'stop', }) return previousResponse @@ -236,6 +237,7 @@ async function runAgentStep({ workspace, streamType: cachedResponse.streamType, finishReason: 'stop', // TODO: we probably should add a cached reason here + chainCompleted: step.chainCompleted, data: buildProviderLogDto({ workspace, source, @@ -258,13 +260,14 @@ async function runAgentStep({ streamConsumer.chainCompleted({ step, response, + finishReason: 'stop', }) return response } else { streamConsumer.stepCompleted(response) - const responseMessages = buildMessagesFromResponse({ response }) + const responseMessages = buildAllMessagesFromResponse({ response }) const nextConversation = { ...conversation, messages: responseMessages, @@ -311,12 +314,15 @@ async function runAgentStep({ source, workspace, startTime: stepStartTime, + chainCompleted: step.chainCompleted, + finishReason: consumedStream.finishReason, }) const providerLog = await saveOrPublishProviderLogs({ workspace, streamType: aiResult.type, finishReason: consumedStream.finishReason, + chainCompleted: step.chainCompleted, data: buildProviderLogDto({ workspace, source, @@ -340,7 +346,7 @@ async function runAgentStep({ streamConsumer.stepCompleted(response) - const responseMessages = buildMessagesFromResponse({ response }) + const responseMessages = buildAllMessagesFromResponse({ response }) const nextConversation = { ...conversation, messages: responseMessages, diff --git a/packages/core/src/services/chains/buildStep/index.ts b/packages/core/src/services/chains/buildStep/index.ts new file mode 100644 index 000000000..bdb7d1293 --- /dev/null +++ b/packages/core/src/services/chains/buildStep/index.ts @@ -0,0 +1,111 @@ +import { Chain as PromptlChain, Message as PromptlMessage } from 'promptl-ai' +import { FinishReason } from 'ai' + +import { ChainStepResponse, StreamType } from '../../../constants' +import { ChainStreamConsumer } from '../ChainStreamConsumer' +import { ValidatedChainStep } from '../ChainValidator' +import { runStep, StepProps } from '../run' +import { + buildProviderLogDto, + saveOrPublishProviderLogs, +} from '../ProviderProcessor/saveOrPublishProviderLogs' +import { cacheChain } from '../chainCache' +import { buildMessagesFromResponse } from '../../../helpers' + +function getToolCalls({ + response, +}: { + response: ChainStepResponse +}) { + const type = response.streamType + if (type === 'object') return [] + + const toolCalls = response.toolCalls ?? [] + + return toolCalls +} + +export async function buildStepExecution({ + streamConsumer, + baseResponse, + step, + stepProps, + providerLogProps: { streamType, finishReason, stepStartTime }, +}: { + streamConsumer: ChainStreamConsumer + baseResponse: ChainStepResponse + step: ValidatedChainStep + stepProps: StepProps + providerLogProps: { + streamType: StreamType + finishReason: FinishReason + stepStartTime: number + } +}) { + const { chain } = stepProps + const workspace = stepProps.workspace + const documentLogUuid = stepProps.errorableUuid + const providerLog = await saveOrPublishProviderLogs({ + workspace, + streamType, + finishReason, + chainCompleted: step.chainCompleted, + data: buildProviderLogDto({ + workspace, + provider: step.provider, + conversation: step.conversation, + source: stepProps.source, + errorableUuid: documentLogUuid, + stepStartTime, + response: baseResponse, + }), + // TODO: temp bugfix, shuold only save last one syncronously + saveSyncProviderLogs: true, + }) + + async function executeStep({ + finalResponse, + }: { + finalResponse: ChainStepResponse + }): Promise> { + const isPromptl = chain instanceof PromptlChain + const toolCalls = getToolCalls({ response: finalResponse }) + const hasTools = isPromptl && toolCalls.length > 0 + const responseMessages = buildMessagesFromResponse({ + response: finalResponse, + }) + + if (hasTools) { + await cacheChain({ + workspace, + chain, + documentLogUuid, + responseMessages: responseMessages as unknown as PromptlMessage[], + }) + } + + if (step.chainCompleted || hasTools) { + streamConsumer.chainCompleted({ + step, + response: finalResponse, + finishReason, + responseMessages, + }) + + return { + ...finalResponse, + finishReason, + chainCompleted: step.chainCompleted, + } + } + + streamConsumer.stepCompleted(finalResponse) + + return runStep({ + ...stepProps, + previousResponse: finalResponse, + }) + } + + return { providerLog, executeStep } +} diff --git a/packages/core/src/services/chains/chainCache/index.ts b/packages/core/src/services/chains/chainCache/index.ts new file mode 100644 index 000000000..b26583bba --- /dev/null +++ b/packages/core/src/services/chains/chainCache/index.ts @@ -0,0 +1,104 @@ +import { hash } from 'crypto' +import { + Chain as PromptlChain, + type SerializedChain, + Message, +} from 'promptl-ai' + +import { Workspace } from '../../../browser' +import { cache } from '../../../cache' + +type CachedChain = { + chain: PromptlChain + messages: Message[] +} + +function generateCacheKey({ + workspace, + documentLogUuid, +}: { + workspace: Workspace + documentLogUuid: string +}): string { + const k = hash('sha256', `${documentLogUuid}`) + return `workspace:${workspace.id}:chain:${k}` +} + +async function setToCache({ + key, + chain, + messages, +}: { + key: string + chain: SerializedChain + messages: Message[] +}) { + try { + const c = await cache() + await c.set(key, JSON.stringify({ chain, messages })) + } catch (e) { + // Silently fail cache writes + } +} + +async function getFromCache(key: string): Promise { + try { + const c = await cache() + const serialized = await c.get(key) + if (!serialized) return undefined + + const deserialized = JSON.parse(serialized) + const chain = PromptlChain.deserialize({ serialized: deserialized.chain }) + + if (!chain || !deserialized.messages) return undefined + return { chain, messages: deserialized.messages ?? [] } + } catch (e) { + return undefined + } +} + +export async function getCachedChain({ + documentLogUuid, + workspace, +}: { + documentLogUuid: string | undefined + workspace: Workspace +}) { + if (!documentLogUuid) return undefined + + const key = generateCacheKey({ documentLogUuid, workspace }) + return await getFromCache(key) +} + +export async function deleteCachedChain({ + documentLogUuid, + workspace, +}: { + documentLogUuid: string + workspace: Workspace +}) { + const key = generateCacheKey({ documentLogUuid, workspace }) + try { + const c = await cache() + await c.del(key) + } catch (e) { + // Silently fail cache writes + } +} + +export async function cacheChain({ + workspace, + documentLogUuid, + chain, + responseMessages, +}: { + workspace: Workspace + chain: PromptlChain + responseMessages: Message[] + documentLogUuid: string +}) { + const key = generateCacheKey({ documentLogUuid, workspace }) + const serialized = chain.serialize() + + await setToCache({ key, chain: serialized, messages: responseMessages }) +} diff --git a/packages/core/src/services/chains/run-errors.test.ts b/packages/core/src/services/chains/run-errors.test.ts index 98a0344e4..e318fff3b 100644 --- a/packages/core/src/services/chains/run-errors.test.ts +++ b/packages/core/src/services/chains/run-errors.test.ts @@ -370,6 +370,8 @@ describe('run chain error handling', () => { streamType: 'text', text: 'MY TEXT', toolCalls: [], + chainCompleted: true, + finishReason: 'stop', usage: { promptTokens: 3, completionTokens: 7, diff --git a/packages/core/src/services/chains/run.ts b/packages/core/src/services/chains/run.ts index 9c7bc89ab..6c4709ede 100644 --- a/packages/core/src/services/chains/run.ts +++ b/packages/core/src/services/chains/run.ts @@ -1,4 +1,4 @@ -import { Chain as LegacyChain } from '@latitude-data/compiler' +import { Chain as LegacyChain, Message } from '@latitude-data/compiler' import { RunErrorCodes } from '@latitude-data/constants/errors' import { Chain as PromptlChain } from 'promptl-ai' @@ -21,13 +21,14 @@ import { createRunError as createRunErrorFn } from '../runErrors/create' import { ChainError } from './ChainErrors' import { ChainStreamConsumer } from './ChainStreamConsumer' import { consumeStream } from './ChainStreamConsumer/consumeStream' -import { ConfigOverrides, validateChain } from './ChainValidator' +import { + ConfigOverrides, + validateChain, + ValidatedChainStep, +} from './ChainValidator' import { checkValidStream } from './checkValidStream' import { processResponse } from './ProviderProcessor' -import { - buildProviderLogDto, - saveOrPublishProviderLogs, -} from './ProviderProcessor/saveOrPublishProviderLogs' +import { buildStepExecution } from './buildStep' export type CachedApiKeys = Map export type SomeChain = LegacyChain | PromptlChain @@ -77,6 +78,8 @@ type CommonArgs = { generateUUID?: () => string persistErrors?: T removeSchema?: boolean + extraMessages?: Message[] + previousCount?: number } export type RunChainArgs< T extends boolean, @@ -98,6 +101,8 @@ export async function runChain({ persistErrors = true, generateUUID = generateUUIDIdentifier, removeSchema = false, + extraMessages, + previousCount = 0, }: RunChainArgs) { const errorableUuid = generateUUID() @@ -121,6 +126,8 @@ export async function runChain({ errorableType, configOverrides, removeSchema, + extraMessages, + previousCount, }) .then((okResponse) => { responseResolve(Result.ok(okResponse)) @@ -147,21 +154,34 @@ export async function runChain({ } } -async function runStep({ - workspace, - source, +function getNextStepCount({ + stepCount, chain, - promptlVersion, - providersMap, - controller, - previousCount = 0, - previousResponse, - errorableUuid, - errorableType, - configOverrides, - removeSchema, - stepCount = 0, + step, }: { + stepCount: number + chain: SomeChain + step: ValidatedChainStep +}) { + const maxSteps = Math.min( + (step.conversation.config[MAX_STEPS_CONFIG_NAME] as number | undefined) ?? + DEFAULT_MAX_STEPS, + ABSOLUTE_MAX_STEPS, + ) + const exceededMaxSteps = + chain instanceof PromptlChain ? stepCount >= maxSteps : stepCount > maxSteps + + if (!exceededMaxSteps) return Result.ok(++stepCount) + + return Result.error( + new ChainError({ + message: stepLimitExceededErrorMessage(maxSteps), + code: RunErrorCodes.MaxStepCountExceededError, + }), + ) +} + +export type StepProps = { workspace: Workspace source: LogSources chain: SomeChain @@ -172,11 +192,34 @@ async function runStep({ errorableType?: ErrorableEntity previousCount?: number previousResponse?: ChainStepResponse - configOverrides?: ConfigOverrides removeSchema?: boolean stepCount?: number -}) { - const prevText = previousResponse?.text + configOverrides?: ConfigOverrides + extraMessages?: Message[] +} + +export async function runStep({ + workspace, + source, + chain, + promptlVersion, + providersMap, + controller, + previousCount = 0, + previousResponse, + errorableUuid, + errorableType, + configOverrides, + extraMessages, + removeSchema, + stepCount = 0, +}: StepProps) { + // When passed extra messages it means we are resuming a conversation + const prevContent = previousResponse + ? previousResponse.text + : extraMessages + ? extraMessages + : undefined const streamConsumer = new ChainStreamConsumer({ controller, previousCount, @@ -186,7 +229,7 @@ async function runStep({ try { const step = await validateChain({ workspace, - prevText, + prevContent, chain, promptlVersion, providersMap, @@ -194,30 +237,20 @@ async function runStep({ removeSchema, }).then((r) => r.unwrap()) + const messages = step.conversation.messages + if (chain instanceof PromptlChain && step.chainCompleted) { streamConsumer.chainCompleted({ step, response: previousResponse!, + finishReason: previousResponse?.finishReason ?? 'stop', }) + previousResponse!.chainCompleted = true return previousResponse! } - const maxSteps = Math.min( - (step.conversation.config[MAX_STEPS_CONFIG_NAME] as number | undefined) ?? - DEFAULT_MAX_STEPS, - ABSOLUTE_MAX_STEPS, - ) - const exceededMaxSteps = - chain instanceof PromptlChain - ? stepCount >= maxSteps - : stepCount > maxSteps - if (exceededMaxSteps) { - throw new ChainError({ - message: stepLimitExceededErrorMessage(maxSteps), - code: RunErrorCodes.MaxStepCountExceededError, - }) - } + const nextStepCount = getNextStepCount({ stepCount, chain, step }).unwrap() const { messageCount, stepStartTime } = streamConsumer.setup(step) @@ -228,39 +261,17 @@ async function runStep({ }) if (cachedResponse) { - const providerLog = await saveOrPublishProviderLogs({ - workspace, - streamType: cachedResponse.streamType, - finishReason: 'stop', // TODO: we probably should add a cached reason here - data: buildProviderLogDto({ - workspace, - source, - provider: step.provider, - conversation: step.conversation, + const { providerLog, executeStep } = await buildStepExecution({ + baseResponse: cachedResponse as ChainStepResponse, + step, + streamConsumer, + providerLogProps: { + streamType: cachedResponse.streamType, + // NOTE: Before we were hardcoding `stop` when cached. + finishReason: cachedResponse.finishReason ?? 'stop', stepStartTime, - errorableUuid, - response: cachedResponse as ChainStepResponse, - }), - saveSyncProviderLogs: true, // TODO: temp bugfix, it should only save last one syncronously - }) - - const response = { - ...cachedResponse, - providerLog, - documentLogUuid: errorableUuid, - } as ChainStepResponse - - if (step.chainCompleted) { - streamConsumer.chainCompleted({ - step, - response, - }) - - return response - } else { - streamConsumer.stepCompleted(response) - - return runStep({ + }, + stepProps: { workspace, source, chain, @@ -270,28 +281,37 @@ async function runStep({ errorableUuid, errorableType, previousCount: messageCount, - previousResponse: response, + previousResponse: cachedResponse as ChainStepResponse, configOverrides, - stepCount: ++stepCount, - }) - } + stepCount: nextStepCount, + }, + }) + + const finalResponse = { + ...cachedResponse, + providerLog, + documentLogUuid: errorableUuid, + } as ChainStepResponse + + return executeStep({ finalResponse }) } const aiResult = await ai({ - messages: step.conversation.messages, + messages, config: step.config, provider: step.provider, schema: step.schema, output: step.output, }).then((r) => r.unwrap()) - const checkResult = checkValidStream(aiResult) + if (checkResult.error) throw checkResult.error const consumedStream = await consumeStream({ controller, result: aiResult, }) + if (consumedStream.error) throw consumedStream.error const _response = await processResponse({ @@ -299,48 +319,24 @@ async function runStep({ apiProvider: step.provider, config: step.config, errorableUuid, - messages: step.conversation.messages, + messages, source, workspace, startTime: stepStartTime, - }) - - const providerLog = await saveOrPublishProviderLogs({ - workspace, - streamType: aiResult.type, finishReason: consumedStream.finishReason, - data: buildProviderLogDto({ - workspace, - source, - provider: step.provider, - conversation: step.conversation, - stepStartTime, - errorableUuid, - response: _response, - }), - saveSyncProviderLogs: true, // TODO: temp bugfix, shuold only save last one syncronously - }) - - const response = { ..._response, providerLog } - - await setCachedResponse({ - workspace, - config: step.config, - conversation: step.conversation, - response, + chainCompleted: step.chainCompleted, }) - if (step.chainCompleted) { - streamConsumer.chainCompleted({ - step, - response, - }) - - return response - } else { - streamConsumer.stepCompleted(response) - - return runStep({ + const { providerLog, executeStep } = await buildStepExecution({ + step, + streamConsumer, + baseResponse: _response, + providerLogProps: { + streamType: aiResult.type, + finishReason: consumedStream.finishReason, + stepStartTime, + }, + stepProps: { workspace, source, chain, @@ -350,11 +346,21 @@ async function runStep({ providersMap, controller, previousCount: messageCount, - previousResponse: response, + previousResponse: _response, configOverrides, - stepCount: ++stepCount, - }) - } + stepCount: nextStepCount, + }, + }) + + const finalResponse = { ..._response, providerLog } + await setCachedResponse({ + workspace, + config: step.config, + conversation: step.conversation, + response: finalResponse, + }) + + return executeStep({ finalResponse }) } catch (e: unknown) { const error = streamConsumer.chainError(e) throw error diff --git a/packages/core/src/services/commits/runDocumentAtCommit.test.ts b/packages/core/src/services/commits/runDocumentAtCommit.test.ts index e09f7b797..0d0d3a334 100644 --- a/packages/core/src/services/commits/runDocumentAtCommit.test.ts +++ b/packages/core/src/services/commits/runDocumentAtCommit.test.ts @@ -289,6 +289,7 @@ model: gpt-4o type: 'chain-step-complete', documentLogUuid: expect.any(String), response: { + chainCompleted: true, documentLogUuid: expect.any(String), providerLog: expect.any(Object), streamType: 'text', @@ -304,6 +305,7 @@ model: gpt-4o data: { type: 'chain-complete', documentLogUuid: expect.any(String), + finishReason: 'stop', config: { provider: 'openai', model: 'gpt-4o', @@ -323,6 +325,7 @@ model: gpt-4o response: { streamType: 'text', text: 'Fake AI generated text', + chainCompleted: true, toolCalls: [], usage: { promptTokens: 0, completionTokens: 0, totalTokens: 0 }, providerLog: logs[logs.length - 1], diff --git a/packages/core/src/services/commits/runDocumentAtCommit.ts b/packages/core/src/services/commits/runDocumentAtCommit.ts index 868e28681..ceeb91a02 100644 --- a/packages/core/src/services/commits/runDocumentAtCommit.ts +++ b/packages/core/src/services/commits/runDocumentAtCommit.ts @@ -17,7 +17,7 @@ import { getResolvedContent } from '../documents' import { buildProvidersMap } from '../providerApiKeys/buildMap' import { RunDocumentChecker } from './RunDocumentChecker' -export async function createDocumentRunResult({ +async function createDocumentRunResult({ workspace, document, commit, @@ -43,6 +43,7 @@ export async function createDocumentRunResult({ response?: ChainStepResponse }) { const durantionInMs = duration ?? 0 + if (publishEvent) { publisher.publishLater({ type: 'documentRun', diff --git a/packages/core/src/services/documentLogs/addMessages/findPausedChain.ts b/packages/core/src/services/documentLogs/addMessages/findPausedChain.ts new file mode 100644 index 000000000..249d3da6a --- /dev/null +++ b/packages/core/src/services/documentLogs/addMessages/findPausedChain.ts @@ -0,0 +1,43 @@ +import { Workspace } from '../../../browser' +import { Result } from '../../../lib' +import { + CommitsRepository, + DocumentLogsRepository, + DocumentVersionsRepository, +} from '../../../repositories' +import { getCachedChain } from '../../chains/chainCache' + +export async function findPausedChain({ + workspace, + documentLogUuid, +}: { + workspace: Workspace + documentLogUuid: string | undefined +}) { + const cachedData = await getCachedChain({ workspace, documentLogUuid }) + if (!cachedData) return undefined + + const logsRepo = new DocumentLogsRepository(workspace.id) + const logResult = await logsRepo.findByUuid(documentLogUuid) + if (logResult.error) return logResult + + const documentLog = logResult.value + const commitsRepo = new CommitsRepository(workspace.id) + const commitResult = await commitsRepo.find(documentLog.commitId) + if (commitResult.error) return commitResult + + const commit = commitResult.value + const documentsRepo = new DocumentVersionsRepository(workspace.id) + const result = await documentsRepo.getDocumentByUuid({ + commitUuid: commit?.uuid, + documentUuid: documentLog.documentUuid, + }) + if (result.error) return result + + return Result.ok({ + document: result.value, + commit, + pausedChainMessages: cachedData.messages, + pausedChain: cachedData.chain, + }) +} diff --git a/packages/core/src/services/documentLogs/addMessages/index.test.ts b/packages/core/src/services/documentLogs/addMessages/index.test.ts index 657b93be7..7bf2e6161 100644 --- a/packages/core/src/services/documentLogs/addMessages/index.test.ts +++ b/packages/core/src/services/documentLogs/addMessages/index.test.ts @@ -458,6 +458,7 @@ describe('addMessages', () => { data: { type: 'chain-complete', documentLogUuid: providerLog.documentLogUuid!, + finishReason: 'stop', config: { provider: 'openai', model: 'gpt-4o', diff --git a/packages/core/src/services/documentLogs/addMessages/index.ts b/packages/core/src/services/documentLogs/addMessages/index.ts index 87f4b4a9d..1d29a2a42 100644 --- a/packages/core/src/services/documentLogs/addMessages/index.ts +++ b/packages/core/src/services/documentLogs/addMessages/index.ts @@ -1,8 +1,34 @@ import type { Message } from '@latitude-data/compiler' import { LogSources, Workspace } from '../../../browser' -import { ProviderLogsRepository } from '../../../repositories' import { addMessages as addMessagesProviderLog } from '../../providerLogs/addMessages' +import { ProviderLogsRepository } from '../../../repositories' +import { findPausedChain } from './findPausedChain' +import { resumeConversation } from './resumeConversation' +import { Result } from '../../../lib' + +type CommonProps = { + workspace: Workspace + documentLogUuid: string | undefined + messages: Message[] + source: LogSources +} + +async function addMessagesToCompleteChain({ + workspace, + documentLogUuid, + messages, + source, +}: CommonProps) { + const providerLogRepo = new ProviderLogsRepository(workspace.id) + const providerLogResult = + await providerLogRepo.findLastByDocumentLogUuid(documentLogUuid) + if (providerLogResult.error) return providerLogResult + + const providerLog = providerLogResult.value + + return addMessagesProviderLog({ workspace, providerLog, messages, source }) +} export async function addMessages({ workspace, @@ -15,12 +41,35 @@ export async function addMessages({ messages: Message[] source: LogSources }) { - const providerLogRepo = new ProviderLogsRepository(workspace.id) - const providerLogResult = - await providerLogRepo.findLastByDocumentLogUuid(documentLogUuid) - if (providerLogResult.error) return providerLogResult + if (!documentLogUuid) { + return Result.error(new Error('documentLogUuid is required')) + } - const providerLog = providerLogResult.value + const foundResult = await findPausedChain({ workspace, documentLogUuid }) - return addMessagesProviderLog({ workspace, providerLog, messages, source }) + if (!foundResult) { + // No chain cached found means normal chat behavior + return addMessagesToCompleteChain({ + workspace, + documentLogUuid, + messages, + source, + }) + } + + if (foundResult.error) return foundResult + + const pausedChainData = foundResult.value + + return resumeConversation({ + workspace, + commit: pausedChainData.commit, + document: pausedChainData.document, + pausedChain: pausedChainData.pausedChain, + pausedChainMessages: + pausedChainData.pausedChainMessages as unknown as Message[], + documentLogUuid, + responseMessages: messages, + source, + }) } diff --git a/packages/core/src/services/documentLogs/addMessages/resumeConversation/index.ts b/packages/core/src/services/documentLogs/addMessages/resumeConversation/index.ts new file mode 100644 index 000000000..504053dce --- /dev/null +++ b/packages/core/src/services/documentLogs/addMessages/resumeConversation/index.ts @@ -0,0 +1,100 @@ +import { + Commit, + ErrorableEntity, + DocumentVersion, + LogSources, + Workspace, +} from '../../../../browser' +import { Result } from '../../../../lib' +import { runChain } from '../../../chains/run' +import { buildProvidersMap } from '../../../providerApiKeys/buildMap' +import { Message } from '@latitude-data/compiler' +import { Chain as PromptlChain } from 'promptl-ai' +import { getResolvedContent } from '../../../documents' +import { deleteCachedChain } from '../../../chains/chainCache' + +/** + * What means resume a converstation + * :::::::::::::::::::: + * When a paused/cached chain is found in our cache (Redis at the time of writing), + * we asume is a paused and incompleted conversation. + * To resume it we re-run it passing `extraMessages` that will be passed down to prompt + * as response from the previous run. + * + * One use case is tool calling: + * This is helpful for tool calling. Allow users to preprare tool responses when + * AI returns a tool call. And Latitude add the tool request and tool response to the + * conversation so next time AI runs have all the info necesary to continue the conversation. + */ +export async function resumeConversation({ + workspace, + document, + commit, + documentLogUuid, + pausedChain, + pausedChainMessages, + responseMessages, + source, +}: { + workspace: Workspace + commit: Commit + document: DocumentVersion + documentLogUuid: string + pausedChain: PromptlChain + pausedChainMessages: Message[] + responseMessages: Message[] + source: LogSources +}) { + const resultResolvedContent = await getResolvedContent({ + workspaceId: workspace.id, + document, + commit, + }) + + if (resultResolvedContent.error) return resultResolvedContent + + const resolvedContent = resultResolvedContent.value + const errorableType = ErrorableEntity.DocumentLog + const providersMap = await buildProvidersMap({ + workspaceId: workspace.id, + }) + const errorableUuid = documentLogUuid + + let extraMessages = pausedChainMessages.concat(responseMessages) + // These are all the messages that the client + // already have seen. So we don't want to send them again. + const previousCount = + pausedChain.globalMessagesCount + + pausedChainMessages.length + + responseMessages.length + + const run = await runChain({ + generateUUID: () => errorableUuid, + errorableType, + workspace, + chain: pausedChain, + promptlVersion: document.promptlVersion, + providersMap, + source, + extraMessages, + previousCount, + }) + + return Result.ok({ + stream: run.stream, + duration: run.duration, + resolvedContent, + errorableUuid, + response: run.response.then(async (response) => { + const isCompleted = response.value?.chainCompleted + + if (isCompleted) { + // We delete cached chain so next time someone add a message to + // this documentLogUuid it will simple add the message to the conversation. + await deleteCachedChain({ workspace, documentLogUuid }) + } + + return response + }), + }) +} diff --git a/packages/core/src/services/providerLogs/addMessages.ts b/packages/core/src/services/providerLogs/addMessages.ts index e74d4389d..3c4c433e0 100644 --- a/packages/core/src/services/providerLogs/addMessages.ts +++ b/packages/core/src/services/providerLogs/addMessages.ts @@ -3,6 +3,7 @@ import { RunErrorCodes } from '@latitude-data/constants/errors' import { buildConversation, + buildMessagesFromResponse, ChainEvent, ChainStepResponse, LogSources, @@ -149,6 +150,8 @@ async function iterate({ apiProvider: provider, messages, startTime: stepStartTime, + finishReason: consumedStream.finishReason, + chainCompleted: true, }) const providerLog = await saveOrPublishProviderLogs({ @@ -165,10 +168,12 @@ async function iterate({ response: _response, }), saveSyncProviderLogs: true, + chainCompleted: true, }) const response = { ..._response, providerLog } + const responseMessages = buildMessagesFromResponse({ response }) ChainStreamConsumer.chainCompleted({ controller, response, @@ -177,6 +182,8 @@ async function iterate({ provider: provider.name, model: config.model, }, + finishReason: consumedStream.finishReason, + responseMessages, }) return response diff --git a/packages/sdks/typescript/package.json b/packages/sdks/typescript/package.json index 58b8c2e4b..996166020 100644 --- a/packages/sdks/typescript/package.json +++ b/packages/sdks/typescript/package.json @@ -36,12 +36,12 @@ } }, "dependencies": { - "@t3-oss/env-core": "^0.10.1", "@latitude-data/telemetry": "workspace:^", + "@t3-oss/env-core": "^0.10.1", "eventsource-parser": "^2.0.1", "node-fetch": "3.3.2", - "zod": "^3.23.8", - "promptl-ai": "^0.3.5" + "promptl-ai": "^0.4.5", + "zod": "^3.23.8" }, "devDependencies": { "@latitude-data/compiler": "workspace:^", diff --git a/packages/sdks/typescript/src/index.ts b/packages/sdks/typescript/src/index.ts index 1e32af172..4f59f52df 100644 --- a/packages/sdks/typescript/src/index.ts +++ b/packages/sdks/typescript/src/index.ts @@ -8,6 +8,7 @@ import { type ChainEventDto, type DocumentLog, type EvaluationResultDto, + type ToolCallResponse, } from '@latitude-data/constants' import { @@ -358,7 +359,8 @@ class Latitude { if (logResponses) { await this.logs.create( prompt.path, - step.messages as unknown as Message[], // Inexistent types incompatibilities between legacy messages and promptl messages + // Inexistent types incompatibilities between legacy messages and promptl messages + step.messages as unknown as Message[], ) } @@ -449,4 +451,5 @@ export type { MessageRole, Options, StreamChainResponse, + ToolCallResponse, } diff --git a/packages/web-ui/package.json b/packages/web-ui/package.json index b531fd953..da1b24a8d 100644 --- a/packages/web-ui/package.json +++ b/packages/web-ui/package.json @@ -21,6 +21,7 @@ "dependencies": { "@latitude-data/compiler": "workspace:^", "@latitude-data/core": "workspace:^", + "@latitude-data/constants": "workspace:^", "@radix-ui/react-checkbox": "^1.1.2", "@radix-ui/react-progress": "^1.1.0", "date-fns": "^3.6.0", diff --git a/packages/web-ui/src/ds/atoms/ClientOnly/index.tsx b/packages/web-ui/src/ds/atoms/ClientOnly/index.tsx index e4bced3a7..6b70ec422 100644 --- a/packages/web-ui/src/ds/atoms/ClientOnly/index.tsx +++ b/packages/web-ui/src/ds/atoms/ClientOnly/index.tsx @@ -2,7 +2,13 @@ import { ReactNode, useEffect, useState } from 'react' -export function ClientOnly({ children }: { children: ReactNode }) { +export function ClientOnly({ + children, + className, +}: { + children: ReactNode + className?: string +}) { const [mounted, setMounted] = useState(false) useEffect(() => { setMounted(true) @@ -12,6 +18,7 @@ export function ClientOnly({ children }: { children: ReactNode }) { // We have a Hydration issue with the inputs because // they come from localStorage and are not available on the server if (!mounted) return null + if (!className) return <>{children} - return <>{children} + return
{children}
} diff --git a/packages/web-ui/src/ds/atoms/CodeBlock/index.tsx b/packages/web-ui/src/ds/atoms/CodeBlock/index.tsx index c1f585a7c..0a3895506 100644 --- a/packages/web-ui/src/ds/atoms/CodeBlock/index.tsx +++ b/packages/web-ui/src/ds/atoms/CodeBlock/index.tsx @@ -29,6 +29,12 @@ export function CodeBlock(props: CodeBlockProps) { ) } +export function useCodeBlockBackgroundColor() { + const { resolvedTheme } = useTheme() + if (resolvedTheme === CurrentTheme.Light) return 'bg-backgroundCode' + return 'bg-[#282c34]' +} + function Content({ language, children, @@ -36,16 +42,11 @@ function Content({ className, }: CodeBlockProps) { const { resolvedTheme } = useTheme() - + const bgColor = useCodeBlockBackgroundColor() return ( -
+
{copy && ( -
+
)} diff --git a/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolBar/index.tsx b/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolBar/index.tsx new file mode 100644 index 000000000..c09ff58e1 --- /dev/null +++ b/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolBar/index.tsx @@ -0,0 +1,24 @@ +import { Button } from '../../../../atoms' + +export function ToolBar({ + onSubmit, + clearChat, + disabled = false, + submitLabel = 'Send Message', +}: { + onSubmit?: () => void + clearChat?: () => void + disabled?: boolean + submitLabel?: string +}) { + return ( +
+ + +
+ ) +} diff --git a/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolCallForm/Editor/index.tsx b/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolCallForm/Editor/index.tsx new file mode 100644 index 000000000..1e9d3ecdb --- /dev/null +++ b/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolCallForm/Editor/index.tsx @@ -0,0 +1,179 @@ +import { useCallback, useEffect, useRef, useState } from 'react' +import Editor, { Monaco } from '@monaco-editor/react' +import { type editor } from 'monaco-editor' +import { useMonacoSetup } from '../../../../DocumentTextEditor/Editor/useMonacoSetup' +import { TextEditorProps } from './types' + +const LINE_HEIGHT = 18 +const CONTAINER_GUTTER = 10 +function useUpdateEditorHeight({ initialHeight }: { initialHeight: number }) { + const [heightState, setHeight] = useState(initialHeight) + const prevLineCount = useRef(0) + const updateHeight = useCallback((editor: editor.IStandaloneCodeEditor) => { + const el = editor.getDomNode() + if (!el) return + const codeContainer = el.getElementsByClassName( + 'view-lines', + )[0] as HTMLDivElement | null + + if (!codeContainer) return + + setTimeout(() => { + const height = + codeContainer.childElementCount > prevLineCount.current + ? codeContainer.offsetHeight + : codeContainer.childElementCount * LINE_HEIGHT + CONTAINER_GUTTER // fold + prevLineCount.current = codeContainer.childElementCount + + // Max height + if (height >= 200) return + + setHeight(height) + el.style.height = height + 'px' + + editor.layout() + }, 0) + }, []) + return { height: heightState, updateHeight } +} + +function updatePlaceholder({ + collection, + decoration, + editor, + hasPlaceholder, +}: { + collection: editor.IEditorDecorationsCollection + decoration: editor.IModelDeltaDecoration + editor: editor.IStandaloneCodeEditor + monaco: Monaco + hasPlaceholder: boolean +}) { + const model = editor.getModel() + if (!model) return false + + const modelValue = model.getValue() + + if ((modelValue === '' || modelValue === ' ') && !hasPlaceholder) { + collection.append([decoration]) + return true + } else if (hasPlaceholder) { + collection.clear() + return false + } + + return false +} + +export default function TextEditor({ + value, + onChange, + onCmdEnter, + placeholder, +}: TextEditorProps) { + const editorRef = useRef(null) + const { monacoRef, handleEditorWillMount } = useMonacoSetup() + const isMountedRef = useRef(false) + const { height, updateHeight } = useUpdateEditorHeight({ initialHeight: 0 }) + const handleEditorDidMount = useCallback( + (editor: editor.IStandaloneCodeEditor, monaco: Monaco) => { + monacoRef.current = monaco + editorRef.current = editor + let hasPlaceholder = false + const collection = editor.createDecorationsCollection([]) + const decoration = { + range: new monaco.Range(1, 1, 1, 1), + options: { + isWholeLine: true, + className: 'hack-monaco-editor-placeholder', + }, + } + editor.focus() + + // Initial placeholder setup + hasPlaceholder = updatePlaceholder({ + collection, + decoration, + editor, + monaco, + hasPlaceholder, + }) + + editor.onDidChangeModelContent(() => { + if (!isMountedRef.current) return + + hasPlaceholder = updatePlaceholder({ + collection, + decoration, + editor, + monaco, + hasPlaceholder, + }) + }) + + editor.onDidChangeModelDecorations(() => { + updateHeight(editor) + }) + + editor.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.Enter, () => { + onCmdEnter?.(editor.getValue()) + }) + + isMountedRef.current = true + }, + [updateHeight, isMountedRef, monacoRef, placeholder, onCmdEnter, onChange], + ) + + // Refresh onCmdEnter prop callback when it changes + useEffect(() => { + if (!editorRef.current) return + if (!isMountedRef.current) return + if (!monacoRef.current) return + + const monaco = monacoRef.current + const editor = editorRef.current + editor.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.Enter, () => { + onCmdEnter?.(editor.getValue()) + }) + }, [onCmdEnter]) + return ( + <> + + + + ) +} diff --git a/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolCallForm/Editor/types.ts b/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolCallForm/Editor/types.ts new file mode 100644 index 000000000..05eb25e68 --- /dev/null +++ b/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolCallForm/Editor/types.ts @@ -0,0 +1,6 @@ +export type TextEditorProps = { + value?: string + onChange: (value: string | undefined) => void + onCmdEnter?: (value?: string | undefined) => void + placeholder: string +} diff --git a/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolCallForm/index.tsx b/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolCallForm/index.tsx new file mode 100644 index 000000000..83b712a8d --- /dev/null +++ b/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolCallForm/index.tsx @@ -0,0 +1,241 @@ +'use client' + +import { MouseEvent, lazy, useState, useCallback, useMemo } from 'react' +import { ToolMessage, Message } from '@latitude-data/compiler' +import { + Badge, + Icon, + Text, + CodeBlock, + Tooltip, + ClientOnly, +} from '../../../../atoms' +import { ToolRequest } from '../../../../../lib/versionedMessagesHelpers' +import { buildResponseMessage } from '@latitude-data/core/browser' +import { ToolBar } from '../ToolBar' + +const TextEditor = lazy(() => import('./Editor/index')) + +function generateExampleFunctionCall(toolCall: ToolRequest) { + const args = toolCall.toolArguments + const functionName = toolCall.toolName + const formattedArgs = Object.keys(args).length + ? JSON.stringify(args, null, 2) + : '' + + return `${functionName}(${formattedArgs ? formattedArgs : ''})` +} + +function buildToolResponseMessage({ + value, + toolRequest, +}: { + value: string + toolRequest: ToolRequest +}) { + const toolResponse = { + id: toolRequest.toolCallId, + name: toolRequest.toolName, + result: value, + } + const message = buildResponseMessage<'text'>({ + type: 'text', + data: { + text: undefined, + toolCallResponses: [toolResponse], + }, + }) + return message! as ToolMessage +} + +function ToolEditor({ + value, + toolRequest, + onChange, + onSubmit, + placeholder, + currentToolRequest, + totalToolRequests, +}: { + toolRequest: ToolRequest + placeholder: string + value: string | undefined + currentToolRequest: number + totalToolRequests: number + onChange: (value: string | undefined) => void + onSubmit?: (sentValue?: string | undefined) => void +}) { + const functionCall = useMemo( + () => generateExampleFunctionCall(toolRequest), + [toolRequest], + ) + return ( +
+
+ + + You have{' '} + {totalToolRequests <= 1 ? ( + <> + {totalToolRequests} tool to + be responded + + ) : ( + <> + {currentToolRequest} of{' '} + {totalToolRequests} tools to + be responded + + )} + + +
+ } + > + The Assistant has responded with a Tool Request. In your server, this + would mean to run the function from your code and send back the + result. However, since we cannot access your code execution from the + Playground, you must write the expected response when using this tool. + +
+ + {functionCall} + +
+
+ Your tool response + +
+
+ ) +} + +function getValue(value: string | MouseEvent | undefined) { + if (typeof value === 'string') return value.trim() + return undefined +} + +export function ToolCallForm({ + toolRequests, + sendToServer, + addLocalMessages, + clearChat, + disabled, + placeholder, +}: { + toolRequests: ToolRequest[] + placeholder: string + addLocalMessages: (messages: Message[]) => void + sendToServer?: (messages: ToolMessage[]) => void + clearChat?: () => void + disabled?: boolean +}) { + const [currentToolRequestIndex, setCurrentToolRequestIndex] = useState(1) + const [currentToolRequest, setCurrentToolRequest] = useState< + ToolRequest | undefined + >(toolRequests[0]) + const [respondedToolRequests, setRespondedToolRequests] = useState< + ToolMessage[] + >([]) + const [value, setValue] = useState('') + const onChange = useCallback((newValue: string | undefined) => { + setValue(newValue) + }, []) + const onLocalSend = useCallback( + (sentValue?: MouseEvent | string | undefined) => { + const cleanValue = getValue(value || sentValue) + + if (!currentToolRequest) return + if (!cleanValue) return + + const message = buildToolResponseMessage({ + value: cleanValue, + toolRequest: currentToolRequest, + }) + setRespondedToolRequests((prev) => [...prev, message]) + addLocalMessages([message]) + const findNextToolRequest = toolRequests.findIndex( + (tr) => tr.toolCallId === currentToolRequest.toolCallId, + ) + const nextIndex = findNextToolRequest + 1 + const nextToolRequest = toolRequests[nextIndex] + + if (nextToolRequest) { + setCurrentToolRequest(nextToolRequest) + setCurrentToolRequestIndex(nextIndex) + } + setValue('') + }, + [ + currentToolRequest, + addLocalMessages, + value, + setCurrentToolRequest, + setCurrentToolRequestIndex, + setRespondedToolRequests, + toolRequests, + ], + ) + + const onServerSend = useCallback( + (sentValue?: string | undefined) => { + const cleanValue = getValue(value || sentValue) + + if (!currentToolRequest) return + if (!cleanValue) return + + const message = buildToolResponseMessage({ + value: cleanValue, + toolRequest: currentToolRequest, + }) + addLocalMessages([message]) + const allToolMessages = respondedToolRequests.concat([message]) + sendToServer?.(allToolMessages) + }, + [ + addLocalMessages, + respondedToolRequests, + sendToServer, + currentToolRequest, + value, + ], + ) + + if (!currentToolRequest) return null + + const isLastRequest = + currentToolRequest.toolCallId === + toolRequests[toolRequests.length - 1]?.toolCallId + + const onSubmitHandler = isLastRequest ? onServerSend : onLocalSend + + return ( + + +
+ +
+
+ ) +} diff --git a/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/index.tsx b/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/index.tsx index 8ad04fa17..b66d56e62 100644 --- a/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/index.tsx +++ b/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/index.tsx @@ -4,12 +4,16 @@ import { KeyboardEvent, useCallback, useState } from 'react' import TextareaAutosize from 'react-textarea-autosize' -import { Button } from '../../../atoms' +import { ToolRequest } from '../../../../lib/versionedMessagesHelpers' +import { ToolCallForm } from './ToolCallForm' +import { Message, ToolMessage } from '@latitude-data/compiler' +import { cn } from '../../../../lib/utils' +import { ToolBar } from './ToolBar' -export function ChatTextArea({ +function SimpleTextArea({ placeholder, clearChat, - onSubmit, + onSubmit: onSubmitProp, disabled = false, }: { placeholder: string @@ -18,26 +22,23 @@ export function ChatTextArea({ onSubmit?: (value: string) => void }) { const [value, setValue] = useState('') - - const handleSubmit = useCallback(() => { + const onSubmit = useCallback(() => { if (disabled) return if (value === '') return setValue('') - onSubmit?.(value) - }, [value, onSubmit, disabled]) - + onSubmitProp?.(value) + }, [value, onSubmitProp, disabled]) const handleKeyDown = useCallback( (e: KeyboardEvent) => { if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault() - handleSubmit() + onSubmit() } }, - [handleSubmit], + [onSubmit], ) - return ( -
+ <> -
- - +
+
+ + ) +} + +type OnSubmitWithTools = (value: string | ToolMessage[]) => void +type OnSubmit = (value: string) => void +export function ChatTextArea({ + placeholder, + clearChat, + onSubmit, + disabled = false, + toolRequests = [], + addMessages, +}: { + placeholder: string + clearChat: () => void + disabled?: boolean + onSubmit?: OnSubmit | OnSubmitWithTools + addMessages?: (messages: Message[]) => void + toolRequests?: ToolRequest[] +}) { + return ( +
0, + }, + )} + > + {toolRequests.length > 0 && addMessages ? ( + + ) : ( + + )}
) } diff --git a/packages/web-ui/src/ds/molecules/DocumentTextEditor/Editor/RegularEditor.tsx b/packages/web-ui/src/ds/molecules/DocumentTextEditor/Editor/RegularEditor.tsx index 8a23ae538..4298ef9e7 100644 --- a/packages/web-ui/src/ds/molecules/DocumentTextEditor/Editor/RegularEditor.tsx +++ b/packages/web-ui/src/ds/molecules/DocumentTextEditor/Editor/RegularEditor.tsx @@ -34,7 +34,8 @@ export function RegularMonacoEditor({ const { monacoRef, handleEditorWillMount } = useMonacoSetup({ errorFixFn }) const [defaultValue, _] = useState(value) - const [isEditorMounted, setIsEditorMounted] = useState(false) // to avoid race conditions + // to avoid race conditions + const [isEditorMounted, setIsEditorMounted] = useState(false) const { options } = useEditorOptions({ tabSize: 2, readOnly: !!readOnlyMessage, diff --git a/packages/web-ui/src/ds/molecules/DocumentTextEditor/Editor/useMonacoSetup.ts b/packages/web-ui/src/ds/molecules/DocumentTextEditor/Editor/useMonacoSetup.ts index 063b7762d..417fcc74b 100644 --- a/packages/web-ui/src/ds/molecules/DocumentTextEditor/Editor/useMonacoSetup.ts +++ b/packages/web-ui/src/ds/molecules/DocumentTextEditor/Editor/useMonacoSetup.ts @@ -41,6 +41,7 @@ export function useMonacoSetup({ useEffect(() => { if (!monacoRef.current) return + applyTheme(monacoRef.current) }, [applyTheme]) diff --git a/packages/web-ui/src/index.ts b/packages/web-ui/src/index.ts index 258ef642c..a774e8b4c 100644 --- a/packages/web-ui/src/index.ts +++ b/packages/web-ui/src/index.ts @@ -10,4 +10,5 @@ export * from './lib/getUserInfo' export * from './lib/hooks/useAutoScroll' export * from './lib/hooks/useLocalStorage' export * from './lib/commonTypes' +export * from './lib/versionedMessagesHelpers' export * from './ds/atoms/ChartBlankSlate' diff --git a/packages/web-ui/src/lib/versionedMessagesHelpers.ts b/packages/web-ui/src/lib/versionedMessagesHelpers.ts new file mode 100644 index 000000000..6b2aa554d --- /dev/null +++ b/packages/web-ui/src/lib/versionedMessagesHelpers.ts @@ -0,0 +1,85 @@ +import { + Message as CompilerMessage, + ContentType as CompilerContentType, + MessageRole as CompilerMessageRole, + ToolRequestContent as CompilerToolRequestContent, +} from '@latitude-data/compiler' +import { ToolCallResponse as ToolResponse } from '@latitude-data/constants' +import { + Message as PromptlMessage, + ToolCallContent as ToolRequest, + ContentType as PromptlContentType, + MessageRole as PromptlMessageRole, +} from 'promptl-ai' + +export type PromptlVersion = 0 | 1 +export type VersionedMessage = V extends 0 + ? CompilerMessage + : PromptlMessage +type ToolResponsePart = Pick + +export type ToolPart = ToolRequest | ToolResponsePart + +function extractCompilerToolContents(messages: CompilerMessage[]): ToolPart[] { + return messages.flatMap((message) => { + if (message.role === CompilerMessageRole.tool) { + return message.content + .filter((content) => { + return content.type === CompilerContentType.toolResult + }) + .map((content) => ({ + id: content.toolCallId, + })) + } + + if (message.role !== CompilerMessageRole.assistant) return [] + if (typeof message.content === 'string') return [] + + const content = Array.isArray(message.content) + ? message.content + : [message.content] + + const toolRequestContents = content.filter((content) => { + return content.type === CompilerContentType.toolCall + }) as CompilerToolRequestContent[] + + return toolRequestContents.map((content) => ({ + type: PromptlContentType.toolCall, + toolCallId: content.toolCallId, + toolName: content.toolName, + toolArguments: content.args, + })) + }) +} + +function extractPromptlToolContents(messages: PromptlMessage[]): ToolPart[] { + return messages.flatMap((message) => { + if (message.role === PromptlMessageRole.tool) { + return [{ id: message.toolId }] + } + + if (message.role !== PromptlMessageRole.assistant) return [] + if (typeof message.content === 'string') return [] + + const content = Array.isArray(message.content) + ? message.content + : [message.content] + return content.filter((content) => { + return content.type === PromptlContentType.toolCall + }) + }) +} + +export function extractToolContents({ + version, + messages, +}: { + version: V + messages: VersionedMessage[] +}) { + return version === 0 + ? extractCompilerToolContents(messages as CompilerMessage[]) + : extractPromptlToolContents(messages as PromptlMessage[]) +} + +export type { ToolRequest } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 6b071821a..01801575e 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -260,8 +260,8 @@ importers: specifier: ^4.17.21 version: 4.17.21 promptl-ai: - specifier: ^0.3.5 - version: 0.3.5 + specifier: ^0.4.5 + version: 0.4.5 rate-limiter-flexible: specifier: ^5.0.3 version: 5.0.4 @@ -427,8 +427,8 @@ importers: specifier: ^1.161.6 version: 1.176.0 promptl-ai: - specifier: ^0.3.5 - version: 0.3.5 + specifier: ^0.4.5 + version: 0.4.5 rate-limiter-flexible: specifier: ^5.0.3 version: 5.0.4 @@ -773,8 +773,8 @@ importers: specifier: ^1.0.5 version: 1.0.5 promptl-ai: - specifier: ^0.3.5 - version: 0.3.5 + specifier: ^0.4.5 + version: 0.4.5 devDependencies: '@ai-sdk/anthropic': specifier: ^1.0.5 @@ -993,8 +993,8 @@ importers: specifier: 3.3.2 version: 3.3.2 promptl-ai: - specifier: ^0.3.5 - version: 0.3.5 + specifier: ^0.4.5 + version: 0.4.5 typescript: specifier: ^5.5.4 version: 5.7.2 @@ -1186,6 +1186,9 @@ importers: '@latitude-data/compiler': specifier: workspace:^ version: link:../compiler + '@latitude-data/constants': + specifier: workspace:^ + version: link:../constants '@latitude-data/core': specifier: workspace:^ version: link:../core @@ -11276,6 +11279,9 @@ packages: promptl-ai@0.3.5: resolution: {integrity: sha512-g3RtXgCOtMky68rkdalGGApO85DBZvL65gJy7SzZ+D+eEpCB2hR+7q2G67y+s66AHDzG2DM+sU+1WTRbDCrfyg==} + promptl-ai@0.4.5: + resolution: {integrity: sha512-bJkWlLtyxJkxt28+InTtKF8e4OZ67J9QtkOnI9xKBeEv6ILRFHcU09quTbCFi0wCL+6lw6okL2tV1yix21Jf6g==} + prop-types@15.8.1: resolution: {integrity: sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg==} @@ -12575,10 +12581,6 @@ packages: resolution: {integrity: sha512-4dbzIzqvjtgiM5rw1k5rEHtBANKmdudhGyBEajN01fEyhaAIhsoKNy6y7+IN93IfpFtwY9iqi7kD+xwKhQsNJA==} engines: {node: '>=8'} - type-fest@4.26.1: - resolution: {integrity: sha512-yOGpmOAL7CkKe/91I5O3gPICmJNLJ1G4zFYVAsRHg7M64biSnPtRj0WNQt++bRkjYOqjWXrhnUw1utzmVErAdg==} - engines: {node: '>=16'} - type-fest@4.30.0: resolution: {integrity: sha512-G6zXWS1dLj6eagy6sVhOMQiLtJdxQBHIA9Z6HFUNLOlr6MFOgzV8wvmidtPONfPtEUv0uZsy77XJNzTAfwPDaA==} engines: {node: '>=16'} @@ -19635,7 +19637,7 @@ snapshots: '@types/body-parser@1.19.5': dependencies: '@types/connect': 3.4.38 - '@types/node': 22.10.1 + '@types/node': 20.17.10 '@types/bytes@3.1.4': {} @@ -19654,7 +19656,7 @@ snapshots: '@types/connect@3.4.38': dependencies: - '@types/node': 22.10.1 + '@types/node': 20.17.10 '@types/cookie-parser@1.4.7': dependencies: @@ -19708,7 +19710,7 @@ snapshots: '@types/express-serve-static-core@4.19.6': dependencies: - '@types/node': 22.10.1 + '@types/node': 20.17.10 '@types/qs': 6.9.17 '@types/range-parser': 1.2.7 '@types/send': 0.17.4 @@ -19908,12 +19910,12 @@ snapshots: '@types/send@0.17.4': dependencies: '@types/mime': 1.3.5 - '@types/node': 22.10.1 + '@types/node': 20.17.10 '@types/serve-static@1.15.7': dependencies: '@types/http-errors': 2.0.4 - '@types/node': 22.10.1 + '@types/node': 20.17.10 '@types/send': 0.17.4 '@types/shimmer@1.2.0': {} @@ -25352,7 +25354,7 @@ snapshots: outvariant: 1.4.3 path-to-regexp: 6.3.0 strict-event-emitter: 0.5.1 - type-fest: 4.26.1 + type-fest: 4.30.0 yargs: 17.7.2 optionalDependencies: typescript: 5.7.2 @@ -25377,7 +25379,7 @@ snapshots: outvariant: 1.4.3 path-to-regexp: 6.3.0 strict-event-emitter: 0.5.1 - type-fest: 4.26.1 + type-fest: 4.30.0 yargs: 17.7.2 optionalDependencies: typescript: 5.7.2 @@ -25402,7 +25404,7 @@ snapshots: outvariant: 1.4.3 path-to-regexp: 6.3.0 strict-event-emitter: 0.5.1 - type-fest: 4.26.1 + type-fest: 4.30.0 yargs: 17.7.2 optionalDependencies: typescript: 5.7.2 @@ -26462,6 +26464,14 @@ snapshots: yaml: 2.6.0 zod: 3.23.8 + promptl-ai@0.4.5: + dependencies: + acorn: 8.14.0 + code-red: 1.0.4 + locate-character: 3.0.0 + yaml: 2.6.0 + zod: 3.23.8 + prop-types@15.8.1: dependencies: loose-envify: 1.4.0 @@ -28342,8 +28352,6 @@ snapshots: type-fest@0.8.1: {} - type-fest@4.26.1: {} - type-fest@4.30.0: {} type-is@1.6.18: