[fbode-testing][dynamo][reland][inline-inbuilt-nn-modules] Mark attri… (#134136)

Shuai wants to test this internally before https://github.com/pytorch/pytorch/pull/133713 can go in. Creating a separate PR for ghmport.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134136
Approved by: https://github.com/yanboliang
This commit is contained in:
Animesh Jain 2024-08-22 17:54:58 +00:00 committed by PyTorch MergeBot
parent 8f7d66f0c3
commit fee677eeb6
10 changed files with 204 additions and 28 deletions

View File

@ -2771,6 +2771,49 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
fn(inp, buf, mod)
self.assertEqual(num_compiles, 1)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_mark_static_nn_module_tensor(self):
# This test verifies that dynamo will mark
# the nn module tensor attributes as static
num_compiles = 0
def debug_compiler(gm, _):
nonlocal num_compiles
num_compiles += 1
input_nodes = [
n
for n in gm.graph.nodes
if n.op == "placeholder" and n.name == "l_mod_buf"
]
self.assertGreater(len(input_nodes), 0)
for input_node in input_nodes:
self.assertEqual(
input_node.meta["tensor_dict"]["_dynamo_static_input_type"],
"unguarded",
)
return gm
class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.buf = torch.ones(2, 2)
def forward(self, x):
return self.buf * x
mod = TestModule()
@torch._dynamo.optimize(backend=debug_compiler)
def fn(x):
return x * mod(x)
inp = torch.ones(2)
fn(inp)
self.assertEqual(num_compiles, 1)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
@torch._inductor.config.patch("freezing", True)
@torch.no_grad()

View File

@ -5593,6 +5593,64 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
self.assertTrue(cnt.frame_count <= 2)
@torch._dynamo.config.patch(guard_nn_modules=False)
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=False)
def test_inlining_cornercase(self):
"""
nn.Modules can be mapped to either NNModuleVariable or UnspecializedNNModuleVariable. For NNModuleVariable, the
tensor attributes become part of the Dynamo graph. For unspecialized, they are lifted as inputs.
But there is a cornercase. Suppose you have NNModuleVariable with a submodule that is
UnspecializedNNModuleVariable. Today, Dynamo will still consider the submodule as specialized (courtesy of
guard.source().is_nn_module()). In retrospect, this is a mistake but there are dependencies of export and also
cudagraphs which make it harder to fix the corner case right away. The long term solution is
inline_inbuilt_nn_modules anyways, so we might have to live with this cornercase in the short term.
We are starting to annotate the source of each nn module more precisely - NNModuleVariable attribute is marked
as NNModuleSource, UnspecilaizedNNModuleVariable attribute is marked as UnspecializedNNModuleSource. But this
changes the behavior for the cornercase. And fails some tests which have unfortunately relied on this behavior.
To solve this, we tag the source only when inline_inbuilt_nn_module flag is turned on.
In this test, we purposely turn the flag off, testing that the tagging is disabled.
"""
class SubMod(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(1, 1)
self.a = torch.randn(1, 1)
self.counter = 0
self.multipliers = [2.2, 3.3]
def forward(self, x):
self.counter += 1
return (
self.linear(x) * self.a * self.multipliers[0] * self.multipliers[1]
)
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.submod = SubMod()
def forward(self, x):
return self.submod(x)
mod = Mod()
opt_mod = torch.compile(mod, backend="eager")
x = torch.randn(1, 1)
ref = mod(x)
res = opt_mod(x)
mod.submod.multipliers = [3.3, 4.4]
# Since guard_nn_modules is False, this will not recompile
with torch._dynamo.config.patch(error_on_recompile=True):
ref = mod(x)
res = opt_mod(x)
def test_optimized_module_training(self):
mod = torch.nn.Linear(3, 3)
mod.eval()

View File

@ -99,6 +99,7 @@ from .source import (
TypeSource,
UnspecializedBuiltinNNModuleSource,
UnspecializedNNModuleSource,
UnspecializedParamBufferSource,
WeakRefCallSource,
)
from .types import CacheEntry, ExtraState, GuardedCode, GuardFail, GuardFn # noqa: F401
@ -876,7 +877,7 @@ class GuardBuilder(GuardBuilderBase):
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, AttrSource):
elif istype(source, (AttrSource, UnspecializedParamBufferSource)):
assert base_guard_manager # to make mypy happy
if (
@ -1940,7 +1941,7 @@ class GuardBuilder(GuardBuilderBase):
#
assert guard.source is not None
static, reason = tensor_always_has_static_shape(
value, is_tensor=True, guard_source=guard.source
value, is_tensor=True, tensor_source=guard.originating_source
)
if not static:

View File

@ -236,6 +236,12 @@ class ParamBufferSource(AttrSource):
return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()]
# Special AttrSource to differentiate module._buffers or module._parameters
@dataclasses.dataclass(frozen=True)
class UnspecializedParamBufferSource(AttrSource):
pass
# This source is intended to be used in places where a source is needed but it is expected
# that the symbol will be simplified out later on. Symbols with ephemeral sources are
# prioritized to be simplified out when e.g. compared against a symbol without an ephemeral
@ -704,6 +710,14 @@ def is_from_local_source(source: Source, *, allow_cell_or_freevar=True):
return True
def is_from_unspecialized_param_buffer_source(source: Source):
if isinstance(source, UnspecializedParamBufferSource):
return True
if isinstance(source, ChainedSource):
return is_from_unspecialized_param_buffer_source(source.base)
return False
def is_from_flatten_script_object_source(source: Source):
if isinstance(source, FlattenScriptObjectSource):
return True

View File

@ -69,7 +69,7 @@ from torch._C import (
_push_on_torch_function_stack,
)
from torch._dispatch.python import enable_python_dispatcher
from torch._guards import TracingContext
from torch._guards import Source, TracingContext
from torch._subclasses.meta_utils import is_sparse_compressed
from torch._utils_internal import log_chromium_event_internal, log_compilation_event
from torch.fx._utils import _format_graph_code, lazy_format_graph_code
@ -2344,7 +2344,7 @@ def tensor_static_reason_to_message(reason: TensorStaticReason):
def tensor_always_has_static_shape(
tensor: Union[torch.Tensor, Any],
is_tensor: bool,
guard_source: torch._guards.GuardSource,
tensor_source: Source,
) -> Tuple[bool, Optional[TensorStaticReason]]:
"""
Given a tensor, source, and is_tensor flag, determine if a shape should be static.
@ -2357,12 +2357,20 @@ def tensor_always_has_static_shape(
Returns a tuple, where the first element is the bool of whether or not this tensor should have a static shape.
The second element is a TensorStaticReason, useful for passing to tensor_static_reason_to_message if needed.
"""
from .source import is_from_unspecialized_param_buffer_source
if (
guard_source.is_specialized_nn_module()
and config.force_nn_module_property_static_shapes
):
tensor_source.guard_source().is_specialized_nn_module()
# Marking the tensor attributes of nn modules static to keep the behavior same as before
# inline_inbuilt_nn_module flag was introduced.
or tensor_source.guard_source().is_unspecialized_nn_module()
) and config.force_nn_module_property_static_shapes:
return True, TensorStaticReason.NN_MODULE_PROPERTY
if type(tensor) is torch.nn.Parameter and config.force_parameter_static_shapes:
if (
type(tensor) is torch.nn.Parameter
or is_from_unspecialized_param_buffer_source(tensor_source)
) and config.force_parameter_static_shapes:
return True, TensorStaticReason.PARAMETER
if not is_tensor:
return True, TensorStaticReason.NOT_TENSOR

View File

@ -82,7 +82,12 @@ from .misc import (
TypingVariable,
UnknownVariable,
)
from .nn_module import NNModuleVariable, UnspecializedNNModuleVariable
from .nn_module import (
FSDPManagedNNModuleVariable,
NNModuleVariable,
UnspecializedBuiltinNNModuleVariable,
UnspecializedNNModuleVariable,
)
from .optimizer import OptimizerVariable
from .sdpa import SDPAParamsVariable
from .tensor import (

View File

@ -170,7 +170,11 @@ from .misc import (
TorchVersionVariable,
TypingVariable,
)
from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable
from .nn_module import (
FSDPManagedNNModuleVariable,
UnspecializedBuiltinNNModuleVariable,
UnspecializedNNModuleVariable,
)
from .optimizer import OptimizerVariable
from .script_object import TorchScriptObjectVariable
from .sdpa import SDPAParamsVariable
@ -1320,7 +1324,11 @@ class VariableBuilder:
# this will get cleaned up once compile ends
self.tx.output.nn_modules[self.name] = value
result = UnspecializedNNModuleVariable(value, source=self.source)
if value.__module__.startswith(("torch.nn.", "torch.ao.")):
result = UnspecializedBuiltinNNModuleVariable(value, source=self.source)
else:
result = UnspecializedNNModuleVariable(value, source=self.source)
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
# don't allow STORE_ATTR mutation with custom __setattr__
return result
@ -1349,6 +1357,7 @@ class VariableBuilder:
# specialized (as we don't expect users to be changing the
# NN modules on the fly)
or self.source.guard_source().is_specialized_nn_module()
or self.source.guard_source().is_unspecialized_builtin_nn_module()
or is_from_defaults(self.source)
or is_cell_contents(self.source)
# TODO: Delete this condition when rollout is done. NB: this
@ -1392,7 +1401,12 @@ class VariableBuilder:
if (
config.inline_inbuilt_nn_modules
and not is_static_input
and isinstance(value, torch.nn.Parameter)
and (
isinstance(value, torch.nn.Parameter)
# mark tensor attributes of nn modules static. This is done to keep inline_inbuilt_nn_modules behavior
# compatible with previous behavior.
or (source and source.guard_source().is_unspecialized_nn_module())
)
):
self.mark_static_input(value, guard=False)
@ -2572,7 +2586,9 @@ def wrap_to_fake_tensor_and_record(
):
assert source is not None
static_shapes, reason = tensor_always_has_static_shape(
e, is_tensor, guard_source=source.guard_source()
e,
is_tensor,
tensor_source=source,
)
if not parent_context:

View File

@ -5,7 +5,7 @@ import inspect
import itertools
import types
from contextlib import contextmanager, nullcontext
from typing import Any, Dict, List, TYPE_CHECKING
from typing import Dict, List, TYPE_CHECKING
import torch.nn
@ -24,6 +24,7 @@ from ..source import (
FSDPNNModuleSource,
GetItemSource,
NNModuleSource,
UnspecializedBuiltinNNModuleSource,
UnspecializedNNModuleSource,
)
from ..utils import (
@ -800,6 +801,11 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
# nn_module_stack_source appropriately to resemble mod.linear.
self.nn_module_stack_source = self.source
def _wrap_source(self, attr_source):
if not isinstance(attr_source, UnspecializedNNModuleSource):
return UnspecializedNNModuleSource(attr_source)
return attr_source
def get_nn_module_stack_source(self):
return self.nn_module_stack_source or self.source
@ -1131,6 +1137,17 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
return out
class UnspecializedBuiltinNNModuleVariable(UnspecializedNNModuleVariable):
"""
Differentiates between builtin nn modules (e.g. torch.nn.Linear) and user defined nn modules.
"""
def _wrap_source(self, attr_source):
if not isinstance(attr_source, UnspecializedBuiltinNNModuleSource):
return UnspecializedBuiltinNNModuleSource(attr_source)
return attr_source
class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable):
"""
Tracing behavior: trace into submodules and treat them as Unspecialized, do not
@ -1152,19 +1169,12 @@ class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable):
super().__init__(value=value, **kwargs)
self.source = source
@staticmethod
def _wrap_source(source):
if not isinstance(source, (FSDPNNModuleSource, UnspecializedNNModuleSource)):
def _wrap_source(self, attr_source):
if not isinstance(
attr_source, (FSDPNNModuleSource, UnspecializedNNModuleSource)
):
if torch._dynamo.config.skip_fsdp_guards:
return FSDPNNModuleSource(source)
return FSDPNNModuleSource(attr_source)
else:
# this makes us behave like a usual UnspecializedNNModuleVariable for guarding purposes
return UnspecializedNNModuleSource(source)
else:
return source
def __setattr__(self, name: str, value: Any) -> None:
if name == "source":
value = FSDPManagedNNModuleVariable._wrap_source(value)
return super().__setattr__(name, value)
return UnspecializedNNModuleSource(attr_source)
return attr_source

View File

@ -31,6 +31,7 @@ from ..source import (
GetItemSource,
ODictGetItemSource,
RandomValueSource,
UnspecializedParamBufferSource,
WeakRefCallSource,
)
from ..utils import (
@ -1088,6 +1089,25 @@ class UserDefinedObjectVariable(UserDefinedVariable):
else:
return trace_rules.lookup(func)(func)
if (
# wrap the source only if inline_inbuilt_nn_modules is set or fsdp modules. This is a temporary solution to
# keep Dynamo behavior compatible with no inlining, as there will be some delay to turn on the flag in
# fbcode.
(
torch._dynamo.config.inline_inbuilt_nn_modules
or isinstance(self, variables.FSDPManagedNNModuleVariable)
)
and source
and isinstance(self, variables.UnspecializedNNModuleVariable)
# export has some awkwardness around specialized and unspecialized modules. Skip wrapping source for export
# usecase for now.
and not tx.output.export
):
# Recalculate source for params/buffers
if name in ("_buffers", "_parameters"):
source = UnspecializedParamBufferSource(self.source, name)
source = self._wrap_source(source)
if subobj is not NO_SUCH_SUBOBJ:
if is_wrapper_or_member_descriptor(subobj):
options = {"source": source}

View File

@ -129,6 +129,7 @@ class GuardSource(enum.Enum):
GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
GuardSource.LOCAL_FSDP_MODULE,
GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
)