Skip to content

Commit

Permalink
Merge pull request #551 from Pingdred/why_in_wm
Browse files Browse the repository at this point in the history
Added `why` in the working memory
  • Loading branch information
pieroit authored Nov 8, 2023
2 parents c928d7a + fea5461 commit 8bae130
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
7 changes: 4 additions & 3 deletions core/cat/looking_glass/cheshire_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,7 @@ def __call__(self, user_message_json):
log.info("cat_message:")
log.info(cat_message)

# update conversation history
user_message = user_working_memory["user_message_json"]["text"]
user_working_memory.update_conversation_history(who="Human", message=user_message)
user_working_memory.update_conversation_history(who="AI", message=cat_message["output"])

# store user message in episodic memory
# TODO: vectorize and store also conversation chunks
Expand Down Expand Up @@ -513,6 +510,10 @@ def __call__(self, user_message_json):

final_output = self.mad_hatter.execute_hook("before_cat_sends_message", final_output)

# update conversation history
user_working_memory.update_conversation_history(who="Human", message=user_message)
user_working_memory.update_conversation_history(who="AI", message=final_output["content"], why=final_output["why"])

return final_output


Expand Down
4 changes: 2 additions & 2 deletions core/cat/memory/working_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_user_id(self):

return self["user_message_json"]["user_id"]

def update_conversation_history(self, who, message):
def update_conversation_history(self, who, message, why={}):
"""Update the conversation history.
The methods append to the history key the last three conversation turns.
Expand All @@ -41,7 +41,7 @@ def update_conversation_history(self, who, message):
"""
# append latest message in conversation
self["history"].append({"who": who, "message": message})
self["history"].append({"who": who, "message": message, "why": why})

# do not allow more than k messages in convo history (+2 which are the current turn)
# TODO: allow infinite history, but only insert in prompts the last k messages
Expand Down
5 changes: 3 additions & 2 deletions core/cat/routes/websocket.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import traceback
import asyncio
from cat.looking_glass.cheshire_cat import CheshireCat
from typing import Dict, Optional
from typing import Dict
from fastapi import APIRouter, WebSocketDisconnect, WebSocket
from cat.log import log
from fastapi.concurrency import run_in_threadpool
Expand Down Expand Up @@ -88,8 +88,9 @@ async def websocket_endpoint(websocket: WebSocket, user_id: str = "user"):
# Retrieve the `ccat` instance from the application's state.
ccat = websocket.app.state.ccat

# Skip the coroutine if the same user is already connected via WebSocket.
if user_id in manager.active_connections:
# Skip the coroutine if the same user is already connected via WebSocket.
log.error(f"A websocket connection with ID '{user_id}' has already been opened.")
return

# Add the new WebSocket connection to the manager.
Expand Down

0 comments on commit 8bae130

Please sign in to comment.