[pyfunctorch] Generate a more meaningful name for _SingleLevelAutogradFunction (#90418)

The API to do this is not pretty, but at least it works.

Test Plan:
- new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90418
Approved by: https://github.com/soulitzer
This commit is contained in:
Richard Zou 2022-12-13 12:36:07 -08:00 committed by PyTorch MergeBot
parent da42eab48b
commit f21cb7d77e
2 changed files with 78 additions and 42 deletions

View File

@ -1087,6 +1087,32 @@ class TestAutogradFunction(TestCase):
grad(h, argnums=(0, 1))(x, grad_y)
@_set_autograd_function_extension_enabled()
def test_grad_fn_name(self, device):
names = []
class FooBar(torch.autograd.Function):
@staticmethod
def forward(x):
return x.clone()
@staticmethod
def setup_context(ctx, inputs, outputs):
return
@staticmethod
def backward(ctx, grad_output):
return grad_output
def f(x):
y = FooBar.apply(x)
names.append(type(y.grad_fn).__name__)
return y
x = torch.tensor(1.)
grad(f)(x)
self.assertEqual(names, ['FooBarGeneratedBackward'])
class TestAutogradFunctionVmapAPI(TestCase):
@_set_autograd_function_extension_enabled()

View File

@ -82,53 +82,63 @@ custom_function_call = CustomFunctionPyOperator()
@custom_function_call.py_impl(TransformType.Grad)
@custom_function_call.py_impl(TransformType.Jvp)
def custom_function_call_grad(interpreter, autograd_function, *operands):
maybe_interpreter = interpreter
level = maybe_interpreter.level()
# TODO: The name of the grad_fn is GeneratedBackward. This isn't a great UX,
# but in theory functorch users shouldn't be peeking at the grad_fn.
# We should try to generate a better name for this.
# https://github.com/pytorch/pytorch/issues/90224
class Generated(torch.autograd.function._SingleLevelFunction):
@staticmethod
def forward(*operands):
unwrapped_operands = pytree.tree_map_only(
torch.Tensor,
lambda x: _unwrap_for_grad(x, level),
operands)
# Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter
# the transform. _SingleLevelFunction will turn off both fwd and bwd
# gradient computation and we need to turn it back on here.
with torch.enable_grad(), _set_fwd_grad_enabled(True), maybe_interpreter.lower():
output = custom_function_call(autograd_function, *unwrapped_operands)
return pytree.tree_map_only(
torch.Tensor,
lambda x: _wrap_for_grad(x, level),
output)
@staticmethod
def setup_context(ctx, outputs, *operands):
ctx.mark_dirty = mark_dirty_error
return autograd_function.setup_context(ctx, outputs, *operands)
# backward is only used if the transform is TransformType.Grad
@staticmethod
def backward(ctx, *grads):
result = autograd_function.backward(ctx, *grads)
return result
# jvp is only used if the transform is TransformType.Jvp
@staticmethod
def jvp(ctx, *tangents):
result = autograd_function.jvp(ctx, *tangents)
return result
Generated = generate_single_level_function(interpreter, autograd_function)
with enable_autograd_function():
flat_out = Generated.apply(*operands)
return flat_out
def generate_single_level_function(interpreter, autograd_function):
level = interpreter.level()
def forward(*operands):
unwrapped_operands = pytree.tree_map_only(
torch.Tensor,
lambda x: _unwrap_for_grad(x, level),
operands)
# Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter
# the transform. _SingleLevelFunction will turn off both fwd and bwd
# gradient computation and we need to turn it back on here.
with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower():
output = custom_function_call(autograd_function, *unwrapped_operands)
return pytree.tree_map_only(
torch.Tensor,
lambda x: _wrap_for_grad(x, level),
output)
def setup_context(ctx, outputs, *operands):
ctx.mark_dirty = mark_dirty_error
return autograd_function.setup_context(ctx, outputs, *operands)
# backward is only used if the transform is TransformType.Grad
def backward(ctx, *grads):
result = autograd_function.backward(ctx, *grads)
return result
# jvp is only used if the transform is TransformType.Jvp
def jvp(ctx, *tangents):
result = autograd_function.jvp(ctx, *tangents)
return result
# This is the sequence of magic words to dynamically generate a Subclass with
# a given name. A Tensor's .grad_fn field has a class name that is the original
# autograd.Function's name + Backward, so we do this to generate some
# meaningful name.
name = f'{autograd_function.__name__}Generated'
Generated = type(
name,
(torch.autograd.function._SingleLevelFunction,),
{
'forward': staticmethod(forward),
'backward': staticmethod(backward),
'jvp': staticmethod(jvp),
'setup_context': staticmethod(setup_context),
},
)
return Generated
# https://github.com/pytorch/pytorch/issues/90225
# If an input was marked as dirty, and the autograd.Function returns the input
# from the forward, then the grad rule for custom_function_call must also