mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46204 Test Plan: Imported from OSS Reviewed By: izdeby Differential Revision: D24259500 Pulled By: bdhirsh fbshipit-source-id: 223f8a07da4e4121009fc0a8b6760d90eef089b3
This commit is contained in:
parent
0c5cd8c2b9
commit
00c779a92b
|
|
@ -152,7 +152,7 @@ void quantile_impl(
|
|||
} else if (dim == self.dim() - 1) {
|
||||
sorted = std::get<0>(self.sort());
|
||||
} else {
|
||||
sorted = std::get<0>(self.unsqueeze(-1).transpose_(dim, -1).sort());
|
||||
sorted = std::get<0>(self.unsqueeze(-1).transpose(dim, -1).sort());
|
||||
}
|
||||
|
||||
// Treat q as a 1D tensor for the following computations
|
||||
|
|
|
|||
|
|
@ -1068,7 +1068,8 @@ class TestFreezing(JitTestCase):
|
|||
inp = torch.ones(1, 8, 32, 32)
|
||||
out1 = fmod.forward(inp)
|
||||
# FIXME: frozen module mutated from outside (original module).
|
||||
smod.weight[0, 0, 0, 0] += 100.0
|
||||
with torch.no_grad():
|
||||
smod.weight[0, 0, 0, 0] += 100.0
|
||||
out2 = fmod.forward(inp)
|
||||
out3 = smod(inp)
|
||||
self.assertNotEqual(out1, out2)
|
||||
|
|
|
|||
|
|
@ -1538,7 +1538,7 @@ class TestTracer(JitTestCase):
|
|||
x[i, :] = torch.zeros(4)
|
||||
return x
|
||||
|
||||
self.checkTrace(foo, (torch.rand(3, 4),))
|
||||
self.checkTrace(foo, (torch.rand(3, 4),), inputs_require_grads=False)
|
||||
|
||||
def test_trace_checker_inplace_on_view(self):
|
||||
def foo(x):
|
||||
|
|
|
|||
|
|
@ -3520,6 +3520,16 @@ class TestAutograd(TestCase):
|
|||
test()
|
||||
self.assertEqual(dealloc[0], 1)
|
||||
|
||||
def test_inplace_view_leaf_errors(self):
|
||||
# Issue #21875: Fail faster (when we try to modify the view vs. in backward())
|
||||
x = torch.zeros(1, requires_grad=True)
|
||||
y = x.view_as(x)
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"a view of a leaf Variable that "
|
||||
"requires grad is being used in "
|
||||
"an in-place operation."):
|
||||
y.add_(1)
|
||||
|
||||
def test_inplace_view_backward(self):
|
||||
# Issue #10532: Make sure that this does not raise RuntimeError.
|
||||
net = nn.Sequential(
|
||||
|
|
@ -3564,13 +3574,10 @@ class TestAutograd(TestCase):
|
|||
s.backward()
|
||||
self.assertEqual(s, torch.tensor(1.0))
|
||||
|
||||
# Issue 23502: Ensure RuntimeError for modification of SavedVariable.
|
||||
# Issue #21875: Fail faster (when we try to modify the view vs. in backward())
|
||||
a = torch.rand(10, requires_grad=True).narrow(0, 0, 10)
|
||||
b = a.relu_()
|
||||
c = b.add_(100)
|
||||
del b
|
||||
with self.assertRaises(RuntimeError):
|
||||
c.sum().backward(torch.ones(1, requires_grad=True))
|
||||
b = a.relu_()
|
||||
|
||||
def test_mul_out(self):
|
||||
a = torch.randn(2, 2, requires_grad=True)
|
||||
|
|
@ -4421,7 +4428,9 @@ for shape in [(1,), ()]:
|
|||
if fn_id == "view_of_temp":
|
||||
# This will be fixed after the deprecation cycle and the warning becomes
|
||||
# an error.
|
||||
with self.assertRaisesRegex(RuntimeError, "Jacobian mismatch for output 0"):
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"a view of a leaf Variable that requires grad "
|
||||
"is being used in an in-place operation."):
|
||||
gradcheck(fn, (a, b))
|
||||
else:
|
||||
# This works but the custom backward is not called (or called with partial)
|
||||
|
|
@ -4435,7 +4444,13 @@ for shape in [(1,), ()]:
|
|||
bw_called[0] = 0
|
||||
ga_nz[0] = True # For the case where the backward is called
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
fn(a, b).backward()
|
||||
if inplace and output_is_a_view and fn_id != "one_output":
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"a view of a leaf Variable that requires grad "
|
||||
"is being used in an in-place operation."):
|
||||
fn(a, b).backward()
|
||||
else:
|
||||
fn(a, b).backward()
|
||||
|
||||
expected_called = 1
|
||||
expected_ga_nz = True
|
||||
|
|
|
|||
|
|
@ -1552,7 +1552,8 @@ class TestNN(NNTestCase):
|
|||
m = nn.Linear(20, 10).float()
|
||||
mw = m.weight[:]
|
||||
m.double()
|
||||
mw[0][0] = 5
|
||||
with torch.no_grad():
|
||||
mw[0][0] = 5
|
||||
self.assertTrue(mw[0][0].dtype == torch.float)
|
||||
self.assertTrue(mw._base[0][0].dtype == torch.double)
|
||||
|
||||
|
|
@ -1565,7 +1566,8 @@ class TestNN(NNTestCase):
|
|||
m = nn.Linear(20, 10).float()
|
||||
mw = m.weight[:]
|
||||
m.double()
|
||||
mw[0][0] = 5
|
||||
with torch.no_grad():
|
||||
mw[0][0] = 5
|
||||
self.assertTrue(mw[0][0] == mw._base[0][0])
|
||||
|
||||
# Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
|
||||
|
|
@ -3815,8 +3817,9 @@ class TestNN(NNTestCase):
|
|||
|
||||
# use sequential to verify nesting
|
||||
m = nn.Sequential(CustomState())
|
||||
m[0].param[0] = 10
|
||||
m[0].sub.weight[0, 0] = 555
|
||||
with torch.no_grad():
|
||||
m[0].param[0] = 10
|
||||
m[0].sub.weight[0, 0] = 555
|
||||
state_dict = m.state_dict()
|
||||
self.assertEqual(state_dict["0.serialized"].item(), 11)
|
||||
self.assertIn("0.sub.weight", state_dict)
|
||||
|
|
@ -11373,7 +11376,8 @@ class TestNNDeviceType(NNTestCase):
|
|||
m = nn.Linear(20, 10)
|
||||
mw = m.weight[:]
|
||||
m.to(device)
|
||||
mw[0][0] = 5
|
||||
with torch.no_grad():
|
||||
mw[0][0] = 5
|
||||
self.assertTrue(mw[0][0] == mw._base[0][0])
|
||||
|
||||
# Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
|
||||
|
|
|
|||
|
|
@ -47,6 +47,10 @@ inline void check_inplace(const Tensor& tensor) {
|
|||
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(impl::get_autograd_meta(var));
|
||||
// This can throw or warn
|
||||
handle_view_on_rebase(diff_view_meta);
|
||||
if (tensor._base().is_leaf()) {
|
||||
AT_ERROR(
|
||||
"a view of a leaf Variable that requires grad is being used in an in-place operation.");
|
||||
}
|
||||
}
|
||||
if (var.is_leaf()) {
|
||||
AT_ERROR(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user