[dynamo] Track params/buffers and mark them as static (#132334)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132334
Approved by: https://github.com/ezyang, https://github.com/mlazos
This commit is contained in:
Animesh Jain 2024-08-01 16:21:58 -07:00 committed by PyTorch MergeBot
parent 2ee9895304
commit babb249a89
6 changed files with 36 additions and 8 deletions

View File

@ -560,6 +560,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
@inductor_config.patch({"freezing": True})
@patches
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
@torch.no_grad
@parametrize("batch_size", (3, 16, 32, 49))
@parametrize("in_features", (4, 68, 128)) # k should be a multiple of 4

View File

@ -99,6 +99,7 @@ from .source import (
TypeSource,
UnspecializedBuiltinNNModuleSource,
UnspecializedNNModuleSource,
UnspecializedParamBufferSource,
WeakRefCallSource,
)
from .types import CacheEntry, ExtraState, GuardedCode, GuardFail, GuardFn # noqa: F401
@ -875,7 +876,7 @@ class GuardBuilder(GuardBuilderBase):
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, AttrSource):
elif istype(source, (AttrSource, UnspecializedParamBufferSource)):
assert base_guard_manager # to make mypy happy
if (
@ -1914,7 +1915,7 @@ class GuardBuilder(GuardBuilderBase):
#
assert guard.source is not None
static, reason = tensor_always_has_static_shape(
value, is_tensor=True, guard_source=guard.source
value, is_tensor=True, tensor_source=guard.originating_source
)
if not static:

View File

@ -236,6 +236,12 @@ class ParamBufferSource(AttrSource):
return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()]
# Special AttrSource to differentiate module._buffers or module._parameters
@dataclasses.dataclass(frozen=True)
class UnspecializedParamBufferSource(AttrSource):
pass
# This source is intended to be used in places where a source is needed but it is expected
# that the symbol will be simplified out later on. Symbols with ephemeral sources are
# prioritized to be simplified out when e.g. compared against a symbol without an ephemeral
@ -679,6 +685,14 @@ def is_from_local_source(source: Source, *, allow_cell_or_freevar=True):
return True
def is_from_unspecialized_param_buffer_source(source: Source):
if isinstance(source, UnspecializedParamBufferSource):
return True
if isinstance(source, ChainedSource):
return is_from_unspecialized_param_buffer_source(source.base)
return False
def is_from_flatten_script_object_source(source: Source):
if isinstance(source, FlattenScriptObjectSource):
return True

View File

@ -60,7 +60,7 @@ import torch.fx.experimental.symbolic_shapes
import torch.utils._pytree as pytree
from torch import fx
from torch._dispatch.python import enable_python_dispatcher
from torch._guards import TracingContext
from torch._guards import Source, TracingContext
from torch._subclasses.meta_utils import is_sparse_compressed
from torch._utils_internal import log_compilation_event
from torch.fx._utils import _format_graph_code, lazy_format_graph_code
@ -2156,7 +2156,7 @@ def tensor_static_reason_to_message(reason: TensorStaticReason):
def tensor_always_has_static_shape(
tensor: Union[torch.Tensor, Any],
is_tensor: bool,
guard_source: "torch._guards.GuardSource",
tensor_source: Source,
) -> Tuple[bool, Optional[TensorStaticReason]]:
"""
Given a tensor, source, and is_tensor flag, determine if a shape should be static.
@ -2169,12 +2169,18 @@ def tensor_always_has_static_shape(
Returns a tuple, where the first element is the bool of whether or not this tensor should have a static shape.
The second element is a TensorStaticReason, useful for passing to tensor_static_reason_to_message if needed.
"""
from .source import is_from_unspecialized_param_buffer_source
if (
guard_source.is_specialized_nn_module()
or guard_source.is_unspecialized_builtin_nn_module()
tensor_source.guard_source().is_specialized_nn_module()
or tensor_source.guard_source().is_unspecialized_builtin_nn_module()
) and config.force_nn_module_property_static_shapes:
return True, TensorStaticReason.NN_MODULE_PROPERTY
if type(tensor) is torch.nn.Parameter and config.force_parameter_static_shapes:
if (
type(tensor) is torch.nn.Parameter
or is_from_unspecialized_param_buffer_source(tensor_source)
) and config.force_parameter_static_shapes:
return True, TensorStaticReason.PARAMETER
if not is_tensor:
return True, TensorStaticReason.NOT_TENSOR

View File

@ -2548,7 +2548,9 @@ def wrap_to_fake_tensor_and_record(
):
assert source is not None
static_shapes, reason = tensor_always_has_static_shape(
e, is_tensor, guard_source=source.guard_source()
e,
is_tensor,
tensor_source=source,
)
if not parent_context:

View File

@ -27,6 +27,7 @@ from ..source import (
GetItemSource,
ODictGetItemSource,
RandomValueSource,
UnspecializedParamBufferSource,
WeakRefCallSource,
)
from ..utils import (
@ -1022,6 +1023,9 @@ class UserDefinedObjectVariable(UserDefinedVariable):
return trace_rules.lookup(func)(func)
if source and isinstance(self, variables.UnspecializedNNModuleVariable):
# Recalculate source for params/buffers
if name in ("_buffers", "_parameters"):
source = UnspecializedParamBufferSource(self.source, name)
source = self._wrap_source(source)
if subobj is not NO_SUCH_SUBOBJ and not is_wrapper_or_member_descriptor(subobj):