Skip to content

Commit

Permalink
fix: (cohere) convert user's message to search query (#1248)
Browse files Browse the repository at this point in the history
  • Loading branch information
abvthecity authored Aug 6, 2024
1 parent 8cbf09d commit bfc1eea
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 101 deletions.
4 changes: 3 additions & 1 deletion packages/ui/chatbot/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@
"react": "^18.3.1",
"react-dom": "^18.3.1",
"react-markdown": "^9.0.1",
"remark-gfm": "^4.0.0"
"remark-gfm": "^4.0.0",
"uuid": "^9.0.0"
},
"devDependencies": {
"@tailwindcss/typography": "^0.5.10",
"@types/hast": "^3.0.4",
"@types/react": "^18.3.3",
"@types/react-dom": "^18.3.0",
"@types/uuid": "^9.0.1",
"@typescript-eslint/eslint-plugin": "^7.15.0",
"@typescript-eslint/parser": "^7.15.0",
"@vitejs/plugin-react-swc": "^3.5.0",
Expand Down
7 changes: 4 additions & 3 deletions packages/ui/chatbot/src/components/ChatbotModal.tsx
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { FernScrollArea } from "@fern-ui/components";
import clsx from "clsx";
import { throttle, uniqueId } from "lodash-es";
import { throttle } from "lodash-es";
import { forwardRef, useRef, useState } from "react";
import type { Components } from "react-markdown";
import { v4 } from "uuid";
import { ChatbotMessage, Citation, Message } from "../types";
import { AskInput } from "./AskInput";
import { ChatConversation } from "./ChatConversation";
Expand All @@ -26,7 +27,7 @@ interface ChatbotModalProps {
export const ChatbotModal = forwardRef<HTMLElement, ChatbotModalProps>(
({ chatStream, className, components, belowInput }, ref) => {
const [chatHistory, setChatHistory] = useState<ChatHistory>(() => ({
conversationId: uniqueId(),
conversationId: v4(),
messages: [],
}));
const [isStreaming, setIsStreaming] = useState(false);
Expand Down Expand Up @@ -86,7 +87,7 @@ export const ChatbotModal = forwardRef<HTMLElement, ChatbotModalProps>(

const reset = () => {
setChatHistory({
conversationId: uniqueId(),
conversationId: v4(),
messages: [],
});
setIsStreaming(false);
Expand Down
3 changes: 2 additions & 1 deletion packages/ui/docs-bundle/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@
"sharp": "^0.33.3",
"ts-essentials": "^10.0.1",
"url-join": "5.0.0",
"uuid": "^9.0.0"
"uuid": "^9.0.0",
"zod": "^3.23.8"
},
"devDependencies": {
"@fern-platform/configs": "workspace:*",
Expand Down
59 changes: 38 additions & 21 deletions packages/ui/docs-bundle/src/pages/api/fern-docs/search/cohere.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { Cohere, CohereClient } from "cohere-ai";
import { ChatMessage } from "cohere-ai/api";
import { NextRequest } from "next/server";
import { v4 } from "uuid";
import { z } from "zod";
import { getXFernHostEdge } from "../../../../utils/xFernHost";

export const runtime = "edge";
Expand All @@ -29,6 +30,11 @@ The user asking questions may be a developer, technical writer, or product manag
- Always be respectful and professional. If you are unsure about a question, let the user know.
`;

const RequestSchema = z.strictObject({
conversationId: z.string().optional(),
message: z.string(),
});

class ConversationCache {
constructor(private conversationId: string) {}
async get() {
Expand Down Expand Up @@ -63,6 +69,16 @@ export default async function handler(req: NextRequest): Promise<Response> {
return new Response(null, { status: 401 });
}

const body = RequestSchema.safeParse(await req.json());
if (body.success === false) {
// eslint-disable-next-line no-console
console.error(body.error);
return new Response(null, { status: 400 });
}
const { conversationId = v4(), message } = body.data;
const cache = new ConversationCache(conversationId);
const chatHistory = await cache.get();

const docs = await REGISTRY_SERVICE.docs.v2.read.getDocsForUrl({ url: docsUrl });
if (!docs.ok) {
if (docs.error.error === "UnauthorizedError") {
Expand Down Expand Up @@ -102,17 +118,29 @@ export default async function handler(req: NextRequest): Promise<Response> {
return new Response(null, { status: 500 });
}

const body = await req.json();
assertNonNullish(process.env.NEXT_PUBLIC_ALGOLIA_APP_ID, "NEXT_PUBLIC_ALGOLIA_APP_ID is required");
assertNonNullish(process.env.NEXT_PUBLIC_ALGOLIA_API_KEY, "NEXT_PUBLIC_ALGOLIA_API_KEY is required");
assertNonNullish(process.env.NEXT_PUBLIC_ALGOLIA_SEARCH_INDEX, "NEXT_PUBLIC_ALGOLIA_SEARCH_INDEX is required");
assertNonNullish(process.env.COHERE_API_KEY, "COHERE_API_KEY is required");

assertNonNullish(process.env.NEXT_PUBLIC_ALGOLIA_APP_ID);
assertNonNullish(process.env.NEXT_PUBLIC_ALGOLIA_API_KEY);
assertNonNullish(process.env.NEXT_PUBLIC_ALGOLIA_SEARCH_INDEX);
const algolia = algoliasearch(process.env.NEXT_PUBLIC_ALGOLIA_APP_ID, searchApiKey.body.searchApiKey, {
// create clients
const index = algoliasearch(process.env.NEXT_PUBLIC_ALGOLIA_APP_ID, searchApiKey.body.searchApiKey, {
requester: createFetchRequester(),
}).initIndex(process.env.NEXT_PUBLIC_ALGOLIA_SEARCH_INDEX);
const cohere = new CohereClient({ token: process.env.COHERE_API_KEY });

// construct search query
const query = await cohere.chat({
preamble:
"Convert the user's last message into a search query, which will be used to find relevant documentation. Don't respond with anything else.",
chatHistory,
message,
});
const index = algolia.initIndex(process.env.NEXT_PUBLIC_ALGOLIA_SEARCH_INDEX);
const { hits } = await index.search<Algolia.AlgoliaRecord>(body.message, { hitsPerPage: 20 });

// execute search
const { hits } = await index.search<Algolia.AlgoliaRecord>(query.text, { hitsPerPage: 10 });

// convert search results to cohere documents
const documents: Record<string, string>[] = hits
.map((hit) => ({
id: getSlugForSearchRecord(hit, docs.body.baseUrl.basePath ?? ""),
Expand All @@ -122,26 +150,15 @@ export default async function handler(req: NextRequest): Promise<Response> {
// remove undefined values
.map((doc) => JSON.parse(JSON.stringify(doc)));

const cohere = new CohereClient({
token: process.env.COHERE_API_KEY,
});

let conversationId = body.conversationId;

if (!body.conversationId) {
conversationId = v4();
}
const cache = new ConversationCache(conversationId);

const chatHistory = await cache.get();
// re-run the user's message with the search results
const response = await cohere.chatStream({
preamble: PREAMBLE,
chatHistory,
message: body.message,
message,
documents,
});

chatHistory.push({ role: "USER", message: body.message });
chatHistory.push({ role: "USER", message });

const stream = convertAsyncIterableToStream(response).pipeThrough(getCohereStreamTransformer(cache, chatHistory));

Expand Down
Loading

0 comments on commit bfc1eea

Please sign in to comment.