Revert "[dynamo] Mark a vt unspecialized nn module variable source earlier (#154780)"

This reverts commit cc96febb97.

Reverted https://github.com/pytorch/pytorch/pull/154780 on behalf of https://github.com/seemethere due to This fails internal testing see, https://fburl.com/diff/b0yuxk4w ([comment](https://github.com/pytorch/pytorch/pull/154780#issuecomment-2940381691))
This commit is contained in:
PyTorch MergeBot 2025-06-04 15:03:34 +00:00
parent a0f2544502
commit a99a01a677
8 changed files with 19 additions and 39 deletions

View File

@ -138,7 +138,7 @@ hf_Bert_large,pass,0
hf_BigBird,pass,24 hf_BigBird,pass,18

1 name accuracy graph_breaks
138
139
140
141
142
143
144

View File

@ -122,7 +122,7 @@ hf_Bert_large,pass,0
hf_BigBird,pass,24 hf_BigBird,pass,18

1 name accuracy graph_breaks
122
123
124
125
126
127
128

View File

@ -1299,7 +1299,6 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
self.assertTrue(torch._dynamo.testing.same(r, m(i))) self.assertTrue(torch._dynamo.testing.same(r, m(i)))
self.assertEqual(cnt.op_count, 6) self.assertEqual(cnt.op_count, 6)
@patch.object(torch._dynamo.config, "allow_unspec_int_on_nn_module", True)
def test_self_mutating1(self): def test_self_mutating1(self):
m1 = torch.nn.Linear(10, 10) m1 = torch.nn.Linear(10, 10)
m2 = SelfMutatingModule(m1) m2 = SelfMutatingModule(m1)

View File

@ -7454,13 +7454,14 @@ def forward(self, l_inp_, l_tmp_):
self.assertExpectedInline( self.assertExpectedInline(
backend.graphs[0].code.strip(), backend.graphs[0].code.strip(),
"""\ """\
def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor): def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor, L_self_num : torch.SymInt):
l_a_ = L_a_ l_a_ = L_a_
l_b_ = L_b_ l_b_ = L_b_
l_self_num = L_self_num
tensor = torch.tensor([True]) tensor = torch.tensor([True])
cond_true_0 = self.cond_true_0 cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0 cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = s97 = None cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, l_self_num, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = s97 = None
getitem = cond[0]; cond = None getitem = cond[0]; cond = None
return (getitem,)""", # noqa: B950 return (getitem,)""", # noqa: B950
) )

View File

@ -2402,10 +2402,6 @@ def is_int_specialization_case(value, source):
source.guard_source().is_unspecialized_builtin_nn_module() source.guard_source().is_unspecialized_builtin_nn_module()
and not config.allow_unspec_int_on_nn_module and not config.allow_unspec_int_on_nn_module
) )
or (
source.guard_source().is_unspecialized_nn_module()
and not config.allow_unspec_int_on_nn_module
)
or is_from_defaults(source) or is_from_defaults(source)
# TODO: Delete this condition when rollout is done. NB: this # TODO: Delete this condition when rollout is done. NB: this
# condition never evaluates True in open source # condition never evaluates True in open source

View File

@ -115,8 +115,6 @@ from ..source import (
Source, Source,
SubclassAttrListSource, SubclassAttrListSource,
TupleIteratorGetItemSource, TupleIteratorGetItemSource,
UnspecializedBuiltinNNModuleSource,
UnspecializedNNModuleSource,
) )
from ..utils import ( from ..utils import (
_extract_tensor_dict, _extract_tensor_dict,
@ -436,10 +434,7 @@ class VariableBuilder:
return cached_vt return cached_vt
vt = self._wrap(value) vt = self._wrap(value)
vt.source = self.source
if vt.source is None:
vt.source = self.source
if ( if (
self._can_lift_attrs_to_inputs(vt) self._can_lift_attrs_to_inputs(vt)
and value not in self.tx.output.side_effects and value not in self.tx.output.side_effects
@ -1719,6 +1714,7 @@ class VariableBuilder:
value = value.get_base() value = value.get_base()
self.source = AttrProxySource(self.source) self.source = AttrProxySource(self.source)
self.install_guards(GuardBuilder.TYPE_MATCH)
if torch._dynamo.config.inline_inbuilt_nn_modules: if torch._dynamo.config.inline_inbuilt_nn_modules:
freezing = is_parameter_freezing() freezing = is_parameter_freezing()
@ -1753,23 +1749,12 @@ class VariableBuilder:
# this will get cleaned up once compile ends # this will get cleaned up once compile ends
self.tx.output.nn_modules[self.name] = value self.tx.output.nn_modules[self.name] = value
if ( if value.__module__.startswith(("torch.nn.", "torch.ao.")) or getattr(
value.__module__.startswith(("torch.nn.modules", "torch.ao.")) value.__class__, "_dynamo_marked_static", False
and not value.__module__.startswith("torch.nn.modules.container") ):
) or getattr(value.__class__, "_dynamo_marked_static", False): result = UnspecializedBuiltinNNModuleVariable(value, source=self.source)
new_source = self.source
if config.inline_inbuilt_nn_modules:
# Export corner case - look at test_repros.py test_inlining_cornercase
new_source = UnspecializedBuiltinNNModuleSource(self.source)
result = UnspecializedBuiltinNNModuleVariable(value, source=new_source)
install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH))
else: else:
new_source = self.source result = UnspecializedNNModuleVariable(value, source=self.source)
if config.inline_inbuilt_nn_modules:
# Export corner case - look at test_repros.py test_inlining_cornercase
new_source = UnspecializedNNModuleSource(self.source)
result = UnspecializedNNModuleVariable(value, source=new_source)
install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH))
if not SideEffects.cls_supports_mutation_side_effects(type(value)): if not SideEffects.cls_supports_mutation_side_effects(type(value)):
# don't allow STORE_ATTR mutation with custom __setattr__ # don't allow STORE_ATTR mutation with custom __setattr__
@ -2142,10 +2127,6 @@ class VariableBuilder:
) )
proxy.node.meta["grapharg"] = grapharg proxy.node.meta["grapharg"] = grapharg
# TODO - Why do we need to set the source of the np ndarray vt back to
# original source. Many tests fails.
numpy_ndarray_variable.source = self.source
return numpy_ndarray_variable return numpy_ndarray_variable
def wrap_symint( def wrap_symint(

View File

@ -2658,8 +2658,8 @@ class AutoFunctionalizeHigherOrderVariable(TorchHigherOrderOperatorVariable):
class FlexAttentionBackwardHighOrderVariable(TorchHigherOrderOperatorVariable): class FlexAttentionBackwardHighOrderVariable(TorchHigherOrderOperatorVariable):
def proxy_submod(self, tx, arg): def proxy_submod(self, tx, arg):
assert isinstance(arg.source.base, DictGetItemSource) assert isinstance(arg.source, DictGetItemSource)
submod_name = tx.output.install_subgraph(arg.source.base.index, arg.value) submod_name = tx.output.install_subgraph(arg.source.index, arg.value)
p_submod = make_attr(tx, submod_name) p_submod = make_attr(tx, submod_name)
set_example_value(p_submod.node, arg.value) set_example_value(p_submod.node, arg.value)
return p_submod return p_submod

View File

@ -48,6 +48,7 @@ from ..source import (
FSDPNNModuleSource, FSDPNNModuleSource,
GetItemSource, GetItemSource,
NNModuleSource, NNModuleSource,
UnspecializedBuiltinNNModuleSource,
UnspecializedNNModuleSource, UnspecializedNNModuleSource,
) )
from ..utils import ( from ..utils import (
@ -890,7 +891,8 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
self.nn_module_stack_source = self.source self.nn_module_stack_source = self.source
def _wrap_source(self, attr_source): def _wrap_source(self, attr_source):
# the vt is already wrapped with UnspecializedNNModuleSource if not isinstance(attr_source, UnspecializedNNModuleSource):
return UnspecializedNNModuleSource(attr_source)
return attr_source return attr_source
def get_nn_module_stack_source(self): def get_nn_module_stack_source(self):
@ -1191,7 +1193,8 @@ class UnspecializedBuiltinNNModuleVariable(UnspecializedNNModuleVariable):
""" """
def _wrap_source(self, attr_source): def _wrap_source(self, attr_source):
# vt is already wrapped with the UnspecializedBuiltinNNModuleSource if not isinstance(attr_source, UnspecializedBuiltinNNModuleSource):
return UnspecializedBuiltinNNModuleSource(attr_source)
return attr_source return attr_source