Skip to content

Commit

Permalink
Support setting OpenAI params using environment variables (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackmpcollins authored Sep 10, 2023
1 parent 943019c commit a4ed8ab
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 7 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,20 @@ for hero in create_superhero_team("The Food Dudes"):
- The `Annotated` type annotation can be used to provide descriptions and other metadata for function parameters. See [the pydantic documentation on using `Field` to describe function arguments](https://docs.pydantic.dev/latest/usage/validation_decorator/#using-field-to-describe-function-arguments).
- The `@prompt` and `@prompt_chain` decorators also accept a `model` argument. You can pass an instance of `OpenaiChatModel` (from `magentic.chat_model.openai_chat_model`) to use GPT4 or configure a different temperature.

## Configuration

The order of precedence of configuration is

1. Arguments passed when initializing an instance in Python
2. Environment variables

The following environment variables can be set.

| Environment Variable | Description |
| --------------------------- | ------------------------- |
| MAGENTIC_OPENAI_MODEL | OpenAI model e.g. "gpt-4" |
| MAGENTIC_OPENAI_TEMPERATURE | OpenAI temperature, float |

## Type Checking

Many type checkers will raise warnings or errors for functions with the `@prompt` decorator due to the function having no body or return value. There are several ways to deal with these.
Expand Down
33 changes: 31 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ repository = "https://github.com/jackmpcollins/magentic"
python = ">=3.10"
openai = ">=0.27"
pydantic = ">=2.0.0"
pydantic-settings = ">=2.0.0"

[tool.poetry.group.dev.dependencies]
black = "*"
Expand Down
23 changes: 18 additions & 5 deletions src/magentic/chat_model/openai_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
UserMessage,
)
from magentic.function_call import FunctionCall
from magentic.settings import get_settings
from magentic.streaming import (
AsyncStreamedStr,
StreamedStr,
Expand Down Expand Up @@ -435,10 +436,22 @@ async def openai_chatcompletion_acreate(


class OpenaiChatModel:
def __init__(self, model: str = "gpt-3.5-turbo-0613", temperature: float = 0):
def __init__(self, model: str | None = None, temperature: float | None = None):
self._model = model
self._temperature = temperature

@property
def model(self) -> str:
if self._model is not None:
return self._model
return get_settings().openai_model

@property
def temperature(self) -> float | None:
if self._temperature is not None:
return self._temperature
return get_settings().openai_temperature

def complete(
self,
messages: Iterable[Message[Any]],
Expand All @@ -463,9 +476,9 @@ def complete(

openai_functions = [schema.dict() for schema in function_schemas]
response = openai_chatcompletion_create(
model=self._model,
model=self.model,
messages=[message_to_openai_message(m) for m in messages],
temperature=self._temperature,
temperature=self.temperature,
functions=openai_functions,
function_call=(
{"name": openai_functions[0]["name"]}
Expand Down Expand Up @@ -536,9 +549,9 @@ async def acomplete(

openai_functions = [schema.dict() for schema in function_schemas]
response = await openai_chatcompletion_acreate(
model=self._model,
model=self.model,
messages=[message_to_openai_message(m) for m in messages],
temperature=self._temperature,
temperature=self.temperature,
functions=openai_functions,
function_call=(
{"name": openai_functions[0]["name"]}
Expand Down
12 changes: 12 additions & 0 deletions src/magentic/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from pydantic_settings import BaseSettings, SettingsConfigDict


class Settings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="MAGENTIC_")

openai_model: str = "gpt-3.5-turbo"
openai_temperature: float | None = None


def get_settings() -> Settings:
return Settings()
11 changes: 11 additions & 0 deletions tests/test_openai_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
DictFunctionSchema,
FunctionCallFunctionSchema,
IterableFunctionSchema,
OpenaiChatModel,
)
from magentic.function_call import FunctionCall
from magentic.streaming import async_iter
Expand Down Expand Up @@ -605,3 +606,13 @@ def test_function_call_function_schema_serialize_args(
):
serialized_args = FunctionCallFunctionSchema(function).serialize_args(args)
assert json.loads(serialized_args) == json.loads(expected_args_str)


def test_openai_chat_model_model(monkeypatch):
monkeypatch.setenv("MAGENTIC_OPENAI_MODEL", "gpt-4")
assert OpenaiChatModel().model == "gpt-4"


def test_openai_chat_model_temperature(monkeypatch):
monkeypatch.setenv("MAGENTIC_OPENAI_TEMPERATURE", "2")
assert OpenaiChatModel().temperature == 2

0 comments on commit a4ed8ab

Please sign in to comment.