mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46671 Previously, the vmap fallback would choke whenever it saw an undefined tensor. For each sample in a batch, the fallback runs an operator and then stacks together outputs to get the actual output. Undefined tensors can occur as outputs while computing batched gradients with vmap. This PR updates the vmap fallback to handle undefined tensors which can appear in backward formulas: - if for each sample in a batch the output was undefined, then the vmap fallback returns an undefined tensor - if for each sample in a batch the output is defined, then the vmap fallback stacks together the defined tensors - if for some samples in a batch the output is defined/undefined, then we error out. Test Plan: - new tests Reviewed By: ezyang Differential Revision: D24454909 Pulled By: zou3519 fbshipit-source-id: d225382fd17881f23c9833323b68834cfef351f3
1822 lines
74 KiB
Python
1822 lines
74 KiB
Python
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
import torch
|
|
from torch import Tensor, vmap
|
|
import functools
|
|
import warnings
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
from torch.testing._internal.common_utils import TEST_WITH_ROCM
|
|
import types
|
|
|
|
|
|
FALLBACK_REGEX = r'falling back to slow \(for loop( and stack)?\) implementation'
|
|
|
|
class TestVmapAPI(TestCase):
|
|
def test_non_tensor_output_raises(self):
|
|
with self.assertRaisesRegex(ValueError, "got type <class 'float'> as the return"):
|
|
output = vmap(lambda x: 3.14)(torch.ones(3))
|
|
|
|
def multiple_outputs(x):
|
|
return x, 3
|
|
|
|
with self.assertRaisesRegex(ValueError, "got type <class 'int'> for return 1"):
|
|
vmap(multiple_outputs)(torch.ones(3))
|
|
|
|
def test_different_map_dim_size_raises(self):
|
|
x = torch.randn(2)
|
|
y = torch.randn(3)
|
|
expected_msg = 'Expected all tensors to have the same size in the mapped dimension'
|
|
with self.assertRaisesRegex(ValueError, expected_msg):
|
|
vmap(torch.mul)(x, y)
|
|
with self.assertRaisesRegex(ValueError, expected_msg):
|
|
vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y))
|
|
with self.assertRaisesRegex(ValueError, expected_msg):
|
|
vmap(lambda z: z['x'] + z['y'], in_dims=({'x': 0, 'y': 0},))({'x': x, 'y': y})
|
|
|
|
def test_func_with_no_inputs(self):
|
|
expected_msg = 'got no inputs'
|
|
|
|
def foo():
|
|
return torch.randn(3)
|
|
|
|
def bar(x):
|
|
return torch.randn(3)
|
|
|
|
with self.assertRaisesRegex(ValueError, expected_msg):
|
|
vmap(foo)()
|
|
|
|
with self.assertRaisesRegex(ValueError, expected_msg):
|
|
vmap(bar)()
|
|
|
|
def test_constant_function(self):
|
|
output = vmap(lambda x: torch.tensor(3.14))(torch.ones(3))
|
|
self.assertEqual(output, torch.tensor([3.14, 3.14, 3.14]))
|
|
|
|
def test_single_input(self):
|
|
x = torch.randn(2, 3)
|
|
|
|
def square(x):
|
|
return x * x
|
|
|
|
output = vmap(square)(x)
|
|
self.assertEqual(output, x * x)
|
|
|
|
def test_multiple_inputs(self):
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(2, 3)
|
|
output = vmap(torch.mul)(x, y)
|
|
self.assertEqual(output, x * y)
|
|
|
|
def test_multiple_outputs(self):
|
|
def foo(x):
|
|
return x * x, x * x * x
|
|
|
|
x = torch.randn(3)
|
|
outputs = vmap(foo)(x)
|
|
self.assertEqual(outputs[0], x * x)
|
|
self.assertEqual(outputs[1], x * x * x)
|
|
|
|
def test_multiple_outputs_error_cases(self):
|
|
# This is the same thing as
|
|
# def returns_tuple_of_tensors(x):
|
|
# return x, x
|
|
def returns_tuple_of_tensors(x):
|
|
return (x, x)
|
|
|
|
def returns_list_of_two_tensors(x):
|
|
return [x, x]
|
|
|
|
def returns_list_of_one_tensor(x):
|
|
return [x]
|
|
|
|
x = torch.randn(3)
|
|
|
|
# should not throw
|
|
vmap(returns_tuple_of_tensors)(x)
|
|
|
|
# jax supports these, but we don't yet
|
|
msg = "must only return Tensors, got type <class 'list'>"
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(returns_list_of_two_tensors)(x)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(returns_list_of_one_tensor)(x)
|
|
|
|
def test_nested_with_same_map_dim(self):
|
|
x = torch.randn(2, 3, 5)
|
|
y = torch.randn(2, 3, 5)
|
|
output = vmap(vmap(torch.mul))(x, y)
|
|
self.assertEqual(output, x * y)
|
|
|
|
output = vmap(vmap(vmap(torch.mul)))(x, y)
|
|
self.assertEqual(output, x * y)
|
|
|
|
def test_nested_with_different_map_dim(self):
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(5, 3)
|
|
output = vmap(lambda x: vmap(lambda y: x * y)(y))(x)
|
|
self.assertEqual(output.shape, (2, 5, 3))
|
|
self.assertEqual(output, x.view(2, 1, 3) * y)
|
|
|
|
z = torch.randn(7, 3)
|
|
output = vmap(lambda x: vmap(lambda y: vmap(lambda z: x * y * z)(z))(y))(x)
|
|
self.assertEqual(output.shape, (2, 5, 7, 3))
|
|
self.assertEqual(output, x.view(2, 1, 1, 3) * y.view(5, 1, 3) * z)
|
|
|
|
def test_noop_in_inner_vmap(self):
|
|
x = torch.randn(3)
|
|
y = torch.randn(5)
|
|
output = vmap(lambda x: vmap(lambda y: x)(y))(x)
|
|
self.assertEqual(output, x.view(3, 1).expand(3, 5))
|
|
|
|
def test_unsupported_op_err_msg(self):
|
|
# Unsupported view op
|
|
tensor = torch.randn(2, 3)
|
|
msg = (
|
|
r"Batching rule not implemented for aten::.+; the "
|
|
r"fallback path doesn't work on out= or view ops"
|
|
)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(torch.as_strided, (0, None, None))(tensor, [2, 3], [0, 0])
|
|
|
|
def out_op(x, y):
|
|
return torch.abs(x, out=y)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(out_op)(tensor, tensor)
|
|
|
|
# The fallback doesn't support TensorList
|
|
with self.assertRaisesRegex(RuntimeError, 'Batching rule not implemented'):
|
|
vmap(lambda t: torch.atleast_1d([t]))(tensor)
|
|
|
|
# Don't support non-tensor returns. This is a limitation of vmap;
|
|
# functions that don't return tensors must be special cased
|
|
with self.assertRaisesRegex(RuntimeError, 'Batching rule not implemented'):
|
|
vmap(torch.Tensor.item)(tensor)
|
|
|
|
def test_nonzero_out_dims(self):
|
|
# Basic test
|
|
tensor = torch.randn(2, 3)
|
|
result = vmap(lambda x: x, out_dims=1)(tensor)
|
|
self.assertEqual(result, tensor.permute(1, 0))
|
|
self.assertEqual(result.data_ptr(), tensor.data_ptr())
|
|
|
|
# Test that the batch dimension gets permuted to dim 2
|
|
tensor = torch.randn(2, 3, 5, 7)
|
|
result = vmap(lambda x: x, out_dims=2)(tensor)
|
|
self.assertEqual(result, tensor.permute(1, 2, 0, 3))
|
|
self.assertEqual(result.data_ptr(), tensor.data_ptr())
|
|
|
|
# negative out_dim
|
|
tensor = torch.randn(2, 3, 5, 7)
|
|
result = vmap(lambda x: x, out_dims=-1)(tensor)
|
|
self.assertEqual(result, tensor.permute(1, 2, 3, 0))
|
|
self.assertEqual(result.data_ptr(), tensor.data_ptr())
|
|
|
|
# check that out_dims works on ALL outputs
|
|
tensor = torch.randn(2, 3, 5, 7)
|
|
other = torch.randn(2, 3, 5, 7)
|
|
result = vmap(lambda x, y: (x, y), out_dims=2)(tensor, other)
|
|
self.assertEqual(result, (tensor.permute(1, 2, 0, 3), other.permute(1, 2, 0, 3)))
|
|
|
|
# use out_dims with the maximum vmap-able tensor dims (64 dims)
|
|
ndims = 64
|
|
shape = [2] + [1] * (ndims - 1)
|
|
expected_shape = [1, 1, 2] + [1] * (ndims - 3)
|
|
tensor = torch.randn(shape)
|
|
result = vmap(lambda x: x, out_dims=2)(tensor)
|
|
self.assertEqual(result.shape, expected_shape)
|
|
|
|
# test something that is not the identity function
|
|
def foo(x, y):
|
|
return x, x * y, x * y * y
|
|
x = torch.randn(2, 3, 5)
|
|
y = torch.randn(2, 3, 5)
|
|
result = vmap(foo, out_dims=1)(x, y)
|
|
self.assertEqual(
|
|
result,
|
|
(x.permute(1, 0, 2), (x * y).permute(1, 0, 2), (x * y * y).permute(1, 0, 2)))
|
|
|
|
def test_multiple_out_dims(self):
|
|
def foo(x):
|
|
return x, x
|
|
|
|
def bar(x, y):
|
|
return x, x, x, x * y
|
|
|
|
x = torch.randn(2, 3, 5)
|
|
y = torch.randn(2, 3, 5)
|
|
result = vmap(foo, out_dims=(0, 1))(x)
|
|
self.assertEqual(result, (x, x.permute(1, 0, 2)))
|
|
|
|
result = vmap(bar, out_dims=(-1, 0, 1, 2))(x, y)
|
|
expected = (
|
|
x.permute(1, 2, 0),
|
|
x,
|
|
x.permute(1, 0, 2),
|
|
(x * y).permute(1, 2, 0),
|
|
)
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_nested_out_dims(self):
|
|
y = torch.randn(2, 3, 5, 7)
|
|
|
|
# Inner vmap has non-zero out_dim
|
|
result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y))(y)
|
|
self.assertEqual(result.shape, (2, 5, 3, 7))
|
|
self.assertEqual(result, y.permute(0, 2, 1, 3))
|
|
|
|
# all vmaps have non-zero out_dim
|
|
result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y), out_dims=1)(y)
|
|
self.assertEqual(result.shape, (5, 2, 3, 7))
|
|
self.assertEqual(result, y.permute(2, 0, 1, 3))
|
|
|
|
# throwing in some negative out_dims
|
|
result = vmap(lambda y: vmap(lambda x: x, out_dims=-1)(y), out_dims=-1)(y)
|
|
self.assertEqual(result.shape, (5, 7, 3, 2))
|
|
self.assertEqual(result, y.permute(2, 3, 1, 0))
|
|
|
|
# testing fn that isn't the identity
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(5, 3)
|
|
result = vmap(lambda y: vmap(lambda x: x * y, out_dims=1)(x), out_dims=-1)(y)
|
|
self.assertEqual(result.shape, (3, 2, 5))
|
|
self.assertEqual(result, (y.view(5, 1, 3) * x).permute(2, 1, 0))
|
|
|
|
def test_out_dims_edge_case(self):
|
|
def foo(x):
|
|
return x
|
|
|
|
# Test that we accept out_dims=(1,) for a function with one output.
|
|
tensor = torch.randn(2, 3)
|
|
expected = vmap(foo, out_dims=1)(tensor)
|
|
result = vmap(foo, out_dims=(1,))(tensor)
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_out_dims_must_be_int_or_tuple_of_int_err_msg(self):
|
|
msg = '`out_dims` must be an int or a tuple of int'
|
|
tensor = torch.randn(2, 3)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda x: x, out_dims='lol')(tensor)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda x: x, out_dims=('lol',))(tensor)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda x: x, out_dims=None)(tensor)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda x: x, out_dims=(None,))(tensor)
|
|
|
|
def test_out_dims_and_num_outputs_mismatch_err_msg(self):
|
|
msg = '`out_dims` must have one dim per output'
|
|
x = torch.randn(2, 3, 5)
|
|
|
|
# Too many out_dims
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda x: x, out_dims=(0, 0))(x)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda x: (x, x, x), out_dims=(0, 0, 0, 0))(x)
|
|
|
|
# Too few out_dims
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda x: (x, x), out_dims=(0,))(x)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda x: (x, x, x), out_dims=(0, 0))(x)
|
|
|
|
def test_out_dim_out_of_bounds_err_msg(self):
|
|
# TODO(rzou): This error message isn't that great. It comes straight
|
|
# from maybe_wrap_dim. Consider doing a try-catch-(add some context) to
|
|
# the error message in the future in C++
|
|
msg = 'Dimension out of range'
|
|
x = torch.randn(2, 3, 5)
|
|
with self.assertRaisesRegex(IndexError, msg):
|
|
vmap(lambda x: x, out_dims=3)(x)
|
|
with self.assertRaisesRegex(IndexError, msg):
|
|
vmap(lambda x: x, out_dims=-4)(x)
|
|
|
|
def test_non_zero_in_dims(self):
|
|
tensor = torch.randn(2, 3, 5)
|
|
|
|
# Implicit out_dims = 0; vmap will move the batch dim to the front.
|
|
output = vmap(lambda x: x, (1,))(tensor)
|
|
self.assertEqual(output, tensor.permute(1, 0, 2))
|
|
self.assertEqual(output.data_ptr(), tensor.data_ptr())
|
|
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(3, 2)
|
|
output = vmap(torch.mul, (0, 1))(x, y)
|
|
self.assertEqual(output, x * y.t())
|
|
output = vmap(torch.mul, (1, 0))(x, y)
|
|
self.assertEqual(output, x.t() * y)
|
|
|
|
def test_none_in_dims(self):
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(2, 3)
|
|
|
|
# None in_dim for a Tensor means we don't map over it
|
|
output = vmap(torch.mul, (0, None))(x, y)
|
|
self.assertEqual(output.shape, (2, 2, 3))
|
|
self.assertEqual(output, x.view(2, 1, 3) * y)
|
|
|
|
# None in_dim for non-tensor arguments
|
|
output = vmap(torch.mul, (0, None))(x, 2)
|
|
self.assertEqual(output, x * 2)
|
|
|
|
def test_nested_non_default_in_dims(self):
|
|
x = torch.rand(5, 2, 3)
|
|
y = torch.rand(3, 5, 2)
|
|
result = vmap(vmap(vmap(torch.mul), (1, 0)), (1, 2))(x, y)
|
|
self.assertEqual(result, x.permute(1, 2, 0) * y.permute(2, 0, 1))
|
|
|
|
def test_non_default_in_dims_out_dims(self):
|
|
x = torch.randn(2, 3, 5)
|
|
|
|
# Same in_dim as out_dim, vmap over identity
|
|
result = vmap(lambda x: x, in_dims=1, out_dims=1)(x)
|
|
self.assertEqual(result, x)
|
|
self.assertEqual(result.data_ptr(), x.data_ptr())
|
|
|
|
# Different in_dim from out_dim, vmap over identity
|
|
result = vmap(lambda x: x, in_dims=2, out_dims=1)(x)
|
|
self.assertEqual(result.shape, (2, 5, 3))
|
|
self.assertEqual(result, x.transpose(1, 2))
|
|
self.assertEqual(result.data_ptr(), x.data_ptr())
|
|
|
|
def foo(x):
|
|
return x * 2
|
|
|
|
# Same in_dim as out_dim, vmap over operation
|
|
result = vmap(foo, in_dims=1, out_dims=1)(x)
|
|
self.assertEqual(result, x * 2)
|
|
|
|
# Different in_dim as out_dim, vmap over operation
|
|
result = vmap(foo, in_dims=2, out_dims=1)(x)
|
|
self.assertEqual(result.shape, (2, 5, 3))
|
|
self.assertEqual(result, (x * 2).transpose(1, 2))
|
|
|
|
# Basic nested test.
|
|
result = vmap(vmap(foo, 1, 1), 1, 1)(x)
|
|
self.assertEqual(result, x * 2)
|
|
|
|
def test_accepts_nested_inputs(self):
|
|
B0 = 2
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(2, 3)
|
|
|
|
# Single layer of nesting
|
|
out = vmap(lambda z: z[0] + z[1])((x, y))
|
|
self.assertEqual(out, x + y)
|
|
out = vmap(lambda z: z[0] + z[1], in_dims=(0,))((x, y))
|
|
self.assertEqual(out, x + y)
|
|
out = vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y))
|
|
self.assertEqual(out, x + y)
|
|
|
|
out = vmap(lambda z: z[0] + z[1])([x, y])
|
|
self.assertEqual(out, x + y)
|
|
out = vmap(lambda z: z[0] + z[1], in_dims=(0,))([x, y])
|
|
self.assertEqual(out, x + y)
|
|
out = vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, y])
|
|
self.assertEqual(out, x + y)
|
|
|
|
out = vmap(lambda z: z['x'] + z['y'])({'x': x, 'y': y})
|
|
self.assertEqual(out, x + y)
|
|
out = vmap(lambda z: z['x'] + z['y'], in_dims=(0,))({'x': x, 'y': y})
|
|
self.assertEqual(out, x + y)
|
|
out = vmap(lambda z: z['x'] + z['y'], in_dims=({'x': 0, 'y': 0},))({'x': x, 'y': y})
|
|
self.assertEqual(out, x + y)
|
|
|
|
# Multiple layers of nesting
|
|
out_fn = vmap(lambda z: z['x'][0] + z['x'][1][0] + z['y'][0] + z['y'][1])
|
|
out = out_fn({'x': [x, (x,)], 'y': [y, y]})
|
|
self.assertEqual(out, x + x + y + y)
|
|
|
|
def test_in_dims_wrong_type_err_msg(self):
|
|
x = torch.randn(3)
|
|
y = torch.randn(3)
|
|
msg = r'expected `in_dims` to be int or a \(potentially nested\) tuple'
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(torch.mul, [0, 0])(x, y)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(torch.mul, set({0, 0}))(x, y)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(torch.mul, 'lol')(x, y)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda z: z[0] + z[1], in_dims=[0, 0])([x, y])
|
|
# The following should not throw
|
|
vmap(torch.mul, (0, 0))(x, y)
|
|
|
|
def test_not_enough_in_dims_err_msg(self):
|
|
x = torch.randn(3)
|
|
y = torch.randn(3)
|
|
msg = r'in_dims is not compatible with the structure of `inputs`'
|
|
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(torch.mul, (0,))(x, y)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(torch.mul, (0, 0, 0))(x, y)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda z: z[0] + z[1], in_dims=([0],))([x, y])
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))([x, y])
|
|
# The following should not throw
|
|
vmap(torch.mul, (0, 0))(x, y)
|
|
|
|
def test_integer_in_dim_but_not_tensor_input_err_msg(self):
|
|
def foo(xy):
|
|
return xy[0] * xy[1]
|
|
|
|
def bar(x, yz):
|
|
return x * yz[0] * yz[1]
|
|
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(2, 3)
|
|
|
|
# the following are errors in jax (and will always be errors)
|
|
msg = 'Got in_dim=0 for an input but the input is of type'
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(torch.sum)(x, 0)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(torch.sum, (0, 0))(x, 0)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, 1])
|
|
# The following should not throw
|
|
vmap(torch.sum, (0, None))(x, 0)
|
|
|
|
def test_in_dim_not_in_tensor_err_msg(self):
|
|
def foo(x):
|
|
return x * x
|
|
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(2, 3)
|
|
|
|
msg = r'Got in_dim=-?\w for some input, but that input is a Tensor of dimensionality \w'
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(foo)(torch.randn([]))
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(foo, in_dims=(0,))(torch.randn([]))
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(foo, in_dims=(-1,))(x)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(foo, in_dims=(2,))(y)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda z: z[0] + z[1], in_dims=([3, 0],))([x, y])
|
|
# the following should not throw
|
|
vmap(foo, in_dims=(0,))(torch.randn(2, 3))
|
|
vmap(foo, in_dims=(1,))(torch.randn(2, 3))
|
|
|
|
def _assert_uses_vmap_fallback(self, vmap_args, inputs):
|
|
with warnings.catch_warnings(record=True) as wa:
|
|
result = vmap(*vmap_args)(*inputs)
|
|
self.assertEqual(len(wa), 2)
|
|
self.assertRegex(str(wa[-1].message), FALLBACK_REGEX)
|
|
|
|
def test_fallback_atan2(self):
|
|
# NB: One day we will implement a batching rule for torch.atan2.
|
|
# If/when we do, this test should be replaced to test the fallback
|
|
# path on another operator to avoid bitrot.
|
|
op = torch.atan2
|
|
|
|
x = torch.randn(5, 7, 11)
|
|
y = torch.randn(5, 7, 11)
|
|
|
|
self._assert_uses_vmap_fallback((op,), (x, y))
|
|
|
|
# fallback on torch.atan2
|
|
x = torch.randn(7, 11, 5)
|
|
y = torch.randn(5, 7, 11)
|
|
result = vmap(op, (2, 0))(x, y)
|
|
self.assertEqual(result, op(x.permute(2, 0, 1), y))
|
|
|
|
# fallback on torch.atan2, nested vmap
|
|
x = torch.randn(7, 11, 5)
|
|
y = torch.randn(5, 7, 11)
|
|
result = vmap(vmap(op), (2, 0))(x, y)
|
|
self.assertEqual(result, op(x.permute(2, 0, 1), y))
|
|
|
|
# big batch size (total 10000)
|
|
x = torch.randn(100, 10, 10, 5)
|
|
y = torch.randn(100, 10, 10)
|
|
result = vmap(vmap(vmap(op)))(x, y)
|
|
self.assertEqual(result, op(x, y.view(100, 10, 10, 1)))
|
|
|
|
def test_fallback_masked_fill(self):
|
|
# NB: One day we will implement a batching rule for masked_fill
|
|
# If/when we do, this test should be replaced to test the fallback
|
|
# path on another operator to avoid bitrot.
|
|
def run_test(batch_size):
|
|
B0 = batch_size
|
|
x = torch.randn(B0, 7, 11, 13)
|
|
dim = 0
|
|
index = torch.tensor([0, 4, 2])
|
|
values = torch.randn(B0, 3, 13)
|
|
|
|
self._assert_uses_vmap_fallback((torch.index_add, (0, None, None, 0)), (x, dim, index, values))
|
|
|
|
result = vmap(torch.index_add, (0, None, None, 0))(x, dim, index, values)
|
|
expected = torch.index_add(
|
|
x, dim + 1, index, values.view(B0, 3, 1, 13))
|
|
self.assertEqual(result, expected)
|
|
|
|
run_test(batch_size=5)
|
|
run_test(batch_size=1237)
|
|
|
|
def test_fallback_multiple_returns(self):
|
|
# NB: One day we will implement a batching rule for torch.var_mean
|
|
# If/when we do, this test should be replaced to test the fallback
|
|
# path on another operator to avoid bitrot.
|
|
B0, B1, B2 = 2, 3, 1237
|
|
tensor = torch.randn(B0, 10)
|
|
|
|
self._assert_uses_vmap_fallback((torch.var_mean,), (tensor,))
|
|
|
|
# fallback correctness on torch.var_mean
|
|
result = vmap(torch.var_mean)(tensor)
|
|
expected = torch.var_mean(tensor, dim=1)
|
|
self.assertEqual(result, expected)
|
|
|
|
# nested vmap
|
|
tensor = torch.randn(B0, B1, 10)
|
|
result = vmap(vmap(torch.var_mean))(tensor)
|
|
expected = torch.var_mean(tensor, dim=2)
|
|
self.assertEqual(result, expected)
|
|
|
|
# big batch size, nested vmap
|
|
tensor = torch.randn(B0, B1, B2, 10)
|
|
result = vmap(vmap(vmap(torch.var_mean)))(tensor)
|
|
expected = torch.var_mean(tensor, dim=3)
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_inplace_fallback_unary(self):
|
|
# Test the in-place fallback on an in-place method that takes no
|
|
# additional Tensor arguments. This is the simplest case of the fallback.
|
|
# NB: One day we will implement a batching rule for acos_.
|
|
# If/when we do, this test should be replaced to test the fallback
|
|
# path on another operator to avoid bitrot.
|
|
op = Tensor.acos_
|
|
B0, B1, B2 = 2, 3, 10000
|
|
|
|
x = torch.randn(B0, 5)
|
|
self._assert_uses_vmap_fallback((op,), (x,))
|
|
|
|
# Single vmap
|
|
x_orig = torch.rand(B0, 5)
|
|
x = x_orig.clone()
|
|
result = vmap(op)(x)
|
|
self.assertTrue(result is x)
|
|
self.assertEqual(result, x_orig.acos())
|
|
|
|
# Single vmap + different out_dim produces a view(!)
|
|
x_orig = torch.rand(B0, 5)
|
|
x = x_orig.clone()
|
|
result = vmap(op, out_dims=(1,))(x)
|
|
self.assertTrue(result._base is x)
|
|
self.assertEqual(result, x_orig.t().acos())
|
|
|
|
# Nested vmap
|
|
x_orig = torch.randn(B0, B1, 5)
|
|
x = x_orig.clone()
|
|
result = vmap(vmap(op))(x)
|
|
self.assertTrue(result is x)
|
|
self.assertEqual(result, x_orig.acos())
|
|
|
|
# Nested vmap, large batch size
|
|
x_orig = torch.randn(B0, B1, B2, 5)
|
|
x = x_orig.clone()
|
|
result = vmap(vmap(vmap(op)))(x)
|
|
self.assertTrue(result is x)
|
|
self.assertEqual(result, x_orig.acos())
|
|
|
|
def test_inplace_fallback_nary_same_levels(self):
|
|
# NB: One day we will implement a batching rule for atan2_
|
|
# If/when we do, this test should be replaced to test the fallback
|
|
# path on another operator to avoid bitrot.
|
|
op = Tensor.atan2_
|
|
outplace_op = torch.atan2
|
|
|
|
x = torch.randn(5, 7, 11)
|
|
y = torch.randn(5, 7, 11)
|
|
self._assert_uses_vmap_fallback((op,), (x, y))
|
|
|
|
# Single vmap
|
|
B0 = 5
|
|
x_orig = torch.randn(7, 11, B0)
|
|
x = x_orig.clone()
|
|
y = torch.randn(B0, 7, 11)
|
|
vmap(op, (2, 0))(x, y)
|
|
self.assertEqual(x, outplace_op(x_orig, y.movedim(0, 2)))
|
|
|
|
# Nested vmap
|
|
B0, B1 = 5, 7
|
|
x_orig = torch.randn(B1, 11, B0)
|
|
x = x_orig.clone()
|
|
y = torch.randn(B0, B1, 11)
|
|
vmap(vmap(op), (2, 0))(x, y)
|
|
self.assertEqual(x, outplace_op(x_orig, y.movedim([0, 1], [2, 0])))
|
|
|
|
# big batch size (total 10000)
|
|
B0, B1, B2 = 100, 10, 10
|
|
x_orig = torch.randn(B0, B1, B2, 5)
|
|
x = x_orig.clone()
|
|
y = torch.randn(B0, B1, B2)
|
|
result = vmap(vmap(vmap(op)))(x, y)
|
|
self.assertEqual(x, outplace_op(x_orig, y.view(B0, B1, B2, 1)))
|
|
|
|
def test_inplace_fallback_nary_different_levels(self):
|
|
# NB: One day we will implement a batching rule for atan2_
|
|
# If/when we do, this test should be replaced to test the fallback
|
|
# path on another operator to avoid bitrot.
|
|
op = Tensor.atan2_
|
|
outplace_op = torch.atan2
|
|
B0, B1, B2 = 2, 3, 5
|
|
|
|
x = torch.rand(B0, 7)
|
|
y = torch.rand(7)
|
|
self._assert_uses_vmap_fallback((op, (0, None)), (x, y))
|
|
|
|
# op(left, right): All of the levels in right are found in left
|
|
x_orig = torch.rand(B0, 7)
|
|
x = x_orig.clone()
|
|
y = torch.rand(7)
|
|
vmap(op, in_dims=(0, None))(x, y)
|
|
self.assertEqual(x, outplace_op(x_orig, y))
|
|
|
|
x_orig = torch.rand(B0, B1, 7)
|
|
x = x_orig.clone()
|
|
y = torch.rand(B0, 7)
|
|
vmap(vmap(op, in_dims=(0, None)))(x, y)
|
|
self.assertEqual(x, outplace_op(x_orig, y.view(B0, 1, 7)))
|
|
|
|
# op(left, right): Some of the levels in right are not found in left
|
|
msg = r'vmap: aten::atan2_\(self, \*extra_args\) is not possible'
|
|
x = torch.rand(7)
|
|
y = torch.rand(B0, 7)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, in_dims=(None, 0))(x, y)
|
|
|
|
x = torch.rand(B1, 7)
|
|
y = torch.rand(B0, 7)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 0))(x, y)
|
|
|
|
x = torch.rand(B1, 7)
|
|
y = torch.rand(7, B0)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 1))(x, y)
|
|
|
|
x = torch.rand(B0, 7)
|
|
y = torch.rand(B0, B1, 7)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(vmap(op, in_dims=(None, 0)))(x, y)
|
|
|
|
def test_backward_unsupported_interaction(self):
|
|
x = torch.randn(3, requires_grad=True)
|
|
y = torch.randn(5)
|
|
grad = torch.randn_like(x)
|
|
err_msg = r'backward\(\) called inside torch.vmap'
|
|
|
|
def backward_on_vmapped_tensor(x):
|
|
x.sum().backward()
|
|
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
|
vmap(backward_on_vmapped_tensor)(x)
|
|
|
|
def backward_with_vmapped_grad(x, grad):
|
|
x.backward(grad)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
|
vmap(backward_with_vmapped_grad)(x, grad)
|
|
|
|
def completely_unrelated_backward(y):
|
|
x.sum().backward()
|
|
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
|
vmap(completely_unrelated_backward)(y)
|
|
|
|
def test_grad_unsupported_interaction(self):
|
|
input_tensor = torch.randn(3, requires_grad=True)
|
|
err_msg = 'autograd.grad.* called inside torch.vmap'
|
|
|
|
captured = torch.randn(3, requires_grad=True)
|
|
|
|
def output_to_grad_is_vmapped(input_tensor):
|
|
output = (captured * input_tensor).sum()
|
|
return torch.autograd.grad([output], [captured])[0]
|
|
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
|
vmap(output_to_grad_is_vmapped)(input_tensor)
|
|
|
|
output = (input_tensor ** 2).sum()
|
|
|
|
def input_to_grad_is_vmapped(input_tensor):
|
|
return torch.autograd.grad([output], [input_tensor])[0]
|
|
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
|
vmap(input_to_grad_is_vmapped)(input_tensor)
|
|
|
|
def test_batched_gradient_basic(self):
|
|
N = 3
|
|
x = torch.randn(N, requires_grad=True)
|
|
y = torch.randn(N)
|
|
|
|
def vjp_mul(v):
|
|
return torch.autograd.grad([x * y], [x], grad_outputs=[v])[0]
|
|
|
|
batched_v = torch.eye(N)
|
|
jacobian = vmap(vjp_mul)(batched_v)
|
|
self.assertEqual(jacobian, torch.diagflat(y))
|
|
|
|
def test_functools_partial(self):
|
|
x = torch.randn(3)
|
|
y = torch.randn(2, 3)
|
|
result = vmap(functools.partial(torch.mul, x))(y)
|
|
self.assertEqual(result, x * y)
|
|
|
|
def test_nn_module(self):
|
|
tensor = torch.randn(2, 3)
|
|
model = torch.nn.Linear(3, 3, bias=False)
|
|
result = vmap(model)(tensor)
|
|
self.assertEqual(result, model(tensor))
|
|
|
|
def test_fallback_with_undefined_grad(self):
|
|
B0 = 7
|
|
x = torch.randn(2, 3, 4, 5, requires_grad=True)
|
|
weight = torch.randn(3, 3, 1, 1)
|
|
v = torch.randn(B0, 2, 3, 4, 5)
|
|
|
|
def get_vjp(v):
|
|
result = torch.nn.functional.conv2d(x, weight)
|
|
grad_x, = torch.autograd.grad(result, x, v)
|
|
return grad_x
|
|
|
|
# Runs vmap(get_vjp)(v), which should not error out.
|
|
# The backward formula for convolution returns an undefined
|
|
# Tensor for grad_bias because the original bias does not exist.
|
|
#
|
|
# In the future we'll probably add a batching rule for convolution
|
|
# backward. When this happens, we should modify this test to use a
|
|
# different op (and/or create and use a dummy operator) to avoid bitrot.
|
|
self._assert_uses_vmap_fallback([get_vjp], [v])
|
|
|
|
def slice_inputs(inputs, bdims, i):
|
|
result = []
|
|
for inp, bdim in zip(inputs, bdims):
|
|
if bdim is None:
|
|
result.append(inp)
|
|
else:
|
|
result.append(inp.select(bdim, i))
|
|
return tuple(result)
|
|
|
|
|
|
def reference_vmap(op, inputs, in_dims=0, out_dims=0):
|
|
if isinstance(in_dims, int):
|
|
in_dims = (in_dims,) * len(inputs)
|
|
bdim_sizes = [inp.size(dim) for inp, dim in zip(inputs, in_dims) if dim is not None]
|
|
assert all(bdim_size == bdim_sizes[0] for bdim_size in bdim_sizes)
|
|
bdim_size = bdim_sizes[0]
|
|
results = tuple(op(*slice_inputs(inputs, in_dims, i)) for i in range(bdim_size))
|
|
|
|
assert len(results) > 0
|
|
op_has_single_return = not isinstance(results[0], tuple)
|
|
if op_has_single_return:
|
|
assert all(isinstance(result, torch.Tensor) for result in results)
|
|
if isinstance(out_dims, int):
|
|
out_dims = (out_dims,) * 1
|
|
return torch.stack(results, dim=out_dims[0])
|
|
|
|
assert all(isinstance(result, tuple) for result in results)
|
|
num_returns = len(results[0])
|
|
assert all(len(result) == num_returns for result in results)
|
|
if isinstance(out_dims, int):
|
|
out_dims = (out_dims,) * num_returns
|
|
return tuple(torch.stack(result_shards, out_dim)
|
|
for result_shards, out_dim in zip(zip(*results), out_dims))
|
|
|
|
|
|
class TensorFactory:
|
|
@staticmethod
|
|
def rand(size, device='cpu', dtype=torch.float):
|
|
return torch.rand(size, device=device, dtype=dtype)
|
|
|
|
@staticmethod
|
|
def randn(size, device='cpu', dtype=torch.float):
|
|
return torch.randn(size, device=device, dtype=dtype)
|
|
|
|
@staticmethod
|
|
def randp1(size, device='cpu', dtype=torch.float):
|
|
return torch.rand(size, device=device, dtype=dtype) + 1
|
|
|
|
# Tests vmap(op, in_dims, out_dims)(*inputs) by comparing the output to a
|
|
# (slow) sequential map+stack fallback.
|
|
#
|
|
# check_view: Test if the first returned output is a view of the first input
|
|
# check_propagates_grad: Test if the operation propagates gradients.
|
|
def _vmap_test(self, op, inputs, in_dims=0, out_dims=0,
|
|
check_view=False, check_propagates_grad=True):
|
|
result = vmap(op, in_dims, out_dims)(*inputs)
|
|
reference_result = reference_vmap(op, inputs, in_dims, out_dims)
|
|
self.assertEqual(result, reference_result)
|
|
op_has_single_return = not isinstance(result, tuple)
|
|
|
|
if check_view:
|
|
result_as_tuple = (result,) if op_has_single_return else result
|
|
for output in result_as_tuple:
|
|
input0_base = inputs[0] if inputs[0]._base is None else inputs[0]._base
|
|
self.assertTrue(output._base is input0_base,
|
|
msg="result was not a view of the first input!")
|
|
|
|
if not check_propagates_grad:
|
|
return
|
|
# Assuming input[0] is a floating-point tensor. Check if the vmap
|
|
# operation propagates the requires_grad flag to the zeroth output.
|
|
# Some vmap operators are implemented in a way that assumes that
|
|
# they are composite with respect to autograd. If the operator ever is
|
|
# changed to not be composite with respect to autograd, then the
|
|
# following check should fail.
|
|
inputs_clone = list(inputs)
|
|
inputs_clone[0] = inputs[0].clone().requires_grad_()
|
|
result = vmap(op, in_dims, out_dims)(*inputs_clone)
|
|
result_as_tuple = (result,) if op_has_single_return else result
|
|
self.assertTrue(result[0].requires_grad)
|
|
|
|
def should_allow_vmap_fallback_usage(fn):
|
|
return getattr(fn, '_allow_vmap_fallback_usage', False)
|
|
|
|
def allowVmapFallbackUsage(fn):
|
|
fn._allow_vmap_fallback_usage = True
|
|
return fn
|
|
|
|
# All tests of TestVmapBase check that the slow vmap fallback is never invoked.
|
|
# This is so that we can incrementally add batching rules for operators to
|
|
# replace the slow vmap fallback path for said operators. To skip this check,
|
|
# please use the allowVmapFallbackUsage decorator.
|
|
#
|
|
# NB: Don't add tests to TestVmapBase directly, unless you want them to run
|
|
# on every subclass of TestVmapBase. Add them to e.g. TestVmapOperators.
|
|
#
|
|
# NB: TestVmapBase is a nested class. This prevents test runners from picking
|
|
# it up and running it.
|
|
class Namespace:
|
|
class TestVmapBase(TestCase):
|
|
def __init__(self, method_name='runTest'):
|
|
super().__init__(method_name)
|
|
|
|
test_method = getattr(self, method_name, None)
|
|
if test_method is None:
|
|
return
|
|
|
|
if not should_allow_vmap_fallback_usage(test_method):
|
|
setattr(self, method_name,
|
|
self._wrap_method_with_vmap_fallback_check(test_method))
|
|
|
|
def _wrap_method_with_vmap_fallback_check(self, method):
|
|
msg = (
|
|
'Expected the test to not invoke the vmap fallback path, i.e., '
|
|
'all of the operators being tested in this test should have batching '
|
|
'rules implemented. If you are intentionally testing something to '
|
|
'do with the fallback path, use allowVmapFallbackUsage. Otherwise, '
|
|
'please make sure that batching rules are implemented for the '
|
|
'operator(s) being tested.'
|
|
)
|
|
|
|
@functools.wraps(method)
|
|
def wrapper(self, *args, **kwargs):
|
|
with warnings.catch_warnings(record=True) as wa:
|
|
warnings.simplefilter('always')
|
|
method(*args, **kwargs)
|
|
for captured_warning in wa:
|
|
self.assertNotRegex(str(captured_warning.message), FALLBACK_REGEX, msg)
|
|
return types.MethodType(wrapper, self)
|
|
|
|
@allowVmapFallbackUsage
|
|
def test_vmap_fallback_check_ok(self):
|
|
# One day we'll implement a batching rule for torch.var_mean.
|
|
# When that happens, please change the example to use an
|
|
# operator that doesn't have a batching rule implemented.
|
|
op_using_fallback = torch.var_mean
|
|
vmap(op_using_fallback)(torch.rand(3))
|
|
|
|
def test_vmap_fallback_check(self):
|
|
@self._wrap_method_with_vmap_fallback_check
|
|
def no_fallback(self):
|
|
pass
|
|
|
|
# One day we'll implement a batching rule for torch.var_mean.
|
|
# When that happens, please change the example to use an
|
|
# operator that doesn't have a batching rule implemented.
|
|
op_using_fallback = torch.var_mean
|
|
|
|
@self._wrap_method_with_vmap_fallback_check
|
|
def uses_fallback(self):
|
|
vmap(op_using_fallback)(torch.rand(3))
|
|
|
|
no_fallback(self)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
uses_fallback(self)
|
|
|
|
|
|
class TestVmapOperators(Namespace.TestVmapBase):
|
|
def _vmap_test(self, *args, **kwargs):
|
|
return _vmap_test(self, *args, **kwargs)
|
|
|
|
def _vmap_view_test(self, *args, **kwargs):
|
|
self._vmap_test(*args, **kwargs, check_view=True)
|
|
|
|
def _test_unary(self, op, getter, device):
|
|
test = self._vmap_test
|
|
B0, B1 = 7, 11
|
|
|
|
# Single vmap, various in_dims / out_dims
|
|
test(op, [getter([B0, 3], device)])
|
|
test(op, [getter([2, 5, B0, 3], device)], in_dims=2)
|
|
test(op, [getter([2, 5, B0, 3], device)], in_dims=2, out_dims=2)
|
|
|
|
# Doubly nested vmap
|
|
test(vmap(op), [getter([B0, B1], device)])
|
|
test(vmap(op), [getter([B1, 2, 5, B0, 3], device)], in_dims=2)
|
|
test(vmap(op, in_dims=2), [getter([2, 5, B0, B1, 3], device)],
|
|
in_dims=2, out_dims=2)
|
|
|
|
def test_unary_pointwise_ops(self):
|
|
cases = [
|
|
(torch.abs, TensorFactory.randn),
|
|
(torch.acos, TensorFactory.rand),
|
|
(torch.asin, TensorFactory.rand),
|
|
(torch.atan, TensorFactory.rand),
|
|
(torch.ceil, TensorFactory.randn),
|
|
(torch.cos, TensorFactory.rand),
|
|
(torch.cosh, TensorFactory.rand),
|
|
(torch.digamma, TensorFactory.rand),
|
|
(torch.exp, TensorFactory.randn),
|
|
(torch.expm1, TensorFactory.randn),
|
|
(torch.floor, TensorFactory.randn),
|
|
(torch.frac, TensorFactory.randn),
|
|
(torch.lgamma, TensorFactory.rand),
|
|
(torch.log, TensorFactory.randp1),
|
|
(torch.log10, TensorFactory.randp1),
|
|
(torch.log1p, TensorFactory.randp1),
|
|
(torch.log2, TensorFactory.randp1),
|
|
(torch.neg, TensorFactory.randn),
|
|
(torch.reciprocal, TensorFactory.randp1),
|
|
(torch.relu, TensorFactory.randn),
|
|
(torch.round, TensorFactory.randn),
|
|
(torch.rsqrt, TensorFactory.randp1),
|
|
(torch.sigmoid, TensorFactory.randn),
|
|
(torch.sign, TensorFactory.randn),
|
|
(torch.sin, TensorFactory.rand),
|
|
(torch.sinh, TensorFactory.rand),
|
|
(torch.sqrt, TensorFactory.rand),
|
|
(torch.tan, TensorFactory.rand),
|
|
(torch.tanh, TensorFactory.rand),
|
|
(torch.trunc, TensorFactory.randn),
|
|
]
|
|
for op, getter in cases:
|
|
self._test_unary(op, getter, 'cpu')
|
|
|
|
def test_binary_pointwise_ops(self):
|
|
def get_number(getter):
|
|
return getter([]).item()
|
|
|
|
def make_case(op, input_getter=TensorFactory.randn):
|
|
return (op, input_getter)
|
|
|
|
cases = [
|
|
# Basic arithmetic
|
|
make_case(torch.add),
|
|
make_case(lambda x, y: x + y),
|
|
make_case(torch.sub),
|
|
make_case(lambda x, y: x - y),
|
|
make_case(torch.mul),
|
|
make_case(lambda x, y: x * y),
|
|
make_case(torch.div, input_getter=TensorFactory.randp1),
|
|
make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1),
|
|
make_case(torch.pow, input_getter=TensorFactory.randp1),
|
|
make_case(lambda x, y: x ** y, input_getter=TensorFactory.randp1),
|
|
]
|
|
test = self._vmap_test
|
|
|
|
for op, getter in cases:
|
|
device = 'cpu'
|
|
B0, B1 = 7, 11
|
|
|
|
# Single vmap: op(Tensor, Tensor)
|
|
test(op, (getter([B0, 3], device), getter([B0, 3], device)))
|
|
test(op, (getter([B0], device), getter([B0, 2, 3], device)))
|
|
test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1))
|
|
test(op, (getter([B0], device), getter([2, B0, 3], device)),
|
|
in_dims=(0, 1), out_dims=1)
|
|
test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None))
|
|
test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(0, None))
|
|
|
|
# Nested vmap: op(Tensor, Tensor)
|
|
test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device)))
|
|
test(vmap(op, in_dims=(None, 0)),
|
|
(getter([B0, 2, 3], device), getter([B1, 3], device)), in_dims=(0, None))
|
|
|
|
# Python number overload: op(Tensor, Number) (and vice-versa)
|
|
number = get_number(getter)
|
|
self._test_unary(lambda t: op(t, number), getter, device)
|
|
number = get_number(getter)
|
|
self._test_unary(lambda t: op(number, t), getter, device)
|
|
|
|
# Type promotion: op(Logical Scalar Tensor, Logical Scalar Tensor)
|
|
test(op, (getter([B0], device), getter([B0], device, dtype=torch.double)))
|
|
test(op, (getter([B0], device, dtype=torch.double), getter([B0], device)))
|
|
test(op, (getter([B0], device), getter([B0], device)))
|
|
|
|
# Type promotion: op(Tensor, Logical Scalar Tensor) (and vice-versa)
|
|
test(op, (getter([B0, 2], device), getter([B0], device, torch.double)))
|
|
test(op, (getter([B0], device, torch.double), getter([B0, 2], device)))
|
|
|
|
if not torch.cuda.is_available():
|
|
continue
|
|
|
|
# TODO(rzou): fix the following
|
|
# # Test cross-device scalars
|
|
# number = get_number(getter)
|
|
# self._test_unary(lambda t: op(t, number), getter, device='cuda')
|
|
# self._test_unary(lambda t: op(number, t), getter, device='cuda')
|
|
# self._test_unary(lambda t: op(t, torch.tensor(number)), getter, device='cuda')
|
|
|
|
def test_bmm(self):
|
|
op = torch.bmm
|
|
test = self._vmap_test
|
|
B0, B1 = 7, 11
|
|
|
|
# shape mismatch
|
|
msg = "Shape mismatch"
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, in_dims=(0, None))(torch.randn(B0, 3, 3, 2), torch.randn(2, 2))
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2))
|
|
|
|
# left arg is vmapped
|
|
test(op, (torch.rand(B0, 2, 3, 5), torch.rand(2, 5, 3)), in_dims=(0, None))
|
|
test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 2, 3, 5), torch.rand(2, 5, 3)),
|
|
in_dims=(1, None))
|
|
|
|
# right arg is vmapped
|
|
test(op, (torch.rand(2, 5, 3), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0))
|
|
test(vmap(op, in_dims=(None, 0)), (torch.rand(2, 5, 3), torch.rand(B1, B0, 2, 3, 5)),
|
|
in_dims=(None, 1))
|
|
|
|
# both args are vmapped
|
|
test(op, (torch.rand(B0, 2, 3, 5), torch.rand(B0, 2, 5, 3)))
|
|
test(vmap(op), (torch.rand(B1, B0, 2, 3, 5), torch.rand(B0, B1, 2, 5, 3)), in_dims=(1, 0))
|
|
test(vmap(op, in_dims=(0, None)),
|
|
(torch.rand(B1, 2, 3, 5), torch.rand(B0, 2, 5, 3)), in_dims=(None, 0))
|
|
|
|
def test_cat(self):
|
|
test = self._vmap_test
|
|
B0, B1 = 5, 7
|
|
|
|
# Quick hack b/c vmap can't accept a list of tensors as an argument
|
|
def get_op(dim):
|
|
def op(*tensors):
|
|
return torch.cat(tensors, dim=dim)
|
|
return op
|
|
|
|
test(get_op(0), (torch.rand(B0, 2), torch.rand(B0, 3)))
|
|
test(get_op(0), (torch.rand(2), torch.rand(B0, 3)), in_dims=(None, 0))
|
|
test(get_op(0), (torch.rand(2, 17), torch.rand(3, 17, B0)), in_dims=(None, 2))
|
|
test(get_op(-1), (torch.rand(17, 2), torch.rand(17, 3, B0)), in_dims=(None, 2))
|
|
test(vmap(get_op(0), in_dims=(0, None)),
|
|
(torch.rand(B1, 2), torch.rand(B0, 3)), in_dims=(None, 0))
|
|
test(vmap(get_op(0), in_dims=(0, 0)),
|
|
(torch.rand(B1, 2), torch.rand(B0, B1, 3)), in_dims=(None, 0))
|
|
|
|
def test_conj(self):
|
|
op = torch.conj
|
|
|
|
def run_test(dtype):
|
|
def get(shape):
|
|
return torch.randn(shape, dtype=dtype)
|
|
B0, B1 = 7, 11
|
|
test = self._vmap_test
|
|
|
|
# Single vmap, various in_dims / out_dims
|
|
test(op, [get([B0, 3])])
|
|
test(op, [get([2, 5, B0, 3])], in_dims=2)
|
|
test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2)
|
|
|
|
# Doubly nested vmap
|
|
test(vmap(op), [get([B0, B1])])
|
|
test(vmap(op), [get([B1, 2, 5, B0, 3])], in_dims=2)
|
|
test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])],
|
|
in_dims=2, out_dims=2)
|
|
|
|
# correctness tests
|
|
run_test(torch.float)
|
|
run_test(torch.cfloat)
|
|
|
|
# check that torch.conj on a non-complex tensor returns the same tensor
|
|
real_tensor = torch.randn(3)
|
|
result = vmap(op)(real_tensor)
|
|
self.assertEqual(result.data_ptr(), real_tensor.data_ptr())
|
|
|
|
def test_chunk(self):
|
|
test = self._vmap_view_test
|
|
op = torch.chunk
|
|
B0, B1, B2 = 7, 11, 13
|
|
|
|
# tests for torch.split(self, split_size: int, dim)
|
|
test(op, (torch.rand(B0, 2, 1024), 15, -1), in_dims=(0, None, None))
|
|
test(op, (torch.rand(2, B0, 1024), 9, 1), in_dims=(1, None, None))
|
|
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 4, 0),
|
|
in_dims=(2, None, None))
|
|
test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
|
|
(torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
|
|
|
|
def test_diagonal(self):
|
|
tensor = torch.randn(3, 5, 7, 11, 13)
|
|
test = self._vmap_view_test
|
|
op = torch.diagonal
|
|
test(op, (tensor, 1, 0, 1), in_dims=(0, None, None, None))
|
|
test(op, (tensor, 0, 2, -1), in_dims=(0, None, None, None))
|
|
test(op, (tensor, 2, 1, 2), in_dims=(1, None, None, None))
|
|
test(op, (tensor, 0, -2, -1), in_dims=(1, None, None, None), out_dims=1)
|
|
test(vmap(lambda t: op(t, 0, 0, -1)), (tensor,), in_dims=1, out_dims=1)
|
|
test(vmap(vmap(lambda t: op(t, 0, 0, 1), in_dims=1), in_dims=3),
|
|
(tensor,), in_dims=1, out_dims=1)
|
|
|
|
def test_dot(self):
|
|
op = torch.dot
|
|
test = self._vmap_test
|
|
B0, B1 = 7, 11
|
|
|
|
# shape mismatch
|
|
msg = "Shape mismatch"
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2))
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2))
|
|
|
|
# left arg is vmapped
|
|
test(op, (torch.rand(B0, 5), torch.rand(5)), in_dims=(0, None))
|
|
test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 5), torch.rand(5)),
|
|
in_dims=(1, None))
|
|
|
|
# right arg is vmapped
|
|
test(op, (torch.rand(5), torch.rand(B0, 5)), in_dims=(None, 0))
|
|
test(vmap(op, in_dims=(None, 0)), (torch.rand(5), torch.rand(B1, B0, 5)),
|
|
in_dims=(None, 1))
|
|
|
|
# both args are vmapped
|
|
test(op, (torch.rand(B0, 5), torch.rand(B0, 5)))
|
|
test(vmap(op), (torch.rand(B1, B0, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0))
|
|
test(vmap(op, in_dims=(0, None)),
|
|
(torch.rand(B1, 5), torch.rand(B0, 5)), in_dims=(None, 0))
|
|
|
|
def test_expand_as(self):
|
|
op = torch.Tensor.expand_as
|
|
test = self._vmap_view_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
test(op, (torch.rand(B0, 1, 5), torch.rand(B0, 2, 3, 5)))
|
|
test(op, (torch.rand(B0, 1, 5), torch.rand(2, 3, 5)), in_dims=(0, None))
|
|
test(op, (torch.rand(1, 5), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0))
|
|
test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B0, B1, 2, 3, 5)))
|
|
test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B1, B0, 2, 3, 5)), in_dims=(0, 1))
|
|
test(vmap(op), (torch.rand(B0, B1), torch.rand(B1, 2, 3, 5)), in_dims=(0, None))
|
|
test(vmap(vmap(op)), (torch.rand(B0, B1, B2), torch.rand(B0, B1, B2, 2, 3, 5)))
|
|
|
|
def test_is_complex(self):
|
|
ctensor = torch.randn(3, dtype=torch.cfloat)
|
|
tensor = torch.randn(3)
|
|
|
|
def foo(x):
|
|
if x.is_complex():
|
|
return torch.tensor(1)
|
|
else:
|
|
return torch.tensor(0)
|
|
|
|
self.assertEqual(vmap(foo)(ctensor), torch.tensor([1, 1, 1]))
|
|
self.assertEqual(vmap(foo)(tensor), torch.tensor([0, 0, 0]))
|
|
|
|
def test_movedim(self):
|
|
op = torch.movedim
|
|
test = self._vmap_view_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
|
|
# movedim(tensor, int, int) variant
|
|
test(op, (torch.rand(B0, 2, 5), 0, 1), in_dims=(0, None, None))
|
|
test(op, (torch.rand(2, B0, 5), 0, 1), in_dims=(1, None, None))
|
|
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 2, B0, 5), 0, 1), in_dims=(2, None, None))
|
|
test(vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)),
|
|
(torch.rand(B1, 2, B0, 5, B2), 0, 1), in_dims=(2, None, None))
|
|
|
|
# movedim(tensor, intlist, intlist) variant
|
|
test(op, (torch.rand(B0, 2, 3, 5), [1, 0], [0, 2]), in_dims=(0, None, None))
|
|
test(op, (torch.rand(2, 3, B0, 5), [1, 0], [0, 2]), in_dims=(1, None, None))
|
|
test(vmap(op, in_dims=(0, None, None)),
|
|
(torch.rand(B1, 2, B0, 5), [0, 1], [1, 0]), in_dims=(2, None, None))
|
|
test(vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)),
|
|
(torch.rand(B1, 2, B0, 5, B2), [0, 1], [1, 0]), in_dims=(2, None, None))
|
|
|
|
def test_mm(self):
|
|
op = torch.mm
|
|
test = self._vmap_test
|
|
B0, B1 = 7, 11
|
|
|
|
# shape mismatch
|
|
msg = "Shape mismatch"
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2))
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2))
|
|
|
|
# left arg is vmapped
|
|
test(op, (torch.rand(B0, 2, 5), torch.rand(5, 2)), in_dims=(0, None))
|
|
test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 2, 5), torch.rand(5, 2)),
|
|
in_dims=(1, None))
|
|
|
|
# right arg is vmapped
|
|
test(op, (torch.rand(2, 5), torch.rand(B0, 5, 2)), in_dims=(None, 0))
|
|
test(vmap(op, in_dims=(None, 0)), (torch.rand(2, 5), torch.rand(B1, B0, 5, 2)),
|
|
in_dims=(None, 1))
|
|
|
|
# both args are vmapped
|
|
test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5, 2)))
|
|
test(vmap(op), (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5, 2)), in_dims=(1, 0))
|
|
test(vmap(op, in_dims=(0, None)),
|
|
(torch.rand(B1, 2, 5), torch.rand(B0, 5, 2)), in_dims=(None, 0))
|
|
|
|
def test_mv(self):
|
|
op = torch.mv
|
|
test = self._vmap_test
|
|
B0, B1 = 7, 11
|
|
|
|
# shape mismatch
|
|
msg = "Shape mismatch"
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, in_dims=(0, None))(torch.randn(B0, 2, 2), torch.randn(2, 2))
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2))
|
|
|
|
# left arg is vmapped
|
|
test(op, (torch.rand(B0, 2, 5), torch.rand(5)), in_dims=(0, None))
|
|
test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 2, 5), torch.rand(5)),
|
|
in_dims=(1, None))
|
|
|
|
# right arg is vmapped
|
|
test(op, (torch.rand(2, 5), torch.rand(B0, 5)), in_dims=(None, 0))
|
|
test(vmap(op, in_dims=(None, 0)), (torch.rand(2, 5), torch.rand(B1, B0, 5)),
|
|
in_dims=(None, 1))
|
|
|
|
# both args are vmapped
|
|
test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5)))
|
|
test(vmap(op), (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0))
|
|
test(vmap(op, in_dims=(0, None)),
|
|
(torch.rand(B1, 2, 5), torch.rand(B0, 5)), in_dims=(None, 0))
|
|
|
|
def test_narrow(self):
|
|
op = torch.narrow
|
|
test = self._vmap_view_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
|
|
test(op, (torch.rand(B0, 2, 5), -1, 1, 3), in_dims=(0, None, None, None))
|
|
test(op, (torch.rand(2, B0, 5), 1, 1, 3), in_dims=(1, None, None, None))
|
|
test(vmap(op, in_dims=(0, None, None, None)),
|
|
(torch.rand(B1, 2, B0, 5), 1, 0, 0), in_dims=(2, None, None, None))
|
|
test(vmap(vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)),
|
|
(torch.rand(B1, 2, B0, 5, B2), -1, 2, 3), in_dims=(2, None, None, None))
|
|
|
|
def test_new_empty(self):
|
|
# Empty is non-deterministic so we just check that the shape of the
|
|
# output tensor is what we expect and that the vmap fallback isn't used.
|
|
op = Tensor.new_empty
|
|
|
|
B0, B1 = 7, 11
|
|
|
|
result = vmap(lambda x: op(x, [2, 3]))(torch.randn(B0))
|
|
self.assertEqual(result.shape, [B0, 2, 3])
|
|
|
|
result = vmap(lambda x: op(x, []))(torch.randn(B0))
|
|
self.assertEqual(result.shape, [B0])
|
|
|
|
result = vmap(vmap(lambda x: op(x, [2, 3])))(torch.randn(B0, B1))
|
|
self.assertEqual(result.shape, [B0, B1, 2, 3])
|
|
|
|
def test_new_zeros(self):
|
|
op = Tensor.new_zeros
|
|
test = functools.partial(self._vmap_test, check_propagates_grad=False)
|
|
B0, B1 = 7, 11
|
|
|
|
test(lambda x: op(x, 2, 3), (torch.rand(B0),))
|
|
test(lambda x: op(x, []), (torch.rand(B0),))
|
|
test(vmap(lambda x: op(x, 3, 5)), (torch.rand(B0, B1),))
|
|
|
|
def test_select(self):
|
|
op = torch.select
|
|
test = self._vmap_view_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
test(op, (torch.rand(B0, 2, 5), 0, 0), in_dims=(0, None, None))
|
|
test(op, (torch.rand(2, B0, 5), 1, 1), in_dims=(1, None, None))
|
|
test(vmap(lambda t: op(t, 1, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
|
|
test(vmap(vmap(lambda t: op(t, 1, 1), in_dims=1)), (torch.rand(B1, 2, B0, B2, 5),), in_dims=2)
|
|
|
|
def test_stack(self):
|
|
test = self._vmap_test
|
|
B0, B1 = 5, 7
|
|
|
|
# Quick hack b/c vmap can't accept a list of tensors as an argument
|
|
def get_op(dim):
|
|
def op(*tensors):
|
|
return torch.stack(tensors, dim=dim)
|
|
return op
|
|
|
|
test(get_op(0), (torch.rand(B0, 3), torch.rand(B0, 3)))
|
|
test(get_op(0), (torch.rand(3), torch.rand(B0, 3)), in_dims=(None, 0))
|
|
test(get_op(0), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2))
|
|
test(get_op(-1), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2))
|
|
test(vmap(get_op(0), in_dims=(0, None)),
|
|
(torch.rand(B1, 2), torch.rand(B0, 2)), in_dims=(None, 0))
|
|
test(vmap(get_op(0), in_dims=(0, 0)),
|
|
(torch.rand(B1, 2), torch.rand(B0, B1, 2)), in_dims=(None, 0))
|
|
|
|
|
|
def test_slice(self):
|
|
test = self._vmap_view_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
test(lambda t: t[0:1], (torch.rand(B0, 3, 5),))
|
|
test(lambda t: t[:, 1:3], (torch.rand(3, 5, B0),), in_dims=2)
|
|
test(vmap(lambda t: t[:, 0:1], in_dims=2), (torch.rand(3, 5, B0, B1),), in_dims=2)
|
|
test(vmap(vmap(lambda t: t[0:1], in_dims=2), in_dims=2),
|
|
(torch.rand(3, 5, B0, B1, B2),), in_dims=2)
|
|
|
|
def test_reshape(self):
|
|
test = self._vmap_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
op = torch.reshape
|
|
test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None), check_view=True)
|
|
test(op, (torch.rand(2, B0, 5), [1, 1, 10]), in_dims=(1, None), check_view=False)
|
|
test(vmap(lambda t: t.reshape([-1])), (torch.rand(B0, B1, 2, 5),), check_view=True)
|
|
test(vmap(vmap(lambda t: t.reshape([-1]), in_dims=2), in_dims=1),
|
|
(torch.rand(3, B1, 2, B2, 5, B0),), in_dims=5, check_view=False)
|
|
|
|
def test_reshape_as(self):
|
|
test = self._vmap_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
op = torch.Tensor.reshape_as
|
|
test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)), check_view=True)
|
|
test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0), check_view=True)
|
|
test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None), check_view=True)
|
|
|
|
test(op, (torch.rand(2, B0, 5), torch.rand(1, 1, 10)), in_dims=(1, None), check_view=False)
|
|
|
|
test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)), check_view=True)
|
|
test(vmap(vmap(op, in_dims=(2, None)), in_dims=(1, None)),
|
|
(torch.rand(3, B1, 2, B2, 5, B0), torch.rand(B0, 3 * 2 * 5)),
|
|
in_dims=(5, 0), check_view=False)
|
|
|
|
def test_result_type(self):
|
|
def scalar_tensor_with_dtype(op):
|
|
def wrapped(*args, **kwargs):
|
|
dtype = op(*args, **kwargs)
|
|
return torch.ones([], dtype=dtype)
|
|
return wrapped
|
|
|
|
test = self._vmap_test
|
|
op = scalar_tensor_with_dtype(torch.result_type)
|
|
|
|
B0 = 2
|
|
|
|
test(op, (torch.randn(B0), torch.randn(B0, dtype=torch.float64)),
|
|
check_propagates_grad=False)
|
|
test(op, (torch.randn(B0), torch.randint(10, [B0], dtype=torch.int64)),
|
|
check_propagates_grad=False)
|
|
|
|
test(lambda x: op(x, 1), (torch.randn(B0),), check_propagates_grad=False)
|
|
test(lambda x: op(x, 1.6), (torch.randn(B0),), check_propagates_grad=False)
|
|
|
|
test(lambda x: op(x, torch.tensor(1)), (torch.randn(B0),),
|
|
check_propagates_grad=False)
|
|
test(lambda x: op(x, torch.tensor(1.6, dtype=torch.double)),
|
|
(torch.randn(B0),), check_propagates_grad=False)
|
|
|
|
test(op, (torch.randn(B0, 2), torch.randn(B0, 2, dtype=torch.float64)),
|
|
check_propagates_grad=False)
|
|
test(op, (torch.randn(B0, 2), torch.randint(10, [B0, 2], dtype=torch.int64)),
|
|
check_propagates_grad=False)
|
|
|
|
test(lambda x: op(x, 1), (torch.randn(B0, 2),), check_propagates_grad=False)
|
|
test(lambda x: op(x, 1.6), (torch.randn(B0, 2),), check_propagates_grad=False)
|
|
|
|
test(lambda x: op(x, torch.tensor(1)), (torch.randn(B0, 2),),
|
|
check_propagates_grad=False)
|
|
test(lambda x: op(x, torch.tensor(1.6, dtype=torch.double)),
|
|
(torch.randn(B0, 2),), check_propagates_grad=False)
|
|
|
|
test(op, (torch.randn(B0, 2), torch.randn(B0, dtype=torch.float64)),
|
|
check_propagates_grad=False)
|
|
test(op, (torch.randn(B0, 2), torch.randint(10, [B0], dtype=torch.int64)),
|
|
check_propagates_grad=False)
|
|
|
|
def test_tensor_split(self):
|
|
test = self._vmap_view_test
|
|
op = torch.tensor_split
|
|
B0, B1, B2 = 7, 11, 13
|
|
|
|
# tests for torch.tensor_split(self, indices_or_sections: int, dim)
|
|
test(op, (torch.rand(B0, 2, 1024), 5, -1), in_dims=(0, None, None))
|
|
test(op, (torch.rand(2, B0, 1024), 150, 1), in_dims=(1, None, None))
|
|
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 256, 0),
|
|
in_dims=(2, None, None))
|
|
test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
|
|
(torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
|
|
|
|
# tests for torch.tensor_split(self, indices_or_sections: List[int], dim)
|
|
test(op, (torch.rand(B0, 2, 1024), [50, 100, 378, 890], -1), in_dims=(0, None, None))
|
|
test(op, (torch.rand(2, B0, 1024), [50, 100, 212, 345, 0, 378, 890], 1), in_dims=(1, None, None))
|
|
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), [50, 100, 212, 345, 0, 378, 890], 0),
|
|
in_dims=(2, None, None))
|
|
test(vmap(vmap(lambda t: op(t, [4, 8, 9, 34, 29], 1), in_dims=2)),
|
|
(torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
|
|
|
|
def test_split(self):
|
|
test = self._vmap_view_test
|
|
op = torch.split
|
|
B0, B1, B2 = 7, 11, 13
|
|
|
|
# tests for torch.split(self, split_size: int, dim)
|
|
test(op, (torch.rand(B0, 2, 1024), 101, -1), in_dims=(0, None, None))
|
|
test(op, (torch.rand(2, B0, 1024), 130, 1), in_dims=(1, None, None))
|
|
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 256, 0),
|
|
in_dims=(2, None, None))
|
|
test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
|
|
(torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
|
|
|
|
# tests for torch.split(self, split_size: List[int], dim)
|
|
test(op, (torch.rand(B0, 2, 1024), [1, 1020, 3], -1), in_dims=(0, None, None))
|
|
test(op, (torch.rand(2, B0, 1024), [100] * 10 + [24], 1), in_dims=(1, None, None))
|
|
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), [256] * 3 + [255], 0),
|
|
in_dims=(2, None, None))
|
|
test(vmap(vmap(lambda t: op(t, [4] * 8 + [8] * 4, 1), in_dims=2)),
|
|
(torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
|
|
|
|
def test_t(self):
|
|
op = torch.t
|
|
test = self._vmap_view_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
test(op, (torch.rand(B0, 2, 5),))
|
|
test(op, (torch.rand(2, B0, 5),), in_dims=1)
|
|
test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
|
|
test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2)
|
|
|
|
def test_T_numpy(self):
|
|
def op(t):
|
|
return t.T
|
|
|
|
test = self._vmap_view_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
test(op, (torch.rand(B0, 2, 3, 5),))
|
|
test(op, (torch.rand(B0),))
|
|
test(op, (torch.rand(2, B0, 3, 5),), in_dims=1)
|
|
test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
|
|
test(vmap(op), (torch.rand(B1, 2, B0, 3, 5),), in_dims=2)
|
|
test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 3, B2, 5),), in_dims=2)
|
|
|
|
def test_to(self):
|
|
test = self._vmap_test
|
|
B0, B1 = 7, 11
|
|
|
|
test(lambda t: t.to('cpu'), (torch.rand(B0),))
|
|
test(lambda t: t.to(torch.double), (torch.rand(B0),))
|
|
test(lambda t, o: t.to(o), (torch.rand(B0), torch.randn(B0, dtype=torch.float64)))
|
|
test(lambda t, o: t.to(o),
|
|
(torch.rand(B0), torch.randn(B0, dtype=torch.float64)),
|
|
in_dims=(0, None))
|
|
test(vmap(lambda t: t.to(torch.double)), (torch.rand(B0, B1, 3),))
|
|
|
|
# also test some casting methods
|
|
test(lambda t: t.double(), (torch.rand(B0),))
|
|
test(lambda t: t.float(), (torch.rand(B0),))
|
|
test(lambda t: t.int(), (torch.rand(B0),), check_propagates_grad=False)
|
|
test(lambda t: t.long(), (torch.rand(B0),), check_propagates_grad=False)
|
|
|
|
def test_unfold(self):
|
|
op = torch.Tensor.unfold
|
|
test = self._vmap_view_test
|
|
B0, B1, B2 = 3, 2, 5
|
|
|
|
test(op, (torch.rand(B0, 7, 11), 0, 2, 1), in_dims=(0, None, None, None))
|
|
test(op, (torch.rand(7, B0, 11), 1, 4, 2), in_dims=(1, None, None, None))
|
|
test(vmap(op, in_dims=(0, None, None, None)),
|
|
(torch.rand(B1, 7, B0, 11), 1, 5, 1), in_dims=(2, None, None, None))
|
|
test(vmap(vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)),
|
|
(torch.rand(B1, 7, B0, 11, B2), -1, 2, 4), in_dims=(2, None, None, None))
|
|
|
|
def test_unbind(self):
|
|
test = self._vmap_view_test
|
|
op = torch.unbind
|
|
B0, B1, B2 = 7, 11, 13
|
|
|
|
test(op, (torch.rand(B0, 2, 1024), -1), in_dims=(0, None))
|
|
test(op, (torch.rand(B0, 2, 0),))
|
|
test(op, (torch.rand(2, B0, 7), 0), in_dims=(1, None))
|
|
test(vmap(op, in_dims=(0, None)), (torch.rand(B1, 1023, B0, 5), 1),
|
|
in_dims=(2, None))
|
|
test(vmap(vmap(lambda t: op(t, dim=1), in_dims=2)),
|
|
(torch.rand(B1, 2, B0, 32, B2),), in_dims=2)
|
|
|
|
def test_view(self):
|
|
test = self._vmap_view_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
op = torch.Tensor.view
|
|
|
|
# We should error out if the view would produce an incorrect result
|
|
with self.assertRaises(RuntimeError):
|
|
vmap(op, in_dims=(1, None))(torch.rand(2, B0, 5), [10])
|
|
|
|
test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None))
|
|
test(op, (torch.rand(B0, 4, 5), [1, 2, 1, 10]), in_dims=(0, None))
|
|
test(vmap(lambda t: t.view([-1])), (torch.rand(B0, B1, 2, 5, 3),))
|
|
test(vmap(vmap(lambda t: t.reshape([-1])), in_dims=1),
|
|
(torch.rand(B2, B0, B1, 3, 2, 5),), in_dims=1)
|
|
|
|
def test_view_as(self):
|
|
test = self._vmap_view_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
op = torch.Tensor.view_as
|
|
|
|
# We should error out if the view would produce an incorrect result
|
|
with self.assertRaises(RuntimeError):
|
|
vmap(op, in_dims=(1, 0))(torch.rand(2, B0, 5), torch.rand(B0, 10))
|
|
|
|
test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)))
|
|
test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0))
|
|
test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None))
|
|
|
|
test(op, (torch.rand(B0, 4, 5), torch.rand(2, 1, 1, 10)), in_dims=(0, None))
|
|
|
|
test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)))
|
|
test(vmap(vmap(op, in_dims=(0, None)), in_dims=(0, None)),
|
|
(torch.rand(B1, B2, B0, 3, 2, 5), torch.rand(B0, 3 * 2 * 5)),
|
|
in_dims=(2, 0))
|
|
|
|
def test_no_random_op_support(self):
|
|
B0 = 2
|
|
|
|
captured = torch.rand(3)
|
|
|
|
random_ops = [
|
|
# out-of-place on BatchedTensor
|
|
(torch.bernoulli, (torch.rand(B0, 1),)),
|
|
(lambda t: torch.bernoulli(t, p=0.5), (torch.rand(B0, 1),)),
|
|
(lambda t: torch.multinomial(t, 2), (torch.rand(B0, 3),)),
|
|
(torch.normal, (torch.randn(B0, 1), torch.randn(B0, 1))),
|
|
(lambda t: torch.normal(t, 1.), (torch.randn(B0, 1),)),
|
|
(lambda t: torch.normal(0., t), (torch.randn(B0, 1),)),
|
|
(torch.poisson, (torch.rand(B0, 1),)),
|
|
(torch.rand_like, (torch.rand(B0, 1),)),
|
|
(torch.randn_like, (torch.rand(B0, 1),)),
|
|
(lambda t: torch.randint_like(t, 2), (torch.rand(B0, 1),)),
|
|
(lambda t: torch.randint_like(t, 0, 2), (torch.rand(B0, 1),)),
|
|
|
|
# out-of-place on captured tensor
|
|
(lambda t: torch.bernoulli(captured), (torch.rand(B0),)),
|
|
(lambda t: torch.bernoulli(captured, p=0.5), (torch.rand(B0),)),
|
|
(lambda t: torch.multinomial(captured, 2), (torch.rand(B0),)),
|
|
(lambda t: torch.normal(captured, captured), (torch.randn(B0),)),
|
|
(lambda t: torch.normal(captured, 1.), (torch.randn(B0),)),
|
|
(lambda t: torch.normal(0., captured), (torch.randn(B0),)),
|
|
(lambda t: torch.poisson(captured), (torch.rand(B0),)),
|
|
(lambda t: torch.rand_like(captured), (torch.rand(B0),)),
|
|
(lambda t: torch.randn_like(captured) , (torch.rand(B0),)),
|
|
(lambda t: torch.randint_like(captured, 2), (torch.rand(B0),)),
|
|
(lambda t: torch.randint_like(captured, 0, 2), (torch.rand(B0),)),
|
|
|
|
# in-place on BatchedTensor
|
|
(lambda t: t.bernoulli_(), (torch.randn(B0, 1),)),
|
|
(lambda t: t.cauchy_(), (torch.randn(B0, 1),)),
|
|
(lambda t: t.exponential_(), (torch.randn(B0, 1),)),
|
|
(lambda t: t.geometric_(0.5), (torch.randn(B0, 1),)),
|
|
(lambda t: t.log_normal_(), (torch.randn(B0, 1),)),
|
|
(lambda t: t.normal_(), (torch.randn(B0, 1),)),
|
|
(lambda t: t.random_(), (torch.randn(B0, 1),)),
|
|
(lambda t: t.random_(0, 2), (torch.randn(B0, 1),)),
|
|
(lambda t: t.random_(2), (torch.randn(B0, 1),)),
|
|
(lambda t: t.uniform_(), (torch.randn(B0, 1),)),
|
|
|
|
# in-place on captured tensor
|
|
(lambda t: captured.bernoulli_(), (torch.randn(B0),)),
|
|
(lambda t: captured.cauchy_(), (torch.randn(B0),)),
|
|
(lambda t: captured.exponential_(), (torch.randn(B0),)),
|
|
(lambda t: captured.geometric_(0.5), (torch.randn(B0),)),
|
|
(lambda t: captured.log_normal_(), (torch.randn(B0),)),
|
|
(lambda t: captured.normal_(), (torch.randn(B0),)),
|
|
(lambda t: captured.random_(), (torch.randn(B0),)),
|
|
(lambda t: captured.random_(0, 2), (torch.randn(B0),)),
|
|
(lambda t: captured.random_(2), (torch.randn(B0),)),
|
|
(lambda t: captured.uniform_(), (torch.randn(B0),)),
|
|
|
|
# factory functions
|
|
(lambda t: torch.rand(1), (torch.randn(B0),)),
|
|
(lambda t: torch.randn(1), (torch.randn(B0),)),
|
|
(lambda t: torch.randint(5, [1]), (torch.randn(B0),)),
|
|
(lambda t: torch.randperm(5), (torch.randn(B0),)),
|
|
]
|
|
for op, args in random_ops:
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
'vmap: We do not yet support calling random operations'):
|
|
vmap(op)(*args)
|
|
|
|
def construct_v(output, batch_size):
|
|
return torch.randn(batch_size, *output.shape,
|
|
dtype=output.dtype, device=output.device)
|
|
|
|
def as_tuple(x):
|
|
if isinstance(x, tuple):
|
|
return x
|
|
elif isinstance(x, list):
|
|
return tuple(x)
|
|
else:
|
|
return x,
|
|
|
|
def differentiable(args):
|
|
return tuple(arg for arg in as_tuple(args)
|
|
if isinstance(arg, torch.Tensor) and arg.requires_grad)
|
|
|
|
def _get_rand_no_zeros(*args, **kwargs):
|
|
requires_grad = kwargs.get('requires_grad', False)
|
|
kwargs_without_requires_grad = kwargs.copy()
|
|
kwargs_without_requires_grad['requires_grad'] = False
|
|
result = torch.rand(*args, **kwargs_without_requires_grad)
|
|
return result.clamp_min_(0.1).requires_grad_(requires_grad)
|
|
|
|
class TestVmapBatchedGradient(Namespace.TestVmapBase):
|
|
def _vmap_test(self, *args, **kwargs):
|
|
return _vmap_test(self, *args, **kwargs)
|
|
|
|
# Tests batched gradient computation of outputs = op(*args, **kwargs)
|
|
# by comparing it to a sequential map+stack fallback.
|
|
#
|
|
# output_process_fn: a function that maps the outputs to the part
|
|
# that should be differentiated.
|
|
# batch_size: the batch dim size for the batched grad
|
|
def _batched_grad_test(self, op, args, kwargs, output_process_fn=lambda x: x, batch_size=3):
|
|
outputs = op(*args, **kwargs)
|
|
outputs = differentiable(output_process_fn(outputs))
|
|
batched_vectors = tuple(construct_v(out, batch_size) for out in outputs)
|
|
|
|
def vector_jacobian_product(*vectors):
|
|
return torch.autograd.grad(outputs, differentiable(args), vectors,
|
|
retain_graph=True)
|
|
self._vmap_test(vector_jacobian_product, batched_vectors,
|
|
check_propagates_grad=False)
|
|
|
|
# Tests batched second grad computation of outputs = op(*args, **kwargs).
|
|
# by comparing it to a sequential map+stack fallback.
|
|
#
|
|
# output_process_fn: a function that maps the outputs to the part
|
|
# that should be differentiated.
|
|
# batch_size: the batch dim size for the batched grad
|
|
#
|
|
# NB: we only test computing batched gradients in the second gradient
|
|
# computation. One specific use case that does this is computing the hessian
|
|
# matrix of a scalar-valued function; this is useful in Bayesian Logistic
|
|
# Regression.
|
|
# It might be useful to have a test that computes batched first gradients and
|
|
# then uses those to compute batched second gradients in the future.
|
|
def _batched_grad_grad_test(self, op, args, kwargs, output_process_fn=lambda x: x, batch_size=3):
|
|
outputs = op(*args, **kwargs)
|
|
outputs = differentiable(output_process_fn(outputs))
|
|
ones = tuple(torch.ones_like(out) for out in outputs)
|
|
# Same thing as summing together all of the outputs and calling .backward()
|
|
first_grads = torch.autograd.grad(outputs, differentiable(args), ones,
|
|
create_graph=True)
|
|
first_grads = differentiable(first_grads)
|
|
self.assertNotEqual(
|
|
len(first_grads), 0, "None of the first grads depend on the input!")
|
|
|
|
batched_vectors = tuple(construct_v(grad, batch_size) for grad in first_grads)
|
|
|
|
def vector_hessian_product(*vectors):
|
|
outputs = torch.autograd.grad(first_grads, differentiable(args), vectors,
|
|
retain_graph=True, allow_unused=True)
|
|
outputs = tuple(out for out in outputs if out is not None)
|
|
assert len(outputs) > 0
|
|
return outputs
|
|
|
|
self._vmap_test(vector_hessian_product, batched_vectors,
|
|
check_propagates_grad=False)
|
|
|
|
def _test_arithmetic(self, op, device, test_grad_grad=True):
|
|
x = torch.randn(2, 3, requires_grad=True, device=device)
|
|
y = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
|
|
scalar = 3.14
|
|
self._batched_grad_test(op, (x, y), {})
|
|
self._batched_grad_test(op, (scalar, y), {})
|
|
self._batched_grad_test(op, (x, scalar), {})
|
|
|
|
if test_grad_grad:
|
|
self._batched_grad_grad_test(op, (x, y), {})
|
|
|
|
def test_add(self, device):
|
|
self._test_arithmetic(torch.add, device, test_grad_grad=False)
|
|
self._test_arithmetic(lambda x, y: x + y, device, test_grad_grad=False)
|
|
|
|
def test_sub(self, device):
|
|
self._test_arithmetic(torch.sub, device, test_grad_grad=False)
|
|
self._test_arithmetic(lambda x, y: x - y, device, test_grad_grad=False)
|
|
|
|
def test_mul(self, device):
|
|
self._test_arithmetic(torch.mul, device)
|
|
self._test_arithmetic(lambda x, y: x * y, device)
|
|
|
|
def test_div(self, device):
|
|
self._test_arithmetic(torch.div, device)
|
|
self._test_arithmetic(lambda x, y: x / y, device)
|
|
|
|
def test_expand(self, device):
|
|
x = torch.randn(2, 3, device=device, requires_grad=True)
|
|
|
|
def op(x):
|
|
return x.expand(5, 5, 2, 3)
|
|
self._batched_grad_test(op, (x,), {})
|
|
|
|
def test_lgamma(self, device):
|
|
x = torch.randn(2, 3, requires_grad=True, device=device)
|
|
self._batched_grad_test(Tensor.lgamma, (x,), {})
|
|
self._batched_grad_grad_test(Tensor.lgamma, (x,), {})
|
|
|
|
def test_log(self, device):
|
|
x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
|
|
self._batched_grad_test(torch.log, (x,), {})
|
|
self._batched_grad_grad_test(torch.log, (x,), {})
|
|
|
|
def test_logsumexp(self, device):
|
|
x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
|
|
|
|
def op(x):
|
|
return torch.logsumexp(x, -1)
|
|
|
|
self._batched_grad_test(op, (x,), {})
|
|
self._batched_grad_grad_test(op, (x,), {})
|
|
|
|
def test_log1p(self, device):
|
|
x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
|
|
self._batched_grad_test(torch.log1p, (x,), {})
|
|
self._batched_grad_grad_test(torch.log1p, (x,), {})
|
|
|
|
def test_permute(self, device):
|
|
x = torch.randn(2, 3, 5, requires_grad=True, device=device)
|
|
|
|
def op(x):
|
|
return x.permute(2, 0, 1)
|
|
|
|
self._batched_grad_test(op, (x,), {})
|
|
|
|
def test_reshape(self, device):
|
|
x = torch.randn(2, 3, 5, requires_grad=True, device=device)
|
|
|
|
def op(x):
|
|
return x.reshape([2 * 3, 5])
|
|
|
|
self._batched_grad_test(op, (x,), {})
|
|
|
|
def test_sigmoid(self, device):
|
|
x = torch.randn(2, 3, requires_grad=True, device=device)
|
|
self._batched_grad_test(Tensor.sigmoid, (x,), {})
|
|
self._batched_grad_grad_test(Tensor.sigmoid, (x,), {})
|
|
|
|
def test_stack(self, device):
|
|
x = torch.randn(2, 3, device=device, requires_grad=True)
|
|
y = torch.randn(2, 3, device=device, requires_grad=True)
|
|
|
|
def op(x, y):
|
|
return torch.stack([x, y])
|
|
self._batched_grad_test(op, (x, y), {})
|
|
|
|
def test_select(self, device):
|
|
x = torch.randn(2, 3, device=device, requires_grad=True)
|
|
self._batched_grad_test(lambda x: x[1], (x,), {})
|
|
self._batched_grad_test(lambda x: x.select(1, 2), (x,), {})
|
|
self._batched_grad_test(lambda x: x.select(-1, 0), (x,), {})
|
|
|
|
def test_slice(self, device):
|
|
x = torch.randn(2, 3, 5, device=device, requires_grad=True)
|
|
self._batched_grad_test(lambda x: x[0:1], (x,), {})
|
|
self._batched_grad_test(lambda x: x[:, 1:3], (x,), {})
|
|
self._batched_grad_test(lambda x: x[..., 1:3], (x,), {})
|
|
|
|
def test_diagonal(self, device):
|
|
x = torch.randn(4, 5, device=device, requires_grad=True)
|
|
self._batched_grad_test(lambda x: x.diagonal(1, 0, 1), (x,), {})
|
|
|
|
x = torch.randn(3, 4, 5, device=device, requires_grad=True)
|
|
self._batched_grad_test(lambda x: x.diagonal(0, -1, -2), (x,), {})
|
|
|
|
instantiate_device_type_tests(
|
|
TestVmapBatchedGradient,
|
|
globals(),
|
|
# Excluding ROCM
|
|
except_for='cuda' if TEST_WITH_ROCM else None,
|
|
)
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|