From bb5ec6e613db939fa7551d2204b2a66a9581d9a6 Mon Sep 17 00:00:00 2001 From: Hassan Kibirige Date: Sun, 20 Oct 2024 00:24:07 +0300 Subject: [PATCH] Refactor transforms to use dataclasses --- mizani/transforms.py | 443 +++++++++++++++++++-------------------- mizani/typing.py | 2 + tests/test_transforms.py | 30 ++- 3 files changed, 234 insertions(+), 241 deletions(-) diff --git a/mizani/transforms.py b/mizani/transforms.py index 610c9c4..545d865 100644 --- a/mizani/transforms.py +++ b/mizani/transforms.py @@ -23,7 +23,8 @@ import sys from abc import ABC, abstractmethod -from datetime import MAXYEAR, MINYEAR, datetime, timedelta +from dataclasses import KW_ONLY, dataclass, field +from datetime import MAXYEAR, MINYEAR, datetime, timedelta, tzinfo from types import MethodType from typing import TYPE_CHECKING from zoneinfo import ZoneInfo @@ -52,7 +53,6 @@ label_number, label_timedelta, ) -from .utils import identity if TYPE_CHECKING: from typing import Any, Callable, Sequence, Type @@ -103,43 +103,41 @@ UTC = ZoneInfo("UTC") +@dataclass(kw_only=True) class trans(ABC): - """ - Base class for all transforms - - This class is used to transform data and also tell the - x and y axes how to create and label the tick locations. - - The key methods to override are :meth:`trans.transform` - and :meth:`trans.inverse`. Alternately, you can quickly - create a transform class using the :func:`trans_new` - function. - - Parameters - ---------- - kwargs : dict - Attributes of the class to set/override + domain: DomainType = (-np.inf, np.inf) + transform_is_linear: bool = False + """ + Whether the transformation over the whole domain is linear. + e.g. `2x` is linear while `1/x` and `log(x)` are not. """ - #: Whether the transformation over the whole domain is linear. - #: e.g. `2x` is linear while `1/x` and `log(x)` are not. - transform_is_linear: bool = False + breaks_func: BreaksFunction = field(default_factory=breaks_extended) + "Callable to calculate breaks" - domain: DomainType = (-np.inf, np.inf) + format_func: FormatFunction = field(default_factory=label_number) + "Function to format breaks" - #: Callable to calculate breaks - breaks_: BreaksFunction = breaks_extended(n=5) + minor_breaks_func: MinorBreaksFunction | None = None + "Callable to calculate minor breaks" - #: Function to format breaks - format: FormatFunction = staticmethod(label_number()) + # Use type variables for trans.transform and trans.inverse + # to help upstream packages avoid type mismatches. e.g. + # transform(tuple[float, float]) -> tuple[float, float] + @abstractmethod + def transform(self, x: TFloatArrayLike) -> TFloatArrayLike: + """ + Transform of x + """ + ... - def __init__(self, **kwargs: Any): - for k, v in kwargs.items(): - if hasattr(self, k): - setattr(self, k, v) - else: - raise AttributeError(f"Unknown Parameter: {k}") + @abstractmethod + def inverse(self, x: TFloatArrayLike) -> TFloatArrayLike: + """ + Inverse of x + """ + ... @property def domain_is_numerical(self) -> bool: @@ -160,6 +158,9 @@ def minor_breaks( """ Calculate minor_breaks """ + if self.minor_breaks_func is not None: + return self.minor_breaks_func(major, limits, n) + n = 1 if n is None else n # minor_breaks_trans undoes the transformation and @@ -171,25 +172,6 @@ def minor_breaks( func = minor_breaks_trans(self, n=n) return func(major, limits, n) - # Use type variables for trans.transform and trans.inverse - # to help upstream packages avoid type mismatches. e.g. - # transform(tuple[float, float]) -> tuple[float, float] - @staticmethod - @abstractmethod - def transform(x: TFloatArrayLike) -> TFloatArrayLike: - """ - Transform of x - """ - ... - - @staticmethod - @abstractmethod - def inverse(x: TFloatArrayLike) -> TFloatArrayLike: - """ - Inverse of x - """ - ... - def breaks(self, limits: DomainType) -> NDArrayFloat: """ Calculate breaks in data space and return them @@ -224,7 +206,7 @@ def breaks(self, limits: DomainType) -> NDArrayFloat: max(self.domain[0], limits[0]), min(self.domain[1], limits[1]), ) - breaks = np.asarray(self.breaks_(limits)) + breaks = np.asarray(self.breaks_func(limits)) # Some methods (e.g. breaks_extended) that # calculate breaks take the limits as guide posts and @@ -259,10 +241,10 @@ def trans_new( name: str, transform: TransformFunction, inverse: InverseFunction, - breaks: BreaksFunction | None = None, - minor_breaks: MinorBreaksFunction | None = None, - _format: FormatFunction | None = None, - domain=(-np.inf, np.inf), + breaks_func: BreaksFunction | None = None, + minor_breaks_func: MinorBreaksFunction | None = None, + format_func: FormatFunction | None = None, + domain: DomainType=(-np.inf, np.inf), doc: str = "", **kwargs, ) -> trans: @@ -320,19 +302,20 @@ def _get(func): **kwargs, } - if breaks: - d["breaks_"] = _get(breaks) + if breaks_func: + d["breaks_func"] = _get(breaks_func) if minor_breaks: - d["minor_breaks"] = _get(minor_breaks) + d["minor_breaks_func"] = _get(minor_breaks_func) - if _format: - d["format"] = _get(_format) + if format_func: + d["format_func"] = _get(format_func) return type(klass_name, (trans,), d) # type: ignore -def log_trans(base: float | None = None, **kwargs: Any) -> trans: +@dataclass +class log_trans(trans): """ Create a log transform class for *base* @@ -351,50 +334,56 @@ def log_trans(base: float | None = None, **kwargs: Any) -> trans: out : type Log transform class """ - # transform function - if base is None: - name = "log" - base = np.exp(1) - transform = np.log # type: ignore - elif base == 10: - name = "log10" - transform = np.log10 # type: ignore - elif base == 2: - name = "log2" - transform = np.log2 # type: ignore - else: - name = "log{}".format(base) - def transform(x: FloatArrayLike) -> NDArrayFloat: - return np.log(x) / np.log(base) + base: float = np.exp(1) + _: KW_ONLY + domain: DomainType = (sys.float_info.min, np.inf) + + def __post_init__(self): + if self.base == 10: + self._transform = np.log10 + elif self.base == 2: + self._transform = np.log2 + elif self.base == np.exp(1): + self._transform = np.log + else: + + def _transform(x: FloatArrayLike) -> NDArrayFloat: + return np.log(x) / np.log(self.base) + + self._transform = _transform - # inverse function - def inverse(x): - return np.power(base, x) # type: ignore + self.breaks_func = breaks_log(base=self.base) + self.format_func = label_log(base=self.base) + self.minor_breaks_func = minor_breaks_trans(self, n=int(self.base) - 2) - if "domain" not in kwargs: - kwargs["domain"] = (sys.float_info.min, np.inf) + def transform(self, x): + return self._transform(x) - if "breaks" not in kwargs: - kwargs["breaks"] = breaks_log(base=base) # type: ignore + def inverse(self, x): + return np.power(self.base, x) - kwargs["base"] = base - kwargs["_format"] = label_log(base) # type: ignore - _trans = trans_new(name, transform, inverse, **kwargs) +@dataclass +class log10_trans(log_trans): + """ + Log 10 Transformation + """ - if "minor_breaks" not in kwargs: - n = int(base) - 2 # type: ignore - _trans.minor_breaks = minor_breaks_trans(_trans, n=n) + base: float = 10 - return _trans +@dataclass +class log2_trans(log_trans): + """ + Log 2 Transformation + """ -log10_trans = log_trans(10, doc="Log 10 Transformation") -log2_trans = log_trans(2, doc="Log 2 Transformation") + base: float = 2 -def exp_trans(base: float | None = None, **kwargs: Any): +@dataclass +class exp_trans(trans): """ Create a exponential transform class for *base* @@ -414,34 +403,30 @@ def exp_trans(base: float | None = None, **kwargs: Any): out : type Exponential transform class """ - # default to e - if base is None: - name = "power_e" - base = np.exp(1) - else: - name = "power_{}".format(base) - # transform function - def transform(x): - return np.power(base, x) # type: ignore + base: float = np.exp(1) - # inverse function - def inverse(x): - return np.log(x) / np.log(base) # type: ignore + def transform(self, x): + return np.power(self.base, x) - kwargs["base"] = base - return trans_new(name, transform, inverse, **kwargs) + def inverse(self, x): + return np.log(x) / np.log(self.base) +@dataclass class log1p_trans(trans): """ Log plus one Transformation """ - transform = staticmethod(np.log1p) # type: ignore - inverse = staticmethod(np.expm1) # type: ignore + def transform(self, x): + return np.log1p(x) + def inverse(self, x): + return np.expm1(x) + +@dataclass class identity_trans(trans): """ Identity Transformation @@ -458,61 +443,83 @@ class identity_trans(trans): Create a trans that returns 4 minor breaks - >>> t = identity_trans(minor_breaks=minor_breaks(4)) + >>> t = identity_trans(minor_breaks_func=minor_breaks(4)) >>> t.minor_breaks(major) array([0.2, 0.4, 0.6, 0.8, 1.2, 1.4, 1.6, 1.8]) """ - transform_is_linear = True - transform = staticmethod(identity) # type: ignore - inverse = staticmethod(identity) # type: ignore + transform_is_linear: bool = True + + def transform(self, x): + return x + + def inverse(self, x): + return x +@dataclass(kw_only=True) class reverse_trans(trans): """ Reverse Transformation """ - transform_is_linear = True - transform = staticmethod(np.negative) # type: ignore - inverse = staticmethod(np.negative) # type: ignore + transform_is_linear: bool = True + def transform(self, x): + return np.negative(x) + def inverse(self, x): + return np.negative(x) + + +@dataclass(kw_only=True) class sqrt_trans(trans): """ Square-root Transformation """ - transform = staticmethod(np.sqrt) # type: ignore - inverse = staticmethod(np.square) # type: ignore - domain = (0, np.inf) + domain: DomainType = (0, np.inf) + + def transform(self, x): + return np.sqrt(x) + def inverse(self, x): + return np.square(x) + +@dataclass(kw_only=True) class asn_trans(trans): """ Arc-sin square-root Transformation """ - @staticmethod - def transform(x: FloatArrayLike) -> NDArrayFloat: + transform_is_linear: bool = True + + def transform(self, x: FloatArrayLike) -> NDArrayFloat: return 2 * np.arcsin(np.sqrt(x)) # type: ignore - @staticmethod - def inverse(x: FloatArrayLike) -> NDArrayFloat: + def inverse(self, x: FloatArrayLike) -> NDArrayFloat: x = np.asarray(x) return np.sin(x / 2) ** 2 # type: ignore +@dataclass(kw_only=True) class atanh_trans(trans): """ Arc-tangent Transformation """ - transform = staticmethod(np.arctanh) # type: ignore - inverse = staticmethod(np.tanh) # type: ignore + transform_is_linear: bool = True + + def transform(self, x): + return np.arctanh(x) + def inverse(self, x): + return np.tanh(x) -def boxcox_trans(p, offset=0, **kwargs): + +@dataclass +class boxcox_trans(trans): r""" Boxcox Transformation @@ -556,34 +563,31 @@ def boxcox_trans(p, offset=0, **kwargs): """ - def transform(x: FloatArrayLike) -> NDArrayFloat: + p: float + offset: int = 0 + + def transform(self, x: FloatArrayLike) -> NDArrayFloat: x = np.asarray(x) - if np.any((x + offset) < 0): + if np.any((x + self.offset) < 0): raise ValueError( "boxcox_trans must be given only positive values. " "Consider using modulus_trans instead?" ) - if np.abs(p) < 1e-7: - return np.log(x + offset) + if np.abs(self.p) < 1e-7: + return np.log(x + self.offset) else: - return ((x + offset) ** p - 1) / p + return ((x + self.offset) ** self.p - 1) / self.p - def inverse(x: FloatArrayLike) -> NDArrayFloat: + def inverse(self, x: FloatArrayLike) -> NDArrayFloat: x = np.asarray(x) - if np.abs(p) < 1e-7: - return np.exp(x) - offset # type: ignore + if np.abs(self.p) < 1e-7: + return np.exp(x) - self.offset # type: ignore else: - return (x * p + 1) ** (1 / p) - offset - - kwargs["p"] = p - kwargs["offset"] = offset - kwargs["name"] = kwargs.get("name", "pow_{}".format(p)) - kwargs["transform"] = transform - kwargs["inverse"] = inverse - return trans_new(**kwargs) + return (x * self.p + 1) ** (1 / self.p) - self.offset -def modulus_trans(p, offset=1, **kwargs): +@dataclass +class modulus_trans(trans): r""" Modulus Transformation @@ -628,35 +632,31 @@ def modulus_trans(p, offset=1, **kwargs): -------- :func:`~mizani.transforms.boxcox_trans` """ - if np.abs(p) < 1e-7: - def transform(x: FloatArrayLike) -> NDArrayFloat: - x = np.asarray(x) - return np.sign(x) * np.log(np.abs(x) + offset) + p: float + offset: int = 1 - def inverse(x: FloatArrayLike) -> NDArrayFloat: - x = np.asarray(x) - return np.sign(x) * (np.exp(np.abs(x)) - offset) # type: ignore - - else: + def transform(self, x: FloatArrayLike) -> NDArrayFloat: + x = np.asarray(x) + p, offset = self.p, self.offset - def transform(x: FloatArrayLike) -> NDArrayFloat: - x = np.asarray(x) + if np.abs(self.p) < 1e-7: + return np.sign(x) * np.log(np.abs(x) + offset) + else: return np.sign(x) * ((np.abs(x) + offset) ** p - 1) / p - def inverse(x: FloatArrayLike) -> NDArrayFloat: - x = np.asarray(x) - return np.sign(x) * ((np.abs(x) * p + 1) ** (1 / p) - offset) + def inverse(self, x: FloatArrayLike) -> NDArrayFloat: + x = np.asarray(x) + p, offset = self.p, self.offset - kwargs["p"] = p - kwargs["offset"] = offset - kwargs["name"] = kwargs.get("name", "mt_pow_{}".format(p)) - kwargs["transform"] = transform - kwargs["inverse"] = inverse - return trans_new(**kwargs) + if np.abs(self.p) < 1e-7: + return np.sign(x) * (np.exp(np.abs(x)) - offset) # type: ignore + else: + return np.sign(x) * ((np.abs(x) * p + 1) ** (1 / p) - offset) -def probability_trans(distribution: str, *args, **kwargs) -> trans: +@dataclass +class probability_trans(trans): """ Probability Transformation @@ -678,39 +678,44 @@ def probability_trans(distribution: str, *args, **kwargs) -> trans: computations may run into errors. Absence of any errors does not imply that the distribution fits the data. """ - import scipy.stats as stats - cdists = {k for k in dir(stats) if hasattr(getattr(stats, k), "cdf")} - if distribution not in cdists: - raise ValueError(f"Unknown distribution '{distribution}'") + def __init__(self, distribution: str, *args, **kwargs): + import scipy.stats as stats - try: - doc = kwargs.pop("_doc") - except KeyError: - doc = "" + cdists = {k for k in dir(stats) if hasattr(getattr(stats, k), "cdf")} + if distribution not in cdists: + raise ValueError(f"Unknown distribution '{distribution}'") - try: - name = kwargs.pop("_name") - except KeyError: - name = "prob_{}".format(distribution) + self._dist = getattr(stats, distribution) + self._args = args + self._kwargs = kwargs - def transform(x: FloatArrayLike) -> NDArrayFloat: - return getattr(stats, distribution).cdf(x, *args, **kwargs) + def transform(self, x: FloatArrayLike) -> NDArrayFloat: + return self._dist.cdf(x, *self._args, **self._kwargs) - def inverse(x: FloatArrayLike) -> NDArrayFloat: - return getattr(stats, distribution).ppf(x, *args, **kwargs) + def inverse(self, x: FloatArrayLike) -> NDArrayFloat: + return self._dist.ppf(x, *self._args, **self._kwargs) - return trans_new(name, transform, inverse, domain=(0, 1), doc=doc) +class logit_trans(probability_trans): + """ + Logit Transformation + """ -logit_trans = probability_trans( - "logistic", _name="logit", _doc="Logit Transformation" -) -probit_trans = probability_trans( - "norm", _name="norm", _doc="Probit Transformation" -) + def __init__(self): + super().__init__("logistic") + + +class probit_trans(probability_trans): + """ + Probit Transformation + """ + + def __init__(self): + super().__init__("norm") +@dataclass class datetime_trans(trans): """ Datetime Transformation @@ -738,20 +743,19 @@ class datetime_trans(trans): 'EST' """ - domain = ( + tz: tzinfo | str | None = None + + _: KW_ONLY + domain: DomainType = ( datetime(MINYEAR, 1, 1, tzinfo=UTC), datetime(MAXYEAR, 12, 31, tzinfo=UTC), ) - breaks_ = staticmethod(breaks_date()) - format = staticmethod(label_date()) - tz = None + breaks_func: BreaksFunction = field(default_factory=breaks_date) + format_func: FormatFunction = field(default_factory=label_date) - def __init__(self, tz=None, **kwargs): - if isinstance(tz, str): - tz = ZoneInfo(tz) - - super().__init__(**kwargs) - self.tz = tz + def __post_init__(self): + if isinstance(self.tz, str): + self.tz = ZoneInfo(self.tz) def transform(self, x: DatetimeArrayLike) -> NDArrayFloat: # pyright: ignore[reportIncompatibleMethodOverride] """ @@ -794,17 +798,17 @@ def diff_type_to_num(self, x: TimedeltaArrayLike) -> FloatArrayLike: return timedelta_to_num(x) +@dataclass(kw_only=True) class timedelta_trans(trans): """ Timedelta Transformation """ - domain = (timedelta.min, timedelta.max) - breaks_ = staticmethod(breaks_timedelta()) - format = staticmethod(label_timedelta()) + domain: DomainType = (timedelta.min, timedelta.max) + breaks_func: BreaksFunction = field(default_factory=breaks_timedelta) + format_func: FormatFunction = field(default_factory=label_timedelta) - @staticmethod - def transform(x: TimedeltaArrayLike) -> NDArrayFloat: # pyright: ignore[reportIncompatibleMethodOverride] + def transform(self, x: TimedeltaArrayLike) -> NDArrayFloat: # pyright: ignore[reportIncompatibleMethodOverride] """ Transform from Timeddelta to numerical format @@ -812,8 +816,7 @@ def transform(x: TimedeltaArrayLike) -> NDArrayFloat: # pyright: ignore[reportI """ return timedelta_to_num(x) - @staticmethod - def inverse(x: FloatArrayLike) -> Sequence[pd.Timedelta]: # pyright: ignore[reportIncompatibleMethodOverride] + def inverse(self, x: FloatArrayLike) -> Sequence[pd.Timedelta]: # pyright: ignore[reportIncompatibleMethodOverride] """ Transform to Timedelta from numerical format """ @@ -828,12 +831,13 @@ def diff_type_to_num(self, x: TimedeltaArrayLike) -> FloatArrayLike: return timedelta_to_num(x) +@dataclass(kw_only=True) class pd_timedelta_trans(timedelta_trans): """ Pandas timedelta Transformation """ - domain = (pd.Timedelta.min, pd.Timedelta.max) + domain: DomainType = (pd.Timedelta.min, pd.Timedelta.max) class reciprocal_trans(trans): @@ -841,15 +845,14 @@ class reciprocal_trans(trans): Reciprocal Transformation """ - @staticmethod - def transform(x: FloatArrayLike) -> NDArrayFloat: + def transform(self, x: FloatArrayLike) -> NDArrayFloat: return 1 / np.asarray(x) - @staticmethod - def inverse(x: FloatArrayLike) -> NDArrayFloat: + def inverse(self, x: FloatArrayLike) -> NDArrayFloat: return 1 / np.asarray(x) +@dataclass class pseudo_log_trans(trans): """ Pseudo-log transformation @@ -870,19 +873,14 @@ class pseudo_log_trans(trans): the `transform` or `inverse`. """ - def __init__(self, sigma=1, base=None, **kwargs): - if base is None: - base = np.exp(1) - - self.sigma = sigma - self.base = base - super().__init__(**kwargs) + sigma: float = 1 + base: float = np.exp(1) - def transform(self, x: FloatArrayLike) -> NDArrayFloat: # pyright: ignore[reportIncompatibleMethodOverride] + def transform(self, x: FloatArrayLike) -> NDArrayFloat: x = np.asarray(x) return np.arcsinh(x / (2 * self.sigma)) / np.log(self.base) - def inverse(self, x: FloatArrayLike) -> NDArrayFloat: # pyright: ignore[reportIncompatibleMethodOverride] + def inverse(self, x: FloatArrayLike) -> NDArrayFloat: x = np.asarray(x) return 2 * self.sigma * np.sinh(x * np.log(self.base)) @@ -896,6 +894,7 @@ def minor_breaks( return super().minor_breaks(major, limits, n) +@dataclass(kw_only=True) class symlog_trans(trans): """ Symmetric Log Transformation @@ -911,14 +910,12 @@ class symlog_trans(trans): and negative values (including zero). """ - breaks_: BreaksFunction = breaks_symlog() + breaks_func: BreaksFunction = breaks_symlog() - @staticmethod - def transform(x: FloatArrayLike) -> NDArrayFloat: + def transform(self, x: FloatArrayLike) -> NDArrayFloat: return np.sign(x) * np.log1p(np.abs(x)) - @staticmethod - def inverse(x: FloatArrayLike) -> NDArrayFloat: + def inverse(self, x: FloatArrayLike) -> NDArrayFloat: return np.sign(x) * (np.exp(np.abs(x)) - 1) # type: ignore diff --git a/mizani/typing.py b/mizani/typing.py index 4aaf862..2c2d78a 100644 --- a/mizani/typing.py +++ b/mizani/typing.py @@ -190,6 +190,8 @@ def __gt__(self, other, /) -> bool: ... DomainType: TypeAlias = tuple[PComparison, PComparison] + TFloatTimedelta = TypeVar("TFloatTimedelta", float, timedelta) + # This does not work probably due to a bug in the typechecker # FormatFunction: TypeAlias = Callable[[AnyArrayLike], Sequence[str]] FormatFunction: TypeAlias = ( diff --git a/tests/test_transforms.py b/tests/test_transforms.py index b0bc766..1c9c81e 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from types import FunctionType, MethodType +from types import FunctionType from zoneinfo import ZoneInfo import numpy as np @@ -41,29 +41,26 @@ def test_trans(): with pytest.raises(TypeError): trans() # type: ignore - with pytest.raises(AttributeError): - identity_trans(universe=True) - def test_trans_new(): t = trans_new( "bounded_identity", - staticmethod(lambda x: x), - classmethod(lambda x: x), - _format=lambda x: str(x), - domain=(-999, 999), + transform=staticmethod(lambda x: x), + inverse=staticmethod(lambda x: x), doc="Bounded Identity transform", ) assert t.__name__ == "bounded_identity_trans" assert isinstance(t.transform, FunctionType) - assert isinstance(t.inverse, MethodType) - assert isinstance(t.format, FunctionType) - assert t.domain == (-999, 999) + assert isinstance(t.inverse, FunctionType) assert t.__doc__ == "Bounded Identity transform" - # ticks do not go beyond the bounds - major = t().breaks((-1999, 1999)) + # ticks do not go beyond the domain bounds + # major = t().breaks((-1999, 1999)) + # + major = t(format_func=lambda x: str(x), domain=(-999, 999)).breaks( + (-1999, 1999) + ) assert min(major) >= -999 assert max(major) <= 999 @@ -76,9 +73,6 @@ def test_gettrans(): t4 = gettrans() assert all(isinstance(x, identity_trans) for x in (t0, t1, t2, t3, t4)) - t = gettrans(exp_trans) - assert t.__class__.__name__ == "power_e_trans" - with pytest.raises(ValueError): gettrans(object) @@ -188,8 +182,8 @@ def test_logn_trans(): log4_trans = log_trans( 4, domain=(0.1, 100), - breaks=breaks_extended(), - minor_breaks=minor_breaks(), + breaks_func=breaks_extended(), + minor_breaks_func=minor_breaks(), ) _test_trans(log4_trans, arr)