mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[dynamo][guards] Consider tensors as immutable for dict tag matches (#139560)"
This reverts commit b09eb6ed6a.
Reverted https://github.com/pytorch/pytorch/pull/139560 on behalf of https://github.com/anijain2305 due to internal test failures ([comment](https://github.com/pytorch/pytorch/pull/139560#issuecomment-2486344859))
This commit is contained in:
parent
7ced49d2cc
commit
d276688da6
|
|
@ -3166,55 +3166,6 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
|||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
@patch.object(
|
||||
torch._dynamo.config, "skip_tensor_guards_with_matching_dict_tags", False
|
||||
)
|
||||
@patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", True)
|
||||
def test_param_requires_grad(self):
|
||||
def adjust_model(model):
|
||||
to_freeze = model.num_iter % 2 == 0
|
||||
if to_freeze:
|
||||
for param in model.layer2.parameters():
|
||||
param.requires_grad = False
|
||||
else:
|
||||
for param in model.layer2.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, output_size):
|
||||
super().__init__()
|
||||
|
||||
self.layer1 = torch.nn.Linear(hidden_size, hidden_size)
|
||||
self.layer2 = torch.nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
self.num_iter = 0
|
||||
|
||||
def forward(self, x):
|
||||
x = self.layer2(x + self.layer1.bias)
|
||||
|
||||
self.num_iter += 1
|
||||
return x
|
||||
|
||||
input_size = 1024
|
||||
hidden_size = 1024
|
||||
output_size = 1
|
||||
num_samples = 2048
|
||||
features = torch.randn(num_samples, input_size)
|
||||
|
||||
model = MyModule(input_size, hidden_size, output_size)
|
||||
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
opt_model = torch.compile(model, backend=cnt, fullgraph=True)
|
||||
|
||||
for _ in range(3):
|
||||
model.zero_grad(True)
|
||||
adjust_model(model)
|
||||
res = opt_model(features)
|
||||
res.sum().backward()
|
||||
|
||||
# Check that we have recompiled twice, which leads to 3 frames
|
||||
self.assertEqual(cnt.frame_count, 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -331,10 +331,6 @@ skip_nnmodule_hook_guards = True
|
|||
# notice and lead to incorrect result.
|
||||
skip_no_tensor_aliasing_guards_on_parameters = True
|
||||
|
||||
# Considers a tensor immutable if it is one of the values of a dictionary, and
|
||||
# the dictionary tag is same across invocation calls.
|
||||
skip_tensor_guards_with_matching_dict_tags = True
|
||||
|
||||
# If True, raises exception if TorchDynamo is called with a context manager
|
||||
raise_on_ctx_manager_usage = True
|
||||
|
||||
|
|
|
|||
|
|
@ -887,11 +887,6 @@ std::string get_exception_message() {
|
|||
}
|
||||
|
||||
bool is_immutable_object(py::handle example_value) {
|
||||
static py::object config_module = py::module_::import("torch._dynamo.config");
|
||||
bool is_tensor_immutable =
|
||||
config_module.attr("skip_tensor_guards_with_matching_dict_tags")
|
||||
.cast<bool>();
|
||||
|
||||
if (PyTuple_Check(example_value.ptr())) {
|
||||
// Check that each element is immutable
|
||||
for (Py_ssize_t i = 0; i < PyTuple_Size(example_value.ptr()); ++i) {
|
||||
|
|
@ -902,11 +897,10 @@ bool is_immutable_object(py::handle example_value) {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
return PyLong_Check(example_value.ptr()) ||
|
||||
PyFloat_Check(example_value.ptr()) || PyBool_Check(example_value.ptr()) ||
|
||||
PyUnicode_Check(example_value.ptr()) ||
|
||||
(is_tensor_immutable && THPVariable_Check(example_value.ptr()));
|
||||
THPVariable_Check(example_value.ptr());
|
||||
}
|
||||
|
||||
bool is_parameter(py::handle tensor) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user