Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Genie tracing #32

Merged
merged 8 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integrations/langchain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pip install databricks-langchain
### Install from source

```sh
pip install git+ssh://[email protected]/databricks/databricks-ai-bridge.git#subdirectory=integrations/langchain
pip install git+https://[email protected]/databricks/databricks-ai-bridge.git#subdirectory=integrations/langchain
```

## Get started
Expand Down
8 changes: 6 additions & 2 deletions integrations/langchain/src/databricks_langchain/genie.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import mlflow
from databricks_ai_bridge.genie import Genie


@mlflow.trace()
def _concat_messages_array(messages):
concatenated_message = "\n".join(
[
Expand All @@ -13,6 +15,7 @@ def _concat_messages_array(messages):
return concatenated_message


@mlflow.trace()
def _query_genie_as_agent(input, genie_space_id, genie_agent_name):
from langchain_core.messages import AIMessage

Expand All @@ -26,12 +29,13 @@ def _query_genie_as_agent(input, genie_space_id, genie_agent_name):
# Send the message and wait for a response
genie_response = genie.ask_question(message)

if genie_response:
return {"messages": [AIMessage(content=genie_response)]}
if query_result := genie_response.result:
return {"messages": [AIMessage(content=query_result)]}
else:
return {"messages": [AIMessage(content="")]}


@mlflow.trace(span_type="AGENT")
def GenieAgent(genie_space_id, genie_agent_name="Genie", description=""):
"""Create a genie agent that can be used to query the API"""
from functools import partial
Expand Down
5 changes: 3 additions & 2 deletions integrations/langchain/tests/unit_tests/test_genie.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest.mock import patch

from databricks_ai_bridge.genie import GenieResponse
from langchain_core.messages import AIMessage

from databricks_langchain.genie import (
Expand Down Expand Up @@ -44,7 +45,7 @@ def __init__(self, role, content):
def test_query_genie_as_agent(MockGenie):
# Mock the Genie class and its response
mock_genie = MockGenie.return_value
mock_genie.ask_question.return_value = "It is sunny."
mock_genie.ask_question.return_value = GenieResponse(result="It is sunny.")

input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]}
result = _query_genie_as_agent(input_data, "space-id", "Genie")
Expand All @@ -53,7 +54,7 @@ def test_query_genie_as_agent(MockGenie):
assert result == expected_message

# Test the case when genie_response is empty
mock_genie.ask_question.return_value = None
mock_genie.ask_question.return_value = GenieResponse(result=None)
result = _query_genie_as_agent(input_data, "space-id", "Genie")

expected_message = {"messages": [AIMessage(content="")]}
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies = [
"pandas",
"tiktoken>=0.8.0",
"tabulate",
"mlflow-skinny>=2.19.0",
]

[project.license]
Expand Down
104 changes: 50 additions & 54 deletions src/databricks_ai_bridge/genie.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import logging
import time
from dataclasses import dataclass
from datetime import datetime
from typing import Union
from typing import Optional, Union

import mlflow
import pandas as pd
import tiktoken
from databricks.sdk import WorkspaceClient

MAX_TOKENS_OF_DATA = 20000 # max tokens of data in markdown format
MAX_ITERATIONS = 50 # max times to poll the API when polling for either result or the query results, each iteration is ~1 second, so max latency == 2 * MAX_ITERATIONS
MAX_TOKENS_OF_DATA = 20000
MAX_ITERATIONS = 50


# Define a function to count tokens
Expand All @@ -17,6 +19,14 @@ def _count_tokens(text):
return len(encoding.encode(text))


@dataclass
class GenieResponse:
result: Union[str, pd.DataFrame]
query: Optional[str] = ""
description: Optional[str] = ""


@mlflow.trace(span_type="PARSER")
def _parse_query_result(resp) -> Union[str, pd.DataFrame]:
columns = resp["manifest"]["schema"]["columns"]
header = [str(col["name"]) for col in columns]
Expand All @@ -40,9 +50,7 @@ def _parse_query_result(resp) -> Union[str, pd.DataFrame]:
row.append(float(str_value))
elif type_name == "BOOLEAN":
row.append(str_value.lower() == "true")
elif type_name == "DATE":
row.append(datetime.strptime(str_value[:10], "%Y-%m-%d").date())
elif type_name == "TIMESTAMP":
elif type_name == "DATE" or type_name == "TIMESTAMP":
row.append(datetime.strptime(str_value[:10], "%Y-%m-%d").date())
elif type_name == "BINARY":
row.append(bytes(str_value, "utf-8"))
Expand All @@ -53,7 +61,6 @@ def _parse_query_result(resp) -> Union[str, pd.DataFrame]:

query_result = pd.DataFrame(rows, columns=header).to_markdown()

# trim down from the total rows until we get under the token limit
tokens_used = _count_tokens(query_result)
while tokens_used > MAX_TOKENS_OF_DATA:
rows.pop()
Expand All @@ -73,6 +80,7 @@ def __init__(self, space_id):
"Content-Type": "application/json",
}

@mlflow.trace()
def start_conversation(self, content):
resp = self.genie._api.do(
"POST",
Expand All @@ -82,6 +90,7 @@ def start_conversation(self, content):
)
return resp

@mlflow.trace()
def create_message(self, conversation_id, content):
resp = self.genie._api.do(
"POST",
Expand All @@ -91,7 +100,30 @@ def create_message(self, conversation_id, content):
)
return resp

@mlflow.trace()
def poll_for_result(self, conversation_id, message_id):
@mlflow.trace()
def poll_query_results(query, description):
iteration_count = 0
while iteration_count < MAX_ITERATIONS:
iteration_count += 1
resp = self.genie._api.do(
"GET",
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/query-result",
headers=self.headers,
)["statement_response"]
state = resp["status"]["state"]
if state == "SUCCEEDED":
result = _parse_query_result(resp)
return GenieResponse(result, query, description)
elif state in ["RUNNING", "PENDING"]:
logging.debug("Waiting for query result...")
time.sleep(5)
else:
logging.debug(f"No query result: {resp['state']}")
return GenieResponse(None, query, description)

@mlflow.trace()
def poll_result():
iteration_count = 0
while iteration_count < MAX_ITERATIONS:
Expand All @@ -101,63 +133,27 @@ def poll_result():
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}",
headers=self.headers,
)
if resp["status"] == "EXECUTING_QUERY":
query = next(r for r in resp["attachments"] if "query" in r)["query"]
description = query.get("description", "")
sql = query.get("query", "")
logging.debug(f"Description: {description}")
logging.debug(f"SQL: {sql}")
return poll_query_results()
elif resp["status"] == "COMPLETED":
# Check if there is a query object in the attachments for the COMPLETED status
if resp["status"] == "EXECUTING_QUERY" or resp["status"] == "COMPLETED":
query_attachment = next((r for r in resp["attachments"] if "query" in r), None)
if query_attachment:
query = query_attachment["query"]
description = query.get("description", "")
sql = query.get("query", "")
logging.debug(f"Description: {description}")
logging.debug(f"SQL: {sql}")
return poll_query_results()
else:
# Handle the text object in the COMPLETED status
return next(r for r in resp["attachments"] if "text" in r)["text"][
query = query_attachment["query"]["query"]
description = query_attachment["query"].get("description", "")
return poll_query_results(query, description)
if resp["status"] == "COMPLETED":
text_content = next(r for r in resp["attachments"] if "text" in r)["text"][
"content"
]
elif resp["status"] == "FAILED":
logging.debug("Genie failed to execute the query")
return None
elif resp["status"] == "CANCELLED":
logging.debug("Genie query cancelled")
return None
elif resp["status"] == "QUERY_RESULT_EXPIRED":
logging.debug("Genie query result expired")
return None
return GenieResponse(result=text_content)
elif resp["status"] in ["FAILED", "CANCELLED", "QUERY_RESULT_EXPIRED"]:
logging.debug(f"Genie query {resp['status'].lower()}.")
return GenieResponse(result=None)
else:
logging.debug(f"Waiting...: {resp['status']}")
time.sleep(5)

def poll_query_results():
iteration_count = 0
while iteration_count < MAX_ITERATIONS:
iteration_count += 1
resp = self.genie._api.do(
"GET",
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/query-result",
headers=self.headers,
)["statement_response"]
state = resp["status"]["state"]
if state == "SUCCEEDED":
return _parse_query_result(resp)
elif state == "RUNNING" or state == "PENDING":
logging.debug("Waiting for query result...")
time.sleep(5)
else:
logging.debug(f"No query result: {resp['state']}")
return None

return poll_result()

@mlflow.trace()
def ask_question(self, question):
resp = self.start_conversation(question)
# TODO (prithvi): return the query and the result
return self.poll_for_result(resp["conversation_id"], resp["message_id"])
28 changes: 14 additions & 14 deletions tests/databricks_ai_bridge/test_genie.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def test_poll_for_result_completed_with_text(genie, mock_workspace_client):
mock_workspace_client.genie._api.do.side_effect = [
{"status": "COMPLETED", "attachments": [{"text": {"content": "Result"}}]},
]
result = genie.poll_for_result("123", "456")
assert result == "Result"
genie_result = genie.poll_for_result("123", "456")
assert genie_result.result == "Result"


def test_poll_for_result_completed_with_query(genie, mock_workspace_client):
Expand All @@ -64,8 +64,8 @@ def test_poll_for_result_completed_with_query(genie, mock_workspace_client):
}
},
]
result = genie.poll_for_result("123", "456")
assert result == pd.DataFrame().to_markdown()
genie_result = genie.poll_for_result("123", "456")
assert genie_result.result == pd.DataFrame().to_markdown()


def test_poll_for_result_executing_query(genie, mock_workspace_client):
Expand All @@ -84,32 +84,32 @@ def test_poll_for_result_executing_query(genie, mock_workspace_client):
}
},
]
result = genie.poll_for_result("123", "456")
assert result == pd.DataFrame().to_markdown()
genie_result = genie.poll_for_result("123", "456")
assert genie_result.result == pd.DataFrame().to_markdown()


def test_poll_for_result_failed(genie, mock_workspace_client):
mock_workspace_client.genie._api.do.side_effect = [
{"status": "FAILED"},
]
result = genie.poll_for_result("123", "456")
assert result is None
genie_result = genie.poll_for_result("123", "456")
assert genie_result.result is None


def test_poll_for_result_cancelled(genie, mock_workspace_client):
mock_workspace_client.genie._api.do.side_effect = [
{"status": "CANCELLED"},
]
result = genie.poll_for_result("123", "456")
assert result is None
genie_result = genie.poll_for_result("123", "456")
assert genie_result.result is None


def test_poll_for_result_expired(genie, mock_workspace_client):
mock_workspace_client.genie._api.do.side_effect = [
{"status": "QUERY_RESULT_EXPIRED"},
]
result = genie.poll_for_result("123", "456")
assert result is None
genie_result = genie.poll_for_result("123", "456")
assert genie_result.result is None


def test_poll_for_result_max_iterations(genie, mock_workspace_client):
Expand Down Expand Up @@ -148,8 +148,8 @@ def test_ask_question(genie, mock_workspace_client):
{"conversation_id": "123", "message_id": "456"},
{"status": "COMPLETED", "attachments": [{"text": {"content": "Answer"}}]},
]
result = genie.ask_question("What is the meaning of life?")
assert result == "Answer"
genie_result = genie.ask_question("What is the meaning of life?")
assert genie_result.result == "Answer"


def test_parse_query_result_empty():
Expand Down
Loading