Enable out_dims for vmap frontend API (#40576)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40576

`out_dims` specifies where in the output tensors the vmapped dimension
should appear. We implement this by simply creating a view with the
batch dimension moved to the desired position.

`out_dims` must either:
- be int (use the same value for all outputs)
- be Tuple[int] (so the user specifies one out_dim per output).
(See the vmap docstring for what we advertise out_dims to do).

I also renamed `TestVmap` to `TestVmapAPI` to make it clearer that we
are testing the API here and not specific operators (which will go into
their own test class).

Test Plan: - `pytest test/test_vmap.py -v`

Differential Revision: D22288086

Pulled By: zou3519

fbshipit-source-id: c8666cb1a0e22c54473d8045477e14c2089167cf
This commit is contained in:
Richard Zou 2020-06-30 08:16:32 -07:00 committed by Facebook GitHub Bot
parent 2f94b7f95c
commit a6a31bcd47
4 changed files with 268 additions and 21 deletions

View File

@ -23,7 +23,9 @@ constexpr int64_t kBatchDimsStackSize = 5;
// a BatchDim represents a "private" dimension on a Tensor created inside of
// vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
// is being vmap'ed over and the `level` being an identifier for which vmap
// said dimension was created inside.
// said dimension was created inside. The `dim` corresponds to a "physical
// dim" - it is a dimension index on the underlying physical tensor that is being
// vmapped over.
struct BatchDim {
BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
int64_t dim() const {

View File

@ -1,4 +1,5 @@
#include <ATen/BatchedTensorImpl.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/VmapTransforms.h>
namespace at { namespace native {
@ -20,6 +21,78 @@ static bool has_level(const Tensor& self, int64_t level) {
return it != bdims.end();
}
// Returns a Tensor with batch dim with level `level` turned into a regular dimension,
// as well as a logical dim index of where said dimension is in the returned tensor.
// A call to this function is always followed by a call to `movedim`.
//
// Preconditions: A BatchDim with level `level` must exist inside `batched`.
//
// The reason why we want to return the index of where said dimension is in the returned
// tensor is because we want to keep track of which dimension used to be the batch
// dimension so that we can move it to the correct logical dimension specified by
// `out_dims` in vmap. For example, if we had
// >>> x = torch.randn(2, 3, 5)
// >>> vmap(lambda x: x, in_dims=0, out_dims=1)(x)
// then right when we are about to exit the vmap block, x is a BatchedTensor with a
// batch dimension at (physical) index 0. Note that the batch dimension doesn't
// always have to exist at (physical) index 0. When we undo the batch dimension,
// we want to move it to dimension 1 (as specified by out_dims). So we return the
// index at which the batch dim appears so that we can move it to the correct place.
// later down the line via a call to `movedim`.
static std::pair<Tensor,int64_t> remove_existing_batch_dim(
const BatchedTensorImpl* batched, int64_t level) {
auto bdims = batched->bdims();
if (bdims.size() == 1) {
TORCH_INTERNAL_ASSERT(bdims[0].level() == level);
return std::make_pair(batched->value(), bdims[0].dim());
}
BatchDims new_bdims;
int64_t newly_exposed_physical_dim = -1;
new_bdims.reserve(bdims.size() - 1);
for (const auto& bdim : bdims) {
if (bdim.level() == level) {
newly_exposed_physical_dim = bdim.dim();
} else {
new_bdims.push_back(bdim);
}
}
// Because a BatchDim with level `level` must exist inside `batched,
// we should have found a `newly_exposed_logical_dim`.
TORCH_INTERNAL_ASSERT(newly_exposed_physical_dim != -1);
int64_t num_batch_dims_before_newly_exposed_physical_dim = std::count_if(
new_bdims.begin(), new_bdims.end(),
[&](const BatchDim& bdim) {
return bdim.dim() < newly_exposed_physical_dim;
});
int64_t newly_exposed_logical_dim =
newly_exposed_physical_dim - num_batch_dims_before_newly_exposed_physical_dim;
auto result_tensor = makeBatched(batched->value(), std::move(new_bdims));
return std::make_pair(std::move(result_tensor), newly_exposed_logical_dim);
}
// Poor man's version of np.moveaxis. Moves the dimension at `dst` to `src`
// while preserving the order of other existing dimensions.
// We should probably add np.moveaxis (it is more general) to PyTorch. (#36048)
// When we do, replace the following with it.
static Tensor movedim(const Tensor& self, int64_t src, int64_t dst) {
auto logical_dim = self.dim();
src = maybe_wrap_dim(src, logical_dim);
dst = maybe_wrap_dim(dst, logical_dim);
if (src == dst) {
return self;
}
VmapDimVector permutation;
permutation.reserve(logical_dim);
for (int64_t dim = 0; dim < logical_dim; dim++) {
if (dim == src) {
continue;
}
permutation.push_back(dim);
}
permutation.insert(permutation.begin() + dst, src);
return self.permute(permutation);
}
// Removes the batch dim with level `level` from `self`. If this causes the
// last batch dim to be removed from a BatchedTensor, then this returns a
// regular Tensor.
@ -37,7 +110,6 @@ static bool has_level(const Tensor& self, int64_t level) {
//
// `out_dim` controls where we should put the batch dimension in the output tensor.
Tensor _remove_batch_dim(const Tensor& self, int64_t level, int64_t batch_size, int64_t out_dim) {
TORCH_INTERNAL_ASSERT(out_dim == 0);
if (!has_level(self, level)) {
auto self_sizes = self.sizes();
VmapDimVector expanded_sizes(self_sizes.begin(), self_sizes.end());
@ -45,17 +117,14 @@ Tensor _remove_batch_dim(const Tensor& self, int64_t level, int64_t batch_size,
return self.expand(expanded_sizes);
}
// Must be batched if has_level(self, /*any_level*/)
const auto* batched = maybeGetBatched(self);
TORCH_INTERNAL_ASSERT(batched != nullptr);
auto bdims = batched->bdims();
if (bdims.size() == 1) {
return batched->value();
}
BatchDims new_bdims;
new_bdims.reserve(bdims.size() - 1);
std::copy_if(bdims.begin(), bdims.end(), std::back_inserter(new_bdims),
[&](const BatchDim& bdim) { return bdim.level() != level; });
return makeBatched(batched->value(), std::move(new_bdims));
Tensor self_without_bdim;
int64_t newly_exposed_logical_dim;
std::tie(self_without_bdim, newly_exposed_logical_dim) = remove_existing_batch_dim(batched, level);
return movedim(self_without_bdim, newly_exposed_logical_dim, out_dim);
}
} // namespace native

View File

@ -2,7 +2,7 @@ from torch.testing._internal.common_utils import TestCase, run_tests
import torch
from torch import vmap
class TestVmap(TestCase):
class TestVmapAPI(TestCase):
def test_non_tensor_output_raises(self):
with self.assertRaisesRegex(ValueError, "got type <class 'float'> as the return"):
output = vmap(lambda x: 3.14)(torch.ones(3))
@ -135,6 +135,143 @@ class TestVmap(TestCase):
RuntimeError, 'Tried to call KernelFunction::call'):
vmap(foo)(x)
def test_nonzero_out_dims(self):
# Basic test
tensor = torch.randn(2, 3)
result = vmap(lambda x: x, out_dims=1)(tensor)
self.assertEqual(result, tensor.permute(1, 0))
self.assertEqual(result.data_ptr(), tensor.data_ptr())
# Test that the batch dimension gets permuted to dim 2
tensor = torch.randn(2, 3, 5, 7)
result = vmap(lambda x: x, out_dims=2)(tensor)
self.assertEqual(result, tensor.permute(1, 2, 0, 3))
self.assertEqual(result.data_ptr(), tensor.data_ptr())
# negative out_dim
tensor = torch.randn(2, 3, 5, 7)
result = vmap(lambda x: x, out_dims=-1)(tensor)
self.assertEqual(result, tensor.permute(1, 2, 3, 0))
self.assertEqual(result.data_ptr(), tensor.data_ptr())
# check that out_dims works on ALL outputs
tensor = torch.randn(2, 3, 5, 7)
other = torch.randn(2, 3, 5, 7)
result = vmap(lambda x, y: (x, y), out_dims=2)(tensor, other)
self.assertEqual(result, (tensor.permute(1, 2, 0, 3), other.permute(1, 2, 0, 3)))
# use out_dims with the maximum vmap-able tensor dims (64 dims)
ndims = 64
shape = [2] + [1] * (ndims - 1)
expected_shape = [1, 1, 2] + [1] * (ndims - 3)
tensor = torch.randn(shape)
result = vmap(lambda x: x, out_dims=2)(tensor)
self.assertEqual(result.shape, expected_shape)
# test something that is not the identity function
def foo(x, y):
return x, x * y, x * y * y
x = torch.randn(2, 3, 5)
y = torch.randn(2, 3, 5)
result = vmap(foo, out_dims=1)(x, y)
self.assertEqual(
result,
(x.permute(1, 0, 2), (x * y).permute(1, 0, 2), (x * y * y).permute(1, 0, 2)))
def test_multiple_out_dims(self):
def foo(x):
return x, x
def bar(x, y):
return x, x, x, x * y
x = torch.randn(2, 3, 5)
y = torch.randn(2, 3, 5)
result = vmap(foo, out_dims=(0, 1))(x)
self.assertEqual(result, (x, x.permute(1, 0, 2)))
result = vmap(bar, out_dims=(-1, 0, 1, 2))(x, y)
expected = (
x.permute(1, 2, 0),
x,
x.permute(1, 0, 2),
(x * y).permute(1, 2, 0),
)
self.assertEqual(result, expected)
def test_nested_out_dims(self):
y = torch.randn(2, 3, 5, 7)
# Inner vmap has non-zero out_dim
result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y))(y)
self.assertEqual(result.shape, (2, 5, 3, 7))
self.assertEqual(result, y.permute(0, 2, 1, 3))
# all vmaps have non-zero out_dim
result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y), out_dims=1)(y)
self.assertEqual(result.shape, (5, 2, 3, 7))
self.assertEqual(result, y.permute(2, 0, 1, 3))
# throwing in some negative out_dims
result = vmap(lambda y: vmap(lambda x: x, out_dims=-1)(y), out_dims=-1)(y)
self.assertEqual(result.shape, (5, 7, 3, 2))
self.assertEqual(result, y.permute(2, 3, 1, 0))
# testing fn that isn't the identity
x = torch.randn(2, 3)
y = torch.randn(5, 3)
result = vmap(lambda y: vmap(lambda x: x * y, out_dims=1)(x), out_dims=-1)(y)
self.assertEqual(result.shape, (3, 2, 5))
self.assertEqual(result, (y.view(5, 1, 3) * x).permute(2, 1, 0))
def test_out_dims_edge_case(self):
def foo(x):
return x
# Test that we accept out_dims=(1,) for a function with one output.
tensor = torch.randn(2, 3)
expected = vmap(foo, out_dims=1)(tensor)
result = vmap(foo, out_dims=(1,))(tensor)
self.assertEqual(result, expected)
def test_out_dims_must_be_int_or_tuple_of_int_err_msg(self):
msg = '`out_dims` must be an int or a tuple of int'
tensor = torch.randn(2, 3)
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: x, out_dims='lol')(tensor)
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: x, out_dims=('lol',))(tensor)
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: x, out_dims=None)(tensor)
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: x, out_dims=(None,))(tensor)
def test_out_dims_and_num_outputs_mismatch_err_msg(self):
msg = '`out_dims` must have one dim per output'
x = torch.randn(2, 3, 5)
# Too many out_dims
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: x, out_dims=(0, 0))(x)
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: (x, x, x), out_dims=(0, 0, 0, 0))(x)
# Too few out_dims
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: (x, x), out_dims=(0,))(x)
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: (x, x, x), out_dims=(0, 0))(x)
def test_out_dim_out_of_bounds_err_msg(self):
# TODO(rzou): This error message isn't that great. It comes straight
# from maybe_wrap_dim. Consider doing a try-catch-(add some context) to
# the error message in the future in C++
msg = 'Dimension out of range'
x = torch.randn(2, 3, 5)
with self.assertRaisesRegex(IndexError, msg):
vmap(lambda x: x, out_dims=3)(x)
with self.assertRaisesRegex(IndexError, msg):
vmap(lambda x: x, out_dims=-4)(x)
if __name__ == '__main__':
run_tests()

View File

@ -24,6 +24,17 @@ NO_INPUTS = (
'function with no inputs. The latter is unsupported.'
)
OUT_DIMS_MUST_BE_INT_OR_TUPLE_OF_INT = (
'vmap({fn}, ..., out_dims={out_dims}): `out_dims` must be an int or a tuple '
'of int representing where in the outputs the vmapped dimension should appear.'
)
OUT_DIMS_AND_NUM_OUTPUTS_MISMATCH = (
'vmap({fn}, ..., out_dims={out_dims}): `out_dims` must have one dim per '
'output (got {num_outputs} outputs) of {fn}.'
)
# Checks that all args have the same batch dim size.
def _validate_and_get_batch_size(args):
batch_sizes = [arg.size(0) for arg in args]
@ -37,15 +48,36 @@ def _validate_inputs_and_get_batch_size(args, fn_name):
raise ValueError(NO_INPUTS.format(fn=fn_name))
return _validate_and_get_batch_size(args)
def _num_outputs(batched_outputs):
if isinstance(batched_outputs, tuple):
return len(batched_outputs)
return 1
# If value is a tuple, check it has length `num_elements`.
# If value is not a tuple, make a tuple with `value` repeated `num_elements` times
def _as_tuple(value, num_elements, error_message_lambda):
if not isinstance(value, tuple):
return (value,) * num_elements
if len(value) != num_elements:
raise ValueError(error_message_lambda())
return value
# Undos the batching (and any batch dimensions) associated with the `vmap_level`.
def _unwrap_batched(batched_outputs, vmap_level, batch_size):
def _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, fn_name):
num_outputs = _num_outputs(batched_outputs)
out_dims_as_tuple = _as_tuple(
out_dims, num_outputs,
lambda: OUT_DIMS_AND_NUM_OUTPUTS_MISMATCH.format(
fn=fn_name, out_dims=out_dims, num_outputs=num_outputs))
# NOTE [Ignored _remove_batch_dim, _add_batch_dim]
# There is something wrong with our type bindings for functions that begin
# with '_', see #40397.
if isinstance(batched_outputs, Tensor):
return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, 0) # type: ignore
return tuple(torch._remove_batch_dim(out, vmap_level, batch_size, 0) # type: ignore
for out in batched_outputs)
out_dim = out_dims_as_tuple[0]
return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim) # type: ignore
return tuple(torch._remove_batch_dim(out, vmap_level, batch_size, out_dim) # type: ignore
for out, out_dim in zip(batched_outputs, out_dims_as_tuple))
# Checks that `fn` returned one or more Tensors and nothing else.
# NB: A python function that return multiple arguments returns a single tuple,
@ -61,6 +93,13 @@ def _validate_outputs(outputs, fn_name):
continue
raise ValueError(ELEMENT_MUST_BE_TENSOR.format(fn=fn_name, out=type(output), idx=idx))
def _check_out_dims_is_int_or_int_tuple(out_dims, fn_name):
if isinstance(out_dims, int):
return
if not isinstance(out_dims, tuple) or \
not all([isinstance(out_dim, int) for out_dim in out_dims]):
raise ValueError(OUT_DIMS_MUST_BE_INT_OR_TUPLE_OF_INT.format(out_dims=out_dims, fn=fn_name))
# This is the global tracker for how many nested vmaps we are currently inside.
VMAP_LEVEL = 0
@ -111,23 +150,23 @@ def vmap(func, in_dims=0, out_dims=0):
if in_dims != 0:
raise NotImplementedError('NYI: vmap with `in_dims` other than 0')
if out_dims != 0:
raise NotImplementedError('NYI: vmap with `out_dims` other than 0')
@functools.wraps(func)
def wrapped(*args):
if any(not isinstance(arg, Tensor) for arg in args):
raise NotImplementedError('NYI: vmap with non-tensor inputs')
batch_size = _validate_inputs_and_get_batch_size(args, func.__name__)
fn_name = func.__name__
_check_out_dims_is_int_or_int_tuple(out_dims, fn_name)
batch_size = _validate_inputs_and_get_batch_size(args, fn_name)
global VMAP_LEVEL
VMAP_LEVEL += 1
try:
# See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
batched_inputs = [torch._add_batch_dim(arg, 0, VMAP_LEVEL) for arg in args] # type: ignore
batched_outputs = func(*batched_inputs)
_validate_outputs(batched_outputs, func.__name__)
return _unwrap_batched(batched_outputs, VMAP_LEVEL, batch_size)
_validate_outputs(batched_outputs, fn_name)
return _unwrap_batched(batched_outputs, out_dims, VMAP_LEVEL, batch_size, fn_name)
finally:
VMAP_LEVEL -= 1
return wrapped