mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[pyfunctorch] Generate a more meaningful name for _SingleLevelAutogradFunction (#90418)
The API to do this is not pretty, but at least it works. Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/90418 Approved by: https://github.com/soulitzer
This commit is contained in:
parent
da42eab48b
commit
f21cb7d77e
|
|
@ -1087,6 +1087,32 @@ class TestAutogradFunction(TestCase):
|
|||
|
||||
grad(h, argnums=(0, 1))(x, grad_y)
|
||||
|
||||
@_set_autograd_function_extension_enabled()
|
||||
def test_grad_fn_name(self, device):
|
||||
names = []
|
||||
|
||||
class FooBar(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(x):
|
||||
return x.clone()
|
||||
|
||||
@staticmethod
|
||||
def setup_context(ctx, inputs, outputs):
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
|
||||
def f(x):
|
||||
y = FooBar.apply(x)
|
||||
names.append(type(y.grad_fn).__name__)
|
||||
return y
|
||||
|
||||
x = torch.tensor(1.)
|
||||
grad(f)(x)
|
||||
self.assertEqual(names, ['FooBarGeneratedBackward'])
|
||||
|
||||
|
||||
class TestAutogradFunctionVmapAPI(TestCase):
|
||||
@_set_autograd_function_extension_enabled()
|
||||
|
|
|
|||
|
|
@ -82,53 +82,63 @@ custom_function_call = CustomFunctionPyOperator()
|
|||
@custom_function_call.py_impl(TransformType.Grad)
|
||||
@custom_function_call.py_impl(TransformType.Jvp)
|
||||
def custom_function_call_grad(interpreter, autograd_function, *operands):
|
||||
maybe_interpreter = interpreter
|
||||
level = maybe_interpreter.level()
|
||||
|
||||
# TODO: The name of the grad_fn is GeneratedBackward. This isn't a great UX,
|
||||
# but in theory functorch users shouldn't be peeking at the grad_fn.
|
||||
# We should try to generate a better name for this.
|
||||
# https://github.com/pytorch/pytorch/issues/90224
|
||||
class Generated(torch.autograd.function._SingleLevelFunction):
|
||||
@staticmethod
|
||||
def forward(*operands):
|
||||
unwrapped_operands = pytree.tree_map_only(
|
||||
torch.Tensor,
|
||||
lambda x: _unwrap_for_grad(x, level),
|
||||
operands)
|
||||
# Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter
|
||||
# the transform. _SingleLevelFunction will turn off both fwd and bwd
|
||||
# gradient computation and we need to turn it back on here.
|
||||
with torch.enable_grad(), _set_fwd_grad_enabled(True), maybe_interpreter.lower():
|
||||
output = custom_function_call(autograd_function, *unwrapped_operands)
|
||||
|
||||
return pytree.tree_map_only(
|
||||
torch.Tensor,
|
||||
lambda x: _wrap_for_grad(x, level),
|
||||
output)
|
||||
|
||||
@staticmethod
|
||||
def setup_context(ctx, outputs, *operands):
|
||||
ctx.mark_dirty = mark_dirty_error
|
||||
return autograd_function.setup_context(ctx, outputs, *operands)
|
||||
|
||||
# backward is only used if the transform is TransformType.Grad
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
result = autograd_function.backward(ctx, *grads)
|
||||
return result
|
||||
|
||||
# jvp is only used if the transform is TransformType.Jvp
|
||||
@staticmethod
|
||||
def jvp(ctx, *tangents):
|
||||
result = autograd_function.jvp(ctx, *tangents)
|
||||
return result
|
||||
|
||||
Generated = generate_single_level_function(interpreter, autograd_function)
|
||||
with enable_autograd_function():
|
||||
flat_out = Generated.apply(*operands)
|
||||
return flat_out
|
||||
|
||||
|
||||
def generate_single_level_function(interpreter, autograd_function):
|
||||
level = interpreter.level()
|
||||
|
||||
def forward(*operands):
|
||||
unwrapped_operands = pytree.tree_map_only(
|
||||
torch.Tensor,
|
||||
lambda x: _unwrap_for_grad(x, level),
|
||||
operands)
|
||||
# Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter
|
||||
# the transform. _SingleLevelFunction will turn off both fwd and bwd
|
||||
# gradient computation and we need to turn it back on here.
|
||||
with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower():
|
||||
output = custom_function_call(autograd_function, *unwrapped_operands)
|
||||
|
||||
return pytree.tree_map_only(
|
||||
torch.Tensor,
|
||||
lambda x: _wrap_for_grad(x, level),
|
||||
output)
|
||||
|
||||
def setup_context(ctx, outputs, *operands):
|
||||
ctx.mark_dirty = mark_dirty_error
|
||||
return autograd_function.setup_context(ctx, outputs, *operands)
|
||||
|
||||
# backward is only used if the transform is TransformType.Grad
|
||||
def backward(ctx, *grads):
|
||||
result = autograd_function.backward(ctx, *grads)
|
||||
return result
|
||||
|
||||
# jvp is only used if the transform is TransformType.Jvp
|
||||
def jvp(ctx, *tangents):
|
||||
result = autograd_function.jvp(ctx, *tangents)
|
||||
return result
|
||||
|
||||
# This is the sequence of magic words to dynamically generate a Subclass with
|
||||
# a given name. A Tensor's .grad_fn field has a class name that is the original
|
||||
# autograd.Function's name + Backward, so we do this to generate some
|
||||
# meaningful name.
|
||||
name = f'{autograd_function.__name__}Generated'
|
||||
Generated = type(
|
||||
name,
|
||||
(torch.autograd.function._SingleLevelFunction,),
|
||||
{
|
||||
'forward': staticmethod(forward),
|
||||
'backward': staticmethod(backward),
|
||||
'jvp': staticmethod(jvp),
|
||||
'setup_context': staticmethod(setup_context),
|
||||
},
|
||||
)
|
||||
return Generated
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/90225
|
||||
# If an input was marked as dirty, and the autograd.Function returns the input
|
||||
# from the forward, then the grad rule for custom_function_call must also
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user