Skip to content

Commit

Permalink
Merge pull request #305 from receptron/embeddings
Browse files Browse the repository at this point in the history
Embeddings
  • Loading branch information
snakajima authored May 10, 2024
2 parents ff6367c + d7e6964 commit 630863e
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 12 deletions.
64 changes: 64 additions & 0 deletions samples/embeddings.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import "dotenv/config";

import { graphDataTestRunner } from "~/utils/runner";
import {
tokenBoundStringsAgent,
sortByValuesAgent,
dotProductAgent,
stringEmbeddingsAgent,
stringSplitterAgent,
stringTemplateAgent,
slashGPTAgent,
wikipediaAgent,
} from "@/experimental_agents";

const graph_data = {
version: 0.3,
nodes: {
strings: {
value: ["王", "女王", "貴族", "男", "女", "庶民", "農民"],
isResult: true,
},
embeddings: {
agent: "stringEmbeddingsAgent",
inputs: [":strings"],
},
similarities: {
agent: "mapAgent",
inputs: [":embeddings", ":embeddings"],
graph: {
nodes: {
result: {
agent: "dotProductAgent",
inputs: [":$1", ":$0"],
isResult: true,
},
},
},
isResult: true,
},
},
};

const main = async () => {
const result = await graphDataTestRunner(
"sample_wiki.log",
graph_data,
{
tokenBoundStringsAgent,
sortByValuesAgent,
dotProductAgent,
stringEmbeddingsAgent,
stringSplitterAgent,
stringTemplateAgent,
slashGPTAgent,
wikipediaAgent,
},
undefined,
false,
);
console.log(result.similarities);
};
if (process.argv[1] === __filename) {
main();
}
2 changes: 1 addition & 1 deletion samples/embeddings/wikipedia.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ const graph_data = {
similarityCheck: {
// Get the cosine similarities of those vectors
agent: "dotProductAgent",
inputs: [":embeddings", ":topicEmbedding"],
inputs: [":embeddings", ":topicEmbedding.$0"],
},
sortedChunks: {
// Sort chunks based on those similarities
Expand Down
12 changes: 6 additions & 6 deletions src/experimental_agents/matrix_agents/dot_product_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ import { AgentFunction } from "@/graphai";
// typically used to calculate cosine similarity of embedding vectors.
// Inputs:
// inputs[0]: Two dimentional array of numbers.
// inputs[1]: Two dimentional array of numbers (but the array size is 1 for the first dimention)
// inputs[1]: One dimentional array of numbers.
// Outputs:
// { contents: Array<number> } // array of docProduct of each vector (A[]) and vector B
export const dotProductAgent: AgentFunction<Record<never, never>, Array<number>, Array<Array<number>>> = async ({ inputs }) => {
const embeddings: Array<Array<number>> = inputs[0];
const reference: Array<number> = inputs[1][0];
export const dotProductAgent: AgentFunction<Record<never, never>, Array<number>, Array<Array<number>> | Array<number>> = async ({ inputs }) => {
const embeddings: Array<Array<number>> = inputs[0] as Array<Array<number>>;
const reference: Array<number> = inputs[1] as Array<number>;
if (embeddings[0].length != reference.length) {
throw new Error("dotProduct: Length of vectors do not match.");
throw new Error(`dotProduct: Length of vectors do not match. ${embeddings[0].length}, ${reference.length}`);
}
const contents = embeddings.map((embedding) => {
return embedding.reduce((dotProduct: number, value, index) => {
Expand All @@ -33,7 +33,7 @@ const dotProductAgentInfo = {
[3, 4],
[5, 6],
],
[[3, 2]],
[3, 2],
],
params: {},
result: [7, 17, 27],
Expand Down
2 changes: 1 addition & 1 deletion src/experimental_agents/service_agents/fetch_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ const fetchAgentInfo = {
headers: {
"Content-Type": "application/json",
},
body: '{"foo":"bar"}',
body: "{\"foo\":\"bar\"}",
},
},
],
Expand Down
6 changes: 3 additions & 3 deletions src/task_manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ export class TaskManager {
public getStatus(verbose: boolean = false) {
const nodes = verbose
? {
runningNodes: Array.from(this.runningNodes).map((node) => node.nodeId),
queuedNodes: this.taskQueue.map((task) => task.node.nodeId),
}
runningNodes: Array.from(this.runningNodes).map((node) => node.nodeId),
queuedNodes: this.taskQueue.map((task) => task.node.nodeId),
}
: {};
return {
concurrency: this.concurrency,
Expand Down
2 changes: 1 addition & 1 deletion tests/agents/test_matrix_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ test("test dotProductAgent", async () => {
[1, 2],
[2, 3],
],
[[1, 2]],
[1, 2],
],
});
assert.deepStrictEqual(result, [5, 8]);
Expand Down

0 comments on commit 630863e

Please sign in to comment.