Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Langchain::Assistant#parallel_tool_calls options whether to allow the LLM to make multiple parallel tool calls #827

Merged
merged 6 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
- [BREAKING] Remove `Langchain::Assistant#clear_thread!` method
- [BREAKING] `Langchain::Messages::*` namespace had migrated to `Langchain::Assistant::Messages::*`
- [BREAKING] Modify `Langchain::LLM::AwsBedrock` constructor to pass model options via default_options: {...}
- Introduce `Langchain::Assistant#parallel_tool_calls` options whether to allow the LLM to make multiple parallel tool calls. Default: true
- Minor improvements to the Langchain::Assistant class
- Added support for streaming with Anthropic
- Bump anthropic gem
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ Note that streaming is not currently supported for all LLMs.
* `tools`: An array of tool instances (optional)
* `instructions`: System instructions for the assistant (optional)
* `tool_choice`: Specifies how tools should be selected. Default: "auto". A specific tool function name can be passed. This will force the Assistant to **always** use this function.
* `parallel_tool_calls`: Whether to make multiple parallel tool calls. Default: true
* `add_message_callback`: A callback function (proc, lambda) that is called when any message is added to the conversation (optional)

### Key Methods
Expand Down
25 changes: 20 additions & 5 deletions lib/langchain/assistant.rb
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,35 @@ module Langchain
# tools: [Langchain::Tool::NewsRetriever.new(api_key: ENV["NEWS_API_KEY"])]
# )
class Assistant
attr_reader :llm, :instructions, :state, :llm_adapter, :tool_choice
attr_reader :total_prompt_tokens, :total_completion_tokens, :total_tokens, :messages
attr_accessor :tools, :add_message_callback
attr_reader :llm,
:instructions,
:state,
:llm_adapter,
:messages,
:tool_choice,
:total_prompt_tokens,
:total_completion_tokens,
:total_tokens

attr_accessor :tools,
:add_message_callback,
:parallel_tool_calls

# Create a new assistant
#
# @param llm [Langchain::LLM::Base] LLM instance that the assistant will use
# @param tools [Array<Langchain::Tool::Base>] Tools that the assistant has access to
# @param instructions [String] The system instructions
# @param tool_choice [String] Specify how tools should be selected. Options: "auto", "any", "none", or <specific function name>
# @params add_message_callback [Proc] A callback function (Proc or lambda) that is called when any message is added to the conversation
# @param parallel_tool_calls [Boolean] Whether or not to run tools in parallel
# @param messages [Array<Langchain::Assistant::Messages::Base>] The messages
# @param add_message_callback [Proc] A callback function (Proc or lambda) that is called when any message is added to the conversation
def initialize(
llm:,
tools: [],
instructions: nil,
tool_choice: "auto",
parallel_tool_calls: true,
messages: [],
add_message_callback: nil,
&block
Expand All @@ -47,6 +60,7 @@ def initialize(

self.messages = messages
@tools = tools
@parallel_tool_calls = parallel_tool_calls
self.tool_choice = tool_choice
self.instructions = instructions
@block = block
Expand Down Expand Up @@ -326,7 +340,8 @@ def chat_with_llm
instructions: @instructions,
messages: array_of_message_hashes,
tools: @tools,
tool_choice: tool_choice
tool_choice: tool_choice,
parallel_tool_calls: parallel_tool_calls
)
@llm.chat(**params, &@block)
end
Expand Down
30 changes: 21 additions & 9 deletions lib/langchain/assistant/llm/adapters/anthropic.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,23 @@ module Adapters
class Anthropic < Base
# Build the chat parameters for the Anthropic API
#
# @param tools [Array<Hash>] The tools to use
# @param instructions [String] The system instructions
# @param messages [Array<Hash>] The messages
# @param instructions [String] The system instructions
# @param tools [Array<Hash>] The tools to use
# @param tool_choice [String] The tool choice
# @param parallel_tool_calls [Boolean] Whether to make parallel tool calls
# @return [Hash] The chat parameters
def build_chat_params(tools:, instructions:, messages:, tool_choice:)
def build_chat_params(
messages:,
instructions:,
tools:,
tool_choice:,
parallel_tool_calls:
)
params = {messages: messages}
if tools.any?
params[:tools] = build_tools(tools)
params[:tool_choice] = build_tool_choice(tool_choice)
params[:tool_choice] = build_tool_choice(tool_choice, parallel_tool_calls)
end
params[:system] = instructions if instructions
params
Expand All @@ -31,7 +38,7 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
# @param tool_call_id [String] The tool call ID
# @return [Messages::AnthropicMessage] The Anthropic message
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
warn "Image URL is not supported by Anthropic currently" if image_url
Langchain.logger.warn "WARNING: Image URL is not supported by Anthropic currently" if image_url

Messages::AnthropicMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
end
Expand Down Expand Up @@ -76,15 +83,20 @@ def support_system_message?

private

def build_tool_choice(choice)
def build_tool_choice(choice, parallel_tool_calls)
tool_choice_object = {disable_parallel_tool_use: !parallel_tool_calls}

case choice
when "auto"
{type: "auto"}
tool_choice_object[:type] = "auto"
when "any"
{type: "any"}
tool_choice_object[:type] = "any"
else
{type: "tool", name: choice}
tool_choice_object[:type] = "tool"
tool_choice_object[:name] = choice
end

tool_choice_object
end
end
end
Expand Down
13 changes: 10 additions & 3 deletions lib/langchain/assistant/llm/adapters/base.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@ module Adapters
class Base
# Build the chat parameters for the LLM
#
# @param tools [Array] The tools to use
# @param instructions [String] The system instructions
# @param messages [Array] The messages
# @param instructions [String] The system instructions
# @param tools [Array] The tools to use
# @param tool_choice [String] The tool choice
# @param parallel_tool_calls [Boolean] Whether to make parallel tool calls
# @return [Hash] The chat parameters
def build_chat_params(tools:, instructions:, messages:, tool_choice:)
def build_chat_params(
messages:,
instructions:,
tools:,
tool_choice:,
parallel_tool_calls:
)
raise NotImplementedError, "Subclasses must implement build_chat_params"
end

Expand Down
17 changes: 13 additions & 4 deletions lib/langchain/assistant/llm/adapters/google_gemini.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,21 @@ module Adapters
class GoogleGemini < Base
# Build the chat parameters for the Google Gemini LLM
#
# @param tools [Array] The tools to use
# @param instructions [String] The system instructions
# @param messages [Array] The messages
# @param instructions [String] The system instructions
# @param tools [Array] The tools to use
# @param tool_choice [String] The tool choice
# @param parallel_tool_calls [Boolean] Whether to make parallel tool calls
# @return [Hash] The chat parameters
def build_chat_params(tools:, instructions:, messages:, tool_choice:)
def build_chat_params(
messages:,
instructions:,
tools:,
tool_choice:,
parallel_tool_calls:
)
Langchain.logger.warn "WARNING: `parallel_tool_calls:` is not supported by Google Gemini currently"

params = {messages: messages}
if tools.any?
params[:tools] = build_tools(tools)
Expand All @@ -31,7 +40,7 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
# @param tool_call_id [String] The tool call ID
# @return [Messages::GoogleGeminiMessage] The Google Gemini message
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
warn "Image URL is not supported by Google Gemini" if image_url
Langchain.logger.warn "Image URL is not supported by Google Gemini" if image_url

Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
end
Expand Down
15 changes: 12 additions & 3 deletions lib/langchain/assistant/llm/adapters/mistral_ai.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,21 @@ module Adapters
class MistralAI < Base
# Build the chat parameters for the Mistral AI LLM
#
# @param tools [Array] The tools to use
# @param instructions [String] The system instructions
# @param messages [Array] The messages
# @param instructions [String] The system instructions
# @param tools [Array] The tools to use
# @param tool_choice [String] The tool choice
# @param parallel_tool_calls [Boolean] Whether to make parallel tool calls
# @return [Hash] The chat parameters
def build_chat_params(tools:, instructions:, messages:, tool_choice:)
def build_chat_params(
messages:,
instructions:,
tools:,
tool_choice:,
parallel_tool_calls:
)
Langchain.logger.warn "WARNING: `parallel_tool_calls:` is not supported by Mistral AI currently"

params = {messages: messages}
if tools.any?
params[:tools] = build_tools(tools)
Expand Down
18 changes: 14 additions & 4 deletions lib/langchain/assistant/llm/adapters/ollama.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,22 @@ module Adapters
class Ollama < Base
# Build the chat parameters for the Ollama LLM
#
# @param tools [Array] The tools to use
# @param instructions [String] The system instructions
# @param messages [Array] The messages
# @param instructions [String] The system instructions
# @param tools [Array] The tools to use
# @param tool_choice [String] The tool choice
# @param parallel_tool_calls [Boolean] Whether to make parallel tool calls
# @return [Hash] The chat parameters
def build_chat_params(tools:, instructions:, messages:, tool_choice:)
def build_chat_params(
messages:,
instructions:,
tools:,
tool_choice:,
parallel_tool_calls:
)
Langchain.logger.warn "WARNING: `parallel_tool_calls:` is not supported by Ollama currently"
Langchain.logger.warn "WARNING: `tool_choice:` is not supported by Ollama currently"

params = {messages: messages}
if tools.any?
params[:tools] = build_tools(tools)
Expand All @@ -29,7 +39,7 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
# @param tool_call_id [String] The tool call ID
# @return [Messages::OllamaMessage] The Ollama message
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
warn "Image URL is not supported by Ollama currently" if image_url
Langchain.logger.warn "WARNING: Image URL is not supported by Ollama currently" if image_url

Messages::OllamaMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
end
Expand Down
14 changes: 11 additions & 3 deletions lib/langchain/assistant/llm/adapters/openai.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,24 @@ module Adapters
class OpenAI < Base
# Build the chat parameters for the OpenAI LLM
#
# @param tools [Array] The tools to use
# @param instructions [String] The system instructions
# @param messages [Array] The messages
# @param instructions [String] The system instructions
# @param tools [Array] The tools to use
# @param tool_choice [String] The tool choice
# @param parallel_tool_calls [Boolean] Whether to make parallel tool calls
# @return [Hash] The chat parameters
def build_chat_params(tools:, instructions:, messages:, tool_choice:)
def build_chat_params(
messages:,
instructions:,
tools:,
tool_choice:,
parallel_tool_calls:
)
params = {messages: messages}
if tools.any?
params[:tools] = build_tools(tools)
params[:tool_choice] = build_tool_choice(tool_choice)
params[:parallel_tool_calls] = parallel_tool_calls
end
params
end
Expand Down
2 changes: 1 addition & 1 deletion lib/langchain/llm/ai21.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module Langchain::LLM
# gem "ai21", "~> 0.2.1"
#
# Usage:
# ai21 = Langchain::LLM::AI21.new(api_key: ENV["AI21_API_KEY"])
# llm = Langchain::LLM::AI21.new(api_key: ENV["AI21_API_KEY"])
#
class AI21 < Base
DEFAULTS = {
Expand Down
2 changes: 1 addition & 1 deletion lib/langchain/llm/anthropic.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module Langchain::LLM
# gem "anthropic", "~> 0.3.2"
#
# Usage:
# anthropic = Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
# llm = Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
#
class Anthropic < Base
DEFAULTS = {
Expand Down
2 changes: 1 addition & 1 deletion lib/langchain/llm/aws_bedrock.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ module Langchain::LLM
# gem 'aws-sdk-bedrockruntime', '~> 1.1'
#
# Usage:
# bedrock = Langchain::LLM::AwsBedrock.new(llm_options: {})
# llm = Langchain::LLM::AwsBedrock.new(llm_options: {})
#
class AwsBedrock < Base
DEFAULTS = {
Expand Down
2 changes: 1 addition & 1 deletion lib/langchain/llm/azure.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ module Langchain::LLM
# gem "ruby-openai", "~> 6.3.0"
#
# Usage:
# openai = Langchain::LLM::Azure.new(api_key:, llm_options: {}, embedding_deployment_url: chat_deployment_url:)
# llm = Langchain::LLM::Azure.new(api_key:, llm_options: {}, embedding_deployment_url: chat_deployment_url:)
#
class Azure < OpenAI
attr_reader :embed_client
Expand Down
2 changes: 1 addition & 1 deletion lib/langchain/llm/hugging_face.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module Langchain::LLM
# gem "hugging-face", "~> 0.3.4"
#
# Usage:
# hf = Langchain::LLM::HuggingFace.new(api_key: ENV["HUGGING_FACE_API_KEY"])
# llm = Langchain::LLM::HuggingFace.new(api_key: ENV["HUGGING_FACE_API_KEY"])
#
class HuggingFace < Base
DEFAULTS = {
Expand Down
1 change: 0 additions & 1 deletion lib/langchain/llm/ollama.rb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ module Langchain::LLM
# Available models: https://ollama.ai/library
#
# Usage:
# llm = Langchain::LLM::Ollama.new
# llm = Langchain::LLM::Ollama.new(url: ENV["OLLAMA_URL"], default_options: {})
#
class Ollama < Base
Expand Down
4 changes: 2 additions & 2 deletions lib/langchain/llm/openai.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ module Langchain::LLM
# gem "ruby-openai", "~> 6.3.0"
#
# Usage:
# openai = Langchain::LLM::OpenAI.new(
# llm = Langchain::LLM::OpenAI.new(
# api_key: ENV["OPENAI_API_KEY"],
# llm_options: {}, # Available options: https://github.com/alexrudall/ruby-openai/blob/main/lib/openai/client.rb#L5-L13
# default_options: {}
Expand Down Expand Up @@ -100,7 +100,7 @@ def embed(
# @param params [Hash] The parameters to pass to the `chat()` method
# @return [Langchain::LLM::OpenAIResponse] Response object
def complete(prompt:, **params)
warn "DEPRECATED: `Langchain::LLM::OpenAI#complete` is deprecated, and will be removed in the next major version. Use `Langchain::LLM::OpenAI#chat` instead."
Langchain.logger.warn "DEPRECATED: `Langchain::LLM::OpenAI#complete` is deprecated, and will be removed in the next major version. Use `Langchain::LLM::OpenAI#chat` instead."

if params[:stop_sequences]
params[:stop] = params.delete(:stop_sequences)
Expand Down
1 change: 1 addition & 0 deletions lib/langchain/llm/parameters/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class Chat < SimpleDelegator
# Function-calling
tools: {default: []},
tool_choice: {},
parallel_tool_calls: {},

# Additional optional parameters
logit_bias: {}
Expand Down
12 changes: 2 additions & 10 deletions lib/langchain/llm/replicate.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,8 @@ module Langchain::LLM
# Gem requirements:
# gem "replicate-ruby", "~> 0.2.2"
#
# Use it directly:
# replicate = Langchain::LLM::Replicate.new(api_key: ENV["REPLICATE_API_KEY"])
#
# Or pass it to be used by a vector search DB:
# chroma = Langchain::Vectorsearch::Chroma.new(
# url: ENV["CHROMA_URL"],
# index_name: "...",
# llm: replicate
# )
#
# Usage:
# llm = Langchain::LLM::Replicate.new(api_key: ENV["REPLICATE_API_KEY"])
class Replicate < Base
DEFAULTS = {
# TODO: Figure out how to send the temperature to the API
Expand Down
Loading