diff --git a/matchpy/expressions/expressions.py b/matchpy/expressions/expressions.py index 17d22cf..f960108 100644 --- a/matchpy/expressions/expressions.py +++ b/matchpy/expressions/expressions.py @@ -280,7 +280,10 @@ def __call__(cls, *operands: Expression, variable_name=None): return operands[0] operation = Expression.__new__(cls) - operation.__init__(operands, variable_name=variable_name) + if not cls.unpacked_args_to_init: + operation.__init__(operands, variable_name=variable_name) + else: + operation.__init__(*operands, variable_name=variable_name) return operation @@ -358,6 +361,10 @@ class Operation(Expression, metaclass=_OperationMeta): infix = False """bool: True if the name of the operation should be used as an infix operator by str().""" + + unpacked_args_to_init = False + """bool: True if the class must be instantiated with ``*operands`` instead of ``operands``.""" + def __init__(self, operands: List[Expression], variable_name=None) -> None: """Create an operation expression. diff --git a/setup.cfg b/setup.cfg index e02980a..bac6ced 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,6 +52,7 @@ develop = %(docs)s pytest-cov>=2.4,<3.0 coveralls flake8 + dataclasses; python_version<'3.7' [flake8] max-line-length = 120 diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 0849812..aa9879a 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -5,7 +5,11 @@ import pytest from multiset import Multiset -from matchpy.expressions.expressions import (Arity, Operation, Symbol, SymbolWildcard, Wildcard, Expression) +from matchpy.expressions.expressions import (Arity, Operation, Symbol, SymbolWildcard, Wildcard, Expression, Pattern) +from matchpy import match +from dataclasses import dataclass, field, fields +from typing import ClassVar, Optional + from .common import * SIMPLE_EXPRESSIONS = [ @@ -349,3 +353,32 @@ def test_one_identity_error(self): def test_infix_error(self): with pytest.raises(TypeError): Operation.new('Invalid', Arity.unary, infix=True) + + +class AbstractDataclassOp(Operation): + @property + def operands(self): + return tuple(getattr(self, field.name) + for field in fields(self) + if not field.metadata.get("not_an_operand", False)) + + +@dataclass +class MySum(AbstractDataclassOp): + x1: Expression + x2: Expression + arity: ClassVar[Arity] = Arity.binary + variable_name: Optional[str] = field(default=None, metadata={"not_an_operand": True}) + unpacked_args_to_init: ClassVar[bool] = True + + + +def test_dataclass_operation_subclass(): + x1_ = Wildcard.dot("x1") + x2_ = Wildcard.dot("x2") + + matches = match(MySum(Symbol("foo"), Symbol("bar")), Pattern(MySum(x1_, x2_))) + subst, = list(matches) + + assert subst["x1"].name == "foo" + assert subst["x2"].name == "bar"