mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable out_dims for vmap frontend API (#40576)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40576 `out_dims` specifies where in the output tensors the vmapped dimension should appear. We implement this by simply creating a view with the batch dimension moved to the desired position. `out_dims` must either: - be int (use the same value for all outputs) - be Tuple[int] (so the user specifies one out_dim per output). (See the vmap docstring for what we advertise out_dims to do). I also renamed `TestVmap` to `TestVmapAPI` to make it clearer that we are testing the API here and not specific operators (which will go into their own test class). Test Plan: - `pytest test/test_vmap.py -v` Differential Revision: D22288086 Pulled By: zou3519 fbshipit-source-id: c8666cb1a0e22c54473d8045477e14c2089167cf
This commit is contained in:
parent
2f94b7f95c
commit
a6a31bcd47
|
|
@ -23,7 +23,9 @@ constexpr int64_t kBatchDimsStackSize = 5;
|
|||
// a BatchDim represents a "private" dimension on a Tensor created inside of
|
||||
// vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
|
||||
// is being vmap'ed over and the `level` being an identifier for which vmap
|
||||
// said dimension was created inside.
|
||||
// said dimension was created inside. The `dim` corresponds to a "physical
|
||||
// dim" - it is a dimension index on the underlying physical tensor that is being
|
||||
// vmapped over.
|
||||
struct BatchDim {
|
||||
BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
|
||||
int64_t dim() const {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#include <ATen/BatchedTensorImpl.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/VmapTransforms.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
|
@ -20,6 +21,78 @@ static bool has_level(const Tensor& self, int64_t level) {
|
|||
return it != bdims.end();
|
||||
}
|
||||
|
||||
// Returns a Tensor with batch dim with level `level` turned into a regular dimension,
|
||||
// as well as a logical dim index of where said dimension is in the returned tensor.
|
||||
// A call to this function is always followed by a call to `movedim`.
|
||||
//
|
||||
// Preconditions: A BatchDim with level `level` must exist inside `batched`.
|
||||
//
|
||||
// The reason why we want to return the index of where said dimension is in the returned
|
||||
// tensor is because we want to keep track of which dimension used to be the batch
|
||||
// dimension so that we can move it to the correct logical dimension specified by
|
||||
// `out_dims` in vmap. For example, if we had
|
||||
// >>> x = torch.randn(2, 3, 5)
|
||||
// >>> vmap(lambda x: x, in_dims=0, out_dims=1)(x)
|
||||
// then right when we are about to exit the vmap block, x is a BatchedTensor with a
|
||||
// batch dimension at (physical) index 0. Note that the batch dimension doesn't
|
||||
// always have to exist at (physical) index 0. When we undo the batch dimension,
|
||||
// we want to move it to dimension 1 (as specified by out_dims). So we return the
|
||||
// index at which the batch dim appears so that we can move it to the correct place.
|
||||
// later down the line via a call to `movedim`.
|
||||
static std::pair<Tensor,int64_t> remove_existing_batch_dim(
|
||||
const BatchedTensorImpl* batched, int64_t level) {
|
||||
auto bdims = batched->bdims();
|
||||
if (bdims.size() == 1) {
|
||||
TORCH_INTERNAL_ASSERT(bdims[0].level() == level);
|
||||
return std::make_pair(batched->value(), bdims[0].dim());
|
||||
}
|
||||
BatchDims new_bdims;
|
||||
int64_t newly_exposed_physical_dim = -1;
|
||||
new_bdims.reserve(bdims.size() - 1);
|
||||
for (const auto& bdim : bdims) {
|
||||
if (bdim.level() == level) {
|
||||
newly_exposed_physical_dim = bdim.dim();
|
||||
} else {
|
||||
new_bdims.push_back(bdim);
|
||||
}
|
||||
}
|
||||
// Because a BatchDim with level `level` must exist inside `batched,
|
||||
// we should have found a `newly_exposed_logical_dim`.
|
||||
TORCH_INTERNAL_ASSERT(newly_exposed_physical_dim != -1);
|
||||
int64_t num_batch_dims_before_newly_exposed_physical_dim = std::count_if(
|
||||
new_bdims.begin(), new_bdims.end(),
|
||||
[&](const BatchDim& bdim) {
|
||||
return bdim.dim() < newly_exposed_physical_dim;
|
||||
});
|
||||
int64_t newly_exposed_logical_dim =
|
||||
newly_exposed_physical_dim - num_batch_dims_before_newly_exposed_physical_dim;
|
||||
auto result_tensor = makeBatched(batched->value(), std::move(new_bdims));
|
||||
return std::make_pair(std::move(result_tensor), newly_exposed_logical_dim);
|
||||
}
|
||||
|
||||
// Poor man's version of np.moveaxis. Moves the dimension at `dst` to `src`
|
||||
// while preserving the order of other existing dimensions.
|
||||
// We should probably add np.moveaxis (it is more general) to PyTorch. (#36048)
|
||||
// When we do, replace the following with it.
|
||||
static Tensor movedim(const Tensor& self, int64_t src, int64_t dst) {
|
||||
auto logical_dim = self.dim();
|
||||
src = maybe_wrap_dim(src, logical_dim);
|
||||
dst = maybe_wrap_dim(dst, logical_dim);
|
||||
if (src == dst) {
|
||||
return self;
|
||||
}
|
||||
VmapDimVector permutation;
|
||||
permutation.reserve(logical_dim);
|
||||
for (int64_t dim = 0; dim < logical_dim; dim++) {
|
||||
if (dim == src) {
|
||||
continue;
|
||||
}
|
||||
permutation.push_back(dim);
|
||||
}
|
||||
permutation.insert(permutation.begin() + dst, src);
|
||||
return self.permute(permutation);
|
||||
}
|
||||
|
||||
// Removes the batch dim with level `level` from `self`. If this causes the
|
||||
// last batch dim to be removed from a BatchedTensor, then this returns a
|
||||
// regular Tensor.
|
||||
|
|
@ -37,7 +110,6 @@ static bool has_level(const Tensor& self, int64_t level) {
|
|||
//
|
||||
// `out_dim` controls where we should put the batch dimension in the output tensor.
|
||||
Tensor _remove_batch_dim(const Tensor& self, int64_t level, int64_t batch_size, int64_t out_dim) {
|
||||
TORCH_INTERNAL_ASSERT(out_dim == 0);
|
||||
if (!has_level(self, level)) {
|
||||
auto self_sizes = self.sizes();
|
||||
VmapDimVector expanded_sizes(self_sizes.begin(), self_sizes.end());
|
||||
|
|
@ -45,17 +117,14 @@ Tensor _remove_batch_dim(const Tensor& self, int64_t level, int64_t batch_size,
|
|||
return self.expand(expanded_sizes);
|
||||
}
|
||||
|
||||
// Must be batched if has_level(self, /*any_level*/)
|
||||
const auto* batched = maybeGetBatched(self);
|
||||
TORCH_INTERNAL_ASSERT(batched != nullptr);
|
||||
auto bdims = batched->bdims();
|
||||
if (bdims.size() == 1) {
|
||||
return batched->value();
|
||||
}
|
||||
BatchDims new_bdims;
|
||||
new_bdims.reserve(bdims.size() - 1);
|
||||
std::copy_if(bdims.begin(), bdims.end(), std::back_inserter(new_bdims),
|
||||
[&](const BatchDim& bdim) { return bdim.level() != level; });
|
||||
return makeBatched(batched->value(), std::move(new_bdims));
|
||||
|
||||
Tensor self_without_bdim;
|
||||
int64_t newly_exposed_logical_dim;
|
||||
std::tie(self_without_bdim, newly_exposed_logical_dim) = remove_existing_batch_dim(batched, level);
|
||||
return movedim(self_without_bdim, newly_exposed_logical_dim, out_dim);
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from torch.testing._internal.common_utils import TestCase, run_tests
|
|||
import torch
|
||||
from torch import vmap
|
||||
|
||||
class TestVmap(TestCase):
|
||||
class TestVmapAPI(TestCase):
|
||||
def test_non_tensor_output_raises(self):
|
||||
with self.assertRaisesRegex(ValueError, "got type <class 'float'> as the return"):
|
||||
output = vmap(lambda x: 3.14)(torch.ones(3))
|
||||
|
|
@ -135,6 +135,143 @@ class TestVmap(TestCase):
|
|||
RuntimeError, 'Tried to call KernelFunction::call'):
|
||||
vmap(foo)(x)
|
||||
|
||||
def test_nonzero_out_dims(self):
|
||||
# Basic test
|
||||
tensor = torch.randn(2, 3)
|
||||
result = vmap(lambda x: x, out_dims=1)(tensor)
|
||||
self.assertEqual(result, tensor.permute(1, 0))
|
||||
self.assertEqual(result.data_ptr(), tensor.data_ptr())
|
||||
|
||||
# Test that the batch dimension gets permuted to dim 2
|
||||
tensor = torch.randn(2, 3, 5, 7)
|
||||
result = vmap(lambda x: x, out_dims=2)(tensor)
|
||||
self.assertEqual(result, tensor.permute(1, 2, 0, 3))
|
||||
self.assertEqual(result.data_ptr(), tensor.data_ptr())
|
||||
|
||||
# negative out_dim
|
||||
tensor = torch.randn(2, 3, 5, 7)
|
||||
result = vmap(lambda x: x, out_dims=-1)(tensor)
|
||||
self.assertEqual(result, tensor.permute(1, 2, 3, 0))
|
||||
self.assertEqual(result.data_ptr(), tensor.data_ptr())
|
||||
|
||||
# check that out_dims works on ALL outputs
|
||||
tensor = torch.randn(2, 3, 5, 7)
|
||||
other = torch.randn(2, 3, 5, 7)
|
||||
result = vmap(lambda x, y: (x, y), out_dims=2)(tensor, other)
|
||||
self.assertEqual(result, (tensor.permute(1, 2, 0, 3), other.permute(1, 2, 0, 3)))
|
||||
|
||||
# use out_dims with the maximum vmap-able tensor dims (64 dims)
|
||||
ndims = 64
|
||||
shape = [2] + [1] * (ndims - 1)
|
||||
expected_shape = [1, 1, 2] + [1] * (ndims - 3)
|
||||
tensor = torch.randn(shape)
|
||||
result = vmap(lambda x: x, out_dims=2)(tensor)
|
||||
self.assertEqual(result.shape, expected_shape)
|
||||
|
||||
# test something that is not the identity function
|
||||
def foo(x, y):
|
||||
return x, x * y, x * y * y
|
||||
x = torch.randn(2, 3, 5)
|
||||
y = torch.randn(2, 3, 5)
|
||||
result = vmap(foo, out_dims=1)(x, y)
|
||||
self.assertEqual(
|
||||
result,
|
||||
(x.permute(1, 0, 2), (x * y).permute(1, 0, 2), (x * y * y).permute(1, 0, 2)))
|
||||
|
||||
def test_multiple_out_dims(self):
|
||||
def foo(x):
|
||||
return x, x
|
||||
|
||||
def bar(x, y):
|
||||
return x, x, x, x * y
|
||||
|
||||
x = torch.randn(2, 3, 5)
|
||||
y = torch.randn(2, 3, 5)
|
||||
result = vmap(foo, out_dims=(0, 1))(x)
|
||||
self.assertEqual(result, (x, x.permute(1, 0, 2)))
|
||||
|
||||
result = vmap(bar, out_dims=(-1, 0, 1, 2))(x, y)
|
||||
expected = (
|
||||
x.permute(1, 2, 0),
|
||||
x,
|
||||
x.permute(1, 0, 2),
|
||||
(x * y).permute(1, 2, 0),
|
||||
)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_nested_out_dims(self):
|
||||
y = torch.randn(2, 3, 5, 7)
|
||||
|
||||
# Inner vmap has non-zero out_dim
|
||||
result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y))(y)
|
||||
self.assertEqual(result.shape, (2, 5, 3, 7))
|
||||
self.assertEqual(result, y.permute(0, 2, 1, 3))
|
||||
|
||||
# all vmaps have non-zero out_dim
|
||||
result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y), out_dims=1)(y)
|
||||
self.assertEqual(result.shape, (5, 2, 3, 7))
|
||||
self.assertEqual(result, y.permute(2, 0, 1, 3))
|
||||
|
||||
# throwing in some negative out_dims
|
||||
result = vmap(lambda y: vmap(lambda x: x, out_dims=-1)(y), out_dims=-1)(y)
|
||||
self.assertEqual(result.shape, (5, 7, 3, 2))
|
||||
self.assertEqual(result, y.permute(2, 3, 1, 0))
|
||||
|
||||
# testing fn that isn't the identity
|
||||
x = torch.randn(2, 3)
|
||||
y = torch.randn(5, 3)
|
||||
result = vmap(lambda y: vmap(lambda x: x * y, out_dims=1)(x), out_dims=-1)(y)
|
||||
self.assertEqual(result.shape, (3, 2, 5))
|
||||
self.assertEqual(result, (y.view(5, 1, 3) * x).permute(2, 1, 0))
|
||||
|
||||
def test_out_dims_edge_case(self):
|
||||
def foo(x):
|
||||
return x
|
||||
|
||||
# Test that we accept out_dims=(1,) for a function with one output.
|
||||
tensor = torch.randn(2, 3)
|
||||
expected = vmap(foo, out_dims=1)(tensor)
|
||||
result = vmap(foo, out_dims=(1,))(tensor)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_out_dims_must_be_int_or_tuple_of_int_err_msg(self):
|
||||
msg = '`out_dims` must be an int or a tuple of int'
|
||||
tensor = torch.randn(2, 3)
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
vmap(lambda x: x, out_dims='lol')(tensor)
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
vmap(lambda x: x, out_dims=('lol',))(tensor)
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
vmap(lambda x: x, out_dims=None)(tensor)
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
vmap(lambda x: x, out_dims=(None,))(tensor)
|
||||
|
||||
def test_out_dims_and_num_outputs_mismatch_err_msg(self):
|
||||
msg = '`out_dims` must have one dim per output'
|
||||
x = torch.randn(2, 3, 5)
|
||||
|
||||
# Too many out_dims
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
vmap(lambda x: x, out_dims=(0, 0))(x)
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
vmap(lambda x: (x, x, x), out_dims=(0, 0, 0, 0))(x)
|
||||
|
||||
# Too few out_dims
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
vmap(lambda x: (x, x), out_dims=(0,))(x)
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
vmap(lambda x: (x, x, x), out_dims=(0, 0))(x)
|
||||
|
||||
def test_out_dim_out_of_bounds_err_msg(self):
|
||||
# TODO(rzou): This error message isn't that great. It comes straight
|
||||
# from maybe_wrap_dim. Consider doing a try-catch-(add some context) to
|
||||
# the error message in the future in C++
|
||||
msg = 'Dimension out of range'
|
||||
x = torch.randn(2, 3, 5)
|
||||
with self.assertRaisesRegex(IndexError, msg):
|
||||
vmap(lambda x: x, out_dims=3)(x)
|
||||
with self.assertRaisesRegex(IndexError, msg):
|
||||
vmap(lambda x: x, out_dims=-4)(x)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -24,6 +24,17 @@ NO_INPUTS = (
|
|||
'function with no inputs. The latter is unsupported.'
|
||||
)
|
||||
|
||||
OUT_DIMS_MUST_BE_INT_OR_TUPLE_OF_INT = (
|
||||
'vmap({fn}, ..., out_dims={out_dims}): `out_dims` must be an int or a tuple '
|
||||
'of int representing where in the outputs the vmapped dimension should appear.'
|
||||
)
|
||||
|
||||
OUT_DIMS_AND_NUM_OUTPUTS_MISMATCH = (
|
||||
'vmap({fn}, ..., out_dims={out_dims}): `out_dims` must have one dim per '
|
||||
'output (got {num_outputs} outputs) of {fn}.'
|
||||
)
|
||||
|
||||
|
||||
# Checks that all args have the same batch dim size.
|
||||
def _validate_and_get_batch_size(args):
|
||||
batch_sizes = [arg.size(0) for arg in args]
|
||||
|
|
@ -37,15 +48,36 @@ def _validate_inputs_and_get_batch_size(args, fn_name):
|
|||
raise ValueError(NO_INPUTS.format(fn=fn_name))
|
||||
return _validate_and_get_batch_size(args)
|
||||
|
||||
def _num_outputs(batched_outputs):
|
||||
if isinstance(batched_outputs, tuple):
|
||||
return len(batched_outputs)
|
||||
return 1
|
||||
|
||||
# If value is a tuple, check it has length `num_elements`.
|
||||
# If value is not a tuple, make a tuple with `value` repeated `num_elements` times
|
||||
def _as_tuple(value, num_elements, error_message_lambda):
|
||||
if not isinstance(value, tuple):
|
||||
return (value,) * num_elements
|
||||
if len(value) != num_elements:
|
||||
raise ValueError(error_message_lambda())
|
||||
return value
|
||||
|
||||
# Undos the batching (and any batch dimensions) associated with the `vmap_level`.
|
||||
def _unwrap_batched(batched_outputs, vmap_level, batch_size):
|
||||
def _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, fn_name):
|
||||
num_outputs = _num_outputs(batched_outputs)
|
||||
out_dims_as_tuple = _as_tuple(
|
||||
out_dims, num_outputs,
|
||||
lambda: OUT_DIMS_AND_NUM_OUTPUTS_MISMATCH.format(
|
||||
fn=fn_name, out_dims=out_dims, num_outputs=num_outputs))
|
||||
|
||||
# NOTE [Ignored _remove_batch_dim, _add_batch_dim]
|
||||
# There is something wrong with our type bindings for functions that begin
|
||||
# with '_', see #40397.
|
||||
if isinstance(batched_outputs, Tensor):
|
||||
return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, 0) # type: ignore
|
||||
return tuple(torch._remove_batch_dim(out, vmap_level, batch_size, 0) # type: ignore
|
||||
for out in batched_outputs)
|
||||
out_dim = out_dims_as_tuple[0]
|
||||
return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim) # type: ignore
|
||||
return tuple(torch._remove_batch_dim(out, vmap_level, batch_size, out_dim) # type: ignore
|
||||
for out, out_dim in zip(batched_outputs, out_dims_as_tuple))
|
||||
|
||||
# Checks that `fn` returned one or more Tensors and nothing else.
|
||||
# NB: A python function that return multiple arguments returns a single tuple,
|
||||
|
|
@ -61,6 +93,13 @@ def _validate_outputs(outputs, fn_name):
|
|||
continue
|
||||
raise ValueError(ELEMENT_MUST_BE_TENSOR.format(fn=fn_name, out=type(output), idx=idx))
|
||||
|
||||
def _check_out_dims_is_int_or_int_tuple(out_dims, fn_name):
|
||||
if isinstance(out_dims, int):
|
||||
return
|
||||
if not isinstance(out_dims, tuple) or \
|
||||
not all([isinstance(out_dim, int) for out_dim in out_dims]):
|
||||
raise ValueError(OUT_DIMS_MUST_BE_INT_OR_TUPLE_OF_INT.format(out_dims=out_dims, fn=fn_name))
|
||||
|
||||
# This is the global tracker for how many nested vmaps we are currently inside.
|
||||
VMAP_LEVEL = 0
|
||||
|
||||
|
|
@ -111,23 +150,23 @@ def vmap(func, in_dims=0, out_dims=0):
|
|||
|
||||
if in_dims != 0:
|
||||
raise NotImplementedError('NYI: vmap with `in_dims` other than 0')
|
||||
if out_dims != 0:
|
||||
raise NotImplementedError('NYI: vmap with `out_dims` other than 0')
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args):
|
||||
if any(not isinstance(arg, Tensor) for arg in args):
|
||||
raise NotImplementedError('NYI: vmap with non-tensor inputs')
|
||||
|
||||
batch_size = _validate_inputs_and_get_batch_size(args, func.__name__)
|
||||
fn_name = func.__name__
|
||||
_check_out_dims_is_int_or_int_tuple(out_dims, fn_name)
|
||||
batch_size = _validate_inputs_and_get_batch_size(args, fn_name)
|
||||
global VMAP_LEVEL
|
||||
VMAP_LEVEL += 1
|
||||
try:
|
||||
# See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
|
||||
batched_inputs = [torch._add_batch_dim(arg, 0, VMAP_LEVEL) for arg in args] # type: ignore
|
||||
batched_outputs = func(*batched_inputs)
|
||||
_validate_outputs(batched_outputs, func.__name__)
|
||||
return _unwrap_batched(batched_outputs, VMAP_LEVEL, batch_size)
|
||||
_validate_outputs(batched_outputs, fn_name)
|
||||
return _unwrap_batched(batched_outputs, out_dims, VMAP_LEVEL, batch_size, fn_name)
|
||||
finally:
|
||||
VMAP_LEVEL -= 1
|
||||
return wrapped
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user