mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[pyfunctorch] Generate a more meaningful name for _SingleLevelAutogradFunction (#90418)
The API to do this is not pretty, but at least it works. Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/90418 Approved by: https://github.com/soulitzer
This commit is contained in:
parent
da42eab48b
commit
f21cb7d77e
|
|
@ -1087,6 +1087,32 @@ class TestAutogradFunction(TestCase):
|
||||||
|
|
||||||
grad(h, argnums=(0, 1))(x, grad_y)
|
grad(h, argnums=(0, 1))(x, grad_y)
|
||||||
|
|
||||||
|
@_set_autograd_function_extension_enabled()
|
||||||
|
def test_grad_fn_name(self, device):
|
||||||
|
names = []
|
||||||
|
|
||||||
|
class FooBar(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(x):
|
||||||
|
return x.clone()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def setup_context(ctx, inputs, outputs):
|
||||||
|
return
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
return grad_output
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
y = FooBar.apply(x)
|
||||||
|
names.append(type(y.grad_fn).__name__)
|
||||||
|
return y
|
||||||
|
|
||||||
|
x = torch.tensor(1.)
|
||||||
|
grad(f)(x)
|
||||||
|
self.assertEqual(names, ['FooBarGeneratedBackward'])
|
||||||
|
|
||||||
|
|
||||||
class TestAutogradFunctionVmapAPI(TestCase):
|
class TestAutogradFunctionVmapAPI(TestCase):
|
||||||
@_set_autograd_function_extension_enabled()
|
@_set_autograd_function_extension_enabled()
|
||||||
|
|
|
||||||
|
|
@ -82,15 +82,15 @@ custom_function_call = CustomFunctionPyOperator()
|
||||||
@custom_function_call.py_impl(TransformType.Grad)
|
@custom_function_call.py_impl(TransformType.Grad)
|
||||||
@custom_function_call.py_impl(TransformType.Jvp)
|
@custom_function_call.py_impl(TransformType.Jvp)
|
||||||
def custom_function_call_grad(interpreter, autograd_function, *operands):
|
def custom_function_call_grad(interpreter, autograd_function, *operands):
|
||||||
maybe_interpreter = interpreter
|
Generated = generate_single_level_function(interpreter, autograd_function)
|
||||||
level = maybe_interpreter.level()
|
with enable_autograd_function():
|
||||||
|
flat_out = Generated.apply(*operands)
|
||||||
|
return flat_out
|
||||||
|
|
||||||
|
|
||||||
|
def generate_single_level_function(interpreter, autograd_function):
|
||||||
|
level = interpreter.level()
|
||||||
|
|
||||||
# TODO: The name of the grad_fn is GeneratedBackward. This isn't a great UX,
|
|
||||||
# but in theory functorch users shouldn't be peeking at the grad_fn.
|
|
||||||
# We should try to generate a better name for this.
|
|
||||||
# https://github.com/pytorch/pytorch/issues/90224
|
|
||||||
class Generated(torch.autograd.function._SingleLevelFunction):
|
|
||||||
@staticmethod
|
|
||||||
def forward(*operands):
|
def forward(*operands):
|
||||||
unwrapped_operands = pytree.tree_map_only(
|
unwrapped_operands = pytree.tree_map_only(
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
|
|
@ -99,7 +99,7 @@ def custom_function_call_grad(interpreter, autograd_function, *operands):
|
||||||
# Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter
|
# Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter
|
||||||
# the transform. _SingleLevelFunction will turn off both fwd and bwd
|
# the transform. _SingleLevelFunction will turn off both fwd and bwd
|
||||||
# gradient computation and we need to turn it back on here.
|
# gradient computation and we need to turn it back on here.
|
||||||
with torch.enable_grad(), _set_fwd_grad_enabled(True), maybe_interpreter.lower():
|
with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower():
|
||||||
output = custom_function_call(autograd_function, *unwrapped_operands)
|
output = custom_function_call(autograd_function, *unwrapped_operands)
|
||||||
|
|
||||||
return pytree.tree_map_only(
|
return pytree.tree_map_only(
|
||||||
|
|
@ -107,26 +107,36 @@ def custom_function_call_grad(interpreter, autograd_function, *operands):
|
||||||
lambda x: _wrap_for_grad(x, level),
|
lambda x: _wrap_for_grad(x, level),
|
||||||
output)
|
output)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def setup_context(ctx, outputs, *operands):
|
def setup_context(ctx, outputs, *operands):
|
||||||
ctx.mark_dirty = mark_dirty_error
|
ctx.mark_dirty = mark_dirty_error
|
||||||
return autograd_function.setup_context(ctx, outputs, *operands)
|
return autograd_function.setup_context(ctx, outputs, *operands)
|
||||||
|
|
||||||
# backward is only used if the transform is TransformType.Grad
|
# backward is only used if the transform is TransformType.Grad
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, *grads):
|
def backward(ctx, *grads):
|
||||||
result = autograd_function.backward(ctx, *grads)
|
result = autograd_function.backward(ctx, *grads)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# jvp is only used if the transform is TransformType.Jvp
|
# jvp is only used if the transform is TransformType.Jvp
|
||||||
@staticmethod
|
|
||||||
def jvp(ctx, *tangents):
|
def jvp(ctx, *tangents):
|
||||||
result = autograd_function.jvp(ctx, *tangents)
|
result = autograd_function.jvp(ctx, *tangents)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
with enable_autograd_function():
|
# This is the sequence of magic words to dynamically generate a Subclass with
|
||||||
flat_out = Generated.apply(*operands)
|
# a given name. A Tensor's .grad_fn field has a class name that is the original
|
||||||
return flat_out
|
# autograd.Function's name + Backward, so we do this to generate some
|
||||||
|
# meaningful name.
|
||||||
|
name = f'{autograd_function.__name__}Generated'
|
||||||
|
Generated = type(
|
||||||
|
name,
|
||||||
|
(torch.autograd.function._SingleLevelFunction,),
|
||||||
|
{
|
||||||
|
'forward': staticmethod(forward),
|
||||||
|
'backward': staticmethod(backward),
|
||||||
|
'jvp': staticmethod(jvp),
|
||||||
|
'setup_context': staticmethod(setup_context),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return Generated
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/90225
|
# https://github.com/pytorch/pytorch/issues/90225
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user