[dynamo] Trace nn.Module __delattr__ (#159969)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159969
Approved by: https://github.com/atalman, https://github.com/malfet, https://github.com/StrongerXi
This commit is contained in:
Animesh Jain 2025-08-06 13:36:02 -07:00 committed by PyTorch MergeBot
parent cb4b29b754
commit 3daef4d128
2 changed files with 61 additions and 4 deletions

View File

@ -3422,6 +3422,58 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
compiled_mod = torch.compile(mod, backend="eager")
compiled_mod(x)
def test_trace_delattr(self):
TMP_PREFIX = "_tmp_"
def pre_forward_rename_hook(module: torch.nn.Module, _input: torch.Tensor):
param_name = "weight"
original_param = getattr(module, param_name)
setattr(module, TMP_PREFIX + param_name, original_param)
new_param = original_param + 1.0
delattr(module, param_name)
setattr(module, param_name, new_param)
def post_forward_restore_hook(
module: torch.nn.Module, _input: torch.Tensor, _output: torch.Tensor
):
param_name = "weight"
tmp_param_name = TMP_PREFIX + param_name
original_param = getattr(module, tmp_param_name)
delattr(module, param_name)
setattr(module, param_name, original_param)
delattr(module, tmp_param_name)
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 5)
def forward(self, x):
return self.linear(x)
torch.manual_seed(0)
model = SimpleModel()
model.linear.register_forward_pre_hook(pre_forward_rename_hook)
model.linear.register_forward_hook(post_forward_restore_hook)
input_tensor = torch.randn(4, 10)
eager_output = model(input_tensor)
assert hasattr(model.linear, "weight")
assert not hasattr(model.linear, "_tmp_weight")
torch.manual_seed(0)
model_to_compile = SimpleModel()
model_to_compile.linear.register_forward_pre_hook(pre_forward_rename_hook)
model_to_compile.linear.register_forward_hook(post_forward_restore_hook)
compiled_model = torch.compile(model_to_compile, fullgraph=True)
compiled_output = compiled_model(input_tensor)
assert hasattr(model.linear, "weight")
assert not hasattr(compiled_model.linear, "_tmp_weight")
torch.testing.assert_close(eager_output, compiled_output)
devices = ["cuda", "hpu", "xpu"]
instantiate_device_type_tests(

View File

@ -909,7 +909,11 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
@functools.cache
def _nn_module_method_ids():
# Allow __setattr__ to fall through to base class handler
supported = {torch.nn.Module.__setattr__, torch.nn.Module.__init__}
supported = {
torch.nn.Module.__setattr__,
torch.nn.Module.__init__,
torch.nn.Module.__delattr__,
}
return {
id(x.__code__)
for x in torch.nn.Module.__dict__.values()
@ -1091,9 +1095,10 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
# Handle submodules
self.is_state_mutated = True
if method is torch.nn.Module.__setattr__ and isinstance(
args[1], variables.DeletedVariable
):
if (
method is torch.nn.Module.__setattr__
and isinstance(args[1], variables.DeletedVariable)
) or method is torch.nn.Module.__delattr__:
# Trace through __delattr__ to track mutations on the module
# members like `_modules``.
return tx.inline_user_function_return(