diff --git a/server/lorax_server/utils/sources/__init__.py b/server/lorax_server/utils/sources/__init__.py index ba9af397..a47df0af 100644 --- a/server/lorax_server/utils/sources/__init__.py +++ b/server/lorax_server/utils/sources/__init__.py @@ -78,6 +78,9 @@ def get_model_source( api_token: Optional[str] = None, ): if source == HUB: + if not api_token and bool(os.environ.get("LORAX_USE_GLOBAL_HF_TOKEN", 0)): + # User initialized LoRAX to fallback to global HF token if request token is empty + api_token = os.environ.get("HUGGING_FACE_HUB_TOKEN") return HubModelSource(model_id, revision, extension, api_token) elif source == S3: return S3ModelSource(model_id, revision, extension) diff --git a/server/lorax_server/utils/sources/hub.py b/server/lorax_server/utils/sources/hub.py index 0b6f6bb6..ec52b952 100644 --- a/server/lorax_server/utils/sources/hub.py +++ b/server/lorax_server/utils/sources/hub.py @@ -195,7 +195,7 @@ def download_file(self, filename: str, ignore_errors: bool = False) -> Optional[ def get_hub_api(token: Optional[str] = None) -> HfApi: - if token == "" and bool(os.environ.get("LORAX_USE_GLOBAL_HF_TOKEN", 0)): + if not token and bool(os.environ.get("LORAX_USE_GLOBAL_HF_TOKEN", 0)): # User initialized LoRAX to fallback to global HF token if request token is empty token = os.environ.get("HUGGING_FACE_HUB_TOKEN") return HfApi(token=token)