Skip to content

Commit

Permalink
Handle tool calls in SDK and playground
Browse files Browse the repository at this point in the history
We have a task to allow users give a response on one or more tool calls
that the AI can request in our playground. The thing is that the work
necessary to make this work is also necessary for making tool call work
on our SDK so this PR is a spike to have a picture of what are the
moving parts necessary to make this happen
  • Loading branch information
andresgutgon committed Jan 20, 2025
1 parent b3f233b commit 97f0894
Show file tree
Hide file tree
Showing 49 changed files with 1,561 additions and 363 deletions.
2 changes: 1 addition & 1 deletion apps/gateway/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"drizzle-orm": "^0.33.0",
"hono": "^4.6.6",
"lodash-es": "^4.17.21",
"promptl-ai": "^0.3.5",
"promptl-ai": "^0.4.5",
"rate-limiter-flexible": "^5.0.3",
"zod": "^3.23.8"
},
Expand Down
2 changes: 1 addition & 1 deletion apps/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
"oslo": "1.2.0",
"pdfjs-dist": "^4.9.155",
"posthog-js": "^1.161.6",
"promptl-ai": "^0.3.5",
"promptl-ai": "^0.4.5",
"rate-limiter-flexible": "^5.0.3",
"react": "19.0.0-rc-5d19e1c8-20240923",
"react-dom": "19.0.0-rc-5d19e1c8-20240923",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ export default function Chat({
) : (
<StreamMessage
responseStream={responseStream}
conversation={conversation}
messages={conversation?.messages ?? []}
chainLength={chainLength}
/>
)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ import { useCallback, useContext, useEffect, useRef, useState } from 'react'

import {
ContentType,
Conversation,
Message as ConversationMessage,
MessageRole,
ToolMessage,
} from '@latitude-data/compiler'
import {
ChainEventTypes,
Expand All @@ -31,15 +31,31 @@ import { readStreamableValue } from 'ai/rsc'

import { DocumentEditorContext } from '..'
import Actions, { ActionsState } from './Actions'
import { useMessages } from './useMessages'
import { type PromptlVersion } from '@latitude-data/web-ui'

export default function Chat({
function buildMessage({ input }: { input: string | ToolMessage[] }) {
if (typeof input === 'string') {
return [
{
role: MessageRole.user,
content: [{ type: ContentType.text, text: input }],
} as ConversationMessage,
]
}
return input
}

export default function Chat<V extends PromptlVersion>({
document,
promptlVersion,
parameters,
clearChat,
expandParameters,
setExpandParameters,
}: {
document: DocumentVersion
promptlVersion: V
parameters: Record<string, unknown>
clearChat: () => void
} & ActionsState) {
Expand All @@ -62,25 +78,11 @@ export default function Chat({
const runChainOnce = useRef(false)
// Index where the chain ends and the chat begins
const [chainLength, setChainLength] = useState<number>(Infinity)
const [conversation, setConversation] = useState<Conversation | undefined>()
const [responseStream, setResponseStream] = useState<string | undefined>()
const [isStreaming, setIsStreaming] = useState(false)

const addMessageToConversation = useCallback(
(message: ConversationMessage) => {
let newConversation: Conversation
setConversation((prevConversation) => {
newConversation = {
...prevConversation,
messages: [...(prevConversation?.messages ?? []), message],
} as Conversation
return newConversation
})
return newConversation!
},
[],
)

const { messages, addMessages, unresponedToolCalls } = useMessages<V>({
version: promptlVersion,
})
const startStreaming = useCallback(() => {
setError(undefined)
setUsage({
Expand Down Expand Up @@ -122,7 +124,7 @@ export default function Chat({

if ('messages' in data) {
setResponseStream(undefined)
data.messages!.forEach(addMessageToConversation)
addMessages(data.messages ?? [])
messagesCount += data.messages!.length
}

Expand All @@ -148,6 +150,7 @@ export default function Chat({
}
break
}

default:
break
}
Expand All @@ -163,7 +166,7 @@ export default function Chat({
commit.uuid,
parameters,
runDocumentAction,
addMessageToConversation,
addMessages,
startStreaming,
stopStreaming,
])
Expand All @@ -176,22 +179,23 @@ export default function Chat({
}, [runDocument])

const submitUserMessage = useCallback(
async (input: string) => {
async (input: string | ToolMessage[]) => {
if (!documentLogUuid) return // This should not happen

const message: ConversationMessage = {
role: MessageRole.user,
content: [{ type: ContentType.text, text: input }],
const newMessages = buildMessage({ input })

// Only in Chat mode we add optimistically the message
if (typeof input === 'string') {
addMessages(newMessages)
}
addMessageToConversation(message)

let response = ''
startStreaming()

try {
const { output } = await addMessagesAction({
documentLogUuid,
messages: [message],
messages: newMessages,
})

for await (const serverEvent of readStreamableValue(output)) {
Expand All @@ -201,7 +205,7 @@ export default function Chat({

if ('messages' in data) {
setResponseStream(undefined)
data.messages!.forEach(addMessageToConversation)
addMessages(data.messages ?? [])
}

switch (event) {
Expand All @@ -222,6 +226,7 @@ export default function Chat({
}
break
}

default:
break
}
Expand All @@ -235,7 +240,7 @@ export default function Chat({
[
documentLogUuid,
addMessagesAction,
addMessageToConversation,
addMessages,
startStreaming,
stopStreaming,
],
Expand All @@ -255,32 +260,30 @@ export default function Chat({
className='flex flex-col gap-3 flex-grow flex-shrink min-h-0 custom-scrollbar scrollable-indicator pb-12'
>
<MessageList
messages={conversation?.messages.slice(0, chainLength - 1) ?? []}
messages={messages.slice(0, chainLength - 1) ?? []}
parameters={Object.keys(parameters)}
collapseParameters={!expandParameters}
/>
{(conversation?.messages.length ?? 0) >= chainLength && (
{(messages.length ?? 0) >= chainLength && (
<>
<MessageList
messages={
conversation?.messages.slice(chainLength - 1, chainLength) ?? []
}
messages={messages.slice(chainLength - 1, chainLength) ?? []}
/>
{time && <Timer timeMs={time} />}
</>
)}
{(conversation?.messages.length ?? 0) > chainLength && (
{(messages.length ?? 0) > chainLength && (
<>
<Text.H6M>Chat</Text.H6M>
<MessageList messages={conversation!.messages.slice(chainLength)} />
<MessageList messages={messages.slice(chainLength)} />
</>
)}
{error ? (
<ErrorMessage error={error} />
) : (
<StreamMessage
responseStream={responseStream}
conversation={conversation}
messages={messages}
chainLength={chainLength}
/>
)}
Expand All @@ -296,6 +299,8 @@ export default function Chat({
placeholder='Enter followup message...'
disabled={isStreaming}
onSubmit={submitUserMessage}
toolRequests={unresponedToolCalls}
addMessages={addMessages}
/>
</div>
</div>
Expand Down Expand Up @@ -342,12 +347,8 @@ export function TokenUsage({
}
>
<div className='flex flex-col gap-2'>
<Text.H6M color='foregroundMuted'>
{usage?.promptTokens || 0} prompt tokens
</Text.H6M>
<Text.H6M color='foregroundMuted'>
{usage?.completionTokens || 0} completion tokens
</Text.H6M>
<span>{usage?.promptTokens || 0} prompt tokens</span>
<span>{usage?.completionTokens || 0} completion tokens</span>
</div>
</Tooltip>
) : (
Expand All @@ -359,16 +360,15 @@ export function TokenUsage({

export function StreamMessage({
responseStream,
conversation,
messages,
chainLength,
}: {
responseStream: string | undefined
conversation: Conversation | undefined
messages: ConversationMessage[]
chainLength: number
}) {
if (responseStream === undefined) return null
if (conversation === undefined) return null
if (conversation.messages.length < chainLength - 1) {
if (messages.length < chainLength - 1) {
return (
<Message
role={MessageRole.assistant}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
'use client'

import { useState } from 'react'

import { useDocumentParameters } from '$/hooks/useDocumentParameters'
Expand Down Expand Up @@ -30,6 +28,7 @@ export default function Playground({
setPrompt: (prompt: string) => void
metadata: ConversationMetadata
}) {
const promptlVersion = document.promptlVersion === 1 ? 1 : 0
const [mode, setMode] = useState<'preview' | 'chat'>('preview')
const { commit } = useCurrentCommit()
const [expanded, setExpanded] = useState(true)
Expand Down Expand Up @@ -77,6 +76,7 @@ export default function Playground({
) : (
<Chat
document={document}
promptlVersion={promptlVersion}
parameters={parameters}
clearChat={() => setMode('preview')}
expandParameters={expandParameters}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import {
ToolPart,
ToolRequest,
PromptlVersion,
VersionedMessage,
extractToolContents,
} from '@latitude-data/web-ui'
import { Message as CompilerMessage } from '@latitude-data/compiler'
import { useCallback, useState } from 'react'

function isToolRequest(part: ToolPart): part is ToolRequest {
return 'toolCallId' in part
}

function getUnrespondedToolRequests<V extends PromptlVersion>({
version: _,
messages,
}: {
version: V
messages: VersionedMessage<V>[]
}) {
// FIXME: Kill compiler please. I made this module compatible with
// both old compiler and promptl. But because everything is typed with
// old compiler prompts in promptl version are also formatted as old compiler
const parts = extractToolContents({
version: 0,
messages: messages as VersionedMessage<0>[],
})
const toolRequestIds = new Set<string>()
const toolResponses = new Set<string>()

parts.forEach((part) => {
if (isToolRequest(part)) {
toolRequestIds.add(part.toolCallId)
} else {
toolResponses.add(part.id)
}
})

return parts.filter(
(part): part is ToolRequest =>
isToolRequest(part) && !toolResponses.has(part.toolCallId),
)
}

type Props<V extends PromptlVersion> = { version: V }

export function useMessages<V extends PromptlVersion>({ version }: Props<V>) {
const [messages, setMessages] = useState<VersionedMessage<V>[]>([])
const [unresponedToolCalls, setUnresponedToolCalls] = useState<ToolRequest[]>(
[],
)
// FIXME: Kill compiler please
// every where we have old compiler types. To avoid Typescript crying we
// allow only compiler messages and then transform them to versioned messages
const addMessages = useCallback(
(m: CompilerMessage[]) => {
const msg = m as VersionedMessage<V>[]
setMessages((prevMessages) => {
const newMessages = prevMessages.concat(msg)
setUnresponedToolCalls(
getUnrespondedToolRequests({ version, messages: newMessages }),
)
return newMessages
})
},
[version],
)

return {
addMessages,
messages: messages as CompilerMessage[],
unresponedToolCalls,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export function AllMessages({
) : (
<StreamMessage
responseStream={responseStream}
conversation={conversation}
messages={conversation?.messages ?? []}
chainLength={chainLength}
/>
)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ export function ChatMessages({
) : (
<StreamMessage
responseStream={responseStream}
conversation={conversation}
messages={conversation.messages}
chainLength={chainLength}
/>
)}
Expand Down
2 changes: 1 addition & 1 deletion packages/compiler/src/types/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ export type AssistantMessage = {

export type ToolMessage = {
role: MessageRole.tool
content: ToolContent[]
content: (TextContent | ToolContent)[]
[key: string]: unknown
}

Expand Down
Loading

0 comments on commit 97f0894

Please sign in to comment.