Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop #540

Merged
merged 6 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/tag.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ jobs:
- name: Set up docker Buildx
uses: docker/setup-buildx-action@v3
with:
platforms: linux/amd64,linux/arm64
platforms:
- linux/amd64
- linux/arm64

# Uses the `docker/login-action`
# action to log in to the Container registry using the account and password that will publish the packages.
Expand Down
4 changes: 1 addition & 3 deletions core/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,5 @@ RUN pip install -U pip && \
python3 -c "import nltk; nltk.download('punkt');nltk.download('averaged_perceptron_tagger')" &&\
python3 install_plugin_dependencies.py


### FINISH ###
# ready to go (docker-compose up)

CMD python3 -m cat.main
22 changes: 1 addition & 21 deletions core/cat/mad_hatter/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import glob
import traceback
import importlib
from importlib import machinery
from typing import Dict
from inspect import getmembers
from pydantic import BaseModel
Expand Down Expand Up @@ -207,27 +206,8 @@ def _load_decorated_functions(self):
tools = []
plugin_overrides = []

"""
for py_file in self.py_files:
module_name = os.path.splitext(os.path.basename(py_file))[0]

log.info(f"Import module {py_file}")

# save a reference to decorated functions
try:
plugin_module = machinery.SourceFileLoader(module_name, py_file).load_module()
hooks += getmembers(plugin_module, self._is_cat_hook)
tools += getmembers(plugin_module, self._is_cat_tool)
plugin_overrides += getmembers(plugin_module, self._is_cat_plugin_override)
except Exception as e:
log.error(f"Error in {module_name}: {str(e)}")
traceback.print_exc()
raise Exception(f"Unable to load the plugin {self._id}")
"""


for py_file in self.py_files:
py_filename = py_file.replace("/", ".").replace(".py", "") # this is UGLY I know. I'm sorry
py_filename = py_file.replace(".py", "").replace("/", ".")

log.info(f"Import module {py_filename}")

Expand Down
3 changes: 2 additions & 1 deletion core/cat/memory/working_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def update_conversation_history(self, who, message):
self["history"].append({"who": who, "message": message})

# do not allow more than k messages in convo history (+2 which are the current turn)
k = 3
# TODO: allow infinite history, but only insert in prompts the last k messages
k = 5
self["history"] = self["history"][(-k - 1):]


Expand Down
8 changes: 2 additions & 6 deletions core/cat/routes/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,8 @@ async def wipe_conversation_history(
) -> Dict:
"""Delete the specified user's conversation history from working memory"""

# TODO: Add possibility to wipe the working memory of specified user id

ccat = request.app.state.ccat
ccat.working_memory["history"] = []
ccat.working_memory_list[user_id]["history"] = []

return {
"deleted": True,
Expand All @@ -224,10 +222,8 @@ async def get_conversation_history(
) -> Dict:
"""Get the specified user's conversation history from working memory"""

# TODO: Add possibility to get the working memory of specified user id

ccat = request.app.state.ccat
history = ccat.working_memory["history"]
history = ccat.working_memory_list[user_id]["history"]

return {
"history": history
Expand Down
99 changes: 99 additions & 0 deletions core/tests/routes/memory/test_working_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import time
from tests.utils import send_websocket_message


def test_convo_history_creation(client):

response = client.get(f"/memory/conversation_history/")
json = response.json()
assert response.status_code == 200
assert "history" in json
assert len(json["history"]) == 0


def test_convo_history_update(client):

message = "It's late! It's late!"

# send websocket messages
res = send_websocket_message({
"text": message
}, client)

# check working memory update
response = client.get(f"/memory/conversation_history/")
json = response.json()
assert response.status_code == 200
assert "history" in json
assert len(json["history"]) == 2 # mex and reply
assert json["history"][0]["who"] == "Human"
assert json["history"][0]["message"] == message


def test_convo_history_reset(client):

# send websocket messages
res = send_websocket_message({
"text": "It's late! It's late!"
}, client)

# delete convo history
response = client.delete(f"/memory/conversation_history/")
assert response.status_code == 200

# check working memory update
response = client.get(f"/memory/conversation_history/")
json = response.json()
assert response.status_code == 200
assert "history" in json
assert len(json["history"]) == 0


# TODO: should be tested also with concurrency!
def test_convo_history_by_user(client):

convos = {
# user_id: n_messages
"White Rabbit": 2,
"Alice": 3
}

# send websocket messages
for user_id, n_messages in convos.items():
for m in range(n_messages):
time.sleep(0.1)
res = send_websocket_message({
"text": f"Mex n.{m} from {user_id}"
}, client, user_id=user_id)

# check working memories
for user_id, n_messages in convos.items():
response = client.get(f"/memory/conversation_history/", headers={"user_id": user_id})
json = response.json()
assert response.status_code == 200
assert "history" in json
assert len(json["history"]) == n_messages * 2 # mex and reply
for m_idx, m in enumerate(json["history"]):
assert "who" in m
assert "message" in m
if m_idx%2==0: # even message
m_number_from_user = int(m_idx/2)
assert m["who"] == "Human"
assert m["message"] == f"Mex n.{m_number_from_user} from {user_id}"
else:
assert m["who"] == "AI"

# delete White Rabbit convo
response = client.delete(f"/memory/conversation_history/", headers={"user_id": "White Rabbit"})
assert response.status_code == 200

# check convo deletion per user
### White Rabbit convo is empty
response = client.get(f"/memory/conversation_history/", headers={"user_id": "White Rabbit"})
json = response.json()
assert len(json["history"]) == 0
### Alice convo still the same
response = client.get(f"/memory/conversation_history/", headers={"user_id": "Alice"})
json = response.json()
assert len(json["history"]) == convos["Alice"] * 2

2 changes: 1 addition & 1 deletion core/tests/routes/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_websocket(client):

# send websocket message
res = send_websocket_message({
"text": "Your bald aunt with a wooden leg"
"text": "It's late! It's late"
}, client)

for k in ["type", "content", "why"]:
Expand Down
Loading