forked from Sunbird-AIAssistant/sakhi-api-service
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathquery_with_langchain.py
348 lines (300 loc) · 14.1 KB
/
query_with_langchain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
import ast
from typing import (
Any,
List,
Tuple
)
import tiktoken
from dotenv import load_dotenv
from langchain.docstore.document import Document
from env_manager import llm_class, vectorstore_class
from utils import convert_chat_messages, get_from_env_or_config
from logger import logger
from redis_util import read_messages_from_redis, store_messages_in_redis
load_dotenv()
temperature = float(get_from_env_or_config("llm", "temperature"))
chatClient = llm_class.get_client(temperature=temperature)
max_messages = int(get_from_env_or_config("llm", "max_messages")) # Maximum number of messages to include in conversation history
def querying_with_langchain_gpt3(index_id, query, context):
intent_response = check_bot_intent(query, context)
if intent_response:
return intent_response, None, 200
try:
system_rules = ""
activity_prompt_config = get_from_env_or_config("llm", "activity_prompt", None)
logger.debug(f"activity_prompt_config: {activity_prompt_config}")
if activity_prompt_config:
activity_prompt_dict = ast.literal_eval(activity_prompt_config)
system_rules = activity_prompt_dict.get(context)
top_docs_to_fetch = get_from_env_or_config("database", "top_docs_to_fetch", None)
documents = vectorstore_class.similarity_search_with_score(query, index_id, k=20)
logger.debug(f"Marqo documents : {str(documents)}")
min_score = get_from_env_or_config("database", "docs_min_score", None)
filtered_document = get_score_filtered_documents(documents, float(min_score))
filtered_document = filtered_document[:int(top_docs_to_fetch)]
logger.info(f"Score filtered documents : {str(filtered_document)}")
contexts = get_formatted_documents(filtered_document)
if not documents or not contexts:
return "I'm sorry, but I am not currently trained with relevant documents to provide a specific answer for your question.", None, 200
system_rules = system_rules.format(contexts=contexts)
logger.debug("==== System Rules ====")
logger.debug(f"System Rules : {system_rules}")
response = call_chat_model(
messages=[
{"role": "system", "content": system_rules},
{"role": "user", "content": query}
]
)
logger.info({"label": "llm_response", "response": response})
return response.strip(";"), None, 200
except Exception as e:
error_message = str(e.__context__) + " and " + e.__str__()
status_code = 500
return "", error_message, status_code
def conversation_retrieval_chain(index_id, query, session_id, context):
intent_response = check_bot_intent(query, context)
if intent_response:
return intent_response, None, 200
try:
system_rules = ""
activity_prompt_config = get_from_env_or_config("llm", "activity_prompt", None)
logger.debug(f"activity_prompt_config: {activity_prompt_config}")
activity_prompt_dict = ast.literal_eval(activity_prompt_config)
system_rules = activity_prompt_dict.get(context)
previous_messages = read_messages_from_redis(session_id)
formatted_messages = format_previous_messages(previous_messages)
user_message = {"role":"user","content": query}
intent_system_prompt = get_chat_intent_prompt()
intent_payload = create_payload_by_message_count(user_message, intent_system_prompt, messages=formatted_messages, max_messages=max_messages)
logger.debug(f"intent_payload :: {intent_payload}")
search_intent = get_intent_query(intent_payload)
logger.info(f"search_intent :: {search_intent}")
documents = vectorstore_class.similarity_search_with_score(search_intent, index_id, k=20)
logger.debug(f"Marqo documents : {str(documents)}")
min_score = get_from_env_or_config("database", "docs_min_score", None)
filtered_document = get_score_filtered_documents(documents, float(min_score))
top_docs_to_fetch = get_from_env_or_config("database", "top_docs_to_fetch", None)
filtered_document = filtered_document[:int(top_docs_to_fetch)]
logger.info(f"Score filtered documents : {str(filtered_document)}")
contexts = get_formatted_documents(filtered_document)
if not documents or not contexts:
return "I'm sorry, but I am not currently trained with relevant documents to provide a specific answer for your question.", None, 200
system_rules = system_rules.format(contexts=contexts)
system_rules = {"role": "system", "content": system_rules}
logger.debug(f"System Rules : {system_rules}")
message_payload = create_payload_by_message_count(user_message,system_rules,formatted_messages,max_messages=max_messages)
logger.debug(f"message_payload :: {message_payload}")
response = call_chat_model(message_payload)
logger.info({"label": "llm_response", "response": response})
assistant_message = format_assistant_message(response.strip(";"))
messages = read_messages_from_redis(session_id)
messages.extend([user_message,assistant_message])
store_messages_in_redis(session_id, messages)
return response.strip(";"), None, 200
except Exception as e:
error_message = str(e.__context__) + " and " + e.__str__()
status_code = 500
return "", error_message, status_code
def call_chat_model(messages: List[dict]) -> str:
converted_messsages = convert_chat_messages(messages)
response = chatClient.invoke(input=converted_messsages)
return response.content
def format_assistant_message(a):
"""Formats the assistant message
Args:
a (str, optional): assistant's reply.
Returns:
dict: formatted assistant message
"""
return {'role': 'assistant', 'content': a.strip()}
def get_chat_intent_prompt():
intent_prompt = get_from_env_or_config("llm", "chat_intent_prompt")
return {'role': "system", 'content': intent_prompt }
def get_intent_query(messages=[]):
"""
Force function calling with openai.ChatCompletion.create()
Args:
- func (dict): function schema
- messages (list): list of messages to complete the chat with
- model (str): model to use for completion
"""
function_info = {
"name": "get_search_intent",
"description": "This function takes the user's previous interactions and synthesizes it into a focused English search query that can be used to find the most relevant documents.",
"parameters": {
"type": "object",
"properties": {
"query": {
"description": "A precise English search query (typically 5-10 words long) generated from the user's previous interactions with the chatbot and/or the available documents.",
"type": "string"
}
},
"required": ["query"]
}
}
# gpt_model = get_from_env_or_config("llm", "GPT_MODEL", "gpt-4")
# response = client.chat.completions.create(
# model=gpt_model,
# messages=messages,
# # functions=[function_info],
# # function_call= {"name": function_info.get("name")},
# stream=False,
# temperature=0.1,
# )
clientIntent = llm_class.get_client(temperature=0.1)
converted_messsages = convert_chat_messages(messages)
response = clientIntent.invoke(input=converted_messsages)
# message = response.choices[0].message
# function_call = message.function_call
# arguments = json.loads(function_call.arguments)
# print("response ====>", arguments)
# return arguments
return response.content
def count_tokens_str(doc, model="gpt-4"):
"""Count tokens in a string.
Args:
doc (str): String to count tokens for.
Returns:
int: number of tokens in the string
"""
encoder = tiktoken.encoding_for_model(model) # BPE encoder # type: ignore
return len(encoder.encode(doc, disallowed_special=()))
def count_tokens(messages):
"""
Counts tokens in a list of messages.
Source: https://platform.openai.com/docs/guides/chat/introduction
Args:
messages (list): list of messages to count tokens for
Returns:
int: number of tokens in the list of messages
"""
num_tokens = 0
for message in messages:
# every message follows <im_start>{role/name}\n{content}<im_end>\n
num_tokens += 4
for key, value in message.items():
num_tokens += count_tokens_str(value)
if key == "name": # if there's a name, the role is omitted
num_tokens += -1 # role is always required and always 1 token
num_tokens += 2 # every reply is primed with <im_start>assistant
return num_tokens
def create_payload_by_message_count(user_message, system_message, messages=[], max_messages=4): # IMPORTANT
"""Get the message history for the conversation, limited by message count.
Args:
user_message (str): User message to add to the history.
system_message (str): System message to add to the history.
messages (list, optional): List of previous messages. Defaults to [].
max_messages (int, optional): Maximum number of messages to include. Defaults to 4.
Returns:
list: Message history
"""
message_history = [system_message]
total_count = max_messages * 2
message_history.extend(messages[-total_count:])
message_history.append(user_message)
return message_history
def create_message_payload(user_message, system_message, messages=[], max_tokens=3000): # IMPORTANT
"""Get the message history for the conversation.
# NOTE: Include user message {role=user,content=user_q} in the message history
Args:
message_payload (dict, optional): Formatted RAG prompt to add (temporarily) to the conversation. Defaults to {}.
max_tokens (int, optional): Maximum number of tokens to limit the message history to. Defaults to 3000.
Returns:
list: message history
NOTE:
- System-Prompt is always added to the beginning of the message history
- message_payload is added to the end of the message history (if provided)
"""
message_history = []
total_tokens = 0
system_token_count = count_tokens([system_message])
max_tokens -= system_token_count # subtract the system prompt tokens
if len(user_message) > 0:
messages = messages + [user_message]
else:
messages = messages
for message in reversed(messages):
message_tokens = count_tokens([message])
if total_tokens + message_tokens <= max_tokens:
total_tokens += message_tokens
# This inserts the message at the beginning of the list
message_history.insert(0, message)
else:
break
message_history.insert(0, system_message)
return message_history
def format_previous_messages(messages):
"""
Format previous messages for display
"""
formatted_messages = []
for message in messages:
if message['role'] == 'user':
formatted_messages.append({"role":"user", "content":f"Question: {message['content']}"})
elif message['role'] == 'assistant':
formatted_messages.append({"role":"assistant", "content":message['content']})
return formatted_messages
def check_bot_intent(query: str, context: str):
enable_bot_intent = get_from_env_or_config("llm", "enable_bot_intent", None)
logger.debug(f"enable_bot_intent: {enable_bot_intent}")
if enable_bot_intent.lower() == "false":
return None
intent_prompt = get_from_env_or_config("llm", "intent_prompt")
intent_response = call_chat_model(
messages=[{"role": "system", "content": intent_prompt}, {"role": "user", "content": query}]
)
logger.info({"label": "intent_response", "intent_response": intent_response})
if intent_response.lower() == "yes":
bot_prompt_config = get_from_env_or_config("llm", "bot_prompt", "")
logger.debug(f"bot_prompt_config: {bot_prompt_config}")
bot_prompt_dict = ast.literal_eval(bot_prompt_config)
system_rules = bot_prompt_dict.get(context)
logger.debug(f"Intent System Rules : {system_rules}")
response = call_chat_model(
messages=[
{"role": "system", "content": system_rules},
{"role": "user", "content": query}
]
)
logger.info({"label": "llm_bot_response", "bot_response": response})
return response
else:
return None
def get_score_filtered_documents(documents: List[Tuple[Document, Any]], min_score=0.0):
return [(document, search_score) for document, search_score in documents if search_score > min_score]
def get_formatted_documents(documents: List[Tuple[Document, Any]]):
sources = ""
for document, _ in documents:
sources += f"""
> {document.page_content} \n Source: {document.metadata['file_name']}, page# {document.metadata['page_label']};\n\n
"""
return sources
def generate_source_format(documents: List[Tuple[Document, Any]]) -> str:
"""Generates an answer format based on the given data.
Args:
data: A list of tuples, where each tuple contains a Document object and a
score.
Returns:
A string containing the formatted answer, listing the source documents
and their corresponding pages.
"""
try:
sources = {}
for doc, _ in documents:
file_name = doc.metadata['file_name']
page_label = doc.metadata['page_label']
sources.setdefault(file_name, []).append(page_label)
answer_format = "\nSources:\n"
counter = 1
for file_name, pages in sources.items():
answer_format += f"{counter}. {file_name} - (Pages: {', '.join(pages)})\n"
counter += 1
return answer_format
except Exception as e:
error_message = "Error while preparing source markdown"
logger.error(f"{error_message}: {e}", exc_info=True)
return ""
def concatenate_elements(arr):
# Concatenate elements from index 1 to n
separator = ': '
result = separator.join(arr[1:])