mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][easy] Simple fixes to prepare for nn module guards (#125316)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125316 Approved by: https://github.com/williamwen42 ghstack dependencies: #125275
This commit is contained in:
parent
0b70026d3b
commit
a13a0a2479
|
|
@ -1961,6 +1961,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(compiled_func(inp).item(), 16)
|
||||
self.assertRegex(failure_reason, r"^___check_obj_id\(L\['m'\]._forward_hooks")
|
||||
|
||||
@patch.object(torch._dynamo.config, "guard_nn_modules", False)
|
||||
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", True)
|
||||
def test_hooks_skip_guards(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -1475,7 +1475,9 @@ class GuardBuilder(GuardBuilderBase):
|
|||
self._produce_guard_code(guard, [shape_guard], shape_env=True)
|
||||
|
||||
def TENSOR_MATCH(self, guard: Guard, value=None):
|
||||
if guard.is_nn_module() or match_on_id_for_tensor(guard):
|
||||
if (
|
||||
not torch._dynamo.config.guard_nn_modules and guard.is_nn_module()
|
||||
) or match_on_id_for_tensor(guard):
|
||||
self.ID_MATCH(guard)
|
||||
else:
|
||||
if isinstance(value, TensorWeakRef):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user