Skip to content

Commit

Permalink
feat: wip
Browse files Browse the repository at this point in the history
  • Loading branch information
andresgutgon committed Jan 15, 2025
1 parent d80cb36 commit dae14e0
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 26 deletions.
76 changes: 52 additions & 24 deletions src/compiler/chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,15 @@ type ChainStep<M extends AdapterMessageType> = ProviderConversation<M> & {

type StepResponse<M extends AdapterMessageType> =
| string
| M[]
| (Omit<M, 'role'> & {
role?: M['role']
})
role?: M['role']
})

type BuildStepResponseContent = {
messages?: Message[]
contents: MessageContent[] | undefined
}

export class Chain<M extends AdapterMessageType = Message> {
public rawText: string
Expand Down Expand Up @@ -65,6 +71,7 @@ export class Chain<M extends AdapterMessageType = Message> {
ast: Fragment
scope: Scope
globalConfig: Config | undefined
globalMessages: Message[]
}
} & CompileOptions) {
this.rawText = prompt
Expand All @@ -73,6 +80,7 @@ export class Chain<M extends AdapterMessageType = Message> {
this.ast = serialized?.ast ?? parse(prompt)
this.scope = serialized?.scope ?? new Scope(parameters)
this.globalConfig = serialized?.globalConfig
this.globalMessages = serialized?.globalMessages ?? []
this.didStart = !!serialized

this.adapter = adapter
Expand All @@ -98,33 +106,26 @@ export class Chain<M extends AdapterMessageType = Message> {

this.didStart = true

// TODO: Maybe response can be an array of messages
// now is a string or a Message
const responseContent = this.buildStepResponseContent(response)
const newGlobalMessages = this.buildGlobalMessages(responseContent)

if (responseContent && !this.wasLastStepIsolated) {
this.globalMessages.push({
role: MessageRole.assistant,
content: responseContent ?? [],
})
if (newGlobalMessages.length > 0) {
this.globalMessages = [
...this.globalMessages,
...(newGlobalMessages as Message[]),
]
}

const compile = new Compile({
ast: this.ast,
rawText: this.rawText,
globalScope: this.scope,
stepResponse: responseContent,
stepResponse: responseContent.contents,
...this.compileOptions,
})

const {
completed,
scopeStash,
ast,
messages,
globalConfig,
stepConfig,
} = await compile.run()
const { completed, scopeStash, ast, messages, globalConfig, stepConfig } =
await compile.run()

this.scope = Scope.withStash(scopeStash).copy(this.scope.getPointers())
this.ast = ast
Expand Down Expand Up @@ -172,6 +173,7 @@ export class Chain<M extends AdapterMessageType = Message> {
adapterType: this.adapter.type,
compilerOptions: this.compileOptions,
globalConfig: this.globalConfig,
globalMessages: this.globalMessages,
}
}

Expand All @@ -180,12 +182,20 @@ export class Chain<M extends AdapterMessageType = Message> {
}

private buildStepResponseContent(
response?: StepResponse<M>,
): MessageContent[] | undefined {
if (response == undefined) return response

response?: StepResponse<M> | M[],
): BuildStepResponseContent {
if (response == undefined) return { contents: undefined }
if (typeof response === 'string') {
return [{ type: ContentType.text, text: response }]
return { contents: [{ text: response, type: ContentType.text }] }
}

if (Array.isArray(response)) {
const converted = this.adapter.toPromptl({
config: this.globalConfig ?? {},
messages: response as M[],
})
const contents = converted.messages.flatMap((m) => m.content)
return { messages: converted.messages as Message[], contents }
}

const responseMessage = {
Expand All @@ -198,6 +208,24 @@ export class Chain<M extends AdapterMessageType = Message> {
messages: [responseMessage],
})

return convertedMessages.messages[0]!.content
return { contents: convertedMessages.messages[0]!.content }
}

private buildGlobalMessages(
buildStepResponseContent: BuildStepResponseContent,
) {
const { messages, contents } = buildStepResponseContent

if (this.wasLastStepIsolated) return []
if (!contents) return []

if (messages) return messages

return [
{
role: MessageRole.assistant,
content: contents ?? [],
},
]
}
}
7 changes: 5 additions & 2 deletions src/compiler/deserializeChain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ function safeSerializedData(data: string | SerializedChain): SerializedChain {
: typeof data === 'object'
? data
: {}

const compilerOptions = serialized.compilerOptions || {}
const globalConfig = serialized.globalConfig
const globalMessages = serialized.globalMessages || []

if (
typeof serialized !== 'object' ||
Expand All @@ -31,7 +33,7 @@ function safeSerializedData(data: string | SerializedChain): SerializedChain {
) {
throw new Error()
}
return { ...serialized, compilerOptions, globalConfig }
return { ...serialized, compilerOptions, globalConfig, globalMessages }
} catch {
throw new Error('Invalid serialized chain data')
}
Expand All @@ -51,6 +53,7 @@ export function deserializeChain({
adapterType,
compilerOptions,
globalConfig,
globalMessages,
} = safeSerializedData(serialized)

const adapter = getAdapter(adapterType)
Expand All @@ -60,7 +63,7 @@ export function deserializeChain({

return new Chain({
prompt: '',
serialized: { ast, scope, globalConfig },
serialized: { ast, scope, globalConfig, globalMessages },
adapter,
...compilerOptions,
})
Expand Down

0 comments on commit dae14e0

Please sign in to comment.