Skip to content

Commit

Permalink
Merge branch 'master' into flairNLPgh-3488/save-column-corpus-to-files
Browse files Browse the repository at this point in the history
  • Loading branch information
chelseagzr authored Jul 22, 2024
2 parents 6941510 + 9c4e1d2 commit f1b1d55
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 106 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
- name: Install Torch cpu
run: pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Install Flair dependencies
run: pip install -e .
run: pip install -e .[word-embeddings]
- name: Install unittest dependencies
run: pip install -r requirements-dev.txt
- name: Show installed dependencies
Expand Down
28 changes: 27 additions & 1 deletion flair/class_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import importlib
import inspect
from typing import Iterable, Optional, Type, TypeVar
from types import ModuleType
from typing import Any, Iterable, List, Optional, Type, TypeVar, Union, overload

T = TypeVar("T")

Expand All @@ -17,3 +19,27 @@ def get_state_subclass_by_name(cls: Type[T], cls_name: Optional[str]) -> Type[T]
if sub_cls.__name__ == cls_name:
return sub_cls
raise ValueError(f"Could not find any class with name '{cls_name}'")


@overload
def lazy_import(group: str, module: str, first_symbol: None) -> ModuleType: ...


@overload
def lazy_import(group: str, module: str, first_symbol: str, *symbols: str) -> List[Any]: ...


def lazy_import(
group: str, module: str, first_symbol: Optional[str] = None, *symbols: str
) -> Union[List[Any], ModuleType]:
try:
imported_module = importlib.import_module(module)
except ImportError:
raise ImportError(
f"Could not import {module}. Please install the optional '{group}' dependency. Via 'pip install flair[{group}]'"
)
if first_symbol is None:
return imported_module
symbols = (first_symbol, *symbols)

return [getattr(imported_module, symbol) for symbol in symbols]
1 change: 0 additions & 1 deletion flair/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@

# Expose token embedding classes
from .token import (
BPEmbSerializable,
BytePairEmbeddings,
CharacterEmbeddings,
FastTextEmbeddings,
Expand Down
Loading

0 comments on commit f1b1d55

Please sign in to comment.