Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve AWS bedrock integration #289

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
89 changes: 77 additions & 12 deletions adalflow/adalflow/components/model_client/bedrock_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class BedrockAPIClient(ModelClient):
Setup:
1. Install boto3: `pip install boto3`
2. Ensure you have the AWS credentials set up. There are four variables you can optionally set:
Either AWS_PROFILE_NAME or (AWS_REGION_NAME and AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY) are needed
- AWS_PROFILE_NAME: The name of the AWS profile to use.
- AWS_REGION_NAME: The name of the AWS region to use.
- AWS_ACCESS_KEY_ID: The AWS access key ID.
Expand All @@ -80,10 +81,9 @@ class BedrockAPIClient(ModelClient):
self.generator = Generator(
model_client=BedrockAPIClient(),
model_kwargs={
"modelId": "anthropic.claude-3-sonnet-20240229-v1:0",
"inferenceConfig": {
"temperature": 0.8
}
"model": "mistral.mistral-7b-instruct-v0:2",
"temperature": 0.8,
"max_tokens": 100
}, template=template
)

Expand All @@ -99,8 +99,8 @@ class BedrockAPIClient(ModelClient):

def __init__(
self,
aws_profile_name="default",
aws_region_name="us-west-2", # Use a supported default region
aws_profile_name=None,
aws_region_name=None,
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
Expand All @@ -123,6 +123,12 @@ def __init__(
self.chat_completion_parser = (
chat_completion_parser or get_first_message_content
)
self.inference_parameters = [
"maxTokens",
"temperature",
"topP",
"stopSequences",
]

def init_sync_client(self):
"""
Expand Down Expand Up @@ -215,22 +221,71 @@ def track_completion_usage(self, completion: Dict) -> CompletionUsage:
total_tokens=usage["totalTokens"],
)

def list_models(self):
def list_models(self, **kwargs):
# Initialize Bedrock client (not runtime)
# Reference: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_ListFoundationModels.html

try:
response = self._client.list_foundation_models()
models = response.get("models", [])
response = self._client.list_foundation_models(**kwargs)
models = response.get("modelSummaries", [])
for model in models:
print(f"Model ID: {model['modelId']}")
print(f" Name: {model['name']}")
print(f" Description: {model['description']}")
print(f" Provider: {model['provider']}")
print(f" Name: {model['modelName']}")
print(f" Model ARN: {model['modelArn']}")
print(f" Provider: {model['providerName']}")
print(f" Input: {model['inputModalities']}")
print(f" Output: {model['outputModalities']}")
print(f" InferenceTypesSupported: {model['inferenceTypesSupported']}")
print("")

except Exception as e:
print(f"Error listing models: {e}")

def _validate_and_process_config_keys(self, api_kwargs: Dict):
"""
Validate and process the model ID in API kwargs.

:param api_kwargs: Dictionary of API keyword arguments
:raises KeyError: If 'model' key is missing
"""
if "model" in api_kwargs:
api_kwargs["modelId"] = api_kwargs.pop("model")
else:
raise KeyError("The required key 'model' is missing in model_kwargs.")

# In .converse() `maxTokens`` is the key for maximum tokens limit
if "max_tokens" in api_kwargs:
api_kwargs["maxTokens"] = api_kwargs.pop("max_tokens")

return api_kwargs

def _separate_parameters(self, api_kwargs: Dict) -> tuple:
"""
Separate inference configuration and additional model request fields.

:param api_kwargs: Dictionary of API keyword arguments
:return: Tuple of (inference_config, additional_model_request_fields)
"""
inference_config = {}
additional_model_request_fields = {}
keys_to_remove = set()
excluded_keys = {"modelId"}

# Categorize parameters
for key, value in list(api_kwargs.items()):
if key in self.inference_parameters:
inference_config[key] = value
keys_to_remove.add(key)
elif key not in excluded_keys:
additional_model_request_fields[key] = value
keys_to_remove.add(key)

# Remove categorized keys from api_kwargs
for key in keys_to_remove:
api_kwargs.pop(key, None)

return api_kwargs, inference_config, additional_model_request_fields

def convert_inputs_to_api_kwargs(
self,
input: Optional[Any] = None,
Expand All @@ -243,9 +298,19 @@ def convert_inputs_to_api_kwargs(
"""
api_kwargs = model_kwargs.copy()
if model_type == ModelType.LLM:
# Validate and process model ID
api_kwargs = self._validate_and_process_config_keys(api_kwargs)

# Separate inference config and additional model request fields
api_kwargs, inference_config, additional_model_request_fields = (
self._separate_parameters(api_kwargs)
)

api_kwargs["messages"] = [
{"role": "user", "content": [{"text": input}]},
]
api_kwargs["inferenceConfig"] = inference_config
api_kwargs["additionalModelRequestFields"] = additional_model_request_fields
else:
raise ValueError(f"Model type {model_type} not supported")
return api_kwargs
Expand Down
55 changes: 55 additions & 0 deletions docs/source/integrations/aws_bedrock.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
.. _integration-aws-bedrock:

AWS Bedrock API Client
=======================

.. admonition:: Author
:class: highlight

`Ajith Kumar <https://github.com/ajithvcoder>`_

Getting Credentials
-------------------

You need to have an AWS account and an access key and secret key to use AWS Bedrock services. Moreover, the account associated with the access key must have
the necessary permissions to access Bedrock services. Refer to the `AWS documentation <https://docs.aws.amazon.com/singlesignon/latest/userguide/howtogetcredentials.html>`_ for more information on obtaining credentials.

Enabling Foundation Models
--------------------------

AWS Bedrock offers several foundation models from providers like "Meta," "Amazon," "Cohere," "Anthropic," and "Microsoft." To access these models, you need to enable them first. Note that each AWS region supports a specific set of models. Not all foundation models are available in every region, and pricing varies by region.

Pricing information: `AWS Bedrock Pricing <https://aws.amazon.com/bedrock/pricing/>`_

Steps for enabling model access:

1. Select the desired region in the AWS Console (e.g., `us-east-1 (N. Virginia)`).
2. Navigate to the `Bedrock services home page <https://console.aws.amazon.com/bedrock/home>`_.
3. On the left sidebar, under "Bedrock Configuration," click "Model Access."

You will be redirected to a page where you can select the models to enable.

Note:

1. Avoid enabling high-cost models to prevent accidental high charges due to incorrect usage.
2. As of Nov 2024, a cost-effective option is the Llama-3.2 1B model, with model ID: ``meta.llama3-2-1b-instruct-v1:0`` in the ``us-east-1`` region.
3. AWS tags certain models with `inferenceTypesSupported` = `INFERENCE_PROFILE` and in UI it might appear with a tooltip as `This model can only be used through an inference profile.` In such cases you may need to use the Model ARN: ``arn:aws:bedrock:us-east-1:306093656765:inference-profile/us.meta.llama3-2-1b-instruct-v1:0`` in the model ID field when using Adalflow.
4. Ensure (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION_NAME) or AWS_DEFAULT_PROFILE is set in the ``.env`` file. Mention exact key names in ``.env`` file for example access key id is ``AWS_ACCESS_KEY_ID``

.. code-block:: python

import adalflow as adal
import os

# Ensure (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION_NAME) or AWS_DEFAULT_PROFILE is set in the .env file
adal.setup_env()
model_client = adal.BedrockAPIClient()
model_client.list_models()

Which ever profile is tagged with ``INFERENCE_PROFILE`` you might need to provide ``Model ARN`` in ``model`` filed of ``model_kwargs``

References
----------

1. You can refer to Model IDs or Model ARNs `here <https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/models>`_. Clicking on a model card provides additional information.
2. Internally, Adalflow's AWS client uses the `Converse API <https://boto3.amazonaws.com/v1/documentation/api/1.35.8/reference/services/bedrock-runtime/client/converse.html>`_ for each conversation.
54 changes: 54 additions & 0 deletions tutorials/bedrock_client_simple_qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from adalflow.components.model_client import BedrockAPIClient
from adalflow.core.types import ModelType
from adalflow.utils import setup_env


def list_models():
# For list of models
model_client = BedrockAPIClient()
model_client.list_models(byProvider="meta")


def bedrock_chat_conversation():
# Initialize the Bedrock client for API interactions
awsbedrock_client = BedrockAPIClient()
query = "What is the capital of France?"

# Embed the prompt in Llama 3's instruction format.
formatted_prompt = f"""
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
{query}
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
"""

# Set the model type to Large Language Model (LLM)
model_type = ModelType.LLM

# Configure model parameters:
# - model: Specifies Llama-3-2 1B as the model to use
# - temperature: Controls randomness (0.5 = balanced between deterministic and creative)
# - max_tokens: Limits the response length to 100 tokens

# Using Model ARN since its has inference_profile in us-east-1 region
# https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=meta.llama3-2-1b-instruct-v1:0
model_id = "arn:aws:bedrock:us-east-1:306093656765:inference-profile/us.meta.llama3-2-1b-instruct-v1:0"
model_kwargs = {"model": model_id, "temperature": 0.5, "max_tokens": 100}

# Convert the inputs into the format required by BedRock's API
api_kwargs = awsbedrock_client.convert_inputs_to_api_kwargs(
input=formatted_prompt, model_kwargs=model_kwargs, model_type=model_type
)
print(f"api_kwargs: {api_kwargs}")

response = awsbedrock_client.call(api_kwargs=api_kwargs, model_type=model_type)

# Extract the text from the chat completion response
response_text = awsbedrock_client.parse_chat_completion(response)
print(f"response_text: {response_text}")


if __name__ == "__main__":
setup_env()
list_models()
bedrock_chat_conversation()
12 changes: 6 additions & 6 deletions tutorials/generator_all_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def use_all_providers():
)
# need to run ollama pull llama3.2:1b first to use this model

# aws_bedrock_llm = adal.Generator(
# model_client=adal.BedrockAPIClient(),
# model_kwargs={"modelId": "amazon.mistral.instruct-7b"},
# )
aws_bedrock_llm = adal.Generator(
model_client=adal.BedrockAPIClient(),
model_kwargs={"model": "mistral.mistral-7b-instruct-v0:2"},
)

prompt_kwargs = {"input_str": "What is the meaning of life in one sentence?"}

Expand All @@ -38,14 +38,14 @@ def use_all_providers():
anthropic_response = anthropic_llm(prompt_kwargs)
google_gen_ai_response = google_gen_ai_llm(prompt_kwargs)
ollama_response = ollama_llm(prompt_kwargs)
# aws_bedrock_llm_response = aws_bedrock_llm(prompt_kwargs)
aws_bedrock_llm_response = aws_bedrock_llm(prompt_kwargs)

print(f"OpenAI: {openai_response}\n")
print(f"Groq: {groq_response}\n")
print(f"Anthropic: {anthropic_response}\n")
print(f"Google GenAI: {google_gen_ai_response}\n")
print(f"Ollama: {ollama_response}\n")
# print(f"AWS Bedrock: {aws_bedrock_llm_response}\n")
print(f"AWS Bedrock: {aws_bedrock_llm_response}\n")


if __name__ == "__main__":
Expand Down
Loading