-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
159 lines (134 loc) · 5.68 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import sys
import os
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from dotenv import load_dotenv
from langchain_openai import AzureChatOpenAI
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.schema import SystemMessage, HumanMessage
from langchain.prompts import PromptTemplate
import openai
import uvicorn
# Load environment variables from .env file
print("Loading environment variables...")
load_dotenv()
# Verify if the necessary environment variables are set
required_env_vars = ["AZURE_OPENAI_ENDPOINT", "OPENAI_API_VERSION", "AZURE_OPENAI_API_KEY"]
missing_vars = [var for var in required_env_vars if os.getenv(var) is None]
if missing_vars:
print(f"Error: Missing environment variables: {', '.join(missing_vars)}")
sys.exit(1)
# Initialize FastAPI app
app = FastAPI()
# Configure CORS
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Azure OpenAI setup
print("Setting up Azure OpenAI...")
openai.api_type = "azure"
openai.api_base = os.getenv("AZURE_OPENAI_ENDPOINT")
openai.api_version = os.getenv("OPENAI_API_VERSION")
openai.api_key = os.getenv("AZURE_OPENAI_API_KEY")
print("Azure OpenAI setup completed.")
# Function to get all MDX files from the docs directory and its subdirectories
def get_mdx_files(directory):
print(f"Searching for MDX files in {directory}...")
mdx_files = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith('.md'):
mdx_files.append(os.path.join(root, file))
print(f"Found {len(mdx_files)} MDX files.")
return mdx_files
# Function to create a vector database for the provided MDX files
def create_vectordb(files, filenames):
try:
print("Creating vector database...")
from brain import get_index_for_mdx
vectordb = get_index_for_mdx(files, filenames)
print("Vector database created successfully.")
return vectordb
except Exception as e:
print(f"Error creating vector database: {str(e)}")
sys.exit(1)
# Load MDX files and create vector database
docs_folder = os.path.join(os.getcwd(), "docs")
mdx_file_paths = get_mdx_files(docs_folder)
if mdx_file_paths:
try:
print("Loading and creating vector database from MDX files...")
mdx_files = [open(f, "rb").read() for f in mdx_file_paths]
mdx_file_names = [os.path.basename(f) for f in mdx_file_paths]
vectordb = create_vectordb(mdx_files, mdx_file_names)
except Exception as e:
print(f"Error processing MDX files: {str(e)}")
sys.exit(1)
# Create a conversational chain
print("Creating conversational chain...")
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key="answer", max_messages=10)
template = """
You are a helpful assistant specialized in answering technical questions related to Keploy. You are provided with context from a vector database and a chat history. Your task is to answer the user's question based on the provided context and the chat history. If you don't know the answer, just say 'I don't know'. Do not try to make up an answer. If the question is not related to Keploy, say 'I am not sure about that'."
Context: {context}
Question: {question}
Answer:
"""
prompt = PromptTemplate(
template=template,
input_variables=["context", "question"]
)
llm = AzureChatOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
azure_deployment="keploy-gpt4o",
openai_api_version=os.getenv("OPENAI_API_VERSION"),
openai_api_type="azure",
temperature=0.7,
)
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=vectordb.as_retriever(search_kwargs={"k": 3}),
memory=memory,
return_source_documents=True,
verbose=False,
combine_docs_chain_kwargs={"prompt": prompt}
)
print("Conversational chain created successfully.")
else:
print("Error: No MDX files found in the docs folder.")
sys.exit(1)
# Define the Question model for the API request
class Question(BaseModel):
question: str
# API endpoint to handle chat queries
@app.post('/chat')
def chat(question: Question):
print("Received chat request")
if not question.question:
raise HTTPException(status_code=400, detail="No question provided")
try:
# Perform similarity search on vector database
search_results = vectordb.similarity_search(question.question, k=3)
context = "\n".join([doc.page_content for doc in search_results])
# Get response from conversation chain
response = conversation_chain({"question": question.question})
# Prepare the result response
result = {
"answer": response['answer'],
"sources": [doc.metadata.get('source', 'Unknown') for doc in response.get('source_documents', [])]
}
return result
except Exception as e:
print(f"Error during chat processing: {str(e)}")
raise HTTPException(status_code=500, detail="An error occurred during chat processing")
# Main entry point to start the FastAPI server
if __name__ == '__main__':
print("Starting FastAPI app...")
uvicorn.run(app, host="0.0.0.0", port=8000)
print("FastAPI app started.")