Skip to content

Commit

Permalink
Fix param.pprint for Array parameters by replacing all_equal wi…
Browse files Browse the repository at this point in the history
…th `Comparator.is_equal` in `values()` (#795)

Co-authored-by: James A. Bednar <[email protected]>
  • Loading branch information
maximlt and jbednar authored Jul 21, 2023
1 parent a75c54d commit bcf9a74
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 4 deletions.
5 changes: 5 additions & 0 deletions param/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import functools
import re
import warnings

from textwrap import dedent
Expand Down Expand Up @@ -109,3 +110,7 @@ def wrapper(self, *args, **kwargs):
return wrapper

return decorating_function


def _is_auto_name(class_name, instance_name):
return re.match('^'+class_name+'[0-9]{5}$', instance_name)
19 changes: 15 additions & 4 deletions param/parameterized.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from ._utils import (
_deprecated,
_deprecate_positional_args,
_is_auto_name,
_recursive_repr,
ParamDeprecationWarning as _ParamDeprecationWarning,
)
Expand Down Expand Up @@ -286,6 +287,8 @@ def get_occupied_slots(instance):
if hasattr(instance,slot)]


# PARAM3_DEPRECATION
@_deprecated()
def all_equal(arg1,arg2):
"""
Return a single boolean for arg1==arg2, even for numpy arrays
Expand Down Expand Up @@ -1615,15 +1618,21 @@ class Comparator:
str: operator.eq,
bytes: operator.eq,
type(None): operator.eq,
lambda o: hasattr(o, '_infinitely_iterable'): operator.eq, # Time
}
equalities.update({dtt: operator.eq for dtt in dt_types})

@classmethod
def is_equal(cls, obj1, obj2):
for eq_type, eq in cls.equalities.items():
if ((isinstance(eq_type, FunctionType)
and eq_type(obj1) and eq_type(obj2))
or (isinstance(obj1, eq_type) and isinstance(obj2, eq_type))):
try:
are_instances = isinstance(obj1, eq_type) and isinstance(obj2, eq_type)
except TypeError:
pass
else:
if are_instances:
return eq(obj1, obj2)
if isinstance(eq_type, FunctionType) and eq_type(obj1) and eq_type(obj2):
return eq(obj1, obj2)
if isinstance(obj2, (list, set, tuple)):
return cls.compare_iterator(obj1, obj2)
Expand Down Expand Up @@ -2408,7 +2417,9 @@ def values(self_, onlychanged=False):
vals = []
for name, val in self_or_cls.param.objects('existing').items():
value = self_or_cls.param.get_value_generator(name)
if not onlychanged or not all_equal(value, val.default):
if name == 'name' and onlychanged and _is_auto_name(self_.cls.__name__, value):
continue
if not onlychanged or not Comparator.is_equal(value, val.default):
vals.append((name, value))

vals.sort(key=itemgetter(0))
Expand Down
3 changes: 3 additions & 0 deletions tests/testdeprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def test_deprecate_recursive_repr(self):
with pytest.raises(param._utils.ParamDeprecationWarning):
param.parameterized.recursive_repr(lambda: '')

def test_deprecate_all_equal(self):
with pytest.raises(param._utils.ParamDeprecationWarning):
param.parameterized.all_equal(1, 1)

class TestDeprecateParameters:

Expand Down
7 changes: 7 additions & 0 deletions tests/testnumpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,10 @@ class Z(param.Parameterized):

z = Z(z=numpy.array([1,2]))
_is_array_and_equal(z.z,[1,2])

def test_array_pprint(self):
class MatParam(param.Parameterized):
mat = param.Array(numpy.zeros((2, 2)))

mp = MatParam()
mp.param.pprint()
8 changes: 8 additions & 0 deletions tests/testparameterizedobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,14 @@ def test_values(self):
assert 'inst' in TestPO.param.values()
assert 'notinst' in TestPO.param.values()

def test_values_name_ignored_for_instances_and_onlychanged(self):
default_inst = param.Parameterized()
assert 'Parameterized' in default_inst.name
# name ignored when automatically computed (behavior inherited from all_equal)
assert 'name' not in default_inst.param.values(onlychanged=True)
# name not ignored when set
assert param.Parameterized(name='foo').param.values(onlychanged=True)['name'] == 'foo'

def test_param_iterator(self):
self.assertEqual(set(TestPO.param), {'name', 'inst', 'notinst', 'const', 'dyn',
'ro', 'ro2', 'ro_label', 'ro_format'})
Expand Down

0 comments on commit bcf9a74

Please sign in to comment.