mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Allow setting attribute to NestedUserFunctionVariable (#146505)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146505 Approved by: https://github.com/zou3519
This commit is contained in:
parent
aae4c0729e
commit
44e6464914
|
|
@ -3247,6 +3247,26 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 9)
|
||||
|
||||
def test_nesteduserfunction_setattr(self):
|
||||
x = 0
|
||||
|
||||
def update(y):
|
||||
def wrapper():
|
||||
x += y
|
||||
|
||||
return wrapper
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
w = update(123)
|
||||
w.__wrapped__ = x
|
||||
return t.sin(), w
|
||||
|
||||
t = torch.randn(2)
|
||||
y, w = fn(t)
|
||||
self.assertEqual(y, t.sin())
|
||||
self.assertEqual(w.__wrapped__, x)
|
||||
|
||||
def test_object_setattr(self):
|
||||
@dataclasses.dataclass
|
||||
class A:
|
||||
|
|
|
|||
|
|
@ -1911,6 +1911,7 @@ class BuiltinVariable(VariableTracker):
|
|||
variables.PlacementVariable,
|
||||
variables.NamedTupleVariable,
|
||||
variables.UserDefinedObjectVariable,
|
||||
variables.NestedUserFunctionVariable,
|
||||
variables.ExceptionVariable,
|
||||
),
|
||||
):
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ from ..utils import (
|
|||
istype,
|
||||
make_cell,
|
||||
)
|
||||
from .base import typestr, ValueMutationNew, VariableTracker
|
||||
from .base import AttributeMutationNew, typestr, ValueMutationNew, VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
|
||||
|
||||
|
|
@ -1013,6 +1013,8 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
|
|||
wrapped_fn=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if kwargs.get("mutation_type") is None:
|
||||
kwargs.update(mutation_type=AttributeMutationNew())
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(fn_name.as_python_constant(), str)
|
||||
assert isinstance(code.as_python_constant(), types.CodeType)
|
||||
|
|
@ -1059,6 +1061,20 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
|
|||
func.__annotations__ = annotations
|
||||
return func
|
||||
|
||||
def call_setattr(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name_var: VariableTracker,
|
||||
val: VariableTracker,
|
||||
):
|
||||
tx.output.side_effects.store_attr(self, name_var.value, val)
|
||||
return ConstantVariable(None)
|
||||
|
||||
def call_method(self, tx, name, args, kwargs):
|
||||
if name == "__setattr__":
|
||||
return self.call_setattr(tx, *args)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def has_closure(self):
|
||||
return self.closure is not None
|
||||
|
||||
|
|
@ -1137,6 +1153,19 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
|
|||
codegen.extend_output(create_rot_n(2))
|
||||
codegen.extend_output(create_call_function(1, True))
|
||||
|
||||
# codegen attributes
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
tx = InstructionTranslator.current_tx()
|
||||
if tx.output.side_effects.has_pending_mutation(self):
|
||||
for name, value in tx.output.side_effects.store_attr_mutations[
|
||||
self
|
||||
].items():
|
||||
codegen.dup_top()
|
||||
codegen(value)
|
||||
codegen.extend_output(create_rot_n(2))
|
||||
codegen.store_attr(name)
|
||||
|
||||
|
||||
class SkipFunctionVariable(VariableTracker):
|
||||
_nonvar_fields = {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user