From 11e7db192c1385475c47fb363be44d02a0209561 Mon Sep 17 00:00:00 2001 From: moon Date: Sun, 3 Nov 2024 14:11:05 -0800 Subject: [PATCH] fix recall by agent id --- core/src/adapters/postgres.ts | 10 +++++----- core/src/adapters/sqlite.ts | 16 +++++++++++----- core/src/adapters/sqljs.ts | 5 +++-- core/src/clients/direct/index.ts | 1 - core/src/core/generation.ts | 4 ++++ 5 files changed, 23 insertions(+), 13 deletions(-) diff --git a/core/src/adapters/postgres.ts b/core/src/adapters/postgres.ts index d3375eeb..b97ec704 100644 --- a/core/src/adapters/postgres.ts +++ b/core/src/adapters/postgres.ts @@ -269,8 +269,8 @@ export class PostgresDatabaseAdapter extends DatabaseAdapter { await client.query( `INSERT INTO memories ( - id, type, content, embedding, "userId", "roomId", "unique", "createdAt" - ) VALUES ($1, $2, $3, $4::vector, $5::uuid, $6::uuid, $7, to_timestamp($8/1000.0))`, + id, type, content, embedding, "userId", "roomId", "agentId", "unique", "createdAt" + ) VALUES ($1, $2, $3, $4::vector, $5::uuid, $6::uuid, $7::uuid, $8, to_timestamp($9/1000.0))`, [ memory.id ?? v4(), tableName, @@ -278,6 +278,7 @@ export class PostgresDatabaseAdapter extends DatabaseAdapter { `[${memory.embedding.join(",")}]`, memory.userId, memory.roomId, + memory.agentId, memory.unique ?? isUnique, Date.now(), ] @@ -365,7 +366,7 @@ export class PostgresDatabaseAdapter extends DatabaseAdapter { } if (params.agentId) { - sql += " AND userId = $3"; + sql += " AND agentId = $3"; values.push(params.agentId); } @@ -638,9 +639,8 @@ export class PostgresDatabaseAdapter extends DatabaseAdapter { sql += ` AND "unique" = true`; } - // TODO: Test this if (params.agentId) { - sql += " AND userId = $3"; + sql += " AND agentId = $3"; values.push(params.agentId); } diff --git a/core/src/adapters/sqlite.ts b/core/src/adapters/sqlite.ts index 68afeb1a..efafc954 100644 --- a/core/src/adapters/sqlite.ts +++ b/core/src/adapters/sqlite.ts @@ -162,7 +162,7 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter { let queryParams = [params.tableName, ...params.roomIds]; if (params.agentId) { - sql += ` AND userId = ?`; + sql += ` AND agentId = ?`; queryParams.push(params.agentId); } @@ -219,7 +219,7 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter { const createdAt = memory.createdAt ?? Date.now(); // Insert the memory with the appropriate 'unique' value - const sql = `INSERT OR REPLACE INTO memories (id, type, content, embedding, userId, roomId, \`unique\`, createdAt) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`; + const sql = `INSERT OR REPLACE INTO memories (id, type, content, embedding, userId, roomId, agentId, \`unique\`, createdAt) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`; this.db.prepare(sql).run( memory.id ?? v4(), tableName, @@ -227,6 +227,7 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter { new Float32Array(memory.embedding ?? embeddingZeroVector), // Store as Float32Array memory.userId, memory.roomId, + memory.agentId, isUnique ? 1 : 0, createdAt ); @@ -235,6 +236,7 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter { async searchMemories(params: { tableName: string; roomId: UUID; + agentId?: UUID; embedding: number[]; match_threshold: number; match_count: number; @@ -256,6 +258,11 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter { sql += " AND `unique` = 1"; } + if (params.agentId) { + sql += " AND agentId = ?"; + queryParams.push(params.agentId); + } + sql += ` ORDER BY similarity ASC LIMIT ?`; // ASC for lower distance // Updated queryParams order matches the placeholders @@ -297,9 +304,8 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter { if (params.unique) { sql += " AND `unique` = 1"; } - // TODO: Test this if (params.agentId) { - sql += " AND userId = ?"; + sql += " AND agentId = ?"; queryParams.push(params.agentId); } @@ -413,7 +419,7 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter { } if (params.agentId) { - sql += " AND userId = ?"; + sql += " AND agentId = ?"; queryParams.push(params.agentId); } diff --git a/core/src/adapters/sqljs.ts b/core/src/adapters/sqljs.ts index 02f5aee3..b6cdf447 100644 --- a/core/src/adapters/sqljs.ts +++ b/core/src/adapters/sqljs.ts @@ -236,7 +236,7 @@ export class SqlJsDatabaseAdapter extends DatabaseAdapter { } // Insert the memory with the appropriate 'unique' value - const sql = `INSERT INTO memories (id, type, content, embedding, userId, roomId, \`unique\`, createdAt) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`; + const sql = `INSERT INTO memories (id, type, content, embedding, userId, roomId, agentId, \`unique\`, createdAt) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`; const stmt = this.db.prepare(sql); const createdAt = memory.createdAt ?? Date.now(); @@ -248,6 +248,7 @@ export class SqlJsDatabaseAdapter extends DatabaseAdapter { JSON.stringify(memory.embedding), memory.userId, memory.roomId, + memory.agentId, isUnique ? 1 : 0, createdAt, ]); @@ -461,7 +462,7 @@ export class SqlJsDatabaseAdapter extends DatabaseAdapter { } if (params.agentId) { - sql += " AND userId = ?"; + sql += " AND agentId = ?"; } sql += " ORDER BY createdAt DESC"; diff --git a/core/src/clients/direct/index.ts b/core/src/clients/direct/index.ts index 18e6a11c..8be31891 100644 --- a/core/src/clients/direct/index.ts +++ b/core/src/clients/direct/index.ts @@ -166,7 +166,6 @@ class DirectClient { const memory: Memory = { id: messageId, - ...userMessage, agentId: runtime.agentId, userId, roomId, diff --git a/core/src/core/generation.ts b/core/src/core/generation.ts index 588fd1ef..caa7ffdc 100644 --- a/core/src/core/generation.ts +++ b/core/src/core/generation.ts @@ -74,6 +74,8 @@ export async function generateText({ prettyConsole.log("Initializing OpenAI model."); const openai = createOpenAI({ apiKey }); + console.log('****** CONTEXT\n', context) + const { text: openaiResponse } = await aiGenerateText({ model: openai.languageModel(model), prompt: context, @@ -83,6 +85,8 @@ export async function generateText({ presencePenalty: presence_penalty, }); + console.log("****** RESPONSE\n", openaiResponse); + response = openaiResponse; prettyConsole.log("Received response from OpenAI model."); break;