Skip to content

Commit

Permalink
Fix llms (#2003)
Browse files Browse the repository at this point in the history
* iwp

* add in api_base

---------

Co-authored-by: Lorenze Jay <[email protected]>
  • Loading branch information
bhancockio and lorenzejay authored Jan 30, 2025
1 parent 7bed63a commit 477cce3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
5 changes: 4 additions & 1 deletion src/crewai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
logprobs: Optional[int] = None,
top_logprobs: Optional[int] = None,
base_url: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
callbacks: List[Any] = [],
Expand All @@ -152,6 +153,7 @@ def __init__(
self.logprobs = logprobs
self.top_logprobs = top_logprobs
self.base_url = base_url
self.api_base = api_base
self.api_version = api_version
self.api_key = api_key
self.callbacks = callbacks
Expand Down Expand Up @@ -232,7 +234,8 @@ def call(
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs,
"api_base": self.base_url,
"api_base": self.api_base,
"base_url": self.base_url,
"api_version": self.api_version,
"api_key": self.api_key,
"stream": False,
Expand Down
17 changes: 15 additions & 2 deletions src/crewai/utilities/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def create_llm(
timeout: Optional[float] = getattr(llm_value, "timeout", None)
api_key: Optional[str] = getattr(llm_value, "api_key", None)
base_url: Optional[str] = getattr(llm_value, "base_url", None)
api_base: Optional[str] = getattr(llm_value, "api_base", None)

created_llm = LLM(
model=model,
Expand All @@ -62,6 +63,7 @@ def create_llm(
timeout=timeout,
api_key=api_key,
base_url=base_url,
api_base=api_base,
)
return created_llm
except Exception as e:
Expand Down Expand Up @@ -101,8 +103,18 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
callbacks: List[Any] = []

# Optional base URL from env
api_base = os.environ.get("OPENAI_API_BASE") or os.environ.get("OPENAI_BASE_URL")
if api_base:
base_url = (
os.environ.get("BASE_URL")
or os.environ.get("OPENAI_API_BASE")
or os.environ.get("OPENAI_BASE_URL")
)

api_base = os.environ.get("API_BASE") or os.environ.get("AZURE_API_BASE")

# Synchronize base_url and api_base if one is populated and the other is not
if base_url and not api_base:
api_base = base_url
elif api_base and not base_url:
base_url = api_base

# Initialize llm_params dictionary
Expand All @@ -115,6 +127,7 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
"timeout": timeout,
"api_key": api_key,
"base_url": base_url,
"api_base": api_base,
"api_version": api_version,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
Expand Down

0 comments on commit 477cce3

Please sign in to comment.