[dynamo][easy] Simple fixes to prepare for nn module guards (#125316)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125316
Approved by: https://github.com/williamwen42
ghstack dependencies: #125275
This commit is contained in:
Animesh Jain 2024-05-02 00:22:14 -07:00 committed by PyTorch MergeBot
parent 0b70026d3b
commit a13a0a2479
2 changed files with 4 additions and 1 deletions

View File

@ -1961,6 +1961,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
self.assertEqual(compiled_func(inp).item(), 16) self.assertEqual(compiled_func(inp).item(), 16)
self.assertRegex(failure_reason, r"^___check_obj_id\(L\['m'\]._forward_hooks") self.assertRegex(failure_reason, r"^___check_obj_id\(L\['m'\]._forward_hooks")
@patch.object(torch._dynamo.config, "guard_nn_modules", False)
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", True) @patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", True)
def test_hooks_skip_guards(self): def test_hooks_skip_guards(self):
class TestModule(torch.nn.Module): class TestModule(torch.nn.Module):

View File

@ -1475,7 +1475,9 @@ class GuardBuilder(GuardBuilderBase):
self._produce_guard_code(guard, [shape_guard], shape_env=True) self._produce_guard_code(guard, [shape_guard], shape_env=True)
def TENSOR_MATCH(self, guard: Guard, value=None): def TENSOR_MATCH(self, guard: Guard, value=None):
if guard.is_nn_module() or match_on_id_for_tensor(guard): if (
not torch._dynamo.config.guard_nn_modules and guard.is_nn_module()
) or match_on_id_for_tensor(guard):
self.ID_MATCH(guard) self.ID_MATCH(guard)
else: else:
if isinstance(value, TensorWeakRef): if isinstance(value, TensorWeakRef):