mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[dynamo] Mark a vt unspecialized nn module variable source earlier (#154780)"
This reverts commit cc96febb97.
Reverted https://github.com/pytorch/pytorch/pull/154780 on behalf of https://github.com/seemethere due to This fails internal testing see, https://fburl.com/diff/b0yuxk4w ([comment](https://github.com/pytorch/pytorch/pull/154780#issuecomment-2940381691))
This commit is contained in:
parent
a0f2544502
commit
a99a01a677
|
|
@ -138,7 +138,7 @@ hf_Bert_large,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
hf_BigBird,pass,24
|
hf_BigBird,pass,18
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -122,7 +122,7 @@ hf_Bert_large,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
hf_BigBird,pass,24
|
hf_BigBird,pass,18
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -1299,7 +1299,6 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
|
||||||
self.assertTrue(torch._dynamo.testing.same(r, m(i)))
|
self.assertTrue(torch._dynamo.testing.same(r, m(i)))
|
||||||
self.assertEqual(cnt.op_count, 6)
|
self.assertEqual(cnt.op_count, 6)
|
||||||
|
|
||||||
@patch.object(torch._dynamo.config, "allow_unspec_int_on_nn_module", True)
|
|
||||||
def test_self_mutating1(self):
|
def test_self_mutating1(self):
|
||||||
m1 = torch.nn.Linear(10, 10)
|
m1 = torch.nn.Linear(10, 10)
|
||||||
m2 = SelfMutatingModule(m1)
|
m2 = SelfMutatingModule(m1)
|
||||||
|
|
|
||||||
|
|
@ -7454,13 +7454,14 @@ def forward(self, l_inp_, l_tmp_):
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
backend.graphs[0].code.strip(),
|
backend.graphs[0].code.strip(),
|
||||||
"""\
|
"""\
|
||||||
def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor):
|
def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor, L_self_num : torch.SymInt):
|
||||||
l_a_ = L_a_
|
l_a_ = L_a_
|
||||||
l_b_ = L_b_
|
l_b_ = L_b_
|
||||||
|
l_self_num = L_self_num
|
||||||
tensor = torch.tensor([True])
|
tensor = torch.tensor([True])
|
||||||
cond_true_0 = self.cond_true_0
|
cond_true_0 = self.cond_true_0
|
||||||
cond_false_0 = self.cond_false_0
|
cond_false_0 = self.cond_false_0
|
||||||
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = s97 = None
|
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, l_self_num, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = s97 = None
|
||||||
getitem = cond[0]; cond = None
|
getitem = cond[0]; cond = None
|
||||||
return (getitem,)""", # noqa: B950
|
return (getitem,)""", # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -2402,10 +2402,6 @@ def is_int_specialization_case(value, source):
|
||||||
source.guard_source().is_unspecialized_builtin_nn_module()
|
source.guard_source().is_unspecialized_builtin_nn_module()
|
||||||
and not config.allow_unspec_int_on_nn_module
|
and not config.allow_unspec_int_on_nn_module
|
||||||
)
|
)
|
||||||
or (
|
|
||||||
source.guard_source().is_unspecialized_nn_module()
|
|
||||||
and not config.allow_unspec_int_on_nn_module
|
|
||||||
)
|
|
||||||
or is_from_defaults(source)
|
or is_from_defaults(source)
|
||||||
# TODO: Delete this condition when rollout is done. NB: this
|
# TODO: Delete this condition when rollout is done. NB: this
|
||||||
# condition never evaluates True in open source
|
# condition never evaluates True in open source
|
||||||
|
|
|
||||||
|
|
@ -115,8 +115,6 @@ from ..source import (
|
||||||
Source,
|
Source,
|
||||||
SubclassAttrListSource,
|
SubclassAttrListSource,
|
||||||
TupleIteratorGetItemSource,
|
TupleIteratorGetItemSource,
|
||||||
UnspecializedBuiltinNNModuleSource,
|
|
||||||
UnspecializedNNModuleSource,
|
|
||||||
)
|
)
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
_extract_tensor_dict,
|
_extract_tensor_dict,
|
||||||
|
|
@ -436,10 +434,7 @@ class VariableBuilder:
|
||||||
return cached_vt
|
return cached_vt
|
||||||
|
|
||||||
vt = self._wrap(value)
|
vt = self._wrap(value)
|
||||||
|
vt.source = self.source
|
||||||
if vt.source is None:
|
|
||||||
vt.source = self.source
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self._can_lift_attrs_to_inputs(vt)
|
self._can_lift_attrs_to_inputs(vt)
|
||||||
and value not in self.tx.output.side_effects
|
and value not in self.tx.output.side_effects
|
||||||
|
|
@ -1719,6 +1714,7 @@ class VariableBuilder:
|
||||||
value = value.get_base()
|
value = value.get_base()
|
||||||
self.source = AttrProxySource(self.source)
|
self.source = AttrProxySource(self.source)
|
||||||
|
|
||||||
|
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||||
if torch._dynamo.config.inline_inbuilt_nn_modules:
|
if torch._dynamo.config.inline_inbuilt_nn_modules:
|
||||||
freezing = is_parameter_freezing()
|
freezing = is_parameter_freezing()
|
||||||
|
|
||||||
|
|
@ -1753,23 +1749,12 @@ class VariableBuilder:
|
||||||
# this will get cleaned up once compile ends
|
# this will get cleaned up once compile ends
|
||||||
self.tx.output.nn_modules[self.name] = value
|
self.tx.output.nn_modules[self.name] = value
|
||||||
|
|
||||||
if (
|
if value.__module__.startswith(("torch.nn.", "torch.ao.")) or getattr(
|
||||||
value.__module__.startswith(("torch.nn.modules", "torch.ao."))
|
value.__class__, "_dynamo_marked_static", False
|
||||||
and not value.__module__.startswith("torch.nn.modules.container")
|
):
|
||||||
) or getattr(value.__class__, "_dynamo_marked_static", False):
|
result = UnspecializedBuiltinNNModuleVariable(value, source=self.source)
|
||||||
new_source = self.source
|
|
||||||
if config.inline_inbuilt_nn_modules:
|
|
||||||
# Export corner case - look at test_repros.py test_inlining_cornercase
|
|
||||||
new_source = UnspecializedBuiltinNNModuleSource(self.source)
|
|
||||||
result = UnspecializedBuiltinNNModuleVariable(value, source=new_source)
|
|
||||||
install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH))
|
|
||||||
else:
|
else:
|
||||||
new_source = self.source
|
result = UnspecializedNNModuleVariable(value, source=self.source)
|
||||||
if config.inline_inbuilt_nn_modules:
|
|
||||||
# Export corner case - look at test_repros.py test_inlining_cornercase
|
|
||||||
new_source = UnspecializedNNModuleSource(self.source)
|
|
||||||
result = UnspecializedNNModuleVariable(value, source=new_source)
|
|
||||||
install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH))
|
|
||||||
|
|
||||||
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
|
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
|
||||||
# don't allow STORE_ATTR mutation with custom __setattr__
|
# don't allow STORE_ATTR mutation with custom __setattr__
|
||||||
|
|
@ -2142,10 +2127,6 @@ class VariableBuilder:
|
||||||
)
|
)
|
||||||
proxy.node.meta["grapharg"] = grapharg
|
proxy.node.meta["grapharg"] = grapharg
|
||||||
|
|
||||||
# TODO - Why do we need to set the source of the np ndarray vt back to
|
|
||||||
# original source. Many tests fails.
|
|
||||||
numpy_ndarray_variable.source = self.source
|
|
||||||
|
|
||||||
return numpy_ndarray_variable
|
return numpy_ndarray_variable
|
||||||
|
|
||||||
def wrap_symint(
|
def wrap_symint(
|
||||||
|
|
|
||||||
|
|
@ -2658,8 +2658,8 @@ class AutoFunctionalizeHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
|
|
||||||
class FlexAttentionBackwardHighOrderVariable(TorchHigherOrderOperatorVariable):
|
class FlexAttentionBackwardHighOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
def proxy_submod(self, tx, arg):
|
def proxy_submod(self, tx, arg):
|
||||||
assert isinstance(arg.source.base, DictGetItemSource)
|
assert isinstance(arg.source, DictGetItemSource)
|
||||||
submod_name = tx.output.install_subgraph(arg.source.base.index, arg.value)
|
submod_name = tx.output.install_subgraph(arg.source.index, arg.value)
|
||||||
p_submod = make_attr(tx, submod_name)
|
p_submod = make_attr(tx, submod_name)
|
||||||
set_example_value(p_submod.node, arg.value)
|
set_example_value(p_submod.node, arg.value)
|
||||||
return p_submod
|
return p_submod
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,7 @@ from ..source import (
|
||||||
FSDPNNModuleSource,
|
FSDPNNModuleSource,
|
||||||
GetItemSource,
|
GetItemSource,
|
||||||
NNModuleSource,
|
NNModuleSource,
|
||||||
|
UnspecializedBuiltinNNModuleSource,
|
||||||
UnspecializedNNModuleSource,
|
UnspecializedNNModuleSource,
|
||||||
)
|
)
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
|
|
@ -890,7 +891,8 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
||||||
self.nn_module_stack_source = self.source
|
self.nn_module_stack_source = self.source
|
||||||
|
|
||||||
def _wrap_source(self, attr_source):
|
def _wrap_source(self, attr_source):
|
||||||
# the vt is already wrapped with UnspecializedNNModuleSource
|
if not isinstance(attr_source, UnspecializedNNModuleSource):
|
||||||
|
return UnspecializedNNModuleSource(attr_source)
|
||||||
return attr_source
|
return attr_source
|
||||||
|
|
||||||
def get_nn_module_stack_source(self):
|
def get_nn_module_stack_source(self):
|
||||||
|
|
@ -1191,7 +1193,8 @@ class UnspecializedBuiltinNNModuleVariable(UnspecializedNNModuleVariable):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _wrap_source(self, attr_source):
|
def _wrap_source(self, attr_source):
|
||||||
# vt is already wrapped with the UnspecializedBuiltinNNModuleSource
|
if not isinstance(attr_source, UnspecializedBuiltinNNModuleSource):
|
||||||
|
return UnspecializedBuiltinNNModuleSource(attr_source)
|
||||||
return attr_source
|
return attr_source
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user