pytorch/torch/_functorch/autograd_function.py
Richard Zou 2a55984139 [generate_vmap_rule] reductify_leaf helper function (#90965)
As seen in
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit

`reductify_leaf(grad_input, ...)` is a helper function that processes a
single grad_input Tensor. The reason why we need it is:
- the grad_input has some optional bdim
- the input has some optional bdim
- if these are different, we need to coerce the grad_input into having
the same shape as the input, either by reducing or expanding the
grad_input.

Note that there is a special case in autograd that the user is allowed
to return a grad_input Tensor that is an expanded version of the
original input tensor. In this case, autograd automatically reduces
grad_input to the same shape as the input. Unfortunately this logic
doesn't work when bdims are involved, so we manually handle it in
`reductify_leaf`.

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90965
Approved by: https://github.com/soulitzer
2022-12-21 00:34:44 +00:00

328 lines
13 KiB
Python

import torch
from torch._ops import PyOperator
from torch._C._functorch import TransformType
from torch._functorch.utils import enable_autograd_function
import torch.utils._pytree as pytree
from torch._C._functorch import (
_wrap_for_grad,
_unwrap_for_grad,
)
from torch._functorch.vmap import (
wrap_batched,
unwrap_batched,
vmap,
)
from torch.autograd.forward_ad import _set_fwd_grad_enabled
from typing import NamedTuple, Tuple
# autograd.Function technically runs before the regular PyTorch dispatcher.
# This is how features like autocast and torch_dispatch (e.g. PythonTLSSnapshot)
# work with it. One day we might decide to change this, but until then,
# we need to give the illusion that autograd.Function runs before those things.
#
# We do this by using creating a custom PyOperator that only functorch
# dispatches specially.
class CustomFunctionPyOperator(PyOperator):
def __init__(self):
super().__init__('custom_function_call')
def __call__(self, *args, **kwargs):
# When custom_function_call is done dispatching through functorch,
# it should just invoke the autograd.Function. This is consistent
# with the autograd.Function behavior of being invoked before the
# PyTorch dispatcher.
#
# This will lead us into trouble later down the line, but this is
# pre-existing. There is an invariant that a function traced by
# make_fx should have the same behavior when provided the same
# Tensor. However, make_fx sees autograd.Function as a composite
# (because autograd.Function happens before the Python dispatch key)
# and only traces the forward pass.
if torch._C._are_functorch_transforms_active():
return super().__call__(*args, **kwargs)
autograd_function = args[0]
return autograd_function.apply(*args[1:], **kwargs)
# "custom_function_call"
# This is the mechanism for an autograd.Function that works with functorch transforms.
# It wraps an autograd.Function; interactions with functorch transforms are defined
# via PyDispatcher and PyOperator rather than through the traditional PyTorch
# dispatcher.
custom_function_call = CustomFunctionPyOperator()
# The grad rule for custom_function_call is to construct a new _SingleLevelFunction
# (autograd.Function that only works with a single layer (level) of functorch) that:
# - unwraps the inputs
# - redispatches to custom_function_call
# - wraps the outputs
# and whose backward pass calls the original autograd.Function's backward.
#
# Why do we need to redispatch to custom_function_call?
# -----------------------------------------------------
# This is consistent with how ATen operators work with functorch's grad transform:
# they always redispatch to the original operator.
# Consider torch.sin, and let's say we do grad0(grad1(torch.sin))(x)
#
# grad1 will:
# - set up the autograd graph
# - unwrap the inputs
# - redispatch to at::sin (*)
# - rewrap the outputs on the return
#
# On the redispatch in (*), grad0 will:
# - set up the autograd graph
# - unwrap the inputs
# - redispatch to at::sin
# - rewrap the outputs on the return
#
# To "set up the autograd graph", we generate a _SingleLevelFunction
# and apply it.
@custom_function_call.py_impl(TransformType.Grad)
@custom_function_call.py_impl(TransformType.Jvp)
def custom_function_call_grad(interpreter, autograd_function, *operands):
Generated = generate_single_level_function(interpreter, autograd_function)
with enable_autograd_function():
flat_out = Generated.apply(*operands)
return flat_out
def generate_single_level_function(interpreter, autograd_function):
level = interpreter.level()
def forward(*operands):
unwrapped_operands = pytree.tree_map_only(
torch.Tensor,
lambda x: _unwrap_for_grad(x, level),
operands)
# Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter
# the transform. _SingleLevelFunction will turn off both fwd and bwd
# gradient computation and we need to turn it back on here.
with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower():
output = custom_function_call(autograd_function, *unwrapped_operands)
return pytree.tree_map_only(
torch.Tensor,
lambda x: _wrap_for_grad(x, level),
output)
def setup_context(ctx, outputs, *operands):
ctx.mark_dirty = mark_dirty_error
return autograd_function.setup_context(ctx, outputs, *operands)
# backward is only used if the transform is TransformType.Grad
def backward(ctx, *grads):
result = autograd_function.backward(ctx, *grads)
return result
# jvp is only used if the transform is TransformType.Jvp
def jvp(ctx, *tangents):
result = autograd_function.jvp(ctx, *tangents)
return result
# This is the sequence of magic words to dynamically generate a Subclass with
# a given name. A Tensor's .grad_fn field has a class name that is the original
# autograd.Function's name + Backward, so we do this to generate some
# meaningful name.
name = f'{autograd_function.__name__}Generated'
Generated = type(
name,
(torch.autograd.function._SingleLevelFunction,),
{
'forward': staticmethod(forward),
'backward': staticmethod(backward),
'jvp': staticmethod(jvp),
'setup_context': staticmethod(setup_context),
},
)
return Generated
# https://github.com/pytorch/pytorch/issues/90225
# If an input was marked as dirty, and the autograd.Function returns the input
# from the forward, then the grad rule for custom_function_call must also
# return the corresponding input from the forward() of the Generated autograd.Function
#
# We haven't figured out how to do this yet. One possibility is to rely
# on if the return from the redispatched custom_function_call in Generated.forward
# has the same object id as one of the inputs,
# but https://github.com/pytorch/pytorch/issues/90209 means we cannot rely on
# that property.
def mark_dirty_error(*args, **kwargs):
raise RuntimeError(
'NYI: we do not yet support ctx.mark_dirty with functorch transforms. '
'Please try to avoid modifying inputs to the autograd.Function in-place '
'by using out-of-place operations or by cloning the inputs. '
'Please see https://github.com/pytorch/pytorch/issues/90209 for more details'
)
# NOTE: [functorch vjp and autograd interaction]
# There's an edge case with the functorch vjp and autograd interaction
# that will eventually be fixed by mode-only functorch.
# The TL;DR is that there's no way to unwrap a dead GradTensorWrapper,
# so we (the framework) need to do it manually. Regular PyTorch operators
# automatically do so this is consisent.
#
# class MyExp(torch.autograd.Function):
# @staticmethod
# def forward(x):
# return x.exp()
#
# @staticmethod
# def setup_context(ctx, outputs, x):
# y = outputs
# ctx.save_for_backward(y)
#
# @staticmethod
# def backward(gy):
# y, = ctx.saved_tensors()
# return MyMul.apply(gy, y)
#
# x = torch.randn([], requires_grad=True)
# gy = torch.randn([], requires_grad=True)
# _, vjp_fn = vjp(MySin.apply, x)
# result = vjp_fn(gy)
#
# MyMul is an autograd.Function that is not shown here.
# It saves a `y` for backward (since gy requires grad).
#
# in vjp_fn(gy), we get:
# > MyMul.apply(gy, GradTensorWrapper(y, level=dead))
# Because the y that is saved for backward by MyExp is a GradTensorWrapper
# but is now dead since we are outside the vjp context.
#
# PyTorch dispatcher operations, upon seeing a dead GradTensorWrapper,
# will automatically unwrap the GradTensorWrapper when applied.
# But since autograd.Function technically sits above the regular PyTorch
# dispatcher, it doesn't get this treatment. So we manually do
# the unwrapping to be consistent with regular PyTorch dispatcher operations.
class VmapInfo(NamedTuple):
batch_size: int
randomness: str
@custom_function_call.py_impl(TransformType.Vmap)
def custom_function_call_vmap(interpreter, autograd_function, *operands):
if not hasattr(autograd_function, "vmap"):
# TODO: link docs when they're ready.
# https://github.com/pytorch/pytorch/issues/90224
raise RuntimeError(
f"You tried to vmap over {autograd_function.__name__}, but "
f"it does not have a vmap rule defined. Please add a vmap "
f"staticmethod to it.")
current_level = interpreter.level()
info = VmapInfo(
batch_size=interpreter.batch_size(),
randomness=interpreter.randomness(),
)
unwrapped_operands, in_dims = unwrap_batched(operands, current_level)
# If none of the tensors are batched at the current level, then we skip the
# current level. This saves the user from needing to handle this case in
# their vmap staticmethod (and is consistent with our C++ batching rule API)
if pytree.tree_all(lambda dim: dim is None, in_dims):
with interpreter.lower():
return custom_function_call(autograd_function, *operands)
with interpreter.lower():
unwrapped_output, out_dims = autograd_function.vmap(info, in_dims, *unwrapped_operands)
# TODO: raise better error message to the user when they don't follow the API.
# Should probably mimic the logic of _process_batched_inputs,
# but that one is hyperspecialized on error messages.
# https://github.com/pytorch/pytorch/issues/90224
output = wrap_batched(unwrapped_output, out_dims, current_level)
return output
@custom_function_call.py_impl(TransformType.Functionalize)
def custom_function_call_functionalize(interpreter, autograd_function, *operands):
raise RuntimeError("NYI: Functionalize rule for custom_function_call")
# Wraps a ctx object. Forwards all attr accesses to the underlying object
# except for the attrs in _pt_attrs
class WrappedCtx:
_pt_reserved_attrs: Tuple[str, ...] = ('_pt_reserved_attrs', '_pt_inner_ctx')
def __init__(self, ctx):
if not isinstance(ctx, WrappedCtx):
reserved_attrs = type(self)._pt_reserved_attrs
for name in reserved_attrs:
if not hasattr(ctx, name):
continue
raise RuntimeError(
f'PyTorch reserves the {reserved_attrs} field on ctx. '
'Please name your fields on ctx something else to avoid name '
'collision.')
self._pt_inner_ctx = ctx
def __getattr__(self, name):
return getattr(self._pt_inner_ctx, name)
def __setattr__(self, name, value):
if name in type(self)._pt_reserved_attrs:
self.__dict__[name] = value
return
return setattr(self._pt_inner_ctx, name, value)
# Wraps ctx to create a new ctx object that overrides saved_tensors.
class CtxWithSavedTensors(WrappedCtx):
_pt_reserved_attrs = ('_pt_new_saved_tensors', *WrappedCtx._pt_reserved_attrs)
def __init__(self, ctx, new_saved_tensors):
super().__init__(ctx)
self._pt_new_saved_tensors = new_saved_tensors
@property
def saved_tensors(self):
return self._pt_new_saved_tensors
def reductify_leaf(grad_input, grad_input_bdim, input_bdim, input_shape_without_bdim, batch_size):
if grad_input is None:
return None
if grad_input_bdim is None and input_bdim is None:
return grad_input
if grad_input_bdim is not None and input_bdim is None:
return grad_input.sum(grad_input_bdim)
# Given a grad_input and input, it is valid for the user to return a
# grad_input that has a broadcasted shape when compared to the input.
# In this situation, autograd automatically reduces the grad_input to
# the shape of the input.
#
# However, when input_bdim is not None, we have problems.
#
# [example 1]
# grad_input: Tensor[3, 4], input: Tensor[B, 4]
# We can expand grad_input to Tensor[B, 3, 4], but that isn't broadcastable
# from [B, 4].
#
# [example 2]
# grad_input: Tensor[3, B, 4], input: Tensor[B, 4]
# We can swizzle grad_input to Tensor[B, 3, 4], but that isn't broadcastable
# from [B, 4].
#
# This means that we need to also reduce the grad_input to the shape of the
# input.
assert input_bdim is not None
if grad_input_bdim is None:
grad_input = grad_input.unsqueeze(input_bdim)
new_shape = list(grad_input.shape)
new_shape[input_bdim] = batch_size
grad_input = grad_input.expand(new_shape)
grad_input_bdim = input_bdim
return vmap(torch.Tensor.sum_to_size, in_dims=(grad_input_bdim, None), out_dims=input_bdim)(
grad_input, input_shape_without_bdim)