From 77fbca70f90ee8845a7093b528998af46df4423f Mon Sep 17 00:00:00 2001 From: Will Poynter Date: Tue, 1 Oct 2024 23:28:06 +0100 Subject: [PATCH] Allowing serdes to be async --- libs/checkpoint-mongodb/src/index.ts | 53 +++++++++++++++------------- libs/checkpoint-sqlite/src/index.ts | 34 ++++++++++-------- libs/checkpoint/src/memory.ts | 18 +++++----- libs/checkpoint/src/serde/base.ts | 4 +-- libs/langgraph/src/tests/utils.ts | 2 +- 5 files changed, 61 insertions(+), 50 deletions(-) diff --git a/libs/checkpoint-mongodb/src/index.ts b/libs/checkpoint-mongodb/src/index.ts index 54eb68a0..949c9dce 100644 --- a/libs/checkpoint-mongodb/src/index.ts +++ b/libs/checkpoint-mongodb/src/index.ts @@ -214,9 +214,12 @@ export class MongoDBSaver extends BaseCheckpointSaver { `The provided config must contain a configurable field with a "thread_id" field.` ); } - const [checkpointType, serializedCheckpoint] = - this.serde.dumpsTyped(checkpoint); - const [metadataType, serializedMetadata] = this.serde.dumpsTyped(metadata); + const [checkpointType, serializedCheckpoint] = await this.serde.dumpsTyped( + checkpoint + ); + const [metadataType, serializedMetadata] = await this.serde.dumpsTyped( + metadata + ); if (checkpointType !== metadataType) { throw new Error("Mismatched checkpoint and metadata types."); } @@ -268,31 +271,33 @@ export class MongoDBSaver extends BaseCheckpointSaver { ); } - const operations = writes.map(([channel, value], idx) => { - const upsertQuery = { - thread_id, - checkpoint_ns, - checkpoint_id, - task_id: taskId, - idx, - }; + const operations = await Promise.all( + writes.map(async ([channel, value], idx) => { + const upsertQuery = { + thread_id, + checkpoint_ns, + checkpoint_id, + task_id: taskId, + idx, + }; - const [type, serializedValue] = this.serde.dumpsTyped(value); + const [type, serializedValue] = await this.serde.dumpsTyped(value); - return { - updateOne: { - filter: upsertQuery, - update: { - $set: { - channel, - type, - value: serializedValue, + return { + updateOne: { + filter: upsertQuery, + update: { + $set: { + channel, + type, + value: serializedValue, + }, }, + upsert: true, }, - upsert: true, - }, - }; - }); + }; + }) + ); await this.db .collection(this.checkpointWritesCollectionName) diff --git a/libs/checkpoint-sqlite/src/index.ts b/libs/checkpoint-sqlite/src/index.ts index bda2b88e..a44020ce 100644 --- a/libs/checkpoint-sqlite/src/index.ts +++ b/libs/checkpoint-sqlite/src/index.ts @@ -221,8 +221,10 @@ CREATE TABLE IF NOT EXISTS writes ( ): Promise { this.setup(); - const [type1, serializedCheckpoint] = this.serde.dumpsTyped(checkpoint); - const [type2, serializedMetadata] = this.serde.dumpsTyped(metadata); + const [type1, serializedCheckpoint] = await this.serde.dumpsTyped( + checkpoint + ); + const [type2, serializedMetadata] = await this.serde.dumpsTyped(metadata); if (type1 !== type2) { throw new Error( "Failed to serialized checkpoint and metadata to the same type." @@ -272,19 +274,21 @@ CREATE TABLE IF NOT EXISTS writes ( } }); - const rows = writes.map((write, idx) => { - const [type, serializedWrite] = this.serde.dumpsTyped(write[1]); - return [ - config.configurable?.thread_id, - config.configurable?.checkpoint_ns, - config.configurable?.checkpoint_id, - taskId, - idx, - write[0], - type, - serializedWrite, - ]; - }); + const rows = await Promise.all( + writes.map(async (write, idx) => { + const [type, serializedWrite] = await this.serde.dumpsTyped(write[1]); + return [ + config.configurable?.thread_id, + config.configurable?.checkpoint_ns, + config.configurable?.checkpoint_id, + taskId, + idx, + write[0], + type, + serializedWrite, + ]; + }) + ); transaction(rows); } diff --git a/libs/checkpoint/src/memory.ts b/libs/checkpoint/src/memory.ts index ae4f124e..6f4ae0dc 100644 --- a/libs/checkpoint/src/memory.ts +++ b/libs/checkpoint/src/memory.ts @@ -47,8 +47,8 @@ export class MemorySaver extends BaseCheckpointSaver { ?.filter(([_taskId, channel]) => { return channel === TASKS; }) - .map(([_taskId, _channel, writes]) => { - return this.serde.loadsTyped("json", writes as string); + .map(async ([_taskId, _channel, writes]) => { + return await this.serde.loadsTyped("json", writes as string); }) ?? [] ); } @@ -294,8 +294,10 @@ export class MemorySaver extends BaseCheckpointSaver { this.storage[threadId][checkpointNamespace] = {}; } - const [, serializedCheckpoint] = this.serde.dumpsTyped(preparedCheckpoint); - const [, serializedMetadata] = this.serde.dumpsTyped(metadata); + const [, serializedCheckpoint] = await this.serde.dumpsTyped( + preparedCheckpoint + ); + const [, serializedMetadata] = await this.serde.dumpsTyped(metadata); this.storage[threadId][checkpointNamespace][checkpoint.id] = [ serializedCheckpoint, serializedMetadata, @@ -333,11 +335,11 @@ export class MemorySaver extends BaseCheckpointSaver { if (this.writes[key] === undefined) { this.writes[key] = []; } - const pendingWrites: CheckpointPendingWrite[] = writes.map( - ([channel, value]) => { - const [, serializedValue] = this.serde.dumpsTyped(value); + const pendingWrites: CheckpointPendingWrite[] = await Promise.all( + writes.map(async ([channel, value]) => { + const [, serializedValue] = await this.serde.dumpsTyped(value); return [taskId, channel, serializedValue]; - } + }) ); this.writes[key].push(...pendingWrites); } diff --git a/libs/checkpoint/src/serde/base.ts b/libs/checkpoint/src/serde/base.ts index a300fd95..da9b2505 100644 --- a/libs/checkpoint/src/serde/base.ts +++ b/libs/checkpoint/src/serde/base.ts @@ -1,6 +1,6 @@ export interface SerializerProtocol { // eslint-disable-next-line @typescript-eslint/no-explicit-any - dumpsTyped(data: any): [string, Uint8Array]; + dumpsTyped(data: any): Promise<[string, Uint8Array]> | [string, Uint8Array]; // eslint-disable-next-line @typescript-eslint/no-explicit-any - loadsTyped(type: string, data: Uint8Array | string): any; + loadsTyped(type: string, data: Uint8Array | string): Promise | any; } diff --git a/libs/langgraph/src/tests/utils.ts b/libs/langgraph/src/tests/utils.ts index 543c7524..4c6e4f44 100644 --- a/libs/langgraph/src/tests/utils.ts +++ b/libs/langgraph/src/tests/utils.ts @@ -175,7 +175,7 @@ export class MemorySaverAssertImmutable extends MemorySaver { ); } } - const [, serializedCheckpoint] = this.serde.dumpsTyped(checkpoint); + const [, serializedCheckpoint] = await this.serde.dumpsTyped(checkpoint); // save a copy of the checkpoint this.storageForCopies[thread_id][checkpoint.id] = new TextDecoder().decode( serializedCheckpoint