Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): Allow checking tool_calls on any BaseMessage without type casts #7479

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions langchain-core/src/messages/ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@ import {
type MessageType,
BaseMessageFields,
_mergeLists,
} from "./base.js";
import {
InvalidToolCall,
ToolCall,
ToolCallChunk,
defaultToolCallParser,
} from "./tool.js";
} from "./base.js";
import { InvalidToolCall, defaultToolCallParser } from "./tool.js";

export type AIMessageFields = BaseMessageFields & {
tool_calls?: ToolCall[];
Expand Down
93 changes: 93 additions & 0 deletions langchain-core/src/messages/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,23 @@ function stringifyWithDepthLimit(obj: any, depthLimit: number): string {
return JSON.stringify(helper(obj, 0), null, 2);
}

/**
* A call to a tool.
* @property {string} name - The name of the tool to be called
* @property {Record<string, any>} args - The arguments to the tool call
* @property {string} [id] - If provided, an identifier associated with the tool call
*/
export type ToolCall = {
name: string;

// eslint-disable-next-line @typescript-eslint/no-explicit-any
args: Record<string, any>;

id?: string;

type?: "tool_call";
};

/**
* Base class for all types of messages in a conversation. It includes
* properties like `content`, `name`, and `additional_kwargs`. It also
Expand Down Expand Up @@ -219,6 +236,8 @@ export abstract class BaseMessage
*/
id?: string;

tool_calls?: never | ToolCall[];

/**
* @deprecated Use .getType() instead or import the proper typeguard.
* For example:
Expand Down Expand Up @@ -457,6 +476,76 @@ export function _mergeObj<T = any>(
}
}

/**
* A chunk of a tool call (e.g., as part of a stream).
* When merging ToolCallChunks (e.g., via AIMessageChunk.__add__),
* all string attributes are concatenated. Chunks are only merged if their
* values of `index` are equal and not None.
*
* @example
* ```ts
* const leftChunks = [
* {
* name: "foo",
* args: '{"a":',
* index: 0
* }
* ];
*
* const leftAIMessageChunk = new AIMessageChunk({
* content: "",
* tool_call_chunks: leftChunks
* });
*
* const rightChunks = [
* {
* name: undefined,
* args: '1}',
* index: 0
* }
* ];
*
* const rightAIMessageChunk = new AIMessageChunk({
* content: "",
* tool_call_chunks: rightChunks
* });
*
* const result = leftAIMessageChunk.concat(rightAIMessageChunk);
* // result.tool_call_chunks is equal to:
* // [
* // {
* // name: "foo",
* // args: '{"a":1}'
* // index: 0
* // }
* // ]
* ```
*
* @property {string} [name] - If provided, a substring of the name of the tool to be called
* @property {string} [args] - If provided, a JSON substring of the arguments to the tool call
* @property {string} [id] - If provided, a substring of an identifier for the tool call
* @property {number} [index] - If provided, the index of the tool call in a sequence
*/
export type ToolCallChunk = {
name?: string;

args?: string;

id?: string;

index?: number;

type?: "tool_call_chunk";
};

export type InvalidToolCall = {
name?: string;
args?: string;
id?: string;
error?: string;
type?: "invalid_tool_call";
};

/**
* Represents a chunk of a message, which can be concatenated with other
* message chunks. It includes a method `_merge_kwargs_dict()` for merging
Expand All @@ -465,6 +554,10 @@ export function _mergeObj<T = any>(
* of `BaseMessageChunk` instances.
*/
export abstract class BaseMessageChunk extends BaseMessage {
tool_call_chunks?: never | ToolCallChunk[];

invalid_tool_calls?: never | InvalidToolCall[];

abstract concat(chunk: BaseMessageChunk): BaseMessageChunk;
}

Expand Down
8 changes: 8 additions & 0 deletions langchain-core/src/messages/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ export class ChatMessage
extends BaseMessage
implements ChatMessageFieldsWithRole
{
declare tool_calls?: never;

static lc_name() {
return "ChatMessage";
}
Expand Down Expand Up @@ -62,6 +64,12 @@ export class ChatMessage
* other chat message chunks.
*/
export class ChatMessageChunk extends BaseMessageChunk {
declare tool_calls?: never;

declare tool_call_chunks?: never;

declare invalid_tool_calls?: never;

static lc_name() {
return "ChatMessageChunk";
}
Expand Down
8 changes: 8 additions & 0 deletions langchain-core/src/messages/function.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ export interface FunctionMessageFieldsWithName extends BaseMessageFields {
* Represents a function message in a conversation.
*/
export class FunctionMessage extends BaseMessage {
declare tool_calls?: never;

static lc_name() {
return "FunctionMessage";
}
Expand Down Expand Up @@ -49,6 +51,12 @@ export class FunctionMessage extends BaseMessage {
* with other function message chunks.
*/
export class FunctionMessageChunk extends BaseMessageChunk {
declare tool_calls?: never;

declare tool_call_chunks?: never;

declare invalid_tool_calls?: never;

static lc_name() {
return "FunctionMessageChunk";
}
Expand Down
8 changes: 8 additions & 0 deletions langchain-core/src/messages/human.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import {
* Represents a human message in a conversation.
*/
export class HumanMessage extends BaseMessage {
declare tool_calls?: never;

static lc_name() {
return "HumanMessage";
}
Expand All @@ -24,6 +26,12 @@ export class HumanMessage extends BaseMessage {
* other human message chunks.
*/
export class HumanMessageChunk extends BaseMessageChunk {
declare tool_calls?: never;

declare tool_call_chunks?: never;

declare invalid_tool_calls?: never;

static lc_name() {
return "HumanMessageChunk";
}
Expand Down
2 changes: 0 additions & 2 deletions langchain-core/src/messages/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ export * from "./system.js";
export * from "./utils.js";
export * from "./transformers.js";
export * from "./modifier.js";
// TODO: Use a star export when we deprecate the
// existing "ToolCall" type in "base.js".
export {
type ToolMessageFieldsWithToolCallId,
ToolMessage,
Expand Down
2 changes: 2 additions & 0 deletions langchain-core/src/messages/modifier.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ export interface RemoveMessageFields
* Message responsible for deleting other messages.
*/
export class RemoveMessage extends BaseMessage {
declare tool_calls?: never;

/**
* The ID of the message to remove.
*/
Expand Down
8 changes: 8 additions & 0 deletions langchain-core/src/messages/system.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import {
* Represents a system message in a conversation.
*/
export class SystemMessage extends BaseMessage {
declare tool_calls?: never;

static lc_name() {
return "SystemMessage";
}
Expand All @@ -24,6 +26,12 @@ export class SystemMessage extends BaseMessage {
* other system message chunks.
*/
export class SystemMessageChunk extends BaseMessageChunk {
declare tool_calls?: never;

declare tool_call_chunks?: never;

declare invalid_tool_calls?: never;

static lc_name() {
return "SystemMessageChunk";
}
Expand Down
102 changes: 102 additions & 0 deletions langchain-core/src/messages/tests/base_message.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@ import {
AIMessageChunk,
coerceMessageLikeToMessage,
SystemMessage,
isAIMessageChunk,
isHumanMessageChunk,
isSystemMessageChunk,
BaseMessageChunk,
SystemMessageChunk,
isAIMessage,
isHumanMessage,
isSystemMessage,
isToolMessage,
BaseMessage,
isChatMessage,
isToolMessageChunk,
isChatMessageChunk,
} from "../index.js";
import { load } from "../../load/index.js";
import { concat } from "../../utils/stream.js";
Expand Down Expand Up @@ -462,3 +475,92 @@ describe("usage_metadata serialized", () => {
expect(jsonConcatenatedAIMessageChunk).toContain("total_tokens");
});
});

it("Should narrow tool call typing when accessing a base message array", async () => {
const messages: BaseMessage[] = [new SystemMessage("test")];
messages.push(new AIMessage("foo"));
if (messages[0].tool_calls?.[0] !== undefined) {
// Allow checking existence on BaseMessage with no errors
void messages[0].tool_calls[0]?.args;
}

const msg = messages[0];
if (isAIMessage(msg)) {
// Should allow access from AI messages
void msg.tool_calls?.[0].args;
}
if (isHumanMessage(msg)) {
// @ts-expect-error Typing should not allow access from human messages
void msg.tool_calls?.[0].args;
}
if (isSystemMessage(msg)) {
// @ts-expect-error Typing should not allow access from system messages
void msg.tool_calls?.[0].args;
}
if (isToolMessage(msg)) {
// @ts-expect-error Typing should not allow access from tool messages
void msg.tool_calls?.[0].args;
}
if (isChatMessage(msg)) {
// @ts-expect-error Typing should not allow access from chat messages
void msg.tool_calls?.[0].args;
}

const messageChunks: BaseMessageChunk[] = [new SystemMessageChunk("test")];
if (messageChunks[0].tool_calls?.[0] !== undefined) {
// Allow checking existence on BaseMessage with no errors
void messageChunks[0].tool_calls[0].args;
}

if (messageChunks[0].tool_call_chunks?.[0] !== undefined) {
// Allow checking existence on BaseMessage with no errors
void messageChunks[0].tool_call_chunks[0].args;
}

if (messageChunks[0].invalid_tool_calls?.[0] !== undefined) {
// Allow checking existence on BaseMessage with no errors
void messageChunks[0].invalid_tool_calls[0].args;
}

const msgChunk = messageChunks[0];
if (isAIMessageChunk(msgChunk)) {
// Typing should allow access from AI message chunks
void msgChunk.tool_calls?.[0].args;
// Typing should allow access from AI message chunks
void msgChunk.tool_call_chunks?.[0].args;
// Typing should allow access from AI message chunks
void msgChunk.invalid_tool_calls?.[0].args;
}
if (isHumanMessageChunk(msgChunk)) {
// @ts-expect-error Typing should not allow access from human message chunks
void msgChunk.tool_calls?.[0].args;
// @ts-expect-error Typing should not allow access from human message chunks
void msgChunk.tool_call_chunks?.[0].args;
// @ts-expect-error Typing should not allow access from human message chunks
void msgChunk.invalid_tool_calls?.[0].args;
}
if (isSystemMessageChunk(msgChunk)) {
// @ts-expect-error Typing should not allow access from system message chunks
void msgChunk.tool_calls?.[0].args;
// @ts-expect-error Typing should not allow access from system message chunks
void msgChunk.tool_call_chunks?.[0].args;
// @ts-expect-error Typing should not allow access from system message chunks
void msgChunk.invalid_tool_calls?.[0].args;
}
if (isToolMessageChunk(msgChunk)) {
// @ts-expect-error Typing should not allow access from tool message chunks
void msgChunk.tool_calls?.[0].args;
// @ts-expect-error Typing should not allow access from tool message chunks
void msgChunk.tool_call_chunks?.[0].args;
// @ts-expect-error Typing should not allow access from tool message chunks
void msgChunk.invalid_tool_calls?.[0].args;
}
if (isChatMessageChunk(msgChunk)) {
// @ts-expect-error Typing should not allow access from chat message chunks
void msgChunk.tool_calls?.[0].args;
// @ts-expect-error Typing should not allow access from chat message chunks
void msgChunk.tool_call_chunks?.[0].args;
// @ts-expect-error Typing should not allow access from chat message chunks
void msgChunk.invalid_tool_calls?.[0].args;
}
});
Loading
Loading