Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve web search #1537

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion src/lib/server/tools/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import calculator from "./calculator";
import directlyAnswer from "./directlyAnswer";
import fetchUrl from "./web/url";
import websearch from "./web/search";
import weather from "./weather";
import { callSpace, getIpToken } from "./utils";
import { uploadFile } from "../files/uploadFile";
import type { MessageFile } from "$lib/types/Message";
Expand Down Expand Up @@ -127,7 +128,7 @@ export const configTools = z
}))
)
// add the extra hardcoded tools
.transform((val) => [...val, calculator, directlyAnswer, fetchUrl, websearch]);
.transform((val) => [...val, calculator, directlyAnswer, fetchUrl, websearch, weather]);

export function getCallMethod(tool: Omit<BaseTool, "call">): BackendCall {
return async function* (params, ctx, uuid) {
Expand Down
27 changes: 27 additions & 0 deletions src/lib/server/tools/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,30 @@ export async function extractJson(text: string): Promise<unknown[]> {
}
return calls.flat();
}

export async function fetchWeatherData(latitude: number, longitude: number): Promise<ArrayBuffer> {
const response = await fetch(
`https://api.open-meteo.com/v1/forecast?latitude=${latitude}&longitude=${longitude}&hourly=temperature_2m`
);
if (!response.ok) {
throw new Error("Failed to fetch weather data");
}
return response.json();
}

export async function fetchCoordinates(
location: string
): Promise<{ latitude: number; longitude: number }> {
const response = await fetch(
`https://geocoding-api.open-meteo.com/v1/search?name=${location}&count=1`
);
if (!response.ok) {
throw new Error("Failed to fetch coordinates");
}
const data = await response.json();
if (data.results.length === 0) {
throw new Error("Location not found");
}
const { latitude, longitude } = data.results[0];
return { latitude, longitude };
}
42 changes: 42 additions & 0 deletions src/lib/server/tools/weather.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import type { ConfigTool } from "$lib/types/Tool";
import { ObjectId } from "mongodb";
import { fetchWeatherData, fetchCoordinates } from "./utils";

const weather: ConfigTool = {
_id: new ObjectId("00000000000000000000000D"),
type: "config",
description: "Fetch the weather for a specified location",
color: "blue",
icon: "cloud",
displayName: "Weather",
name: "weather",
endpoint: null,
inputs: [
{
name: "location",
type: "str",
description: "The name of the location to fetch the weather for",
paramType: "required",
},
],
outputComponent: null,
outputComponentIdx: null,
showOutput: false,
async *call({ location }) {
try {
if (typeof location !== "string") {
throw new Error("Location must be a string");
}
const coordinates = await fetchCoordinates(location);
const weatherData = await fetchWeatherData(coordinates.latitude, coordinates.longitude);

return {
outputs: [{ weather: weatherData }],
};
} catch (error) {
throw new Error("Failed to fetch weather data", { cause: error });
}
},
};

export default weather;
2 changes: 1 addition & 1 deletion src/lib/server/websearch/runWebSearch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import { mergeAsyncGenerators } from "$lib/utils/mergeAsyncGenerators";
import { MetricsServer } from "../metrics";
import { logger } from "$lib/server/logger";

const MAX_N_PAGES_TO_SCRAPE = 8 as const;
const MAX_N_PAGES_TO_SCRAPE = 15 as const;
const MAX_N_PAGES_TO_EMBED = 5 as const;

export async function* runWebSearch(
Expand Down
39 changes: 33 additions & 6 deletions src/lib/server/websearch/search/generateQuery.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ import type { Message } from "$lib/types/Message";
import { format } from "date-fns";
import type { EndpointMessage } from "../../endpoints/endpoints";
import { generateFromDefaultEndpoint } from "../../generateFromDefaultEndpoint";
import { env } from "$env/dynamic/private";

const num_searches = env.NUM_SEARCHES ? parseInt(env.NUM_SEARCHES, 10) : 3;

export async function generateQuery(messages: Message[]) {
const currentDate = format(new Date(), "MMMM d, yyyy");
Expand Down Expand Up @@ -47,8 +50,26 @@ Current Question: Where is it being hosted?`,
from: "assistant",
content: `news ${format(new Date(Date.now() - 864e5), "MMMM d, yyyy")}`,
},
{ from: "user", content: "What is the current weather in Paris?" },
{ from: "assistant", content: `weather in Paris ${currentDate}` },
{
from: "user",
content: `Current Question: My dog has been bitten, what should the gums look like so that he is healthy and when does he need an infusion?`,
},
{
from: "assistant",
content: `What healthy gums look like in dogs
What unhealthy gums look like in dogs
When dogs need an infusion, gum signals
`,
},
{
from: "user",
content: `Current Question: Who is Elon Musk ?`,
},
{
from: "assistant",
content: `Elon Musk
Elon Musk Biography`,
},
{
from: "user",
content:
Expand All @@ -62,13 +83,19 @@ Current Question: Where is it being hosted?`,
},
];

const preprompt = `You are tasked with generating precise and effective web search queries to answer the user's question. Provide a concise and specific query for Google search that will yield the most relevant and up-to-date results. Include key terms and related phrases, and avoid unnecessary words. Answer with only the queries split by linebreaks. Avoid duplicates, make the prompts as divers as you can. You are not allowed to repeat queries. Today is ${currentDate}`;

const webQuery = await generateFromDefaultEndpoint({
messages: convQuery,
preprompt: `You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is ${currentDate}`,
preprompt,
generateSettings: {
max_new_tokens: 30,
max_new_tokens: 128,
},
});

return webQuery.trim();
// transform to list, split by linebreaks
const webQueryList = webQuery.split("\n").map((query) => query.trim());
// remove duplicates
const uniqueWebQueryList = Array.from(new Set(webQueryList));
// return only the first num_searches queries
return uniqueWebQueryList.slice(0, num_searches);
}
87 changes: 73 additions & 14 deletions src/lib/server/websearch/search/search.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,91 @@ export async function* search(
{ searchQuery: string; pages: WebSearchSource[] },
undefined
> {
const newLinks: string[] = [];
let requireQuery = false;

if (ragSettings && ragSettings?.allowedLinks.length > 0) {
for (const link of ragSettings.allowedLinks) {
if (link.includes("[query]")) {
requireQuery = true;
break;
}
}
if (!requireQuery) {
yield makeGeneralUpdate({ message: "Using links specified in Assistant" });
return {
searchQuery: "",
pages: await directLinksToSource(ragSettings?.allowedLinks).then(filterByBlockList),
};
}
}

let searchQueries = await generateQuery(messages);
if (!searchQueries.length && query) {
searchQueries = [query];
}

for (const searchQuery of searchQueries) {
if (ragSettings && ragSettings?.allowedLinks.length > 0) {
for (const link of ragSettings.allowedLinks) {
const newLink = link.replace("[query]", encodeURIComponent(searchQuery));
if (!newLinks.includes(newLink)) {
newLinks.push(newLink);
}
}
yield makeGeneralUpdate({
message: `Querying provided Endpoints with`,
args: [searchQuery],
});
} else {
yield makeGeneralUpdate({
message: `Searching ${getWebSearchProvider()}`,
args: [searchQuery],
});
}
}

if (newLinks.length > 0) {
yield makeGeneralUpdate({ message: "Using links specified in Assistant" });
return {
searchQuery: "",
pages: await directLinksToSource(ragSettings.allowedLinks).then(filterByBlockList),
pages: await directLinksToSource(newLinks).then(filterByBlockList),
};
}

const searchQuery = query ?? (await generateQuery(messages));
yield makeGeneralUpdate({ message: `Searching ${getWebSearchProvider()}`, args: [searchQuery] });
let combinedResults: WebSearchSource[] = [];

for (const searchQuery of searchQueries) {
// handle the global and (optional) rag lists
if (ragSettings && ragSettings?.allowedDomains.length > 0) {
yield makeGeneralUpdate({ message: "Filtering on specified domains" });
}
const filters = buildQueryFromSiteFilters(
[...(ragSettings?.allowedDomains ?? []), ...allowList],
blockList
);

// handle the global and (optional) rag lists
if (ragSettings && ragSettings?.allowedDomains.length > 0) {
yield makeGeneralUpdate({ message: "Filtering on specified domains" });
const searchQueryWithFilters = `${filters} ${searchQuery}`;
const searchResults = await searchWeb(searchQueryWithFilters).then(filterByBlockList);
combinedResults = [...combinedResults, ...searchResults];
}
const filters = buildQueryFromSiteFilters(
[...(ragSettings?.allowedDomains ?? []), ...allowList],
blockList
);

const searchQueryWithFilters = `${filters} ${searchQuery}`;
const searchResults = await searchWeb(searchQueryWithFilters).then(filterByBlockList);
// re-sort the results by relevance
// all results are appended to the end of the list
// so the most relevant results are at the beginning
// using num_searches iterating over the list to get the most relevant results
// example input: [a1,a2,a3,a4,a5,b1,b2,b3,b4,b5,c1,c2,c3,c4,c5]
// example output: [a1,b1,c1,a2,b2,c2,a3,b3,c3,a4,b4,c4,a5,b5,c5]
const sortedResults = [];
for (let i = 0; i < searchQueries.length; i++) {
for (let j = i; j < combinedResults.length; j += searchQueries.length) {
sortedResults.push(combinedResults[j]);
}
}

return {
searchQuery: searchQueryWithFilters,
pages: searchResults,
searchQuery: searchQueries.join(" | "),
pages: sortedResults,
};
}

Expand Down