Added check for torch.where on CPU that both arguments have same dtype (#30662)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30662

Cherry picked from: https://github.com/pytorch/pytorch/pull/29081

Test Plan: Imported from OSS

Differential Revision: D18782295

Pulled By: nairbv

fbshipit-source-id: 897ab25ddf8819ca34f5e86c5d3f41debb56cb04

Co-authored-by: ifedan
This commit is contained in:
Brian Vaughan 2019-12-03 15:17:02 -08:00 committed by Facebook Github Bot
parent 56dd2836ec
commit a376dd344c
3 changed files with 72 additions and 0 deletions

View File

@ -134,6 +134,7 @@ std::vector<Tensor> where(const Tensor& condition) {
}
Tensor _s_where_cpu(const Tensor& condition, const Tensor& self, const Tensor& other) {
TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype());
Tensor ret = at::empty(self.sizes(), self.options());
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(ret.scalar_type(), "where_cpu", [&] {
where_cpu<scalar_t>(ret, condition, self, other);

View File

@ -6232,6 +6232,38 @@ class TestNN(NNTestCase):
inp = torch.randn(4, 5, device='cuda', requires_grad=True)
gradgradcheck(F.pdist, (inp,))
def test_cosine_embedding_loss_with_diff_type(self):
for device in device_():
input1 = torch.tensor([[2, 3, 4], [6, 2, 4]], dtype=torch.double, device=device)
input2 = torch.tensor([[2, 3, 5], [3, 2, 1]], dtype=torch.double, device=device)
target = torch.tensor([1, -1], dtype=torch.int, device=device)
expected = torch.nn.functional.cosine_embedding_loss(input1, input2, target)
for dt1 in torch.testing.get_all_math_dtypes(device):
for dt2 in torch.testing.get_all_math_dtypes(device):
for dt3 in torch.testing.get_all_math_dtypes(device):
# dt3 is used as dtype for target = [1, -1], so let's skip unsigned type
if dt3 == torch.uint8:
continue
input1 = input1.to(dt1)
input2 = input2.to(dt2)
target = target.to(dt3)
result = torch.nn.functional.cosine_embedding_loss(input1, input2, target)
self.assertEqual(result.item(), expected.item(), 0.001)
def test_kl_div_with_diff_type(self):
for device in device_():
input = torch.tensor([[2, 3, 5], [3, 2, 1]], dtype=torch.double, device=device)
target = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.double, device=device)
expected = torch.nn.functional.kl_div(input, target)
for input_dtype in torch.testing.get_all_math_dtypes(device):
for target_dtype in [torch.float32, torch.float64, torch.float16]:
if (torch.device(device).type == 'cpu' and target_dtype == torch.float16):
continue
input = input.to(input_dtype)
target = target.to(target_dtype)
result = torch.nn.functional.kl_div(input, target)
self.assertEqual(result.item(), expected.item(), 0.001)
def test_cosine_embedding_loss_no_reduce(self):
input1 = torch.randn(15, 10, requires_grad=True)
input2 = torch.randn(15, 10, requires_grad=True)

View File

@ -763,6 +763,45 @@ class _TestTorchMixin(object):
res = torch.where(a > 0)
self.assertEqual(1, len(res))
def test_where_tensor(self):
def rand_tensor(size, dtype, device):
if dtype.is_floating_point:
return torch.rand(size=size, dtype=dtype, device=device)
elif dtype == torch.uint8:
return torch.randint(1, 5, size=size, dtype=dtype, device=device)
elif dtype == torch.bool:
return torch.randint(0, 1, size=size, dtype=dtype, device=device).bool()
else:
return torch.randint(-5, 5, size=size, dtype=dtype, device=device)
def get_tensor(size, dtype, device, contiguous):
if not contiguous and len(size) < 2:
raise RuntimeError("Unable to generate non contiguous tensor with size < 2")
t = rand_tensor(size, dtype, device)
if contiguous:
return t
else:
return t.transpose(0, 1)
height = 5
width = 5
for device in torch.testing.get_all_device_types():
for dt1 in torch.testing.get_all_math_dtypes(device):
for dt2 in torch.testing.get_all_math_dtypes(device):
for contiguous in [True, False]:
x1 = get_tensor((height, width), dt1, device, contiguous)
x2 = get_tensor((height, width), dt2, device, contiguous)
if dt1 != dt2:
self.assertRaisesRegex(RuntimeError, "expected scalar type", lambda: torch.where(x1 == 1, x1, x2))
else:
if x1.is_floating_point():
condition = (x1 < 0.5)
else:
condition = (x1 == 1)
expected = condition.to(x1.dtype) * x1 + (~condition).to(x2.dtype) * x2
result = torch.where(condition, x1, x2)
self.assertEqual(expected, result)
def test_all_any_with_dim(self):
def test(x):
r1 = x.prod(dim=0, keepdim=False).byte()