diff --git a/openfeature/_backports/strenum.py b/openfeature/_backports/strenum.py index ed2d7b34..6dcdfd7d 100644 --- a/openfeature/_backports/strenum.py +++ b/openfeature/_backports/strenum.py @@ -1,7 +1,8 @@ import sys if sys.version_info >= (3, 11): - from enum import StrEnum + # re-export needed for type checking + from enum import StrEnum as StrEnum # noqa: PLC0414 else: from enum import Enum diff --git a/openfeature/api.py b/openfeature/api.py index 902e8b3b..6687bb89 100644 --- a/openfeature/api.py +++ b/openfeature/api.py @@ -21,7 +21,7 @@ def get_client( return OpenFeatureClient(name=name, version=version, provider=_provider) -def set_provider(provider: AbstractProvider): +def set_provider(provider: AbstractProvider) -> None: global _provider if provider is None: raise GeneralError(error_message="No provider") @@ -46,19 +46,19 @@ def get_evaluation_context() -> EvaluationContext: return _evaluation_context -def set_evaluation_context(evaluation_context: EvaluationContext): +def set_evaluation_context(evaluation_context: EvaluationContext) -> None: global _evaluation_context if evaluation_context is None: raise GeneralError(error_message="No api level evaluation context") _evaluation_context = evaluation_context -def add_hooks(hooks: typing.List[Hook]): +def add_hooks(hooks: typing.List[Hook]) -> None: global _hooks _hooks = _hooks + hooks -def clear_hooks(): +def clear_hooks() -> None: global _hooks _hooks = [] @@ -68,5 +68,5 @@ def get_hooks() -> typing.List[Hook]: return _hooks -def shutdown(): +def shutdown() -> None: _provider.shutdown() diff --git a/openfeature/client.py b/openfeature/client.py index cc29897f..9b842272 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -45,11 +45,21 @@ FlagResolutionDetails[typing.Union[dict, list]], ], ] +TypeMap = typing.Dict[ + FlagType, + typing.Union[ + typing.Type[bool], + typing.Type[int], + typing.Type[float], + typing.Type[str], + typing.Tuple[typing.Type[dict], typing.Type[list]], + ], +] @dataclass class ClientMetadata: - name: str + name: typing.Optional[str] class OpenFeatureClient: @@ -60,17 +70,17 @@ def __init__( provider: AbstractProvider, context: typing.Optional[EvaluationContext] = None, hooks: typing.Optional[typing.List[Hook]] = None, - ): + ) -> None: self.name = name self.version = version self.context = context or EvaluationContext() self.hooks = hooks or [] self.provider = provider - def get_metadata(self): + def get_metadata(self) -> ClientMetadata: return ClientMetadata(name=self.name) - def add_hooks(self, hooks: typing.List[Hook]): + def add_hooks(self, hooks: typing.List[Hook]) -> None: self.hooks = self.hooks + hooks def get_boolean_value( @@ -93,7 +103,7 @@ def get_boolean_details( default_value: bool, evaluation_context: typing.Optional[EvaluationContext] = None, flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, - ) -> FlagEvaluationDetails: + ) -> FlagEvaluationDetails[bool]: return self.evaluate_flag_details( FlagType.BOOLEAN, flag_key, @@ -122,7 +132,7 @@ def get_string_details( default_value: str, evaluation_context: typing.Optional[EvaluationContext] = None, flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, - ) -> FlagEvaluationDetails: + ) -> FlagEvaluationDetails[str]: return self.evaluate_flag_details( FlagType.STRING, flag_key, @@ -151,7 +161,7 @@ def get_integer_details( default_value: int, evaluation_context: typing.Optional[EvaluationContext] = None, flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, - ) -> FlagEvaluationDetails: + ) -> FlagEvaluationDetails[int]: return self.evaluate_flag_details( FlagType.INTEGER, flag_key, @@ -180,7 +190,7 @@ def get_float_details( default_value: float, evaluation_context: typing.Optional[EvaluationContext] = None, flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, - ) -> FlagEvaluationDetails: + ) -> FlagEvaluationDetails[float]: return self.evaluate_flag_details( FlagType.FLOAT, flag_key, @@ -195,7 +205,7 @@ def get_object_value( default_value: typing.Union[dict, list], evaluation_context: typing.Optional[EvaluationContext] = None, flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, - ) -> dict: + ) -> typing.Union[dict, list]: return self.get_object_details( flag_key, default_value, @@ -209,7 +219,7 @@ def get_object_details( default_value: typing.Union[dict, list], evaluation_context: typing.Optional[EvaluationContext] = None, flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, - ) -> FlagEvaluationDetails: + ) -> FlagEvaluationDetails[typing.Union[dict, list]]: return self.evaluate_flag_details( FlagType.OBJECT, flag_key, @@ -225,7 +235,7 @@ def evaluate_flag_details( default_value: typing.Any, evaluation_context: typing.Optional[EvaluationContext] = None, flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, - ) -> FlagEvaluationDetails: + ) -> FlagEvaluationDetails[typing.Any]: """ Evaluate the flag requested by the user from the clients provider. @@ -335,7 +345,7 @@ def _create_provider_evaluation( flag_key: str, default_value: typing.Any, evaluation_context: typing.Optional[EvaluationContext] = None, - ) -> FlagEvaluationDetails: + ) -> FlagEvaluationDetails[typing.Any]: """ Encapsulated method to create a FlagEvaluationDetail from a specific provider. @@ -384,8 +394,8 @@ def _create_provider_evaluation( ) -def _typecheck_flag_value(value, flag_type): - type_map = { +def _typecheck_flag_value(value: typing.Any, flag_type: FlagType) -> None: + type_map: TypeMap = { FlagType.BOOLEAN: bool, FlagType.STRING: str, FlagType.OBJECT: (dict, list), diff --git a/openfeature/hook/__init__.py b/openfeature/hook/__init__.py index 590843bf..79c20184 100644 --- a/openfeature/hook/__init__.py +++ b/openfeature/hook/__init__.py @@ -46,8 +46,11 @@ def before( return None def after( - self, hook_context: HookContext, details: FlagEvaluationDetails, hints: dict - ): + self, + hook_context: HookContext, + details: FlagEvaluationDetails[typing.Any], + hints: dict, + ) -> None: """ Runs after a flag is resolved. @@ -58,7 +61,9 @@ def after( """ pass - def error(self, hook_context: HookContext, exception: Exception, hints: dict): + def error( + self, hook_context: HookContext, exception: Exception, hints: dict + ) -> None: """ Run when evaluation encounters an error. Errors thrown will be swallowed. @@ -68,7 +73,7 @@ def error(self, hook_context: HookContext, exception: Exception, hints: dict): """ pass - def finally_after(self, hook_context: HookContext, hints: dict): + def finally_after(self, hook_context: HookContext, hints: dict) -> None: """ Run after flag evaluation, including any error processing. This will always run. Errors will be swallowed. diff --git a/openfeature/hook/hook_support.py b/openfeature/hook/hook_support.py index e1432c80..1e530e6a 100644 --- a/openfeature/hook/hook_support.py +++ b/openfeature/hook/hook_support.py @@ -13,7 +13,7 @@ def error_hooks( exception: Exception, hooks: typing.List[Hook], hints: typing.Optional[typing.Mapping] = None, -): +) -> None: kwargs = {"hook_context": hook_context, "exception": exception, "hints": hints} _execute_hooks( flag_type=flag_type, hooks=hooks, hook_method=HookType.ERROR, **kwargs @@ -25,7 +25,7 @@ def after_all_hooks( hook_context: HookContext, hooks: typing.List[Hook], hints: typing.Optional[typing.Mapping] = None, -): +) -> None: kwargs = {"hook_context": hook_context, "hints": hints} _execute_hooks( flag_type=flag_type, hooks=hooks, hook_method=HookType.FINALLY_AFTER, **kwargs @@ -35,10 +35,10 @@ def after_all_hooks( def after_hooks( flag_type: FlagType, hook_context: HookContext, - details: FlagEvaluationDetails, + details: FlagEvaluationDetails[typing.Any], hooks: typing.List[Hook], hints: typing.Optional[typing.Mapping] = None, -): +) -> None: kwargs = {"hook_context": hook_context, "details": details, "hints": hints} _execute_hooks_unchecked( flag_type=flag_type, hooks=hooks, hook_method=HookType.AFTER, **kwargs @@ -55,7 +55,7 @@ def before_hooks( executed_hooks = _execute_hooks_unchecked( flag_type=flag_type, hooks=hooks, hook_method=HookType.BEFORE, **kwargs ) - filtered_hooks = list(filter(lambda hook: hook is not None, executed_hooks)) + filtered_hooks = [result for result in executed_hooks if result is not None] if filtered_hooks: return reduce(lambda a, b: a.merge(b), filtered_hooks) @@ -64,7 +64,10 @@ def before_hooks( def _execute_hooks( - flag_type: FlagType, hooks: typing.List[Hook], hook_method: HookType, **kwargs + flag_type: FlagType, + hooks: typing.List[Hook], + hook_method: HookType, + **kwargs: typing.Any, ) -> list: """ Run multiple hooks of any hook type. All of these hooks will be run through an @@ -84,8 +87,11 @@ def _execute_hooks( def _execute_hooks_unchecked( - flag_type: FlagType, hooks, hook_method: HookType, **kwargs -) -> list: + flag_type: FlagType, + hooks: typing.List[Hook], + hook_method: HookType, + **kwargs: typing.Any, +) -> typing.List[typing.Optional[EvaluationContext]]: """ Execute a single hook without checking whether an exception is thrown. This is used in the before and after hooks since any exception will be caught in the @@ -104,7 +110,9 @@ def _execute_hooks_unchecked( ] -def _execute_hook_checked(hook: Hook, hook_method: HookType, **kwargs): +def _execute_hook_checked( + hook: Hook, hook_method: HookType, **kwargs: typing.Any +) -> typing.Optional[EvaluationContext]: """ Try and run a single hook and catch any exception thrown. This is used in the after all and error hooks since any error thrown at this point needs to be caught. @@ -115,6 +123,10 @@ def _execute_hook_checked(hook: Hook, hook_method: HookType, **kwargs): :return: the result of the hook method """ try: - return getattr(hook, hook_method.value)(**kwargs) + return typing.cast( + "typing.Optional[EvaluationContext]", + getattr(hook, hook_method.value)(**kwargs), + ) except Exception: # pragma: no cover logging.error(f"Exception when running {hook_method.value} hooks") + return None diff --git a/openfeature/provider/in_memory_provider.py b/openfeature/provider/in_memory_provider.py index 2495d2dc..7fe7e735 100644 --- a/openfeature/provider/in_memory_provider.py +++ b/openfeature/provider/in_memory_provider.py @@ -32,7 +32,7 @@ class State(StrEnum): state: State = State.ENABLED context_evaluator: typing.Optional[ typing.Callable[ - ["InMemoryFlag", EvaluationContext], FlagResolutionDetails[T_co] + ["InMemoryFlag[T_co]", EvaluationContext], FlagResolutionDetails[T_co] ] ] = None @@ -52,7 +52,7 @@ def resolve( ) -FlagStorage = typing.Dict[str, InMemoryFlag] +FlagStorage = typing.Dict[str, InMemoryFlag[typing.Any]] V = typing.TypeVar("V") @@ -60,7 +60,7 @@ def resolve( class InMemoryProvider(AbstractProvider): _flags: FlagStorage - def __init__(self, flags: FlagStorage): + def __init__(self, flags: FlagStorage) -> None: self._flags = flags.copy() def get_metadata(self) -> Metadata: diff --git a/openfeature/provider/provider.py b/openfeature/provider/provider.py index 73ce37b5..625b7f6b 100644 --- a/openfeature/provider/provider.py +++ b/openfeature/provider/provider.py @@ -8,10 +8,10 @@ class AbstractProvider: - def initialize(self, evaluation_context: EvaluationContext): + def initialize(self, evaluation_context: EvaluationContext) -> None: pass - def shutdown(self): + def shutdown(self) -> None: pass @abstractmethod diff --git a/pyproject.toml b/pyproject.toml index c92bfdfd..bd236e71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,9 @@ Homepage = "https://github.com/open-feature/python-sdk" files = "openfeature" namespace_packages = true explicit_package_bases = true +pretty = true +strict = true +disallow_any_generics = false [tool.ruff] exclude = [