diff --git a/CHANGELOG.md b/CHANGELOG.md index dd65a498a..c69dc8afa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ - [BREAKING] Remove `Langchain::Assistant#clear_thread!` method - [BREAKING] `Langchain::Messages::*` namespace had migrated to `Langchain::Assistant::Messages::*` - [BREAKING] Modify `Langchain::LLM::AwsBedrock` constructor to pass model options via default_options: {...} +- Introduce `Langchain::Assistant#parallel_tool_calls` options whether to allow the LLM to make multiple parallel tool calls. Default: true - Minor improvements to the Langchain::Assistant class - Added support for streaming with Anthropic - Bump anthropic gem diff --git a/README.md b/README.md index 89a5e436a..78f027683 100644 --- a/README.md +++ b/README.md @@ -534,6 +534,7 @@ Note that streaming is not currently supported for all LLMs. * `tools`: An array of tool instances (optional) * `instructions`: System instructions for the assistant (optional) * `tool_choice`: Specifies how tools should be selected. Default: "auto". A specific tool function name can be passed. This will force the Assistant to **always** use this function. +* `parallel_tool_calls`: Whether to make multiple parallel tool calls. Default: true * `add_message_callback`: A callback function (proc, lambda) that is called when any message is added to the conversation (optional) ### Key Methods diff --git a/lib/langchain/assistant.rb b/lib/langchain/assistant.rb index 9d0d020bf..6cf976e9f 100644 --- a/lib/langchain/assistant.rb +++ b/lib/langchain/assistant.rb @@ -12,9 +12,19 @@ module Langchain # tools: [Langchain::Tool::NewsRetriever.new(api_key: ENV["NEWS_API_KEY"])] # ) class Assistant - attr_reader :llm, :instructions, :state, :llm_adapter, :tool_choice - attr_reader :total_prompt_tokens, :total_completion_tokens, :total_tokens, :messages - attr_accessor :tools, :add_message_callback + attr_reader :llm, + :instructions, + :state, + :llm_adapter, + :messages, + :tool_choice, + :total_prompt_tokens, + :total_completion_tokens, + :total_tokens + + attr_accessor :tools, + :add_message_callback, + :parallel_tool_calls # Create a new assistant # @@ -22,12 +32,15 @@ class Assistant # @param tools [Array] Tools that the assistant has access to # @param instructions [String] The system instructions # @param tool_choice [String] Specify how tools should be selected. Options: "auto", "any", "none", or - # @params add_message_callback [Proc] A callback function (Proc or lambda) that is called when any message is added to the conversation + # @param parallel_tool_calls [Boolean] Whether or not to run tools in parallel + # @param messages [Array] The messages + # @param add_message_callback [Proc] A callback function (Proc or lambda) that is called when any message is added to the conversation def initialize( llm:, tools: [], instructions: nil, tool_choice: "auto", + parallel_tool_calls: true, messages: [], add_message_callback: nil, &block @@ -47,6 +60,7 @@ def initialize( self.messages = messages @tools = tools + @parallel_tool_calls = parallel_tool_calls self.tool_choice = tool_choice self.instructions = instructions @block = block @@ -326,7 +340,8 @@ def chat_with_llm instructions: @instructions, messages: array_of_message_hashes, tools: @tools, - tool_choice: tool_choice + tool_choice: tool_choice, + parallel_tool_calls: parallel_tool_calls ) @llm.chat(**params, &@block) end diff --git a/lib/langchain/assistant/llm/adapters/anthropic.rb b/lib/langchain/assistant/llm/adapters/anthropic.rb index 69e1f7630..c33c8cbab 100644 --- a/lib/langchain/assistant/llm/adapters/anthropic.rb +++ b/lib/langchain/assistant/llm/adapters/anthropic.rb @@ -7,16 +7,23 @@ module Adapters class Anthropic < Base # Build the chat parameters for the Anthropic API # - # @param tools [Array] The tools to use - # @param instructions [String] The system instructions # @param messages [Array] The messages + # @param instructions [String] The system instructions + # @param tools [Array] The tools to use # @param tool_choice [String] The tool choice + # @param parallel_tool_calls [Boolean] Whether to make parallel tool calls # @return [Hash] The chat parameters - def build_chat_params(tools:, instructions:, messages:, tool_choice:) + def build_chat_params( + messages:, + instructions:, + tools:, + tool_choice:, + parallel_tool_calls: + ) params = {messages: messages} if tools.any? params[:tools] = build_tools(tools) - params[:tool_choice] = build_tool_choice(tool_choice) + params[:tool_choice] = build_tool_choice(tool_choice, parallel_tool_calls) end params[:system] = instructions if instructions params @@ -31,7 +38,7 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:) # @param tool_call_id [String] The tool call ID # @return [Messages::AnthropicMessage] The Anthropic message def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil) - warn "Image URL is not supported by Anthropic currently" if image_url + Langchain.logger.warn "WARNING: Image URL is not supported by Anthropic currently" if image_url Messages::AnthropicMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id) end @@ -76,15 +83,20 @@ def support_system_message? private - def build_tool_choice(choice) + def build_tool_choice(choice, parallel_tool_calls) + tool_choice_object = {disable_parallel_tool_use: !parallel_tool_calls} + case choice when "auto" - {type: "auto"} + tool_choice_object[:type] = "auto" when "any" - {type: "any"} + tool_choice_object[:type] = "any" else - {type: "tool", name: choice} + tool_choice_object[:type] = "tool" + tool_choice_object[:name] = choice end + + tool_choice_object end end end diff --git a/lib/langchain/assistant/llm/adapters/base.rb b/lib/langchain/assistant/llm/adapters/base.rb index f571c8319..d74fb1d7b 100644 --- a/lib/langchain/assistant/llm/adapters/base.rb +++ b/lib/langchain/assistant/llm/adapters/base.rb @@ -7,12 +7,19 @@ module Adapters class Base # Build the chat parameters for the LLM # - # @param tools [Array] The tools to use - # @param instructions [String] The system instructions # @param messages [Array] The messages + # @param instructions [String] The system instructions + # @param tools [Array] The tools to use # @param tool_choice [String] The tool choice + # @param parallel_tool_calls [Boolean] Whether to make parallel tool calls # @return [Hash] The chat parameters - def build_chat_params(tools:, instructions:, messages:, tool_choice:) + def build_chat_params( + messages:, + instructions:, + tools:, + tool_choice:, + parallel_tool_calls: + ) raise NotImplementedError, "Subclasses must implement build_chat_params" end diff --git a/lib/langchain/assistant/llm/adapters/google_gemini.rb b/lib/langchain/assistant/llm/adapters/google_gemini.rb index 249203746..2bbb955d3 100644 --- a/lib/langchain/assistant/llm/adapters/google_gemini.rb +++ b/lib/langchain/assistant/llm/adapters/google_gemini.rb @@ -7,12 +7,21 @@ module Adapters class GoogleGemini < Base # Build the chat parameters for the Google Gemini LLM # - # @param tools [Array] The tools to use - # @param instructions [String] The system instructions # @param messages [Array] The messages + # @param instructions [String] The system instructions + # @param tools [Array] The tools to use # @param tool_choice [String] The tool choice + # @param parallel_tool_calls [Boolean] Whether to make parallel tool calls # @return [Hash] The chat parameters - def build_chat_params(tools:, instructions:, messages:, tool_choice:) + def build_chat_params( + messages:, + instructions:, + tools:, + tool_choice:, + parallel_tool_calls: + ) + Langchain.logger.warn "WARNING: `parallel_tool_calls:` is not supported by Google Gemini currently" + params = {messages: messages} if tools.any? params[:tools] = build_tools(tools) @@ -31,7 +40,7 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:) # @param tool_call_id [String] The tool call ID # @return [Messages::GoogleGeminiMessage] The Google Gemini message def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil) - warn "Image URL is not supported by Google Gemini" if image_url + Langchain.logger.warn "Image URL is not supported by Google Gemini" if image_url Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id) end diff --git a/lib/langchain/assistant/llm/adapters/mistral_ai.rb b/lib/langchain/assistant/llm/adapters/mistral_ai.rb index c421d2a57..a7cd65f28 100644 --- a/lib/langchain/assistant/llm/adapters/mistral_ai.rb +++ b/lib/langchain/assistant/llm/adapters/mistral_ai.rb @@ -7,12 +7,21 @@ module Adapters class MistralAI < Base # Build the chat parameters for the Mistral AI LLM # - # @param tools [Array] The tools to use - # @param instructions [String] The system instructions # @param messages [Array] The messages + # @param instructions [String] The system instructions + # @param tools [Array] The tools to use # @param tool_choice [String] The tool choice + # @param parallel_tool_calls [Boolean] Whether to make parallel tool calls # @return [Hash] The chat parameters - def build_chat_params(tools:, instructions:, messages:, tool_choice:) + def build_chat_params( + messages:, + instructions:, + tools:, + tool_choice:, + parallel_tool_calls: + ) + Langchain.logger.warn "WARNING: `parallel_tool_calls:` is not supported by Mistral AI currently" + params = {messages: messages} if tools.any? params[:tools] = build_tools(tools) diff --git a/lib/langchain/assistant/llm/adapters/ollama.rb b/lib/langchain/assistant/llm/adapters/ollama.rb index 7e6e4d4f7..fbc01ab93 100644 --- a/lib/langchain/assistant/llm/adapters/ollama.rb +++ b/lib/langchain/assistant/llm/adapters/ollama.rb @@ -7,12 +7,22 @@ module Adapters class Ollama < Base # Build the chat parameters for the Ollama LLM # - # @param tools [Array] The tools to use - # @param instructions [String] The system instructions # @param messages [Array] The messages + # @param instructions [String] The system instructions + # @param tools [Array] The tools to use # @param tool_choice [String] The tool choice + # @param parallel_tool_calls [Boolean] Whether to make parallel tool calls # @return [Hash] The chat parameters - def build_chat_params(tools:, instructions:, messages:, tool_choice:) + def build_chat_params( + messages:, + instructions:, + tools:, + tool_choice:, + parallel_tool_calls: + ) + Langchain.logger.warn "WARNING: `parallel_tool_calls:` is not supported by Ollama currently" + Langchain.logger.warn "WARNING: `tool_choice:` is not supported by Ollama currently" + params = {messages: messages} if tools.any? params[:tools] = build_tools(tools) @@ -29,7 +39,7 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:) # @param tool_call_id [String] The tool call ID # @return [Messages::OllamaMessage] The Ollama message def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil) - warn "Image URL is not supported by Ollama currently" if image_url + Langchain.logger.warn "WARNING: Image URL is not supported by Ollama currently" if image_url Messages::OllamaMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id) end diff --git a/lib/langchain/assistant/llm/adapters/openai.rb b/lib/langchain/assistant/llm/adapters/openai.rb index 7c710c131..44fed0197 100644 --- a/lib/langchain/assistant/llm/adapters/openai.rb +++ b/lib/langchain/assistant/llm/adapters/openai.rb @@ -7,16 +7,24 @@ module Adapters class OpenAI < Base # Build the chat parameters for the OpenAI LLM # - # @param tools [Array] The tools to use - # @param instructions [String] The system instructions # @param messages [Array] The messages + # @param instructions [String] The system instructions + # @param tools [Array] The tools to use # @param tool_choice [String] The tool choice + # @param parallel_tool_calls [Boolean] Whether to make parallel tool calls # @return [Hash] The chat parameters - def build_chat_params(tools:, instructions:, messages:, tool_choice:) + def build_chat_params( + messages:, + instructions:, + tools:, + tool_choice:, + parallel_tool_calls: + ) params = {messages: messages} if tools.any? params[:tools] = build_tools(tools) params[:tool_choice] = build_tool_choice(tool_choice) + params[:parallel_tool_calls] = parallel_tool_calls end params end diff --git a/lib/langchain/llm/ai21.rb b/lib/langchain/llm/ai21.rb index 107711a8e..58a2292d2 100644 --- a/lib/langchain/llm/ai21.rb +++ b/lib/langchain/llm/ai21.rb @@ -8,7 +8,7 @@ module Langchain::LLM # gem "ai21", "~> 0.2.1" # # Usage: - # ai21 = Langchain::LLM::AI21.new(api_key: ENV["AI21_API_KEY"]) + # llm = Langchain::LLM::AI21.new(api_key: ENV["AI21_API_KEY"]) # class AI21 < Base DEFAULTS = { diff --git a/lib/langchain/llm/anthropic.rb b/lib/langchain/llm/anthropic.rb index feedc75d0..97f3b9f7c 100644 --- a/lib/langchain/llm/anthropic.rb +++ b/lib/langchain/llm/anthropic.rb @@ -8,7 +8,7 @@ module Langchain::LLM # gem "anthropic", "~> 0.3.2" # # Usage: - # anthropic = Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"]) + # llm = Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"]) # class Anthropic < Base DEFAULTS = { diff --git a/lib/langchain/llm/aws_bedrock.rb b/lib/langchain/llm/aws_bedrock.rb index 2e9333287..64fdc954d 100644 --- a/lib/langchain/llm/aws_bedrock.rb +++ b/lib/langchain/llm/aws_bedrock.rb @@ -7,7 +7,7 @@ module Langchain::LLM # gem 'aws-sdk-bedrockruntime', '~> 1.1' # # Usage: - # bedrock = Langchain::LLM::AwsBedrock.new(llm_options: {}) + # llm = Langchain::LLM::AwsBedrock.new(llm_options: {}) # class AwsBedrock < Base DEFAULTS = { diff --git a/lib/langchain/llm/azure.rb b/lib/langchain/llm/azure.rb index 4c79f27e4..d8f022b9b 100644 --- a/lib/langchain/llm/azure.rb +++ b/lib/langchain/llm/azure.rb @@ -7,7 +7,7 @@ module Langchain::LLM # gem "ruby-openai", "~> 6.3.0" # # Usage: - # openai = Langchain::LLM::Azure.new(api_key:, llm_options: {}, embedding_deployment_url: chat_deployment_url:) + # llm = Langchain::LLM::Azure.new(api_key:, llm_options: {}, embedding_deployment_url: chat_deployment_url:) # class Azure < OpenAI attr_reader :embed_client diff --git a/lib/langchain/llm/hugging_face.rb b/lib/langchain/llm/hugging_face.rb index aa1eb4c1c..021149de5 100644 --- a/lib/langchain/llm/hugging_face.rb +++ b/lib/langchain/llm/hugging_face.rb @@ -8,7 +8,7 @@ module Langchain::LLM # gem "hugging-face", "~> 0.3.4" # # Usage: - # hf = Langchain::LLM::HuggingFace.new(api_key: ENV["HUGGING_FACE_API_KEY"]) + # llm = Langchain::LLM::HuggingFace.new(api_key: ENV["HUGGING_FACE_API_KEY"]) # class HuggingFace < Base DEFAULTS = { diff --git a/lib/langchain/llm/ollama.rb b/lib/langchain/llm/ollama.rb index fa7a04178..31d0bbd13 100644 --- a/lib/langchain/llm/ollama.rb +++ b/lib/langchain/llm/ollama.rb @@ -5,7 +5,6 @@ module Langchain::LLM # Available models: https://ollama.ai/library # # Usage: - # llm = Langchain::LLM::Ollama.new # llm = Langchain::LLM::Ollama.new(url: ENV["OLLAMA_URL"], default_options: {}) # class Ollama < Base diff --git a/lib/langchain/llm/openai.rb b/lib/langchain/llm/openai.rb index 591b36c6e..771a85a84 100644 --- a/lib/langchain/llm/openai.rb +++ b/lib/langchain/llm/openai.rb @@ -7,7 +7,7 @@ module Langchain::LLM # gem "ruby-openai", "~> 6.3.0" # # Usage: - # openai = Langchain::LLM::OpenAI.new( + # llm = Langchain::LLM::OpenAI.new( # api_key: ENV["OPENAI_API_KEY"], # llm_options: {}, # Available options: https://github.com/alexrudall/ruby-openai/blob/main/lib/openai/client.rb#L5-L13 # default_options: {} @@ -100,7 +100,7 @@ def embed( # @param params [Hash] The parameters to pass to the `chat()` method # @return [Langchain::LLM::OpenAIResponse] Response object def complete(prompt:, **params) - warn "DEPRECATED: `Langchain::LLM::OpenAI#complete` is deprecated, and will be removed in the next major version. Use `Langchain::LLM::OpenAI#chat` instead." + Langchain.logger.warn "DEPRECATED: `Langchain::LLM::OpenAI#complete` is deprecated, and will be removed in the next major version. Use `Langchain::LLM::OpenAI#chat` instead." if params[:stop_sequences] params[:stop] = params.delete(:stop_sequences) diff --git a/lib/langchain/llm/parameters/chat.rb b/lib/langchain/llm/parameters/chat.rb index 299c41cca..21e963057 100644 --- a/lib/langchain/llm/parameters/chat.rb +++ b/lib/langchain/llm/parameters/chat.rb @@ -34,6 +34,7 @@ class Chat < SimpleDelegator # Function-calling tools: {default: []}, tool_choice: {}, + parallel_tool_calls: {}, # Additional optional parameters logit_bias: {} diff --git a/lib/langchain/llm/replicate.rb b/lib/langchain/llm/replicate.rb index 7fb682726..7133d0fba 100644 --- a/lib/langchain/llm/replicate.rb +++ b/lib/langchain/llm/replicate.rb @@ -7,16 +7,8 @@ module Langchain::LLM # Gem requirements: # gem "replicate-ruby", "~> 0.2.2" # - # Use it directly: - # replicate = Langchain::LLM::Replicate.new(api_key: ENV["REPLICATE_API_KEY"]) - # - # Or pass it to be used by a vector search DB: - # chroma = Langchain::Vectorsearch::Chroma.new( - # url: ENV["CHROMA_URL"], - # index_name: "...", - # llm: replicate - # ) - # + # Usage: + # llm = Langchain::LLM::Replicate.new(api_key: ENV["REPLICATE_API_KEY"]) class Replicate < Base DEFAULTS = { # TODO: Figure out how to send the temperature to the API diff --git a/spec/langchain/assistant/assistant_spec.rb b/spec/langchain/assistant/assistant_spec.rb index 2c6844df1..43ac054b3 100644 --- a/spec/langchain/assistant/assistant_spec.rb +++ b/spec/langchain/assistant/assistant_spec.rb @@ -28,6 +28,10 @@ it "raises an error if messages array contains non-Langchain::Message instance(s)" do expect { described_class.new(llm: llm, messages: [Langchain::Assistant::Messages::OpenAIMessage.new, "foo"]) }.to raise_error(ArgumentError) end + + it "parallel_tool_calls defaults to true" do + expect(described_class.new(llm: llm).parallel_tool_calls).to eq(true) + end end context "methods" do @@ -228,7 +232,8 @@ {role: "user", content: [{type: "text", text: "Please calculate 2+2"}]} ], tools: calculator.class.function_schemas.to_openai_format, - tool_choice: "auto" + tool_choice: "auto", + parallel_tool_calls: true ) .and_return(Langchain::LLM::OpenAIResponse.new(raw_openai_response)) @@ -280,7 +285,8 @@ {content: [{type: "text", text: "4.0"}], role: "tool", tool_call_id: "call_9TewGANaaIjzY31UCpAAGLeV"} ], tools: calculator.class.function_schemas.to_openai_format, - tool_choice: "auto" + tool_choice: "auto", + parallel_tool_calls: true ) .and_return(Langchain::LLM::OpenAIResponse.new(raw_openai_response2)) @@ -1135,7 +1141,7 @@ .with( messages: [{role: "user", content: "Please calculate 2+2"}], tools: calculator.class.function_schemas.to_anthropic_format, - tool_choice: {type: "auto"}, + tool_choice: {disable_parallel_tool_use: false, type: "auto"}, system: instructions ) .and_return(Langchain::LLM::AnthropicResponse.new(raw_anthropic_response)) @@ -1190,7 +1196,7 @@ {role: "user", content: [{type: "tool_result", tool_use_id: "toolu_014eSx9oBA5DMe8gZqaqcJ3H", content: "4.0"}]} ], tools: calculator.class.function_schemas.to_anthropic_format, - tool_choice: {type: "auto"}, + tool_choice: {disable_parallel_tool_use: false, type: "auto"}, system: instructions ) .and_return(Langchain::LLM::AnthropicResponse.new(raw_anthropic_response2)) diff --git a/spec/langchain/assistant/llm/adapters/anthropic_spec.rb b/spec/langchain/assistant/llm/adapters/anthropic_spec.rb index 1ab6e5215..e52121abe 100644 --- a/spec/langchain/assistant/llm/adapters/anthropic_spec.rb +++ b/spec/langchain/assistant/llm/adapters/anthropic_spec.rb @@ -1,17 +1,46 @@ # frozen_string_literal: true RSpec.describe Langchain::Assistant::LLM::Adapters::Anthropic do - let(:adapter) { described_class.new } + subject { described_class.new } + + describe "#build_chat_params" do + it "returns the chat parameters" do + expect( + subject.build_chat_params( + messages: [{role: "user", content: "Hello"}], + instructions: "Instructions", + tools: [Langchain::Tool::Calculator.new], + tool_choice: "langchain_tool_calculator__execute", + parallel_tool_calls: false + ) + ).to eq({ + messages: [{role: "user", content: "Hello"}], + tools: Langchain::Tool::Calculator.function_schemas.to_anthropic_format, + tool_choice: {disable_parallel_tool_use: true, name: "langchain_tool_calculator__execute", type: "tool"}, + system: "Instructions" + }) + end + end describe "#support_system_message?" do it "returns true" do - expect(adapter.support_system_message?).to eq(false) + expect(subject.support_system_message?).to eq(false) end end describe "#tool_role" do it "returns 'tool'" do - expect(adapter.tool_role).to eq("tool_result") + expect(subject.tool_role).to eq("tool_result") + end + end + + describe "#build_tool_choice" do + it "returns the tool choice object with 'auto'" do + expect(subject.send(:build_tool_choice, "auto", true)).to eq({disable_parallel_tool_use: false, type: "auto"}) + end + + it "returns the tool choice object with selected tool function" do + expect(subject.send(:build_tool_choice, "langchain_tool_calculator__execute", false)).to eq({disable_parallel_tool_use: true, type: "tool", name: "langchain_tool_calculator__execute"}) end end end diff --git a/spec/langchain/assistant/llm/adapters/base_spec.rb b/spec/langchain/assistant/llm/adapters/base_spec.rb index 8c568310b..5f436586c 100644 --- a/spec/langchain/assistant/llm/adapters/base_spec.rb +++ b/spec/langchain/assistant/llm/adapters/base_spec.rb @@ -1,35 +1,33 @@ # frozen_string_literal: true RSpec.describe Langchain::Assistant::LLM::Adapters::Base do - let(:adapter) { described_class.new } - describe "#build_chat_params" do it "raises NotImplementedError" do - expect { adapter.build_chat_params(tools: [], instructions: "", messages: [], tool_choice: "") }.to raise_error(NotImplementedError) + expect { subject.build_chat_params(tools: [], instructions: "", messages: [], tool_choice: "", parallel_tool_calls: false) }.to raise_error(NotImplementedError) end end describe "#extract_tool_call_args" do it "raises NotImplementedError" do - expect { adapter.extract_tool_call_args(tool_call: {}) }.to raise_error(NotImplementedError) + expect { subject.extract_tool_call_args(tool_call: {}) }.to raise_error(NotImplementedError) end end describe "#build_message" do it "raises NotImplementedError" do - expect { adapter.build_message(role: "", content: "", image_url: "", tool_calls: [], tool_call_id: "") }.to raise_error(NotImplementedError) + expect { subject.build_message(role: "", content: "", image_url: "", tool_calls: [], tool_call_id: "") }.to raise_error(NotImplementedError) end end describe "#support_system_message?" do it "raises NotImplementedError" do - expect { adapter.support_system_message? }.to raise_error(NotImplementedError) + expect { subject.support_system_message? }.to raise_error(NotImplementedError) end end describe "#tool_role" do it "raises NotImplementedError" do - expect { adapter.tool_role }.to raise_error(NotImplementedError) + expect { subject.tool_role }.to raise_error(NotImplementedError) end end end diff --git a/spec/langchain/assistant/llm/adapters/google_gemini_spec.rb b/spec/langchain/assistant/llm/adapters/google_gemini_spec.rb index 65dc1efcb..09b0d97e8 100644 --- a/spec/langchain/assistant/llm/adapters/google_gemini_spec.rb +++ b/spec/langchain/assistant/llm/adapters/google_gemini_spec.rb @@ -1,17 +1,34 @@ # frozen_string_literal: true RSpec.describe Langchain::Assistant::LLM::Adapters::GoogleGemini do - let(:adapter) { described_class.new } + describe "#build_chat_params" do + it "returns the chat parameters" do + expect( + subject.build_chat_params( + messages: [{role: "user", content: "Hello"}], + instructions: "Instructions", + tools: [Langchain::Tool::Calculator.new], + tool_choice: "langchain_tool_calculator__execute", + parallel_tool_calls: false + ) + ).to eq({ + messages: [{role: "user", content: "Hello"}], + tools: Langchain::Tool::Calculator.function_schemas.to_google_gemini_format, + tool_choice: {function_calling_config: {allowed_function_names: ["langchain_tool_calculator__execute"], mode: "any"}}, + system: "Instructions" + }) + end + end describe "#support_system_message?" do it "returns true" do - expect(adapter.support_system_message?).to eq(false) + expect(subject.support_system_message?).to eq(false) end end describe "#tool_role" do it "returns 'tool'" do - expect(adapter.tool_role).to eq("function") + expect(subject.tool_role).to eq("function") end end end diff --git a/spec/langchain/assistant/llm/adapters/mistral_ai_spec.rb b/spec/langchain/assistant/llm/adapters/mistral_ai_spec.rb index b449c9612..31ae258d7 100644 --- a/spec/langchain/assistant/llm/adapters/mistral_ai_spec.rb +++ b/spec/langchain/assistant/llm/adapters/mistral_ai_spec.rb @@ -1,17 +1,33 @@ # frozen_string_literal: true RSpec.describe Langchain::Assistant::LLM::Adapters::MistralAI do - let(:adapter) { described_class.new } + describe "#build_chat_params" do + it "returns the chat parameters" do + expect( + subject.build_chat_params( + messages: [{role: "user", content: "Hello"}], + instructions: "Instructions", + tools: [Langchain::Tool::Calculator.new], + tool_choice: "langchain_tool_calculator__execute", + parallel_tool_calls: false + ) + ).to eq({ + messages: [{role: "user", content: "Hello"}], + tools: Langchain::Tool::Calculator.function_schemas.to_openai_format, + tool_choice: {"function" => {"name" => "langchain_tool_calculator__execute"}, "type" => "function"} + }) + end + end describe "#support_system_message?" do it "returns true" do - expect(adapter.support_system_message?).to eq(true) + expect(subject.support_system_message?).to eq(true) end end describe "#tool_role" do it "returns 'tool'" do - expect(adapter.tool_role).to eq("tool") + expect(subject.tool_role).to eq("tool") end end end diff --git a/spec/langchain/assistant/llm/adapters/ollama_spec.rb b/spec/langchain/assistant/llm/adapters/ollama_spec.rb index a317a568c..82ee2ec28 100644 --- a/spec/langchain/assistant/llm/adapters/ollama_spec.rb +++ b/spec/langchain/assistant/llm/adapters/ollama_spec.rb @@ -1,17 +1,32 @@ # frozen_string_literal: true RSpec.describe Langchain::Assistant::LLM::Adapters::Ollama do - let(:adapter) { described_class.new } + describe "#build_chat_params" do + it "returns the chat parameters" do + expect( + subject.build_chat_params( + messages: [{role: "user", content: "Hello"}], + instructions: "Instructions", + tools: [Langchain::Tool::Calculator.new], + tool_choice: "langchain_tool_calculator__execute", + parallel_tool_calls: false + ) + ).to eq({ + messages: [{role: "user", content: "Hello"}], + tools: Langchain::Tool::Calculator.function_schemas.to_openai_format + }) + end + end describe "#support_system_message?" do it "returns true" do - expect(adapter.support_system_message?).to eq(true) + expect(subject.support_system_message?).to eq(true) end end describe "#tool_role" do it "returns 'tool'" do - expect(adapter.tool_role).to eq("tool") + expect(subject.tool_role).to eq("tool") end end end diff --git a/spec/langchain/assistant/llm/adapters/openai_spec.rb b/spec/langchain/assistant/llm/adapters/openai_spec.rb index 103b2bafc..c56c14d57 100644 --- a/spec/langchain/assistant/llm/adapters/openai_spec.rb +++ b/spec/langchain/assistant/llm/adapters/openai_spec.rb @@ -1,17 +1,34 @@ # frozen_string_literal: true RSpec.describe Langchain::Assistant::LLM::Adapters::OpenAI do - let(:adapter) { described_class.new } + describe "#build_chat_params" do + it "returns the chat parameters" do + expect( + subject.build_chat_params( + messages: [{role: "user", content: "Hello"}], + instructions: "Instructions", + tools: [Langchain::Tool::Calculator.new], + tool_choice: "langchain_tool_calculator__execute", + parallel_tool_calls: false + ) + ).to eq({ + messages: [{role: "user", content: "Hello"}], + tools: Langchain::Tool::Calculator.function_schemas.to_openai_format, + tool_choice: {"function" => {"name" => "langchain_tool_calculator__execute"}, "type" => "function"}, + parallel_tool_calls: false + }) + end + end describe "#support_system_message?" do it "returns true" do - expect(adapter.support_system_message?).to eq(true) + expect(subject.support_system_message?).to eq(true) end end describe "#tool_role" do it "returns 'tool'" do - expect(adapter.tool_role).to eq("tool") + expect(subject.tool_role).to eq("tool") end end end