Skip to content

Commit

Permalink
Handle tool calls in SDK and playground
Browse files Browse the repository at this point in the history
We have a task to allow users give a response on one or more tool calls
that the AI can request in our playground. The thing is that the work
necessary to make this work is also necessary for making tool call work
on our SDK so this PR is a spike to have a picture of what are the
moving parts necessary to make this happen
  • Loading branch information
andresgutgon committed Jan 16, 2025
1 parent 76830b5 commit 1afb6cd
Show file tree
Hide file tree
Showing 40 changed files with 1,248 additions and 296 deletions.
2 changes: 1 addition & 1 deletion apps/gateway/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"drizzle-orm": "^0.33.0",
"hono": "^4.6.6",
"lodash-es": "^4.17.21",
"promptl-ai": "^0.3.5",
"promptl-ai": "file:/Users/andres/code/latitude/promptl",
"rate-limiter-flexible": "^5.0.3",
"zod": "^3.23.8"
},
Expand Down
2 changes: 1 addition & 1 deletion apps/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
"oslo": "1.2.0",
"pdfjs-dist": "^4.9.155",
"posthog-js": "^1.161.6",
"promptl-ai": "^0.3.5",
"promptl-ai": "file:/Users/andres/code/latitude/promptl",
"rate-limiter-flexible": "^5.0.3",
"react": "19.0.0-rc-5d19e1c8-20240923",
"react-dom": "19.0.0-rc-5d19e1c8-20240923",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ import { useCallback, useContext, useEffect, useRef, useState } from 'react'

import {
ContentType,
Conversation,
Message as ConversationMessage,
MessageRole,
} from '@latitude-data/compiler'
import {
ChainEventTypes,
StreamEventTypes,
buildResponseMessage,
type DocumentVersion,
} from '@latitude-data/core/browser'
import {
Expand All @@ -26,20 +26,42 @@ import {
useCurrentCommit,
useCurrentProject,
} from '@latitude-data/web-ui'
import { ToolCallResponse } from '@latitude-data/constants'
import { LanguageModelUsage } from 'ai'
import { readStreamableValue } from 'ai/rsc'

import { DocumentEditorContext } from '..'
import Actions, { ActionsState } from './Actions'
import { useMessages } from './useMessages'
import { type PromptlVersion } from '$/lib/versionedMessagesHelpers'

function buildMessage({ input }: { input: string | ToolCallResponse }) {
if (typeof input === 'string') {
return {
role: MessageRole.user,
content: [{ type: ContentType.text, text: input }],
} as ConversationMessage
}

const toolMessage = buildResponseMessage<'text'>({
type: 'text',
data: { text: input.text, toolCallResponses: [input] },
})

// We asume it can not be null
return toolMessage!
}

export default function Chat({
export default function Chat<V extends PromptlVersion>({
document,
promptlVersion,
parameters,
clearChat,
expandParameters,
setExpandParameters,
}: {
document: DocumentVersion
promptlVersion: V
parameters: Record<string, unknown>
clearChat: () => void
} & ActionsState) {
Expand All @@ -62,25 +84,11 @@ export default function Chat({
const runChainOnce = useRef(false)
// Index where the chain ends and the chat begins
const [chainLength, setChainLength] = useState<number>(Infinity)
const [conversation, setConversation] = useState<Conversation | undefined>()
const [responseStream, setResponseStream] = useState<string | undefined>()
const [isStreaming, setIsStreaming] = useState(false)

const addMessageToConversation = useCallback(
(message: ConversationMessage) => {
let newConversation: Conversation
setConversation((prevConversation) => {
newConversation = {
...prevConversation,
messages: [...(prevConversation?.messages ?? []), message],
} as Conversation
return newConversation
})
return newConversation!
},
[],
)

const { messages, addMessages, unresponedToolCalls } = useMessages<V>({
version: promptlVersion,
})
const startStreaming = useCallback(() => {
setError(undefined)
setUsage({
Expand Down Expand Up @@ -122,7 +130,7 @@ export default function Chat({

if ('messages' in data) {
setResponseStream(undefined)
data.messages!.forEach(addMessageToConversation)
addMessages(data.messages ?? [])
messagesCount += data.messages!.length
}

Expand All @@ -148,6 +156,7 @@ export default function Chat({
}
break
}

default:
break
}
Expand All @@ -163,7 +172,7 @@ export default function Chat({
commit.uuid,
parameters,
runDocumentAction,
addMessageToConversation,
addMessages,
startStreaming,
stopStreaming,
])
Expand All @@ -176,14 +185,15 @@ export default function Chat({
}, [runDocument])

const submitUserMessage = useCallback(
async (input: string) => {
async (input: string | ToolCallResponse) => {
if (!documentLogUuid) return // This should not happen

const message: ConversationMessage = {
role: MessageRole.user,
content: [{ type: ContentType.text, text: input }],
const message = buildMessage({ input })

// Only in Chat mode we add optimistically the message
if (message.role === MessageRole.user) {
addMessages([message])
}
addMessageToConversation(message)

let response = ''
startStreaming()
Expand All @@ -201,7 +211,7 @@ export default function Chat({

if ('messages' in data) {
setResponseStream(undefined)
data.messages!.forEach(addMessageToConversation)
addMessages(data.messages ?? [])
}

switch (event) {
Expand All @@ -222,6 +232,7 @@ export default function Chat({
}
break
}

default:
break
}
Expand All @@ -235,7 +246,7 @@ export default function Chat({
[
documentLogUuid,
addMessagesAction,
addMessageToConversation,
addMessages,
startStreaming,
stopStreaming,
],
Expand All @@ -255,32 +266,30 @@ export default function Chat({
className='flex flex-col gap-3 flex-grow flex-shrink min-h-0 custom-scrollbar scrollable-indicator pb-12'
>
<MessageList
messages={conversation?.messages.slice(0, chainLength - 1) ?? []}
messages={messages.slice(0, chainLength - 1) ?? []}
parameters={Object.keys(parameters)}
collapseParameters={!expandParameters}
/>
{(conversation?.messages.length ?? 0) >= chainLength && (
{(messages.length ?? 0) >= chainLength && (
<>
<MessageList
messages={
conversation?.messages.slice(chainLength - 1, chainLength) ?? []
}
messages={messages.slice(chainLength - 1, chainLength) ?? []}
/>
{time && <Timer timeMs={time} />}
</>
)}
{(conversation?.messages.length ?? 0) > chainLength && (
{(messages.length ?? 0) > chainLength && (
<>
<Text.H6M>Chat</Text.H6M>
<MessageList messages={conversation!.messages.slice(chainLength)} />
<MessageList messages={messages.slice(chainLength)} />
</>
)}
{error ? (
<ErrorMessage error={error} />
) : (
<StreamMessage
responseStream={responseStream}
conversation={conversation}
messages={messages}
chainLength={chainLength}
/>
)}
Expand Down Expand Up @@ -359,16 +368,15 @@ export function TokenUsage({

export function StreamMessage({
responseStream,
conversation,
messages,
chainLength,
}: {
responseStream: string | undefined
conversation: Conversation | undefined
messages: ConversationMessage[]
chainLength: number
}) {
if (responseStream === undefined) return null
if (conversation === undefined) return null
if (conversation.messages.length < chainLength - 1) {
if (messages.length < chainLength - 1) {
return (
<Message
role={MessageRole.assistant}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
'use client'

import { useState } from 'react'

import { useDocumentParameters } from '$/hooks/useDocumentParameters'
Expand Down Expand Up @@ -30,6 +28,7 @@ export default function Playground({
setPrompt: (prompt: string) => void
metadata: ConversationMetadata
}) {
const promptlVersion = document.promptlVersion === 1 ? 1 : 0
const [mode, setMode] = useState<'preview' | 'chat'>('preview')
const { commit } = useCurrentCommit()
const [expanded, setExpanded] = useState(true)
Expand Down Expand Up @@ -77,6 +76,7 @@ export default function Playground({
) : (
<Chat
document={document}
promptlVersion={promptlVersion}
parameters={parameters}
clearChat={() => setMode('preview')}
expandParameters={expandParameters}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import {
ToolPart,
ToolRequest,
PromptlVersion,
VersionedMessage,
extractToolContents,
} from '$/lib/versionedMessagesHelpers'
import { Message as CompilerMessage } from '@latitude-data/compiler'
import { useCallback, useState } from 'react'

function isToolRequest(part: ToolPart): part is ToolRequest {
return 'toolCallId' in part
}

function getUnrespondedToolRequests<V extends PromptlVersion>({
version,
messages,
}: {
version: V
messages: VersionedMessage<V>[]
}) {
const parts = extractToolContents({ version, messages })
const toolRequestIds = new Set<string>()
const toolResponses = new Set<string>()

parts.forEach((part) => {
if (isToolRequest(part)) {
toolRequestIds.add(part.toolCallId)
} else {
toolResponses.add(part.id)
}
})

return parts.filter(
(part): part is ToolRequest =>
isToolRequest(part) && !toolResponses.has(part.toolCallId),
)
}

type Props<V extends PromptlVersion> = { version: V }

export function useMessages<V extends PromptlVersion>({ version }: Props<V>) {
const [messages, setMessages] = useState<VersionedMessage<V>[]>([])
const [unresponedToolCalls, setUnresponedToolCalls] = useState<ToolRequest[]>(
[],
)
// FIXME: Kill compiler please
// every where we have old compiler types. To avoid Typescript crying we
// allow only compiler messages and then transform them to versioned messages
const addMessages = useCallback(
(m: CompilerMessage[]) => {
const msg = m as VersionedMessage<V>[]
setMessages((prevMessages) => {
const newMessages = prevMessages.concat(msg)
setUnresponedToolCalls(
getUnrespondedToolRequests({ version, messages: newMessages }),
)
return newMessages
})
},
[version],
)

return {
addMessages,
messages: messages as CompilerMessage[],
unresponedToolCalls,
}
}
Loading

0 comments on commit 1afb6cd

Please sign in to comment.