Skip to content

Commit

Permalink
add sources coming from Knowledge base + add metadata capability for …
Browse files Browse the repository at this point in the history
…bedrock LLM and OpenAI
  • Loading branch information
hghandri committed Jan 17, 2025
1 parent bd8c42c commit eb084b6
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 60 deletions.
3 changes: 3 additions & 0 deletions python/src/multi_agent_orchestrator/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def on_llm_new_token(self, message: ConversationMessage) -> None:
# Default implementation
pass

def on_llm_end(self, token: ConversationMessage) -> None:
# Default implementation
pass

@dataclass
class AgentOptions:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def process_request(

if self.retriever:
response = await self.retriever.retrieve_and_combine_results(input_text)
context_prompt = f"\nHere is the context to use to answer the user's question:\n{response}"
context_prompt = f"\nHere is the context to use to answer the user's question:\n{response['text']}"
system_prompt += context_prompt

input = {
Expand Down
75 changes: 51 additions & 24 deletions python/src/multi_agent_orchestrator/agents/bedrock_llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
import os
import boto3
from multi_agent_orchestrator.agents import Agent, AgentOptions
from multi_agent_orchestrator.types import (ConversationMessage,
ConversationMessageMetadata,
from multi_agent_orchestrator.types import (ConversationMessage, ConversationMessageMetadata,
ParticipantRole,
BEDROCK_MODEL_ID_CLAUDE_3_HAIKU,
TemplateVariables,
AgentProviderType)
from multi_agent_orchestrator.utils import conversation_to_dict, Logger, Tools
from multi_agent_orchestrator.retrievers import Retriever

import traceback

@dataclass
class BedrockLLMAgentOptions(AgentOptions):
Expand Down Expand Up @@ -113,11 +113,13 @@ async def process_request(
self.update_system_prompt()

system_prompt = self.system_prompt
citations = []

if self.retriever:
response = await self.retriever.retrieve_and_combine_results(input_text)
context_prompt = "\nHere is the context to use to answer the user's question:\n" + response
context_prompt = "\nHere is the context to use to answer the user's question:\n" + response['text']
system_prompt += context_prompt
citations = response['sources']

converse_cmd = {
'modelId': self.model_id,
Expand Down Expand Up @@ -150,6 +152,17 @@ async def process_request(
else:
bedrock_response = await self.handle_single_response(converse_cmd)

if citations:
if not converse_message.metadata:
bedrock_response['metadata'] = ConversationMessageMetadata()

bedrock_response.metadata.citations.extend(citations)

if self.streaming:
self.callbacks.on_llm_end(
bedrock_response
)

conversation.append(bedrock_response)

if any('toolUse' in content for content in bedrock_response.content):
Expand All @@ -171,17 +184,31 @@ async def process_request(
return final_message

if self.streaming:
return await self.handle_streaming_response(converse_cmd)
converse_message = await self.handle_streaming_response(converse_cmd)
else:
converse_message = await self.handle_single_response(converse_cmd)

if citations:
if not converse_message.metadata:
converse_message['metadata'] = ConversationMessageMetadata()

converse_message.metadata.citations.extend(citations)

if self.streaming:
self.callbacks.on_llm_end(
converse_message
)

return await self.handle_single_response(converse_cmd)
return converse_message

async def handle_single_response(self, converse_input: dict[str, Any]) -> ConversationMessage:
try:
response = self.client.converse(**converse_input)
if 'output' not in response:
raise ValueError("No output received from Bedrock model")

return ConversationMessage(
role=response['output']['message']['role'],
role=ParticipantRole.ASSISTANT.value,
content=response['output']['message']['content'],
metadata=ConversationMessageMetadata({
'usage': response['usage'],
Expand All @@ -199,31 +226,36 @@ async def handle_streaming_response(self, converse_input: dict[str, Any]) -> Con
message = {}
content = []
message['content'] = content
message['metadata'] = None
text = ''
tool_use = {}
metadata: Optional[ConversationMessageMetadata] = None

#stream the response into a message.
for chunk in response['stream']:

if 'messageStart' in chunk:
message['role'] = chunk['messageStart']['role']

elif 'contentBlockStart' in chunk:
tool = chunk['contentBlockStart']['start']['toolUse']
tool_use['toolUseId'] = tool['toolUseId']
tool_use['name'] = tool['name']

elif 'contentBlockDelta' in chunk:
delta = chunk['contentBlockDelta']['delta']

if 'toolUse' in delta:
if 'input' not in tool_use:
tool_use['input'] = ''
tool_use['input'] += delta['toolUse']['input']

elif 'text' in delta:
text += delta['text']
self.callbacks.on_llm_new_token(
ConversationMessage(
role=ParticipantRole.ASSISTANT.value,
content=[{'text': delta['text']}]
)
ConversationMessage(
role=ParticipantRole.ASSISTANT.value,
content=delta['text']
)
)
elif 'contentBlockStop' in chunk:
if 'input' in tool_use:
Expand All @@ -235,24 +267,19 @@ async def handle_streaming_response(self, converse_input: dict[str, Any]) -> Con
text = ''

elif 'metadata' in chunk:
metadata = {
'usage': chunk['metadata']['usage'],
'metrics': chunk['metadata']['metrics']
}

self.callbacks.on_llm_new_token(
ConversationMessage(
role=ParticipantRole.ASSISTANT.value,
metadata=ConversationMessageMetadata(**metadata)
)

message['metadata'] = ConversationMessageMetadata(
usage=chunk['metadata']['usage'],
metrics=chunk['metadata']['metrics']
)

print('generate message stream :', message)
return ConversationMessage(
role=ParticipantRole.ASSISTANT.value,
content=message['content'],
metadata=ConversationMessageMetadata(**metadata)
**message
)

except Exception as error:
print(traceback.print_exc())
Logger.error(f"Error getting stream from Bedrock model: {str(error)}")
raise error

Expand Down
63 changes: 46 additions & 17 deletions python/src/multi_agent_orchestrator/agents/openai_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from multi_agent_orchestrator.agents import Agent, AgentOptions
from multi_agent_orchestrator.types import (
ConversationMessage,
ConversationMessageMetadata,
ParticipantRole,
OPENAI_MODEL_ID_GPT_O_MINI,
TemplateVariables
Expand All @@ -28,15 +29,15 @@ class OpenAIAgentOptions(AgentOptions):
class OpenAIAgent(Agent):
def __init__(self, options: OpenAIAgentOptions):
super().__init__(options)
if not options.api_key:
raise ValueError("OpenAI API key is required")


if options.client:
self.client = options.client
else:
if not options.api_key:
raise ValueError("OpenAI API key is required")

self.client = OpenAI(api_key=options.api_key)


self.model = options.model or OPENAI_MODEL_ID_GPT_O_MINI
self.streaming = options.streaming or False
self.retriever: Optional[Retriever] = options.retriever
Expand Down Expand Up @@ -83,7 +84,7 @@ def __init__(self, options: OpenAIAgentOptions):
options.custom_system_prompt.get('template'),
options.custom_system_prompt.get('variables')
)



def is_streaming_enabled(self) -> bool:
Expand All @@ -102,11 +103,13 @@ async def process_request(
self.update_system_prompt()

system_prompt = self.system_prompt
citations = None

if self.retriever:
response = await self.retriever.retrieve_and_combine_results(input_text)
context_prompt = "\nHere is the context to use to answer the user's question:\n" + response
context_prompt = "\nHere is the context to use to answer the user's question:\n" + response['text']
system_prompt += context_prompt
citations = response['sources']


messages = [
Expand All @@ -118,7 +121,6 @@ async def process_request(
{"role": "user", "content": input_text}
]


request_options = {
"model": self.model,
"messages": messages,
Expand All @@ -128,10 +130,24 @@ async def process_request(
"stop": self.inference_config.get('stopSequences'),
"stream": self.streaming
}

if self.streaming:
return await self.handle_streaming_response(request_options)
converse_message = await self.handle_streaming_response(request_options)
else:
return await self.handle_single_response(request_options)
converse_message = await self.handle_single_response(request_options)

if citations:
if not converse_message.metadata:
converse_message['metadata'] = ConversationMessageMetadata()

converse_message.metadata.citations.extend(citations)

if self.streaming:
self.callbacks.on_llm_end(
converse_message
)

return converse_message

except Exception as error:
Logger.error(f"Error in OpenAI API call: {str(error)}")
Expand All @@ -152,7 +168,11 @@ async def handle_single_response(self, request_options: Dict[str, Any]) -> Conve

return ConversationMessage(
role=ParticipantRole.ASSISTANT.value,
content=[{"text": assistant_message}]
content=[{"text": assistant_message}],
metadata=ConversationMessageMetadata({
'citations': chat_completion.citations,
'usage': chat_completion.usage
})
)

except Exception as error:
Expand All @@ -163,24 +183,33 @@ async def handle_streaming_response(self, request_options: Dict[str, Any]) -> Co
try:
stream = self.client.chat.completions.create(**request_options)
accumulated_message = []

for chunk in stream:
if chunk.choices[0].delta.content:

metadata = {
'citations': chunk.citations,
'usage': chunk.usage
}

chunk_content = chunk.choices[0].delta.content
accumulated_message.append(chunk_content)

if self.callbacks:
self.callbacks.on_llm_new_token(
ConversationMessage(
role=ParticipantRole.ASSISTANT.value,
content=[{'text': chunk_content}]
)
ConversationMessage(
role=ParticipantRole.ASSISTANT.value,
content=chunk_content,
metadata=ConversationMessageMetadata(**metadata)
)
)
#yield chunk_content

# Store the complete message in the instance for later access if needed
return ConversationMessage(
role=ParticipantRole.ASSISTANT.value,
content=[{"text": ''.join(accumulated_message)}]
role=ParticipantRole.ASSISTANT.value,
content=[{"text": ''.join(accumulated_message)}],
metadata=ConversationMessageMetadata(**metadata)
)

except Exception as error:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,21 @@ async def retrieve_and_combine_results(self, text, knowledge_base_id=None, retri

@staticmethod
def combine_retrieval_results(retrieval_results):
return "\n".join(
sources = []

sources.extend(
set(result['metadata']['x-amz-bedrock-kb-source-uri']
for result in retrieval_results
if result and result.get('metadata') and isinstance(result['metadata'].get('x-amz-bedrock-kb-source-uri'), str))
)

text = "\n".join(
result['content']['text']
for result in retrieval_results
if result and result.get('content') and isinstance(result['content'].get('text'), str)
)
)

return {
'text': text,
'sources': sources
}
Loading

0 comments on commit eb084b6

Please sign in to comment.