diff --git a/test/export/test_export.py b/test/export/test_export.py index faa634266986a..552aa23d11d56 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -2040,6 +2040,32 @@ def forward(self, x): ): export(Module(), (torch.tensor(1, device="cpu"),)) + def test_float_conversion(self): + class Module(torch.nn.Module): + def forward(self, x): + return x.float() + + ep = export(Module(), (torch.tensor(1, dtype=torch.float),)) + ops = [] + for node in ep.graph.nodes: + if node.op == "call_function": + ops.append(node.target) + self.assertGreater(len(ops), 0) + for op in ops: + self.assertIn(op, (torch.ops.aten._to_copy.default,)) + + def test_device_to_mutation_float(self): + class Module(torch.nn.Module): + def forward(self, x): + y = x.float() + y.add_(1) + return y, x + + with self.assertRaisesRegex( + RuntimeError, "cannot mutate tensors with frozen storage" + ): + export(Module(), (torch.tensor(1, dtype=torch.float),)) + def test_module(self): class MyLinear(torch.nn.Module): def __init__(self): diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index fb2a81b8aeb20..f0a730926af64 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -17,6 +17,13 @@ not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented") +def _conversion_method_template(**extra_kwargs): + def _(self, *args, **kwargs): + return self.to(*args, **{**kwargs, **extra_kwargs}) + + return _ + + class FunctionalTensor(torch.Tensor): """ Functional tensors represent tensors that will remove mutations @@ -225,6 +232,24 @@ def to(self, *args, **kwargs): return super().to(*args, **{**kwargs, "copy": True}) return super().to(*args, **kwargs) + def cuda(self, device=None, *args, **kwargs): + device = device or torch.cuda.current_device() + if len(args) > 0: + return self.to(device, *args, **kwargs) + else: + return self.to(device=device, **kwargs) + + char = _conversion_method_template(dtype=torch.int8) + cpu = _conversion_method_template(device=torch.device("cpu")) + bfloat16 = _conversion_method_template(dtype=torch.bfloat16) + byte = _conversion_method_template(dtype=torch.uint8) + double = _conversion_method_template(dtype=torch.float64) + float = _conversion_method_template(dtype=torch.float32) + bool = _conversion_method_template(dtype=torch.bool) + half = _conversion_method_template(dtype=torch.float16) + int = _conversion_method_template(dtype=torch.int32) + long = _conversion_method_template(dtype=torch.int64) + class FunctionalTensorMode(TorchDispatchMode): def __init__(self, pre_dispatch=False, export=False, _allow_token_discovery=False):