Skip to content

Commit

Permalink
Introduces static attribute Operation.unpacked_args_to_init (#73)
Browse files Browse the repository at this point in the history
* Introduces static attribute Operation.unpacked_args_to_init

Passing unpacked args is generally helpful if the class derived from
Operation has its constructor implemented as a dataclass.

* adds an example for initializing Operation sub-classes via dataclasses

* make 'dataclasses' a development dependency
  • Loading branch information
kaushikcfd authored Oct 29, 2021
1 parent 390b2d3 commit 45f6374
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 2 deletions.
9 changes: 8 additions & 1 deletion matchpy/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,10 @@ def __call__(cls, *operands: Expression, variable_name=None):
return operands[0]

operation = Expression.__new__(cls)
operation.__init__(operands, variable_name=variable_name)
if not cls.unpacked_args_to_init:
operation.__init__(operands, variable_name=variable_name)
else:
operation.__init__(*operands, variable_name=variable_name)

return operation

Expand Down Expand Up @@ -358,6 +361,10 @@ class Operation(Expression, metaclass=_OperationMeta):
infix = False
"""bool: True if the name of the operation should be used as an infix operator by str()."""


unpacked_args_to_init = False
"""bool: True if the class must be instantiated with ``*operands`` instead of ``operands``."""

def __init__(self, operands: List[Expression], variable_name=None) -> None:
"""Create an operation expression.
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ develop = %(docs)s
pytest-cov>=2.4,<3.0
coveralls
flake8
dataclasses; python_version<'3.7'

[flake8]
max-line-length = 120
Expand Down
35 changes: 34 additions & 1 deletion tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import pytest
from multiset import Multiset

from matchpy.expressions.expressions import (Arity, Operation, Symbol, SymbolWildcard, Wildcard, Expression)
from matchpy.expressions.expressions import (Arity, Operation, Symbol, SymbolWildcard, Wildcard, Expression, Pattern)
from matchpy import match
from dataclasses import dataclass, field, fields
from typing import ClassVar, Optional

from .common import *

SIMPLE_EXPRESSIONS = [
Expand Down Expand Up @@ -349,3 +353,32 @@ def test_one_identity_error(self):
def test_infix_error(self):
with pytest.raises(TypeError):
Operation.new('Invalid', Arity.unary, infix=True)


class AbstractDataclassOp(Operation):
@property
def operands(self):
return tuple(getattr(self, field.name)
for field in fields(self)
if not field.metadata.get("not_an_operand", False))


@dataclass
class MySum(AbstractDataclassOp):
x1: Expression
x2: Expression
arity: ClassVar[Arity] = Arity.binary
variable_name: Optional[str] = field(default=None, metadata={"not_an_operand": True})
unpacked_args_to_init: ClassVar[bool] = True



def test_dataclass_operation_subclass():
x1_ = Wildcard.dot("x1")
x2_ = Wildcard.dot("x2")

matches = match(MySum(Symbol("foo"), Symbol("bar")), Pattern(MySum(x1_, x2_)))
subst, = list(matches)

assert subst["x1"].name == "foo"
assert subst["x2"].name == "bar"

0 comments on commit 45f6374

Please sign in to comment.