mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][easy] Simple fixes to prepare for nn module guards (#125316)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125316 Approved by: https://github.com/williamwen42 ghstack dependencies: #125275
This commit is contained in:
parent
0b70026d3b
commit
a13a0a2479
|
|
@ -1961,6 +1961,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||||
self.assertEqual(compiled_func(inp).item(), 16)
|
self.assertEqual(compiled_func(inp).item(), 16)
|
||||||
self.assertRegex(failure_reason, r"^___check_obj_id\(L\['m'\]._forward_hooks")
|
self.assertRegex(failure_reason, r"^___check_obj_id\(L\['m'\]._forward_hooks")
|
||||||
|
|
||||||
|
@patch.object(torch._dynamo.config, "guard_nn_modules", False)
|
||||||
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", True)
|
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", True)
|
||||||
def test_hooks_skip_guards(self):
|
def test_hooks_skip_guards(self):
|
||||||
class TestModule(torch.nn.Module):
|
class TestModule(torch.nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -1475,7 +1475,9 @@ class GuardBuilder(GuardBuilderBase):
|
||||||
self._produce_guard_code(guard, [shape_guard], shape_env=True)
|
self._produce_guard_code(guard, [shape_guard], shape_env=True)
|
||||||
|
|
||||||
def TENSOR_MATCH(self, guard: Guard, value=None):
|
def TENSOR_MATCH(self, guard: Guard, value=None):
|
||||||
if guard.is_nn_module() or match_on_id_for_tensor(guard):
|
if (
|
||||||
|
not torch._dynamo.config.guard_nn_modules and guard.is_nn_module()
|
||||||
|
) or match_on_id_for_tensor(guard):
|
||||||
self.ID_MATCH(guard)
|
self.ID_MATCH(guard)
|
||||||
else:
|
else:
|
||||||
if isinstance(value, TensorWeakRef):
|
if isinstance(value, TensorWeakRef):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user