mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Beef up {jacobian, hessian} vectorize docs; eliminate a warning (#51638)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51638 This PR makes the following doc changes: - Makes it clear to users that they should use vectorize "at their own risk" - Makes it clear that vectorize uses the "experimental prototype vmap" so that when users see error messages related to vmap they will know where it is coming from. This PR also: - makes it so that {jacobian, hessian} call a version of vmap that doesn't warn the user that they are using an "experimental prototype". The regular torch.vmap API does warn the user about this. This is to improve a UX a little because the user already knows from discovering the flag and reading the docs what they are getting themselves into. Test Plan: - Add test that {jacobian, hessian} with vectorize=True don't raise warnings Reviewed By: albanD Differential Revision: D26225402 Pulled By: zou3519 fbshipit-source-id: 1a6db920ecf10597fb2e0c6576f510507d999c34
This commit is contained in:
parent
443a431ac3
commit
45e5562fcc
|
|
@ -5694,6 +5694,26 @@ class TestAutogradFunctional(TestCase):
|
|||
for inputs in test_cases:
|
||||
self._test_construct_standard_basis_for(inputs)
|
||||
|
||||
def _test_vectorize_raises_no_warnings(self, api):
|
||||
# vmap is an experimental prototype. When someone calls torch.vmap,
|
||||
# it raises a python warning. This test checks that
|
||||
# autogradF.{jacobian, hessian} don't raise that experimental prototype
|
||||
# warning; it is not nice for a public-facing API to raise a warning
|
||||
# no matter how it is called.
|
||||
def foo(a):
|
||||
return (a ** 2).sum()
|
||||
|
||||
x = torch.randn(3)
|
||||
with warnings.catch_warnings(record=True) as wa:
|
||||
result = api(foo, x, vectorize=True)
|
||||
self.assertEqual(len(wa), 0)
|
||||
|
||||
def test_jacobian_vectorize_raises_no_warnings(self):
|
||||
return self._test_vectorize_raises_no_warnings(autogradF.jacobian)
|
||||
|
||||
def test_hessian_vectorize_raises_no_warnings(self):
|
||||
return self._test_vectorize_raises_no_warnings(autogradF.hessian)
|
||||
|
||||
def _test_jacobian_err_check(self, vectorize):
|
||||
def foo(a):
|
||||
return 3 * a.narrow(0, 0, 3)
|
||||
|
|
|
|||
|
|
@ -245,7 +245,10 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca
|
|||
'`torch._C._debug_only_display_vmap_fallback_warnings(True) '
|
||||
'before the call to `vmap`.',
|
||||
stacklevel=2)
|
||||
return _vmap(func, in_dims, out_dims)
|
||||
|
||||
# A version of vmap but without the initial "experimental prototype" warning
|
||||
def _vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args):
|
||||
_check_out_dims_is_int_or_int_tuple(out_dims, func)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
from typing import Tuple, List
|
||||
from torch._vmap_internals import vmap
|
||||
from torch._vmap_internals import _vmap
|
||||
|
||||
# Utility functions
|
||||
|
||||
|
|
@ -417,10 +417,11 @@ def jacobian(func, inputs, create_graph=False, strict=False, vectorize=False):
|
|||
independent of it. If ``False``, we return a Tensor of zeros as the
|
||||
jacobian for said inputs, which is the expected mathematical value.
|
||||
Defaults to ``False``.
|
||||
vectorize (bool, optional): This feature is experimental. When computing
|
||||
the jacobian, usually we invoke ``autograd.grad`` once per row of
|
||||
the jacobian. If this flag is ``True``, we use vmap as the backend
|
||||
to vectorize calls to ``autograd.grad`` so we only invoke it once
|
||||
vectorize (bool, optional): This feature is experimental, please use at
|
||||
your own risk. When computing the jacobian, usually we invoke
|
||||
``autograd.grad`` once per row of the jacobian. If this flag is
|
||||
``True``, we use the vmap prototype feature as the backend to
|
||||
vectorize calls to ``autograd.grad`` so we only invoke it once
|
||||
instead of once per row. This should lead to performance
|
||||
improvements in many use cases, however, due to this feature
|
||||
being incomplete, there may be performance cliffs. Please
|
||||
|
|
@ -537,7 +538,7 @@ def jacobian(func, inputs, create_graph=False, strict=False, vectorize=False):
|
|||
vj[el_idx] = torch.zeros_like(inputs[el_idx])
|
||||
return tuple(vj)
|
||||
|
||||
jacobians_of_flat_output = vmap(vjp)(grad_outputs)
|
||||
jacobians_of_flat_output = _vmap(vjp)(grad_outputs)
|
||||
|
||||
# Step 3: The returned jacobian is one big tensor per input. In this step,
|
||||
# we split each Tensor by output.
|
||||
|
|
@ -607,12 +608,13 @@ def hessian(func, inputs, create_graph=False, strict=False, vectorize=False):
|
|||
such that all the outputs are independent of it. If ``False``, we return a Tensor of zeros as the
|
||||
hessian for said inputs, which is the expected mathematical value.
|
||||
Defaults to ``False``.
|
||||
vectorize (bool, optional): This feature is experimental. When
|
||||
computing the hessian, usually we invoke ``autograd.grad`` once
|
||||
per row of the hessian. If this flag is ``True``, we use vmap as
|
||||
the backend to vectorize calls to ``autograd.grad`` so we only
|
||||
invoke it once instead of once per row. This should lead to
|
||||
performance improvements in many use cases, however, due to this feature
|
||||
vectorize (bool, optional): This feature is experimental, please use at
|
||||
your own risk. When computing the hessian, usually we invoke
|
||||
``autograd.grad`` once per row of the hessian. If this flag is
|
||||
``True``, we use the vmap prototype feature as the backend to
|
||||
vectorize calls to ``autograd.grad`` so we only invoke it once
|
||||
instead of once per row. This should lead to performance
|
||||
improvements in many use cases, however, due to this feature
|
||||
being incomplete, there may be performance cliffs. Please
|
||||
use `torch._C._debug_only_display_vmap_fallback_warnings(True)`
|
||||
to show any performance warnings and file us issues if
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user