diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index cc6552ecbe9..9fa76ad16b6 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -134,6 +134,7 @@ std::vector where(const Tensor& condition) { } Tensor _s_where_cpu(const Tensor& condition, const Tensor& self, const Tensor& other) { + TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype()); Tensor ret = at::empty(self.sizes(), self.options()); AT_DISPATCH_ALL_TYPES_AND_COMPLEX(ret.scalar_type(), "where_cpu", [&] { where_cpu(ret, condition, self, other); diff --git a/test/test_nn.py b/test/test_nn.py index 46c0e216825..9b73cd8ec5c 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -6232,6 +6232,38 @@ class TestNN(NNTestCase): inp = torch.randn(4, 5, device='cuda', requires_grad=True) gradgradcheck(F.pdist, (inp,)) + def test_cosine_embedding_loss_with_diff_type(self): + for device in device_(): + input1 = torch.tensor([[2, 3, 4], [6, 2, 4]], dtype=torch.double, device=device) + input2 = torch.tensor([[2, 3, 5], [3, 2, 1]], dtype=torch.double, device=device) + target = torch.tensor([1, -1], dtype=torch.int, device=device) + expected = torch.nn.functional.cosine_embedding_loss(input1, input2, target) + for dt1 in torch.testing.get_all_math_dtypes(device): + for dt2 in torch.testing.get_all_math_dtypes(device): + for dt3 in torch.testing.get_all_math_dtypes(device): + # dt3 is used as dtype for target = [1, -1], so let's skip unsigned type + if dt3 == torch.uint8: + continue + input1 = input1.to(dt1) + input2 = input2.to(dt2) + target = target.to(dt3) + result = torch.nn.functional.cosine_embedding_loss(input1, input2, target) + self.assertEqual(result.item(), expected.item(), 0.001) + + def test_kl_div_with_diff_type(self): + for device in device_(): + input = torch.tensor([[2, 3, 5], [3, 2, 1]], dtype=torch.double, device=device) + target = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.double, device=device) + expected = torch.nn.functional.kl_div(input, target) + for input_dtype in torch.testing.get_all_math_dtypes(device): + for target_dtype in [torch.float32, torch.float64, torch.float16]: + if (torch.device(device).type == 'cpu' and target_dtype == torch.float16): + continue + input = input.to(input_dtype) + target = target.to(target_dtype) + result = torch.nn.functional.kl_div(input, target) + self.assertEqual(result.item(), expected.item(), 0.001) + def test_cosine_embedding_loss_no_reduce(self): input1 = torch.randn(15, 10, requires_grad=True) input2 = torch.randn(15, 10, requires_grad=True) diff --git a/test/test_torch.py b/test/test_torch.py index 3047f51f641..d20425af170 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -763,6 +763,45 @@ class _TestTorchMixin(object): res = torch.where(a > 0) self.assertEqual(1, len(res)) + def test_where_tensor(self): + def rand_tensor(size, dtype, device): + if dtype.is_floating_point: + return torch.rand(size=size, dtype=dtype, device=device) + elif dtype == torch.uint8: + return torch.randint(1, 5, size=size, dtype=dtype, device=device) + elif dtype == torch.bool: + return torch.randint(0, 1, size=size, dtype=dtype, device=device).bool() + else: + return torch.randint(-5, 5, size=size, dtype=dtype, device=device) + + def get_tensor(size, dtype, device, contiguous): + if not contiguous and len(size) < 2: + raise RuntimeError("Unable to generate non contiguous tensor with size < 2") + t = rand_tensor(size, dtype, device) + if contiguous: + return t + else: + return t.transpose(0, 1) + + height = 5 + width = 5 + for device in torch.testing.get_all_device_types(): + for dt1 in torch.testing.get_all_math_dtypes(device): + for dt2 in torch.testing.get_all_math_dtypes(device): + for contiguous in [True, False]: + x1 = get_tensor((height, width), dt1, device, contiguous) + x2 = get_tensor((height, width), dt2, device, contiguous) + if dt1 != dt2: + self.assertRaisesRegex(RuntimeError, "expected scalar type", lambda: torch.where(x1 == 1, x1, x2)) + else: + if x1.is_floating_point(): + condition = (x1 < 0.5) + else: + condition = (x1 == 1) + expected = condition.to(x1.dtype) * x1 + (~condition).to(x2.dtype) * x2 + result = torch.where(condition, x1, x2) + self.assertEqual(expected, result) + def test_all_any_with_dim(self): def test(x): r1 = x.prod(dim=0, keepdim=False).byte()