diff --git a/README.md b/README.md index 5fc419c6..16391cfd 100644 --- a/README.md +++ b/README.md @@ -219,6 +219,20 @@ for hero in create_superhero_team("The Food Dudes"): - The `Annotated` type annotation can be used to provide descriptions and other metadata for function parameters. See [the pydantic documentation on using `Field` to describe function arguments](https://docs.pydantic.dev/latest/usage/validation_decorator/#using-field-to-describe-function-arguments). - The `@prompt` and `@prompt_chain` decorators also accept a `model` argument. You can pass an instance of `OpenaiChatModel` (from `magentic.chat_model.openai_chat_model`) to use GPT4 or configure a different temperature. +## Configuration + +The order of precedence of configuration is + +1. Arguments passed when initializing an instance in Python +2. Environment variables + +The following environment variables can be set. + +| Environment Variable | Description | +| --------------------------- | ------------------------- | +| MAGENTIC_OPENAI_MODEL | OpenAI model e.g. "gpt-4" | +| MAGENTIC_OPENAI_TEMPERATURE | OpenAI temperature, float | + ## Type Checking Many type checkers will raise warnings or errors for functions with the `@prompt` decorator due to the function having no body or return value. There are several ways to deal with these. diff --git a/poetry.lock b/poetry.lock index 9507c52c..0af8d5a5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -2135,6 +2135,21 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" +[[package]] +name = "pydantic-settings" +version = "2.0.3" +description = "Settings management using Pydantic" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pydantic_settings-2.0.3-py3-none-any.whl", hash = "sha256:ddd907b066622bd67603b75e2ff791875540dc485b7307c4fffc015719da8625"}, + {file = "pydantic_settings-2.0.3.tar.gz", hash = "sha256:962dc3672495aad6ae96a4390fac7e593591e144625e5112d359f8f67fb75945"}, +] + +[package.dependencies] +pydantic = ">=2.0.1" +python-dotenv = ">=0.21.0" + [[package]] name = "pygments" version = "2.15.1" @@ -2236,6 +2251,20 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "python-dotenv" +version = "1.0.0" +description = "Read key-value pairs from a .env file and set them as environment variables" +optional = false +python-versions = ">=3.8" +files = [ + {file = "python-dotenv-1.0.0.tar.gz", hash = "sha256:a8df96034aae6d2d50a4ebe8216326c61c3eb64836776504fcca410e5937a3ba"}, + {file = "python_dotenv-1.0.0-py3-none-any.whl", hash = "sha256:f5971a9226b701070a4bf2c38c89e5a3f0d64de8debda981d1db98583009122a"}, +] + +[package.extras] +cli = ["click (>=5.0)"] + [[package]] name = "python-json-logger" version = "2.0.7" @@ -3032,4 +3061,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.10" -content-hash = "6ad254b5ba4fe3cae7376c55865874f824d466170007a720e6a13279733987a0" +content-hash = "181195948efad6794679a9c5e61a9a5b0cb7908304f3905797e55853a64163b8" diff --git a/pyproject.toml b/pyproject.toml index 10dedf9f..de8bbded 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ repository = "https://github.com/jackmpcollins/magentic" python = ">=3.10" openai = ">=0.27" pydantic = ">=2.0.0" +pydantic-settings = ">=2.0.0" [tool.poetry.group.dev.dependencies] black = "*" diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index fa4a228b..1d8c057e 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -17,6 +17,7 @@ UserMessage, ) from magentic.function_call import FunctionCall +from magentic.settings import get_settings from magentic.streaming import ( AsyncStreamedStr, StreamedStr, @@ -435,10 +436,22 @@ async def openai_chatcompletion_acreate( class OpenaiChatModel: - def __init__(self, model: str = "gpt-3.5-turbo-0613", temperature: float = 0): + def __init__(self, model: str | None = None, temperature: float | None = None): self._model = model self._temperature = temperature + @property + def model(self) -> str: + if self._model is not None: + return self._model + return get_settings().openai_model + + @property + def temperature(self) -> float | None: + if self._temperature is not None: + return self._temperature + return get_settings().openai_temperature + def complete( self, messages: Iterable[Message[Any]], @@ -463,9 +476,9 @@ def complete( openai_functions = [schema.dict() for schema in function_schemas] response = openai_chatcompletion_create( - model=self._model, + model=self.model, messages=[message_to_openai_message(m) for m in messages], - temperature=self._temperature, + temperature=self.temperature, functions=openai_functions, function_call=( {"name": openai_functions[0]["name"]} @@ -536,9 +549,9 @@ async def acomplete( openai_functions = [schema.dict() for schema in function_schemas] response = await openai_chatcompletion_acreate( - model=self._model, + model=self.model, messages=[message_to_openai_message(m) for m in messages], - temperature=self._temperature, + temperature=self.temperature, functions=openai_functions, function_call=( {"name": openai_functions[0]["name"]} diff --git a/src/magentic/settings.py b/src/magentic/settings.py new file mode 100644 index 00000000..9970d5b1 --- /dev/null +++ b/src/magentic/settings.py @@ -0,0 +1,12 @@ +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + model_config = SettingsConfigDict(env_prefix="MAGENTIC_") + + openai_model: str = "gpt-3.5-turbo" + openai_temperature: float | None = None + + +def get_settings() -> Settings: + return Settings() diff --git a/tests/test_openai_chat_model.py b/tests/test_openai_chat_model.py index f8f1f38d..ce712a5b 100644 --- a/tests/test_openai_chat_model.py +++ b/tests/test_openai_chat_model.py @@ -14,6 +14,7 @@ DictFunctionSchema, FunctionCallFunctionSchema, IterableFunctionSchema, + OpenaiChatModel, ) from magentic.function_call import FunctionCall from magentic.streaming import async_iter @@ -605,3 +606,13 @@ def test_function_call_function_schema_serialize_args( ): serialized_args = FunctionCallFunctionSchema(function).serialize_args(args) assert json.loads(serialized_args) == json.loads(expected_args_str) + + +def test_openai_chat_model_model(monkeypatch): + monkeypatch.setenv("MAGENTIC_OPENAI_MODEL", "gpt-4") + assert OpenaiChatModel().model == "gpt-4" + + +def test_openai_chat_model_temperature(monkeypatch): + monkeypatch.setenv("MAGENTIC_OPENAI_TEMPERATURE", "2") + assert OpenaiChatModel().temperature == 2