mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Previously we were accomodating `torch._dynamo.mark_dynamic()` for export's dynamic shapes. Here we clean things up and ignore it, requiring users to specify an export input for `dynamic_shapes`. Note: there's 4 decorators relevant to export, `mark_dynamic, maybe_mark_dynamic, mark_static, mark_unbacked`. User calls that involve export have only been `mark_dynamic()`, and we use `maybe_mark_dynamic` under the hood for `Dim.AUTO`, but we could start using others. One reason I decided to not warn and just silently ignore is these decorators cause the tensors to carry dynamic info, and it'll be hard to tell whether the markers are from export or user calls when re-exporting with the same inputs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135536 Approved by: https://github.com/avikchaudhuri
1211 lines
48 KiB
Python
1211 lines
48 KiB
Python
# mypy: allow-untyped-defs
|
|
import dataclasses
|
|
import inspect
|
|
import logging
|
|
import sys
|
|
from collections import defaultdict
|
|
from enum import auto, Enum
|
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
|
|
|
|
import torch
|
|
from torch.utils._pytree import (
|
|
_get_node_type,
|
|
BUILTIN_TYPES,
|
|
keystr,
|
|
LeafSpec,
|
|
MappingKey,
|
|
SequenceKey,
|
|
SUPPORTED_NODES,
|
|
tree_flatten,
|
|
tree_map_with_path,
|
|
)
|
|
|
|
from .exported_program import ExportedProgram
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from sympy import Symbol
|
|
|
|
from torch._guards import Source
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint
|
|
|
|
__all__ = [
|
|
"Constraint",
|
|
"Dim",
|
|
"dims",
|
|
"refine_dynamic_shapes_from_suggested_fixes",
|
|
]
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class _DimHint(Enum):
|
|
"""
|
|
Enum for dynamic shape hints.
|
|
- AUTO means automatic inference of shape (static or dynamic).
|
|
- STATIC means static shape (always specialized).
|
|
"""
|
|
|
|
AUTO = auto()
|
|
STATIC = auto()
|
|
|
|
|
|
class _Dim(type):
|
|
"""
|
|
Metaclass for :func:`Dim` types.
|
|
"""
|
|
|
|
@staticmethod
|
|
def readable(name, min_, max_):
|
|
from torch.utils._sympy.numbers import int_oo
|
|
|
|
if min_ == 2:
|
|
min_ = None
|
|
if max_ == int_oo:
|
|
max_ = None
|
|
if min_ is None and max_ is None:
|
|
return f"Dim('{name}')"
|
|
if min_ is None:
|
|
return f"Dim('{name}', max={max_})"
|
|
if max_ is None:
|
|
return f"Dim('{name}', min={min_})"
|
|
return f"Dim('{name}', min={min_}, max={max_})"
|
|
|
|
def __add__(cls, other):
|
|
# e.g., dim + 1
|
|
if type(other) is not int:
|
|
raise NotImplementedError(
|
|
f"Attempted to add {other} to {cls.__name__}, where an integer was expected. "
|
|
"(Only increasing linear operations with integer coefficients are supported.)"
|
|
)
|
|
return cls._derive(lambda x: x + other)
|
|
|
|
def __radd__(cls, other):
|
|
return cls + other
|
|
|
|
def __sub__(cls, other):
|
|
# e.g., dim - 1
|
|
if type(other) is not int:
|
|
raise NotImplementedError(
|
|
f"Attempted to subtract {other} from {cls.__name__}, where an integer was expected. "
|
|
"(Only increasing linear operations with integer coefficients are supported.)"
|
|
)
|
|
return cls._derive(lambda x: x - other)
|
|
|
|
def __rsub__(cls, other):
|
|
raise NotImplementedError(
|
|
f"Attempted to negate {cls.__name__}. "
|
|
"(Only increasing linear operations with integer coefficients are supported.)"
|
|
)
|
|
|
|
def __mul__(cls, other):
|
|
# e.g., dim * 2
|
|
if type(other) is not int or other <= 0:
|
|
raise NotImplementedError(
|
|
f"Attempted to multiply {other} with {cls.__name__}, where a positive integer was expected. "
|
|
"(Only increasing linear operations with integer coefficients are supported.)"
|
|
)
|
|
return cls._derive(lambda x: x * other)
|
|
|
|
def __rmul__(cls, other):
|
|
return cls * other
|
|
|
|
def _derived_name(cls, fn):
|
|
from sympy import sympify
|
|
|
|
return str(fn(sympify(cls.__name__)))
|
|
|
|
def _derive(cls, fn):
|
|
return _DerivedDim(cls._derived_name(fn), (int,), {"root": cls, "fn": fn})
|
|
|
|
|
|
class _StaticDim(_Dim):
|
|
"""
|
|
Meta class for static :func:`Dim` types.
|
|
|
|
This class is only for setting and checking static dim constraints,
|
|
and the user should never interact with it.
|
|
"""
|
|
|
|
@property
|
|
def min(self):
|
|
return self.value # type: ignore[attr-defined]
|
|
|
|
@property
|
|
def max(self):
|
|
return self.value # type: ignore[attr-defined]
|
|
|
|
|
|
class _DerivedDim(_Dim):
|
|
"""
|
|
Metaclass for derived :func:`Dim` types.
|
|
|
|
Currently we only support increasing linear expressions with integer coefficients.
|
|
In other words, a derived Dim can always be written in the form Ax + B, where
|
|
x is a regular Dim (i.e., non-derived Dim), A and B are integers, and A is positive.
|
|
(In particular, the latter ensures that x < y => Ax + B < Ay + B.)
|
|
These restrictions on the form of derived Dims makes the metatheory simpler: e.g.,
|
|
it simplifies computing ranges for derived Dims, solving for underlying regular Dims,
|
|
deciding equalities between derived Dims, and so on.
|
|
|
|
The function lambda x: Ax + B is expressed by `fn`, where x is a normal Dim, `root`.
|
|
The range of a derived Dim is computed by mapping `fn` over the range of its `root`.
|
|
"""
|
|
|
|
@property
|
|
def min(self):
|
|
# assume that self.fn is an increasing function
|
|
# TODO(avik): use sympy value range analysis instead?
|
|
from sympy import Integer
|
|
|
|
from torch.utils._sympy.numbers import int_oo
|
|
|
|
if self.root.min is -int_oo: # type: ignore[attr-defined]
|
|
return -int_oo # fn not needed cuz increasing
|
|
|
|
_min_symint = self.fn(Integer(self.root.min)) # type: ignore[attr-defined]
|
|
root = self.root # type: ignore[attr-defined]
|
|
assert _min_symint >= 0, (
|
|
f"Expected derived min value of {self.__name__} to be >= 0. "
|
|
f"Please specify an appropriate min value for {root.__name__} "
|
|
f"(currently {root.min})."
|
|
)
|
|
return int(_min_symint)
|
|
|
|
@property
|
|
def max(self):
|
|
# assume that self.fn is an increasing function
|
|
# TODO(avik): use sympy value range analysis instead?
|
|
from sympy import Integer
|
|
|
|
from torch.utils._sympy.numbers import int_oo
|
|
|
|
if self.root.max is int_oo: # type: ignore[attr-defined]
|
|
return int_oo # fn not needed cuz increasing
|
|
|
|
_max_symint = self.fn(Integer(self.root.max)) # type: ignore[attr-defined]
|
|
root = self.root # type: ignore[attr-defined]
|
|
assert _max_symint <= sys.maxsize - 1, (
|
|
f"Expected derived max value of {self.__name__} to be <= {sys.maxsize - 1}. "
|
|
f"Please specify an appropriate max value for {root.__name__} "
|
|
f"(currently {root.max})."
|
|
)
|
|
return int(_max_symint)
|
|
|
|
def _derive(self, fn):
|
|
# We support nesting, e.g., 2*dim + 1.
|
|
# This is implemented by composing operations on the same root.
|
|
# As a consequence, roots are always regular Dims (i.e., not derived Dims).
|
|
return _DerivedDim(
|
|
self._derived_name(fn),
|
|
(int,),
|
|
{"root": self.root, "fn": lambda x: fn(self.fn(x))}, # type: ignore[attr-defined]
|
|
)
|
|
|
|
|
|
def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None):
|
|
"""
|
|
:func:`Dim` constructs a type analogous to a named symbolic integer with a range.
|
|
It can be used to describe multiple possible values of a dynamic tensor dimension.
|
|
Note that different dynamic dimensions of the same tensor, or of different tensors,
|
|
can be described by the same type.
|
|
|
|
Args:
|
|
name (str): Human-readable name for debugging.
|
|
min (Optional[int]): Minimum possible value of given symbol (inclusive)
|
|
max (Optional[int]): Maximum possible value of given symbol (inclusive)
|
|
|
|
Returns:
|
|
A type that can be used in dynamic shape specifications for tensors.
|
|
"""
|
|
|
|
from torch.utils._sympy.numbers import int_oo
|
|
|
|
_min = 0 if min is None else min
|
|
_max = int_oo if max is None else max
|
|
assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}"
|
|
assert name.isidentifier(), f"Dim name must be a valid identifier, got {name}"
|
|
dim = _Dim(name, (int,), {"min": _min, "max": _max})
|
|
dim.__module__ = getattr(
|
|
inspect.getmodule(inspect.stack()[1][0]), "__name__", "__main__"
|
|
)
|
|
return dim
|
|
|
|
|
|
Dim.AUTO = _DimHint.AUTO # type: ignore[attr-defined]
|
|
Dim.STATIC = _DimHint.STATIC # type: ignore[attr-defined]
|
|
|
|
|
|
def dims(*names: str, min: Optional[int] = None, max: Optional[int] = None):
|
|
"""
|
|
Util to create multiple :func:`Dim` types.
|
|
"""
|
|
return tuple(Dim(name, min=min, max=max) for name in names)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class _ConstraintTarget:
|
|
"""
|
|
This represents input tensor dimensions.
|
|
"""
|
|
|
|
t_id: int
|
|
dim: int
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class _Constraint(_ConstraintTarget):
|
|
"""
|
|
This represents a Dim describing a constraint target.
|
|
|
|
`name` is the name of the Dim.
|
|
`constraint_range` contains the min/max bounds of the Dim.
|
|
"""
|
|
|
|
name: str
|
|
constraint_range: "StrictMinMaxConstraint"
|
|
|
|
def _clone_with_range(self, lower=0, upper=None):
|
|
# Import sympy locally
|
|
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
|
|
from torch.utils._sympy.numbers import int_oo
|
|
from torch.utils._sympy.value_ranges import ValueRanges
|
|
|
|
if upper is None:
|
|
upper = int_oo
|
|
|
|
constraint_range = StrictMinMaxConstraint(
|
|
vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper),
|
|
warn_only=False,
|
|
)
|
|
return _Constraint(
|
|
self.t_id,
|
|
self.dim,
|
|
self.name,
|
|
constraint_range,
|
|
)
|
|
|
|
def __ge__(self, lower):
|
|
return self._clone_with_range(lower=lower)
|
|
|
|
def __gt__(self, lower):
|
|
return self._clone_with_range(lower=lower + 1)
|
|
|
|
def __le__(self, upper):
|
|
return self._clone_with_range(upper=upper)
|
|
|
|
def __lt__(self, upper):
|
|
return self._clone_with_range(upper=upper - 1)
|
|
|
|
def __bool__(self):
|
|
# NOTE(avik): We do not support compound expressions like a <= x <= b.
|
|
# This is because Python implicitly desugars them into bool(a <= x) and bool(x <= b),
|
|
# and moreover, enforces that any overload of __bool__ must return True or False.
|
|
# FWIW, sympy also raises TypeError in this case.
|
|
raise TypeError(
|
|
"Cannot determine truth value of _Constraint. "
|
|
"If you are trying to combine _Constraint's with logical connectives, "
|
|
"you can specify them separately instead."
|
|
)
|
|
|
|
@property
|
|
def serializable_spec(self):
|
|
# We need a serialization compatible format of the constraint so that it
|
|
# can be savedin the graph module w/o breaking the module serialization.
|
|
# The saved constraints will be used directly for the post-exporting pass
|
|
# that converts constraints to runtime assertion. The saved constraints
|
|
# will not be saved in the serialized module.
|
|
# TODO: A better way is needed. Currently we use 't_id' to map the constraint,
|
|
# which is not reliable
|
|
return {
|
|
"t_id": self.t_id,
|
|
"dim": self.dim,
|
|
"min": self.constraint_range.vr.lower,
|
|
"max": self.constraint_range.vr.upper,
|
|
}
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class _PhantomRoot:
|
|
"""
|
|
This represents the root of a derived Dim where the root does not directly
|
|
specify the shape of any input dimension, but the derived Dim does.
|
|
|
|
e.g., the input shapes 2*dim and dim + 1 are related via a "phantom" dim.
|
|
|
|
The fields `name`, `constraint_range`, and `val` carried by a phantom root
|
|
help create a symbol for it. Any derived dims with this phantom root are
|
|
backed by expressions over this symbol.
|
|
"""
|
|
|
|
name: str
|
|
constraint_range: "StrictMinMaxConstraint"
|
|
val: int
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class _DerivedConstraint(_ConstraintTarget):
|
|
"""
|
|
This represents a derived Dim, whose root is either a regular constraint target
|
|
(which directly specifies the shape of some input dimension) or a phantom root
|
|
(which does so indirectly).
|
|
|
|
It can be thought of as a subclass of `_Constraint`, except that it does not
|
|
support <, <=, >, >= operations.
|
|
"""
|
|
|
|
name: str
|
|
constraint_range: "StrictMinMaxConstraint"
|
|
root: Union[_ConstraintTarget, _PhantomRoot]
|
|
fn: Callable
|
|
|
|
@property
|
|
def serializable_spec(self):
|
|
# same as _Constraint.serializable_spec
|
|
return {
|
|
"t_id": self.t_id,
|
|
"dim": self.dim,
|
|
"min": self.constraint_range.vr.lower,
|
|
"max": self.constraint_range.vr.upper,
|
|
}
|
|
|
|
|
|
Constraint = Union[_Constraint, _DerivedConstraint]
|
|
|
|
|
|
def _process_equalities(
|
|
constraint: Constraint,
|
|
get_sources: Callable[[int, int], List["Source"]],
|
|
shape_env: "ShapeEnv",
|
|
names: Dict[str, Tuple[int, int]],
|
|
source_pairs: List[Tuple["Source", "Source"]],
|
|
derived_equalities: List[Tuple["Source", Union["Source", "Symbol"], Callable]],
|
|
phantom_symbols: Dict[str, "Symbol"],
|
|
):
|
|
"""
|
|
Updates `source_pairs`, `derived_equalities`, and `phantom_symbols` (which become
|
|
fields of `EqualityConstraint`) based on a given input `constraint`.
|
|
"""
|
|
|
|
sources = get_sources(constraint.t_id, constraint.dim)
|
|
if not sources: # empty sources due to unused shapes
|
|
return
|
|
|
|
source, *other_sources = sources
|
|
# When t.size()[dim] maps to src0, src1, ..., srcN, we add
|
|
# constraints that make src0 "equal" to src1, ..., srcN.
|
|
source_pairs.extend((source, other_source) for other_source in other_sources)
|
|
if not isinstance(constraint, _DerivedConstraint):
|
|
if constraint.name in names:
|
|
shared_t_id, shared_dim = names[constraint.name]
|
|
other_sources = get_sources(shared_t_id, shared_dim)
|
|
source_pairs.extend(
|
|
(source, other_source) for other_source in other_sources
|
|
)
|
|
else:
|
|
names[constraint.name] = (constraint.t_id, constraint.dim)
|
|
else:
|
|
# branch based on the root of the _DerivedConstraint
|
|
if not isinstance(constraint.root, _PhantomRoot):
|
|
# either root points to an input source
|
|
root = get_sources(constraint.root.t_id, constraint.root.dim)[0] # type: ignore[assignment]
|
|
else:
|
|
# or root points to a phantom symbol
|
|
if constraint.root.name in phantom_symbols:
|
|
root = phantom_symbols[constraint.root.name] # type: ignore[assignment]
|
|
else:
|
|
# create a phantom symbol in the shape env based on the _PhantomRoot
|
|
root = shape_env.create_symbol(
|
|
val=constraint.root.val,
|
|
source=torch._dynamo.source.ConstantSource(constraint.root.name),
|
|
dynamic_dim=torch.fx.experimental.symbolic_shapes.DimDynamic.DYNAMIC,
|
|
constraint_dim=constraint.root.constraint_range,
|
|
)
|
|
phantom_symbols[constraint.root.name] = root # type: ignore[assignment]
|
|
|
|
fn = constraint.fn
|
|
# A derived equality (source, root, fn) informally corresponds to source = fn(root).
|
|
# Here source describes an input and root might describe another input or a phantom symbol.
|
|
derived_equalities.append((source, root, fn))
|
|
|
|
|
|
def _tree_map_with_path(
|
|
func: Callable[..., Any],
|
|
tree: Any,
|
|
*dynamic_shapes: Any,
|
|
tree_name: Optional[str] = None,
|
|
) -> Any:
|
|
"""
|
|
Customized tree_map for mapping pytrees to dynamic_shapes.
|
|
|
|
For built-in types (e.g., standard collections) this behaves exactly like tree_map.
|
|
|
|
OTOH for a user-defined class C registered with pytree, we cannot assume that a C
|
|
containing tensors can be mapped to a C containing dynamic shapes (i.e., C may not
|
|
be a polymorphic container). In that case we use the flattened form of C instead.
|
|
Thus a C(**tensors) that flattens to (**tensors) will map to (**dynamic_shapes).
|
|
|
|
Args:
|
|
func: function to apply to each (int, float, str, bool, None, torch.Tensor)
|
|
tree: input pytree
|
|
dynamic_shapes: zero or more (typically one) dynamic_shapes to match
|
|
|
|
Returns:
|
|
output pytree mapping func to each (int, float, str, bool, None, torch.Tensor)
|
|
"""
|
|
|
|
def is_leaf(t):
|
|
# BUILTIN_TYPES is a subset of SUPPORTED_NODES, the latter being all types
|
|
# registered with pytree. Types *not* in BUILTIN_TYPES include primitive types
|
|
# (int, float, str, bool, None, torch.Tensor), which are not in SUPPORTED_NODES,
|
|
# as well as user-defined classes registered with pytree, which are.
|
|
return _get_node_type(t) not in BUILTIN_TYPES
|
|
|
|
def f(path, t, *dynamic_shapes):
|
|
typ = _get_node_type(t)
|
|
# typ is not in BUILTIN_TYPES
|
|
if typ in SUPPORTED_NODES:
|
|
# thus typ is a user-defined class registered with pytree,
|
|
# in which case flatten and recurse
|
|
return tree_map_with_path(
|
|
f,
|
|
SUPPORTED_NODES[typ].flatten_fn(t)[0],
|
|
*dynamic_shapes,
|
|
is_leaf=is_leaf,
|
|
)
|
|
else:
|
|
return func(path, t, *dynamic_shapes)
|
|
|
|
try:
|
|
return tree_map_with_path(f, tree, *dynamic_shapes, is_leaf=is_leaf)
|
|
except ValueError as e:
|
|
if "mismatch" in e.args[0]:
|
|
# When PyTree finds a structural mismatch between tree and dynamic_shapes,
|
|
# the error message is unfortunately quite horrible. Let's fix that.
|
|
assert dynamic_shapes, "Cannot be a mismatch if there is no dynamic_shapes"
|
|
assert tree_name, "Must provide a tree_name when there might be a mismatch"
|
|
|
|
def _key(type_, context, i):
|
|
# derive a PyTree key given the type, context, and child # of a TreeSpec
|
|
if type_ is dict:
|
|
return MappingKey(context[i])
|
|
if type_ in (list, tuple):
|
|
assert context is None
|
|
return SequenceKey(i)
|
|
raise AssertionError(f"Did not expect type {type_}")
|
|
|
|
def raise_mismatch_error(msg):
|
|
from torch._dynamo.exc import UserError, UserErrorType
|
|
|
|
raise UserError(
|
|
UserErrorType.INVALID_INPUT,
|
|
f"Detected mismatch between the structure of `{tree_name}` and `dynamic_shapes`: {msg}",
|
|
case_name="dynamic_shapes_validation",
|
|
)
|
|
|
|
def _compare(tree, dynamic_shapes, path):
|
|
# raise an error at the point where tree and dynamic_shapes differ,
|
|
# including the path to that point and the reason for the difference
|
|
rendered_path = keystr(path)
|
|
if isinstance(tree, LeafSpec):
|
|
return
|
|
if isinstance(dynamic_shapes, LeafSpec):
|
|
raise_mismatch_error(
|
|
f"`{tree_name}{rendered_path}` is a {tree.type}, "
|
|
f"but `dynamic_shapes{rendered_path}` is not"
|
|
)
|
|
if tree.type != dynamic_shapes.type:
|
|
raise_mismatch_error(
|
|
f"`{tree_name}{rendered_path}` is a {tree.type}, "
|
|
f"but `dynamic_shapes{rendered_path}` is a {dynamic_shapes.type}"
|
|
)
|
|
if len(tree.children_specs) != len(dynamic_shapes.children_specs):
|
|
raise_mismatch_error(
|
|
f"`{tree_name}{rendered_path}` has {len(tree.children_specs)} elements, "
|
|
f"but `dynamic_shapes{rendered_path}` has {len(dynamic_shapes.children_specs)} elements"
|
|
)
|
|
if tree.type is dict:
|
|
# context, children could be out of order
|
|
if sorted(tree.context) != sorted(dynamic_shapes.context):
|
|
raise_mismatch_error(
|
|
f"`{tree_name}{rendered_path}` has keys {tree.context}, "
|
|
f"but `dynamic_shapes{rendered_path}` has keys {dynamic_shapes.context}"
|
|
)
|
|
_remap = dict(
|
|
zip(dynamic_shapes.context, dynamic_shapes.children_specs)
|
|
)
|
|
dynamic_shapes_children_specs = [_remap[k] for k in tree.context]
|
|
else:
|
|
dynamic_shapes_children_specs = dynamic_shapes.children_specs
|
|
for i, (tree_, dynamic_shapes_) in enumerate(
|
|
zip(tree.children_specs, dynamic_shapes_children_specs)
|
|
):
|
|
_compare(
|
|
tree_,
|
|
dynamic_shapes_,
|
|
path + [_key(tree.type, tree.context, i)],
|
|
)
|
|
|
|
_, tree_spec = tree_flatten(tree, is_leaf=is_leaf)
|
|
for other_tree in dynamic_shapes:
|
|
_, other_tree_spec = tree_flatten(other_tree, is_leaf)
|
|
_compare(tree_spec, other_tree_spec, [])
|
|
raise
|
|
|
|
|
|
def _combine_args(f, args, kwargs, _is_torch_jit_trace=False) -> Dict[str, Any]:
|
|
# combine args and kwargs following the signature of f, as it happens
|
|
# in the body of f when called with *args, **kwargs
|
|
if isinstance(f, ExportedProgram):
|
|
f = f.module()
|
|
if not _is_torch_jit_trace:
|
|
signature = (
|
|
inspect.signature(f.forward)
|
|
if isinstance(f, torch.nn.Module)
|
|
else inspect.signature(f)
|
|
)
|
|
kwargs = kwargs if kwargs is not None else {}
|
|
return signature.bind(*args, **kwargs).arguments
|
|
return args
|
|
|
|
|
|
class ShapesCollection:
|
|
"""
|
|
Builder for dynamic_shapes.
|
|
Used to assign dynamic shape specifications to tensors that appear in inputs.
|
|
|
|
Example::
|
|
args = ({"x": tensor_x, "others": [tensor_y, tensor_z]})
|
|
|
|
dim = torch.export.Dim(...)
|
|
dynamic_shapes = torch.export.ShapesCollection()
|
|
dynamic_shapes[tensor_x] = (dim, dim + 1, 8)
|
|
dynamic_shapes[tensor_y] = {0: dim * 2}
|
|
# This is equivalent to the following (now auto-generated):
|
|
# dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]}
|
|
|
|
torch.export(..., args, dynamic_shapes=dynamic_shapes)
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._shapes = {}
|
|
|
|
def __setitem__(self, t, shape):
|
|
assert isinstance(
|
|
t, torch.Tensor
|
|
), f"Cannot assign shape to non-tensor type {type(t)}"
|
|
# TODO(avik): check that shape is indeed a Shape
|
|
t_id = id(t)
|
|
if t_id in self._shapes:
|
|
_shape = self._shapes[t_id]
|
|
assert (
|
|
shape == _shape
|
|
), f"Shapes assigned to tensor do not match: expected {_shape}, got {shape}"
|
|
else:
|
|
self._shapes[id(t)] = shape
|
|
|
|
def __getitem__(self, t):
|
|
t_id = id(t)
|
|
if t_id in self._shapes:
|
|
return self._shapes[t_id]
|
|
else:
|
|
return None
|
|
|
|
def __len__(self):
|
|
return len(self._shapes)
|
|
|
|
def dynamic_shapes(self, m, args, kwargs=None):
|
|
"""
|
|
Generate dynamic_shapes.
|
|
"""
|
|
|
|
t_ids = set()
|
|
|
|
def find_shape(path, t):
|
|
t_id = id(t)
|
|
if t_id in self._shapes:
|
|
t_ids.add(t_id)
|
|
return self._shapes[t_id]
|
|
else:
|
|
return None
|
|
|
|
combined_args = _combine_args(m, args, kwargs)
|
|
dynamic_shapes = _tree_map_with_path(find_shape, combined_args)
|
|
if any(t_id not in t_ids for t_id in self._shapes):
|
|
raise ValueError(
|
|
"Some tensors that were assigned shapes were not found in args. "
|
|
"Maybe such tensors were copied when passing them as args? "
|
|
"Maybe such tensors are contained in classes that were not registered with pytree?"
|
|
)
|
|
return dynamic_shapes
|
|
|
|
|
|
def _warn_on_None_dynamic_shape_dimension():
|
|
msg = (
|
|
"Using None as a dynamic shape dimension is deprecated. "
|
|
"Please use Dim.STATIC instead"
|
|
)
|
|
# TODO(avik): raise an error in the future
|
|
log.warning(msg)
|
|
|
|
|
|
def _check_dynamic_shapes(
|
|
combined_args: Dict[str, Any],
|
|
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
|
|
):
|
|
"""
|
|
Checks the dynamic_shapes specification for correctness,
|
|
using combined args + kwargs as reference for inputs structure.
|
|
"""
|
|
from torch._dynamo.exc import UserError, UserErrorType
|
|
from torch._export.non_strict_utils import _flatten_dynamic_shapes
|
|
|
|
if dynamic_shapes is None or len(dynamic_shapes) == 0:
|
|
return
|
|
if isinstance(dynamic_shapes, (tuple, list)):
|
|
combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc]
|
|
|
|
bounds: Dict[str, Tuple[int, int]] = {}
|
|
|
|
def check_same_bounds(dim):
|
|
if dim.__name__ in bounds:
|
|
min_, max_ = bounds[dim.__name__]
|
|
if dim.min != min_ or dim.max != max_:
|
|
this_ = _Dim.readable(dim.__name__, min_, max_)
|
|
that_ = _Dim.readable(dim.__name__, dim.min, dim.max)
|
|
raise UserError(
|
|
UserErrorType.INVALID_INPUT,
|
|
f"Found different definitions {this_} and {that_} "
|
|
f"for the same symbolic dimension {dim}!",
|
|
)
|
|
else:
|
|
bounds[dim.__name__] = (dim.min, dim.max)
|
|
|
|
def check_symbols(path, tensor, shape):
|
|
if isinstance(shape, dict):
|
|
for i, dim in shape.items():
|
|
if isinstance(dim, _Dim):
|
|
check_same_bounds(dim)
|
|
elif dim is None:
|
|
_warn_on_None_dynamic_shape_dimension()
|
|
elif not (isinstance(dim, (int, _DimHint))):
|
|
raise UserError(
|
|
UserErrorType.INVALID_INPUT,
|
|
f"Unexpected dimension mapped to index {i} in input tensor shape {shape} "
|
|
f"specified at `dynamic_shapes{keystr(path)}` "
|
|
f"(expected None, an int, a Dim, Dim.AUTO, or Dim.STATIC, but got {dim} instead)",
|
|
case_name="dynamic_shapes_validation",
|
|
)
|
|
elif isinstance(shape, (tuple, list)):
|
|
for i, dim in enumerate(shape):
|
|
if isinstance(dim, _Dim):
|
|
check_same_bounds(dim)
|
|
elif dim is None:
|
|
_warn_on_None_dynamic_shape_dimension()
|
|
elif not (isinstance(dim, (int, _DimHint))):
|
|
raise UserError(
|
|
UserErrorType.INVALID_INPUT,
|
|
f"Unexpected dimension #{i} in input tensor shape {shape} "
|
|
f"specified at `dynamic_shapes{keystr(path)}` "
|
|
f"(expected None, an int, a Dim, Dim.AUTO, or Dim.STATIC, but got {dim} instead)",
|
|
case_name="dynamic_shapes_validation",
|
|
)
|
|
elif shape is not None:
|
|
raise UserError(
|
|
UserErrorType.INVALID_INPUT,
|
|
f"Unexpected input tensor shape {shape} specified at `dynamic_shapes{keystr(path)}` "
|
|
f"(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions,"
|
|
f" where each dimension is an int, a Dim, Dim.AUTO, or Dim.STATIC)",
|
|
case_name="dynamic_shapes_validation",
|
|
)
|
|
|
|
assert isinstance(dynamic_shapes, (dict, tuple, list))
|
|
if isinstance(dynamic_shapes, dict):
|
|
got_keys = list(dynamic_shapes.keys())
|
|
expected_arg_names = list(combined_args.keys())
|
|
if sorted(got_keys) != sorted(expected_arg_names):
|
|
msg = (
|
|
f"When `dynamic_shapes` is specified as a dict, its top-level keys "
|
|
f"must be the arg names {expected_arg_names} of `inputs`, but "
|
|
f"here they are {got_keys}. "
|
|
)
|
|
if (
|
|
len(combined_args) == 1
|
|
and expected_arg_names[0] not in got_keys
|
|
and isinstance(combined_args[expected_arg_names[0]], dict)
|
|
):
|
|
msg += (
|
|
"Since here `inputs` is a list/tuple enclosing a single dict, "
|
|
"maybe you just forgot to enclose `dynamic_shapes` in a list/tuple?"
|
|
)
|
|
else:
|
|
msg += (
|
|
"Alternatively, you could also ignore arg names entirely "
|
|
"and specify `dynamic_shapes` as a list/tuple matching `inputs`."
|
|
)
|
|
raise UserError(
|
|
UserErrorType.INVALID_INPUT, msg, case_name="dynamic_shapes_validation"
|
|
)
|
|
|
|
def check_shape(path, t, dynamic_shape):
|
|
if isinstance(t, torch.Tensor):
|
|
check_symbols(path, t, dynamic_shape)
|
|
else:
|
|
if dynamic_shape is not None:
|
|
rendered_path = keystr(path)
|
|
raise UserError(
|
|
UserErrorType.INVALID_INPUT,
|
|
f"Cannot associate shape {dynamic_shape} specified at `dynamic_shapes{rendered_path}` "
|
|
f"to non-tensor type {type(t)} at `inputs{rendered_path}` (expected None)",
|
|
case_name="dynamic_shapes_validation",
|
|
)
|
|
|
|
_tree_map_with_path(check_shape, combined_args, dynamic_shapes, tree_name="inputs")
|
|
|
|
# raise user warning if both Dim.AUTO & Dims are specified in dynamic_shapes
|
|
flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes)
|
|
flatter_dynamic_shapes, _ = tree_flatten(flat_dynamic_shapes)
|
|
if any(isinstance(s, _Dim) for s in flatter_dynamic_shapes) and any(
|
|
s == _DimHint.AUTO for s in flatter_dynamic_shapes
|
|
):
|
|
raise UserError(
|
|
UserErrorType.INVALID_INPUT,
|
|
"Specifying both `Dim.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, "
|
|
"and can easily lead to constraint violation errors or obscure errors in torch.export. Dim/DerivedDims "
|
|
"expect all equal or related dimensions to be specified, and does not yet compose well with `Dim.AUTO`. "
|
|
"We suggest using `Dim.AUTO` mixed with `None` for auto-dynamic + static shapes, plus torch._check(dim >= min), "
|
|
"torch._check(dim <= max) calls in your program to specify min/max ranges, or `Dim`/`DerivedDim` mixed with `None` "
|
|
"if you want to assert on the exact specification of your program's dynamic shapes behavior.",
|
|
case_name="dynamic_shapes_validation",
|
|
)
|
|
|
|
|
|
def _transform_shapes_for_default_dynamic(
|
|
combined_args: Dict[str, Any],
|
|
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
|
|
) -> Union[Dict[str, Any], Tuple[Any], List[Any], None]:
|
|
"""
|
|
In the long run this might not be needed, but this exists because export.export() and _dynamo.export()
|
|
historically have different semantics for how dynamic_shapes are specified, but go through the same
|
|
process of producing constraints, and now both use assume_static_by_default=False.
|
|
|
|
For _dynamo.export(), the semantics for dynamic_shapes are:
|
|
- None: dynamic, allocated a symbol
|
|
- Dim/DerivedDim: a strict assertion on the min/max range for this symbol, and require a specification
|
|
for all dims governed by this symbol (i.e. relations, equality, linear relations, etc.)
|
|
|
|
For export.export(), historically dynamism for unspecified dims has been undesirable, so the semantics are:
|
|
- Dim.AUTO: dynamic, allocated a symbol
|
|
- None/unspecified/Dim.STATIC: static
|
|
- Dim/DerivedDims: also a strict assertion
|
|
|
|
To allow both APIs to follow the same process for producing constraints, this function converts dynamic_shapes
|
|
for export.export() to be compatible with _process_dynamic_shapes() and assume_static_by_default=False, turning them
|
|
into essentially what they'd look like for _dynamo.export().
|
|
|
|
An example conversion might look like, for a 3-d input tensor:
|
|
|
|
input spec: {
|
|
0: Dim.AUTO,
|
|
1: None, # or Dim.STATIC
|
|
2: Dim("dx"),
|
|
}
|
|
output spec: {
|
|
0: None, # None: dynamic by default
|
|
1: 32, # explicitly provide static shape
|
|
2: Dim("dx"), # remains the same
|
|
}
|
|
"""
|
|
|
|
def _tree_map_helper(tree, val):
|
|
"""
|
|
If the user generally specifies dynamic_shapes=None for a pytree input,
|
|
we'd like to convert this into a tree of Nones following the input spec,
|
|
so we can explicitly specify static dims for all tensor dimensions.
|
|
Non-builtin types for pytree (e.g. custom dataclasses) creates some difficulty,
|
|
in which case the correct format is a list containing specs for each child attribute.
|
|
"""
|
|
if (node_type := _get_node_type(tree)) not in SUPPORTED_NODES: # is_leaf
|
|
return val
|
|
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
|
child_pytrees, context = flatten_fn(tree) # flatten from whatever original type
|
|
unflatten_fn = SUPPORTED_NODES[
|
|
node_type if node_type in BUILTIN_TYPES else list
|
|
].unflatten_fn
|
|
children = [_tree_map_helper(child, val) for child in child_pytrees]
|
|
return unflatten_fn(
|
|
children, context
|
|
) # unflatten into original type, or list if not built-in type
|
|
|
|
if (
|
|
dynamic_shapes is None or len(dynamic_shapes) == 0
|
|
): # create pytree structure of static dim
|
|
dynamic_shapes = _tree_map_helper(combined_args, None)
|
|
if isinstance(dynamic_shapes, (tuple, list)):
|
|
combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc]
|
|
|
|
def transform_shapes(path, tensor, shape):
|
|
out: Union[None, List[Any], Dict[int, Any]] = None
|
|
if isinstance(shape, dict):
|
|
out = {}
|
|
for i, val in enumerate(tensor.shape):
|
|
dim = shape.get(i, _DimHint.STATIC)
|
|
if dim == _DimHint.AUTO:
|
|
# don't have to specify anything if dynamic
|
|
# None also works, since assume_static_by_default=False
|
|
torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing
|
|
elif isinstance(dim, _Dim):
|
|
out[i] = dim
|
|
elif isinstance(dim, int):
|
|
# important that this is dim and not val,
|
|
# so we can raise error if user-specified dim != val
|
|
out[i] = dim
|
|
elif dim is None:
|
|
_warn_on_None_dynamic_shape_dimension()
|
|
out[i] = val
|
|
else:
|
|
# make explicitly static
|
|
assert dim == _DimHint.STATIC
|
|
out[i] = val
|
|
elif isinstance(shape, (tuple, list)):
|
|
out = []
|
|
for i, val in enumerate(tensor.shape):
|
|
dim = shape[i]
|
|
if dim == _DimHint.AUTO:
|
|
torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing
|
|
out.append(None)
|
|
elif isinstance(dim, _Dim):
|
|
out.append(dim)
|
|
elif isinstance(dim, int):
|
|
out.append(dim)
|
|
elif dim is None:
|
|
_warn_on_None_dynamic_shape_dimension()
|
|
out.append(val)
|
|
else:
|
|
assert dim == _DimHint.STATIC
|
|
out.append(val)
|
|
out = type(shape)(out) # type: ignore[assignment]
|
|
else:
|
|
assert shape is None
|
|
if isinstance(tensor, torch.Tensor):
|
|
out = list(tensor.shape) or None
|
|
else:
|
|
out = None
|
|
return out
|
|
|
|
def transform_shape(path, t, dynamic_shape):
|
|
if isinstance(t, torch.Tensor):
|
|
return transform_shapes(path, t, dynamic_shape)
|
|
|
|
result = _tree_map_with_path(
|
|
transform_shape, combined_args, dynamic_shapes, tree_name="inputs"
|
|
)
|
|
return result
|
|
|
|
|
|
def _process_dynamic_shapes(
|
|
combined_args: Dict[str, Any],
|
|
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
|
|
) -> List[Constraint]:
|
|
"""
|
|
Reads the dynamic_shapes specification and produces a list of constraints.
|
|
"""
|
|
from torch._dynamo.exc import UserError, UserErrorType
|
|
|
|
if dynamic_shapes is None or len(dynamic_shapes) == 0:
|
|
# we run with dynamic by default, so no need to produce constraints
|
|
return []
|
|
if isinstance(dynamic_shapes, (tuple, list)):
|
|
combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc]
|
|
|
|
# map of Dim names representing input shape dimensions to constraints on them
|
|
symbols: Dict[str, List[Constraint]] = defaultdict(list)
|
|
# track roots that do not directly represent input shape dimensions
|
|
phantom_roots: Dict[str, _PhantomRoot] = {}
|
|
derived_constraints_with_phantom_root: List[_DerivedConstraint] = []
|
|
|
|
def to_constraint(dim, tensor, i):
|
|
import sympy
|
|
|
|
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
|
|
from torch.utils._sympy.solve import try_solve
|
|
from torch.utils._sympy.value_ranges import ValueRanges
|
|
|
|
def root_value():
|
|
# given tensor.shape[i] is the value of dim = fn(root),
|
|
# find the value of root
|
|
symbol = sympy.Symbol(dim.root.__name__, integer=True)
|
|
expr = dim.fn(symbol)
|
|
solution = try_solve(sympy.Eq(expr, tensor.shape[i]), symbol)
|
|
if solution is not None:
|
|
return int(solution[1]) # type: ignore[call-overload]
|
|
else:
|
|
raise UserError( # noqa: B904
|
|
UserErrorType.CONSTRAINT_VIOLATION,
|
|
f"Expected shape[{i}] = {tensor.shape[i]} of input Tensor to be "
|
|
f"of the form {expr}, where {symbol} is an integer",
|
|
)
|
|
|
|
if isinstance(dim, _DerivedDim):
|
|
# generate a _DerivedConstraint where the root is:
|
|
# - either a _ConstraintTarget (if dim.root directly describes an input shape)
|
|
# - or a _PhantomRoot (otherwise)
|
|
dim_root = dim.root # type: ignore[attr-defined]
|
|
if dim_root.__name__ in symbols:
|
|
# root represents an input shape dimension
|
|
root_constraint = symbols[dim_root.__name__][0]
|
|
root = _ConstraintTarget(
|
|
root_constraint.t_id,
|
|
root_constraint.dim,
|
|
)
|
|
elif dim_root.__name__ not in phantom_roots:
|
|
# create a phantom root
|
|
root = _PhantomRoot( # type: ignore[assignment]
|
|
name=dim_root.__name__,
|
|
constraint_range=StrictMinMaxConstraint(
|
|
vr=ValueRanges(lower=dim_root.min, upper=dim_root.max),
|
|
warn_only=False,
|
|
),
|
|
val=root_value(),
|
|
)
|
|
phantom_roots[dim_root.__name__] = root # type: ignore[assignment]
|
|
else:
|
|
root = phantom_roots[dim_root.__name__] # type: ignore[assignment]
|
|
constraint = _DerivedConstraint(
|
|
id(tensor),
|
|
i,
|
|
dim.__name__,
|
|
StrictMinMaxConstraint(
|
|
vr=ValueRanges(lower=dim.min, upper=dim.max),
|
|
warn_only=False,
|
|
),
|
|
root,
|
|
dim.fn, # type: ignore[attr-defined]
|
|
)
|
|
if isinstance(root, _PhantomRoot):
|
|
# NOTE(avik): since we have not processed all inputs yet, we may replace this
|
|
# with a root that does represent an input shape dimension later (see below)
|
|
derived_constraints_with_phantom_root.append(constraint)
|
|
elif isinstance(dim, _StaticDim):
|
|
constraint = _Constraint( # type: ignore[assignment]
|
|
id(tensor),
|
|
i,
|
|
dim.__name__,
|
|
StrictMinMaxConstraint(
|
|
vr=ValueRanges(lower=dim.value, upper=dim.value), warn_only=False # type: ignore[attr-defined]
|
|
),
|
|
)
|
|
else:
|
|
constraint = _Constraint( # type: ignore[assignment]
|
|
id(tensor),
|
|
i,
|
|
dim.__name__,
|
|
StrictMinMaxConstraint(
|
|
vr=ValueRanges(lower=dim.min, upper=dim.max), warn_only=False # type: ignore[attr-defined]
|
|
),
|
|
)
|
|
return constraint
|
|
|
|
def update_symbols(path, tensor, shape):
|
|
def _create_static_dim(tensor, i, value):
|
|
return _StaticDim(str(value), (int,), {"value": value})
|
|
|
|
if isinstance(shape, dict):
|
|
for i, dim in shape.items():
|
|
if isinstance(dim, (int, _Dim)):
|
|
if isinstance(dim, int):
|
|
dim = _create_static_dim(tensor, i, dim)
|
|
constraint = to_constraint(dim, tensor, i)
|
|
symbols[dim.__name__].append(constraint)
|
|
elif isinstance(shape, (tuple, list)):
|
|
for i, dim in enumerate(shape):
|
|
if isinstance(dim, (int, _Dim)):
|
|
if isinstance(dim, int):
|
|
dim = _create_static_dim(tensor, i, dim)
|
|
constraint = to_constraint(dim, tensor, i)
|
|
symbols[dim.__name__].append(constraint)
|
|
|
|
def assoc_shape(path, t, dynamic_shape):
|
|
if isinstance(t, torch.Tensor):
|
|
update_symbols(path, t, dynamic_shape)
|
|
|
|
_tree_map_with_path(assoc_shape, combined_args, dynamic_shapes, tree_name="inputs")
|
|
|
|
constraints = []
|
|
for derived_constraint_with_phantom_root in derived_constraints_with_phantom_root:
|
|
phantom_root_name = derived_constraint_with_phantom_root.root.name # type: ignore[union-attr]
|
|
if phantom_root_name in symbols:
|
|
# We found an input shape dimension corresponding to this name, so we
|
|
# do not need a phantom symbol for it after all.
|
|
# NOTE(avik): Overall we want to maintain the invariant that roots that
|
|
# are phantom symbols are really "phantom," i.e., they cannot be represented
|
|
# by any input source. This is important when we are deciding derived equalities,
|
|
# since we can focus our attention exclusively on input sources: deciding
|
|
# derived equalities involving phantom symbols are, in comparison, trivial.
|
|
derived_constraint_with_phantom_root.root = symbols[phantom_root_name][0]
|
|
|
|
for dynamic_dims in symbols.values():
|
|
constraints.extend(dynamic_dims)
|
|
|
|
return constraints # type: ignore[return-value]
|
|
|
|
|
|
def _get_dim_name_mapping(
|
|
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None]
|
|
):
|
|
name_to_dim = {}
|
|
for dim in tree_flatten(
|
|
dynamic_shapes,
|
|
is_leaf=lambda x: isinstance(x, _Dim),
|
|
)[0]:
|
|
if dim is None:
|
|
# NOTE: this must denote a non-Tensor or automatic at this point.
|
|
continue
|
|
if isinstance(dim, int):
|
|
continue
|
|
assert isinstance(dim, _Dim) # dim hints should have boiled away
|
|
name_to_dim[dim.__name__] = dim
|
|
if isinstance(dim, _DerivedDim):
|
|
name_to_dim[dim.root.__name__] = dim.root # type: ignore[attr-defined]
|
|
return name_to_dim
|
|
|
|
|
|
def refine_dynamic_shapes_from_suggested_fixes(
|
|
msg: str,
|
|
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]],
|
|
) -> Union[Dict[str, Any], Tuple[Any], List[Any]]:
|
|
"""
|
|
For working with export's dynamic shapes suggested fixes, and/or automatic dynamic shapes.
|
|
Refines the given dynamic shapes spec, given a ConstraintViolation error message and the original dynamic shapes.
|
|
|
|
For most cases behavior is straightforward - i.e. for suggested fixes that specialize or refine a Dim's range,
|
|
or fixes that suggest a derived relation, the new dynamic shapes spec will be updated as such.
|
|
|
|
e.g.
|
|
Suggested fixes:
|
|
|
|
dim = Dim('dim', min=3, max=6) -> this just refines the dim's range
|
|
dim = 4 -> this specializes to a constant
|
|
dy = dx + 1 -> dy was specified as an independent dim, but is actually tied to dx with this relation
|
|
|
|
However, suggested fixes associated with derived dims can be more complicated.
|
|
For example, if a suggested fix is provided for a root dim, the new derived dim value is evaluated based on the root.
|
|
|
|
e.g.
|
|
dx = Dim('dx')
|
|
dy = dx + 2
|
|
dynamic_shapes = {"x": (dx,), "y": (dy,)}
|
|
|
|
Suggested fixes:
|
|
|
|
dx = 4 # specialization will lead to dy also specializing = 6
|
|
dx = Dim('dx', max=6) # dy now has max = 8
|
|
|
|
Derived dims suggested fixes can also be used to express divisibility constraints.
|
|
This involves creating new root dims that aren't tied to a particular input shape.
|
|
In this case the root dims won't appear directly in the new spec, but as a root of
|
|
one of the dims.
|
|
|
|
e.g.
|
|
Suggested fixes:
|
|
|
|
_dx = Dim('_dx', max=1024) # this won't appear in the return result, but dx will
|
|
dx = 4*_dx # dx is now divisible by 4, with a max value of 4096
|
|
"""
|
|
|
|
import re
|
|
|
|
import sympy
|
|
|
|
from torch._dynamo.exc import UserError, UserErrorType
|
|
from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence
|
|
|
|
try:
|
|
shape_fixes_msg = msg.split("Suggested fixes:")[1].strip()
|
|
except Exception as exc:
|
|
raise UserError(
|
|
UserErrorType.INVALID_INPUT,
|
|
"Suggested fixes not found in error message given to refine_dynamic_shapes_from_suggested_fixes()",
|
|
) from exc
|
|
|
|
# build shape_fixes dictionary
|
|
shape_fixes = {}
|
|
for fix in shape_fixes_msg.split("\n"):
|
|
fix = fix.strip()
|
|
if match := re.match(r"(.*) = Dim\('(.*)'.*\)", fix):
|
|
name = match.group(1)
|
|
_min, _max = None, None
|
|
if match_min := re.match(r".* = Dim\('.*', min\=([0-9]+).*\)", fix):
|
|
_min = int(match_min.group(1))
|
|
if match_max := re.match(r".* = Dim\('.*'.*max\=([0-9]+)\)", fix):
|
|
_max = int(match_max.group(1))
|
|
shape_fixes[name] = Dim(name, min=_min, max=_max)
|
|
else:
|
|
name, expr = fix.split(" = ")
|
|
expr = sympy.sympify(expr)
|
|
if isinstance(expr, sympy.Number):
|
|
# static, integer
|
|
shape_fixes[name] = int(expr) # type: ignore[assignment]
|
|
else:
|
|
# relation or derived dim
|
|
shape_fixes[name] = expr
|
|
|
|
name_to_dim = _get_dim_name_mapping(dynamic_shapes)
|
|
|
|
# track derived dim roots
|
|
roots: Set[str] = set()
|
|
for k, c in shape_fixes.items():
|
|
assert isinstance(c, (int, _Dim, _DerivedDim, sympy.Expr))
|
|
if isinstance(c, sympy.Expr): # check dim/derived dim expression
|
|
assert _is_supported_equivalence(c)
|
|
shape_fixes[k] = c
|
|
roots.add(str(next(iter(c.free_symbols))))
|
|
if isinstance(c, _DerivedDim):
|
|
roots.add(c.root.__name__) # type: ignore[attr-defined]
|
|
|
|
# check keys are existing dims or new roots
|
|
for k, c in shape_fixes.items():
|
|
assert k in name_to_dim or k in roots
|
|
|
|
# cache so we don't produce multiple derived dim objects
|
|
derived_dim_cache: Dict[str, _DerivedDim] = {}
|
|
|
|
def apply_fixes(path, dim, dummy):
|
|
if dim is None or isinstance(dim, int): # not dynamic
|
|
return dim
|
|
elif dim.__name__ in shape_fixes: # directly fix
|
|
fix = shape_fixes[dim.__name__]
|
|
if isinstance(fix, sympy.Expr): # now derived or related
|
|
if str(fix) in derived_dim_cache:
|
|
return derived_dim_cache[str(fix)]
|
|
else:
|
|
symbol = next(iter(fix.free_symbols))
|
|
# try to locate symbol
|
|
if symbol.name in shape_fixes: # type: ignore[attr-defined]
|
|
root = shape_fixes[symbol.name] # type: ignore[attr-defined]
|
|
else:
|
|
assert symbol.name in name_to_dim # type: ignore[attr-defined]
|
|
root = name_to_dim[symbol.name] # type: ignore[attr-defined]
|
|
# figure out value of fix
|
|
modulus, remainder = sympy.polys.polytools.div(fix, symbol)
|
|
dim = root
|
|
if modulus != 1:
|
|
dim = int(modulus) * dim
|
|
if remainder != 0:
|
|
dim = dim + int(remainder)
|
|
derived_dim_cache[str(fix)] = dim
|
|
return dim
|
|
else:
|
|
return fix
|
|
elif isinstance(dim, _DerivedDim) and dim.root.__name__ in shape_fixes: # type: ignore[attr-defined]
|
|
if dim.__name__ in derived_dim_cache:
|
|
return derived_dim_cache[dim.__name__]
|
|
else: # evaluate new derived value based on root
|
|
_dim = dim.fn(shape_fixes[dim.root.__name__]) # type: ignore[attr-defined]
|
|
derived_dim_cache[dim.__name__] = _dim
|
|
return _dim
|
|
return dim # unchanged dim
|
|
|
|
return _tree_map_with_path(apply_fixes, dynamic_shapes, dynamic_shapes)
|