Skip to content

Commit

Permalink
chore: enable mypy strict mode (#257)
Browse files Browse the repository at this point in the history
Signed-off-by: gruebel <[email protected]>
Co-authored-by: Federico Bond <[email protected]>
  • Loading branch information
gruebel and federicobond authored Jan 8, 2024
1 parent 30f4e69 commit af9d3da
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 39 deletions.
3 changes: 2 additions & 1 deletion openfeature/_backports/strenum.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import sys

if sys.version_info >= (3, 11):
from enum import StrEnum
# re-export needed for type checking
from enum import StrEnum as StrEnum # noqa: PLC0414
else:
from enum import Enum

Expand Down
10 changes: 5 additions & 5 deletions openfeature/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_client(
return OpenFeatureClient(name=name, version=version, provider=_provider)


def set_provider(provider: AbstractProvider):
def set_provider(provider: AbstractProvider) -> None:
global _provider
if provider is None:
raise GeneralError(error_message="No provider")
Expand All @@ -46,19 +46,19 @@ def get_evaluation_context() -> EvaluationContext:
return _evaluation_context


def set_evaluation_context(evaluation_context: EvaluationContext):
def set_evaluation_context(evaluation_context: EvaluationContext) -> None:
global _evaluation_context
if evaluation_context is None:
raise GeneralError(error_message="No api level evaluation context")
_evaluation_context = evaluation_context


def add_hooks(hooks: typing.List[Hook]):
def add_hooks(hooks: typing.List[Hook]) -> None:
global _hooks
_hooks = _hooks + hooks


def clear_hooks():
def clear_hooks() -> None:
global _hooks
_hooks = []

Expand All @@ -68,5 +68,5 @@ def get_hooks() -> typing.List[Hook]:
return _hooks


def shutdown():
def shutdown() -> None:
_provider.shutdown()
38 changes: 24 additions & 14 deletions openfeature/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,21 @@
FlagResolutionDetails[typing.Union[dict, list]],
],
]
TypeMap = typing.Dict[
FlagType,
typing.Union[
typing.Type[bool],
typing.Type[int],
typing.Type[float],
typing.Type[str],
typing.Tuple[typing.Type[dict], typing.Type[list]],
],
]


@dataclass
class ClientMetadata:
name: str
name: typing.Optional[str]


class OpenFeatureClient:
Expand All @@ -60,17 +70,17 @@ def __init__(
provider: AbstractProvider,
context: typing.Optional[EvaluationContext] = None,
hooks: typing.Optional[typing.List[Hook]] = None,
):
) -> None:
self.name = name
self.version = version
self.context = context or EvaluationContext()
self.hooks = hooks or []
self.provider = provider

def get_metadata(self):
def get_metadata(self) -> ClientMetadata:
return ClientMetadata(name=self.name)

def add_hooks(self, hooks: typing.List[Hook]):
def add_hooks(self, hooks: typing.List[Hook]) -> None:
self.hooks = self.hooks + hooks

def get_boolean_value(
Expand All @@ -93,7 +103,7 @@ def get_boolean_details(
default_value: bool,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails:
) -> FlagEvaluationDetails[bool]:
return self.evaluate_flag_details(
FlagType.BOOLEAN,
flag_key,
Expand Down Expand Up @@ -122,7 +132,7 @@ def get_string_details(
default_value: str,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails:
) -> FlagEvaluationDetails[str]:
return self.evaluate_flag_details(
FlagType.STRING,
flag_key,
Expand Down Expand Up @@ -151,7 +161,7 @@ def get_integer_details(
default_value: int,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails:
) -> FlagEvaluationDetails[int]:
return self.evaluate_flag_details(
FlagType.INTEGER,
flag_key,
Expand Down Expand Up @@ -180,7 +190,7 @@ def get_float_details(
default_value: float,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails:
) -> FlagEvaluationDetails[float]:
return self.evaluate_flag_details(
FlagType.FLOAT,
flag_key,
Expand All @@ -195,7 +205,7 @@ def get_object_value(
default_value: typing.Union[dict, list],
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> dict:
) -> typing.Union[dict, list]:
return self.get_object_details(
flag_key,
default_value,
Expand All @@ -209,7 +219,7 @@ def get_object_details(
default_value: typing.Union[dict, list],
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails:
) -> FlagEvaluationDetails[typing.Union[dict, list]]:
return self.evaluate_flag_details(
FlagType.OBJECT,
flag_key,
Expand All @@ -225,7 +235,7 @@ def evaluate_flag_details(
default_value: typing.Any,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails:
) -> FlagEvaluationDetails[typing.Any]:
"""
Evaluate the flag requested by the user from the clients provider.
Expand Down Expand Up @@ -335,7 +345,7 @@ def _create_provider_evaluation(
flag_key: str,
default_value: typing.Any,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagEvaluationDetails:
) -> FlagEvaluationDetails[typing.Any]:
"""
Encapsulated method to create a FlagEvaluationDetail from a specific provider.
Expand Down Expand Up @@ -384,8 +394,8 @@ def _create_provider_evaluation(
)


def _typecheck_flag_value(value, flag_type):
type_map = {
def _typecheck_flag_value(value: typing.Any, flag_type: FlagType) -> None:
type_map: TypeMap = {
FlagType.BOOLEAN: bool,
FlagType.STRING: str,
FlagType.OBJECT: (dict, list),
Expand Down
13 changes: 9 additions & 4 deletions openfeature/hook/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ def before(
return None

def after(
self, hook_context: HookContext, details: FlagEvaluationDetails, hints: dict
):
self,
hook_context: HookContext,
details: FlagEvaluationDetails[typing.Any],
hints: dict,
) -> None:
"""
Runs after a flag is resolved.
Expand All @@ -58,7 +61,9 @@ def after(
"""
pass

def error(self, hook_context: HookContext, exception: Exception, hints: dict):
def error(
self, hook_context: HookContext, exception: Exception, hints: dict
) -> None:
"""
Run when evaluation encounters an error. Errors thrown will be swallowed.
Expand All @@ -68,7 +73,7 @@ def error(self, hook_context: HookContext, exception: Exception, hints: dict):
"""
pass

def finally_after(self, hook_context: HookContext, hints: dict):
def finally_after(self, hook_context: HookContext, hints: dict) -> None:
"""
Run after flag evaluation, including any error processing.
This will always run. Errors will be swallowed.
Expand Down
32 changes: 22 additions & 10 deletions openfeature/hook/hook_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def error_hooks(
exception: Exception,
hooks: typing.List[Hook],
hints: typing.Optional[typing.Mapping] = None,
):
) -> None:
kwargs = {"hook_context": hook_context, "exception": exception, "hints": hints}
_execute_hooks(
flag_type=flag_type, hooks=hooks, hook_method=HookType.ERROR, **kwargs
Expand All @@ -25,7 +25,7 @@ def after_all_hooks(
hook_context: HookContext,
hooks: typing.List[Hook],
hints: typing.Optional[typing.Mapping] = None,
):
) -> None:
kwargs = {"hook_context": hook_context, "hints": hints}
_execute_hooks(
flag_type=flag_type, hooks=hooks, hook_method=HookType.FINALLY_AFTER, **kwargs
Expand All @@ -35,10 +35,10 @@ def after_all_hooks(
def after_hooks(
flag_type: FlagType,
hook_context: HookContext,
details: FlagEvaluationDetails,
details: FlagEvaluationDetails[typing.Any],
hooks: typing.List[Hook],
hints: typing.Optional[typing.Mapping] = None,
):
) -> None:
kwargs = {"hook_context": hook_context, "details": details, "hints": hints}
_execute_hooks_unchecked(
flag_type=flag_type, hooks=hooks, hook_method=HookType.AFTER, **kwargs
Expand All @@ -55,7 +55,7 @@ def before_hooks(
executed_hooks = _execute_hooks_unchecked(
flag_type=flag_type, hooks=hooks, hook_method=HookType.BEFORE, **kwargs
)
filtered_hooks = list(filter(lambda hook: hook is not None, executed_hooks))
filtered_hooks = [result for result in executed_hooks if result is not None]

if filtered_hooks:
return reduce(lambda a, b: a.merge(b), filtered_hooks)
Expand All @@ -64,7 +64,10 @@ def before_hooks(


def _execute_hooks(
flag_type: FlagType, hooks: typing.List[Hook], hook_method: HookType, **kwargs
flag_type: FlagType,
hooks: typing.List[Hook],
hook_method: HookType,
**kwargs: typing.Any,
) -> list:
"""
Run multiple hooks of any hook type. All of these hooks will be run through an
Expand All @@ -84,8 +87,11 @@ def _execute_hooks(


def _execute_hooks_unchecked(
flag_type: FlagType, hooks, hook_method: HookType, **kwargs
) -> list:
flag_type: FlagType,
hooks: typing.List[Hook],
hook_method: HookType,
**kwargs: typing.Any,
) -> typing.List[typing.Optional[EvaluationContext]]:
"""
Execute a single hook without checking whether an exception is thrown. This is
used in the before and after hooks since any exception will be caught in the
Expand All @@ -104,7 +110,9 @@ def _execute_hooks_unchecked(
]


def _execute_hook_checked(hook: Hook, hook_method: HookType, **kwargs):
def _execute_hook_checked(
hook: Hook, hook_method: HookType, **kwargs: typing.Any
) -> typing.Optional[EvaluationContext]:
"""
Try and run a single hook and catch any exception thrown. This is used in the
after all and error hooks since any error thrown at this point needs to be caught.
Expand All @@ -115,6 +123,10 @@ def _execute_hook_checked(hook: Hook, hook_method: HookType, **kwargs):
:return: the result of the hook method
"""
try:
return getattr(hook, hook_method.value)(**kwargs)
return typing.cast(
"typing.Optional[EvaluationContext]",
getattr(hook, hook_method.value)(**kwargs),
)
except Exception: # pragma: no cover
logging.error(f"Exception when running {hook_method.value} hooks")
return None
6 changes: 3 additions & 3 deletions openfeature/provider/in_memory_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class State(StrEnum):
state: State = State.ENABLED
context_evaluator: typing.Optional[
typing.Callable[
["InMemoryFlag", EvaluationContext], FlagResolutionDetails[T_co]
["InMemoryFlag[T_co]", EvaluationContext], FlagResolutionDetails[T_co]
]
] = None

Expand All @@ -52,15 +52,15 @@ def resolve(
)


FlagStorage = typing.Dict[str, InMemoryFlag]
FlagStorage = typing.Dict[str, InMemoryFlag[typing.Any]]

V = typing.TypeVar("V")


class InMemoryProvider(AbstractProvider):
_flags: FlagStorage

def __init__(self, flags: FlagStorage):
def __init__(self, flags: FlagStorage) -> None:
self._flags = flags.copy()

def get_metadata(self) -> Metadata:
Expand Down
4 changes: 2 additions & 2 deletions openfeature/provider/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@


class AbstractProvider:
def initialize(self, evaluation_context: EvaluationContext):
def initialize(self, evaluation_context: EvaluationContext) -> None:
pass

def shutdown(self):
def shutdown(self) -> None:
pass

@abstractmethod
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ Homepage = "https://github.com/open-feature/python-sdk"
files = "openfeature"
namespace_packages = true
explicit_package_bases = true
pretty = true
strict = true
disallow_any_generics = false

[tool.ruff]
exclude = [
Expand Down

0 comments on commit af9d3da

Please sign in to comment.