mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Applies PLW0108 which removes useless lambda calls in Python, the rule is in preview so it is not ready to be enabled by default just yet. These are the autofixes from the rule. Pull Request resolved: https://github.com/pytorch/pytorch/pull/113602 Approved by: https://github.com/albanD
275 lines
12 KiB
Python
275 lines
12 KiB
Python
import torch
|
|
import torch.utils._pytree as pytree
|
|
from collections import namedtuple
|
|
import functools
|
|
|
|
|
|
# NOTE [CustomOp autograd kernel indirection]
|
|
# We register `inner` as the autograd kernel for this custom_op.
|
|
# `inner` either calls the autograd formula registered by the user,
|
|
# or goes into an `autograd_not_implemented` kernel.
|
|
#
|
|
# The reason why this indirection exists is
|
|
# so that we can swap out the autograd kernel (the PyTorch dispatcher
|
|
# doesn't actually allow us to do this). By default, we want
|
|
# the `autograd_not_implemented` behavior, but then the user may come
|
|
# and register something that is actually a backward formula
|
|
def autograd_kernel_indirection(custom_op):
|
|
autograd_fallback = autograd_not_implemented(custom_op)
|
|
|
|
def inner(*args, **kwargs):
|
|
if custom_op._has_impl('autograd'):
|
|
kernel = custom_op._get_impl('autograd').func
|
|
return kernel(*args, **kwargs)
|
|
# As explained in NOTE ["backward", "save_for_backward", and "autograd"],
|
|
# after the user gives us "backward" and "save_for_backward", we generate
|
|
# the "autograd" impl. If the user only provided one, then we tell
|
|
# the user they've done something wrong.
|
|
if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'):
|
|
missing = (
|
|
'save_for_backward' if custom_op._has_impl('backward')
|
|
else 'backward'
|
|
)
|
|
found = 'save_for_backward' if missing == 'backward' else 'backward'
|
|
loc = custom_op._get_impl(found).location
|
|
raise RuntimeError(
|
|
f"We found a '{found}' registration for {custom_op} at "
|
|
f"{loc} but were unable to find a '{missing}' registration. "
|
|
f"To use the CustomOp API to register a backward formula, "
|
|
f"please provide us both a backward function and a "
|
|
f"'save for backward' function via `impl_backward` and "
|
|
f"`impl_save_for_backward` respectively.")
|
|
return autograd_fallback(*args, **kwargs)
|
|
return inner
|
|
|
|
|
|
# TODO(#101191): Use the actual C++ autograd not implemented fallback,
|
|
# or change the default autograd fallback to the autograd not implemented fallback.
|
|
def autograd_not_implemented(custom_op):
|
|
def kernel(*args, **kwargs):
|
|
if torch.is_grad_enabled() and pytree.tree_any(
|
|
lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
|
|
):
|
|
raise RuntimeError("Autograd has not been implemented for operator")
|
|
with torch._C._AutoDispatchBelowAutograd():
|
|
return custom_op(*args, **kwargs)
|
|
return kernel
|
|
|
|
|
|
def mark_non_differentiable(ctx, output, output_differentiability):
|
|
# Output types are restricted to be:
|
|
# - Tensor
|
|
# - Tensor[]
|
|
# - int, bool, Scalar, float
|
|
# See _check_can_register_backward
|
|
if output_differentiability is not None:
|
|
if not isinstance(output, tuple):
|
|
tuple_output = (output,)
|
|
else:
|
|
tuple_output = output # type: ignore[assignment]
|
|
assert len(output_differentiability) == len(tuple_output)
|
|
non_differentiable_tensors = []
|
|
for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)):
|
|
if isinstance(out, torch.Tensor):
|
|
if not differentiable:
|
|
non_differentiable_tensors.append(out)
|
|
continue
|
|
if isinstance(out, list):
|
|
if not differentiable:
|
|
non_differentiable_tensors.extend(out)
|
|
continue
|
|
if differentiable:
|
|
raise RuntimeError(
|
|
f"With output_differentiability={output_differentiability}. "
|
|
f"At idx {idx}, we received an object of type {type(out)} that "
|
|
f"is not a Tensor, so it cannot have be marked as differentiable in "
|
|
f"output_differentiability.")
|
|
if non_differentiable_tensors:
|
|
ctx.mark_non_differentiable(*non_differentiable_tensors)
|
|
|
|
|
|
def construct_autograd_kernel(
|
|
schema,
|
|
output_differentiability,
|
|
custom_op,
|
|
op_overload,
|
|
save_for_backward_fn,
|
|
backward_fn):
|
|
|
|
def apply(*args):
|
|
flat_args, spec = pytree.tree_flatten(args)
|
|
out_spec = None
|
|
|
|
def forward(ctx, *flat_args):
|
|
ctx.set_materialize_grads(True)
|
|
args = pytree.tree_unflatten(list(flat_args), spec)
|
|
with torch._C._AutoDispatchBelowAutograd():
|
|
output = op_overload(*args)
|
|
|
|
# We use the info about args to give better error messages in backward
|
|
args_info = namedtuple_args(
|
|
schema, pytree.tree_map(type, args))
|
|
|
|
save_for_backward_fn_inputs = namedtuple_args(schema, args)
|
|
to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
|
|
|
|
save_pytree_for_backward(ctx, (to_save, args_info))
|
|
mark_non_differentiable(ctx, output, output_differentiability)
|
|
|
|
nonlocal out_spec
|
|
flat_output, out_spec = pytree.tree_flatten(output)
|
|
return tuple(flat_output)
|
|
|
|
def backward(ctx, *flat_grad_output):
|
|
assert out_spec is not None
|
|
grads = pytree.tree_unflatten(list(flat_grad_output), out_spec)
|
|
saved, args_info = unpack_saved(ctx)
|
|
# There is nothing on the ctx object for now, it is just there so
|
|
# that we can add additional things in the future.
|
|
inner_ctx = object()
|
|
if not isinstance(grads, tuple):
|
|
grads = (grads,)
|
|
grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
|
|
|
|
# Massage the grad_inputs_dict to a form acceptable by
|
|
# autograd.Function.
|
|
validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info)
|
|
return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
|
|
|
|
generated_cls = gen_autograd_function(
|
|
custom_op._opname + '_customop', forward, backward)
|
|
|
|
flat_output = generated_cls.apply(*flat_args)
|
|
assert out_spec is not None
|
|
return pytree.tree_unflatten(list(flat_output), out_spec)
|
|
return apply
|
|
|
|
|
|
def gen_autograd_function(name, forward, backward):
|
|
generated_cls = type(
|
|
name,
|
|
(torch.autograd.Function,),
|
|
{
|
|
'forward': staticmethod(forward),
|
|
'backward': staticmethod(backward),
|
|
}
|
|
)
|
|
return generated_cls
|
|
|
|
|
|
@functools.lru_cache
|
|
def namedtuple_args_cls(schema):
|
|
attribs = [arg.name for arg in schema.arguments.flat_all]
|
|
name = str(schema.name) + "_args"
|
|
# mypy doesn't support dynamic namedtuple name
|
|
tuple_cls = namedtuple(name, attribs) # type: ignore[misc]
|
|
return tuple_cls
|
|
|
|
|
|
def namedtuple_args(schema, args):
|
|
assert isinstance(args, tuple)
|
|
tuple_cls = namedtuple_args_cls(schema)
|
|
return tuple_cls(*args)
|
|
|
|
|
|
def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
|
|
def error(what):
|
|
backward = forward_op._get_impl('backward')
|
|
raise RuntimeError(
|
|
f"In the backward function defined for {forward_op} at "
|
|
f"{backward.location} using the CustomOp API, {what}")
|
|
|
|
if not isinstance(grad_inputs_dict, dict):
|
|
error(f"expected the output of the backward function to be a dict but "
|
|
f"got {type(grad_inputs_dict)}")
|
|
|
|
expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all
|
|
if arg.type.is_tensor_like()}
|
|
actual_keys = grad_inputs_dict.keys()
|
|
if expected_keys != actual_keys:
|
|
error(f"expected the returned grad_input dict to have keys "
|
|
f"{expected_keys} but got {actual_keys}. The backward "
|
|
f"function must return a gradient (can be None) for each arg "
|
|
f"to the CustomOp that may be a Tensor or Sequence[Tensor]. "
|
|
f"Args declared to be non-Tensor-like types should not appear "
|
|
f"in the grad_input dict")
|
|
|
|
for name, grad in grad_inputs_dict.items():
|
|
arg_info = getattr(args_info, name)
|
|
|
|
if isinstance(arg_info, list):
|
|
if not isinstance(grad, (tuple, list)):
|
|
error(f"for input '{name}' expected the grad_input dict to "
|
|
f"hold a list of gradients but got object of type "
|
|
f"{type(grad)}.")
|
|
if not len(grad) == len(arg_info):
|
|
error(f"for input '{name}' expected the grad_input dict to "
|
|
f"hold a list of {len(arg_info)} gradients but got "
|
|
f"{len(grad)}")
|
|
for idx, (g, info) in enumerate(zip(grad, arg_info)):
|
|
if g is None:
|
|
continue
|
|
if not isinstance(g, torch.Tensor):
|
|
error(f"for input '{name}' expected the grad_input dict to "
|
|
f"hold a list of None or Tensor gradients but got "
|
|
f"object of {type(g)} at index {idx}")
|
|
if not issubclass(info, torch.Tensor):
|
|
error(f"for input '{name}', got a Tensor as the gradient "
|
|
f"for the {idx}-th value but expected None because "
|
|
f"the {idx}-th value was not a Tensor (it was "
|
|
f"type {arg_info}")
|
|
continue
|
|
|
|
if grad is None:
|
|
continue
|
|
if not isinstance(grad, torch.Tensor):
|
|
error(f"got object of type {type(grad)} as the gradient for input "
|
|
f"'{name}', "
|
|
f"but expected the gradient to be either None or a Tensor")
|
|
if not issubclass(arg_info, torch.Tensor):
|
|
error(f"got a Tensor as the gradient for input '{name}' but "
|
|
f"expected None as the gradient because input '{name}' "
|
|
f"was not a Tensor (it was type {arg_info}).")
|
|
|
|
|
|
def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info):
|
|
result = []
|
|
for name, arg_info in args_info._asdict().items():
|
|
if name not in grad_inputs_dict:
|
|
result.append(pytree.tree_map(lambda x: None, arg_info))
|
|
continue
|
|
result.append(grad_inputs_dict[name])
|
|
return tuple(pytree.tree_leaves(result))
|
|
|
|
# Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it.
|
|
# autograd.Function prefers that users use ctx.save_for_backward to
|
|
# save Tensors (to avoid reference cycles) and for non-Tensors to go onto the
|
|
# ctx object.
|
|
def save_pytree_for_backward(ctx, stuff):
|
|
flat_stuff, spec = pytree.tree_flatten(stuff)
|
|
num_elts = len(flat_stuff)
|
|
tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
|
|
if isinstance(thing, torch.Tensor)]
|
|
non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
|
|
if not isinstance(thing, torch.Tensor)]
|
|
tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)]
|
|
non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)]
|
|
|
|
ctx.spec = spec
|
|
ctx.num_elts = num_elts
|
|
ctx.save_for_backward(*tensors)
|
|
ctx.tensor_idxs = tensor_idxs
|
|
ctx.saved_non_tensors = non_tensors
|
|
ctx.non_tensor_idxs = non_tensor_idxs
|
|
|
|
|
|
# Inverse operation to save_pytree_for_backward
|
|
def unpack_saved(ctx):
|
|
flat_stuff = [None] * ctx.num_elts
|
|
for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs):
|
|
flat_stuff[idx] = tensor
|
|
for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs):
|
|
flat_stuff[idx] = non_tensor
|
|
stuff = pytree.tree_unflatten(flat_stuff, ctx.spec)
|
|
return stuff
|