Skip to content

Commit

Permalink
Add key validation on all models (#1578)
Browse files Browse the repository at this point in the history
## Description

There has been various issues raised around incorrect errors when API
keys are not set. This should address it for all models.

Fixes # (issue)

## Type of change

Please check the options that are relevant:

- [x] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] Model update
- [ ] Infrastructure change

## Checklist

- [x] My code follows Phidata's style guidelines and best practices
- [x] I have performed a self-review of my code
- [x] I have added docstrings and comments for complex logic
- [x] My changes generate no new warnings or errors
- [x] I have added cookbook examples for my new addition (if needed)
- [x] I have updated requirements.txt/pyproject.toml (if needed)
- [x] I have verified my changes in a clean environment

---------

Co-authored-by: Dirk Brand <[email protected]>
  • Loading branch information
dirkbrnd and dirkvolter authored Dec 16, 2024
1 parent b27c973 commit 6d4b9db
Show file tree
Hide file tree
Showing 18 changed files with 76 additions and 38 deletions.
6 changes: 6 additions & 0 deletions phi/llm/openai/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import httpx
from typing import Optional, List, Iterator, Dict, Any, Union, Tuple

Expand Down Expand Up @@ -73,6 +75,10 @@ def get_client(self) -> OpenAIClient:
if self.openai_client:
return self.openai_client

self.api_key = self.api_key or os.getenv("OPENAI_API_KEY")
if not self.api_key:
logger.error("OPENAI_API_KEY not set. Please set the OPENAI_API_KEY environment variable.")

_client_params: Dict[str, Any] = {}
if self.api_key:
_client_params["api_key"] = self.api_key
Expand Down
3 changes: 2 additions & 1 deletion phi/model/InternLM/internlm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from os import getenv
from typing import Optional

from phi.model.openai.like import OpenAILike


Expand All @@ -19,5 +20,5 @@ class InternLM(OpenAILike):
name: str = "InternLM"
provider: str = "InternLM"

api_key: Optional[str] = getenv("INTERNLM_API_KEY")
api_key: Optional[str] = getenv("INTERNLM_API_KEY", None)
base_url: Optional[str] = "https://internlm-chat.intern-ai.org.cn/puyu/api/v1/chat/completions"
10 changes: 7 additions & 3 deletions phi/model/anthropic/claude.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from os import getenv
from dataclasses import dataclass, field
from typing import Optional, List, Iterator, Dict, Any, Union, Tuple

Expand All @@ -18,9 +19,8 @@
RawContentBlockDeltaEvent,
ContentBlockStopEvent,
)
except ImportError:
logger.error("`anthropic` not installed")
raise
except (ModuleNotFoundError, ImportError):
raise ImportError("`anthropic` not installed. Please install using `pip install anthropic`")


@dataclass
Expand Down Expand Up @@ -89,6 +89,10 @@ def get_client(self) -> AnthropicClient:
if self.client:
return self.client

self.api_key = self.api_key or getenv("ANTHROPIC_API_KEY")
if not self.api_key:
logger.error("ANTHROPIC_API_KEY not set. Please set the ANTHROPIC_API_KEY environment variable.")

_client_params: Dict[str, Any] = {}
# Set client parameters if they are provided
if self.api_key:
Expand Down
6 changes: 2 additions & 4 deletions phi/model/azure/openai_chat.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from os import getenv
from typing import Optional, Dict, Any
from phi.utils.log import logger
from phi.model.openai.like import OpenAILike
import httpx

try:
from openai import AzureOpenAI as AzureOpenAIClient
from openai import AsyncAzureOpenAI as AsyncAzureOpenAIClient
except ImportError:
logger.error("`azure openai` not installed")
raise
except (ModuleNotFoundError, ImportError):
raise ImportError("`azure openai` not installed. Please install using `pip install openai`")


class AzureOpenAIChat(OpenAILike):
Expand Down
11 changes: 8 additions & 3 deletions phi/model/cohere/chat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from os import getenv
from dataclasses import dataclass, field
from typing import Optional, List, Iterator, Dict, Any, Tuple

Expand Down Expand Up @@ -28,9 +29,8 @@
)
from cohere.types.api_meta_tokens import ApiMetaTokens
from cohere.types.api_meta import ApiMeta
except ImportError:
logger.error("`cohere` not installed")
raise
except (ModuleNotFoundError, ImportError):
raise ImportError("`cohere` not installed. Please install using `pip install cohere`")


@dataclass
Expand Down Expand Up @@ -72,6 +72,11 @@ def client(self) -> CohereClient:
return self.cohere_client

_client_params: Dict[str, Any] = {}

self.api_key = self.api_key or getenv("CO_API_KEY")
if not self.api_key:
logger.error("CO_API_KEY not set. Please set the CO_API_KEY environment variable.")

if self.api_key:
_client_params["api_key"] = self.api_key
return CohereClient(**_client_params)
Expand Down
2 changes: 1 addition & 1 deletion phi/model/deepseek/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ class DeepSeekChat(OpenAILike):
name: str = "DeepSeekChat"
provider: str = "DeepSeek"

api_key: Optional[str] = getenv("DEEPSEEK_API_KEY")
api_key: Optional[str] = getenv("DEEPSEEK_API_KEY", None)
base_url: str = "https://api.deepseek.com"
2 changes: 1 addition & 1 deletion phi/model/fireworks/fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Fireworks(OpenAILike):
name: str = "Fireworks: " + id
provider: str = "Fireworks"

api_key: Optional[str] = getenv("FIREWORKS_API_KEY")
api_key: Optional[str] = getenv("FIREWORKS_API_KEY", None)
base_url: str = "https://api.fireworks.ai/inference/v1"

def invoke_stream(self, messages: List[Message]) -> Iterator[ChatCompletionChunk]:
Expand Down
13 changes: 8 additions & 5 deletions phi/model/google/gemini.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from os import getenv
import time
import json
from pathlib import Path
Expand All @@ -24,8 +25,7 @@
)
from google.protobuf.struct_pb2 import Struct
except (ModuleNotFoundError, ImportError):
logger.error("`google-generativeai` not installed. Please install it using `pip install google-generativeai`")
raise
raise ImportError("`google-generativeai` not installed. Please install it using `pip install google-generativeai`")


@dataclass
Expand Down Expand Up @@ -103,9 +103,12 @@ def get_client(self) -> GenerativeModel:
return self.client

client_params: Dict[str, Any] = {}
# Set client parameters if they are provided
if self.api_key:
client_params["api_key"] = self.api_key

self.api_key = self.api_key or getenv("GOOGLE_API_KEY")
if not self.api_key:
logger.error("GOOGLE_API_KEY not set. Please set the GOOGLE_API_KEY environment variable.")
client_params["api_key"] = self.api_key

if self.client_params:
client_params.update(self.client_params)
genai.configure(**client_params)
Expand Down
4 changes: 3 additions & 1 deletion phi/model/google/gemini_openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from os import getenv
from typing import Optional


from phi.model.openai.like import OpenAILike


Expand All @@ -19,5 +21,5 @@ class GeminiOpenAIChat(OpenAILike):
name: str = "Gemini"
provider: str = "Google"

api_key: Optional[str] = getenv("GOOGLE_API_KEY")
api_key: Optional[str] = getenv("GOOGLE_API_KEY", None)
base_url: Optional[str] = "https://generativelanguage.googleapis.com/v1beta/"
10 changes: 7 additions & 3 deletions phi/model/groq/groq.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from os import getenv
from dataclasses import dataclass, field
from typing import Optional, List, Iterator, Dict, Any, Union

Expand All @@ -16,9 +17,8 @@
from groq.types.chat import ChatCompletion, ChatCompletionMessage
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDeltaToolCall, ChoiceDelta
from groq.types.completion_usage import CompletionUsage
except ImportError:
logger.error("`groq` not installed")
raise
except (ModuleNotFoundError, ImportError):
raise ImportError("`groq` not installed. Please install using `pip install groq`")


@dataclass
Expand Down Expand Up @@ -104,6 +104,10 @@ class Groq(Model):
async_client: Optional[AsyncGroqClient] = None

def get_client_params(self) -> Dict[str, Any]:
self.api_key = self.api_key or getenv("GROQ_API_KEY")
if not self.api_key:
logger.error("GROQ_API_KEY not set. Please set the GROQ_API_KEY environment variable.")

client_params: Dict[str, Any] = {}
if self.api_key:
client_params["api_key"] = self.api_key
Expand Down
10 changes: 7 additions & 3 deletions phi/model/huggingface/hf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from os import getenv
from dataclasses import dataclass, field
from typing import Optional, List, Iterator, Dict, Any, Union

Expand All @@ -23,9 +24,8 @@
ChatCompletionOutputMessage,
ChatCompletionOutputUsage,
)
except ImportError:
logger.error("`huggingface_hub` not installed")
raise
except (ModuleNotFoundError, ImportError):
raise ImportError("`huggingface_hub` not installed. Please install using `pip install huggingface_hub`")


@dataclass
Expand Down Expand Up @@ -129,6 +129,10 @@ class HuggingFaceChat(Model):
async_client: Optional[AsyncInferenceClient] = None

def get_client_params(self) -> Dict[str, Any]:
self.api_key = self.api_key or getenv("HF_TOKEN")
if not self.api_key:
logger.error("HF_TOKEN not set. Please set the HF_TOKEN environment variable.")

_client_params: Dict[str, Any] = {}
if self.api_key is not None:
_client_params["api_key"] = self.api_key
Expand Down
10 changes: 7 additions & 3 deletions phi/model/mistral/mistral.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from os import getenv
from dataclasses import dataclass, field
from typing import Optional, List, Iterator, Dict, Any, Union

Expand All @@ -14,9 +15,8 @@
from mistralai.models.chatcompletionresponse import ChatCompletionResponse
from mistralai.models.deltamessage import DeltaMessage
from mistralai.types.basemodel import Unset
except ImportError:
logger.error("`mistralai` not installed")
raise
except (ModuleNotFoundError, ImportError):
raise ImportError("`mistralai` not installed. Please install using `pip install mistralai`")

MistralMessage = Union[models.UserMessage, models.AssistantMessage, models.SystemMessage, models.ToolMessage]

Expand Down Expand Up @@ -91,6 +91,10 @@ def client(self) -> Mistral:
if self.mistral_client:
return self.mistral_client

self.api_key = self.api_key or getenv("MISTRAL_API_KEY")
if not self.api_key:
logger.error("MISTRAL_API_KEY not set. Please set the MISTRAL_API_KEY environment variable.")

_client_params: Dict[str, Any] = {}
if self.api_key:
_client_params["api_key"] = self.api_key
Expand Down
2 changes: 1 addition & 1 deletion phi/model/nvidia/nvidia.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ class Nvidia(OpenAILike):
name: str = "Nvidia"
provider: str = "Nvidia " + id

api_key: Optional[str] = getenv("NVIDIA_API_KEY")
api_key: Optional[str] = getenv("NVIDIA_API_KEY", None)
base_url: str = "https://integrate.api.nvidia.com/v1"
5 changes: 2 additions & 3 deletions phi/model/ollama/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

try:
from ollama import Client as OllamaClient, AsyncClient as AsyncOllamaClient
except ImportError:
logger.error("`ollama` not installed")
raise
except (ModuleNotFoundError, ImportError):
raise ImportError("`ollama` not installed. Please install using `pip install ollama`")


@dataclass
Expand Down
11 changes: 8 additions & 3 deletions phi/model/openai/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from os import getenv
from dataclasses import dataclass, field
from typing import Optional, List, Iterator, Dict, Any, Union

Expand All @@ -23,9 +24,8 @@
ChoiceDeltaToolCall,
)
from openai.types.chat.chat_completion_message import ChatCompletionMessage
except ImportError:
logger.error("`openai` not installed")
raise
except (ModuleNotFoundError, ImportError):
raise ImportError("`openai` not installed. Please install using `pip install openai`")


@dataclass
Expand Down Expand Up @@ -118,6 +118,11 @@ class OpenAIChat(Model):

def get_client_params(self) -> Dict[str, Any]:
client_params: Dict[str, Any] = {}

self.api_key = self.api_key or getenv("OPENAI_API_KEY")
if not self.api_key:
logger.error("OPENAI_API_KEY not set. Please set the OPENAI_API_KEY environment variable.")

if self.api_key is not None:
client_params["api_key"] = self.api_key
if self.organization is not None:
Expand Down
7 changes: 4 additions & 3 deletions phi/model/vertexai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
Content,
Part,
)
except ImportError:
logger.error("`google-cloud-aiplatform` not installed")
raise
except (ModuleNotFoundError, ImportError):
raise ImportError(
"`google-cloud-aiplatform` not installed. Please install using `pip install google-cloud-aiplatform`"
)


@dataclass
Expand Down
1 change: 1 addition & 0 deletions phi/model/xai/xai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from os import getenv
from typing import Optional

from phi.model.openai.like import OpenAILike


Expand Down
1 change: 1 addition & 0 deletions phi/workspace/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def get_workspace_objects_from_file(resource_file: Path) -> dict:
"""Returns workspace objects from the resource file"""
from phi.aws.resources import AwsResources
from phi.docker.resources import DockerResources

try:
python_objects = get_python_objects_from_module(resource_file)
# logger.debug(f"python_objects: {python_objects}")
Expand Down

0 comments on commit 6d4b9db

Please sign in to comment.