From 97f0894a9648445e253994700d83d4120b2eb92b Mon Sep 17 00:00:00 2001 From: andresgutgon Date: Thu, 2 Jan 2025 12:25:50 +0100 Subject: [PATCH] Handle tool calls in SDK and playground We have a task to allow users give a response on one or more tool calls that the AI can request in our playground. The thing is that the work necessary to make this work is also necessary for making tool call work on our SDK so this PR is a spike to have a picture of what are the moving parts necessary to make this happen --- apps/gateway/package.json | 2 +- apps/web/package.json | 2 +- .../_components/Playground/Chat/index.tsx | 2 +- .../DocumentEditor/Editor/Playground/Chat.tsx | 94 +++---- .../Editor/Playground/index.tsx | 4 +- .../Editor/Playground/useMessages.ts | 75 ++++++ .../Messages/AllMessages/index.tsx | 2 +- .../Messages/ChatMessages/index.tsx | 2 +- packages/compiler/src/types/message.ts | 2 +- packages/constants/src/ai.ts | 5 +- packages/core/package.json | 2 +- packages/core/src/constants.ts | 6 +- packages/core/src/events/events.d.ts | 4 +- packages/core/src/helpers.ts | 59 ++--- .../batchEvaluations/runEvaluationJob.test.ts | 4 + .../documentLogsRepository/index.ts | 6 +- .../providers/rules/providerMetadata/index.ts | 2 + .../chains/ChainStreamConsumer/index.test.ts | 102 +++++++- .../chains/ChainStreamConsumer/index.ts | 81 ++---- .../services/chains/ChainValidator/index.ts | 38 ++- .../chains/ProviderProcessor/index.test.ts | 2 + .../chains/ProviderProcessor/index.ts | 13 + .../saveOrPublishProviderLogs.test.ts | 3 + .../saveOrPublishProviderLogs.ts | 11 +- .../core/src/services/chains/agents/run.ts | 14 +- .../src/services/chains/buildStep/index.ts | 111 ++++++++ .../src/services/chains/chainCache/index.ts | 104 ++++++++ packages/core/src/services/chains/run.ts | 238 ++++++++--------- .../services/commits/runDocumentAtCommit.ts | 3 +- .../addMessages/findPausedChain.ts | 43 ++++ .../documentLogs/addMessages/index.test.ts | 1 + .../documentLogs/addMessages/index.ts | 63 ++++- .../addMessages/resumeConversation/index.ts | 100 ++++++++ .../src/services/providerLogs/addMessages.ts | 7 + packages/sdks/typescript/package.json | 6 +- packages/sdks/typescript/src/index.ts | 5 +- packages/web-ui/package.json | 1 + .../web-ui/src/ds/atoms/ClientOnly/index.tsx | 11 +- .../web-ui/src/ds/atoms/CodeBlock/index.tsx | 17 +- .../Chat/ChatTextArea/ToolBar/index.tsx | 24 ++ .../ToolCallForm/Editor/index.tsx | 179 +++++++++++++ .../ChatTextArea/ToolCallForm/Editor/types.ts | 6 + .../Chat/ChatTextArea/ToolCallForm/index.tsx | 241 ++++++++++++++++++ .../ds/molecules/Chat/ChatTextArea/index.tsx | 85 ++++-- .../Editor/RegularEditor.tsx | 3 +- .../Editor/useMonacoSetup.ts | 1 + packages/web-ui/src/index.ts | 1 + .../src/lib/versionedMessagesHelpers.ts | 85 ++++++ pnpm-lock.yaml | 52 ++-- 49 files changed, 1561 insertions(+), 363 deletions(-) create mode 100644 apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/useMessages.ts create mode 100644 packages/core/src/services/chains/buildStep/index.ts create mode 100644 packages/core/src/services/chains/chainCache/index.ts create mode 100644 packages/core/src/services/documentLogs/addMessages/findPausedChain.ts create mode 100644 packages/core/src/services/documentLogs/addMessages/resumeConversation/index.ts create mode 100644 packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolBar/index.tsx create mode 100644 packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolCallForm/Editor/index.tsx create mode 100644 packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolCallForm/Editor/types.ts create mode 100644 packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolCallForm/index.tsx create mode 100644 packages/web-ui/src/lib/versionedMessagesHelpers.ts diff --git a/apps/gateway/package.json b/apps/gateway/package.json index 3f3a53963..4fe574874 100644 --- a/apps/gateway/package.json +++ b/apps/gateway/package.json @@ -30,7 +30,7 @@ "drizzle-orm": "^0.33.0", "hono": "^4.6.6", "lodash-es": "^4.17.21", - "promptl-ai": "^0.3.5", + "promptl-ai": "^0.4.5", "rate-limiter-flexible": "^5.0.3", "zod": "^3.23.8" }, diff --git a/apps/web/package.json b/apps/web/package.json index 813debe44..a5990de61 100644 --- a/apps/web/package.json +++ b/apps/web/package.json @@ -51,7 +51,7 @@ "oslo": "1.2.0", "pdfjs-dist": "^4.9.155", "posthog-js": "^1.161.6", - "promptl-ai": "^0.3.5", + "promptl-ai": "^0.4.5", "rate-limiter-flexible": "^5.0.3", "react": "19.0.0-rc-5d19e1c8-20240923", "react-dom": "19.0.0-rc-5d19e1c8-20240923", diff --git a/apps/web/src/app/(private)/evaluations/(evaluation)/[evaluationUuid]/editor/_components/Playground/Chat/index.tsx b/apps/web/src/app/(private)/evaluations/(evaluation)/[evaluationUuid]/editor/_components/Playground/Chat/index.tsx index 828c694a3..c58c60366 100644 --- a/apps/web/src/app/(private)/evaluations/(evaluation)/[evaluationUuid]/editor/_components/Playground/Chat/index.tsx +++ b/apps/web/src/app/(private)/evaluations/(evaluation)/[evaluationUuid]/editor/_components/Playground/Chat/index.tsx @@ -176,7 +176,7 @@ export default function Chat({ ) : ( )} diff --git a/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/Chat.tsx b/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/Chat.tsx index 8e14284b3..4d0d55dd6 100644 --- a/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/Chat.tsx +++ b/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/Chat.tsx @@ -2,9 +2,9 @@ import { useCallback, useContext, useEffect, useRef, useState } from 'react' import { ContentType, - Conversation, Message as ConversationMessage, MessageRole, + ToolMessage, } from '@latitude-data/compiler' import { ChainEventTypes, @@ -31,15 +31,31 @@ import { readStreamableValue } from 'ai/rsc' import { DocumentEditorContext } from '..' import Actions, { ActionsState } from './Actions' +import { useMessages } from './useMessages' +import { type PromptlVersion } from '@latitude-data/web-ui' -export default function Chat({ +function buildMessage({ input }: { input: string | ToolMessage[] }) { + if (typeof input === 'string') { + return [ + { + role: MessageRole.user, + content: [{ type: ContentType.text, text: input }], + } as ConversationMessage, + ] + } + return input +} + +export default function Chat({ document, + promptlVersion, parameters, clearChat, expandParameters, setExpandParameters, }: { document: DocumentVersion + promptlVersion: V parameters: Record clearChat: () => void } & ActionsState) { @@ -62,25 +78,11 @@ export default function Chat({ const runChainOnce = useRef(false) // Index where the chain ends and the chat begins const [chainLength, setChainLength] = useState(Infinity) - const [conversation, setConversation] = useState() const [responseStream, setResponseStream] = useState() const [isStreaming, setIsStreaming] = useState(false) - - const addMessageToConversation = useCallback( - (message: ConversationMessage) => { - let newConversation: Conversation - setConversation((prevConversation) => { - newConversation = { - ...prevConversation, - messages: [...(prevConversation?.messages ?? []), message], - } as Conversation - return newConversation - }) - return newConversation! - }, - [], - ) - + const { messages, addMessages, unresponedToolCalls } = useMessages({ + version: promptlVersion, + }) const startStreaming = useCallback(() => { setError(undefined) setUsage({ @@ -122,7 +124,7 @@ export default function Chat({ if ('messages' in data) { setResponseStream(undefined) - data.messages!.forEach(addMessageToConversation) + addMessages(data.messages ?? []) messagesCount += data.messages!.length } @@ -148,6 +150,7 @@ export default function Chat({ } break } + default: break } @@ -163,7 +166,7 @@ export default function Chat({ commit.uuid, parameters, runDocumentAction, - addMessageToConversation, + addMessages, startStreaming, stopStreaming, ]) @@ -176,14 +179,15 @@ export default function Chat({ }, [runDocument]) const submitUserMessage = useCallback( - async (input: string) => { + async (input: string | ToolMessage[]) => { if (!documentLogUuid) return // This should not happen - const message: ConversationMessage = { - role: MessageRole.user, - content: [{ type: ContentType.text, text: input }], + const newMessages = buildMessage({ input }) + + // Only in Chat mode we add optimistically the message + if (typeof input === 'string') { + addMessages(newMessages) } - addMessageToConversation(message) let response = '' startStreaming() @@ -191,7 +195,7 @@ export default function Chat({ try { const { output } = await addMessagesAction({ documentLogUuid, - messages: [message], + messages: newMessages, }) for await (const serverEvent of readStreamableValue(output)) { @@ -201,7 +205,7 @@ export default function Chat({ if ('messages' in data) { setResponseStream(undefined) - data.messages!.forEach(addMessageToConversation) + addMessages(data.messages ?? []) } switch (event) { @@ -222,6 +226,7 @@ export default function Chat({ } break } + default: break } @@ -235,7 +240,7 @@ export default function Chat({ [ documentLogUuid, addMessagesAction, - addMessageToConversation, + addMessages, startStreaming, stopStreaming, ], @@ -255,24 +260,22 @@ export default function Chat({ className='flex flex-col gap-3 flex-grow flex-shrink min-h-0 custom-scrollbar scrollable-indicator pb-12' > - {(conversation?.messages.length ?? 0) >= chainLength && ( + {(messages.length ?? 0) >= chainLength && ( <> {time && } )} - {(conversation?.messages.length ?? 0) > chainLength && ( + {(messages.length ?? 0) > chainLength && ( <> Chat - + )} {error ? ( @@ -280,7 +283,7 @@ export default function Chat({ ) : ( )} @@ -296,6 +299,8 @@ export default function Chat({ placeholder='Enter followup message...' disabled={isStreaming} onSubmit={submitUserMessage} + toolRequests={unresponedToolCalls} + addMessages={addMessages} /> @@ -342,12 +347,8 @@ export function TokenUsage({ } >
- - {usage?.promptTokens || 0} prompt tokens - - - {usage?.completionTokens || 0} completion tokens - + {usage?.promptTokens || 0} prompt tokens + {usage?.completionTokens || 0} completion tokens
) : ( @@ -359,16 +360,15 @@ export function TokenUsage({ export function StreamMessage({ responseStream, - conversation, + messages, chainLength, }: { responseStream: string | undefined - conversation: Conversation | undefined + messages: ConversationMessage[] chainLength: number }) { if (responseStream === undefined) return null - if (conversation === undefined) return null - if (conversation.messages.length < chainLength - 1) { + if (messages.length < chainLength - 1) { return ( void metadata: ConversationMetadata }) { + const promptlVersion = document.promptlVersion === 1 ? 1 : 0 const [mode, setMode] = useState<'preview' | 'chat'>('preview') const { commit } = useCurrentCommit() const [expanded, setExpanded] = useState(true) @@ -77,6 +76,7 @@ export default function Playground({ ) : ( setMode('preview')} expandParameters={expandParameters} diff --git a/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/useMessages.ts b/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/useMessages.ts new file mode 100644 index 000000000..d35e5b6c2 --- /dev/null +++ b/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/useMessages.ts @@ -0,0 +1,75 @@ +import { + ToolPart, + ToolRequest, + PromptlVersion, + VersionedMessage, + extractToolContents, +} from '@latitude-data/web-ui' +import { Message as CompilerMessage } from '@latitude-data/compiler' +import { useCallback, useState } from 'react' + +function isToolRequest(part: ToolPart): part is ToolRequest { + return 'toolCallId' in part +} + +function getUnrespondedToolRequests({ + version: _, + messages, +}: { + version: V + messages: VersionedMessage[] +}) { + // FIXME: Kill compiler please. I made this module compatible with + // both old compiler and promptl. But because everything is typed with + // old compiler prompts in promptl version are also formatted as old compiler + const parts = extractToolContents({ + version: 0, + messages: messages as VersionedMessage<0>[], + }) + const toolRequestIds = new Set() + const toolResponses = new Set() + + parts.forEach((part) => { + if (isToolRequest(part)) { + toolRequestIds.add(part.toolCallId) + } else { + toolResponses.add(part.id) + } + }) + + return parts.filter( + (part): part is ToolRequest => + isToolRequest(part) && !toolResponses.has(part.toolCallId), + ) +} + +type Props = { version: V } + +export function useMessages({ version }: Props) { + const [messages, setMessages] = useState[]>([]) + const [unresponedToolCalls, setUnresponedToolCalls] = useState( + [], + ) + // FIXME: Kill compiler please + // every where we have old compiler types. To avoid Typescript crying we + // allow only compiler messages and then transform them to versioned messages + const addMessages = useCallback( + (m: CompilerMessage[]) => { + const msg = m as VersionedMessage[] + setMessages((prevMessages) => { + const newMessages = prevMessages.concat(msg) + setUnresponedToolCalls( + getUnrespondedToolRequests({ version, messages: newMessages }), + ) + return newMessages + }) + }, + [version], + ) + + return { + addMessages, + messages: messages as CompilerMessage[], + unresponedToolCalls, + } +} diff --git a/apps/web/src/app/(public)/share/d/[publishedDocumentUuid]/_components/Messages/AllMessages/index.tsx b/apps/web/src/app/(public)/share/d/[publishedDocumentUuid]/_components/Messages/AllMessages/index.tsx index a95ec0933..98eff6eb8 100644 --- a/apps/web/src/app/(public)/share/d/[publishedDocumentUuid]/_components/Messages/AllMessages/index.tsx +++ b/apps/web/src/app/(public)/share/d/[publishedDocumentUuid]/_components/Messages/AllMessages/index.tsx @@ -38,7 +38,7 @@ export function AllMessages({ ) : ( )} diff --git a/apps/web/src/app/(public)/share/d/[publishedDocumentUuid]/_components/Messages/ChatMessages/index.tsx b/apps/web/src/app/(public)/share/d/[publishedDocumentUuid]/_components/Messages/ChatMessages/index.tsx index bcdb46150..314f825e5 100644 --- a/apps/web/src/app/(public)/share/d/[publishedDocumentUuid]/_components/Messages/ChatMessages/index.tsx +++ b/apps/web/src/app/(public)/share/d/[publishedDocumentUuid]/_components/Messages/ChatMessages/index.tsx @@ -30,7 +30,7 @@ export function ChatMessages({ ) : ( )} diff --git a/packages/compiler/src/types/message.ts b/packages/compiler/src/types/message.ts index 66fd6628a..2a372e809 100644 --- a/packages/compiler/src/types/message.ts +++ b/packages/compiler/src/types/message.ts @@ -93,7 +93,7 @@ export type AssistantMessage = { export type ToolMessage = { role: MessageRole.tool - content: ToolContent[] + content: (TextContent | ToolContent)[] [key: string]: unknown } diff --git a/packages/constants/src/ai.ts b/packages/constants/src/ai.ts index 68d3e6e86..f26a1c331 100644 --- a/packages/constants/src/ai.ts +++ b/packages/constants/src/ai.ts @@ -1,6 +1,7 @@ import { Message, ToolCall } from '@latitude-data/compiler' import { CoreTool, + FinishReason, LanguageModelUsage, ObjectStreamPart, TextStreamPart, @@ -93,6 +94,7 @@ export type ChainEventDto = | { type: ChainEventTypes.Complete config: Config + finishReason?: FinishReason messages?: Message[] object?: any response: ChainEventDtoResponse @@ -174,6 +176,7 @@ export type LatitudeChainCompleteEventData = { messages?: Message[] object?: any response: ChainStepResponse + finishReason: FinishReason documentLogUuid?: string } @@ -188,14 +191,12 @@ export type LatitudeEventData = | LatitudeChainCompleteEventData | LatitudeChainErrorEventData -// FIXME: Move to @latitude-data/constants export type RunSyncAPIResponse = { uuid: string conversation: Message[] response: ChainCallResponseDto } -// FIXME: Move to @latitude-data/constants export type ChatSyncAPIResponse = RunSyncAPIResponse export const toolCallResponseSchema = z.object({ diff --git a/packages/core/package.json b/packages/core/package.json index 80e3ae3b7..94c3b9725 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -151,6 +151,6 @@ "dependencies": { "date-fns": "^3.6.0", "diff-match-patch": "^1.0.5", - "promptl-ai": "^0.3.5" + "promptl-ai": "^0.4.5" } } diff --git a/packages/core/src/constants.ts b/packages/core/src/constants.ts index b42b6f8b9..289a23cac 100644 --- a/packages/core/src/constants.ts +++ b/packages/core/src/constants.ts @@ -10,7 +10,7 @@ import { LogSources, ProviderData, } from '@latitude-data/constants' -import { LanguageModelUsage } from 'ai' +import { FinishReason, LanguageModelUsage } from 'ai' import { z } from 'zod' import { DocumentVersion, ProviderLog, Span, Trace } from './browser' @@ -62,6 +62,8 @@ export type StreamType = 'object' | 'text' type BaseResponse = { text: string usage: LanguageModelUsage + finishReason?: FinishReason + chainCompleted?: boolean documentLogUuid?: string providerLog?: ProviderLog } @@ -264,7 +266,7 @@ const toolResultContentSchema = z.object({ type: z.literal('tool-result'), toolCallId: z.string(), toolName: z.string(), - result: z.string(), + result: z.any(), isError: z.boolean().optional(), }) diff --git a/packages/core/src/events/events.d.ts b/packages/core/src/events/events.d.ts index 0f85744ff..ab1e1ef92 100644 --- a/packages/core/src/events/events.d.ts +++ b/packages/core/src/events/events.d.ts @@ -1,4 +1,4 @@ -import { LanguageModelUsage } from 'ai' +import { FinishReason, LanguageModelUsage } from 'ai' import type { ChainStepResponse, @@ -134,6 +134,8 @@ export type StreamCommonData = { messages: Message[] usage: LanguageModelUsage duration: number + chainCompleted: boolean + finishReason: FinishReason } type StreamTextData = { diff --git a/packages/core/src/helpers.ts b/packages/core/src/helpers.ts index ba86c76a3..0140531e5 100644 --- a/packages/core/src/helpers.ts +++ b/packages/core/src/helpers.ts @@ -85,6 +85,7 @@ type BuildMessageParams = T extends 'object' toolCallResponses?: ToolCallResponse[] } } + export function buildResponseMessage({ type, data, @@ -183,50 +184,30 @@ export function buildMessagesFromResponse({ }) : undefined - const previousMessages = response.providerLog - ? response.providerLog.messages - : [] - return [...previousMessages, ...(message ? [message] : [])] + return message ? ([message] as Message[]) : [] +} + +export function buildAllMessagesFromResponse({ + response, +}: { + response: ChainStepResponse +}) { + const previousMessages = response.providerLog?.messages ?? [] + const messages = buildMessagesFromResponse({ response }) + + return [...previousMessages, ...messages] } export function buildConversation(providerLog: ProviderLogDto) { let messages: Message[] = [...providerLog.messages] - let message: Message | undefined = undefined - - if (providerLog.response && providerLog.response.length > 0) { - message = { - role: 'assistant', - content: [ - { - type: 'text', - text: providerLog.response, - }, - ], - toolCalls: [], - } as Message - } - - if (providerLog.toolCalls.length > 0) { - const content = providerLog.toolCalls.map((toolCall) => { - return { - type: 'tool-call', - toolCallId: toolCall.id, - toolName: toolCall.name, - args: toolCall.arguments, - } as ToolRequestContent - }) - if (message) { - message.content = (message.content as MessageContent[]).concat(content) - message.toolCalls = providerLog.toolCalls - } else { - message = { - role: 'assistant', - content: content, - toolCalls: providerLog.toolCalls, - } as Message - } - } + const message = buildResponseMessage({ + type: 'text', + data: { + text: providerLog.response, + toolCalls: providerLog.toolCalls, + }, + }) if (message) { messages.push(message) diff --git a/packages/core/src/jobs/job-definitions/batchEvaluations/runEvaluationJob.test.ts b/packages/core/src/jobs/job-definitions/batchEvaluations/runEvaluationJob.test.ts index 3f698c080..6d25eb5b0 100644 --- a/packages/core/src/jobs/job-definitions/batchEvaluations/runEvaluationJob.test.ts +++ b/packages/core/src/jobs/job-definitions/batchEvaluations/runEvaluationJob.test.ts @@ -125,6 +125,8 @@ describe('runEvaluationJob', () => { usage: { promptTokens: 8, completionTokens: 2, totalTokens: 10 }, documentLogUuid: documentLog.uuid, providerLog: undefined, + finishReason: 'stop', + chainCompleted: true, }) runEvaluationSpy.mockResolvedValueOnce( Result.ok({ @@ -151,6 +153,8 @@ describe('runEvaluationJob', () => { usage: { promptTokens: 8, completionTokens: 2, totalTokens: 10 }, documentLogUuid: documentLog.uuid, providerLog: undefined, + finishReason: 'stop', + chainCompleted: true, }) runEvaluationSpy.mockResolvedValueOnce( Result.ok({ diff --git a/packages/core/src/repositories/documentLogsRepository/index.ts b/packages/core/src/repositories/documentLogsRepository/index.ts index 968c7417b..4825de993 100644 --- a/packages/core/src/repositories/documentLogsRepository/index.ts +++ b/packages/core/src/repositories/documentLogsRepository/index.ts @@ -42,7 +42,11 @@ export class DocumentLogsRepository extends Repository { .$dynamic() } - async findByUuid(uuid: string) { + async findByUuid(uuid: string | undefined) { + if (!uuid) { + return Result.error(new NotFoundError('DocumentLog not found')) + } + const result = await this.scope.where(eq(documentLogs.uuid, uuid)) if (!result.length) { diff --git a/packages/core/src/services/ai/providers/rules/providerMetadata/index.ts b/packages/core/src/services/ai/providers/rules/providerMetadata/index.ts index d9a6bd006..5e34e222e 100644 --- a/packages/core/src/services/ai/providers/rules/providerMetadata/index.ts +++ b/packages/core/src/services/ai/providers/rules/providerMetadata/index.ts @@ -21,6 +21,8 @@ const CONTENT_DEFINED_ATTRIBUTES = [ 'toolCallId', 'toolName', 'args', + // TODO: Add a test for this + 'result', ] as const type AttrArgs = { attributes: string[]; content: Record } diff --git a/packages/core/src/services/chains/ChainStreamConsumer/index.test.ts b/packages/core/src/services/chains/ChainStreamConsumer/index.test.ts index e97ff00fa..3334cae43 100644 --- a/packages/core/src/services/chains/ChainStreamConsumer/index.test.ts +++ b/packages/core/src/services/chains/ChainStreamConsumer/index.test.ts @@ -99,6 +99,8 @@ describe('ChainStreamConsumer', () => { toolCalls: [], usage: { promptTokens: 0, completionTokens: 0, totalTokens: 0 }, documentLogUuid: 'errorable-uuid', + finishReason: 'stop', + chainCompleted: false, } consumer.stepCompleted(response) @@ -121,9 +123,27 @@ describe('ChainStreamConsumer', () => { toolCalls: [], usage: { promptTokens: 0, completionTokens: 0, totalTokens: 0 }, documentLogUuid: 'errorable-uuid', + finishReason: 'stop', + chainCompleted: true, } - consumer.chainCompleted({ step, response }) + consumer.chainCompleted({ + step, + response, + finishReason: 'stop', + responseMessages: [ + { + role: MessageRole.assistant, + content: [ + { + type: ContentType.text, + text: 'text response', + }, + ], + toolCalls: [], + }, + ], + }) expect(controller.enqueue).toHaveBeenCalledWith({ event: StreamEventTypes.Latitude, @@ -157,9 +177,27 @@ describe('ChainStreamConsumer', () => { object: { object: 'response' }, usage: { promptTokens: 0, completionTokens: 0, totalTokens: 0 }, documentLogUuid: 'errorable-uuid', + finishReason: 'stop', + chainCompleted: true, } - consumer.chainCompleted({ step, response }) + consumer.chainCompleted({ + step, + response, + finishReason: 'stop', + responseMessages: [ + { + role: MessageRole.assistant, + content: [ + { + type: ContentType.text, + text: '{\n "object": "response"\n}', + }, + ], + toolCalls: [], + }, + ], + }) expect(controller.enqueue).toHaveBeenCalledWith({ event: StreamEventTypes.Latitude, @@ -199,9 +237,35 @@ describe('ChainStreamConsumer', () => { ], usage: { promptTokens: 0, completionTokens: 0, totalTokens: 0 }, documentLogUuid: 'errorable-uuid', + finishReason: 'stop', + chainCompleted: true, } - consumer.chainCompleted({ step, response }) + consumer.chainCompleted({ + step, + response, + finishReason: 'stop', + responseMessages: [ + { + role: MessageRole.assistant, + content: [ + { + type: ContentType.toolCall, + toolCallId: 'tool-call-id', + toolName: 'tool-call-name', + args: { arg1: 'value1', arg2: 'value2' }, + }, + ], + toolCalls: [ + { + id: 'tool-call-id', + name: 'tool-call-name', + arguments: { arg1: 'value1', arg2: 'value2' }, + }, + ], + }, + ], + }) expect(controller.enqueue).toHaveBeenCalledWith({ event: StreamEventTypes.Latitude, @@ -249,9 +313,39 @@ describe('ChainStreamConsumer', () => { ], usage: { promptTokens: 0, completionTokens: 0, totalTokens: 0 }, documentLogUuid: 'errorable-uuid', + finishReason: 'stop', + chainCompleted: true, } - consumer.chainCompleted({ step, response }) + consumer.chainCompleted({ + step, + response, + finishReason: 'stop', + responseMessages: [ + { + role: MessageRole.assistant, + content: [ + { + type: ContentType.text, + text: 'text response', + }, + { + type: ContentType.toolCall, + toolCallId: 'tool-call-id', + toolName: 'tool-call-name', + args: { arg1: 'value1', arg2: 'value2' }, + }, + ], + toolCalls: [ + { + id: 'tool-call-id', + name: 'tool-call-name', + arguments: { arg1: 'value1', arg2: 'value2' }, + }, + ], + }, + ], + }) expect(controller.enqueue).toHaveBeenCalledWith({ event: StreamEventTypes.Latitude, diff --git a/packages/core/src/services/chains/ChainStreamConsumer/index.ts b/packages/core/src/services/chains/ChainStreamConsumer/index.ts index 2a7f51fb6..9da671b14 100644 --- a/packages/core/src/services/chains/ChainStreamConsumer/index.ts +++ b/packages/core/src/services/chains/ChainStreamConsumer/index.ts @@ -1,9 +1,3 @@ -import { - ContentType, - MessageContent, - MessageRole, - ToolRequestContent, -} from '@latitude-data/compiler' import { RunErrorCodes } from '@latitude-data/constants/errors' import { @@ -14,13 +8,14 @@ import { StreamEventTypes, StreamType, } from '../../../constants' -import { objectToString } from '../../../helpers' import { Config } from '../../ai' import { ChainError } from '../ChainErrors' import { ValidatedChainStep } from '../ChainValidator' import { ValidatedAgentStep } from '../agents/AgentStepValidator' type ValidatedStep = ValidatedChainStep | ValidatedAgentStep +import { FinishReason } from 'ai' +import { buildMessagesFromResponse } from '../../../helpers' export function enqueueChainEvent( controller: ReadableStreamDefaultController, @@ -54,66 +49,15 @@ export class ChainStreamConsumer { response, config, controller, + finishReason, + responseMessages, }: { controller: ReadableStreamDefaultController response: ChainStepResponse config: Config + finishReason: FinishReason + responseMessages: Message[] }) { - let messages: Message[] = [] - let message: Message | undefined = undefined - - if (response.text.length > 0) { - message = { - role: MessageRole.assistant, - content: [ - { - type: ContentType.text, - text: response.text, - }, - ], - toolCalls: [], - } - } - - if (response.streamType === 'object' && response.object) { - message = { - role: MessageRole.assistant, - content: [ - { - type: ContentType.text, - text: objectToString(response.object), - }, - ], - toolCalls: [], - } - } - - if (response.streamType === 'text' && response.toolCalls.length > 0) { - const content = response.toolCalls.map((toolCall) => { - return { - type: ContentType.toolCall, - toolCallId: toolCall.id, - toolName: toolCall.name, - args: toolCall.arguments, - } as ToolRequestContent - }) - - if (message) { - message.content = (message.content as MessageContent[]).concat(content) - message.toolCalls = response.toolCalls - } else { - message = { - role: MessageRole.assistant, - content: content, - toolCalls: response.toolCalls, - } - } - } - - if (message) { - messages.push(message) - } - enqueueChainEvent(controller, { event: StreamEventTypes.Latitude, data: { @@ -121,7 +65,8 @@ export class ChainStreamConsumer { config, documentLogUuid: response.documentLogUuid, response, - messages, + messages: responseMessages, + finishReason, }, }) @@ -199,14 +144,22 @@ export class ChainStreamConsumer { chainCompleted({ step, response, + finishReason, + responseMessages: defaultResponseMessages, }: { step: ValidatedStep response: ChainStepResponse + finishReason: FinishReason + responseMessages?: Message[] }) { - ChainStreamConsumer.chainCompleted({ + const responseMessages = + defaultResponseMessages ?? buildMessagesFromResponse({ response }) + return ChainStreamConsumer.chainCompleted({ controller: this.controller, response, config: step.conversation.config as Config, + finishReason, + responseMessages, }) } diff --git a/packages/core/src/services/chains/ChainValidator/index.ts b/packages/core/src/services/chains/ChainValidator/index.ts index 91cf65c1d..81218149c 100644 --- a/packages/core/src/services/chains/ChainValidator/index.ts +++ b/packages/core/src/services/chains/ChainValidator/index.ts @@ -6,7 +6,7 @@ import { } from '@latitude-data/compiler' import { RunErrorCodes } from '@latitude-data/constants/errors' import { JSONSchema7 } from 'json-schema' -import { Chain as PromptlChain } from 'promptl-ai' +import { Chain as PromptlChain, Message as PromptlMessage } from 'promptl-ai' import { z } from 'zod' import { applyProviderRules, ProviderApiKey, Workspace } from '../../../browser' @@ -33,7 +33,7 @@ export type ConfigOverrides = JSONOverride | { output: 'no-schema' } type ValidatorContext = { workspace: Workspace - prevText: string | undefined + prevContent: Message[] | string | undefined chain: SomeChain promptlVersion: number providersMap: CachedApiKeys @@ -77,24 +77,48 @@ const getOutputType = ({ return configSchema.type === 'array' ? 'array' : 'object' } +/* + * Legacy compiler wants a string as response for the next step + * But new Promptl can handle an array of messages + */ +function getTextFromMessages( + prevContent: Message[] | string | undefined, +): string | undefined { + if (!prevContent) return undefined + if (typeof prevContent === 'string') return prevContent + + try { + return prevContent + .flatMap((message) => { + if (typeof message.content === 'string') return message.content + return JSON.stringify(message.content) + }) + .join('\n') + } catch { + return '' + } +} + const safeChain = async ({ promptlVersion, chain, - prevText, + prevContent, }: { promptlVersion: number chain: SomeChain - prevText: string | undefined + prevContent: Message[] | string | undefined }) => { try { if (promptlVersion === 0) { + let prevText = getTextFromMessages(prevContent) const { completed, conversation } = await (chain as LegacyChain).step( prevText, ) return Result.ok({ chainCompleted: completed, conversation }) } + const { completed, messages, config } = await (chain as PromptlChain).step( - prevText, + prevContent as PromptlMessage[] | undefined | string, ) return Result.ok({ chainCompleted: completed, @@ -170,13 +194,13 @@ export const validateChain = async ( const { workspace, promptlVersion, - prevText, + prevContent, chain, providersMap, configOverrides, removeSchema, } = context - const chainResult = await safeChain({ promptlVersion, chain, prevText }) + const chainResult = await safeChain({ promptlVersion, chain, prevContent }) if (chainResult.error) return chainResult const { chainCompleted, conversation } = chainResult.value diff --git a/packages/core/src/services/chains/ProviderProcessor/index.test.ts b/packages/core/src/services/chains/ProviderProcessor/index.test.ts index 588e69ee2..6f2b728ed 100644 --- a/packages/core/src/services/chains/ProviderProcessor/index.test.ts +++ b/packages/core/src/services/chains/ProviderProcessor/index.test.ts @@ -93,6 +93,8 @@ describe('ProviderProcessor', () => { }, }, startTime: Date.now(), + chainCompleted: true, + finishReason: 'stop', }) expect(result).toEqual({ diff --git a/packages/core/src/services/chains/ProviderProcessor/index.ts b/packages/core/src/services/chains/ProviderProcessor/index.ts index 9e99213f9..97311bb20 100644 --- a/packages/core/src/services/chains/ProviderProcessor/index.ts +++ b/packages/core/src/services/chains/ProviderProcessor/index.ts @@ -12,6 +12,7 @@ import { generateUUIDIdentifier } from '../../../lib' import { AIReturn, PartialConfig } from '../../ai' import { processStreamObject } from './processStreamObject' import { processStreamText } from './processStreamText' +import { FinishReason } from 'ai' async function buildCommonData({ aiResult, @@ -21,6 +22,8 @@ async function buildCommonData({ apiProvider, config, messages, + chainCompleted, + finishReason, errorableUuid, }: { aiResult: Awaited> @@ -30,6 +33,8 @@ async function buildCommonData({ apiProvider: ProviderApiKey config: PartialConfig messages: Message[] + chainCompleted: boolean + finishReason: FinishReason errorableUuid?: string }): Promise { const endTime = Date.now() @@ -53,6 +58,8 @@ async function buildCommonData({ config: config, messages: messages, usage: await aiResult.data.usage, + finishReason, + chainCompleted, } } @@ -67,6 +74,8 @@ export async function processResponse({ apiProvider, config, messages, + finishReason, + chainCompleted, errorableUuid, }: { aiResult: Awaited> @@ -76,6 +85,8 @@ export async function processResponse({ apiProvider: ProviderApiKey config: PartialConfig messages: Message[] + finishReason: FinishReason + chainCompleted: boolean errorableUuid?: string }): Promise> { const commonData = await buildCommonData({ @@ -87,6 +98,8 @@ export async function processResponse({ config, messages, errorableUuid, + finishReason, + chainCompleted, }) if (aiResult.type === 'text') { diff --git a/packages/core/src/services/chains/ProviderProcessor/saveOrPublishProviderLogs.test.ts b/packages/core/src/services/chains/ProviderProcessor/saveOrPublishProviderLogs.test.ts index e838598c1..0a290ce07 100644 --- a/packages/core/src/services/chains/ProviderProcessor/saveOrPublishProviderLogs.test.ts +++ b/packages/core/src/services/chains/ProviderProcessor/saveOrPublishProviderLogs.test.ts @@ -84,6 +84,7 @@ describe('saveOrPublishProviderLogs', () => { streamType: 'text', saveSyncProviderLogs: true, finishReason: 'stop', + chainCompleted: true, }) expect(publisherSpy).toHaveBeenCalledWith({ @@ -99,6 +100,7 @@ describe('saveOrPublishProviderLogs', () => { workspace, saveSyncProviderLogs: true, finishReason: 'stop', + chainCompleted: true, }) expect(createProviderLogSpy).toHaveBeenCalledWith({ @@ -114,6 +116,7 @@ describe('saveOrPublishProviderLogs', () => { streamType: 'text', saveSyncProviderLogs: false, finishReason: 'stop', + chainCompleted: true, workspace, }) diff --git a/packages/core/src/services/chains/ProviderProcessor/saveOrPublishProviderLogs.ts b/packages/core/src/services/chains/ProviderProcessor/saveOrPublishProviderLogs.ts index 44dededde..9906dcf09 100644 --- a/packages/core/src/services/chains/ProviderProcessor/saveOrPublishProviderLogs.ts +++ b/packages/core/src/services/chains/ProviderProcessor/saveOrPublishProviderLogs.ts @@ -16,18 +16,23 @@ export async function saveOrPublishProviderLogs({ streamType, saveSyncProviderLogs, finishReason, + chainCompleted, }: { workspace: Workspace streamType: StreamType data: ReturnType saveSyncProviderLogs: boolean finishReason: FinishReason + chainCompleted: boolean }) { publisher.publishLater({ type: 'aiProviderCallCompleted', - data: { ...data, streamType } as AIProviderCallCompletedData< - typeof streamType - >, + data: { + ...data, + finishReason, + chainCompleted, + streamType, + } as AIProviderCallCompletedData, }) const providerLogsData = { diff --git a/packages/core/src/services/chains/agents/run.ts b/packages/core/src/services/chains/agents/run.ts index b663d80f2..09fc3dd3d 100644 --- a/packages/core/src/services/chains/agents/run.ts +++ b/packages/core/src/services/chains/agents/run.ts @@ -26,7 +26,7 @@ import { StreamType, } from '../../../constants' import { Conversation } from '@latitude-data/compiler' -import { buildMessagesFromResponse, Workspace } from '../../../browser' +import { buildAllMessagesFromResponse, Workspace } from '../../../browser' import { ChainStreamConsumer } from '../ChainStreamConsumer' import { getCachedResponse, setCachedResponse } from '../../commits/promptCache' import { validateAgentStep } from './AgentStepValidator' @@ -97,7 +97,7 @@ export async function runAgent({ chainEventsReader.read().then(({ done, value }) => { if (value) { if (isChainComplete(value)) { - const messages = buildMessagesFromResponse({ + const messages = buildAllMessagesFromResponse({ response: value.data.response, }) conversation = { @@ -206,6 +206,7 @@ async function runAgentStep({ streamConsumer.chainCompleted({ step, response: previousResponse, + finishReason: 'stop', }) return previousResponse @@ -236,6 +237,7 @@ async function runAgentStep({ workspace, streamType: cachedResponse.streamType, finishReason: 'stop', // TODO: we probably should add a cached reason here + chainCompleted: step.chainCompleted, data: buildProviderLogDto({ workspace, source, @@ -258,13 +260,14 @@ async function runAgentStep({ streamConsumer.chainCompleted({ step, response, + finishReason: 'stop', }) return response } else { streamConsumer.stepCompleted(response) - const responseMessages = buildMessagesFromResponse({ response }) + const responseMessages = buildAllMessagesFromResponse({ response }) const nextConversation = { ...conversation, messages: responseMessages, @@ -311,12 +314,15 @@ async function runAgentStep({ source, workspace, startTime: stepStartTime, + chainCompleted: step.chainCompleted, + finishReason: consumedStream.finishReason, }) const providerLog = await saveOrPublishProviderLogs({ workspace, streamType: aiResult.type, finishReason: consumedStream.finishReason, + chainCompleted: step.chainCompleted, data: buildProviderLogDto({ workspace, source, @@ -340,7 +346,7 @@ async function runAgentStep({ streamConsumer.stepCompleted(response) - const responseMessages = buildMessagesFromResponse({ response }) + const responseMessages = buildAllMessagesFromResponse({ response }) const nextConversation = { ...conversation, messages: responseMessages, diff --git a/packages/core/src/services/chains/buildStep/index.ts b/packages/core/src/services/chains/buildStep/index.ts new file mode 100644 index 000000000..bdb7d1293 --- /dev/null +++ b/packages/core/src/services/chains/buildStep/index.ts @@ -0,0 +1,111 @@ +import { Chain as PromptlChain, Message as PromptlMessage } from 'promptl-ai' +import { FinishReason } from 'ai' + +import { ChainStepResponse, StreamType } from '../../../constants' +import { ChainStreamConsumer } from '../ChainStreamConsumer' +import { ValidatedChainStep } from '../ChainValidator' +import { runStep, StepProps } from '../run' +import { + buildProviderLogDto, + saveOrPublishProviderLogs, +} from '../ProviderProcessor/saveOrPublishProviderLogs' +import { cacheChain } from '../chainCache' +import { buildMessagesFromResponse } from '../../../helpers' + +function getToolCalls({ + response, +}: { + response: ChainStepResponse +}) { + const type = response.streamType + if (type === 'object') return [] + + const toolCalls = response.toolCalls ?? [] + + return toolCalls +} + +export async function buildStepExecution({ + streamConsumer, + baseResponse, + step, + stepProps, + providerLogProps: { streamType, finishReason, stepStartTime }, +}: { + streamConsumer: ChainStreamConsumer + baseResponse: ChainStepResponse + step: ValidatedChainStep + stepProps: StepProps + providerLogProps: { + streamType: StreamType + finishReason: FinishReason + stepStartTime: number + } +}) { + const { chain } = stepProps + const workspace = stepProps.workspace + const documentLogUuid = stepProps.errorableUuid + const providerLog = await saveOrPublishProviderLogs({ + workspace, + streamType, + finishReason, + chainCompleted: step.chainCompleted, + data: buildProviderLogDto({ + workspace, + provider: step.provider, + conversation: step.conversation, + source: stepProps.source, + errorableUuid: documentLogUuid, + stepStartTime, + response: baseResponse, + }), + // TODO: temp bugfix, shuold only save last one syncronously + saveSyncProviderLogs: true, + }) + + async function executeStep({ + finalResponse, + }: { + finalResponse: ChainStepResponse + }): Promise> { + const isPromptl = chain instanceof PromptlChain + const toolCalls = getToolCalls({ response: finalResponse }) + const hasTools = isPromptl && toolCalls.length > 0 + const responseMessages = buildMessagesFromResponse({ + response: finalResponse, + }) + + if (hasTools) { + await cacheChain({ + workspace, + chain, + documentLogUuid, + responseMessages: responseMessages as unknown as PromptlMessage[], + }) + } + + if (step.chainCompleted || hasTools) { + streamConsumer.chainCompleted({ + step, + response: finalResponse, + finishReason, + responseMessages, + }) + + return { + ...finalResponse, + finishReason, + chainCompleted: step.chainCompleted, + } + } + + streamConsumer.stepCompleted(finalResponse) + + return runStep({ + ...stepProps, + previousResponse: finalResponse, + }) + } + + return { providerLog, executeStep } +} diff --git a/packages/core/src/services/chains/chainCache/index.ts b/packages/core/src/services/chains/chainCache/index.ts new file mode 100644 index 000000000..b26583bba --- /dev/null +++ b/packages/core/src/services/chains/chainCache/index.ts @@ -0,0 +1,104 @@ +import { hash } from 'crypto' +import { + Chain as PromptlChain, + type SerializedChain, + Message, +} from 'promptl-ai' + +import { Workspace } from '../../../browser' +import { cache } from '../../../cache' + +type CachedChain = { + chain: PromptlChain + messages: Message[] +} + +function generateCacheKey({ + workspace, + documentLogUuid, +}: { + workspace: Workspace + documentLogUuid: string +}): string { + const k = hash('sha256', `${documentLogUuid}`) + return `workspace:${workspace.id}:chain:${k}` +} + +async function setToCache({ + key, + chain, + messages, +}: { + key: string + chain: SerializedChain + messages: Message[] +}) { + try { + const c = await cache() + await c.set(key, JSON.stringify({ chain, messages })) + } catch (e) { + // Silently fail cache writes + } +} + +async function getFromCache(key: string): Promise { + try { + const c = await cache() + const serialized = await c.get(key) + if (!serialized) return undefined + + const deserialized = JSON.parse(serialized) + const chain = PromptlChain.deserialize({ serialized: deserialized.chain }) + + if (!chain || !deserialized.messages) return undefined + return { chain, messages: deserialized.messages ?? [] } + } catch (e) { + return undefined + } +} + +export async function getCachedChain({ + documentLogUuid, + workspace, +}: { + documentLogUuid: string | undefined + workspace: Workspace +}) { + if (!documentLogUuid) return undefined + + const key = generateCacheKey({ documentLogUuid, workspace }) + return await getFromCache(key) +} + +export async function deleteCachedChain({ + documentLogUuid, + workspace, +}: { + documentLogUuid: string + workspace: Workspace +}) { + const key = generateCacheKey({ documentLogUuid, workspace }) + try { + const c = await cache() + await c.del(key) + } catch (e) { + // Silently fail cache writes + } +} + +export async function cacheChain({ + workspace, + documentLogUuid, + chain, + responseMessages, +}: { + workspace: Workspace + chain: PromptlChain + responseMessages: Message[] + documentLogUuid: string +}) { + const key = generateCacheKey({ documentLogUuid, workspace }) + const serialized = chain.serialize() + + await setToCache({ key, chain: serialized, messages: responseMessages }) +} diff --git a/packages/core/src/services/chains/run.ts b/packages/core/src/services/chains/run.ts index 9c7bc89ab..6c4709ede 100644 --- a/packages/core/src/services/chains/run.ts +++ b/packages/core/src/services/chains/run.ts @@ -1,4 +1,4 @@ -import { Chain as LegacyChain } from '@latitude-data/compiler' +import { Chain as LegacyChain, Message } from '@latitude-data/compiler' import { RunErrorCodes } from '@latitude-data/constants/errors' import { Chain as PromptlChain } from 'promptl-ai' @@ -21,13 +21,14 @@ import { createRunError as createRunErrorFn } from '../runErrors/create' import { ChainError } from './ChainErrors' import { ChainStreamConsumer } from './ChainStreamConsumer' import { consumeStream } from './ChainStreamConsumer/consumeStream' -import { ConfigOverrides, validateChain } from './ChainValidator' +import { + ConfigOverrides, + validateChain, + ValidatedChainStep, +} from './ChainValidator' import { checkValidStream } from './checkValidStream' import { processResponse } from './ProviderProcessor' -import { - buildProviderLogDto, - saveOrPublishProviderLogs, -} from './ProviderProcessor/saveOrPublishProviderLogs' +import { buildStepExecution } from './buildStep' export type CachedApiKeys = Map export type SomeChain = LegacyChain | PromptlChain @@ -77,6 +78,8 @@ type CommonArgs = { generateUUID?: () => string persistErrors?: T removeSchema?: boolean + extraMessages?: Message[] + previousCount?: number } export type RunChainArgs< T extends boolean, @@ -98,6 +101,8 @@ export async function runChain({ persistErrors = true, generateUUID = generateUUIDIdentifier, removeSchema = false, + extraMessages, + previousCount = 0, }: RunChainArgs) { const errorableUuid = generateUUID() @@ -121,6 +126,8 @@ export async function runChain({ errorableType, configOverrides, removeSchema, + extraMessages, + previousCount, }) .then((okResponse) => { responseResolve(Result.ok(okResponse)) @@ -147,21 +154,34 @@ export async function runChain({ } } -async function runStep({ - workspace, - source, +function getNextStepCount({ + stepCount, chain, - promptlVersion, - providersMap, - controller, - previousCount = 0, - previousResponse, - errorableUuid, - errorableType, - configOverrides, - removeSchema, - stepCount = 0, + step, }: { + stepCount: number + chain: SomeChain + step: ValidatedChainStep +}) { + const maxSteps = Math.min( + (step.conversation.config[MAX_STEPS_CONFIG_NAME] as number | undefined) ?? + DEFAULT_MAX_STEPS, + ABSOLUTE_MAX_STEPS, + ) + const exceededMaxSteps = + chain instanceof PromptlChain ? stepCount >= maxSteps : stepCount > maxSteps + + if (!exceededMaxSteps) return Result.ok(++stepCount) + + return Result.error( + new ChainError({ + message: stepLimitExceededErrorMessage(maxSteps), + code: RunErrorCodes.MaxStepCountExceededError, + }), + ) +} + +export type StepProps = { workspace: Workspace source: LogSources chain: SomeChain @@ -172,11 +192,34 @@ async function runStep({ errorableType?: ErrorableEntity previousCount?: number previousResponse?: ChainStepResponse - configOverrides?: ConfigOverrides removeSchema?: boolean stepCount?: number -}) { - const prevText = previousResponse?.text + configOverrides?: ConfigOverrides + extraMessages?: Message[] +} + +export async function runStep({ + workspace, + source, + chain, + promptlVersion, + providersMap, + controller, + previousCount = 0, + previousResponse, + errorableUuid, + errorableType, + configOverrides, + extraMessages, + removeSchema, + stepCount = 0, +}: StepProps) { + // When passed extra messages it means we are resuming a conversation + const prevContent = previousResponse + ? previousResponse.text + : extraMessages + ? extraMessages + : undefined const streamConsumer = new ChainStreamConsumer({ controller, previousCount, @@ -186,7 +229,7 @@ async function runStep({ try { const step = await validateChain({ workspace, - prevText, + prevContent, chain, promptlVersion, providersMap, @@ -194,30 +237,20 @@ async function runStep({ removeSchema, }).then((r) => r.unwrap()) + const messages = step.conversation.messages + if (chain instanceof PromptlChain && step.chainCompleted) { streamConsumer.chainCompleted({ step, response: previousResponse!, + finishReason: previousResponse?.finishReason ?? 'stop', }) + previousResponse!.chainCompleted = true return previousResponse! } - const maxSteps = Math.min( - (step.conversation.config[MAX_STEPS_CONFIG_NAME] as number | undefined) ?? - DEFAULT_MAX_STEPS, - ABSOLUTE_MAX_STEPS, - ) - const exceededMaxSteps = - chain instanceof PromptlChain - ? stepCount >= maxSteps - : stepCount > maxSteps - if (exceededMaxSteps) { - throw new ChainError({ - message: stepLimitExceededErrorMessage(maxSteps), - code: RunErrorCodes.MaxStepCountExceededError, - }) - } + const nextStepCount = getNextStepCount({ stepCount, chain, step }).unwrap() const { messageCount, stepStartTime } = streamConsumer.setup(step) @@ -228,39 +261,17 @@ async function runStep({ }) if (cachedResponse) { - const providerLog = await saveOrPublishProviderLogs({ - workspace, - streamType: cachedResponse.streamType, - finishReason: 'stop', // TODO: we probably should add a cached reason here - data: buildProviderLogDto({ - workspace, - source, - provider: step.provider, - conversation: step.conversation, + const { providerLog, executeStep } = await buildStepExecution({ + baseResponse: cachedResponse as ChainStepResponse, + step, + streamConsumer, + providerLogProps: { + streamType: cachedResponse.streamType, + // NOTE: Before we were hardcoding `stop` when cached. + finishReason: cachedResponse.finishReason ?? 'stop', stepStartTime, - errorableUuid, - response: cachedResponse as ChainStepResponse, - }), - saveSyncProviderLogs: true, // TODO: temp bugfix, it should only save last one syncronously - }) - - const response = { - ...cachedResponse, - providerLog, - documentLogUuid: errorableUuid, - } as ChainStepResponse - - if (step.chainCompleted) { - streamConsumer.chainCompleted({ - step, - response, - }) - - return response - } else { - streamConsumer.stepCompleted(response) - - return runStep({ + }, + stepProps: { workspace, source, chain, @@ -270,28 +281,37 @@ async function runStep({ errorableUuid, errorableType, previousCount: messageCount, - previousResponse: response, + previousResponse: cachedResponse as ChainStepResponse, configOverrides, - stepCount: ++stepCount, - }) - } + stepCount: nextStepCount, + }, + }) + + const finalResponse = { + ...cachedResponse, + providerLog, + documentLogUuid: errorableUuid, + } as ChainStepResponse + + return executeStep({ finalResponse }) } const aiResult = await ai({ - messages: step.conversation.messages, + messages, config: step.config, provider: step.provider, schema: step.schema, output: step.output, }).then((r) => r.unwrap()) - const checkResult = checkValidStream(aiResult) + if (checkResult.error) throw checkResult.error const consumedStream = await consumeStream({ controller, result: aiResult, }) + if (consumedStream.error) throw consumedStream.error const _response = await processResponse({ @@ -299,48 +319,24 @@ async function runStep({ apiProvider: step.provider, config: step.config, errorableUuid, - messages: step.conversation.messages, + messages, source, workspace, startTime: stepStartTime, - }) - - const providerLog = await saveOrPublishProviderLogs({ - workspace, - streamType: aiResult.type, finishReason: consumedStream.finishReason, - data: buildProviderLogDto({ - workspace, - source, - provider: step.provider, - conversation: step.conversation, - stepStartTime, - errorableUuid, - response: _response, - }), - saveSyncProviderLogs: true, // TODO: temp bugfix, shuold only save last one syncronously - }) - - const response = { ..._response, providerLog } - - await setCachedResponse({ - workspace, - config: step.config, - conversation: step.conversation, - response, + chainCompleted: step.chainCompleted, }) - if (step.chainCompleted) { - streamConsumer.chainCompleted({ - step, - response, - }) - - return response - } else { - streamConsumer.stepCompleted(response) - - return runStep({ + const { providerLog, executeStep } = await buildStepExecution({ + step, + streamConsumer, + baseResponse: _response, + providerLogProps: { + streamType: aiResult.type, + finishReason: consumedStream.finishReason, + stepStartTime, + }, + stepProps: { workspace, source, chain, @@ -350,11 +346,21 @@ async function runStep({ providersMap, controller, previousCount: messageCount, - previousResponse: response, + previousResponse: _response, configOverrides, - stepCount: ++stepCount, - }) - } + stepCount: nextStepCount, + }, + }) + + const finalResponse = { ..._response, providerLog } + await setCachedResponse({ + workspace, + config: step.config, + conversation: step.conversation, + response: finalResponse, + }) + + return executeStep({ finalResponse }) } catch (e: unknown) { const error = streamConsumer.chainError(e) throw error diff --git a/packages/core/src/services/commits/runDocumentAtCommit.ts b/packages/core/src/services/commits/runDocumentAtCommit.ts index 868e28681..ceeb91a02 100644 --- a/packages/core/src/services/commits/runDocumentAtCommit.ts +++ b/packages/core/src/services/commits/runDocumentAtCommit.ts @@ -17,7 +17,7 @@ import { getResolvedContent } from '../documents' import { buildProvidersMap } from '../providerApiKeys/buildMap' import { RunDocumentChecker } from './RunDocumentChecker' -export async function createDocumentRunResult({ +async function createDocumentRunResult({ workspace, document, commit, @@ -43,6 +43,7 @@ export async function createDocumentRunResult({ response?: ChainStepResponse }) { const durantionInMs = duration ?? 0 + if (publishEvent) { publisher.publishLater({ type: 'documentRun', diff --git a/packages/core/src/services/documentLogs/addMessages/findPausedChain.ts b/packages/core/src/services/documentLogs/addMessages/findPausedChain.ts new file mode 100644 index 000000000..249d3da6a --- /dev/null +++ b/packages/core/src/services/documentLogs/addMessages/findPausedChain.ts @@ -0,0 +1,43 @@ +import { Workspace } from '../../../browser' +import { Result } from '../../../lib' +import { + CommitsRepository, + DocumentLogsRepository, + DocumentVersionsRepository, +} from '../../../repositories' +import { getCachedChain } from '../../chains/chainCache' + +export async function findPausedChain({ + workspace, + documentLogUuid, +}: { + workspace: Workspace + documentLogUuid: string | undefined +}) { + const cachedData = await getCachedChain({ workspace, documentLogUuid }) + if (!cachedData) return undefined + + const logsRepo = new DocumentLogsRepository(workspace.id) + const logResult = await logsRepo.findByUuid(documentLogUuid) + if (logResult.error) return logResult + + const documentLog = logResult.value + const commitsRepo = new CommitsRepository(workspace.id) + const commitResult = await commitsRepo.find(documentLog.commitId) + if (commitResult.error) return commitResult + + const commit = commitResult.value + const documentsRepo = new DocumentVersionsRepository(workspace.id) + const result = await documentsRepo.getDocumentByUuid({ + commitUuid: commit?.uuid, + documentUuid: documentLog.documentUuid, + }) + if (result.error) return result + + return Result.ok({ + document: result.value, + commit, + pausedChainMessages: cachedData.messages, + pausedChain: cachedData.chain, + }) +} diff --git a/packages/core/src/services/documentLogs/addMessages/index.test.ts b/packages/core/src/services/documentLogs/addMessages/index.test.ts index 657b93be7..7bf2e6161 100644 --- a/packages/core/src/services/documentLogs/addMessages/index.test.ts +++ b/packages/core/src/services/documentLogs/addMessages/index.test.ts @@ -458,6 +458,7 @@ describe('addMessages', () => { data: { type: 'chain-complete', documentLogUuid: providerLog.documentLogUuid!, + finishReason: 'stop', config: { provider: 'openai', model: 'gpt-4o', diff --git a/packages/core/src/services/documentLogs/addMessages/index.ts b/packages/core/src/services/documentLogs/addMessages/index.ts index 87f4b4a9d..1d29a2a42 100644 --- a/packages/core/src/services/documentLogs/addMessages/index.ts +++ b/packages/core/src/services/documentLogs/addMessages/index.ts @@ -1,8 +1,34 @@ import type { Message } from '@latitude-data/compiler' import { LogSources, Workspace } from '../../../browser' -import { ProviderLogsRepository } from '../../../repositories' import { addMessages as addMessagesProviderLog } from '../../providerLogs/addMessages' +import { ProviderLogsRepository } from '../../../repositories' +import { findPausedChain } from './findPausedChain' +import { resumeConversation } from './resumeConversation' +import { Result } from '../../../lib' + +type CommonProps = { + workspace: Workspace + documentLogUuid: string | undefined + messages: Message[] + source: LogSources +} + +async function addMessagesToCompleteChain({ + workspace, + documentLogUuid, + messages, + source, +}: CommonProps) { + const providerLogRepo = new ProviderLogsRepository(workspace.id) + const providerLogResult = + await providerLogRepo.findLastByDocumentLogUuid(documentLogUuid) + if (providerLogResult.error) return providerLogResult + + const providerLog = providerLogResult.value + + return addMessagesProviderLog({ workspace, providerLog, messages, source }) +} export async function addMessages({ workspace, @@ -15,12 +41,35 @@ export async function addMessages({ messages: Message[] source: LogSources }) { - const providerLogRepo = new ProviderLogsRepository(workspace.id) - const providerLogResult = - await providerLogRepo.findLastByDocumentLogUuid(documentLogUuid) - if (providerLogResult.error) return providerLogResult + if (!documentLogUuid) { + return Result.error(new Error('documentLogUuid is required')) + } - const providerLog = providerLogResult.value + const foundResult = await findPausedChain({ workspace, documentLogUuid }) - return addMessagesProviderLog({ workspace, providerLog, messages, source }) + if (!foundResult) { + // No chain cached found means normal chat behavior + return addMessagesToCompleteChain({ + workspace, + documentLogUuid, + messages, + source, + }) + } + + if (foundResult.error) return foundResult + + const pausedChainData = foundResult.value + + return resumeConversation({ + workspace, + commit: pausedChainData.commit, + document: pausedChainData.document, + pausedChain: pausedChainData.pausedChain, + pausedChainMessages: + pausedChainData.pausedChainMessages as unknown as Message[], + documentLogUuid, + responseMessages: messages, + source, + }) } diff --git a/packages/core/src/services/documentLogs/addMessages/resumeConversation/index.ts b/packages/core/src/services/documentLogs/addMessages/resumeConversation/index.ts new file mode 100644 index 000000000..504053dce --- /dev/null +++ b/packages/core/src/services/documentLogs/addMessages/resumeConversation/index.ts @@ -0,0 +1,100 @@ +import { + Commit, + ErrorableEntity, + DocumentVersion, + LogSources, + Workspace, +} from '../../../../browser' +import { Result } from '../../../../lib' +import { runChain } from '../../../chains/run' +import { buildProvidersMap } from '../../../providerApiKeys/buildMap' +import { Message } from '@latitude-data/compiler' +import { Chain as PromptlChain } from 'promptl-ai' +import { getResolvedContent } from '../../../documents' +import { deleteCachedChain } from '../../../chains/chainCache' + +/** + * What means resume a converstation + * :::::::::::::::::::: + * When a paused/cached chain is found in our cache (Redis at the time of writing), + * we asume is a paused and incompleted conversation. + * To resume it we re-run it passing `extraMessages` that will be passed down to prompt + * as response from the previous run. + * + * One use case is tool calling: + * This is helpful for tool calling. Allow users to preprare tool responses when + * AI returns a tool call. And Latitude add the tool request and tool response to the + * conversation so next time AI runs have all the info necesary to continue the conversation. + */ +export async function resumeConversation({ + workspace, + document, + commit, + documentLogUuid, + pausedChain, + pausedChainMessages, + responseMessages, + source, +}: { + workspace: Workspace + commit: Commit + document: DocumentVersion + documentLogUuid: string + pausedChain: PromptlChain + pausedChainMessages: Message[] + responseMessages: Message[] + source: LogSources +}) { + const resultResolvedContent = await getResolvedContent({ + workspaceId: workspace.id, + document, + commit, + }) + + if (resultResolvedContent.error) return resultResolvedContent + + const resolvedContent = resultResolvedContent.value + const errorableType = ErrorableEntity.DocumentLog + const providersMap = await buildProvidersMap({ + workspaceId: workspace.id, + }) + const errorableUuid = documentLogUuid + + let extraMessages = pausedChainMessages.concat(responseMessages) + // These are all the messages that the client + // already have seen. So we don't want to send them again. + const previousCount = + pausedChain.globalMessagesCount + + pausedChainMessages.length + + responseMessages.length + + const run = await runChain({ + generateUUID: () => errorableUuid, + errorableType, + workspace, + chain: pausedChain, + promptlVersion: document.promptlVersion, + providersMap, + source, + extraMessages, + previousCount, + }) + + return Result.ok({ + stream: run.stream, + duration: run.duration, + resolvedContent, + errorableUuid, + response: run.response.then(async (response) => { + const isCompleted = response.value?.chainCompleted + + if (isCompleted) { + // We delete cached chain so next time someone add a message to + // this documentLogUuid it will simple add the message to the conversation. + await deleteCachedChain({ workspace, documentLogUuid }) + } + + return response + }), + }) +} diff --git a/packages/core/src/services/providerLogs/addMessages.ts b/packages/core/src/services/providerLogs/addMessages.ts index e74d4389d..3c4c433e0 100644 --- a/packages/core/src/services/providerLogs/addMessages.ts +++ b/packages/core/src/services/providerLogs/addMessages.ts @@ -3,6 +3,7 @@ import { RunErrorCodes } from '@latitude-data/constants/errors' import { buildConversation, + buildMessagesFromResponse, ChainEvent, ChainStepResponse, LogSources, @@ -149,6 +150,8 @@ async function iterate({ apiProvider: provider, messages, startTime: stepStartTime, + finishReason: consumedStream.finishReason, + chainCompleted: true, }) const providerLog = await saveOrPublishProviderLogs({ @@ -165,10 +168,12 @@ async function iterate({ response: _response, }), saveSyncProviderLogs: true, + chainCompleted: true, }) const response = { ..._response, providerLog } + const responseMessages = buildMessagesFromResponse({ response }) ChainStreamConsumer.chainCompleted({ controller, response, @@ -177,6 +182,8 @@ async function iterate({ provider: provider.name, model: config.model, }, + finishReason: consumedStream.finishReason, + responseMessages, }) return response diff --git a/packages/sdks/typescript/package.json b/packages/sdks/typescript/package.json index 58b8c2e4b..996166020 100644 --- a/packages/sdks/typescript/package.json +++ b/packages/sdks/typescript/package.json @@ -36,12 +36,12 @@ } }, "dependencies": { - "@t3-oss/env-core": "^0.10.1", "@latitude-data/telemetry": "workspace:^", + "@t3-oss/env-core": "^0.10.1", "eventsource-parser": "^2.0.1", "node-fetch": "3.3.2", - "zod": "^3.23.8", - "promptl-ai": "^0.3.5" + "promptl-ai": "^0.4.5", + "zod": "^3.23.8" }, "devDependencies": { "@latitude-data/compiler": "workspace:^", diff --git a/packages/sdks/typescript/src/index.ts b/packages/sdks/typescript/src/index.ts index 1e32af172..4f59f52df 100644 --- a/packages/sdks/typescript/src/index.ts +++ b/packages/sdks/typescript/src/index.ts @@ -8,6 +8,7 @@ import { type ChainEventDto, type DocumentLog, type EvaluationResultDto, + type ToolCallResponse, } from '@latitude-data/constants' import { @@ -358,7 +359,8 @@ class Latitude { if (logResponses) { await this.logs.create( prompt.path, - step.messages as unknown as Message[], // Inexistent types incompatibilities between legacy messages and promptl messages + // Inexistent types incompatibilities between legacy messages and promptl messages + step.messages as unknown as Message[], ) } @@ -449,4 +451,5 @@ export type { MessageRole, Options, StreamChainResponse, + ToolCallResponse, } diff --git a/packages/web-ui/package.json b/packages/web-ui/package.json index b531fd953..da1b24a8d 100644 --- a/packages/web-ui/package.json +++ b/packages/web-ui/package.json @@ -21,6 +21,7 @@ "dependencies": { "@latitude-data/compiler": "workspace:^", "@latitude-data/core": "workspace:^", + "@latitude-data/constants": "workspace:^", "@radix-ui/react-checkbox": "^1.1.2", "@radix-ui/react-progress": "^1.1.0", "date-fns": "^3.6.0", diff --git a/packages/web-ui/src/ds/atoms/ClientOnly/index.tsx b/packages/web-ui/src/ds/atoms/ClientOnly/index.tsx index e4bced3a7..6b70ec422 100644 --- a/packages/web-ui/src/ds/atoms/ClientOnly/index.tsx +++ b/packages/web-ui/src/ds/atoms/ClientOnly/index.tsx @@ -2,7 +2,13 @@ import { ReactNode, useEffect, useState } from 'react' -export function ClientOnly({ children }: { children: ReactNode }) { +export function ClientOnly({ + children, + className, +}: { + children: ReactNode + className?: string +}) { const [mounted, setMounted] = useState(false) useEffect(() => { setMounted(true) @@ -12,6 +18,7 @@ export function ClientOnly({ children }: { children: ReactNode }) { // We have a Hydration issue with the inputs because // they come from localStorage and are not available on the server if (!mounted) return null + if (!className) return <>{children} - return <>{children} + return
{children}
} diff --git a/packages/web-ui/src/ds/atoms/CodeBlock/index.tsx b/packages/web-ui/src/ds/atoms/CodeBlock/index.tsx index c1f585a7c..0a3895506 100644 --- a/packages/web-ui/src/ds/atoms/CodeBlock/index.tsx +++ b/packages/web-ui/src/ds/atoms/CodeBlock/index.tsx @@ -29,6 +29,12 @@ export function CodeBlock(props: CodeBlockProps) { ) } +export function useCodeBlockBackgroundColor() { + const { resolvedTheme } = useTheme() + if (resolvedTheme === CurrentTheme.Light) return 'bg-backgroundCode' + return 'bg-[#282c34]' +} + function Content({ language, children, @@ -36,16 +42,11 @@ function Content({ className, }: CodeBlockProps) { const { resolvedTheme } = useTheme() - + const bgColor = useCodeBlockBackgroundColor() return ( -
+
{copy && ( -
+
)} diff --git a/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolBar/index.tsx b/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolBar/index.tsx new file mode 100644 index 000000000..c09ff58e1 --- /dev/null +++ b/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolBar/index.tsx @@ -0,0 +1,24 @@ +import { Button } from '../../../../atoms' + +export function ToolBar({ + onSubmit, + clearChat, + disabled = false, + submitLabel = 'Send Message', +}: { + onSubmit?: () => void + clearChat?: () => void + disabled?: boolean + submitLabel?: string +}) { + return ( +
+ + +
+ ) +} diff --git a/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolCallForm/Editor/index.tsx b/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolCallForm/Editor/index.tsx new file mode 100644 index 000000000..1e9d3ecdb --- /dev/null +++ b/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolCallForm/Editor/index.tsx @@ -0,0 +1,179 @@ +import { useCallback, useEffect, useRef, useState } from 'react' +import Editor, { Monaco } from '@monaco-editor/react' +import { type editor } from 'monaco-editor' +import { useMonacoSetup } from '../../../../DocumentTextEditor/Editor/useMonacoSetup' +import { TextEditorProps } from './types' + +const LINE_HEIGHT = 18 +const CONTAINER_GUTTER = 10 +function useUpdateEditorHeight({ initialHeight }: { initialHeight: number }) { + const [heightState, setHeight] = useState(initialHeight) + const prevLineCount = useRef(0) + const updateHeight = useCallback((editor: editor.IStandaloneCodeEditor) => { + const el = editor.getDomNode() + if (!el) return + const codeContainer = el.getElementsByClassName( + 'view-lines', + )[0] as HTMLDivElement | null + + if (!codeContainer) return + + setTimeout(() => { + const height = + codeContainer.childElementCount > prevLineCount.current + ? codeContainer.offsetHeight + : codeContainer.childElementCount * LINE_HEIGHT + CONTAINER_GUTTER // fold + prevLineCount.current = codeContainer.childElementCount + + // Max height + if (height >= 200) return + + setHeight(height) + el.style.height = height + 'px' + + editor.layout() + }, 0) + }, []) + return { height: heightState, updateHeight } +} + +function updatePlaceholder({ + collection, + decoration, + editor, + hasPlaceholder, +}: { + collection: editor.IEditorDecorationsCollection + decoration: editor.IModelDeltaDecoration + editor: editor.IStandaloneCodeEditor + monaco: Monaco + hasPlaceholder: boolean +}) { + const model = editor.getModel() + if (!model) return false + + const modelValue = model.getValue() + + if ((modelValue === '' || modelValue === ' ') && !hasPlaceholder) { + collection.append([decoration]) + return true + } else if (hasPlaceholder) { + collection.clear() + return false + } + + return false +} + +export default function TextEditor({ + value, + onChange, + onCmdEnter, + placeholder, +}: TextEditorProps) { + const editorRef = useRef(null) + const { monacoRef, handleEditorWillMount } = useMonacoSetup() + const isMountedRef = useRef(false) + const { height, updateHeight } = useUpdateEditorHeight({ initialHeight: 0 }) + const handleEditorDidMount = useCallback( + (editor: editor.IStandaloneCodeEditor, monaco: Monaco) => { + monacoRef.current = monaco + editorRef.current = editor + let hasPlaceholder = false + const collection = editor.createDecorationsCollection([]) + const decoration = { + range: new monaco.Range(1, 1, 1, 1), + options: { + isWholeLine: true, + className: 'hack-monaco-editor-placeholder', + }, + } + editor.focus() + + // Initial placeholder setup + hasPlaceholder = updatePlaceholder({ + collection, + decoration, + editor, + monaco, + hasPlaceholder, + }) + + editor.onDidChangeModelContent(() => { + if (!isMountedRef.current) return + + hasPlaceholder = updatePlaceholder({ + collection, + decoration, + editor, + monaco, + hasPlaceholder, + }) + }) + + editor.onDidChangeModelDecorations(() => { + updateHeight(editor) + }) + + editor.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.Enter, () => { + onCmdEnter?.(editor.getValue()) + }) + + isMountedRef.current = true + }, + [updateHeight, isMountedRef, monacoRef, placeholder, onCmdEnter, onChange], + ) + + // Refresh onCmdEnter prop callback when it changes + useEffect(() => { + if (!editorRef.current) return + if (!isMountedRef.current) return + if (!monacoRef.current) return + + const monaco = monacoRef.current + const editor = editorRef.current + editor.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.Enter, () => { + onCmdEnter?.(editor.getValue()) + }) + }, [onCmdEnter]) + return ( + <> + + + + ) +} diff --git a/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolCallForm/Editor/types.ts b/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolCallForm/Editor/types.ts new file mode 100644 index 000000000..05eb25e68 --- /dev/null +++ b/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolCallForm/Editor/types.ts @@ -0,0 +1,6 @@ +export type TextEditorProps = { + value?: string + onChange: (value: string | undefined) => void + onCmdEnter?: (value?: string | undefined) => void + placeholder: string +} diff --git a/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolCallForm/index.tsx b/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolCallForm/index.tsx new file mode 100644 index 000000000..83b712a8d --- /dev/null +++ b/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/ToolCallForm/index.tsx @@ -0,0 +1,241 @@ +'use client' + +import { MouseEvent, lazy, useState, useCallback, useMemo } from 'react' +import { ToolMessage, Message } from '@latitude-data/compiler' +import { + Badge, + Icon, + Text, + CodeBlock, + Tooltip, + ClientOnly, +} from '../../../../atoms' +import { ToolRequest } from '../../../../../lib/versionedMessagesHelpers' +import { buildResponseMessage } from '@latitude-data/core/browser' +import { ToolBar } from '../ToolBar' + +const TextEditor = lazy(() => import('./Editor/index')) + +function generateExampleFunctionCall(toolCall: ToolRequest) { + const args = toolCall.toolArguments + const functionName = toolCall.toolName + const formattedArgs = Object.keys(args).length + ? JSON.stringify(args, null, 2) + : '' + + return `${functionName}(${formattedArgs ? formattedArgs : ''})` +} + +function buildToolResponseMessage({ + value, + toolRequest, +}: { + value: string + toolRequest: ToolRequest +}) { + const toolResponse = { + id: toolRequest.toolCallId, + name: toolRequest.toolName, + result: value, + } + const message = buildResponseMessage<'text'>({ + type: 'text', + data: { + text: undefined, + toolCallResponses: [toolResponse], + }, + }) + return message! as ToolMessage +} + +function ToolEditor({ + value, + toolRequest, + onChange, + onSubmit, + placeholder, + currentToolRequest, + totalToolRequests, +}: { + toolRequest: ToolRequest + placeholder: string + value: string | undefined + currentToolRequest: number + totalToolRequests: number + onChange: (value: string | undefined) => void + onSubmit?: (sentValue?: string | undefined) => void +}) { + const functionCall = useMemo( + () => generateExampleFunctionCall(toolRequest), + [toolRequest], + ) + return ( +
+
+ + + You have{' '} + {totalToolRequests <= 1 ? ( + <> + {totalToolRequests} tool to + be responded + + ) : ( + <> + {currentToolRequest} of{' '} + {totalToolRequests} tools to + be responded + + )} + + +
+ } + > + The Assistant has responded with a Tool Request. In your server, this + would mean to run the function from your code and send back the + result. However, since we cannot access your code execution from the + Playground, you must write the expected response when using this tool. + +
+ + {functionCall} + +
+
+ Your tool response + +
+
+ ) +} + +function getValue(value: string | MouseEvent | undefined) { + if (typeof value === 'string') return value.trim() + return undefined +} + +export function ToolCallForm({ + toolRequests, + sendToServer, + addLocalMessages, + clearChat, + disabled, + placeholder, +}: { + toolRequests: ToolRequest[] + placeholder: string + addLocalMessages: (messages: Message[]) => void + sendToServer?: (messages: ToolMessage[]) => void + clearChat?: () => void + disabled?: boolean +}) { + const [currentToolRequestIndex, setCurrentToolRequestIndex] = useState(1) + const [currentToolRequest, setCurrentToolRequest] = useState< + ToolRequest | undefined + >(toolRequests[0]) + const [respondedToolRequests, setRespondedToolRequests] = useState< + ToolMessage[] + >([]) + const [value, setValue] = useState('') + const onChange = useCallback((newValue: string | undefined) => { + setValue(newValue) + }, []) + const onLocalSend = useCallback( + (sentValue?: MouseEvent | string | undefined) => { + const cleanValue = getValue(value || sentValue) + + if (!currentToolRequest) return + if (!cleanValue) return + + const message = buildToolResponseMessage({ + value: cleanValue, + toolRequest: currentToolRequest, + }) + setRespondedToolRequests((prev) => [...prev, message]) + addLocalMessages([message]) + const findNextToolRequest = toolRequests.findIndex( + (tr) => tr.toolCallId === currentToolRequest.toolCallId, + ) + const nextIndex = findNextToolRequest + 1 + const nextToolRequest = toolRequests[nextIndex] + + if (nextToolRequest) { + setCurrentToolRequest(nextToolRequest) + setCurrentToolRequestIndex(nextIndex) + } + setValue('') + }, + [ + currentToolRequest, + addLocalMessages, + value, + setCurrentToolRequest, + setCurrentToolRequestIndex, + setRespondedToolRequests, + toolRequests, + ], + ) + + const onServerSend = useCallback( + (sentValue?: string | undefined) => { + const cleanValue = getValue(value || sentValue) + + if (!currentToolRequest) return + if (!cleanValue) return + + const message = buildToolResponseMessage({ + value: cleanValue, + toolRequest: currentToolRequest, + }) + addLocalMessages([message]) + const allToolMessages = respondedToolRequests.concat([message]) + sendToServer?.(allToolMessages) + }, + [ + addLocalMessages, + respondedToolRequests, + sendToServer, + currentToolRequest, + value, + ], + ) + + if (!currentToolRequest) return null + + const isLastRequest = + currentToolRequest.toolCallId === + toolRequests[toolRequests.length - 1]?.toolCallId + + const onSubmitHandler = isLastRequest ? onServerSend : onLocalSend + + return ( + + +
+ +
+
+ ) +} diff --git a/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/index.tsx b/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/index.tsx index 8ad04fa17..b66d56e62 100644 --- a/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/index.tsx +++ b/packages/web-ui/src/ds/molecules/Chat/ChatTextArea/index.tsx @@ -4,12 +4,16 @@ import { KeyboardEvent, useCallback, useState } from 'react' import TextareaAutosize from 'react-textarea-autosize' -import { Button } from '../../../atoms' +import { ToolRequest } from '../../../../lib/versionedMessagesHelpers' +import { ToolCallForm } from './ToolCallForm' +import { Message, ToolMessage } from '@latitude-data/compiler' +import { cn } from '../../../../lib/utils' +import { ToolBar } from './ToolBar' -export function ChatTextArea({ +function SimpleTextArea({ placeholder, clearChat, - onSubmit, + onSubmit: onSubmitProp, disabled = false, }: { placeholder: string @@ -18,26 +22,23 @@ export function ChatTextArea({ onSubmit?: (value: string) => void }) { const [value, setValue] = useState('') - - const handleSubmit = useCallback(() => { + const onSubmit = useCallback(() => { if (disabled) return if (value === '') return setValue('') - onSubmit?.(value) - }, [value, onSubmit, disabled]) - + onSubmitProp?.(value) + }, [value, onSubmitProp, disabled]) const handleKeyDown = useCallback( (e: KeyboardEvent) => { if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault() - handleSubmit() + onSubmit() } }, - [handleSubmit], + [onSubmit], ) - return ( -
+ <> -
- - +
+
+ + ) +} + +type OnSubmitWithTools = (value: string | ToolMessage[]) => void +type OnSubmit = (value: string) => void +export function ChatTextArea({ + placeholder, + clearChat, + onSubmit, + disabled = false, + toolRequests = [], + addMessages, +}: { + placeholder: string + clearChat: () => void + disabled?: boolean + onSubmit?: OnSubmit | OnSubmitWithTools + addMessages?: (messages: Message[]) => void + toolRequests?: ToolRequest[] +}) { + return ( +
0, + }, + )} + > + {toolRequests.length > 0 && addMessages ? ( + + ) : ( + + )}
) } diff --git a/packages/web-ui/src/ds/molecules/DocumentTextEditor/Editor/RegularEditor.tsx b/packages/web-ui/src/ds/molecules/DocumentTextEditor/Editor/RegularEditor.tsx index 8a23ae538..4298ef9e7 100644 --- a/packages/web-ui/src/ds/molecules/DocumentTextEditor/Editor/RegularEditor.tsx +++ b/packages/web-ui/src/ds/molecules/DocumentTextEditor/Editor/RegularEditor.tsx @@ -34,7 +34,8 @@ export function RegularMonacoEditor({ const { monacoRef, handleEditorWillMount } = useMonacoSetup({ errorFixFn }) const [defaultValue, _] = useState(value) - const [isEditorMounted, setIsEditorMounted] = useState(false) // to avoid race conditions + // to avoid race conditions + const [isEditorMounted, setIsEditorMounted] = useState(false) const { options } = useEditorOptions({ tabSize: 2, readOnly: !!readOnlyMessage, diff --git a/packages/web-ui/src/ds/molecules/DocumentTextEditor/Editor/useMonacoSetup.ts b/packages/web-ui/src/ds/molecules/DocumentTextEditor/Editor/useMonacoSetup.ts index 063b7762d..417fcc74b 100644 --- a/packages/web-ui/src/ds/molecules/DocumentTextEditor/Editor/useMonacoSetup.ts +++ b/packages/web-ui/src/ds/molecules/DocumentTextEditor/Editor/useMonacoSetup.ts @@ -41,6 +41,7 @@ export function useMonacoSetup({ useEffect(() => { if (!monacoRef.current) return + applyTheme(monacoRef.current) }, [applyTheme]) diff --git a/packages/web-ui/src/index.ts b/packages/web-ui/src/index.ts index 258ef642c..a774e8b4c 100644 --- a/packages/web-ui/src/index.ts +++ b/packages/web-ui/src/index.ts @@ -10,4 +10,5 @@ export * from './lib/getUserInfo' export * from './lib/hooks/useAutoScroll' export * from './lib/hooks/useLocalStorage' export * from './lib/commonTypes' +export * from './lib/versionedMessagesHelpers' export * from './ds/atoms/ChartBlankSlate' diff --git a/packages/web-ui/src/lib/versionedMessagesHelpers.ts b/packages/web-ui/src/lib/versionedMessagesHelpers.ts new file mode 100644 index 000000000..6b2aa554d --- /dev/null +++ b/packages/web-ui/src/lib/versionedMessagesHelpers.ts @@ -0,0 +1,85 @@ +import { + Message as CompilerMessage, + ContentType as CompilerContentType, + MessageRole as CompilerMessageRole, + ToolRequestContent as CompilerToolRequestContent, +} from '@latitude-data/compiler' +import { ToolCallResponse as ToolResponse } from '@latitude-data/constants' +import { + Message as PromptlMessage, + ToolCallContent as ToolRequest, + ContentType as PromptlContentType, + MessageRole as PromptlMessageRole, +} from 'promptl-ai' + +export type PromptlVersion = 0 | 1 +export type VersionedMessage = V extends 0 + ? CompilerMessage + : PromptlMessage +type ToolResponsePart = Pick + +export type ToolPart = ToolRequest | ToolResponsePart + +function extractCompilerToolContents(messages: CompilerMessage[]): ToolPart[] { + return messages.flatMap((message) => { + if (message.role === CompilerMessageRole.tool) { + return message.content + .filter((content) => { + return content.type === CompilerContentType.toolResult + }) + .map((content) => ({ + id: content.toolCallId, + })) + } + + if (message.role !== CompilerMessageRole.assistant) return [] + if (typeof message.content === 'string') return [] + + const content = Array.isArray(message.content) + ? message.content + : [message.content] + + const toolRequestContents = content.filter((content) => { + return content.type === CompilerContentType.toolCall + }) as CompilerToolRequestContent[] + + return toolRequestContents.map((content) => ({ + type: PromptlContentType.toolCall, + toolCallId: content.toolCallId, + toolName: content.toolName, + toolArguments: content.args, + })) + }) +} + +function extractPromptlToolContents(messages: PromptlMessage[]): ToolPart[] { + return messages.flatMap((message) => { + if (message.role === PromptlMessageRole.tool) { + return [{ id: message.toolId }] + } + + if (message.role !== PromptlMessageRole.assistant) return [] + if (typeof message.content === 'string') return [] + + const content = Array.isArray(message.content) + ? message.content + : [message.content] + return content.filter((content) => { + return content.type === PromptlContentType.toolCall + }) + }) +} + +export function extractToolContents({ + version, + messages, +}: { + version: V + messages: VersionedMessage[] +}) { + return version === 0 + ? extractCompilerToolContents(messages as CompilerMessage[]) + : extractPromptlToolContents(messages as PromptlMessage[]) +} + +export type { ToolRequest } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 6b071821a..01801575e 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -260,8 +260,8 @@ importers: specifier: ^4.17.21 version: 4.17.21 promptl-ai: - specifier: ^0.3.5 - version: 0.3.5 + specifier: ^0.4.5 + version: 0.4.5 rate-limiter-flexible: specifier: ^5.0.3 version: 5.0.4 @@ -427,8 +427,8 @@ importers: specifier: ^1.161.6 version: 1.176.0 promptl-ai: - specifier: ^0.3.5 - version: 0.3.5 + specifier: ^0.4.5 + version: 0.4.5 rate-limiter-flexible: specifier: ^5.0.3 version: 5.0.4 @@ -773,8 +773,8 @@ importers: specifier: ^1.0.5 version: 1.0.5 promptl-ai: - specifier: ^0.3.5 - version: 0.3.5 + specifier: ^0.4.5 + version: 0.4.5 devDependencies: '@ai-sdk/anthropic': specifier: ^1.0.5 @@ -993,8 +993,8 @@ importers: specifier: 3.3.2 version: 3.3.2 promptl-ai: - specifier: ^0.3.5 - version: 0.3.5 + specifier: ^0.4.5 + version: 0.4.5 typescript: specifier: ^5.5.4 version: 5.7.2 @@ -1186,6 +1186,9 @@ importers: '@latitude-data/compiler': specifier: workspace:^ version: link:../compiler + '@latitude-data/constants': + specifier: workspace:^ + version: link:../constants '@latitude-data/core': specifier: workspace:^ version: link:../core @@ -11276,6 +11279,9 @@ packages: promptl-ai@0.3.5: resolution: {integrity: sha512-g3RtXgCOtMky68rkdalGGApO85DBZvL65gJy7SzZ+D+eEpCB2hR+7q2G67y+s66AHDzG2DM+sU+1WTRbDCrfyg==} + promptl-ai@0.4.5: + resolution: {integrity: sha512-bJkWlLtyxJkxt28+InTtKF8e4OZ67J9QtkOnI9xKBeEv6ILRFHcU09quTbCFi0wCL+6lw6okL2tV1yix21Jf6g==} + prop-types@15.8.1: resolution: {integrity: sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg==} @@ -12575,10 +12581,6 @@ packages: resolution: {integrity: sha512-4dbzIzqvjtgiM5rw1k5rEHtBANKmdudhGyBEajN01fEyhaAIhsoKNy6y7+IN93IfpFtwY9iqi7kD+xwKhQsNJA==} engines: {node: '>=8'} - type-fest@4.26.1: - resolution: {integrity: sha512-yOGpmOAL7CkKe/91I5O3gPICmJNLJ1G4zFYVAsRHg7M64biSnPtRj0WNQt++bRkjYOqjWXrhnUw1utzmVErAdg==} - engines: {node: '>=16'} - type-fest@4.30.0: resolution: {integrity: sha512-G6zXWS1dLj6eagy6sVhOMQiLtJdxQBHIA9Z6HFUNLOlr6MFOgzV8wvmidtPONfPtEUv0uZsy77XJNzTAfwPDaA==} engines: {node: '>=16'} @@ -19635,7 +19637,7 @@ snapshots: '@types/body-parser@1.19.5': dependencies: '@types/connect': 3.4.38 - '@types/node': 22.10.1 + '@types/node': 20.17.10 '@types/bytes@3.1.4': {} @@ -19654,7 +19656,7 @@ snapshots: '@types/connect@3.4.38': dependencies: - '@types/node': 22.10.1 + '@types/node': 20.17.10 '@types/cookie-parser@1.4.7': dependencies: @@ -19708,7 +19710,7 @@ snapshots: '@types/express-serve-static-core@4.19.6': dependencies: - '@types/node': 22.10.1 + '@types/node': 20.17.10 '@types/qs': 6.9.17 '@types/range-parser': 1.2.7 '@types/send': 0.17.4 @@ -19908,12 +19910,12 @@ snapshots: '@types/send@0.17.4': dependencies: '@types/mime': 1.3.5 - '@types/node': 22.10.1 + '@types/node': 20.17.10 '@types/serve-static@1.15.7': dependencies: '@types/http-errors': 2.0.4 - '@types/node': 22.10.1 + '@types/node': 20.17.10 '@types/send': 0.17.4 '@types/shimmer@1.2.0': {} @@ -25352,7 +25354,7 @@ snapshots: outvariant: 1.4.3 path-to-regexp: 6.3.0 strict-event-emitter: 0.5.1 - type-fest: 4.26.1 + type-fest: 4.30.0 yargs: 17.7.2 optionalDependencies: typescript: 5.7.2 @@ -25377,7 +25379,7 @@ snapshots: outvariant: 1.4.3 path-to-regexp: 6.3.0 strict-event-emitter: 0.5.1 - type-fest: 4.26.1 + type-fest: 4.30.0 yargs: 17.7.2 optionalDependencies: typescript: 5.7.2 @@ -25402,7 +25404,7 @@ snapshots: outvariant: 1.4.3 path-to-regexp: 6.3.0 strict-event-emitter: 0.5.1 - type-fest: 4.26.1 + type-fest: 4.30.0 yargs: 17.7.2 optionalDependencies: typescript: 5.7.2 @@ -26462,6 +26464,14 @@ snapshots: yaml: 2.6.0 zod: 3.23.8 + promptl-ai@0.4.5: + dependencies: + acorn: 8.14.0 + code-red: 1.0.4 + locate-character: 3.0.0 + yaml: 2.6.0 + zod: 3.23.8 + prop-types@15.8.1: dependencies: loose-envify: 1.4.0 @@ -28342,8 +28352,6 @@ snapshots: type-fest@0.8.1: {} - type-fest@4.26.1: {} - type-fest@4.30.0: {} type-is@1.6.18: