From 9a14dba31af657ea40a69192fd7cc41aada1ddda Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Sat, 6 Jul 2024 07:17:59 +0000 Subject: [PATCH 1/5] Implemented Repeat and Unique Ops in PyTorch --- pytensor/link/pytorch/dispatch/extra_ops.py | 43 ++++++++++++++++++- tests/link/pytorch/test_extra_ops.py | 46 +++++++++++++++++++++ 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/pytensor/link/pytorch/dispatch/extra_ops.py b/pytensor/link/pytorch/dispatch/extra_ops.py index f7af1eca7b..5ee7a9b184 100644 --- a/pytensor/link/pytorch/dispatch/extra_ops.py +++ b/pytensor/link/pytorch/dispatch/extra_ops.py @@ -1,7 +1,7 @@ import torch from pytensor.link.pytorch.dispatch.basic import pytorch_funcify -from pytensor.tensor.extra_ops import CumOp +from pytensor.tensor.extra_ops import CumOp, Repeat, Unique @pytorch_funcify.register(CumOp) @@ -21,3 +21,44 @@ def cumop(x): return torch.cumprod(x, dim=dim) return cumop + + +@pytorch_funcify.register(Repeat) +def pytorch_funcify_Repeat(op, **kwargs): + axis = op.axis + + def repeat(x, repeats, axis=axis): + return x.repeat_interleave(repeats, dim=axis) + + return repeat + + +@pytorch_funcify.register(Unique) +def pytorch_funcify_Unique(op, **kwargs): + return_index = op.return_index + + if return_index: + # TODO: evaluate whether is worth implementing this param + # (see https://github.com/pytorch/pytorch/issues/36748) + raise NotImplementedError("return_index is not implemented for pytorch") + + axis = op.axis + return_inverse = op.return_inverse + return_counts = op.return_counts + + def unique( + x, + return_index=return_index, + return_inverse=return_inverse, + return_counts=return_counts, + axis=axis, + ): + return torch.unique( + x, + sorted=True, + return_inverse=return_inverse, + return_counts=return_counts, + dim=axis, + ) + + return unique diff --git a/tests/link/pytorch/test_extra_ops.py b/tests/link/pytorch/test_extra_ops.py index 72faa3d0d0..dbff0895eb 100644 --- a/tests/link/pytorch/test_extra_ops.py +++ b/tests/link/pytorch/test_extra_ops.py @@ -41,3 +41,49 @@ def test_pytorch_CumOp(axis, dtype): out = pt.cumprod(a, axis=axis) fgraph = FunctionGraph([a], [out]) compare_pytorch_and_py(fgraph, [test_value]) + + +def test_pytorch_Repeat(): + a = pt.matrix("a", dtype="float64") + + test_value = np.arange(6, dtype="float64").reshape((3, 2)) + + # Test along axis 0 + out = pt.repeat(a, (1, 2, 3), axis=0) + fgraph = FunctionGraph([a], [out]) + compare_pytorch_and_py(fgraph, [test_value]) + + # Test along axis 1 + out = pt.repeat(a, (3, 3), axis=1) + fgraph = FunctionGraph([a], [out]) + compare_pytorch_and_py(fgraph, [test_value]) + + +def test_pytorch_Unique(): + a = pt.matrix("a", dtype="float64") + + test_value = np.array( + [[1.0, 1.0, 2.0], [1.0, 1.0, 2.0], [3.0, 3.0, 0.0]], dtype="float64" + ) + + # Test along axis 0 + out = pt.unique(a, axis=0) + fgraph = FunctionGraph([a], [out]) + compare_pytorch_and_py(fgraph, [test_value]) + + # Test along axis 1 + out = pt.unique(a, axis=1) + fgraph = FunctionGraph([a], [out]) + compare_pytorch_and_py(fgraph, [test_value]) + + # Test with params + out = pt.unique(a, return_inverse=True, return_counts=True, axis=0) + fgraph = FunctionGraph([a], [out[0]]) + compare_pytorch_and_py(fgraph, [test_value]) + + # Test with return_index=True + out = pt.unique(a, return_index=True, axis=0) + fgraph = FunctionGraph([a], [out[0]]) + + with pytest.raises(NotImplementedError): + compare_pytorch_and_py(fgraph, [test_value]) From dfdc1149a501d61b5357cc9f480779edf46973a1 Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Sat, 6 Jul 2024 09:26:30 +0000 Subject: [PATCH 2/5] Removed kwargs in Repeat and Unique Ops impl. in PyTorch --- pytensor/link/pytorch/dispatch/extra_ops.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/extra_ops.py b/pytensor/link/pytorch/dispatch/extra_ops.py index 5ee7a9b184..74284d651d 100644 --- a/pytensor/link/pytorch/dispatch/extra_ops.py +++ b/pytensor/link/pytorch/dispatch/extra_ops.py @@ -27,7 +27,7 @@ def cumop(x): def pytorch_funcify_Repeat(op, **kwargs): axis = op.axis - def repeat(x, repeats, axis=axis): + def repeat(x, repeats): return x.repeat_interleave(repeats, dim=axis) return repeat @@ -46,13 +46,7 @@ def pytorch_funcify_Unique(op, **kwargs): return_inverse = op.return_inverse return_counts = op.return_counts - def unique( - x, - return_index=return_index, - return_inverse=return_inverse, - return_counts=return_counts, - axis=axis, - ): + def unique(x): return torch.unique( x, sorted=True, From 330a7d27c6098ad43277ab681deb8c0f0d49804c Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Tue, 9 Jul 2024 19:44:25 +0200 Subject: [PATCH 3/5] Parametrized tests for Repeat and Unique impls. in PyTorch --- tests/link/pytorch/test_extra_ops.py | 55 +++++++++++++++------------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/tests/link/pytorch/test_extra_ops.py b/tests/link/pytorch/test_extra_ops.py index dbff0895eb..efc504b6a5 100644 --- a/tests/link/pytorch/test_extra_ops.py +++ b/tests/link/pytorch/test_extra_ops.py @@ -43,47 +43,52 @@ def test_pytorch_CumOp(axis, dtype): compare_pytorch_and_py(fgraph, [test_value]) -def test_pytorch_Repeat(): +@pytest.mark.parametrize("axis", [0, 1]) +def test_pytorch_Repeat(axis): a = pt.matrix("a", dtype="float64") test_value = np.arange(6, dtype="float64").reshape((3, 2)) - # Test along axis 0 - out = pt.repeat(a, (1, 2, 3), axis=0) + out = pt.repeat(a, (1, 2, 3) if axis == 0 else (3, 3), axis=axis) fgraph = FunctionGraph([a], [out]) compare_pytorch_and_py(fgraph, [test_value]) - # Test along axis 1 - out = pt.repeat(a, (3, 3), axis=1) - fgraph = FunctionGraph([a], [out]) - compare_pytorch_and_py(fgraph, [test_value]) - -def test_pytorch_Unique(): +@pytest.mark.parametrize("axis", [0, 1]) +def test_pytorch_Unique_axis(axis): a = pt.matrix("a", dtype="float64") test_value = np.array( [[1.0, 1.0, 2.0], [1.0, 1.0, 2.0], [3.0, 3.0, 0.0]], dtype="float64" ) - # Test along axis 0 - out = pt.unique(a, axis=0) - fgraph = FunctionGraph([a], [out]) - compare_pytorch_and_py(fgraph, [test_value]) - - # Test along axis 1 - out = pt.unique(a, axis=1) + out = pt.unique(a, axis=axis) fgraph = FunctionGraph([a], [out]) compare_pytorch_and_py(fgraph, [test_value]) - # Test with params - out = pt.unique(a, return_inverse=True, return_counts=True, axis=0) - fgraph = FunctionGraph([a], [out[0]]) - compare_pytorch_and_py(fgraph, [test_value]) - # Test with return_index=True - out = pt.unique(a, return_index=True, axis=0) - fgraph = FunctionGraph([a], [out[0]]) +@pytest.mark.parametrize( + "return_index, return_inverse, return_counts", + [ + (False, True, False), + (False, True, True), + pytest.param( + True, False, False, marks=pytest.mark.xfail(raises=NotImplementedError) + ), + ], +) +def test_pytorch_Unique_params(return_index, return_inverse, return_counts): + a = pt.matrix("a", dtype="float64") + test_value = np.array( + [[1.0, 1.0, 2.0], [1.0, 1.0, 2.0], [3.0, 3.0, 0.0]], dtype="float64" + ) - with pytest.raises(NotImplementedError): - compare_pytorch_and_py(fgraph, [test_value]) + out = pt.unique( + a, + return_index=return_index, + return_inverse=return_inverse, + return_counts=return_counts, + axis=0, + ) + fgraph = FunctionGraph([a], [out[0] if isinstance(out, list) else out]) + compare_pytorch_and_py(fgraph, [test_value]) From 8dbe5ca6ccc0a92f9b86324ed03e092a03eaad05 Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Wed, 10 Jul 2024 06:07:38 +0200 Subject: [PATCH 4/5] Change parametrization in tests for Repeat and Unique impls. in PyTorch --- tests/link/pytorch/test_extra_ops.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/link/pytorch/test_extra_ops.py b/tests/link/pytorch/test_extra_ops.py index efc504b6a5..0896a5ddc1 100644 --- a/tests/link/pytorch/test_extra_ops.py +++ b/tests/link/pytorch/test_extra_ops.py @@ -67,15 +67,11 @@ def test_pytorch_Unique_axis(axis): compare_pytorch_and_py(fgraph, [test_value]) +@pytest.mark.parametrize("return_inverse", [False, True]) +@pytest.mark.parametrize("return_counts", [False, True]) @pytest.mark.parametrize( - "return_index, return_inverse, return_counts", - [ - (False, True, False), - (False, True, True), - pytest.param( - True, False, False, marks=pytest.mark.xfail(raises=NotImplementedError) - ), - ], + "return_index", + (False, pytest.param(True, marks=pytest.mark.xfail(raises=NotImplementedError))), ) def test_pytorch_Unique_params(return_index, return_inverse, return_counts): a = pt.matrix("a", dtype="float64") From 62e453dee46c6d1e3d66024512f27409ee0a27c8 Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Wed, 10 Jul 2024 21:40:48 +0200 Subject: [PATCH 5/5] Added test axis=None for Repeat in PyTorch --- tests/link/pytorch/test_extra_ops.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/link/pytorch/test_extra_ops.py b/tests/link/pytorch/test_extra_ops.py index 0896a5ddc1..221855864a 100644 --- a/tests/link/pytorch/test_extra_ops.py +++ b/tests/link/pytorch/test_extra_ops.py @@ -43,18 +43,29 @@ def test_pytorch_CumOp(axis, dtype): compare_pytorch_and_py(fgraph, [test_value]) -@pytest.mark.parametrize("axis", [0, 1]) -def test_pytorch_Repeat(axis): +@pytest.mark.parametrize( + "axis, repeats", + [ + (0, (1, 2, 3)), + (1, (3, 3)), + pytest.param( + None, + 3, + marks=pytest.mark.xfail(reason="Reshape not implemented"), + ), + ], +) +def test_pytorch_Repeat(axis, repeats): a = pt.matrix("a", dtype="float64") test_value = np.arange(6, dtype="float64").reshape((3, 2)) - out = pt.repeat(a, (1, 2, 3) if axis == 0 else (3, 3), axis=axis) + out = pt.repeat(a, repeats, axis=axis) fgraph = FunctionGraph([a], [out]) compare_pytorch_and_py(fgraph, [test_value]) -@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.parametrize("axis", [None, 0, 1]) def test_pytorch_Unique_axis(axis): a = pt.matrix("a", dtype="float64")