pytorch/torch/_functorch/autograd_function.py
Richard Zou f21cb7d77e [pyfunctorch] Generate a more meaningful name for _SingleLevelAutogradFunction (#90418)
The API to do this is not pretty, but at least it works.

Test Plan:
- new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90418
Approved by: https://github.com/soulitzer
2022-12-14 16:20:57 +00:00

260 lines
10 KiB
Python

import torch
from torch._ops import PyOperator
from torch._C._functorch import TransformType
from torch._functorch.utils import enable_autograd_function
import torch.utils._pytree as pytree
from torch._C._functorch import (
_wrap_for_grad,
_unwrap_for_grad,
_unwrap_batched,
)
from torch._functorch.vmap import (
_broadcast_to_and_flatten,
_create_batched_inputs,
)
from torch.autograd.forward_ad import _set_fwd_grad_enabled
from typing import NamedTuple
# autograd.Function technically runs before the regular PyTorch dispatcher.
# This is how features like autocast and torch_dispatch (e.g. PythonTLSSnapshot)
# work with it. One day we might decide to change this, but until then,
# we need to give the illusion that autograd.Function runs before those things.
#
# We do this by using creating a custom PyOperator that only functorch
# dispatches specially.
class CustomFunctionPyOperator(PyOperator):
def __init__(self):
super().__init__('custom_function_call')
def __call__(self, *args, **kwargs):
# When custom_function_call is done dispatching through functorch,
# it should just invoke the autograd.Function. This is consistent
# with the autograd.Function behavior of being invoked before the
# PyTorch dispatcher.
#
# This will lead us into trouble later down the line, but this is
# pre-existing. There is an invariant that a function traced by
# make_fx should have the same behavior when provided the same
# Tensor. However, make_fx sees autograd.Function as a composite
# (because autograd.Function happens before the Python dispatch key)
# and only traces the forward pass.
if torch._C._are_functorch_transforms_active():
return super().__call__(*args, **kwargs)
autograd_function = args[0]
return autograd_function.apply(*args[1:], **kwargs)
# "custom_function_call"
# This is the mechanism for an autograd.Function that works with functorch transforms.
# It wraps an autograd.Function; interactions with functorch transforms are defined
# via PyDispatcher and PyOperator rather than through the traditional PyTorch
# dispatcher.
custom_function_call = CustomFunctionPyOperator()
# The grad rule for custom_function_call is to construct a new _SingleLevelFunction
# (autograd.Function that only works with a single layer (level) of functorch) that:
# - unwraps the inputs
# - redispatches to custom_function_call
# - wraps the outputs
# and whose backward pass calls the original autograd.Function's backward.
#
# Why do we need to redispatch to custom_function_call?
# -----------------------------------------------------
# This is consistent with how ATen operators work with functorch's grad transform:
# they always redispatch to the original operator.
# Consider torch.sin, and let's say we do grad0(grad1(torch.sin))(x)
#
# grad1 will:
# - set up the autograd graph
# - unwrap the inputs
# - redispatch to at::sin (*)
# - rewrap the outputs on the return
#
# On the redispatch in (*), grad0 will:
# - set up the autograd graph
# - unwrap the inputs
# - redispatch to at::sin
# - rewrap the outputs on the return
#
# To "set up the autograd graph", we generate a _SingleLevelFunction
# and apply it.
@custom_function_call.py_impl(TransformType.Grad)
@custom_function_call.py_impl(TransformType.Jvp)
def custom_function_call_grad(interpreter, autograd_function, *operands):
Generated = generate_single_level_function(interpreter, autograd_function)
with enable_autograd_function():
flat_out = Generated.apply(*operands)
return flat_out
def generate_single_level_function(interpreter, autograd_function):
level = interpreter.level()
def forward(*operands):
unwrapped_operands = pytree.tree_map_only(
torch.Tensor,
lambda x: _unwrap_for_grad(x, level),
operands)
# Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter
# the transform. _SingleLevelFunction will turn off both fwd and bwd
# gradient computation and we need to turn it back on here.
with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower():
output = custom_function_call(autograd_function, *unwrapped_operands)
return pytree.tree_map_only(
torch.Tensor,
lambda x: _wrap_for_grad(x, level),
output)
def setup_context(ctx, outputs, *operands):
ctx.mark_dirty = mark_dirty_error
return autograd_function.setup_context(ctx, outputs, *operands)
# backward is only used if the transform is TransformType.Grad
def backward(ctx, *grads):
result = autograd_function.backward(ctx, *grads)
return result
# jvp is only used if the transform is TransformType.Jvp
def jvp(ctx, *tangents):
result = autograd_function.jvp(ctx, *tangents)
return result
# This is the sequence of magic words to dynamically generate a Subclass with
# a given name. A Tensor's .grad_fn field has a class name that is the original
# autograd.Function's name + Backward, so we do this to generate some
# meaningful name.
name = f'{autograd_function.__name__}Generated'
Generated = type(
name,
(torch.autograd.function._SingleLevelFunction,),
{
'forward': staticmethod(forward),
'backward': staticmethod(backward),
'jvp': staticmethod(jvp),
'setup_context': staticmethod(setup_context),
},
)
return Generated
# https://github.com/pytorch/pytorch/issues/90225
# If an input was marked as dirty, and the autograd.Function returns the input
# from the forward, then the grad rule for custom_function_call must also
# return the corresponding input from the forward() of the Generated autograd.Function
#
# We haven't figured out how to do this yet. One possibility is to rely
# on if the return from the redispatched custom_function_call in Generated.forward
# has the same object id as one of the inputs,
# but https://github.com/pytorch/pytorch/issues/90209 means we cannot rely on
# that property.
def mark_dirty_error(*args, **kwargs):
raise RuntimeError(
'NYI: we do not yet support ctx.mark_dirty with functorch transforms. '
'Please try to avoid modifying inputs to the autograd.Function in-place '
'by using out-of-place operations or by cloning the inputs. '
'Please see https://github.com/pytorch/pytorch/issues/90209 for more details'
)
# NOTE: [functorch vjp and autograd interaction]
# There's an edge case with the functorch vjp and autograd interaction
# that will eventually be fixed by mode-only functorch.
# The TL;DR is that there's no way to unwrap a dead GradTensorWrapper,
# so we (the framework) need to do it manually. Regular PyTorch operators
# automatically do so this is consisent.
#
# class MyExp(torch.autograd.Function):
# @staticmethod
# def forward(x):
# return x.exp()
#
# @staticmethod
# def setup_context(ctx, outputs, x):
# y = outputs
# ctx.save_for_backward(y)
#
# @staticmethod
# def backward(gy):
# y, = ctx.saved_tensors()
# return MyMul.apply(gy, y)
#
# x = torch.randn([], requires_grad=True)
# gy = torch.randn([], requires_grad=True)
# _, vjp_fn = vjp(MySin.apply, x)
# result = vjp_fn(gy)
#
# MyMul is an autograd.Function that is not shown here.
# It saves a `y` for backward (since gy requires grad).
#
# in vjp_fn(gy), we get:
# > MyMul.apply(gy, GradTensorWrapper(y, level=dead))
# Because the y that is saved for backward by MyExp is a GradTensorWrapper
# but is now dead since we are outside the vjp context.
#
# PyTorch dispatcher operations, upon seeing a dead GradTensorWrapper,
# will automatically unwrap the GradTensorWrapper when applied.
# But since autograd.Function technically sits above the regular PyTorch
# dispatcher, it doesn't get this treatment. So we manually do
# the unwrapping to be consistent with regular PyTorch dispatcher operations.
class VmapInfo(NamedTuple):
batch_size: int
@custom_function_call.py_impl(TransformType.Vmap)
def custom_function_call_vmap(interpreter, autograd_function, *operands):
if not hasattr(autograd_function, "vmap"):
# TODO: link docs when they're ready.
# https://github.com/pytorch/pytorch/issues/90224
raise RuntimeError(
f"You tried to vmap over {autograd_function.__name__}, but "
f"it does not have a vmap rule defined. Please add a vmap "
f"staticmethod to it.")
current_level = interpreter.level()
info = VmapInfo(batch_size=interpreter.batch_size())
unwrapped_operands, in_dims = unwrap_batched(operands, current_level)
# If none of the tensors are batched at the current level, then we skip the
# current level. This saves the user from needing to handle this case in
# their vmap staticmethod (and is consistent with our C++ batching rule API)
if pytree.tree_all(lambda dim: dim is None, in_dims):
with interpreter.lower():
return custom_function_call(autograd_function, *operands)
with interpreter.lower():
unwrapped_output, out_dims = autograd_function.vmap(info, in_dims, *unwrapped_operands)
output = wrap_batched(unwrapped_output, out_dims, current_level)
return output
def unwrap_batched(args, level):
flat_args, spec = pytree.tree_flatten(args)
if len(flat_args) == 0:
return args, ()
result = [_unwrap_batched(arg, level) if isinstance(arg, torch.Tensor)
else (arg, None) for arg in flat_args]
output, bdims = zip(*result)
return pytree.tree_unflatten(output, spec), pytree.tree_unflatten(bdims, spec)
def wrap_batched(args, bdims, level):
# TODO: raise better error message to the user when they don't follow the API.
# Should probably mimic the logic of _process_batched_inputs,
# but that one is hyperspecialized on error messages.
# https://github.com/pytorch/pytorch/issues/90224
flat_args, spec = pytree.tree_flatten(args)
flat_bdims = _broadcast_to_and_flatten(bdims, spec)
assert flat_bdims is not None
result = _create_batched_inputs(flat_bdims, flat_args, level, spec)
return result
@custom_function_call.py_impl(TransformType.Functionalize)
def custom_function_call_functionalize(interpreter, autograd_function, *operands):
raise RuntimeError("NYI: Functionalize rule for custom_function_call")