-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #11 from waleedkadous/refactor
Refactor
- Loading branch information
Showing
70 changed files
with
1,047 additions
and
12,436 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1 @@ | ||
web: source setup.sh && python main_langchain.py & python discord_presenter.py & wait | ||
|
||
|
||
web: gunicorn -w 4 -k uvicorn.workers.UvicornWorker main_api:app |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,56 +1,148 @@ | ||
from hermetic.agents.openai_chat_agent import OpenAIChatAgent | ||
from hermetic.core.prompt_mgr import PromptMgr | ||
from tools.kalemat import Kalemat | ||
|
||
NAME = 'ansari' | ||
class Ansari(OpenAIChatAgent): | ||
def __init__(self, env): | ||
super().__init__(model = 'gpt-4', | ||
environment = env, id=NAME) | ||
env.add_agent(NAME, self) | ||
self.pm = self.env.prompt_mgr | ||
sys_msg = self.pm.bind('system_msg') | ||
|
||
import time | ||
from pydantic import BaseModel | ||
from util.prompt_mgr import PromptMgr | ||
from tools.search_quran import SearchQuran | ||
from tools.search_hadith import SearchHadith | ||
import json | ||
import openai | ||
|
||
|
||
MODEL = 'gpt-4' | ||
MAX_FUNCTION_TRIES = 3 | ||
class Ansari: | ||
def __init__(self): | ||
sq = SearchQuran() | ||
sh = SearchHadith() | ||
self.tools = { sq.get_fn_name(): sq, sh.get_fn_name(): sh} | ||
self.model = MODEL | ||
self.pm = PromptMgr() | ||
self.sys_msg = self.pm.bind('system_msg_fn').render() | ||
self.functions = [x.get_function_description() for x in self.tools.values()] | ||
|
||
self.message_history = [{ | ||
'role': 'system', | ||
'content': sys_msg.render() | ||
'content': self.sys_msg | ||
}] | ||
|
||
def greet(self): | ||
self.greeting = self.pm.bind('greeting') | ||
return self.greeting.render() | ||
|
||
def update_message_history(self, inp): | ||
quran_decider = self.env.agents['quran_decider'] | ||
result = quran_decider.process_all(inp) | ||
print(f'quran decider returned {result}') | ||
if 'Yes' in result: | ||
# Do a secondary search here. | ||
query_extractor = self.env.agents['query_extractor'] | ||
query = query_extractor.process_all(inp) | ||
print(f'query extractor returned {query}') | ||
kalemat = self.env.tools['kalemat'] | ||
results = kalemat.run_as_string(query) | ||
print(f'kalemat returned {results}') | ||
eq = self.pm.bind('ansari_expanded_query') | ||
expanded_query = eq.render(quran_results=results, user_question=inp) | ||
print(f'expanded query is {expanded_query}') | ||
if ' flag ' in inp: | ||
expanded_query = expanded_query + '\nIt seems the user asked to flag something. Ask them what they want to flag and why.\n' | ||
self.message_history.append({ | ||
'role': 'user', | ||
'content': expanded_query | ||
}) | ||
def process_input(self, user_input): | ||
self.message_history.append({ | ||
'role': 'user', | ||
'content': user_input | ||
}) | ||
return self.process_message_history() | ||
|
||
def replace_message_history(self, message_history): | ||
self.message_history = [{ | ||
'role': 'system', | ||
'content': self.sys_msg | ||
}] + message_history | ||
for m in self.process_message_history(): | ||
if m: | ||
yield m | ||
|
||
else: | ||
print(f'In else clause {inp}') | ||
if ' flag ' in inp: | ||
print(f'In flag clause {inp}') | ||
inp = inp + '\nIt seems the user asked to flag something. Ask them what they want to flag and why.\n' | ||
self.message_history.append({ | ||
'role': 'user', | ||
'content': inp | ||
}) | ||
def process_message_history(self): | ||
# Keep processing the user input until we get something from the assistant | ||
while self.message_history[-1]['role'] != 'assistant': | ||
#print(f'Processing one round {self.message_history}') | ||
|
||
# This is pretty complicated so leaving a comment. | ||
# We want to yield from so that we can send the sequence through the input | ||
yield from self.process_one_round() | ||
|
||
def process_one_round(self): | ||
response = None | ||
while not response: | ||
try: | ||
response = openai.ChatCompletion.create( | ||
model = self.model, | ||
messages = self.message_history, | ||
stream = True, | ||
functions = self.functions, | ||
temperature = 0.0, | ||
) | ||
except Exception as e: | ||
print('Exception occurred: ', e) | ||
print('Retrying in 5 seconds...') | ||
time.sleep(5) | ||
|
||
|
||
words = '' | ||
function_name = '' | ||
function_arguments = '' | ||
response_mode = '' # words or fn | ||
for tok in response: | ||
#print(f'Token received: {tok.choices[0].delta}') | ||
delta = tok.choices[0].delta | ||
if not response_mode: | ||
# This code should only trigger the first | ||
# time through the loop. | ||
if 'function_call' in delta: | ||
# We are in function mode | ||
response_mode = 'fn' | ||
function_name = delta['function_call']['name'] | ||
else: | ||
response_mode = 'words' | ||
print('Response mode: ' + response_mode) | ||
|
||
# We process things differently depending on whether it is a function or a | ||
# text | ||
if response_mode == 'words': | ||
if not delta: # End token | ||
self.message_history.append({ | ||
'role': 'assistant', | ||
'content': words | ||
}) | ||
|
||
break | ||
elif 'content' in delta: | ||
if delta['content']: | ||
words += delta['content'] | ||
yield delta['content'] | ||
else: | ||
continue | ||
elif response_mode == 'fn': | ||
if not delta: # End token | ||
function_call = function_name + '(' + function_arguments + ')' | ||
print(f'Function call is {function_call}') | ||
# The function call below appends the function call to the message history | ||
yield self.process_fn_call(input, function_name, function_arguments) | ||
# | ||
break | ||
elif 'function_call' in delta: | ||
#print(f"Function call --{delta['function_call']['arguments']}") | ||
function_arguments += delta['function_call']['arguments'] | ||
#print(f'Function arguments are {function_arguments}') | ||
yield '' # delta['function_call']['arguments'] # we shouldn't yield anything if it's a fn | ||
else: | ||
continue | ||
else: | ||
raise Exception("Invalid response mode: " + response_mode) | ||
|
||
|
||
|
||
def process_fn_call(self, orig_question, function_name, function_arguments): | ||
if function_name in self.tools.keys(): | ||
args = json.loads(function_arguments) | ||
query = args['query'] | ||
results = self.tools[function_name].run_as_list(query) | ||
# print(f'Results are {results}') | ||
# Now we have to pass the results back in | ||
if len(results) > 0: | ||
for result in results: | ||
self.message_history.append({ | ||
'role': 'function', | ||
'name': function_name, | ||
'content': result | ||
}) | ||
else: | ||
self.message_history.append({ | ||
'role': 'function', | ||
'name': function_name, | ||
'content': 'No results found' | ||
}) | ||
else: | ||
print('Unknown function name: ', function_name) |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.