Skip to content

Commit

Permalink
Limit the use of class level attributes on trans
Browse files Browse the repository at this point in the history
  • Loading branch information
has2k1 committed Nov 5, 2017
1 parent 7b3b353 commit 2dff398
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 24 deletions.
6 changes: 6 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ v0.4.1
------
*not-yet-released*

- :class:`~mizani.transforms.trans` objects can now be instantiated
with parameter to override attributes of the instance. And the
default methods for computing breaks and minor breaks on the
transform instance are not class attributes, so they can be
modified without global repercussions.

v0.4.0
------
*(2017-10-24)*
Expand Down
8 changes: 4 additions & 4 deletions mizani/breaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,12 @@ class trans_minor_breaks(object):
>>> from mizani.transforms import sqrt_trans
>>> major = [1, 2, 3, 4]
>>> limits = [0, 5]
>>> sqrt_trans.minor_breaks(major, limits)
>>> sqrt_trans().minor_breaks(major, limits)
array([ 0.5, 1.5, 2.5, 3.5, 4.5])
>>> class sqrt_trans2(sqrt_trans):
... pass
>>> sqrt_trans2.minor_breaks = trans_minor_breaks(sqrt_trans2)
>>> sqrt_trans2.minor_breaks(major, limits)
... def __init__(self):
... self.minor_breaks = trans_minor_breaks(sqrt_trans2)
>>> sqrt_trans2().minor_breaks(major, limits)
array([ 1.58113883, 2.54950976, 3.53553391])
"""
def __init__(self, trans, n=1):
Expand Down
17 changes: 9 additions & 8 deletions mizani/tests/test_breaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,31 +81,32 @@ def test_minor_breaks():

def test_trans_minor_breaks():
class identity_trans(trans):
pass

identity_trans.minor_breaks = trans_minor_breaks(identity_trans)
def __init__(self):
self.minor_breaks = trans_minor_breaks(identity_trans)

class square_trans(trans):
transform = staticmethod(np.square)
inverse = staticmethod(np.sqrt)

square_trans.minor_breaks = trans_minor_breaks(square_trans)
def __init__(self):
self.minor_breaks = trans_minor_breaks(square_trans)

class weird_trans(trans):
dataspace_is_numerical = False

weird_trans.minor_breaks = trans_minor_breaks(weird_trans)
def __init__(self):
self.minor_breaks = trans_minor_breaks(weird_trans)

major = [1, 2, 3, 4]
limits = [0, 5]
regular_minors = trans.minor_breaks(major, limits)
regular_minors = trans().minor_breaks(major, limits)
npt.assert_allclose(
regular_minors,
identity_trans.minor_breaks(major, limits))
identity_trans().minor_breaks(major, limits))

# Transform the input major breaks and check against
# the inverse of the output minor breaks
squared_input_minors = square_trans.minor_breaks(
squared_input_minors = square_trans().minor_breaks(
np.square(major), np.square(limits))
npt.assert_allclose(regular_minors,
np.sqrt(squared_input_minors))
Expand Down
8 changes: 7 additions & 1 deletion mizani/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from mizani.breaks import mpl_breaks, minor_breaks
from mizani.transforms import (
trans,
asn_trans, atanh_trans, boxcox_trans, datetime_trans,
exp_trans, identity_trans, log10_trans, log1p_trans,
log2_trans, log_trans, probability_trans, reverse_trans,
Expand All @@ -19,6 +20,11 @@
arr = np.arange(1, 100)


def test_trans():
with pytest.raises(KeyError):
trans(universe=True)


def test_trans_new():
t = trans_new('bounded_identity',
staticmethod(lambda x: x),
Expand All @@ -36,7 +42,7 @@ def test_trans_new():
assert t.__doc__ == 'Bounded Identity transform'

# ticks do not go beyond the bounds
major = t.breaks((-1999, 1999))
major = t().breaks((-1999, 1999))
assert min(major) >= -999
assert max(major) <= 999

Expand Down
57 changes: 46 additions & 11 deletions mizani/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,27 @@ class trans(object):
and :meth:`trans.inverse`. Alternately, you can quickly
create a transform class using the :func:`trans_new`
function.
Parameters
---------
kwargs : dict
Attributes of the class to set/override
Examples
--------
By default trans returns one minor break between every pair
of major break
>>> major = [0, 1, 2]
>>> t = trans()
>>> t.minor_breaks(major)
array([ 0.5, 1.5])
Create a trans that returns 4 minor breaks
>>> t = trans(minor_breaks=minor_breaks(4))
>>> t.minor_breaks(major)
array([ 0.2, 0.4, 0.6, 0.8, 1.2, 1.4, 1.6, 1.8])
"""
#: Aesthetic that the transform works on
aesthetic = None
Expand All @@ -66,15 +87,30 @@ class trans(object):
#: Limits of the transformed data
domain = (-np.inf, np.inf)

#: Function to calculate breaks
breaks_ = staticmethod(extended_breaks(n=5))
#: Callable to calculate breaks
breaks_ = None

#: Function to calculate minor_breaks
minor_breaks = staticmethod(minor_breaks(1))
#: Callable to calculate minor_breaks
minor_breaks = None

#: Function to format breaks
format = staticmethod(mpl_format())

def __init__(self, **kwargs):
for attr in kwargs:
if hasattr(self, attr):
setattr(self, attr, kwargs[attr])
else:
raise KeyError(
"Unknown Parameter {!r}".format(attr))

# Defaults
if 'breaks_' not in kwargs:
self.breaks_ = extended_breaks(n=5)

if 'minor_breaks' not in kwargs:
self.minor_breaks = minor_breaks(1)

@staticmethod
def transform(x):
"""
Expand All @@ -89,8 +125,7 @@ def inverse(x):
"""
return x

@classmethod
def breaks(cls, limits):
def breaks(self, limits):
"""
Calculate breaks in data space and return them
in transformed space.
Expand Down Expand Up @@ -120,15 +155,15 @@ def breaks(cls, limits):
"""
# clip the breaks to the domain,
# e.g. probabilities will be in [0, 1] domain
vmin = np.max([cls.domain[0], limits[0]])
vmax = np.min([cls.domain[1], limits[1]])
breaks = np.asarray(cls.breaks_([vmin, vmax]))
vmin = np.max([self.domain[0], limits[0]])
vmax = np.min([self.domain[1], limits[1]])
breaks = np.asarray(self.breaks_([vmin, vmax]))

# Some methods(mpl_breaks, extended_breaks) that
# calculate breaks take the limits as guide posts and
# not hard limits.
breaks = breaks.compress((breaks >= cls.domain[0]) &
(breaks <= cls.domain[1]))
breaks = breaks.compress((breaks >= self.domain[0]) &
(breaks <= self.domain[1]))
return breaks


Expand Down

0 comments on commit 2dff398

Please sign in to comment.