From dc9a34c49942afe487079dabf0e41b80d72f7b5e Mon Sep 17 00:00:00 2001 From: Hassan Kibirige Date: Thu, 24 Oct 2024 01:09:42 +0300 Subject: [PATCH] Improve code for datetime breaks --- doc/changelog.rst | 6 ++++++ mizani/_core/date_utils.py | 25 ++++++++++++++++++------- mizani/_core/dates.py | 14 +++++++++++--- mizani/breaks.py | 6 ++++-- mizani/utils.py | 20 ++++++-------------- tests/test_bounds.py | 4 ++-- tests/test_breaks.py | 6 +++++- tests/test_date_utils.py | 9 +++++++++ tests/test_utils.py | 4 ++-- 9 files changed, 63 insertions(+), 31 deletions(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index c1e7826..5c3c727 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -13,6 +13,12 @@ API Changes - `mizani.transforms.trans_new` function has been deprecated. +Enhancements +************ + +- `~mizani.breaks.breaks_date` has been slightly improved for the case when it + generates monthly breaks. + New *** diff --git a/mizani/_core/date_utils.py b/mizani/_core/date_utils.py index cf9064c..299491f 100644 --- a/mizani/_core/date_utils.py +++ b/mizani/_core/date_utils.py @@ -34,12 +34,6 @@ class Interval: end: datetime def __post_init__(self): - if isinstance(self.start, date): - self.start = datetime.fromisoformat(self.start.isoformat()) - - if isinstance(self.end, date): - self.end = datetime.fromisoformat(self.end.isoformat()) - self._delta = relativedelta(self.end, self.start) self._tdelta = self.end - self.start @@ -149,7 +143,7 @@ def limits_year(self) -> tuple[datetime, datetime]: return floor_year(self.start), ceil_year(self.end) def limits_month(self) -> tuple[datetime, datetime]: - return round_month(self.start), round_month(self.end) + return floor_month(self.start), ceil_month(self.end) def limits_week(self) -> tuple[datetime, datetime]: return floor_week(self.start), ceil_week(self.end) @@ -481,3 +475,20 @@ def expand_datetime_limits( end = end.replace(y2) return start, end + + +def as_datetime( + tup: tuple[datetime, datetime] | tuple[date, date], +) -> tuple[datetime, datetime]: + """ + Ensure that a tuple of datetime values + """ + l, h = tup + + if not isinstance(l, datetime): + l = datetime.fromisoformat(l.isoformat()) + + if not isinstance(h, datetime): + h = datetime.fromisoformat(h.isoformat()) + + return l, h diff --git a/mizani/_core/dates.py b/mizani/_core/dates.py index 998b4a9..1f3f5a2 100644 --- a/mizani/_core/dates.py +++ b/mizani/_core/dates.py @@ -11,7 +11,12 @@ from dateutil.rrule import rrule from ..utils import get_timezone, isclose_abs -from .date_utils import Interval, align_limits, expand_datetime_limits +from .date_utils import ( + Interval, + align_limits, + as_datetime, + expand_datetime_limits, +) from .types import DateFrequency, date_breaks_info if TYPE_CHECKING: @@ -316,10 +321,13 @@ def calculate_date_breaks_info( return res -def calculate_date_breaks_auto(limits, n: int = 5) -> Sequence[datetime]: +def calculate_date_breaks_auto( + limits: tuple[datetime, datetime], n: int = 5 +) -> Sequence[datetime]: """ Calcuate date breaks using appropriate units """ + limits = as_datetime(limits) info = calculate_date_breaks_info(limits, n=n) lookup = { DF.YEARLY: yearly_breaks, @@ -334,7 +342,7 @@ def calculate_date_breaks_auto(limits, n: int = 5) -> Sequence[datetime]: def calculate_date_breaks_byunits( - limits, + limits: tuple[datetime, datetime], units: DatetimeBreaksUnits, width: int, max_breaks: int | None = None, diff --git a/mizani/breaks.py b/mizani/breaks.py index 2446a39..eadf85b 100644 --- a/mizani/breaks.py +++ b/mizani/breaks.py @@ -15,7 +15,7 @@ import sys from dataclasses import KW_ONLY, dataclass, field -from datetime import datetime, timedelta +from datetime import date, datetime, timedelta from itertools import product from typing import TYPE_CHECKING from warnings import warn @@ -23,6 +23,7 @@ import numpy as np import pandas as pd +from mizani._core.date_utils import as_datetime from mizani._core.dates import ( calculate_date_breaks_auto, calculate_date_breaks_byunits, @@ -460,7 +461,7 @@ def __post_init__(self): self._units = units.rstrip("s") # type: ignore def __call__( - self, limits: tuple[datetime, datetime] + self, limits: tuple[datetime, datetime] | tuple[date, date] ) -> Sequence[datetime]: """ Compute breaks @@ -483,6 +484,7 @@ def __call__( ): limits = limits[0].astype(object), limits[1].astype(object) + limits = as_datetime(limits) if self._units and self._width: return calculate_date_breaks_byunits( limits, self._units, self._width diff --git a/mizani/utils.py b/mizani/utils.py index 4409bab..01d5806 100644 --- a/mizani/utils.py +++ b/mizani/utils.py @@ -2,8 +2,8 @@ import math import sys -from datetime import datetime, timezone -from typing import TYPE_CHECKING, cast, overload +from datetime import datetime +from typing import TYPE_CHECKING, overload from warnings import warn import numpy as np @@ -22,7 +22,6 @@ NDArrayFloat, NullType, NumericUFunction, - SeqDatetime, ) T = TypeVar("T") @@ -327,28 +326,21 @@ def log(x, base): return res -def get_timezone(x: SeqDatetime) -> tzinfo | None: +def get_timezone(x: Sequence[datetime]) -> tzinfo | None: """ Return a single timezone for the sequence of datetimes Returns the timezone of first item and warns if any other items have a different timezone """ - # Ref: https://en.wikipedia.org/wiki/List_of_tz_database_time_zones - x0 = next(iter(x)) - if not isinstance(x0, datetime): + if not len(x) or x[0].tzinfo is None: return None - x = cast(list[datetime], x) - info = x0.tzinfo - if info is None: - return timezone.utc - # Consistency check - tzname0 = info.tzname(x0) + info = x[0].tzinfo + tzname0 = info.tzname(x[0]) tznames = (dt.tzinfo.tzname(dt) if dt.tzinfo else None for dt in x) - if any(tzname0 != name for name in tznames): msg = ( "Dates in column have different time zones. " diff --git a/tests/test_bounds.py b/tests/test_bounds.py index 94b88cd..d907e21 100644 --- a/tests/test_bounds.py +++ b/tests/test_bounds.py @@ -255,7 +255,7 @@ def test_squish_infinite(): squish_infinite(a, (-100, 100)), [-100, 100, -100, 100] ) - b = np.array([5, -np.inf, 2, 3, 6]) + b = pd.Series([5, -np.inf, 2, 3, 6]) npt.assert_allclose(squish_infinite(b, (1, 10)), [5, 1, 2, 3, 6]) @@ -270,7 +270,7 @@ def test_squish(): b = np.array([5, 0, -2, 3, 10]) npt.assert_allclose(squish(b, (0, 5)), [5, 0, 0, 3, 5]) - c = np.array([5, -np.inf, 2, 3, 6]) + c = pd.Series([5, -np.inf, 2, 3, 6]) npt.assert_allclose(squish(c, (1, 10), only_finite=False), [5, 1, 2, 3, 6]) npt.assert_allclose(squish(c, (1, 10)), c) diff --git a/tests/test_breaks.py b/tests/test_breaks.py index a6d3080..021a3e9 100644 --- a/tests/test_breaks.py +++ b/tests/test_breaks.py @@ -210,7 +210,7 @@ def test_breaks_date(): # automatic monthly breaks with rounding limits = (datetime(2019, 12, 27), datetime(2020, 6, 3)) breaks = breaks_date()(limits) - assert [dt.month for dt in breaks] == [1, 3, 5] + assert [dt.month for dt in breaks] == [12, 2, 4, 6] # automatic day breaks limits = (datetime(2020, 1, 1), datetime(2020, 1, 15)) @@ -246,6 +246,10 @@ def test_breaks_date(): breaks = breaks_date()(limits) assert breaks[0].tzinfo == UG + # date + limits = (date(2000, 4, 23), date(2000, 6, 15)) + breaks = breaks_date()(limits) + # Special cases limits = (datetime(2039, 12, 17), datetime(2045, 12, 16)) breaks = breaks_date()(limits) diff --git a/tests/test_date_utils.py b/tests/test_date_utils.py index aa0207b..b9c6c35 100644 --- a/tests/test_date_utils.py +++ b/tests/test_date_utils.py @@ -3,6 +3,7 @@ from mizani._core.date_utils import ( ceil_month, expand_datetime_limits, + round_month, shift_limits_down, ) @@ -31,3 +32,11 @@ def test_ceil_month(): d = datetime(2020, 1, 1) assert ceil_month(d) == d + + +def test_round_month(): + d = datetime(2000, 4, 23) + assert round_month(d) == datetime(2000, 5, 1) + + d = datetime(2000, 4, 14) + assert round_month(d) == datetime(2000, 4, 1) diff --git a/tests/test_utils.py b/tests/test_utils.py index e47eed7..6c24379 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,4 @@ -from datetime import date, datetime +from datetime import datetime from zoneinfo import ZoneInfo import pandas as pd @@ -133,7 +133,7 @@ def test_get_timezone(): UTC = ZoneInfo("UTC") UG = ZoneInfo("Africa/Kampala") - x = [date(2022, 1, 1), date(2022, 12, 1)] + x = [datetime(2022, 1, 1), datetime(2022, 12, 1)] assert get_timezone(x) is None x = [datetime(2022, 1, 1, tzinfo=UTC), datetime(2022, 12, 1, tzinfo=UG)]