mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[dynamo] Track params/buffers and mark them as static (#132334)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132334 Approved by: https://github.com/ezyang, https://github.com/mlazos
This commit is contained in:
parent
2ee9895304
commit
babb249a89
|
|
@ -560,6 +560,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@patches
|
||||
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
|
||||
@torch.no_grad
|
||||
@parametrize("batch_size", (3, 16, 32, 49))
|
||||
@parametrize("in_features", (4, 68, 128)) # k should be a multiple of 4
|
||||
|
|
|
|||
|
|
@ -99,6 +99,7 @@ from .source import (
|
|||
TypeSource,
|
||||
UnspecializedBuiltinNNModuleSource,
|
||||
UnspecializedNNModuleSource,
|
||||
UnspecializedParamBufferSource,
|
||||
WeakRefCallSource,
|
||||
)
|
||||
from .types import CacheEntry, ExtraState, GuardedCode, GuardFail, GuardFn # noqa: F401
|
||||
|
|
@ -875,7 +876,7 @@ class GuardBuilder(GuardBuilderBase):
|
|||
example_value=example_value,
|
||||
guard_manager_enum=guard_manager_enum,
|
||||
)
|
||||
elif istype(source, AttrSource):
|
||||
elif istype(source, (AttrSource, UnspecializedParamBufferSource)):
|
||||
assert base_guard_manager # to make mypy happy
|
||||
|
||||
if (
|
||||
|
|
@ -1914,7 +1915,7 @@ class GuardBuilder(GuardBuilderBase):
|
|||
#
|
||||
assert guard.source is not None
|
||||
static, reason = tensor_always_has_static_shape(
|
||||
value, is_tensor=True, guard_source=guard.source
|
||||
value, is_tensor=True, tensor_source=guard.originating_source
|
||||
)
|
||||
|
||||
if not static:
|
||||
|
|
|
|||
|
|
@ -236,6 +236,12 @@ class ParamBufferSource(AttrSource):
|
|||
return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()]
|
||||
|
||||
|
||||
# Special AttrSource to differentiate module._buffers or module._parameters
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class UnspecializedParamBufferSource(AttrSource):
|
||||
pass
|
||||
|
||||
|
||||
# This source is intended to be used in places where a source is needed but it is expected
|
||||
# that the symbol will be simplified out later on. Symbols with ephemeral sources are
|
||||
# prioritized to be simplified out when e.g. compared against a symbol without an ephemeral
|
||||
|
|
@ -679,6 +685,14 @@ def is_from_local_source(source: Source, *, allow_cell_or_freevar=True):
|
|||
return True
|
||||
|
||||
|
||||
def is_from_unspecialized_param_buffer_source(source: Source):
|
||||
if isinstance(source, UnspecializedParamBufferSource):
|
||||
return True
|
||||
if isinstance(source, ChainedSource):
|
||||
return is_from_unspecialized_param_buffer_source(source.base)
|
||||
return False
|
||||
|
||||
|
||||
def is_from_flatten_script_object_source(source: Source):
|
||||
if isinstance(source, FlattenScriptObjectSource):
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ import torch.fx.experimental.symbolic_shapes
|
|||
import torch.utils._pytree as pytree
|
||||
from torch import fx
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._guards import TracingContext
|
||||
from torch._guards import Source, TracingContext
|
||||
from torch._subclasses.meta_utils import is_sparse_compressed
|
||||
from torch._utils_internal import log_compilation_event
|
||||
from torch.fx._utils import _format_graph_code, lazy_format_graph_code
|
||||
|
|
@ -2156,7 +2156,7 @@ def tensor_static_reason_to_message(reason: TensorStaticReason):
|
|||
def tensor_always_has_static_shape(
|
||||
tensor: Union[torch.Tensor, Any],
|
||||
is_tensor: bool,
|
||||
guard_source: "torch._guards.GuardSource",
|
||||
tensor_source: Source,
|
||||
) -> Tuple[bool, Optional[TensorStaticReason]]:
|
||||
"""
|
||||
Given a tensor, source, and is_tensor flag, determine if a shape should be static.
|
||||
|
|
@ -2169,12 +2169,18 @@ def tensor_always_has_static_shape(
|
|||
Returns a tuple, where the first element is the bool of whether or not this tensor should have a static shape.
|
||||
The second element is a TensorStaticReason, useful for passing to tensor_static_reason_to_message if needed.
|
||||
"""
|
||||
from .source import is_from_unspecialized_param_buffer_source
|
||||
|
||||
if (
|
||||
guard_source.is_specialized_nn_module()
|
||||
or guard_source.is_unspecialized_builtin_nn_module()
|
||||
tensor_source.guard_source().is_specialized_nn_module()
|
||||
or tensor_source.guard_source().is_unspecialized_builtin_nn_module()
|
||||
) and config.force_nn_module_property_static_shapes:
|
||||
return True, TensorStaticReason.NN_MODULE_PROPERTY
|
||||
if type(tensor) is torch.nn.Parameter and config.force_parameter_static_shapes:
|
||||
|
||||
if (
|
||||
type(tensor) is torch.nn.Parameter
|
||||
or is_from_unspecialized_param_buffer_source(tensor_source)
|
||||
) and config.force_parameter_static_shapes:
|
||||
return True, TensorStaticReason.PARAMETER
|
||||
if not is_tensor:
|
||||
return True, TensorStaticReason.NOT_TENSOR
|
||||
|
|
|
|||
|
|
@ -2548,7 +2548,9 @@ def wrap_to_fake_tensor_and_record(
|
|||
):
|
||||
assert source is not None
|
||||
static_shapes, reason = tensor_always_has_static_shape(
|
||||
e, is_tensor, guard_source=source.guard_source()
|
||||
e,
|
||||
is_tensor,
|
||||
tensor_source=source,
|
||||
)
|
||||
|
||||
if not parent_context:
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from ..source import (
|
|||
GetItemSource,
|
||||
ODictGetItemSource,
|
||||
RandomValueSource,
|
||||
UnspecializedParamBufferSource,
|
||||
WeakRefCallSource,
|
||||
)
|
||||
from ..utils import (
|
||||
|
|
@ -1022,6 +1023,9 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||
return trace_rules.lookup(func)(func)
|
||||
|
||||
if source and isinstance(self, variables.UnspecializedNNModuleVariable):
|
||||
# Recalculate source for params/buffers
|
||||
if name in ("_buffers", "_parameters"):
|
||||
source = UnspecializedParamBufferSource(self.source, name)
|
||||
source = self._wrap_source(source)
|
||||
|
||||
if subobj is not NO_SUCH_SUBOBJ and not is_wrapper_or_member_descriptor(subobj):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user