Skip to content

Commit

Permalink
code formatting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffxtang committed Jul 17, 2024
1 parent 43630fc commit a0e80ca
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 19 deletions.
2 changes: 1 addition & 1 deletion aimodels/client/multi_fm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self):
"replicate": ReplicateInterface,
"together": TogetherInterface,
"octo": OctoInterface,
"aws": AWSBedrockInterface
"aws": AWSBedrockInterface,
}

def get_provider_interface(self, model):
Expand Down
50 changes: 32 additions & 18 deletions aimodels/providers/aws_bedrock_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

from ..framework.provider_interface import ProviderInterface


def convert_messages_to_llama3_prompt(messages):
"""
Convert a list of messages to a prompt in Llama 3 instruction format.
Args:
messages (list of dict): List of messages where each message is a dictionary
messages (list of dict): List of messages where each message is a dictionary
with 'role' ('system', 'user', 'assistant') and 'content'.
Returns:
str: Formatted prompt for Llama 3.
"""
Expand All @@ -23,21 +24,26 @@ def convert_messages_to_llama3_prompt(messages):
prompt += f"<|start_header_id|>{message['role']}<|end_header_id|>{message['content']}<|eot_id|>\n"

prompt += "<|start_header_id|>assistant<|end_header_id|>"

return prompt

return prompt


class RecursiveNamespace:
"""
Convert dictionaries to objects with attribute access, including nested dictionaries.
This class is used to simulate the OpenAI chat.completions.create's return type, so
response.choices[0].message.content works consistenly for AWS Bedrock's LLM return of a string.
"""

def __init__(self, data):
for key, value in data.items():
if isinstance(value, dict):
value = RecursiveNamespace(value)
elif isinstance(value, list):
value = [RecursiveNamespace(item) if isinstance(item, dict) else item for item in value]
value = [
RecursiveNamespace(item) if isinstance(item, dict) else item
for item in value
]
setattr(self, key, value)

@classmethod
Expand All @@ -50,10 +56,14 @@ def to_dict(self):
if isinstance(value, RecursiveNamespace):
value = value.to_dict()
elif isinstance(value, list):
value = [item.to_dict() if isinstance(item, RecursiveNamespace) else item for item in value]
value = [
item.to_dict() if isinstance(item, RecursiveNamespace) else item
for item in value
]
result[key] = value
return result


class AWSBedrockInterface(ProviderInterface):
"""Implements the ProviderInterface for interacting with AWS Bedrock's APIs."""

Expand All @@ -64,7 +74,7 @@ def __init__(self):
region_name="us-west-2",
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
)
)

def chat_completion_create(self, messages=None, model=None, temperature=0):
"""Request chat completions from the AWS Bedrock API.
Expand All @@ -80,22 +90,26 @@ def chat_completion_create(self, messages=None, model=None, temperature=0):
The API response with the completion result.
"""
body = json.dumps({
"prompt": convert_messages_to_llama3_prompt(messages),
"temperature": temperature
})
accept = 'application/json'
content_type = 'application/json'
response = self.aws_bedrock_client.invoke_model(body=body, modelId=model, accept=accept, contentType=content_type)
response_body = json.loads(response.get('body').read())
generation = response_body.get('generation')
body = json.dumps(
{
"prompt": convert_messages_to_llama3_prompt(messages),
"temperature": temperature,
}
)
accept = "application/json"
content_type = "application/json"
response = self.aws_bedrock_client.invoke_model(
body=body, modelId=model, accept=accept, contentType=content_type
)
response_body = json.loads(response.get("body").read())
generation = response_body.get("generation")

response_data = {
"choices": [
{
"message": {"content": generation},
}
],
}
}

return RecursiveNamespace.from_dict(response_data)

0 comments on commit a0e80ca

Please sign in to comment.