Skip to content

Commit

Permalink
feat: enhance WhatsApp integration with retention settings and debug …
Browse files Browse the repository at this point in the history
…logging
  • Loading branch information
OdyAsh committed Jan 7, 2025
1 parent c18d615 commit 792fc63
Show file tree
Hide file tree
Showing 10 changed files with 213 additions and 118 deletions.
Binary file added sql/00_ansari_db_data_model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 11 additions & 8 deletions sql/10_create_whatsapp_tables.sql
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@ CREATE TABLE messages_whatsapp (
-- Debugging notes: run below commands to untie dependencies and drop tables
-- (if you're still prototyping with the tables' final schema)

-- -- Drop foreign key constraints
-- ALTER TABLE users_whatsapp DROP CONSTRAINT users_whatsapp_user_id_fkey;
-- ALTER TABLE messages_whatsapp DROP CONSTRAINT messages_whatsapp_user_id_whatsapp_fkey;

-- -- Drop the tables
-- DROP TABLE IF EXISTS users_whatsapp;
-- DROP TABLE IF EXISTS messages_whatsapp;
-- DROP TABLE IF EXISTS threads_whatsapp;
-- -- Drop foreign key constraints
--ALTER TABLE users_whatsapp DROP CONSTRAINT users_whatsapp_user_id_fkey;
--ALTER TABLE messages_whatsapp
-- DROP CONSTRAINT messages_whatsapp_user_id_whatsapp_fkey,
-- DROP CONSTRAINT messages_whatsapp_thread_id_fkey;
--ALTER TABLE threads_whatsapp DROP CONSTRAINT threads_whatsapp_user_id_whatsapp_fkey;

-- -- Drop the tables
--DROP TABLE IF EXISTS users_whatsapp;
--DROP TABLE IF EXISTS messages_whatsapp;
--DROP TABLE IF EXISTS threads_whatsapp;
Binary file removed sql/ansari_db_data_model_before_whatsapp.png
Binary file not shown.
30 changes: 24 additions & 6 deletions src/ansari/agents/ansari.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import hashlib
import json
import os
Expand Down Expand Up @@ -92,6 +93,28 @@ def replace_message_history(self, message_history, use_tool=True, stream=True):
if m:
yield m

def _debug_log_truncated_message_history(self, message_history, count: int, failures: int):
"""
Logs a truncated version of the message history for debugging purposes.
Args:
message_history (list): The message history to be truncated and logged.
"""
trunc_msg_hist = copy.deepcopy(message_history)
if (
len(trunc_msg_hist) > 1
and isinstance(trunc_msg_hist[0], dict)
and "role" in trunc_msg_hist[0]
and trunc_msg_hist[0]["role"] == "system"
and "content" in trunc_msg_hist[0]
):
sys_p = trunc_msg_hist[0]["content"]
trunc_msg_hist[0]["content"] = sys_p[:15] + "..."

logger.info(
f"Process attempt #{count+failures+1} of this message history:\n" + "-" * 60 + f"\n{trunc_msg_hist}\n" + "-" * 60,
)

@observe(capture_input=False, capture_output=False)
def process_message_history(self, use_tool=True, stream=True):
"""
Expand All @@ -110,12 +133,7 @@ def process_message_history(self, use_tool=True, stream=True):
failures = 0
while self.message_history[-1]["role"] != "assistant" or "tool_call_id" in self.message_history[-1]:
try:
logger.info(
f"Process attempt #{count+failures+1} of this message history:\n"
+ "-" * 60
+ f"\n{self.message_history}\n"
+ "-" * 60,
)
self._debug_log_truncated_message_history(self.message_history, count, failures)
# This is pretty complicated so leaving a comment.
# We want to yield from so that we can send the sequence through the input
# Also use tools only if we haven't tried too many times (failure)
Expand Down
60 changes: 31 additions & 29 deletions src/ansari/ansari_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ class MessageLogger:
without having to share details about the user_id and the thread_id
"""

def __init__(self, db: "AnsariDB", user_id: int, thread_id: int, trace_id: int, to_whatsapp: bool = False) -> None:
def __init__(self, db: "AnsariDB", user_id: int, thread_id: int, trace_id: int = None, to_whatsapp: bool = False) -> None:
if not to_whatsapp and trace_id is None:
raise ValueError("trace_id must be provided when not logging to WhatsApp")
self.user_id = user_id
self.thread_id = thread_id
self.trace_id = trace_id
Expand Down Expand Up @@ -167,7 +169,7 @@ def _execute_query(
which_fetch = [which_fetch] * len(query)

caller_function_name = inspect.stack()[1].function
logger.debug(f"Function {caller_function_name}() \nis running queries: \n{query} \nwith params: \n{params}")
logger.debug(f"Running DB function: {caller_function_name}()")

results = []
with self.get_connection() as conn:
Expand Down Expand Up @@ -249,7 +251,7 @@ def register(self, email, first_name, last_name, password_hash):
self._execute_query(insert_cmd, (email, password_hash, first_name, last_name))
return {"status": "success"}
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return {"status": "failure", "error": str(e)}

def register_whatsapp(self, phone_num: str, db_cols_to_vals: dict) -> dict:
Expand Down Expand Up @@ -281,7 +283,7 @@ def register_whatsapp(self, phone_num: str, db_cols_to_vals: dict) -> dict:

return {"status": "success"}
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return {"status": "failure", "error": str(e)}

def account_exists(self, email):
Expand All @@ -290,7 +292,7 @@ def account_exists(self, email):
result = self._execute_query(select_cmd, (email,), "one")[0]
return result is not None
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return False

def account_exists_whatsapp(self, phone_num):
Expand All @@ -299,7 +301,7 @@ def account_exists_whatsapp(self, phone_num):
result = self._execute_query(select_cmd, (phone_num,), "one")[0]
return result is not None
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return False

def save_access_token(self, user_id, token):
Expand All @@ -313,7 +315,7 @@ def save_access_token(self, user_id, token):
"token_db_id": inserted_id,
}
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return {"status": "failure", "error": str(e)}

def save_refresh_token(self, user_id, token, access_token_id):
Expand All @@ -322,7 +324,7 @@ def save_refresh_token(self, user_id, token, access_token_id):
self._execute_query(insert_cmd, (user_id, token, access_token_id))
return {"status": "success", "token": token}
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return {"status": "failure", "error": str(e)}

def save_reset_token(self, user_id, token):
Expand All @@ -334,7 +336,7 @@ def save_reset_token(self, user_id, token):
self._execute_query(insert_cmd, (user_id, token, token))
return {"status": "success", "token": token}
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return {"status": "failure", "error": str(e)}

def retrieve_user_info(self, email):
Expand All @@ -346,7 +348,7 @@ def retrieve_user_info(self, email):
return user_id, existing_hash, first_name, last_name
return None, None, None, None
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return None, None, None, None

def retrieve_user_info_whatsapp(self, phone_num: str, db_cols: Union[list, str]) -> Optional[Tuple]:
Expand Down Expand Up @@ -382,7 +384,7 @@ def retrieve_user_info_whatsapp(self, phone_num: str, db_cols: Union[list, str])
return result
return None
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return None

def add_feedback(self, user_id, thread_id, message_id, feedback_class, comment):
Expand All @@ -393,7 +395,7 @@ def add_feedback(self, user_id, thread_id, message_id, feedback_class, comment):
self._execute_query(insert_cmd, (user_id, thread_id, message_id, feedback_class, comment))
return {"status": "success"}
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return {"status": "failure", "error": str(e)}

def create_thread(self, user_id):
Expand All @@ -403,7 +405,7 @@ def create_thread(self, user_id):
inserted_id = result[0] if result else None
return {"status": "success", "thread_id": inserted_id}
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return {"status": "failure", "error": str(e)}

def create_thread_whatsapp(self, user_id_whatsapp: int, thread_name: str) -> str:
Expand All @@ -426,7 +428,7 @@ def create_thread_whatsapp(self, user_id_whatsapp: int, thread_name: str) -> str
result = self._execute_query(insert_cmd, (user_id_whatsapp, thread_name), "one")[0]
return result[0] if result else None
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return None

def get_all_threads(self, user_id):
Expand All @@ -435,7 +437,7 @@ def get_all_threads(self, user_id):
result = self._execute_query(select_cmd, (user_id,), "all")[0]
return [{"thread_id": x[0], "thread_name": x[1], "updated_at": x[2]} for x in result] if result else []
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return []

def set_thread_name(self, thread_id, user_id, thread_name):
Expand All @@ -454,7 +456,7 @@ def set_thread_name(self, thread_id, user_id, thread_name):
)
return {"status": "success"}
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return {"status": "failure", "error": str(e)}

def append_message(self, user_id, thread_id, role, content, tool_name=None):
Expand All @@ -474,7 +476,7 @@ def append_message(self, user_id, thread_id, role, content, tool_name=None):

return {"status": "success"}
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return {"status": "failure", "error": str(e)}

def append_message_whatsapp(self, user_id_whatsapp: int, thread_id: int, db_cols_to_vals: dict) -> dict:
Expand Down Expand Up @@ -511,7 +513,7 @@ def append_message_whatsapp(self, user_id_whatsapp: int, thread_id: int, db_cols

return {"status": "success"}
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return {"status": "failure", "error": str(e)}

def get_thread(self, thread_id, user_id):
Expand Down Expand Up @@ -542,7 +544,7 @@ def get_thread(self, thread_id, user_id):
}
return retval
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return {}

def get_thread_llm(self, thread_id, user_id):
Expand Down Expand Up @@ -574,7 +576,7 @@ def get_thread_llm(self, thread_id, user_id):
}
return retval
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return {}

def get_thread_llm_whatsapp(self, thread_id: str, user_id_whatsapp: int) -> list[dict]:
Expand Down Expand Up @@ -603,7 +605,7 @@ def get_thread_llm_whatsapp(self, thread_id: str, user_id_whatsapp: int) -> list
else []
)
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return []

def get_last_message_time_whatsapp(self, user_id_whatsapp: int) -> tuple[Optional[str], Optional[datetime]]:
Expand All @@ -630,7 +632,7 @@ def get_last_message_time_whatsapp(self, user_id_whatsapp: int) -> tuple[Optiona
return result[0], result[1]
return None, None
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return None, None

def snapshot_thread(self, thread_id, user_id):
Expand All @@ -649,7 +651,7 @@ def snapshot_thread(self, thread_id, user_id):
logger.info(f"Result is {result}")
return result[0] if result else None
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return {"status": "failure", "error": str(e)}

def get_snapshot(self, share_uuid):
Expand All @@ -662,7 +664,7 @@ def get_snapshot(self, share_uuid):
return json.loads(result[0])
return {}
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return {}

def delete_thread(self, thread_id, user_id):
Expand All @@ -675,7 +677,7 @@ def delete_thread(self, thread_id, user_id):
self._execute_query([delete_cmd_1, delete_cmd_2], [params, params])
return {"status": "success"}
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return {"status": "failure", "error": str(e)}

def delete_access_refresh_tokens_pair(self, refresh_token):
Expand Down Expand Up @@ -715,7 +717,7 @@ def delete_access_token(self, user_id, token):
self._execute_query(delete_cmd, (user_id, token))
return {"status": "success"}
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return {"status": "failure", "error": str(e)}

def logout(self, user_id, token):
Expand All @@ -725,7 +727,7 @@ def logout(self, user_id, token):
self._execute_query(delete_cmd, (user_id, token))
return {"status": "success"}
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return {"status": "failure", "error": str(e)}

def set_pref(self, user_id, key, value):
Expand All @@ -750,7 +752,7 @@ def update_password(self, user_id, new_password_hash):
self._execute_query(update_cmd, (new_password_hash, user_id))
return {"status": "success"}
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return {"status": "failure", "error": str(e)}

def update_user_whatsapp(self, phone_num: str, db_cols_to_vals: dict) -> dict:
Expand Down Expand Up @@ -782,7 +784,7 @@ def update_user_whatsapp(self, phone_num: str, db_cols_to_vals: dict) -> dict:

return {"status": "success"}
except Exception as e:
logger.warning(f"Error is {e}")
logger.warning(f"Warning (possbile error): {e}")
return {"status": "failure", "error": str(e)}

def convert_message(self, msg):
Expand Down
Loading

0 comments on commit 792fc63

Please sign in to comment.