Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request] Training with torch.compile #307

Open
RaulPPelaez opened this issue Mar 18, 2024 · 0 comments
Open

[Feature request] Training with torch.compile #307

RaulPPelaez opened this issue Mar 18, 2024 · 0 comments

Comments

@RaulPPelaez
Copy link
Collaborator

Currently it is not possible to run backwards twice with torch.compile. For instance, this code fails:

from torch import nn, Tensor


class Model(nn.Module):
    def forward(self, input: Tensor) -> Tensor:
        output = input * input
        return output


model = Model()
model = torch.compile(model, backend="inductor")
input = torch.randn(10, requires_grad=True)
y = model(input)
dy = torch.autograd.grad(
    y, input, grad_outputs=torch.ones_like(y), create_graph=True, retain_graph=True
)[0]
ddy = torch.autograd.grad(dy, input, grad_outputs=torch.ones_like(dy))[0]

With this error

$ python test_model.py 
Traceback (most recent call last):
  File "/home/raul/work/bcn/torchmd-net/tests/test_model.py", line 266, in <module>
    ddy = torch.autograd.grad(dy, input, grad_outputs=torch.ones_like(dy))[0]
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raul/miniforge3/envs/torchnightly/lib/python3.11/site-packages/torch/autograd/__init__.py", line 412, in grad
    result = _engine_run_backward(
             ^^^^^^^^^^^^^^^^^^^^^
  File "/home/raul/miniforge3/envs/torchnightly/lib/python3.11/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raul/miniforge3/envs/torchnightly/lib/python3.11/site-packages/torch/autograd/function.py", line 301, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/raul/miniforge3/envs/torchnightly/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 877, in backward
    raise RuntimeError(
RuntimeError: torch.compile with aot_autograd does not currently support double backward

Even when using the latest pytorch nightly:
pytorch 2.3.0.dev20240313 py3.11_cpu_0 pytorch-nightly

This is a well known limitation of compile pytorch/pytorch#91469

TorchMD-Net uses backpropagation to compute forces from energies, which means that double backpropagation is required to train with forces.

Thus, one cannot currently do this, as it will trigger the same error as above:

def test_compile_double_backwards():
    pl.seed_everything(12345)
    output_model = "Scalar"
    derivative = True
    args = load_example_args(
        "tensornet",
        remove_prior=True,
        output_model=output_model,
        derivative=derivative,
    )
    model = create_model(args)
    model = torch.compile(model, backend="inductor")
    z, pos, batch = create_example_batch(n_atoms=5)
    pos.requires_grad_(True)
    y, dy = model(z, pos, batch)
    dy.sum().backward()

I am opening this issue to keep track of the feature.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant