Add batched_grad parameter to autograd.grad (#65564)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65564

- wrap the call into engine with vmap if `batched_grad` is `True`
- improves the comment on the call to engine (somewhat addressing https://github.com/pytorch/pytorch/issues/41659)
- borrows the message from functional.jacobian's vectorized argument concerning usage of the vmap feature
- adds basic test (further testing is done when we replace the usage in vectorized jacobian computation)

TODO:
 - create an issue tracking this

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D31236259

Pulled By: soulitzer

fbshipit-source-id: b33e6b26ea98fa9f70c44da08458fc54ba4df0f7
This commit is contained in:
soulitzer 2021-10-03 19:52:30 -07:00 committed by Facebook GitHub Bot
parent b6d5f1ee70
commit 73901b099d
2 changed files with 83 additions and 18 deletions

View File

@ -719,6 +719,30 @@ class TestAutograd(TestCase):
torch.autograd.backward(x, inputs=(y, )) # allow_unused is implicitly True!
self.assertIsNone(y.grad)
def test_grad_batched_grad(self):
x = torch.randn(2, 2, requires_grad=True)
out = x.clone() # Size([2, 2])
batched_grad = torch.arange(3).expand(2, 2, 3).transpose(0, 2) # Size([3, 2, 2])
grad, = torch.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True)
self.assertEqual(grad, torch.arange(3).expand(2, 2, 3).transpose(0, 2).to(dtype=grad.dtype))
# Detect shape mismatch
grad_out = torch.ones(2, 2)
with self.assertRaisesRegex(RuntimeError, "If `is_grads_batched=True`, we interpret the first"):
torch.autograd.grad(outputs=out, grad_outputs=(grad_out,), inputs=(x,), is_grads_batched=True)
# Scalar outputs
out = x.sum() # Size([])
batched_grad = torch.arange(3) # Size([3])
grad, = torch.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True)
self.assertEqual(grad, torch.arange(3).expand(2, 2, 3).transpose(0, 2).to(dtype=grad.dtype))
# We consider scalar and sized-1 to be a mismatch. This is consistent with current non-batched behavior.
grad_out = torch.ones(2).unsqueeze(1)
with self.assertRaisesRegex(RuntimeError, "If `is_grads_batched=True`, we interpret the first"):
torch.autograd.grad(outputs=out, grad_outputs=(grad_out,), inputs=(x,), is_grads_batched=True)
def test_hooks(self):
x = torch.ones(5, 5, requires_grad=True)
y = torch.ones(5, 5) * 4

View File

@ -21,21 +21,37 @@ from ..overrides import has_torch_function, handle_torch_function
from . import functional
from . import forward_ad
from . import graph
from .. import _vmap_internals
__all__ = ['Variable', 'Function', 'backward', 'grad_mode']
_OptionalTensor = Optional[torch.Tensor]
def _make_grads(outputs: Sequence[torch.Tensor], grads: Sequence[_OptionalTensor]) -> Tuple[_OptionalTensor, ...]:
def _make_grads(outputs: Sequence[torch.Tensor], grads: Sequence[_OptionalTensor],
is_grads_batched: bool) -> Tuple[_OptionalTensor, ...]:
new_grads: List[_OptionalTensor] = []
for out, grad in zip(outputs, grads):
if isinstance(grad, torch.Tensor):
if not out.shape == grad.shape:
raise RuntimeError("Mismatch in shape: grad_output["
+ str(grads.index(grad)) + "] has a shape of "
+ str(grad.shape) + " and output["
+ str(outputs.index(out)) + "] has a shape of "
+ str(out.shape) + ".")
grad_shape = grad.shape if not is_grads_batched else grad.shape[1:]
if not out.shape == grad_shape:
if is_grads_batched:
raise RuntimeError("If `is_grads_batched=True`, we interpret the first "
"dimension of each grad_output as the batch dimension. "
"The sizes of the remaining dimensions are expected to match "
"the shape of corresponding output, but a mismatch "
"was detected: grad_output["
+ str(grads.index(grad)) + "] has a shape of "
+ str(grad.shape) + " and output["
+ str(outputs.index(out)) + "] has a shape of "
+ str(out.shape) + ". "
"If you only want some tensors in `grad_output` to be considered "
"batched, consider using vmap.")
else:
raise RuntimeError("Mismatch in shape: grad_output["
+ str(grads.index(grad)) + "] has a shape of "
+ str(grad.shape) + " and output["
+ str(outputs.index(out)) + "] has a shape of "
+ str(out.shape) + ".")
if out.dtype.is_complex != grad.dtype.is_complex:
raise RuntimeError("For complex Tensors, both grad_output and output"
" are required to have the same dtype."
@ -147,14 +163,16 @@ def backward(
tuple(inputs) if inputs is not None else tuple()
grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
grad_tensors_ = _make_grads(tensors, grad_tensors_)
grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
if retain_graph is None:
retain_graph = create_graph
Variable._execution_engine.run_backward(
# The reason we repeat same the comment below is that
# some Python versions print out the first line of a multi-line function
# calls in the traceback and some print out the last line
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
tensors, grad_tensors_, retain_graph, create_graph, inputs,
allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass
def grad(
outputs: _TensorOrTensors,
@ -163,13 +181,14 @@ def grad(
retain_graph: Optional[bool] = None,
create_graph: bool = False,
only_inputs: bool = True,
allow_unused: bool = False
allow_unused: bool = False,
is_grads_batched: bool = False
) -> Tuple[torch.Tensor, ...]:
r"""Computes and returns the sum of gradients of outputs with respect to
the inputs.
``grad_outputs`` should be a sequence of length matching ``output``
containing the "vector" in Jacobian-vector product, usually the pre-computed
containing the "vector" in vector-Jacobian product, usually the pre-computed
gradients w.r.t. each of the outputs. If an output doesn't require_grad,
then the gradient can be ``None``).
@ -189,7 +208,7 @@ def grad(
outputs (sequence of Tensor): outputs of the differentiated function.
inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be
returned (and not accumulated into ``.grad``).
grad_outputs (sequence of Tensor): The "vector" in the Jacobian-vector product.
grad_outputs (sequence of Tensor): The "vector" in the vector-Jacobian product.
Usually gradients w.r.t. each output. None values can be specified for scalar
Tensors or ones that don't require grad. If a None value would be acceptable
for all grad_tensors, then this argument is optional. Default: None.
@ -203,6 +222,18 @@ def grad(
allow_unused (bool, optional): If ``False``, specifying inputs that were not
used when computing outputs (and therefore their grad is always zero)
is an error. Defaults to ``False``.
is_grads_batched (bool, optional): If ``True``, the first dimension of each
tensor in ``grad_outputs`` will be interpreted as the batch dimension.
Instead of computing a single vector-Jacobian product, we compute a
batch of vector-Jacobian products for each "vector" in the batch.
We use the vmap prototype feature as the backend to vectorize calls
to the autograd engine so that this computation can be performed in a
single call. This should lead to performance improvements when compared
to manually looping and performing backward multiple times. Note that
due to this feature being experimental, there may be performance
cliffs. Please use ``torch._C._debug_only_display_vmap_fallback_warnings(True)``
to show any performance warnings and file an issue on github if warnings exist
for your use case. Defaults to ``False``.
"""
outputs = (outputs,) if isinstance(outputs, torch.Tensor) else tuple(outputs)
inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs)
@ -226,14 +257,24 @@ def grad(
"parts of the graph, please use torch.autograd.backward.")
grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(outputs))
grad_outputs_ = _make_grads(outputs, grad_outputs_)
grad_outputs_ = _make_grads(outputs, grad_outputs_, is_grads_batched=is_grads_batched)
if retain_graph is None:
retain_graph = create_graph
return Variable._execution_engine.run_backward(
outputs, grad_outputs_, retain_graph, create_graph,
inputs, allow_unused, accumulate_grad=False)
# The reason we repeat same the comment several times below is because
# some Python versions print out the first line of multi-line function
# calls in the traceback and some print out the last line
if is_grads_batched:
def vjp(gO):
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
outputs, gO, retain_graph, create_graph, inputs,
allow_unused, accumulate_grad=False) # Calls into the C++ engine to run the backward pass
return _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(grad_outputs)
else:
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
outputs, grad_outputs_, retain_graph, create_graph, inputs,
allow_unused, accumulate_grad=False) # Calls into the C++ engine to run the backward pass
# This function applies in case of gradient checkpointing for memory