mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR adds support for torch.autograd.Function subclasses in compiled autograd. We do this by:
- Creating a uid for all torch.autograd.Function via its metaclass. This uid is used in the compiled autograd key, which is a subset of the cache key to the compiled graph
- "Lifting" the backward/saved_tensors, having them as input arguments in the compiled graph
- Creating proxies to track the backward's inputs and outputs. Since the backward's outputs (grads) have to match the forward's inputs, we pass the node's `input_info` (forward's input sizes) to build the proxies tracking the backward's outputs.
- Use a `FakeContext` class as a replacement for the autograd node's context object (`BackwardCFunction`) during tracing, only support passing saved_tensors from the forward to the backward
- Index each backward, to support multiple torch.autograd.Functions in the same graph
- Special case for `CompiledFunctionBackward`, lifting CompiledFunction will fail 4 tests and requires some skipfiles changes that I'd rather do that in a separate PR
Example graph: test_custom_fn_saved_multiple_tensors (eager fw + compiled autograd)
```python
class MyFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return torch.sin(x), torch.sin(y)
@staticmethod
def backward(ctx, gO_x, gO_y):
(x, y) = ctx.saved_tensors
return gO_x * torch.cos(x), gO_y * torch.cos(y)
```
The backwards is lifted via `getitem_5` and `call_backward`
```python
# Compiled autograd graph
===== Compiled autograd graph =====
<eval_with_key>.0 class CompiledAutograd(torch.nn.Module):
def forward(self, inputs, sizes, hooks):
# No stacktrace found for following nodes
getitem: "f32[]" = inputs[0]
getitem_1: "f32[10]" = inputs[1]
getitem_2: "f32[10]" = inputs[2]
getitem_3: "f32[10]" = inputs[3]
getitem_4: "f32[10]" = inputs[4]; inputs = None
expand: "f32[10]" = torch.ops.aten.expand.default(getitem, [10]); getitem = None
mul: "f32[10]" = torch.ops.aten.mul.Tensor(expand, getitem_2); getitem_2 = None
mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(expand, getitem_1); expand = getitem_1 = None
getitem_5 = hooks[0]; hooks = None
call_backward = torch__dynamo_external_utils_call_backward(getitem_5, (getitem_3, getitem_4), mul_1, mul); getitem_5 = mul_1 = mul = None
getitem_6: "f32[10]" = call_backward[0]
getitem_7: "f32[10]" = call_backward[1]; call_backward = None
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_4, getitem_7); getitem_4 = getitem_7 = None
accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_3, getitem_6); getitem_3 = getitem_6 = None
return []
```
then is later inlined by dynamo
```python
# Dynamo graph
===== __compiled_fn_0 =====
<eval_with_key>.1 class GraphModule(torch.nn.Module):
def forward(self, L_inputs_0_ : torch.Tensor, L_inputs_1_ : torch.Tensor, L_inputs_2_ : torch.Tensor, L_inputs_3_ : torch.Tensor, L_inputs_4_ : torch.Tensor):
getitem = L_inputs_0_
getitem_1 = L_inputs_1_
getitem_2 = L_inputs_2_
x = L_inputs_3_
y = L_inputs_4_
# File: <eval_with_key>.0:10, code: expand = torch.ops.aten.expand.default(getitem, [10]); getitem = None
expand = torch.ops.aten.expand.default(getitem, [10]); getitem = None
# File: <eval_with_key>.0:11, code: mul = torch.ops.aten.mul.Tensor(expand, getitem_2); getitem_2 = None
mul = torch.ops.aten.mul.Tensor(expand, getitem_2); getitem_2 = None
# File: <eval_with_key>.0:12, code: mul_1 = torch.ops.aten.mul.Tensor(expand, getitem_1); expand = getitem_1 = None
mul_1 = torch.ops.aten.mul.Tensor(expand, getitem_1); expand = getitem_1 = None
# File: /data/users/xmfan/core/pytorch/test/inductor/test_compiled_autograd.py:412, code: return gO_x * torch.cos(x), gO_y * torch.cos(y)
cos = torch.cos(x)
getitem_6 = mul_1 * cos; mul_1 = cos = None
cos_1 = torch.cos(y)
getitem_7 = mul * cos_1; mul = cos_1 = None
# File: <eval_with_key>.0:17, code: accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_4, getitem_7); getitem_4 = getitem_7 = None
accumulate_grad__default = torch.ops.inductor.accumulate_grad_.default(y, getitem_7); y = getitem_7 = None
# File: <eval_with_key>.0:18, code: accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_3, getitem_6); getitem_3 = getitem_6 = None
accumulate_grad__default_1 = torch.ops.inductor.accumulate_grad_.default(x, getitem_6); x = getitem_6 = None
return ()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115573
Approved by: https://github.com/jansel
73 lines
1.6 KiB
Python
73 lines
1.6 KiB
Python
# This module contains functions that *will be allowed* by dynamo
|
|
|
|
import functools
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
|
|
try:
|
|
import numpy as np
|
|
except ModuleNotFoundError:
|
|
np = None # type: ignore[assignment]
|
|
|
|
|
|
def is_compiling() -> bool:
|
|
return False
|
|
|
|
|
|
def wrap_inline(fn):
|
|
"""
|
|
Create an extra frame around fn that is not in skipfiles
|
|
"""
|
|
|
|
@functools.wraps(fn)
|
|
def inner(*args, **kwargs):
|
|
return fn(*args, **kwargs)
|
|
|
|
return inner
|
|
|
|
|
|
def call_hook(hook, *args):
|
|
"""
|
|
Used by compiled autograd to handle hook returning None
|
|
"""
|
|
result = hook(*args)
|
|
if result is None:
|
|
return args[0]
|
|
return result
|
|
|
|
|
|
def wrap_numpy(f):
|
|
r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
|
|
from ``torch.Tensor``s to ``torch.Tensor``s.
|
|
"""
|
|
if not np:
|
|
return f
|
|
|
|
@functools.wraps(f)
|
|
def wrap(*args, **kwargs):
|
|
args, kwargs = pytree.tree_map_only(
|
|
torch.Tensor, lambda x: x.numpy(), (args, kwargs)
|
|
)
|
|
out = f(*args, **kwargs)
|
|
return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out)
|
|
|
|
return wrap
|
|
|
|
|
|
class FakeContext:
|
|
def __init__(self, saved_tensors):
|
|
# this will cache the results of saved_tensors
|
|
# and will no longer call into c++ binding
|
|
self.saved_tensors = saved_tensors
|
|
|
|
|
|
def call_backward(backward_fn, saved_tensors, *args):
|
|
grads = backward_fn(FakeContext(saved_tensors), *args)
|
|
|
|
# in eager, we wrap in a tuple when there's only one grad output
|
|
if type(grads) is not tuple:
|
|
grads = (grads,)
|
|
|
|
return grads
|