set data permits requires_grad=True on integer tensor

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78436

Approved by: https://github.com/albanD, https://github.com/soulitzer
This commit is contained in:
yuguo68 2022-05-31 18:11:31 -07:00 committed by PyTorch MergeBot
parent e41389f84b
commit efdb4192bc
4 changed files with 17 additions and 5 deletions

View File

@ -164,10 +164,10 @@ void test_move_to_dtype(const std::string& path_to_exported_script_module) {
torch::jit::Module module =
torch::jit::load(path_to_exported_script_module);
module.to(torch::kInt);
module.to(torch::kFloat16);
helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
return tensor.dtype() == torch::kInt;
return tensor.dtype() == torch::kFloat16;
});
module.to(torch::kDouble);

View File

@ -4301,6 +4301,14 @@ class TestAutograd(TestCase):
b.data = a
self.assertTrue(b_id_saved == id(b))
def test_set_data_self_requires_grad(self):
a = torch.tensor(1.0, requires_grad=True)
b = torch.tensor(2.0)
c = torch.tensor(3, dtype=torch.int64)
a.data = b
with self.assertRaisesRegex(RuntimeError, 'must be floating point or complex dtype'):
a.data = c
@unittest.skipIf(IS_WINDOWS, "Skipping because doesn't work for windows")
def test_thread_shutdown(self):
code = """import torch

View File

@ -7535,12 +7535,12 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
# raise a nice error.
with self.assertRaisesRegex(
RuntimeError,
# message includes both Double and Long
'(?=.*Double)(?=.*Long)'):
# message includes both Double and ComplexFloat
'(?=.*Double)(?=.*ComplexFloat)'):
# Calls model with a LongTensor input but DoubleTensor weights
input = torch.randn(1, 1, 1, 6, dtype=torch.double)
weight = torch.zeros(1, 1, 1, 3, dtype=torch.long)
weight = torch.zeros(1, 1, 1, 3, dtype=torch.complex64)
model = torch.nn.Conv2d(1, 1, (1, 3), stride=1, padding=0, bias=False)
model.weight.data = weight
out = model(input)

View File

@ -401,6 +401,10 @@ void VariableHooks::set_data(const at::TensorBase & self_base, const at::TensorB
_has_compatible_shallow_copy_type(self, new_data),
"Attempted to call `variable.set_data(tensor)`, but `variable` and `tensor` have incompatible tensor type.");
TORCH_CHECK(
!self.requires_grad() || isDifferentiableType(at::typeMetaToScalarType(new_data.dtype())),
"data set to a tensor that requires gradients must be floating point or complex dtype");
// Resets gradient accumulator if metadata is out of date
AutogradMeta* autograd_meta = impl::get_autograd_meta(self);
if (autograd_meta) {