Skip to content

Commit

Permalink
fix recall by agent id
Browse files Browse the repository at this point in the history
  • Loading branch information
lalalune committed Nov 3, 2024
1 parent d9aeb80 commit 11e7db1
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 13 deletions.
10 changes: 5 additions & 5 deletions core/src/adapters/postgres.ts
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,16 @@ export class PostgresDatabaseAdapter extends DatabaseAdapter {

await client.query(
`INSERT INTO memories (
id, type, content, embedding, "userId", "roomId", "unique", "createdAt"
) VALUES ($1, $2, $3, $4::vector, $5::uuid, $6::uuid, $7, to_timestamp($8/1000.0))`,
id, type, content, embedding, "userId", "roomId", "agentId", "unique", "createdAt"
) VALUES ($1, $2, $3, $4::vector, $5::uuid, $6::uuid, $7::uuid, $8, to_timestamp($9/1000.0))`,
[
memory.id ?? v4(),
tableName,
JSON.stringify(memory.content),
`[${memory.embedding.join(",")}]`,
memory.userId,
memory.roomId,
memory.agentId,
memory.unique ?? isUnique,
Date.now(),
]
Expand Down Expand Up @@ -365,7 +366,7 @@ export class PostgresDatabaseAdapter extends DatabaseAdapter {
}

if (params.agentId) {
sql += " AND userId = $3";
sql += " AND agentId = $3";
values.push(params.agentId);
}

Expand Down Expand Up @@ -638,9 +639,8 @@ export class PostgresDatabaseAdapter extends DatabaseAdapter {
sql += ` AND "unique" = true`;
}

// TODO: Test this
if (params.agentId) {
sql += " AND userId = $3";
sql += " AND agentId = $3";
values.push(params.agentId);
}

Expand Down
16 changes: 11 additions & 5 deletions core/src/adapters/sqlite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter {
let queryParams = [params.tableName, ...params.roomIds];

if (params.agentId) {
sql += ` AND userId = ?`;
sql += ` AND agentId = ?`;
queryParams.push(params.agentId);
}

Expand Down Expand Up @@ -219,14 +219,15 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter {
const createdAt = memory.createdAt ?? Date.now();

// Insert the memory with the appropriate 'unique' value
const sql = `INSERT OR REPLACE INTO memories (id, type, content, embedding, userId, roomId, \`unique\`, createdAt) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`;
const sql = `INSERT OR REPLACE INTO memories (id, type, content, embedding, userId, roomId, agentId, \`unique\`, createdAt) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`;
this.db.prepare(sql).run(
memory.id ?? v4(),
tableName,
content,
new Float32Array(memory.embedding ?? embeddingZeroVector), // Store as Float32Array
memory.userId,
memory.roomId,
memory.agentId,
isUnique ? 1 : 0,
createdAt
);
Expand All @@ -235,6 +236,7 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter {
async searchMemories(params: {
tableName: string;
roomId: UUID;
agentId?: UUID;
embedding: number[];
match_threshold: number;
match_count: number;
Expand All @@ -256,6 +258,11 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter {
sql += " AND `unique` = 1";
}

if (params.agentId) {
sql += " AND agentId = ?";
queryParams.push(params.agentId);
}

sql += ` ORDER BY similarity ASC LIMIT ?`; // ASC for lower distance
// Updated queryParams order matches the placeholders

Expand Down Expand Up @@ -297,9 +304,8 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter {
if (params.unique) {
sql += " AND `unique` = 1";
}
// TODO: Test this
if (params.agentId) {
sql += " AND userId = ?";
sql += " AND agentId = ?";
queryParams.push(params.agentId);
}

Expand Down Expand Up @@ -413,7 +419,7 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter {
}

if (params.agentId) {
sql += " AND userId = ?";
sql += " AND agentId = ?";
queryParams.push(params.agentId);
}

Expand Down
5 changes: 3 additions & 2 deletions core/src/adapters/sqljs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ export class SqlJsDatabaseAdapter extends DatabaseAdapter {
}

// Insert the memory with the appropriate 'unique' value
const sql = `INSERT INTO memories (id, type, content, embedding, userId, roomId, \`unique\`, createdAt) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`;
const sql = `INSERT INTO memories (id, type, content, embedding, userId, roomId, agentId, \`unique\`, createdAt) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`;
const stmt = this.db.prepare(sql);

const createdAt = memory.createdAt ?? Date.now();
Expand All @@ -248,6 +248,7 @@ export class SqlJsDatabaseAdapter extends DatabaseAdapter {
JSON.stringify(memory.embedding),
memory.userId,
memory.roomId,
memory.agentId,
isUnique ? 1 : 0,
createdAt,
]);
Expand Down Expand Up @@ -461,7 +462,7 @@ export class SqlJsDatabaseAdapter extends DatabaseAdapter {
}

if (params.agentId) {
sql += " AND userId = ?";
sql += " AND agentId = ?";
}

sql += " ORDER BY createdAt DESC";
Expand Down
1 change: 0 additions & 1 deletion core/src/clients/direct/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ class DirectClient {

const memory: Memory = {
id: messageId,
...userMessage,
agentId: runtime.agentId,
userId,
roomId,
Expand Down
4 changes: 4 additions & 0 deletions core/src/core/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ export async function generateText({
prettyConsole.log("Initializing OpenAI model.");
const openai = createOpenAI({ apiKey });

console.log('****** CONTEXT\n', context)

const { text: openaiResponse } = await aiGenerateText({
model: openai.languageModel(model),
prompt: context,
Expand All @@ -83,6 +85,8 @@ export async function generateText({
presencePenalty: presence_penalty,
});

console.log("****** RESPONSE\n", openaiResponse);

response = openaiResponse;
prettyConsole.log("Received response from OpenAI model.");
break;
Expand Down

0 comments on commit 11e7db1

Please sign in to comment.