mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Trace nn.Module __delattr__ (#159969)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159969 Approved by: https://github.com/atalman, https://github.com/malfet, https://github.com/StrongerXi
This commit is contained in:
parent
cb4b29b754
commit
3daef4d128
|
|
@ -3422,6 +3422,58 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
|||
compiled_mod = torch.compile(mod, backend="eager")
|
||||
compiled_mod(x)
|
||||
|
||||
def test_trace_delattr(self):
|
||||
TMP_PREFIX = "_tmp_"
|
||||
|
||||
def pre_forward_rename_hook(module: torch.nn.Module, _input: torch.Tensor):
|
||||
param_name = "weight"
|
||||
original_param = getattr(module, param_name)
|
||||
setattr(module, TMP_PREFIX + param_name, original_param)
|
||||
new_param = original_param + 1.0
|
||||
delattr(module, param_name)
|
||||
setattr(module, param_name, new_param)
|
||||
|
||||
def post_forward_restore_hook(
|
||||
module: torch.nn.Module, _input: torch.Tensor, _output: torch.Tensor
|
||||
):
|
||||
param_name = "weight"
|
||||
tmp_param_name = TMP_PREFIX + param_name
|
||||
original_param = getattr(module, tmp_param_name)
|
||||
delattr(module, param_name)
|
||||
setattr(module, param_name, original_param)
|
||||
delattr(module, tmp_param_name)
|
||||
|
||||
class SimpleModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(10, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = SimpleModel()
|
||||
|
||||
model.linear.register_forward_pre_hook(pre_forward_rename_hook)
|
||||
model.linear.register_forward_hook(post_forward_restore_hook)
|
||||
|
||||
input_tensor = torch.randn(4, 10)
|
||||
|
||||
eager_output = model(input_tensor)
|
||||
assert hasattr(model.linear, "weight")
|
||||
assert not hasattr(model.linear, "_tmp_weight")
|
||||
|
||||
torch.manual_seed(0)
|
||||
model_to_compile = SimpleModel()
|
||||
model_to_compile.linear.register_forward_pre_hook(pre_forward_rename_hook)
|
||||
model_to_compile.linear.register_forward_hook(post_forward_restore_hook)
|
||||
|
||||
compiled_model = torch.compile(model_to_compile, fullgraph=True)
|
||||
compiled_output = compiled_model(input_tensor)
|
||||
assert hasattr(model.linear, "weight")
|
||||
assert not hasattr(compiled_model.linear, "_tmp_weight")
|
||||
torch.testing.assert_close(eager_output, compiled_output)
|
||||
|
||||
|
||||
devices = ["cuda", "hpu", "xpu"]
|
||||
instantiate_device_type_tests(
|
||||
|
|
|
|||
|
|
@ -909,7 +909,11 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
|||
@functools.cache
|
||||
def _nn_module_method_ids():
|
||||
# Allow __setattr__ to fall through to base class handler
|
||||
supported = {torch.nn.Module.__setattr__, torch.nn.Module.__init__}
|
||||
supported = {
|
||||
torch.nn.Module.__setattr__,
|
||||
torch.nn.Module.__init__,
|
||||
torch.nn.Module.__delattr__,
|
||||
}
|
||||
return {
|
||||
id(x.__code__)
|
||||
for x in torch.nn.Module.__dict__.values()
|
||||
|
|
@ -1091,9 +1095,10 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
|||
# Handle submodules
|
||||
self.is_state_mutated = True
|
||||
|
||||
if method is torch.nn.Module.__setattr__ and isinstance(
|
||||
args[1], variables.DeletedVariable
|
||||
):
|
||||
if (
|
||||
method is torch.nn.Module.__setattr__
|
||||
and isinstance(args[1], variables.DeletedVariable)
|
||||
) or method is torch.nn.Module.__delattr__:
|
||||
# Trace through __delattr__ to track mutations on the module
|
||||
# members like `_modules``.
|
||||
return tx.inline_user_function_return(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user