Revert "[dynamo][guards] Consider tensors as immutable for dict tag matches (#139560)"

This reverts commit b09eb6ed6a.

Reverted https://github.com/pytorch/pytorch/pull/139560 on behalf of https://github.com/anijain2305 due to internal test failures ([comment](https://github.com/pytorch/pytorch/pull/139560#issuecomment-2486344859))
This commit is contained in:
PyTorch MergeBot 2024-11-19 17:37:44 +00:00
parent 7ced49d2cc
commit d276688da6
3 changed files with 1 additions and 60 deletions

View File

@ -3166,55 +3166,6 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
res = opt_fn(x)
self.assertEqual(ref, res)
@patch.object(
torch._dynamo.config, "skip_tensor_guards_with_matching_dict_tags", False
)
@patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", True)
def test_param_requires_grad(self):
def adjust_model(model):
to_freeze = model.num_iter % 2 == 0
if to_freeze:
for param in model.layer2.parameters():
param.requires_grad = False
else:
for param in model.layer2.parameters():
param.requires_grad = True
class MyModule(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.layer1 = torch.nn.Linear(hidden_size, hidden_size)
self.layer2 = torch.nn.Linear(hidden_size, hidden_size)
self.num_iter = 0
def forward(self, x):
x = self.layer2(x + self.layer1.bias)
self.num_iter += 1
return x
input_size = 1024
hidden_size = 1024
output_size = 1
num_samples = 2048
features = torch.randn(num_samples, input_size)
model = MyModule(input_size, hidden_size, output_size)
cnt = torch._dynamo.testing.CompileCounter()
opt_model = torch.compile(model, backend=cnt, fullgraph=True)
for _ in range(3):
model.zero_grad(True)
adjust_model(model)
res = opt_model(features)
res.sum().backward()
# Check that we have recompiled twice, which leads to 3 frames
self.assertEqual(cnt.frame_count, 3)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -331,10 +331,6 @@ skip_nnmodule_hook_guards = True
# notice and lead to incorrect result.
skip_no_tensor_aliasing_guards_on_parameters = True
# Considers a tensor immutable if it is one of the values of a dictionary, and
# the dictionary tag is same across invocation calls.
skip_tensor_guards_with_matching_dict_tags = True
# If True, raises exception if TorchDynamo is called with a context manager
raise_on_ctx_manager_usage = True

View File

@ -887,11 +887,6 @@ std::string get_exception_message() {
}
bool is_immutable_object(py::handle example_value) {
static py::object config_module = py::module_::import("torch._dynamo.config");
bool is_tensor_immutable =
config_module.attr("skip_tensor_guards_with_matching_dict_tags")
.cast<bool>();
if (PyTuple_Check(example_value.ptr())) {
// Check that each element is immutable
for (Py_ssize_t i = 0; i < PyTuple_Size(example_value.ptr()); ++i) {
@ -902,11 +897,10 @@ bool is_immutable_object(py::handle example_value) {
}
return true;
}
return PyLong_Check(example_value.ptr()) ||
PyFloat_Check(example_value.ptr()) || PyBool_Check(example_value.ptr()) ||
PyUnicode_Check(example_value.ptr()) ||
(is_tensor_immutable && THPVariable_Check(example_value.ptr()));
THPVariable_Check(example_value.ptr());
}
bool is_parameter(py::handle tensor) {