mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
set data permits requires_grad=True on integer tensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78436 Approved by: https://github.com/albanD, https://github.com/soulitzer
This commit is contained in:
parent
e41389f84b
commit
efdb4192bc
|
|
@ -164,10 +164,10 @@ void test_move_to_dtype(const std::string& path_to_exported_script_module) {
|
|||
torch::jit::Module module =
|
||||
torch::jit::load(path_to_exported_script_module);
|
||||
|
||||
module.to(torch::kInt);
|
||||
module.to(torch::kFloat16);
|
||||
|
||||
helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
|
||||
return tensor.dtype() == torch::kInt;
|
||||
return tensor.dtype() == torch::kFloat16;
|
||||
});
|
||||
|
||||
module.to(torch::kDouble);
|
||||
|
|
|
|||
|
|
@ -4301,6 +4301,14 @@ class TestAutograd(TestCase):
|
|||
b.data = a
|
||||
self.assertTrue(b_id_saved == id(b))
|
||||
|
||||
def test_set_data_self_requires_grad(self):
|
||||
a = torch.tensor(1.0, requires_grad=True)
|
||||
b = torch.tensor(2.0)
|
||||
c = torch.tensor(3, dtype=torch.int64)
|
||||
a.data = b
|
||||
with self.assertRaisesRegex(RuntimeError, 'must be floating point or complex dtype'):
|
||||
a.data = c
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Skipping because doesn't work for windows")
|
||||
def test_thread_shutdown(self):
|
||||
code = """import torch
|
||||
|
|
|
|||
|
|
@ -7535,12 +7535,12 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||
# raise a nice error.
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
# message includes both Double and Long
|
||||
'(?=.*Double)(?=.*Long)'):
|
||||
# message includes both Double and ComplexFloat
|
||||
'(?=.*Double)(?=.*ComplexFloat)'):
|
||||
|
||||
# Calls model with a LongTensor input but DoubleTensor weights
|
||||
input = torch.randn(1, 1, 1, 6, dtype=torch.double)
|
||||
weight = torch.zeros(1, 1, 1, 3, dtype=torch.long)
|
||||
weight = torch.zeros(1, 1, 1, 3, dtype=torch.complex64)
|
||||
model = torch.nn.Conv2d(1, 1, (1, 3), stride=1, padding=0, bias=False)
|
||||
model.weight.data = weight
|
||||
out = model(input)
|
||||
|
|
|
|||
|
|
@ -401,6 +401,10 @@ void VariableHooks::set_data(const at::TensorBase & self_base, const at::TensorB
|
|||
_has_compatible_shallow_copy_type(self, new_data),
|
||||
"Attempted to call `variable.set_data(tensor)`, but `variable` and `tensor` have incompatible tensor type.");
|
||||
|
||||
TORCH_CHECK(
|
||||
!self.requires_grad() || isDifferentiableType(at::typeMetaToScalarType(new_data.dtype())),
|
||||
"data set to a tensor that requires gradients must be floating point or complex dtype");
|
||||
|
||||
// Resets gradient accumulator if metadata is out of date
|
||||
AutogradMeta* autograd_meta = impl::get_autograd_meta(self);
|
||||
if (autograd_meta) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user