mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add batched_grad parameter to autograd.grad (#65564)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65564 - wrap the call into engine with vmap if `batched_grad` is `True` - improves the comment on the call to engine (somewhat addressing https://github.com/pytorch/pytorch/issues/41659) - borrows the message from functional.jacobian's vectorized argument concerning usage of the vmap feature - adds basic test (further testing is done when we replace the usage in vectorized jacobian computation) TODO: - create an issue tracking this Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D31236259 Pulled By: soulitzer fbshipit-source-id: b33e6b26ea98fa9f70c44da08458fc54ba4df0f7
This commit is contained in:
parent
b6d5f1ee70
commit
73901b099d
|
|
@ -719,6 +719,30 @@ class TestAutograd(TestCase):
|
|||
torch.autograd.backward(x, inputs=(y, )) # allow_unused is implicitly True!
|
||||
self.assertIsNone(y.grad)
|
||||
|
||||
def test_grad_batched_grad(self):
|
||||
x = torch.randn(2, 2, requires_grad=True)
|
||||
|
||||
out = x.clone() # Size([2, 2])
|
||||
batched_grad = torch.arange(3).expand(2, 2, 3).transpose(0, 2) # Size([3, 2, 2])
|
||||
grad, = torch.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True)
|
||||
self.assertEqual(grad, torch.arange(3).expand(2, 2, 3).transpose(0, 2).to(dtype=grad.dtype))
|
||||
|
||||
# Detect shape mismatch
|
||||
grad_out = torch.ones(2, 2)
|
||||
with self.assertRaisesRegex(RuntimeError, "If `is_grads_batched=True`, we interpret the first"):
|
||||
torch.autograd.grad(outputs=out, grad_outputs=(grad_out,), inputs=(x,), is_grads_batched=True)
|
||||
|
||||
# Scalar outputs
|
||||
out = x.sum() # Size([])
|
||||
batched_grad = torch.arange(3) # Size([3])
|
||||
grad, = torch.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True)
|
||||
self.assertEqual(grad, torch.arange(3).expand(2, 2, 3).transpose(0, 2).to(dtype=grad.dtype))
|
||||
|
||||
# We consider scalar and sized-1 to be a mismatch. This is consistent with current non-batched behavior.
|
||||
grad_out = torch.ones(2).unsqueeze(1)
|
||||
with self.assertRaisesRegex(RuntimeError, "If `is_grads_batched=True`, we interpret the first"):
|
||||
torch.autograd.grad(outputs=out, grad_outputs=(grad_out,), inputs=(x,), is_grads_batched=True)
|
||||
|
||||
def test_hooks(self):
|
||||
x = torch.ones(5, 5, requires_grad=True)
|
||||
y = torch.ones(5, 5) * 4
|
||||
|
|
|
|||
|
|
@ -21,21 +21,37 @@ from ..overrides import has_torch_function, handle_torch_function
|
|||
from . import functional
|
||||
from . import forward_ad
|
||||
from . import graph
|
||||
from .. import _vmap_internals
|
||||
|
||||
__all__ = ['Variable', 'Function', 'backward', 'grad_mode']
|
||||
|
||||
_OptionalTensor = Optional[torch.Tensor]
|
||||
|
||||
def _make_grads(outputs: Sequence[torch.Tensor], grads: Sequence[_OptionalTensor]) -> Tuple[_OptionalTensor, ...]:
|
||||
def _make_grads(outputs: Sequence[torch.Tensor], grads: Sequence[_OptionalTensor],
|
||||
is_grads_batched: bool) -> Tuple[_OptionalTensor, ...]:
|
||||
new_grads: List[_OptionalTensor] = []
|
||||
for out, grad in zip(outputs, grads):
|
||||
if isinstance(grad, torch.Tensor):
|
||||
if not out.shape == grad.shape:
|
||||
raise RuntimeError("Mismatch in shape: grad_output["
|
||||
+ str(grads.index(grad)) + "] has a shape of "
|
||||
+ str(grad.shape) + " and output["
|
||||
+ str(outputs.index(out)) + "] has a shape of "
|
||||
+ str(out.shape) + ".")
|
||||
grad_shape = grad.shape if not is_grads_batched else grad.shape[1:]
|
||||
if not out.shape == grad_shape:
|
||||
if is_grads_batched:
|
||||
raise RuntimeError("If `is_grads_batched=True`, we interpret the first "
|
||||
"dimension of each grad_output as the batch dimension. "
|
||||
"The sizes of the remaining dimensions are expected to match "
|
||||
"the shape of corresponding output, but a mismatch "
|
||||
"was detected: grad_output["
|
||||
+ str(grads.index(grad)) + "] has a shape of "
|
||||
+ str(grad.shape) + " and output["
|
||||
+ str(outputs.index(out)) + "] has a shape of "
|
||||
+ str(out.shape) + ". "
|
||||
"If you only want some tensors in `grad_output` to be considered "
|
||||
"batched, consider using vmap.")
|
||||
else:
|
||||
raise RuntimeError("Mismatch in shape: grad_output["
|
||||
+ str(grads.index(grad)) + "] has a shape of "
|
||||
+ str(grad.shape) + " and output["
|
||||
+ str(outputs.index(out)) + "] has a shape of "
|
||||
+ str(out.shape) + ".")
|
||||
if out.dtype.is_complex != grad.dtype.is_complex:
|
||||
raise RuntimeError("For complex Tensors, both grad_output and output"
|
||||
" are required to have the same dtype."
|
||||
|
|
@ -147,14 +163,16 @@ def backward(
|
|||
tuple(inputs) if inputs is not None else tuple()
|
||||
|
||||
grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
|
||||
grad_tensors_ = _make_grads(tensors, grad_tensors_)
|
||||
grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
|
||||
if retain_graph is None:
|
||||
retain_graph = create_graph
|
||||
|
||||
Variable._execution_engine.run_backward(
|
||||
# The reason we repeat same the comment below is that
|
||||
# some Python versions print out the first line of a multi-line function
|
||||
# calls in the traceback and some print out the last line
|
||||
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
|
||||
tensors, grad_tensors_, retain_graph, create_graph, inputs,
|
||||
allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
|
||||
|
||||
allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass
|
||||
|
||||
def grad(
|
||||
outputs: _TensorOrTensors,
|
||||
|
|
@ -163,13 +181,14 @@ def grad(
|
|||
retain_graph: Optional[bool] = None,
|
||||
create_graph: bool = False,
|
||||
only_inputs: bool = True,
|
||||
allow_unused: bool = False
|
||||
allow_unused: bool = False,
|
||||
is_grads_batched: bool = False
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
r"""Computes and returns the sum of gradients of outputs with respect to
|
||||
the inputs.
|
||||
|
||||
``grad_outputs`` should be a sequence of length matching ``output``
|
||||
containing the "vector" in Jacobian-vector product, usually the pre-computed
|
||||
containing the "vector" in vector-Jacobian product, usually the pre-computed
|
||||
gradients w.r.t. each of the outputs. If an output doesn't require_grad,
|
||||
then the gradient can be ``None``).
|
||||
|
||||
|
|
@ -189,7 +208,7 @@ def grad(
|
|||
outputs (sequence of Tensor): outputs of the differentiated function.
|
||||
inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be
|
||||
returned (and not accumulated into ``.grad``).
|
||||
grad_outputs (sequence of Tensor): The "vector" in the Jacobian-vector product.
|
||||
grad_outputs (sequence of Tensor): The "vector" in the vector-Jacobian product.
|
||||
Usually gradients w.r.t. each output. None values can be specified for scalar
|
||||
Tensors or ones that don't require grad. If a None value would be acceptable
|
||||
for all grad_tensors, then this argument is optional. Default: None.
|
||||
|
|
@ -203,6 +222,18 @@ def grad(
|
|||
allow_unused (bool, optional): If ``False``, specifying inputs that were not
|
||||
used when computing outputs (and therefore their grad is always zero)
|
||||
is an error. Defaults to ``False``.
|
||||
is_grads_batched (bool, optional): If ``True``, the first dimension of each
|
||||
tensor in ``grad_outputs`` will be interpreted as the batch dimension.
|
||||
Instead of computing a single vector-Jacobian product, we compute a
|
||||
batch of vector-Jacobian products for each "vector" in the batch.
|
||||
We use the vmap prototype feature as the backend to vectorize calls
|
||||
to the autograd engine so that this computation can be performed in a
|
||||
single call. This should lead to performance improvements when compared
|
||||
to manually looping and performing backward multiple times. Note that
|
||||
due to this feature being experimental, there may be performance
|
||||
cliffs. Please use ``torch._C._debug_only_display_vmap_fallback_warnings(True)``
|
||||
to show any performance warnings and file an issue on github if warnings exist
|
||||
for your use case. Defaults to ``False``.
|
||||
"""
|
||||
outputs = (outputs,) if isinstance(outputs, torch.Tensor) else tuple(outputs)
|
||||
inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs)
|
||||
|
|
@ -226,14 +257,24 @@ def grad(
|
|||
"parts of the graph, please use torch.autograd.backward.")
|
||||
|
||||
grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(outputs))
|
||||
grad_outputs_ = _make_grads(outputs, grad_outputs_)
|
||||
grad_outputs_ = _make_grads(outputs, grad_outputs_, is_grads_batched=is_grads_batched)
|
||||
|
||||
if retain_graph is None:
|
||||
retain_graph = create_graph
|
||||
|
||||
return Variable._execution_engine.run_backward(
|
||||
outputs, grad_outputs_, retain_graph, create_graph,
|
||||
inputs, allow_unused, accumulate_grad=False)
|
||||
# The reason we repeat same the comment several times below is because
|
||||
# some Python versions print out the first line of multi-line function
|
||||
# calls in the traceback and some print out the last line
|
||||
if is_grads_batched:
|
||||
def vjp(gO):
|
||||
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
|
||||
outputs, gO, retain_graph, create_graph, inputs,
|
||||
allow_unused, accumulate_grad=False) # Calls into the C++ engine to run the backward pass
|
||||
return _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(grad_outputs)
|
||||
else:
|
||||
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
|
||||
outputs, grad_outputs_, retain_graph, create_graph, inputs,
|
||||
allow_unused, accumulate_grad=False) # Calls into the C++ engine to run the backward pass
|
||||
|
||||
|
||||
# This function applies in case of gradient checkpointing for memory
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user