From fb2e1878cbd386ec43a00d615166ebaf622397d7 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Wed, 21 Dec 2022 12:37:37 -0500 Subject: [PATCH] [torch.func] alias torch.func.vmap as torch.vmap (#91026) This PR also redirects torch.vmap to torch.func.vmap instead of the old vmap prototype. Test Plan: - tests - view docs preview Pull Request resolved: https://github.com/pytorch/pytorch/pull/91026 Approved by: https://github.com/albanD, https://github.com/samdow --- aten/src/ATen/LegacyBatchedFallback.cpp | 9 +- test/test_legacy_vmap.py | 3 +- torch/__init__.py | 5 +- torch/_functorch/eager_transforms.py | 3 +- torch/_functorch/vmap.py | 22 +++-- torch/_vmap_internals.py | 105 +----------------------- torch/autograd/gradcheck.py | 2 +- torch/overrides.py | 1 + 8 files changed, 27 insertions(+), 123 deletions(-) diff --git a/aten/src/ATen/LegacyBatchedFallback.cpp b/aten/src/ATen/LegacyBatchedFallback.cpp index 72794ece1c5..83e95472a68 100644 --- a/aten/src/ATen/LegacyBatchedFallback.cpp +++ b/aten/src/ATen/LegacyBatchedFallback.cpp @@ -72,10 +72,11 @@ static void warnFallback(const c10::FunctionSchema& schema) { } TORCH_WARN("There is a performance drop because we have not yet implemented ", "the batching rule for ", schema.operator_name(), ". ", - "We've moved development of vmap to to functorch " - "(https://github.com/pytorch/functorch), please try functorch.vmap " - "instead and/or file ", - " an issue on GitHub so that we can prioritize its implementation."); + "You are using the legacy vmap prototype (torch._vmap_internals.vmap). ", + "If you are using torch.autograd.functional.{jacobian, hessian} ", + "or torch._vmap_internals.vmap: please switch to using ", + "torch.func.{jacrev, jacfwd, hessian} and/or torch.vmap instead ", + "for better operator coverage and performance improvements ."); } // The general flow of the algorithm is as follows. diff --git a/test/test_legacy_vmap.py b/test/test_legacy_vmap.py index 6a80268e794..7e3e5215633 100644 --- a/test/test_legacy_vmap.py +++ b/test/test_legacy_vmap.py @@ -3,7 +3,8 @@ from torch.testing._internal.common_utils import TestCase, run_tests import torch import torch.nn.functional as F -from torch import Tensor, vmap +from torch import Tensor +from torch._vmap_internals import vmap import functools import itertools import warnings diff --git a/torch/__init__.py b/torch/__init__.py index c375b94bdb6..56544d28bfa 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -48,7 +48,7 @@ __all__ = [ 'set_deterministic_debug_mode', 'get_deterministic_debug_mode', 'set_float32_matmul_precision', 'get_float32_matmul_precision', 'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat', - 'compile', + 'compile', 'vmap', ] ################################################################################ @@ -1116,8 +1116,6 @@ del register_after_fork # torch.jit.script as a decorator, for instance): from ._lobpcg import lobpcg as lobpcg -from ._vmap_internals import vmap as vmap - # These were previously defined in native_functions.yaml and appeared on the # `torch` namespace, but we moved them to c10 dispatch to facilitate custom # class usage. We add these lines here to preserve backward compatibility. @@ -1245,3 +1243,4 @@ if 'TORCH_CUDA_SANITIZER' in os.environ: import torch.fx.experimental.symbolic_shapes from torch import func as func +from torch.func import vmap diff --git a/torch/_functorch/eager_transforms.py b/torch/_functorch/eager_transforms.py index 320e0997238..9af1017dd60 100644 --- a/torch/_functorch/eager_transforms.py +++ b/torch/_functorch/eager_transforms.py @@ -1259,8 +1259,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients: - >>> from torch.func import grad - >>> from torch.func import vmap + >>> from torch.func import grad, vmap >>> batch_size, feature_size = 3, 5 >>> >>> def model(weights, feature_vec): diff --git a/torch/_functorch/vmap.py b/torch/_functorch/vmap.py index 950262ee4e5..142f40f7e9e 100644 --- a/torch/_functorch/vmap.py +++ b/torch/_functorch/vmap.py @@ -255,6 +255,10 @@ def vmap( take batches of examples with ``vmap(func)``. vmap can also be used to compute batched gradients when composed with autograd. + .. note:: + :func:`torch.vmap` is aliased to :func:`torch.func.vmap` for + convenience. Use whichever one you'd like. + Args: func (function): A Python function that takes one or more arguments. Must return one or more Tensors. @@ -308,7 +312,7 @@ def vmap( >>> return feature_vec.dot(weights).relu() >>> >>> examples = torch.randn(batch_size, feature_size) - >>> result = torch.func.vmap(model)(examples) + >>> result = torch.vmap(model)(examples) :func:`vmap` can also help vectorize computations that were previously difficult or impossible to batch. One example is higher-order gradient computation. @@ -333,12 +337,12 @@ def vmap( >>> # vectorized gradient computation >>> def get_vjp(v): >>> return torch.autograd.grad(y, x, v) - >>> jacobian = torch.func.vmap(get_vjp)(I_N) + >>> jacobian = torch.vmap(get_vjp)(I_N) :func:`vmap` can also be nested, producing an output with multiple batched dimensions >>> torch.dot # [D], [D] -> [] - >>> batched_dot = torch.func.vmap(torch.func.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0] + >>> batched_dot = torch.vmap(torch.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0] >>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5) >>> batched_dot(x, y) # tensor of size [2, 3] @@ -346,7 +350,7 @@ def vmap( the dimension that each inputs are batched along as >>> torch.dot # [N], [N] -> [] - >>> batched_dot = torch.func.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D] + >>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D] >>> x, y = torch.randn(2, 5), torch.randn(2, 5) >>> batched_dot(x, y) # output is [5] instead of [2] if batched along the 0th dimension @@ -354,7 +358,7 @@ def vmap( ``in_dims`` must be a tuple with the batch dimension for each input as >>> torch.dot # [D], [D] -> [] - >>> batched_dot = torch.func.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N] + >>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N] >>> x, y = torch.randn(2, 5), torch.randn(5) >>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None @@ -364,7 +368,7 @@ def vmap( >>> f = lambda dict: torch.dot(dict['x'], dict['y']) >>> x, y = torch.randn(2, 5), torch.randn(5) >>> input = {'x': x, 'y': y} - >>> batched_dot = torch.func.vmap(f, in_dims=({'x': 0, 'y': None},)) + >>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},)) >>> batched_dot(input) By default, the output is batched along the first dimension. However, it can be batched @@ -372,17 +376,17 @@ def vmap( >>> f = lambda x: x ** 2 >>> x = torch.randn(2, 5) - >>> batched_pow = torch.func.vmap(f, out_dims=1) + >>> batched_pow = torch.vmap(f, out_dims=1) >>> batched_pow(x) # [5, 2] For any function that uses kwargs, the returned function will not batch the kwargs but will accept kwargs >>> x = torch.randn([2, 5]) - >>> def f(x, scale=4.): + >>> def fn(x, scale=4.): >>> return x * scale >>> - >>> batched_pow = torch.func.vmap(f) + >>> batched_pow = torch.vmap(fn) >>> assert torch.allclose(batched_pow(x), x * 4) >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5] diff --git a/torch/_vmap_internals.py b/torch/_vmap_internals.py index 5bb88c06ed7..8226d8a9fef 100644 --- a/torch/_vmap_internals.py +++ b/torch/_vmap_internals.py @@ -191,111 +191,10 @@ def _get_name(func: Callable): # on BatchedTensors perform the batched operations that the user is asking for. def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable: """ - vmap is the vectorizing map. Returns a new function that maps `func` over some - dimension of the inputs. Semantically, vmap pushes the map into PyTorch - operations called by `func`, effectively vectorizing those operations. - - vmap is useful for handling batch dimensions: one can write a function `func` - that runs on examples and then lift it to a function that can take batches of - examples with `vmap(func)`. vmap can also be used to compute batched - gradients when composed with autograd. - - .. note:: - We have moved development of vmap to - `functorch. `_ functorch's - vmap is able to arbitrarily compose with gradient computation - and contains significant performance improvements. - Please give that a try if that is what you're looking for. - - Furthermore, if you're interested in using vmap for your use case, - please `contact us! `_ - We're interested in gathering feedback from early adopters to inform - the design. - - .. warning:: - torch.vmap is an experimental prototype that is subject to - change and/or deletion. Please use at your own risk. - - Args: - func (function): A Python function that takes one or more arguments. - Must return one or more Tensors. - in_dims (int or nested structure): Specifies which dimension of the - inputs should be mapped over. `in_dims` should have a structure - like the inputs. If the `in_dim` for a particular input is None, - then that indicates there is no map dimension. Default: 0. - out_dims (int or Tuple[int]): Specifies where the mapped dimension - should appear in the outputs. If `out_dims` is a Tuple, then it should - have one element per output. Default: 0. - - Returns: - Returns a new "batched" function. It takes the same inputs as `func`, - except each input has an extra dimension at the index specified by `in_dims`. - It takes returns the same outputs as `func`, except each output has - an extra dimension at the index specified by `out_dims`. - - .. warning: - vmap works best with functional-style code. Please do not perform any - side-effects in `func`, with the exception of in-place PyTorch operations. - Examples of side-effects include mutating Python data structures and - assigning values to variables not captured in `func`. - - One example of using `vmap` is to compute batched dot products. PyTorch - doesn't provide a batched `torch.dot` API; instead of unsuccessfully - rummaging through docs, use `vmap` to construct a new function. - - >>> torch.dot # [D], [D] -> [] - >>> batched_dot = torch.vmap(torch.dot) # [N, D], [N, D] -> [N] - >>> x, y = torch.randn(2, 5), torch.randn(2, 5) - >>> batched_dot(x, y) - - `vmap` can be helpful in hiding batch dimensions, leading to a simpler - model authoring experience. - - >>> batch_size, feature_size = 3, 5 - >>> weights = torch.randn(feature_size, requires_grad=True) - >>> - >>> def model(feature_vec): - >>> # Very simple linear model with activation - >>> return feature_vec.dot(weights).relu() - >>> - >>> examples = torch.randn(batch_size, feature_size) - >>> result = torch.vmap(model)(examples) - - `vmap` can also help vectorize computations that were previously difficult - or impossible to batch. One example is higher-order gradient computation. - The PyTorch autograd engine computes vjps (vector-Jacobian products). - Computing a full Jacobian matrix for some function f: R^N -> R^N usually - requires N calls to `autograd.grad`, one per Jacobian row. Using `vmap`, - we can vectorize the whole computation, computing the Jacobian in a single - call to `autograd.grad`. - - >>> # Setup - >>> N = 5 - >>> f = lambda x: x ** 2 - >>> x = torch.randn(N, requires_grad=True) - >>> y = f(x) - >>> I_N = torch.eye(N) - >>> - >>> # Sequential approach - >>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0] - >>> for v in I_N.unbind()] - >>> jacobian = torch.stack(jacobian_rows) - >>> - >>> # vectorized gradient computation - >>> def get_vjp(v): - >>> return torch.autograd.grad(y, x, v) - >>> jacobian = torch.vmap(get_vjp)(I_N) - - .. note:: - vmap does not provide general autobatching or handle variable-length - sequences out of the box. + Please use torch.vmap instead of this API. """ warnings.warn( - "Please use functorch.vmap instead of torch.vmap " - "(https://github.com/pytorch/functorch). " - "We've moved development on torch.vmap over to functorch; " - "functorch's vmap has a multitude of significant performance and " - "functionality improvements.", + "Please use torch.vmap instead of torch._vmap_internals.vmap. ", stacklevel=2, ) return _vmap(func, in_dims, out_dims) diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index 8ffb3a38c5c..086c2b470ef 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -870,7 +870,7 @@ def _test_batched_grad(input, output, output_idx) -> bool: # NB: this doesn't work for CUDA tests: https://github.com/pytorch/pytorch/issues/50209 with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="There is a performance drop") - warnings.filterwarnings("ignore", message="Please use functorch.vmap") + warnings.filterwarnings("ignore", message="Please use torch.vmap") try: result = vmap(vjp)(torch.stack(grad_outputs)) except RuntimeError as ex: diff --git a/torch/overrides.py b/torch/overrides.py index 0ab3f8d1bba..cc64ca60fc9 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -238,6 +238,7 @@ def get_ignored_functions() -> Set[Callable]: torch.vitals_enabled, torch.set_vital, torch.read_vitals, + torch.vmap, torch.frombuffer, torch.asarray, Tensor.__delitem__,