diff --git a/src/magentic/chat_model/mistral_chat_model.py b/src/magentic/chat_model/mistral_chat_model.py index cd9009b3..8b6a3a18 100644 --- a/src/magentic/chat_model/mistral_chat_model.py +++ b/src/magentic/chat_model/mistral_chat_model.py @@ -42,6 +42,11 @@ def _get_tool_choice( # type: ignore[override] """ return openai.NOT_GIVEN if allow_string_output else _MistralToolChoice.ANY.value + def _get_parallel_tool_calls( + self, *, tools_specified: bool, output_types: Iterable[type] + ) -> bool | openai.NotGiven: + return openai.NOT_GIVEN + R = TypeVar("R") diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index b366b712..5ea82e0f 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -435,6 +435,19 @@ def _get_tool_choice( return tool_schemas[0].as_tool_choice() return "required" + def _get_parallel_tool_calls( + self, *, tools_specified: bool, output_types: Iterable[type] + ) -> bool | openai.NotGiven: + if not tools_specified: # Enforced by OpenAI API + return openai.NOT_GIVEN + if self.api_type == "azure": + return openai.NOT_GIVEN + if is_any_origin_subclass(output_types, ParallelFunctionCall): + return openai.NOT_GIVEN + if is_any_origin_subclass(output_types, AsyncParallelFunctionCall): + return openai.NOT_GIVEN + return False + @overload def complete( self, @@ -496,6 +509,9 @@ def complete( tool_choice=self._get_tool_choice( tool_schemas=tool_schemas, allow_string_output=allow_string_output ), + parallel_tool_calls=self._get_parallel_tool_calls( + tools_specified=bool(tool_schemas), output_types=output_types + ), ) usage_ref, response = _create_usage_ref(response) @@ -607,6 +623,9 @@ async def acomplete( tool_choice=self._get_tool_choice( tool_schemas=tool_schemas, allow_string_output=allow_string_output ), + parallel_tool_calls=self._get_parallel_tool_calls( + tools_specified=bool(tool_schemas), output_types=output_types + ), ) usage_ref, response = _create_usage_ref_async(response)