Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve web search #1537

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion src/lib/server/websearch/runWebSearch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import { mergeAsyncGenerators } from "$lib/utils/mergeAsyncGenerators";
import { MetricsServer } from "../metrics";
import { logger } from "$lib/server/logger";

const MAX_N_PAGES_TO_SCRAPE = 8 as const;
const MAX_N_PAGES_TO_SCRAPE = 15 as const;
const MAX_N_PAGES_TO_EMBED = 5 as const;

export async function* runWebSearch(
Expand Down
37 changes: 32 additions & 5 deletions src/lib/server/websearch/search/generateQuery.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { format } from "date-fns";
import type { EndpointMessage } from "../../endpoints/endpoints";
import { generateFromDefaultEndpoint } from "../../generateFromDefaultEndpoint";

export async function generateQuery(messages: Message[]) {
export async function generateQuery(messages: Message[], generatedQueries?: string[]) {
const currentDate = format(new Date(), "MMMM d, yyyy");
const userMessages = messages.filter(({ from }) => from === "user");
const previousUserMessages = userMessages.slice(0, -1);
Expand All @@ -17,6 +17,9 @@ export async function generateQuery(messages: Message[]) {
- Who is the president of France?

Current Question: What about Mexico?

Previously generated queries:
- President list of Mexico
`,
},
{
Expand Down Expand Up @@ -47,8 +50,26 @@ Current Question: Where is it being hosted?`,
from: "assistant",
content: `news ${format(new Date(Date.now() - 864e5), "MMMM d, yyyy")}`,
},
{ from: "user", content: "What is the current weather in Paris?" },
{ from: "assistant", content: `weather in Paris ${currentDate}` },
{
from: "user",
content: `Current Question: My dog has been bitten, what should the gums look like so that he is healthy and when does he need an infusion?

Previously generated queries:
- What healthy gums look like in dogs
- What unhealthy gums look like in dogs
`,
},
{ from: "assistant", content: `dog, gums, indications of necessary infusion` },
{
from: "user",
content: `Current Question: Who is Elon Musk ?

Previously generated queries:
- Elon Musk biography
- Who is Elon Musk
`,
},
{ from: "assistant", content: `Elon Musk` },
{
from: "user",
content:
Expand All @@ -58,13 +79,19 @@ Current Question: Where is it being hosted?`,
.join("\n")}`
: "") +
"\n\nCurrent Question: " +
lastMessage.content,
lastMessage.content +
"\n" +
(generatedQueries && generatedQueries.length > 0
? `Previously generated queries:\n${generatedQueries.map((q) => `- ${q}`).join("\n")}`
: ""),
},
];

const preprompt = `You are tasked with generating precise and effective web search queries to answer the user's question. Provide a concise and specific query for Google search that will yield the most relevant and up-to-date results. Include key terms and related phrases, and avoid unnecessary words. Answer with only the query. Avoid duplicates, make the prompts as divers as you can. You are not allowed to repeat queries. Today is ${currentDate}`;

const webQuery = await generateFromDefaultEndpoint({
messages: convQuery,
preprompt: `You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is ${currentDate}`,
preprompt,
generateSettings: {
max_new_tokens: 30,
},
Expand Down
53 changes: 40 additions & 13 deletions src/lib/server/websearch/search/search.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import type { MessageWebSearchUpdate } from "$lib/types/MessageUpdate";
const listSchema = z.array(z.string()).default([]);
const allowList = listSchema.parse(JSON5.parse(env.WEBSEARCH_ALLOWLIST));
const blockList = listSchema.parse(JSON5.parse(env.WEBSEARCH_BLOCKLIST));
const num_searches = 3;
flozi00 marked this conversation as resolved.
Show resolved Hide resolved

export async function* search(
messages: Message[],
Expand All @@ -25,6 +26,8 @@ export async function* search(
{ searchQuery: string; pages: WebSearchSource[] },
undefined
> {
const searchQueries: string[] = [];

if (ragSettings && ragSettings?.allowedLinks.length > 0) {
yield makeGeneralUpdate({ message: "Using links specified in Assistant" });
return {
Expand All @@ -33,24 +36,48 @@ export async function* search(
};
}

const searchQuery = query ?? (await generateQuery(messages));
yield makeGeneralUpdate({ message: `Searching ${getWebSearchProvider()}`, args: [searchQuery] });
for (let i = 0; i < num_searches; i++) {
const searchQuery = query ?? (await generateQuery(messages, searchQueries));
searchQueries.push(searchQuery);
yield makeGeneralUpdate({
message: `Searching ${getWebSearchProvider()}`,
args: [searchQuery],
});
}

let combinedResults: WebSearchSource[] = [];

// handle the global and (optional) rag lists
if (ragSettings && ragSettings?.allowedDomains.length > 0) {
yield makeGeneralUpdate({ message: "Filtering on specified domains" });
for (const searchQuery of searchQueries) {
// handle the global and (optional) rag lists
if (ragSettings && ragSettings?.allowedDomains.length > 0) {
yield makeGeneralUpdate({ message: "Filtering on specified domains" });
}
const filters = buildQueryFromSiteFilters(
[...(ragSettings?.allowedDomains ?? []), ...allowList],
blockList
);

const searchQueryWithFilters = `${filters} ${searchQuery}`;
const searchResults = await searchWeb(searchQueryWithFilters).then(filterByBlockList);
combinedResults = [...combinedResults, ...searchResults];
}
const filters = buildQueryFromSiteFilters(
[...(ragSettings?.allowedDomains ?? []), ...allowList],
blockList
);

const searchQueryWithFilters = `${filters} ${searchQuery}`;
const searchResults = await searchWeb(searchQueryWithFilters).then(filterByBlockList);
// re-sort the results by relevance
// all results are appended to the end of the list
// so the most relevant results are at the beginning
// using num_searches iterating over the list to get the most relevant results
// example input: [a1,a2,a3,a4,a5,b1,b2,b3,b4,b5,c1,c2,c3,c4,c5]
// example output: [a1,b1,c1,a2,b2,c2,a3,b3,c3,a4,b4,c4,a5,b5,c5]
const sortedResults = [];
for (let i = 0; i < num_searches; i++) {
for (let j = i; j < combinedResults.length; j += num_searches) {
sortedResults.push(combinedResults[j]);
}
}

return {
searchQuery: searchQueryWithFilters,
pages: searchResults,
searchQuery: searchQueries.join(" | "),
pages: sortedResults,
};
}

Expand Down