Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented Repeat and Unique Ops in PyTorch #890

Merged
merged 5 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion pytensor/link/pytorch/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.extra_ops import CumOp
from pytensor.tensor.extra_ops import CumOp, Repeat, Unique


@pytorch_funcify.register(CumOp)
Expand All @@ -21,3 +21,38 @@ def cumop(x):
return torch.cumprod(x, dim=dim)

return cumop


@pytorch_funcify.register(Repeat)
def pytorch_funcify_Repeat(op, **kwargs):
axis = op.axis

def repeat(x, repeats):
return x.repeat_interleave(repeats, dim=axis)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved

return repeat


@pytorch_funcify.register(Unique)
def pytorch_funcify_Unique(op, **kwargs):
return_index = op.return_index

if return_index:
# TODO: evaluate whether is worth implementing this param
# (see https://github.com/pytorch/pytorch/issues/36748)
raise NotImplementedError("return_index is not implemented for pytorch")

axis = op.axis
return_inverse = op.return_inverse
return_counts = op.return_counts

def unique(x):
return torch.unique(
x,
sorted=True,
return_inverse=return_inverse,
return_counts=return_counts,
dim=axis,
)

return unique
58 changes: 58 additions & 0 deletions tests/link/pytorch/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,61 @@ def test_pytorch_CumOp(axis, dtype):
out = pt.cumprod(a, axis=axis)
fgraph = FunctionGraph([a], [out])
compare_pytorch_and_py(fgraph, [test_value])


@pytest.mark.parametrize(
"axis, repeats",
[
(0, (1, 2, 3)),
(1, (3, 3)),
pytest.param(
None,
3,
marks=pytest.mark.xfail(reason="Reshape not implemented"),
),
],
)
def test_pytorch_Repeat(axis, repeats):
a = pt.matrix("a", dtype="float64")

test_value = np.arange(6, dtype="float64").reshape((3, 2))

out = pt.repeat(a, repeats, axis=axis)
fgraph = FunctionGraph([a], [out])
compare_pytorch_and_py(fgraph, [test_value])


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_pytorch_Unique_axis(axis):
a = pt.matrix("a", dtype="float64")

test_value = np.array(
[[1.0, 1.0, 2.0], [1.0, 1.0, 2.0], [3.0, 3.0, 0.0]], dtype="float64"
)

out = pt.unique(a, axis=axis)
Copy link
Member

@ricardoV94 ricardoV94 Jul 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does PyTensor only allows integer axis at the moment? No None or partial multiple axis? If we do, we should test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All good with this one?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't check the code, but we need to see if multiple (but not all) axis is supported by PyTensor and if so test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean multiple axis like in the case of ArgMax and Max?
I don't think so:

axis : int, optional

fgraph = FunctionGraph([a], [out])
compare_pytorch_and_py(fgraph, [test_value])


@pytest.mark.parametrize("return_inverse", [False, True])
@pytest.mark.parametrize("return_counts", [False, True])
@pytest.mark.parametrize(
"return_index",
(False, pytest.param(True, marks=pytest.mark.xfail(raises=NotImplementedError))),
)
def test_pytorch_Unique_params(return_index, return_inverse, return_counts):
a = pt.matrix("a", dtype="float64")
test_value = np.array(
[[1.0, 1.0, 2.0], [1.0, 1.0, 2.0], [3.0, 3.0, 0.0]], dtype="float64"
)

out = pt.unique(
a,
return_index=return_index,
return_inverse=return_inverse,
return_counts=return_counts,
axis=0,
)
fgraph = FunctionGraph([a], [out[0] if isinstance(out, list) else out])
compare_pytorch_and_py(fgraph, [test_value])
Loading