diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index 16401353fa3..74681c4d4b0 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -1087,6 +1087,32 @@ class TestAutogradFunction(TestCase): grad(h, argnums=(0, 1))(x, grad_y) + @_set_autograd_function_extension_enabled() + def test_grad_fn_name(self, device): + names = [] + + class FooBar(torch.autograd.Function): + @staticmethod + def forward(x): + return x.clone() + + @staticmethod + def setup_context(ctx, inputs, outputs): + return + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + def f(x): + y = FooBar.apply(x) + names.append(type(y.grad_fn).__name__) + return y + + x = torch.tensor(1.) + grad(f)(x) + self.assertEqual(names, ['FooBarGeneratedBackward']) + class TestAutogradFunctionVmapAPI(TestCase): @_set_autograd_function_extension_enabled() diff --git a/torch/_functorch/autograd_function.py b/torch/_functorch/autograd_function.py index a1b4744ebda..9c38eb8dd82 100644 --- a/torch/_functorch/autograd_function.py +++ b/torch/_functorch/autograd_function.py @@ -82,53 +82,63 @@ custom_function_call = CustomFunctionPyOperator() @custom_function_call.py_impl(TransformType.Grad) @custom_function_call.py_impl(TransformType.Jvp) def custom_function_call_grad(interpreter, autograd_function, *operands): - maybe_interpreter = interpreter - level = maybe_interpreter.level() - - # TODO: The name of the grad_fn is GeneratedBackward. This isn't a great UX, - # but in theory functorch users shouldn't be peeking at the grad_fn. - # We should try to generate a better name for this. - # https://github.com/pytorch/pytorch/issues/90224 - class Generated(torch.autograd.function._SingleLevelFunction): - @staticmethod - def forward(*operands): - unwrapped_operands = pytree.tree_map_only( - torch.Tensor, - lambda x: _unwrap_for_grad(x, level), - operands) - # Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter - # the transform. _SingleLevelFunction will turn off both fwd and bwd - # gradient computation and we need to turn it back on here. - with torch.enable_grad(), _set_fwd_grad_enabled(True), maybe_interpreter.lower(): - output = custom_function_call(autograd_function, *unwrapped_operands) - - return pytree.tree_map_only( - torch.Tensor, - lambda x: _wrap_for_grad(x, level), - output) - - @staticmethod - def setup_context(ctx, outputs, *operands): - ctx.mark_dirty = mark_dirty_error - return autograd_function.setup_context(ctx, outputs, *operands) - - # backward is only used if the transform is TransformType.Grad - @staticmethod - def backward(ctx, *grads): - result = autograd_function.backward(ctx, *grads) - return result - - # jvp is only used if the transform is TransformType.Jvp - @staticmethod - def jvp(ctx, *tangents): - result = autograd_function.jvp(ctx, *tangents) - return result - + Generated = generate_single_level_function(interpreter, autograd_function) with enable_autograd_function(): flat_out = Generated.apply(*operands) return flat_out +def generate_single_level_function(interpreter, autograd_function): + level = interpreter.level() + + def forward(*operands): + unwrapped_operands = pytree.tree_map_only( + torch.Tensor, + lambda x: _unwrap_for_grad(x, level), + operands) + # Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter + # the transform. _SingleLevelFunction will turn off both fwd and bwd + # gradient computation and we need to turn it back on here. + with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower(): + output = custom_function_call(autograd_function, *unwrapped_operands) + + return pytree.tree_map_only( + torch.Tensor, + lambda x: _wrap_for_grad(x, level), + output) + + def setup_context(ctx, outputs, *operands): + ctx.mark_dirty = mark_dirty_error + return autograd_function.setup_context(ctx, outputs, *operands) + + # backward is only used if the transform is TransformType.Grad + def backward(ctx, *grads): + result = autograd_function.backward(ctx, *grads) + return result + + # jvp is only used if the transform is TransformType.Jvp + def jvp(ctx, *tangents): + result = autograd_function.jvp(ctx, *tangents) + return result + + # This is the sequence of magic words to dynamically generate a Subclass with + # a given name. A Tensor's .grad_fn field has a class name that is the original + # autograd.Function's name + Backward, so we do this to generate some + # meaningful name. + name = f'{autograd_function.__name__}Generated' + Generated = type( + name, + (torch.autograd.function._SingleLevelFunction,), + { + 'forward': staticmethod(forward), + 'backward': staticmethod(backward), + 'jvp': staticmethod(jvp), + 'setup_context': staticmethod(setup_context), + }, + ) + return Generated + + # https://github.com/pytorch/pytorch/issues/90225 # If an input was marked as dirty, and the autograd.Function returns the input # from the forward, then the grad rule for custom_function_call must also