Allow setting attribute to NestedUserFunctionVariable (#146505)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146505
Approved by: https://github.com/zou3519
This commit is contained in:
Guilherme Leobas 2025-03-20 11:51:04 -03:00 committed by PyTorch MergeBot
parent aae4c0729e
commit 44e6464914
3 changed files with 51 additions and 1 deletions

View File

@ -3247,6 +3247,26 @@ utils_device.CURRENT_DEVICE == None""".split(
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 9)
def test_nesteduserfunction_setattr(self):
x = 0
def update(y):
def wrapper():
x += y
return wrapper
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
w = update(123)
w.__wrapped__ = x
return t.sin(), w
t = torch.randn(2)
y, w = fn(t)
self.assertEqual(y, t.sin())
self.assertEqual(w.__wrapped__, x)
def test_object_setattr(self):
@dataclasses.dataclass
class A:

View File

@ -1911,6 +1911,7 @@ class BuiltinVariable(VariableTracker):
variables.PlacementVariable,
variables.NamedTupleVariable,
variables.UserDefinedObjectVariable,
variables.NestedUserFunctionVariable,
variables.ExceptionVariable,
),
):

View File

@ -64,7 +64,7 @@ from ..utils import (
istype,
make_cell,
)
from .base import typestr, ValueMutationNew, VariableTracker
from .base import AttributeMutationNew, typestr, ValueMutationNew, VariableTracker
from .constant import ConstantVariable
@ -1013,6 +1013,8 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
wrapped_fn=None,
**kwargs,
) -> None:
if kwargs.get("mutation_type") is None:
kwargs.update(mutation_type=AttributeMutationNew())
super().__init__(**kwargs)
assert isinstance(fn_name.as_python_constant(), str)
assert isinstance(code.as_python_constant(), types.CodeType)
@ -1059,6 +1061,20 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
func.__annotations__ = annotations
return func
def call_setattr(
self,
tx: "InstructionTranslator",
name_var: VariableTracker,
val: VariableTracker,
):
tx.output.side_effects.store_attr(self, name_var.value, val)
return ConstantVariable(None)
def call_method(self, tx, name, args, kwargs):
if name == "__setattr__":
return self.call_setattr(tx, *args)
return super().call_method(tx, name, args, kwargs)
def has_closure(self):
return self.closure is not None
@ -1137,6 +1153,19 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
codegen.extend_output(create_rot_n(2))
codegen.extend_output(create_call_function(1, True))
# codegen attributes
from torch._dynamo.symbolic_convert import InstructionTranslator
tx = InstructionTranslator.current_tx()
if tx.output.side_effects.has_pending_mutation(self):
for name, value in tx.output.side_effects.store_attr_mutations[
self
].items():
codegen.dup_top()
codegen(value)
codegen.extend_output(create_rot_n(2))
codegen.store_attr(name)
class SkipFunctionVariable(VariableTracker):
_nonvar_fields = {