diff --git a/test/test_maskedtensor.py b/test/test_maskedtensor.py index db1ffbc38c1..03c05c7ea6d 100644 --- a/test/test_maskedtensor.py +++ b/test/test_maskedtensor.py @@ -236,6 +236,32 @@ class TestBasics(TestCase): _compare_mt_t(sparse_mt, data) _compare_mt_t(mt.grad, data.grad) + def test_to_device(self, device): + for sample in _generate_sample_data(device=device): + data = sample.input + mask = sample.kwargs["mask"] + mt = masked_tensor(data, mask, requires_grad=True) + + new_device = torch.device("cuda") if device != "cuda" and torch.cuda.is_available() else torch.device("cpu") + mt_device = mt.to(new_device) + + self.assertEqual(mt_device.device.type, new_device.type) + self.assertEqual(mt_device.get_mask().device.type, new_device.type) + self.assertEqual(mt_device.get_data().device.type, new_device.type) + + def test_to_dtype(self, device): + for sample in _generate_sample_data(device=device): + data = sample.input + mask = sample.kwargs["mask"] + mt = masked_tensor(data, mask, requires_grad=True) + + new_dtype = torch.float64 if data.dtype == torch.float32 else torch.float32 + mt_dtype = mt.to(new_dtype) + + self.assertEqual(mt_dtype.dtype, new_dtype) + self.assertEqual(mt_dtype.get_mask().dtype, torch.bool) + self.assertEqual(mt_dtype.get_data().dtype, new_dtype) + def test_to_dense(self, device): samples = _generate_sample_data( device=device, diff --git a/torch/masked/maskedtensor/_ops_refs.py b/torch/masked/maskedtensor/_ops_refs.py index 719df7eac46..8135f149a1b 100644 --- a/torch/masked/maskedtensor/_ops_refs.py +++ b/torch/masked/maskedtensor/_ops_refs.py @@ -351,7 +351,10 @@ def _apply_fn_on_data(func, *args, **kwargs): @register_dispatch_func([torch.ops.aten._to_copy]) def _to_copy(func, *args, **kwargs): new_data = func(_get_data(args[0]), *args[1:], **kwargs) - return MaskedTensor(new_data, _maybe_get_mask(args[0])) + cloned_kwargs = kwargs.copy() + cloned_kwargs["dtype"] = torch.bool + new_mask = func(_maybe_get_mask(args[0]), *args[1:], **cloned_kwargs) + return MaskedTensor(new_data, new_mask) @register_dispatch_func([torch.ops.aten._softmax])