Skip to content

Commit

Permalink
Implements shape Ops and MakeVector in PyTorch (#926)
Browse files Browse the repository at this point in the history
* Implements shape and MakeVector Ops in PyTorch

- Shape
- Shape_i
- Reshape
- SpecifyShape
- Unbroadcast
- MakeVector
  • Loading branch information
twaclaw authored Jul 17, 2024
1 parent 6b8df2c commit 426931b
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 23 deletions.
1 change: 1 addition & 0 deletions pytensor/link/pytorch/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
import pytensor.link.pytorch.dispatch.elemwise
import pytensor.link.pytorch.dispatch.extra_ops
import pytensor.link.pytorch.dispatch.sort
import pytensor.link.pytorch.dispatch.shape
# isort: on
18 changes: 17 additions & 1 deletion pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from functools import singledispatch
from types import NoneType

import torch

from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join, MakeVector


@singledispatch
Expand All @@ -15,6 +16,11 @@ def pytorch_typify(data, dtype=None, **kwargs):
return torch.as_tensor(data, dtype=dtype)


@pytorch_typify.register(NoneType)
def pytorch_typify_None(data, **kwargs):
return None


@singledispatch
def pytorch_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a PyTorch compatible function from an PyTensor `Op`."""
Expand Down Expand Up @@ -116,3 +122,13 @@ def eye(N, M, k):
return zeros

return eye


@pytorch_funcify.register(MakeVector)
def pytorch_funcify_MakeVector(op, **kwargs):
torch_dtype = getattr(torch, op.dtype)

def makevector(*x):
return torch.tensor(x, dtype=torch_dtype)

return makevector
52 changes: 52 additions & 0 deletions pytensor/link/pytorch/dispatch/shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast


@pytorch_funcify.register(Reshape)
def pytorch_funcify_Reshape(op, node, **kwargs):
def reshape(x, shape):
return torch.reshape(x, tuple(shape))

return reshape


@pytorch_funcify.register(Shape)
def pytorch_funcify_Shape(op, **kwargs):
def shape(x):
return x.shape

return shape


@pytorch_funcify.register(Shape_i)
def pytorch_funcify_Shape_i(op, **kwargs):
i = op.i

def shape_i(x):
return torch.tensor(x.shape[i])

return shape_i


@pytorch_funcify.register(SpecifyShape)
def pytorch_funcify_SpecifyShape(op, node, **kwargs):
def specifyshape(x, *shape):
assert x.ndim == len(shape)
for actual, expected in zip(x.shape, shape):
if expected is None:
continue
if actual != expected:
raise ValueError(f"Invalid shape: Expected {shape} but got {x.shape}")
return x

return specifyshape


@pytorch_funcify.register(Unbroadcast)
def pytorch_funcify_Unbroadcast(op, **kwargs):
def unbroadcast(x):
return x

return unbroadcast
7 changes: 7 additions & 0 deletions tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,10 @@ def test_eye(dtype):
for _M in range(1, 6):
for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]:
np.testing.assert_array_equal(fn(_N, _M, _k), np.eye(_N, _M, _k))


def test_pytorch_MakeVector():
x = ptb.make_vector(1, 2, 3)
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [])
13 changes: 1 addition & 12 deletions tests/link/pytorch/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,7 @@ def test_pytorch_CumOp(axis, dtype):
compare_pytorch_and_py(fgraph, [test_value])


@pytest.mark.parametrize(
"axis, repeats",
[
(0, (1, 2, 3)),
(1, (3, 3)),
pytest.param(
None,
3,
marks=pytest.mark.xfail(reason="Reshape not implemented"),
),
],
)
@pytest.mark.parametrize("axis, repeats", [(0, (1, 2, 3)), (1, (3, 3)), (None, 3)])
def test_pytorch_Repeat(axis, repeats):
a = pt.matrix("a", dtype="float64")

Expand Down
61 changes: 61 additions & 0 deletions tests/link/pytorch/test_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import numpy as np

import pytensor.tensor as pt
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape
from pytensor.tensor.type import iscalar, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py


def test_pytorch_shape_ops():
x_np = np.zeros((20, 3))
x = Shape()(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [], must_be_device_array=False)

x = Shape_i(1)(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [], must_be_device_array=False)


def test_pytorch_specify_shape():
in_pt = pt.matrix("in")
x = pt.specify_shape(in_pt, (4, None))
x_fg = FunctionGraph([in_pt], [x])
compare_pytorch_and_py(x_fg, [np.ones((4, 5)).astype(config.floatX)])

# When used to assert two arrays have similar shapes
in_pt = pt.matrix("in")
shape_pt = pt.matrix("shape")
x = pt.specify_shape(in_pt, shape_pt.shape)
x_fg = FunctionGraph([in_pt, shape_pt], [x])
compare_pytorch_and_py(
x_fg,
[np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)],
)


def test_pytorch_Reshape_constant():
a = vector("a")
x = reshape(a, (2, 2))
x_fg = FunctionGraph([a], [x])
compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])


def test_pytorch_Reshape_dynamic():
a = vector("a")
shape_pt = iscalar("b")
x = reshape(a, (shape_pt, shape_pt))
x_fg = FunctionGraph([a, shape_pt], [x])
compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2])


def test_pytorch_unbroadcast():
x_np = np.zeros((20, 1, 1))
x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [])
11 changes: 1 addition & 10 deletions tests/link/pytorch/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,7 @@


@pytest.mark.parametrize("func", (sort, argsort))
@pytest.mark.parametrize(
"axis",
[
pytest.param(0),
pytest.param(1),
pytest.param(
None, marks=pytest.mark.xfail(reason="Reshape Op not implemented")
),
],
)
@pytest.mark.parametrize("axis", [0, 1, None])
def test_sort(func, axis):
x = matrix("x", shape=(2, 2), dtype="float64")
out = func(x, axis=axis)
Expand Down

0 comments on commit 426931b

Please sign in to comment.