diff --git a/docs/source/api_doc/tree/integration.rst b/docs/source/api_doc/tree/integration.rst index 1e8bebf6af..72afb9f89d 100644 --- a/docs/source/api_doc/tree/integration.rst +++ b/docs/source/api_doc/tree/integration.rst @@ -27,3 +27,36 @@ register_treevalue_class .. autofunction:: register_treevalue_class + +.. _apidoc_tree_integration_register_integrate_container: + +register_integrate_container +-------------------------------- + +.. autofunction:: register_integrate_container + + +.. _apidoc_tree_integration_generic_flatten: + +generic_flatten +-------------------------------- + +.. autofunction:: generic_flatten + + +.. _apidoc_tree_integration_generic_unflatten: + +generic_unflatten +-------------------------------- + +.. autofunction:: generic_unflatten + + +.. _apidoc_tree_integration_generic_mapping: + +generic_mapping +-------------------------------- + +.. autofunction:: generic_mapping + + diff --git a/test/tree/integration/test_general.py b/test/tree/integration/test_general.py new file mode 100644 index 0000000000..7c959b82f7 --- /dev/null +++ b/test/tree/integration/test_general.py @@ -0,0 +1,95 @@ +from collections import namedtuple + +import pytest +from easydict import EasyDict + +from treevalue import generic_flatten, generic_unflatten, FastTreeValue, register_integrate_container, generic_mapping + +nt = namedtuple('nt', ['a', 'b']) + + +class MyTreeValue(FastTreeValue): + pass + + +@pytest.mark.unittest +class TestTreeIntegrationGeneral: + def test_general_flatten_and_unflatten(self): + demo_data = { + 'a': 1, + 'b': [2, 3, 'f'], + 'c': (2, 5, 'ds', EasyDict({ + 'x': None, + 'z': [34, '1.2'], + })), + 'd': nt('f', 100), + 'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) + } + v, spec = generic_flatten(demo_data) + assert v == [1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']] + + rv = generic_unflatten(v, spec) + assert rv == demo_data + assert isinstance(rv['c'][-1], EasyDict) + assert isinstance(rv['d'], nt) + assert isinstance(rv['c'][-1]['z'], list) + assert isinstance(rv['e'], MyTreeValue) + + def test_register_my_class(self): + class MyDC: + def __init__(self, x, y): + self.x = x + self.y = y + + def __eq__(self, other): + return isinstance(other, MyDC) and self.x == other.x and self.y == other.y + + def _mydc_flatten(v): + return [v.x, v.y], MyDC + + def _mydc_unflatten(v, spec): + return spec(*v) + + register_integrate_container(MyDC, _mydc_flatten, _mydc_unflatten) + + demo_data = { + 'a': 1, + 'b': [2, 3, 'f'], + 'c': (2, 5, 'ds', EasyDict({ + 'x': None, + 'z': MyDC(34, '1.2'), + })), + 'd': nt('f', 100), + 'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) + } + v, spec = generic_flatten(demo_data) + assert v == [1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']] + + rv = generic_unflatten(v, spec) + assert rv == demo_data + assert isinstance(rv['c'][-1], EasyDict) + assert isinstance(rv['d'], nt) + assert isinstance(rv['c'][-1]['z'], MyDC) + assert isinstance(rv['e'], MyTreeValue) + + def test_generic_mapping(self): + demo_data = { + 'a': 1, + 'b': [2, 3, 'f'], + 'c': (2, 5, 'ds', EasyDict({ + 'x': None, + 'z': (34, '1.2'), + })), + 'd': nt('f', 100), + 'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) + } + assert generic_mapping(demo_data, str) == { + 'a': '1', + 'b': ['2', '3', 'f'], + 'c': ('2', '5', 'ds', EasyDict({ + 'x': 'None', + 'z': ('34', '1.2'), + })), + 'd': nt('f', '100'), + 'e': MyTreeValue({'x': '1', 'y': 'dsfljk'}) + } diff --git a/treevalue/tree/integration/__init__.py b/treevalue/tree/integration/__init__.py index 7bf1079447..7899b7242e 100644 --- a/treevalue/tree/integration/__init__.py +++ b/treevalue/tree/integration/__init__.py @@ -1,5 +1,6 @@ from typing import Type +from .general import generic_flatten, generic_unflatten, register_integrate_container, generic_mapping from .jax import register_for_jax from .torch import register_for_torch from ..tree import TreeValue diff --git a/treevalue/tree/integration/base.pyx b/treevalue/tree/integration/base.pyx index 43bb761eb6..8ac5f40674 100644 --- a/treevalue/tree/integration/base.pyx +++ b/treevalue/tree/integration/base.pyx @@ -14,7 +14,6 @@ cdef inline tuple _c_flatten_for_integration(object tv): values.append(value) return values, (type(tv), paths) - pass cdef inline object _c_unflatten_for_integration(object values, tuple spec): cdef object type_ diff --git a/treevalue/tree/integration/general.pxd b/treevalue/tree/integration/general.pxd new file mode 100644 index 0000000000..dfb18b5c44 --- /dev/null +++ b/treevalue/tree/integration/general.pxd @@ -0,0 +1,27 @@ +# distutils:language=c++ +# cython:language_level=3 + +from libcpp cimport bool + +cdef tuple _dict_flatten(object d) +cdef object _dict_unflatten(list values, tuple spec) + +cdef tuple _list_and_tuple_flatten(object l) +cdef object _list_and_tuple_unflatten(list values, object spec) + +cdef tuple _namedtuple_flatten(object l) +cdef object _namedtuple_unflatten(list values, object spec) + +cdef tuple _treevalue_flatten(object l) +cdef object _treevalue_unflatten(list values, tuple spec) + +cdef bool _is_namedtuple_instance(pytree) except* + +cpdef void register_integrate_container(object type_, object flatten_func, object unflatten_func) except* + +cdef tuple _c_get_flatted_values_and_spec(object v) +cdef object _c_get_object_from_flatted(object values, object type_, object spec) + +cpdef object generic_flatten(object v) +cpdef object generic_unflatten(object v, tuple gspec) +cpdef object generic_mapping(object v, object func) diff --git a/treevalue/tree/integration/general.pyx b/treevalue/tree/integration/general.pyx new file mode 100644 index 0000000000..a9467c774c --- /dev/null +++ b/treevalue/tree/integration/general.pyx @@ -0,0 +1,284 @@ +# distutils:language=c++ +# cython:language_level=3 + +from collections import namedtuple + +import cython +from libcpp cimport bool + +from .base cimport _c_flatten_for_integration, _c_unflatten_for_integration +from ..tree.tree cimport TreeValue + +_REGISTERED_CONTAINERS = {} + +cdef inline tuple _dict_flatten(object d): + cdef list values = [] + cdef list keys = [] + + cdef object key, value + for key, value in d.items(): + keys.append(key) + values.append(value) + + return values, (type(d), keys) + +cdef inline object _dict_unflatten(list values, tuple spec): + cdef object type_ + cdef list keys + type_, keys = spec + + cdef dict retval = {} + for key, value in zip(keys, values): + retval[key] = value + + return type_(retval) + +cdef inline tuple _list_and_tuple_flatten(object l): + return list(l), type(l) + +cdef inline object _list_and_tuple_unflatten(list values, object spec): + return spec(values) + +cdef inline tuple _namedtuple_flatten(object l): + return list(l), type(l) + +cdef inline object _namedtuple_unflatten(list values, object spec): + return spec(*values) + +cdef inline tuple _treevalue_flatten(object l): + return _c_flatten_for_integration(l) + +cdef inline object _treevalue_unflatten(list values, tuple spec): + return _c_unflatten_for_integration(values, spec) + +cdef inline bool _is_namedtuple_instance(pytree) except*: + cdef object typ = type(pytree) + cdef tuple bases = typ.__bases__ + if len(bases) != 1 or bases[0] != tuple: + return False + + fields = getattr(typ, '_fields', None) + if not isinstance(fields, tuple): + return False # pragma: no cover + + return all(type(entry) == str for entry in fields) + +@cython.binding(True) +cpdef inline void register_integrate_container(object type_, object flatten_func, object unflatten_func) except*: + """ + Overview: + Register custom data class for generic flatten and unflatten. + + :param type_: Class of data to be registered. + :param flatten_func: Function for flattening. + :param unflatten_func: Function for unflattening. + + Examples:: + >>> from treevalue import register_integrate_container, generic_flatten, FastTreeValue, generic_unflatten + >>> + >>> class MyDC: + ... def __init__(self, x, y): + ... self.x = x + ... self.y = y + ... + ... def __eq__(self, other): + ... return isinstance(other, MyDC) and self.x == other.x and self.y == other.y + >>> + >>> def _mydc_flatten(v): + ... return [v.x, v.y], MyDC + >>> + >>> def _mydc_unflatten(v, spec): # spec will be MyDC + ... return spec(*v) + + >>> + >>> register_integrate_container(MyDC, _mydc_flatten, _mydc_unflatten) # register MyDC + >>> + >>> v, spec = generic_flatten({'a': MyDC(2, 3), 'b': MyDC((4, 5), FastTreeValue({'x': 1, 'y': 'f'}))}) + >>> v + [[2, 3], [[4, 5], [1, 'f']]] + >>> + >>> rt=generic_unflatten(v, spec) + >>> rt + {'a': <__main__.MyDC object at 0x7fbda613f9d0>, 'b': <__main__.MyDC object at 0x7fbda6148150>} + >>> rt['a'].x + 2 + >>> rt['a'].y + 3 + >>> rt['b'].x + (4, 5) + >>> rt['b'].y + + ├── 'x' --> 1 + └── 'y' --> 'f' + """ + _REGISTERED_CONTAINERS[type_] = (flatten_func, unflatten_func) + +cdef inline tuple _c_get_flatted_values_and_spec(object v): + cdef list values + cdef object spec, type_ + cdef object flatten_func + if isinstance(v, dict): + values, spec = _dict_flatten(v) + type_ = dict + elif _is_namedtuple_instance(v): + values, spec = _namedtuple_flatten(v) + type_ = namedtuple + elif isinstance(v, (list, tuple)): + values, spec = _list_and_tuple_flatten(v) + type_ = list + elif isinstance(v, TreeValue): + values, spec = _treevalue_flatten(v) + type_ = TreeValue + elif type(v) in _REGISTERED_CONTAINERS: + flatten_func, _ = _REGISTERED_CONTAINERS[type(v)] + values, spec = flatten_func(v) + type_ = type(v) + else: + return v, None, None + + return values, type_, spec + +cdef inline object _c_get_object_from_flatted(object values, object type_, object spec): + cdef object unflatten_func + if type_ is dict: + return _dict_unflatten(values, spec) + elif type_ is namedtuple: + return _namedtuple_unflatten(values, spec) + elif type_ is list: + return _list_and_tuple_unflatten(values, spec) + elif type_ is TreeValue: + return _treevalue_unflatten(values, spec) + elif type_ in _REGISTERED_CONTAINERS: + _, unflatten_func = _REGISTERED_CONTAINERS[type_] + return unflatten_func(values, spec) + else: + raise TypeError(f'Unknown type for unflatten - {values!r}, {spec!r}.') # pragma: no cover + +@cython.binding(True) +cpdef inline object generic_flatten(object v): + """ + Overview: + Flatten generic data, including native objects, ``TreeValue``, namedtuples and custom classes \ + (see :func:`register_integrate_container`). + + :param v: Value to be flatted. + :return: Flatted value. + + Examples:: + >>> from collections import namedtuple + >>> from easydict import EasyDict + >>> from treevalue import FastTreeValue, generic_flatten, generic_unflatten + >>> + >>> class MyTreeValue(FastTreeValue): + ... pass + >>> + >>> nt = namedtuple('nt', ['a', 'b']) + >>> + >>> origin = { + ... 'a': 1, + ... 'b': (2, 3, 'f',), + ... 'c': (2, 5, 'ds', EasyDict({ # dict's child class + ... 'x': None, + ... 'z': [34, '1.2'], # dataclass + ... })), + ... 'd': nt('f', 100), # namedtuple + ... 'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) # treevalue + ... } + >>> v, spec = generic_flatten(origin) + >>> v + [1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']] + >>> + >>> rv = generic_unflatten(v, spec) + >>> rv + {'a': 1, 'b': (2, 3, 'f'), 'c': (2, 5, 'ds', {'x': None, 'z': [34, '1.2']}), 'd': nt(a='f', b=100), 'e': + ├── 'x' --> 1 + └── 'y' --> 'dsfljk' + } + >>> type(rv['c'][-1]) + + """ + values, type_, spec = _c_get_flatted_values_and_spec(v) + if type_ is None: + return values, (None, None, None) + + cdef list child_values = [] + cdef list child_specs = [] + cdef object value, cval, cspec + for value in values: + cval, cspec = generic_flatten(value) + child_values.append(cval) + child_specs.append(cspec) + + return child_values, (type_, spec, child_specs) + +@cython.binding(True) +cpdef inline object generic_unflatten(object v, tuple gspec): + """ + Overview: + Inverse operation of :func:`generic_flatten`. + + :param v: Flatted values. + :param gspec: Spec data of original object. + + Examples:: + See :func:`generic_flatten`. + """ + cdef object type_, spec + cdef list child_specs + type_, spec, child_specs = gspec + if type_ is None: + return v + + cdef list values = [] + cdef object _i_value, _i_spec + for _i_value, _i_spec in zip(v, child_specs): + values.append(generic_unflatten(_i_value, _i_spec)) + + return _c_get_object_from_flatted(values, type_, spec) + +@cython.binding(True) +cpdef inline object generic_mapping(object v, object func): + """ + Overview: + Generic map all the values, including native objects, ``TreeValue``, namedtuples and custom classes \ + (see :func:`register_integrate_container`) + + :param v: Original value, nested structure is supported. + :param func: Function to operate. + + Examples:: + >>> from collections import namedtuple + >>> from easydict import EasyDict + >>> from treevalue import FastTreeValue, generic_mapping + >>> + >>> class MyTreeValue(FastTreeValue): + ... pass + >>> + >>> nt = namedtuple('nt', ['a', 'b']) + >>> + >>> origin = { + ... 'a': 1, + ... 'b': (2, 3, 'f',), + ... 'c': (2, 5, 'ds', EasyDict({ # dict's child class + ... 'x': None, + ... 'z': [34, '1.2'], # dataclass + ... })), + ... 'd': nt('f', 100), # namedtuple + ... 'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) # treevalue + ... } + >>> generic_mapping(origin, str) + {'a': '1', 'b': ('2', '3', 'f'), 'c': ('2', '5', 'ds', {'x': 'None', 'z': ['34', '1.2']}), 'd': nt(a='f', b='100'), 'e': + ├── 'x' --> '1' + └── 'y' --> 'dsfljk' + } + """ + values, type_, spec = _c_get_flatted_values_and_spec(v) + if type_ is None: + return func(values) + + cdef list retvals = [] + cdef object value + for value in values: + retvals.append(generic_mapping(value, func)) + + return _c_get_object_from_flatted(retvals, type_, spec)