Skip to content

Commit

Permalink
Introduce Langchain::Assistant#parallel_tool_calls options whether …
Browse files Browse the repository at this point in the history
…to allow the LLM to make multiple parallel tool calls (#827)

* Implement assistant.parallel_tool_calls: true/false option

* fix linter

* Specs

* specs

* fix linter

* CHANGELOG entry
  • Loading branch information
andreibondarev authored Oct 12, 2024
1 parent fe33fa2 commit 49161bb
Show file tree
Hide file tree
Showing 25 changed files with 237 additions and 75 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
- [BREAKING] Remove `Langchain::Assistant#clear_thread!` method
- [BREAKING] `Langchain::Messages::*` namespace had migrated to `Langchain::Assistant::Messages::*`
- [BREAKING] Modify `Langchain::LLM::AwsBedrock` constructor to pass model options via default_options: {...}
- Introduce `Langchain::Assistant#parallel_tool_calls` options whether to allow the LLM to make multiple parallel tool calls. Default: true
- Minor improvements to the Langchain::Assistant class
- Added support for streaming with Anthropic
- Bump anthropic gem
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ Note that streaming is not currently supported for all LLMs.
* `tools`: An array of tool instances (optional)
* `instructions`: System instructions for the assistant (optional)
* `tool_choice`: Specifies how tools should be selected. Default: "auto". A specific tool function name can be passed. This will force the Assistant to **always** use this function.
* `parallel_tool_calls`: Whether to make multiple parallel tool calls. Default: true
* `add_message_callback`: A callback function (proc, lambda) that is called when any message is added to the conversation (optional)

### Key Methods
Expand Down
25 changes: 20 additions & 5 deletions lib/langchain/assistant.rb
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,35 @@ module Langchain
# tools: [Langchain::Tool::NewsRetriever.new(api_key: ENV["NEWS_API_KEY"])]
# )
class Assistant
attr_reader :llm, :instructions, :state, :llm_adapter, :tool_choice
attr_reader :total_prompt_tokens, :total_completion_tokens, :total_tokens, :messages
attr_accessor :tools, :add_message_callback
attr_reader :llm,
:instructions,
:state,
:llm_adapter,
:messages,
:tool_choice,
:total_prompt_tokens,
:total_completion_tokens,
:total_tokens

attr_accessor :tools,
:add_message_callback,
:parallel_tool_calls

# Create a new assistant
#
# @param llm [Langchain::LLM::Base] LLM instance that the assistant will use
# @param tools [Array<Langchain::Tool::Base>] Tools that the assistant has access to
# @param instructions [String] The system instructions
# @param tool_choice [String] Specify how tools should be selected. Options: "auto", "any", "none", or <specific function name>
# @params add_message_callback [Proc] A callback function (Proc or lambda) that is called when any message is added to the conversation
# @param parallel_tool_calls [Boolean] Whether or not to run tools in parallel
# @param messages [Array<Langchain::Assistant::Messages::Base>] The messages
# @param add_message_callback [Proc] A callback function (Proc or lambda) that is called when any message is added to the conversation
def initialize(
llm:,
tools: [],
instructions: nil,
tool_choice: "auto",
parallel_tool_calls: true,
messages: [],
add_message_callback: nil,
&block
Expand All @@ -47,6 +60,7 @@ def initialize(

self.messages = messages
@tools = tools
@parallel_tool_calls = parallel_tool_calls
self.tool_choice = tool_choice
self.instructions = instructions
@block = block
Expand Down Expand Up @@ -326,7 +340,8 @@ def chat_with_llm
instructions: @instructions,
messages: array_of_message_hashes,
tools: @tools,
tool_choice: tool_choice
tool_choice: tool_choice,
parallel_tool_calls: parallel_tool_calls
)
@llm.chat(**params, &@block)
end
Expand Down
30 changes: 21 additions & 9 deletions lib/langchain/assistant/llm/adapters/anthropic.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,23 @@ module Adapters
class Anthropic < Base
# Build the chat parameters for the Anthropic API
#
# @param tools [Array<Hash>] The tools to use
# @param instructions [String] The system instructions
# @param messages [Array<Hash>] The messages
# @param instructions [String] The system instructions
# @param tools [Array<Hash>] The tools to use
# @param tool_choice [String] The tool choice
# @param parallel_tool_calls [Boolean] Whether to make parallel tool calls
# @return [Hash] The chat parameters
def build_chat_params(tools:, instructions:, messages:, tool_choice:)
def build_chat_params(
messages:,
instructions:,
tools:,
tool_choice:,
parallel_tool_calls:
)
params = {messages: messages}
if tools.any?
params[:tools] = build_tools(tools)
params[:tool_choice] = build_tool_choice(tool_choice)
params[:tool_choice] = build_tool_choice(tool_choice, parallel_tool_calls)
end
params[:system] = instructions if instructions
params
Expand All @@ -31,7 +38,7 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
# @param tool_call_id [String] The tool call ID
# @return [Messages::AnthropicMessage] The Anthropic message
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
warn "Image URL is not supported by Anthropic currently" if image_url
Langchain.logger.warn "WARNING: Image URL is not supported by Anthropic currently" if image_url

Messages::AnthropicMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
end
Expand Down Expand Up @@ -76,15 +83,20 @@ def support_system_message?

private

def build_tool_choice(choice)
def build_tool_choice(choice, parallel_tool_calls)
tool_choice_object = {disable_parallel_tool_use: !parallel_tool_calls}

case choice
when "auto"
{type: "auto"}
tool_choice_object[:type] = "auto"
when "any"
{type: "any"}
tool_choice_object[:type] = "any"
else
{type: "tool", name: choice}
tool_choice_object[:type] = "tool"
tool_choice_object[:name] = choice
end

tool_choice_object
end
end
end
Expand Down
13 changes: 10 additions & 3 deletions lib/langchain/assistant/llm/adapters/base.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@ module Adapters
class Base
# Build the chat parameters for the LLM
#
# @param tools [Array] The tools to use
# @param instructions [String] The system instructions
# @param messages [Array] The messages
# @param instructions [String] The system instructions
# @param tools [Array] The tools to use
# @param tool_choice [String] The tool choice
# @param parallel_tool_calls [Boolean] Whether to make parallel tool calls
# @return [Hash] The chat parameters
def build_chat_params(tools:, instructions:, messages:, tool_choice:)
def build_chat_params(
messages:,
instructions:,
tools:,
tool_choice:,
parallel_tool_calls:
)
raise NotImplementedError, "Subclasses must implement build_chat_params"
end

Expand Down
17 changes: 13 additions & 4 deletions lib/langchain/assistant/llm/adapters/google_gemini.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,21 @@ module Adapters
class GoogleGemini < Base
# Build the chat parameters for the Google Gemini LLM
#
# @param tools [Array] The tools to use
# @param instructions [String] The system instructions
# @param messages [Array] The messages
# @param instructions [String] The system instructions
# @param tools [Array] The tools to use
# @param tool_choice [String] The tool choice
# @param parallel_tool_calls [Boolean] Whether to make parallel tool calls
# @return [Hash] The chat parameters
def build_chat_params(tools:, instructions:, messages:, tool_choice:)
def build_chat_params(
messages:,
instructions:,
tools:,
tool_choice:,
parallel_tool_calls:
)
Langchain.logger.warn "WARNING: `parallel_tool_calls:` is not supported by Google Gemini currently"

params = {messages: messages}
if tools.any?
params[:tools] = build_tools(tools)
Expand All @@ -31,7 +40,7 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
# @param tool_call_id [String] The tool call ID
# @return [Messages::GoogleGeminiMessage] The Google Gemini message
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
warn "Image URL is not supported by Google Gemini" if image_url
Langchain.logger.warn "Image URL is not supported by Google Gemini" if image_url

Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
end
Expand Down
15 changes: 12 additions & 3 deletions lib/langchain/assistant/llm/adapters/mistral_ai.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,21 @@ module Adapters
class MistralAI < Base
# Build the chat parameters for the Mistral AI LLM
#
# @param tools [Array] The tools to use
# @param instructions [String] The system instructions
# @param messages [Array] The messages
# @param instructions [String] The system instructions
# @param tools [Array] The tools to use
# @param tool_choice [String] The tool choice
# @param parallel_tool_calls [Boolean] Whether to make parallel tool calls
# @return [Hash] The chat parameters
def build_chat_params(tools:, instructions:, messages:, tool_choice:)
def build_chat_params(
messages:,
instructions:,
tools:,
tool_choice:,
parallel_tool_calls:
)
Langchain.logger.warn "WARNING: `parallel_tool_calls:` is not supported by Mistral AI currently"

params = {messages: messages}
if tools.any?
params[:tools] = build_tools(tools)
Expand Down
18 changes: 14 additions & 4 deletions lib/langchain/assistant/llm/adapters/ollama.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,22 @@ module Adapters
class Ollama < Base
# Build the chat parameters for the Ollama LLM
#
# @param tools [Array] The tools to use
# @param instructions [String] The system instructions
# @param messages [Array] The messages
# @param instructions [String] The system instructions
# @param tools [Array] The tools to use
# @param tool_choice [String] The tool choice
# @param parallel_tool_calls [Boolean] Whether to make parallel tool calls
# @return [Hash] The chat parameters
def build_chat_params(tools:, instructions:, messages:, tool_choice:)
def build_chat_params(
messages:,
instructions:,
tools:,
tool_choice:,
parallel_tool_calls:
)
Langchain.logger.warn "WARNING: `parallel_tool_calls:` is not supported by Ollama currently"
Langchain.logger.warn "WARNING: `tool_choice:` is not supported by Ollama currently"

params = {messages: messages}
if tools.any?
params[:tools] = build_tools(tools)
Expand All @@ -29,7 +39,7 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
# @param tool_call_id [String] The tool call ID
# @return [Messages::OllamaMessage] The Ollama message
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
warn "Image URL is not supported by Ollama currently" if image_url
Langchain.logger.warn "WARNING: Image URL is not supported by Ollama currently" if image_url

Messages::OllamaMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
end
Expand Down
14 changes: 11 additions & 3 deletions lib/langchain/assistant/llm/adapters/openai.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,24 @@ module Adapters
class OpenAI < Base
# Build the chat parameters for the OpenAI LLM
#
# @param tools [Array] The tools to use
# @param instructions [String] The system instructions
# @param messages [Array] The messages
# @param instructions [String] The system instructions
# @param tools [Array] The tools to use
# @param tool_choice [String] The tool choice
# @param parallel_tool_calls [Boolean] Whether to make parallel tool calls
# @return [Hash] The chat parameters
def build_chat_params(tools:, instructions:, messages:, tool_choice:)
def build_chat_params(
messages:,
instructions:,
tools:,
tool_choice:,
parallel_tool_calls:
)
params = {messages: messages}
if tools.any?
params[:tools] = build_tools(tools)
params[:tool_choice] = build_tool_choice(tool_choice)
params[:parallel_tool_calls] = parallel_tool_calls
end
params
end
Expand Down
2 changes: 1 addition & 1 deletion lib/langchain/llm/ai21.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module Langchain::LLM
# gem "ai21", "~> 0.2.1"
#
# Usage:
# ai21 = Langchain::LLM::AI21.new(api_key: ENV["AI21_API_KEY"])
# llm = Langchain::LLM::AI21.new(api_key: ENV["AI21_API_KEY"])
#
class AI21 < Base
DEFAULTS = {
Expand Down
2 changes: 1 addition & 1 deletion lib/langchain/llm/anthropic.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module Langchain::LLM
# gem "anthropic", "~> 0.3.2"
#
# Usage:
# anthropic = Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
# llm = Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
#
class Anthropic < Base
DEFAULTS = {
Expand Down
2 changes: 1 addition & 1 deletion lib/langchain/llm/aws_bedrock.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ module Langchain::LLM
# gem 'aws-sdk-bedrockruntime', '~> 1.1'
#
# Usage:
# bedrock = Langchain::LLM::AwsBedrock.new(llm_options: {})
# llm = Langchain::LLM::AwsBedrock.new(llm_options: {})
#
class AwsBedrock < Base
DEFAULTS = {
Expand Down
2 changes: 1 addition & 1 deletion lib/langchain/llm/azure.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ module Langchain::LLM
# gem "ruby-openai", "~> 6.3.0"
#
# Usage:
# openai = Langchain::LLM::Azure.new(api_key:, llm_options: {}, embedding_deployment_url: chat_deployment_url:)
# llm = Langchain::LLM::Azure.new(api_key:, llm_options: {}, embedding_deployment_url: chat_deployment_url:)
#
class Azure < OpenAI
attr_reader :embed_client
Expand Down
2 changes: 1 addition & 1 deletion lib/langchain/llm/hugging_face.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module Langchain::LLM
# gem "hugging-face", "~> 0.3.4"
#
# Usage:
# hf = Langchain::LLM::HuggingFace.new(api_key: ENV["HUGGING_FACE_API_KEY"])
# llm = Langchain::LLM::HuggingFace.new(api_key: ENV["HUGGING_FACE_API_KEY"])
#
class HuggingFace < Base
DEFAULTS = {
Expand Down
1 change: 0 additions & 1 deletion lib/langchain/llm/ollama.rb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ module Langchain::LLM
# Available models: https://ollama.ai/library
#
# Usage:
# llm = Langchain::LLM::Ollama.new
# llm = Langchain::LLM::Ollama.new(url: ENV["OLLAMA_URL"], default_options: {})
#
class Ollama < Base
Expand Down
4 changes: 2 additions & 2 deletions lib/langchain/llm/openai.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ module Langchain::LLM
# gem "ruby-openai", "~> 6.3.0"
#
# Usage:
# openai = Langchain::LLM::OpenAI.new(
# llm = Langchain::LLM::OpenAI.new(
# api_key: ENV["OPENAI_API_KEY"],
# llm_options: {}, # Available options: https://github.com/alexrudall/ruby-openai/blob/main/lib/openai/client.rb#L5-L13
# default_options: {}
Expand Down Expand Up @@ -100,7 +100,7 @@ def embed(
# @param params [Hash] The parameters to pass to the `chat()` method
# @return [Langchain::LLM::OpenAIResponse] Response object
def complete(prompt:, **params)
warn "DEPRECATED: `Langchain::LLM::OpenAI#complete` is deprecated, and will be removed in the next major version. Use `Langchain::LLM::OpenAI#chat` instead."
Langchain.logger.warn "DEPRECATED: `Langchain::LLM::OpenAI#complete` is deprecated, and will be removed in the next major version. Use `Langchain::LLM::OpenAI#chat` instead."

if params[:stop_sequences]
params[:stop] = params.delete(:stop_sequences)
Expand Down
1 change: 1 addition & 0 deletions lib/langchain/llm/parameters/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class Chat < SimpleDelegator
# Function-calling
tools: {default: []},
tool_choice: {},
parallel_tool_calls: {},

# Additional optional parameters
logit_bias: {}
Expand Down
12 changes: 2 additions & 10 deletions lib/langchain/llm/replicate.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,8 @@ module Langchain::LLM
# Gem requirements:
# gem "replicate-ruby", "~> 0.2.2"
#
# Use it directly:
# replicate = Langchain::LLM::Replicate.new(api_key: ENV["REPLICATE_API_KEY"])
#
# Or pass it to be used by a vector search DB:
# chroma = Langchain::Vectorsearch::Chroma.new(
# url: ENV["CHROMA_URL"],
# index_name: "...",
# llm: replicate
# )
#
# Usage:
# llm = Langchain::LLM::Replicate.new(api_key: ENV["REPLICATE_API_KEY"])
class Replicate < Base
DEFAULTS = {
# TODO: Figure out how to send the temperature to the API
Expand Down
Loading

0 comments on commit 49161bb

Please sign in to comment.