mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[fbode-testing][dynamo][reland][inline-inbuilt-nn-modules] Mark attri… (#134136)
Shuai wants to test this internally before https://github.com/pytorch/pytorch/pull/133713 can go in. Creating a separate PR for ghmport. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134136 Approved by: https://github.com/yanboliang
This commit is contained in:
parent
8f7d66f0c3
commit
fee677eeb6
|
|
@ -2771,6 +2771,49 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
|||
fn(inp, buf, mod)
|
||||
self.assertEqual(num_compiles, 1)
|
||||
|
||||
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
|
||||
def test_mark_static_nn_module_tensor(self):
|
||||
# This test verifies that dynamo will mark
|
||||
# the nn module tensor attributes as static
|
||||
num_compiles = 0
|
||||
|
||||
def debug_compiler(gm, _):
|
||||
nonlocal num_compiles
|
||||
num_compiles += 1
|
||||
|
||||
input_nodes = [
|
||||
n
|
||||
for n in gm.graph.nodes
|
||||
if n.op == "placeholder" and n.name == "l_mod_buf"
|
||||
]
|
||||
|
||||
self.assertGreater(len(input_nodes), 0)
|
||||
for input_node in input_nodes:
|
||||
self.assertEqual(
|
||||
input_node.meta["tensor_dict"]["_dynamo_static_input_type"],
|
||||
"unguarded",
|
||||
)
|
||||
|
||||
return gm
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buf = torch.ones(2, 2)
|
||||
|
||||
def forward(self, x):
|
||||
return self.buf * x
|
||||
|
||||
mod = TestModule()
|
||||
|
||||
@torch._dynamo.optimize(backend=debug_compiler)
|
||||
def fn(x):
|
||||
return x * mod(x)
|
||||
|
||||
inp = torch.ones(2)
|
||||
fn(inp)
|
||||
self.assertEqual(num_compiles, 1)
|
||||
|
||||
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
|
||||
@torch._inductor.config.patch("freezing", True)
|
||||
@torch.no_grad()
|
||||
|
|
|
|||
|
|
@ -5593,6 +5593,64 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
|
||||
self.assertTrue(cnt.frame_count <= 2)
|
||||
|
||||
@torch._dynamo.config.patch(guard_nn_modules=False)
|
||||
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=False)
|
||||
def test_inlining_cornercase(self):
|
||||
"""
|
||||
nn.Modules can be mapped to either NNModuleVariable or UnspecializedNNModuleVariable. For NNModuleVariable, the
|
||||
tensor attributes become part of the Dynamo graph. For unspecialized, they are lifted as inputs.
|
||||
|
||||
But there is a cornercase. Suppose you have NNModuleVariable with a submodule that is
|
||||
UnspecializedNNModuleVariable. Today, Dynamo will still consider the submodule as specialized (courtesy of
|
||||
guard.source().is_nn_module()). In retrospect, this is a mistake but there are dependencies of export and also
|
||||
cudagraphs which make it harder to fix the corner case right away. The long term solution is
|
||||
inline_inbuilt_nn_modules anyways, so we might have to live with this cornercase in the short term.
|
||||
|
||||
We are starting to annotate the source of each nn module more precisely - NNModuleVariable attribute is marked
|
||||
as NNModuleSource, UnspecilaizedNNModuleVariable attribute is marked as UnspecializedNNModuleSource. But this
|
||||
changes the behavior for the cornercase. And fails some tests which have unfortunately relied on this behavior.
|
||||
|
||||
|
||||
To solve this, we tag the source only when inline_inbuilt_nn_module flag is turned on.
|
||||
|
||||
In this test, we purposely turn the flag off, testing that the tagging is disabled.
|
||||
"""
|
||||
|
||||
class SubMod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(1, 1)
|
||||
self.a = torch.randn(1, 1)
|
||||
self.counter = 0
|
||||
self.multipliers = [2.2, 3.3]
|
||||
|
||||
def forward(self, x):
|
||||
self.counter += 1
|
||||
return (
|
||||
self.linear(x) * self.a * self.multipliers[0] * self.multipliers[1]
|
||||
)
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.submod = SubMod()
|
||||
|
||||
def forward(self, x):
|
||||
return self.submod(x)
|
||||
|
||||
mod = Mod()
|
||||
opt_mod = torch.compile(mod, backend="eager")
|
||||
|
||||
x = torch.randn(1, 1)
|
||||
ref = mod(x)
|
||||
res = opt_mod(x)
|
||||
|
||||
mod.submod.multipliers = [3.3, 4.4]
|
||||
# Since guard_nn_modules is False, this will not recompile
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
ref = mod(x)
|
||||
res = opt_mod(x)
|
||||
|
||||
def test_optimized_module_training(self):
|
||||
mod = torch.nn.Linear(3, 3)
|
||||
mod.eval()
|
||||
|
|
|
|||
|
|
@ -99,6 +99,7 @@ from .source import (
|
|||
TypeSource,
|
||||
UnspecializedBuiltinNNModuleSource,
|
||||
UnspecializedNNModuleSource,
|
||||
UnspecializedParamBufferSource,
|
||||
WeakRefCallSource,
|
||||
)
|
||||
from .types import CacheEntry, ExtraState, GuardedCode, GuardFail, GuardFn # noqa: F401
|
||||
|
|
@ -876,7 +877,7 @@ class GuardBuilder(GuardBuilderBase):
|
|||
example_value=example_value,
|
||||
guard_manager_enum=guard_manager_enum,
|
||||
)
|
||||
elif istype(source, AttrSource):
|
||||
elif istype(source, (AttrSource, UnspecializedParamBufferSource)):
|
||||
assert base_guard_manager # to make mypy happy
|
||||
|
||||
if (
|
||||
|
|
@ -1940,7 +1941,7 @@ class GuardBuilder(GuardBuilderBase):
|
|||
#
|
||||
assert guard.source is not None
|
||||
static, reason = tensor_always_has_static_shape(
|
||||
value, is_tensor=True, guard_source=guard.source
|
||||
value, is_tensor=True, tensor_source=guard.originating_source
|
||||
)
|
||||
|
||||
if not static:
|
||||
|
|
|
|||
|
|
@ -236,6 +236,12 @@ class ParamBufferSource(AttrSource):
|
|||
return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()]
|
||||
|
||||
|
||||
# Special AttrSource to differentiate module._buffers or module._parameters
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class UnspecializedParamBufferSource(AttrSource):
|
||||
pass
|
||||
|
||||
|
||||
# This source is intended to be used in places where a source is needed but it is expected
|
||||
# that the symbol will be simplified out later on. Symbols with ephemeral sources are
|
||||
# prioritized to be simplified out when e.g. compared against a symbol without an ephemeral
|
||||
|
|
@ -704,6 +710,14 @@ def is_from_local_source(source: Source, *, allow_cell_or_freevar=True):
|
|||
return True
|
||||
|
||||
|
||||
def is_from_unspecialized_param_buffer_source(source: Source):
|
||||
if isinstance(source, UnspecializedParamBufferSource):
|
||||
return True
|
||||
if isinstance(source, ChainedSource):
|
||||
return is_from_unspecialized_param_buffer_source(source.base)
|
||||
return False
|
||||
|
||||
|
||||
def is_from_flatten_script_object_source(source: Source):
|
||||
if isinstance(source, FlattenScriptObjectSource):
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ from torch._C import (
|
|||
_push_on_torch_function_stack,
|
||||
)
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._guards import TracingContext
|
||||
from torch._guards import Source, TracingContext
|
||||
from torch._subclasses.meta_utils import is_sparse_compressed
|
||||
from torch._utils_internal import log_chromium_event_internal, log_compilation_event
|
||||
from torch.fx._utils import _format_graph_code, lazy_format_graph_code
|
||||
|
|
@ -2344,7 +2344,7 @@ def tensor_static_reason_to_message(reason: TensorStaticReason):
|
|||
def tensor_always_has_static_shape(
|
||||
tensor: Union[torch.Tensor, Any],
|
||||
is_tensor: bool,
|
||||
guard_source: torch._guards.GuardSource,
|
||||
tensor_source: Source,
|
||||
) -> Tuple[bool, Optional[TensorStaticReason]]:
|
||||
"""
|
||||
Given a tensor, source, and is_tensor flag, determine if a shape should be static.
|
||||
|
|
@ -2357,12 +2357,20 @@ def tensor_always_has_static_shape(
|
|||
Returns a tuple, where the first element is the bool of whether or not this tensor should have a static shape.
|
||||
The second element is a TensorStaticReason, useful for passing to tensor_static_reason_to_message if needed.
|
||||
"""
|
||||
from .source import is_from_unspecialized_param_buffer_source
|
||||
|
||||
if (
|
||||
guard_source.is_specialized_nn_module()
|
||||
and config.force_nn_module_property_static_shapes
|
||||
):
|
||||
tensor_source.guard_source().is_specialized_nn_module()
|
||||
# Marking the tensor attributes of nn modules static to keep the behavior same as before
|
||||
# inline_inbuilt_nn_module flag was introduced.
|
||||
or tensor_source.guard_source().is_unspecialized_nn_module()
|
||||
) and config.force_nn_module_property_static_shapes:
|
||||
return True, TensorStaticReason.NN_MODULE_PROPERTY
|
||||
if type(tensor) is torch.nn.Parameter and config.force_parameter_static_shapes:
|
||||
|
||||
if (
|
||||
type(tensor) is torch.nn.Parameter
|
||||
or is_from_unspecialized_param_buffer_source(tensor_source)
|
||||
) and config.force_parameter_static_shapes:
|
||||
return True, TensorStaticReason.PARAMETER
|
||||
if not is_tensor:
|
||||
return True, TensorStaticReason.NOT_TENSOR
|
||||
|
|
|
|||
|
|
@ -82,7 +82,12 @@ from .misc import (
|
|||
TypingVariable,
|
||||
UnknownVariable,
|
||||
)
|
||||
from .nn_module import NNModuleVariable, UnspecializedNNModuleVariable
|
||||
from .nn_module import (
|
||||
FSDPManagedNNModuleVariable,
|
||||
NNModuleVariable,
|
||||
UnspecializedBuiltinNNModuleVariable,
|
||||
UnspecializedNNModuleVariable,
|
||||
)
|
||||
from .optimizer import OptimizerVariable
|
||||
from .sdpa import SDPAParamsVariable
|
||||
from .tensor import (
|
||||
|
|
|
|||
|
|
@ -170,7 +170,11 @@ from .misc import (
|
|||
TorchVersionVariable,
|
||||
TypingVariable,
|
||||
)
|
||||
from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable
|
||||
from .nn_module import (
|
||||
FSDPManagedNNModuleVariable,
|
||||
UnspecializedBuiltinNNModuleVariable,
|
||||
UnspecializedNNModuleVariable,
|
||||
)
|
||||
from .optimizer import OptimizerVariable
|
||||
from .script_object import TorchScriptObjectVariable
|
||||
from .sdpa import SDPAParamsVariable
|
||||
|
|
@ -1320,7 +1324,11 @@ class VariableBuilder:
|
|||
# this will get cleaned up once compile ends
|
||||
self.tx.output.nn_modules[self.name] = value
|
||||
|
||||
result = UnspecializedNNModuleVariable(value, source=self.source)
|
||||
if value.__module__.startswith(("torch.nn.", "torch.ao.")):
|
||||
result = UnspecializedBuiltinNNModuleVariable(value, source=self.source)
|
||||
else:
|
||||
result = UnspecializedNNModuleVariable(value, source=self.source)
|
||||
|
||||
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
|
||||
# don't allow STORE_ATTR mutation with custom __setattr__
|
||||
return result
|
||||
|
|
@ -1349,6 +1357,7 @@ class VariableBuilder:
|
|||
# specialized (as we don't expect users to be changing the
|
||||
# NN modules on the fly)
|
||||
or self.source.guard_source().is_specialized_nn_module()
|
||||
or self.source.guard_source().is_unspecialized_builtin_nn_module()
|
||||
or is_from_defaults(self.source)
|
||||
or is_cell_contents(self.source)
|
||||
# TODO: Delete this condition when rollout is done. NB: this
|
||||
|
|
@ -1392,7 +1401,12 @@ class VariableBuilder:
|
|||
if (
|
||||
config.inline_inbuilt_nn_modules
|
||||
and not is_static_input
|
||||
and isinstance(value, torch.nn.Parameter)
|
||||
and (
|
||||
isinstance(value, torch.nn.Parameter)
|
||||
# mark tensor attributes of nn modules static. This is done to keep inline_inbuilt_nn_modules behavior
|
||||
# compatible with previous behavior.
|
||||
or (source and source.guard_source().is_unspecialized_nn_module())
|
||||
)
|
||||
):
|
||||
self.mark_static_input(value, guard=False)
|
||||
|
||||
|
|
@ -2572,7 +2586,9 @@ def wrap_to_fake_tensor_and_record(
|
|||
):
|
||||
assert source is not None
|
||||
static_shapes, reason = tensor_always_has_static_shape(
|
||||
e, is_tensor, guard_source=source.guard_source()
|
||||
e,
|
||||
is_tensor,
|
||||
tensor_source=source,
|
||||
)
|
||||
|
||||
if not parent_context:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import inspect
|
|||
import itertools
|
||||
import types
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Any, Dict, List, TYPE_CHECKING
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
|
||||
import torch.nn
|
||||
|
||||
|
|
@ -24,6 +24,7 @@ from ..source import (
|
|||
FSDPNNModuleSource,
|
||||
GetItemSource,
|
||||
NNModuleSource,
|
||||
UnspecializedBuiltinNNModuleSource,
|
||||
UnspecializedNNModuleSource,
|
||||
)
|
||||
from ..utils import (
|
||||
|
|
@ -800,6 +801,11 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
|||
# nn_module_stack_source appropriately to resemble mod.linear.
|
||||
self.nn_module_stack_source = self.source
|
||||
|
||||
def _wrap_source(self, attr_source):
|
||||
if not isinstance(attr_source, UnspecializedNNModuleSource):
|
||||
return UnspecializedNNModuleSource(attr_source)
|
||||
return attr_source
|
||||
|
||||
def get_nn_module_stack_source(self):
|
||||
return self.nn_module_stack_source or self.source
|
||||
|
||||
|
|
@ -1131,6 +1137,17 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
|||
return out
|
||||
|
||||
|
||||
class UnspecializedBuiltinNNModuleVariable(UnspecializedNNModuleVariable):
|
||||
"""
|
||||
Differentiates between builtin nn modules (e.g. torch.nn.Linear) and user defined nn modules.
|
||||
"""
|
||||
|
||||
def _wrap_source(self, attr_source):
|
||||
if not isinstance(attr_source, UnspecializedBuiltinNNModuleSource):
|
||||
return UnspecializedBuiltinNNModuleSource(attr_source)
|
||||
return attr_source
|
||||
|
||||
|
||||
class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable):
|
||||
"""
|
||||
Tracing behavior: trace into submodules and treat them as Unspecialized, do not
|
||||
|
|
@ -1152,19 +1169,12 @@ class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable):
|
|||
super().__init__(value=value, **kwargs)
|
||||
self.source = source
|
||||
|
||||
@staticmethod
|
||||
def _wrap_source(source):
|
||||
if not isinstance(source, (FSDPNNModuleSource, UnspecializedNNModuleSource)):
|
||||
def _wrap_source(self, attr_source):
|
||||
if not isinstance(
|
||||
attr_source, (FSDPNNModuleSource, UnspecializedNNModuleSource)
|
||||
):
|
||||
if torch._dynamo.config.skip_fsdp_guards:
|
||||
return FSDPNNModuleSource(source)
|
||||
return FSDPNNModuleSource(attr_source)
|
||||
else:
|
||||
# this makes us behave like a usual UnspecializedNNModuleVariable for guarding purposes
|
||||
return UnspecializedNNModuleSource(source)
|
||||
else:
|
||||
return source
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name == "source":
|
||||
value = FSDPManagedNNModuleVariable._wrap_source(value)
|
||||
|
||||
return super().__setattr__(name, value)
|
||||
return UnspecializedNNModuleSource(attr_source)
|
||||
return attr_source
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from ..source import (
|
|||
GetItemSource,
|
||||
ODictGetItemSource,
|
||||
RandomValueSource,
|
||||
UnspecializedParamBufferSource,
|
||||
WeakRefCallSource,
|
||||
)
|
||||
from ..utils import (
|
||||
|
|
@ -1088,6 +1089,25 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||
else:
|
||||
return trace_rules.lookup(func)(func)
|
||||
|
||||
if (
|
||||
# wrap the source only if inline_inbuilt_nn_modules is set or fsdp modules. This is a temporary solution to
|
||||
# keep Dynamo behavior compatible with no inlining, as there will be some delay to turn on the flag in
|
||||
# fbcode.
|
||||
(
|
||||
torch._dynamo.config.inline_inbuilt_nn_modules
|
||||
or isinstance(self, variables.FSDPManagedNNModuleVariable)
|
||||
)
|
||||
and source
|
||||
and isinstance(self, variables.UnspecializedNNModuleVariable)
|
||||
# export has some awkwardness around specialized and unspecialized modules. Skip wrapping source for export
|
||||
# usecase for now.
|
||||
and not tx.output.export
|
||||
):
|
||||
# Recalculate source for params/buffers
|
||||
if name in ("_buffers", "_parameters"):
|
||||
source = UnspecializedParamBufferSource(self.source, name)
|
||||
source = self._wrap_source(source)
|
||||
|
||||
if subobj is not NO_SUCH_SUBOBJ:
|
||||
if is_wrapper_or_member_descriptor(subobj):
|
||||
options = {"source": source}
|
||||
|
|
|
|||
|
|
@ -129,6 +129,7 @@ class GuardSource(enum.Enum):
|
|||
GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
|
||||
GuardSource.LOCAL_FSDP_MODULE,
|
||||
GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
|
||||
GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user