Skip to content

Commit

Permalink
Resolves dataclass type validation issue
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Eleff committed Aug 21, 2024
1 parent 4f7b76c commit 3222e84
Showing 1 changed file with 67 additions and 10 deletions.
77 changes: 67 additions & 10 deletions assemblit/_app/_generic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
""" Generic web-application """

import os
from typing import Union, Optional
from typing import Type, Union, Optional, Any, get_type_hints
from dataclasses import dataclass, field, fields, asdict
from assemblit.toolkit import _exceptions

Expand Down Expand Up @@ -74,17 +74,36 @@ def __post_init__(self):
raise _exceptions.MissingEnvironmentVariables

# Validate types
# for variable in fields(self):
# if variable.name not in ['ASSEMBLIT_DIR']:
# value = getattr(self, variable.name)
# if not isinstance(value, variable.type):
# raise ValueError(
# 'Invalid dtype {%s} for {%s}. Expected {%s}.' % (
# type(value).__name__,
# variable.name,
# (variable.type).__name__
# )
# )

# # Convert relative directory paths to absoluate paths
# if variable.name in ['ASSEMBLIT_DIR']:
# setattr(self, variable.name, os.path.abspath(getattr(self, variable.name)))

# Validate types v2
type_hints = _env.get_all_type_hints(dataclass_object=type(self))

# Check the type of field
for variable in fields(self):
if variable.name not in ['ASSEMBLIT_DIR']: # Avoid type-checking `ASSEMBLIT_DIR` due to Python 3.8 and 3.9 behavior
value = getattr(self, variable.name)
if not isinstance(value, variable.type):
raise ValueError(
'Invalid dtype {%s} for {%s}. Expected {%s}.' % (
type(value).__name__,
variable.name,
(variable.type).__name__
)
value = getattr(self, variable.name)
if not _env.check_type(expected_type=type_hints[variable.name], value=value):
raise ValueError(
'Invalid dtype {%s} for {%s}. Expected {%s}.' % (
type(value).__name__,
variable.name,
(variable.type).__name__
)
)

# Convert relative directory paths to absoluate paths
if variable.name in ['ASSEMBLIT_DIR']:
Expand All @@ -101,3 +120,41 @@ def list_variables(self) -> list:
def values(self) -> tuple:
""" Returns the environment variable values as a tuple. """
return tuple(asdict(self).values())

def check_type(expected_type: Type, value: Any) -> bool:
""" Recursively check if a value matches the expected type. This function is
compatible with basic types, Union, and Optional.
Parameters
----------
expected_type : `Type`
The expected type assigned to the dataclass field.
value : `Any`
The value to type check.
"""
if hasattr(expected_type, '__origin__'):

# Handle Optional[Type] which is equivalent to Union[Type, None]
if expected_type.__origin__ is Union:
return any(_env.check_type(arg, value) for arg in expected_type.__args__)

# For Optional, allow NoneType
if isinstance(value, expected_type) or value is None and expected_type is Optional:
return True

return isinstance(value, expected_type)

def get_all_type_hints(dataclass_object: object) -> dict:
""" Get all type hints from the dataclass and all subclasses.
Parameters
----------
dataclass_object : `object`
The dataclass object.
"""
hints = {}
for base in dataclass_object.__mro__:
if base is object:
continue
hints.update(get_type_hints(base))
return hints

0 comments on commit 3222e84

Please sign in to comment.