mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Added check for torch.where on CPU that both arguments have same dtype (#30662)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30662 Cherry picked from: https://github.com/pytorch/pytorch/pull/29081 Test Plan: Imported from OSS Differential Revision: D18782295 Pulled By: nairbv fbshipit-source-id: 897ab25ddf8819ca34f5e86c5d3f41debb56cb04 Co-authored-by: ifedan
This commit is contained in:
parent
56dd2836ec
commit
a376dd344c
|
|
@ -134,6 +134,7 @@ std::vector<Tensor> where(const Tensor& condition) {
|
|||
}
|
||||
|
||||
Tensor _s_where_cpu(const Tensor& condition, const Tensor& self, const Tensor& other) {
|
||||
TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype());
|
||||
Tensor ret = at::empty(self.sizes(), self.options());
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(ret.scalar_type(), "where_cpu", [&] {
|
||||
where_cpu<scalar_t>(ret, condition, self, other);
|
||||
|
|
|
|||
|
|
@ -6232,6 +6232,38 @@ class TestNN(NNTestCase):
|
|||
inp = torch.randn(4, 5, device='cuda', requires_grad=True)
|
||||
gradgradcheck(F.pdist, (inp,))
|
||||
|
||||
def test_cosine_embedding_loss_with_diff_type(self):
|
||||
for device in device_():
|
||||
input1 = torch.tensor([[2, 3, 4], [6, 2, 4]], dtype=torch.double, device=device)
|
||||
input2 = torch.tensor([[2, 3, 5], [3, 2, 1]], dtype=torch.double, device=device)
|
||||
target = torch.tensor([1, -1], dtype=torch.int, device=device)
|
||||
expected = torch.nn.functional.cosine_embedding_loss(input1, input2, target)
|
||||
for dt1 in torch.testing.get_all_math_dtypes(device):
|
||||
for dt2 in torch.testing.get_all_math_dtypes(device):
|
||||
for dt3 in torch.testing.get_all_math_dtypes(device):
|
||||
# dt3 is used as dtype for target = [1, -1], so let's skip unsigned type
|
||||
if dt3 == torch.uint8:
|
||||
continue
|
||||
input1 = input1.to(dt1)
|
||||
input2 = input2.to(dt2)
|
||||
target = target.to(dt3)
|
||||
result = torch.nn.functional.cosine_embedding_loss(input1, input2, target)
|
||||
self.assertEqual(result.item(), expected.item(), 0.001)
|
||||
|
||||
def test_kl_div_with_diff_type(self):
|
||||
for device in device_():
|
||||
input = torch.tensor([[2, 3, 5], [3, 2, 1]], dtype=torch.double, device=device)
|
||||
target = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.double, device=device)
|
||||
expected = torch.nn.functional.kl_div(input, target)
|
||||
for input_dtype in torch.testing.get_all_math_dtypes(device):
|
||||
for target_dtype in [torch.float32, torch.float64, torch.float16]:
|
||||
if (torch.device(device).type == 'cpu' and target_dtype == torch.float16):
|
||||
continue
|
||||
input = input.to(input_dtype)
|
||||
target = target.to(target_dtype)
|
||||
result = torch.nn.functional.kl_div(input, target)
|
||||
self.assertEqual(result.item(), expected.item(), 0.001)
|
||||
|
||||
def test_cosine_embedding_loss_no_reduce(self):
|
||||
input1 = torch.randn(15, 10, requires_grad=True)
|
||||
input2 = torch.randn(15, 10, requires_grad=True)
|
||||
|
|
|
|||
|
|
@ -763,6 +763,45 @@ class _TestTorchMixin(object):
|
|||
res = torch.where(a > 0)
|
||||
self.assertEqual(1, len(res))
|
||||
|
||||
def test_where_tensor(self):
|
||||
def rand_tensor(size, dtype, device):
|
||||
if dtype.is_floating_point:
|
||||
return torch.rand(size=size, dtype=dtype, device=device)
|
||||
elif dtype == torch.uint8:
|
||||
return torch.randint(1, 5, size=size, dtype=dtype, device=device)
|
||||
elif dtype == torch.bool:
|
||||
return torch.randint(0, 1, size=size, dtype=dtype, device=device).bool()
|
||||
else:
|
||||
return torch.randint(-5, 5, size=size, dtype=dtype, device=device)
|
||||
|
||||
def get_tensor(size, dtype, device, contiguous):
|
||||
if not contiguous and len(size) < 2:
|
||||
raise RuntimeError("Unable to generate non contiguous tensor with size < 2")
|
||||
t = rand_tensor(size, dtype, device)
|
||||
if contiguous:
|
||||
return t
|
||||
else:
|
||||
return t.transpose(0, 1)
|
||||
|
||||
height = 5
|
||||
width = 5
|
||||
for device in torch.testing.get_all_device_types():
|
||||
for dt1 in torch.testing.get_all_math_dtypes(device):
|
||||
for dt2 in torch.testing.get_all_math_dtypes(device):
|
||||
for contiguous in [True, False]:
|
||||
x1 = get_tensor((height, width), dt1, device, contiguous)
|
||||
x2 = get_tensor((height, width), dt2, device, contiguous)
|
||||
if dt1 != dt2:
|
||||
self.assertRaisesRegex(RuntimeError, "expected scalar type", lambda: torch.where(x1 == 1, x1, x2))
|
||||
else:
|
||||
if x1.is_floating_point():
|
||||
condition = (x1 < 0.5)
|
||||
else:
|
||||
condition = (x1 == 1)
|
||||
expected = condition.to(x1.dtype) * x1 + (~condition).to(x2.dtype) * x2
|
||||
result = torch.where(condition, x1, x2)
|
||||
self.assertEqual(expected, result)
|
||||
|
||||
def test_all_any_with_dim(self):
|
||||
def test(x):
|
||||
r1 = x.prod(dim=0, keepdim=False).byte()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user