mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Previous discussion: https://github.com/pytorch/pytorch/pull/109476 In this PR, I made following additions to the original PR: 1) Unlifted graph module now runs the runtime assertions in its' forward call. 2) When we retrace, we make sure we run the assertions to make sure user is tracing the module with correct inputs with respect to the assumptions we made during first tracing. The way I do is that I create new graph module type with modified call method. And the runtime assertions happen under torchdynamo.disable so that it is just run in eager directly. The reason is we don't this to be traced part of the graph. 3) Both ep.module and capture_pre_autograd now returns _UnliftedGraphModule. Differential Revision: [D51078056](https://our.internmc.facebook.com/intern/diff/D51078056) Pull Request resolved: https://github.com/pytorch/pytorch/pull/110222 Approved by: https://github.com/zhxchen17
625 lines
23 KiB
Python
625 lines
23 KiB
Python
import copy
|
|
import dataclasses
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
Iterator,
|
|
List,
|
|
Optional,
|
|
Tuple,
|
|
Type,
|
|
TYPE_CHECKING,
|
|
Union,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
# Import the following modules during type checking to enable code intelligence features,
|
|
# such as auto-completion in tools like pylance, even when these modules are not explicitly
|
|
# imported in user code.
|
|
|
|
import sympy
|
|
|
|
from torch.utils._sympy.value_ranges import ValueRanges
|
|
|
|
import torch
|
|
import torch.fx._pytree as fx_pytree
|
|
import torch.utils._pytree as pytree
|
|
from torch.fx._compatibility import compatibility
|
|
|
|
from torch.fx.passes.infra.pass_base import PassResult
|
|
from torch.fx.passes.infra.pass_manager import PassManager
|
|
|
|
from .graph_signature import ( # noqa: F401
|
|
_sig_to_specs,
|
|
ArgumentSpec,
|
|
ConstantArgument,
|
|
ExportGraphSignature,
|
|
InputKind,
|
|
InputSpec,
|
|
OutputKind,
|
|
OutputSpec,
|
|
SymIntArgument,
|
|
TensorArgument,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"ExportedProgram",
|
|
"ModuleCallEntry",
|
|
"ModuleCallSignature",
|
|
]
|
|
|
|
|
|
PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ModuleCallSignature:
|
|
inputs: List[ArgumentSpec]
|
|
outputs: List[ArgumentSpec]
|
|
in_spec: pytree.TreeSpec
|
|
out_spec: pytree.TreeSpec
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ModuleCallEntry:
|
|
fqn: str
|
|
signature: Optional[ModuleCallSignature] = None
|
|
|
|
|
|
class ExportedProgram:
|
|
"""
|
|
Package of a program from :func:`export`. It contains
|
|
an :class:`torch.fx.Graph` that represents Tensor computation, a state_dict containing
|
|
tensor values of all lifted parameters and buffers, and various metadata.
|
|
|
|
You can call an ExportedProgram like the original callable traced by
|
|
:func:`export` with the same calling convention.
|
|
|
|
To perform transformations on the graph, use ``.module`` property to access
|
|
an :class:`torch.fx.GraphModule`. You can then use
|
|
`FX transformation <https://pytorch.org/docs/stable/fx.html#writing-transformations>`_
|
|
to rewrite the graph. Afterwards, you can simply use :func:`export`
|
|
again to construct a correct ExportedProgram.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
root: Union[torch.nn.Module, Dict[str, Any]],
|
|
graph: torch.fx.Graph,
|
|
graph_signature: ExportGraphSignature,
|
|
state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]],
|
|
range_constraints: "Dict[sympy.Symbol, Any]",
|
|
equality_constraints: List[Tuple[Any, Any]],
|
|
module_call_graph: List[ModuleCallEntry],
|
|
example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None,
|
|
verifier: Optional[Type[Any]] = None, # TODO Change typing hint to Verifier.
|
|
):
|
|
from torch._export.exported_program import _create_graph_module_for_export
|
|
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
|
|
InputDim,
|
|
)
|
|
|
|
# Remove codegen related things from the graph. It should just be a flat graph.
|
|
graph._codegen = torch.fx.graph.CodeGen()
|
|
self._graph_module = _create_graph_module_for_export(root, graph)
|
|
if isinstance(root, torch.fx.GraphModule):
|
|
self._graph_module.meta.update(root.meta)
|
|
|
|
self._graph_signature: ExportGraphSignature = graph_signature
|
|
self._state_dict: Dict[str, Any] = state_dict
|
|
self._range_constraints: "Dict[sympy.Symbol, ValueRanges]" = range_constraints
|
|
self._equality_constraints: List[
|
|
Tuple[InputDim, InputDim]
|
|
] = equality_constraints
|
|
self._module_call_graph: List[ModuleCallEntry] = module_call_graph
|
|
self._example_inputs = example_inputs
|
|
|
|
from torch._export.verifier import Verifier
|
|
|
|
if verifier is None:
|
|
verifier = Verifier
|
|
assert issubclass(verifier, Verifier)
|
|
self._verifier = verifier
|
|
# Validate should be always the last step of the constructor.
|
|
self.verifier().check(self)
|
|
|
|
@property
|
|
@compatibility(is_backward_compatible=False)
|
|
def graph_module(self):
|
|
return self._graph_module
|
|
|
|
@property
|
|
@compatibility(is_backward_compatible=False)
|
|
def graph(self):
|
|
return self.graph_module.graph
|
|
|
|
@property
|
|
@compatibility(is_backward_compatible=False)
|
|
def graph_signature(self):
|
|
return self._graph_signature
|
|
|
|
@property
|
|
@compatibility(is_backward_compatible=False)
|
|
def state_dict(self):
|
|
return self._state_dict
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def parameters(self) -> Iterator[torch.nn.Parameter]:
|
|
"""
|
|
Returns an iterator over original module's parameters.
|
|
"""
|
|
for _, param in self.named_parameters():
|
|
yield param
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def named_parameters(self) -> Iterator[Tuple[str, torch.nn.Parameter]]:
|
|
"""
|
|
Returns an iterator over original module parameters, yielding
|
|
both the name of the parameter as well as the parameter itself.
|
|
"""
|
|
for param_name in self.graph_signature.parameters:
|
|
yield param_name, self.state_dict[param_name]
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def buffers(self) -> Iterator[torch.Tensor]:
|
|
"""
|
|
Returns an iterator over original module buffers.
|
|
"""
|
|
for _, buf in self.named_buffers():
|
|
yield buf
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def named_buffers(self) -> Iterator[Tuple[str, torch.Tensor]]:
|
|
"""
|
|
Returns an iterator over original module buffers, yielding
|
|
both the name of the buffer as well as the buffer itself.
|
|
"""
|
|
for buffer_name in self.graph_signature.buffers:
|
|
yield buffer_name, self.state_dict[buffer_name]
|
|
|
|
@property
|
|
@compatibility(is_backward_compatible=False)
|
|
def range_constraints(self):
|
|
return self._range_constraints
|
|
|
|
@property
|
|
@compatibility(is_backward_compatible=False)
|
|
def equality_constraints(self):
|
|
return self._equality_constraints
|
|
|
|
@property
|
|
@compatibility(is_backward_compatible=False)
|
|
def module_call_graph(self):
|
|
return self._module_call_graph
|
|
|
|
@property
|
|
@compatibility(is_backward_compatible=False)
|
|
def example_inputs(self):
|
|
return self._example_inputs
|
|
|
|
@property
|
|
@compatibility(is_backward_compatible=False)
|
|
def call_spec(self):
|
|
from torch._export.exported_program import CallSpec
|
|
|
|
if len(self.module_call_graph) == 0:
|
|
return CallSpec(in_spec=None, out_spec=None)
|
|
assert self.module_call_graph[0].fqn == ""
|
|
return CallSpec(
|
|
in_spec=self.module_call_graph[0].signature.in_spec,
|
|
out_spec=self.module_call_graph[0].signature.out_spec,
|
|
)
|
|
|
|
@property
|
|
@compatibility(is_backward_compatible=False)
|
|
def verifier(self) -> Any:
|
|
return self._verifier
|
|
|
|
@property
|
|
@compatibility(is_backward_compatible=False)
|
|
def dialect(self) -> str:
|
|
return self._verifier.dialect
|
|
|
|
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
import torch._export.error as error
|
|
from torch._export import combine_args_kwargs
|
|
|
|
if self.call_spec.in_spec is not None:
|
|
try:
|
|
user_args = combine_args_kwargs(args, kwargs)
|
|
args = fx_pytree.tree_flatten_spec(
|
|
user_args, self.call_spec.in_spec, exact_structural_match=True
|
|
) # type: ignore[assignment]
|
|
except Exception:
|
|
_, received_spec = pytree.tree_flatten(user_args)
|
|
raise TypeError( # noqa: TRY200
|
|
"Trying to flatten user inputs with exported input tree spec: \n"
|
|
f"{self.call_spec.in_spec}\n"
|
|
"but actually got inputs with tree spec of: \n"
|
|
f"{received_spec}"
|
|
)
|
|
|
|
ordered_params = tuple(
|
|
self.state_dict[name] for name in self.graph_signature.parameters
|
|
)
|
|
ordered_buffers = tuple(
|
|
self.state_dict[name] for name in self.graph_signature.buffers
|
|
)
|
|
self._check_input_constraints(*ordered_params, *ordered_buffers, *args)
|
|
|
|
# NOTE: calling convention is first params, then buffers, then args as user supplied them.
|
|
# See: torch/_functorch/aot_autograd.py#L1034
|
|
res = torch.fx.Interpreter(self.graph_module).run(
|
|
*ordered_params, *ordered_buffers, *args, enable_io_processing=False
|
|
)
|
|
|
|
if self.call_spec.out_spec is not None:
|
|
mutation = self.graph_signature.buffers_to_mutate
|
|
num_mutated = len(mutation)
|
|
mutated_buffers = res[:num_mutated]
|
|
|
|
# Exclude dependency token from final result.
|
|
assertion_dep_token = self.graph_signature.assertion_dep_token
|
|
if assertion_dep_token is not None:
|
|
assertion_dep_token_index = next(iter(assertion_dep_token.keys()))
|
|
res = res[:assertion_dep_token_index]
|
|
|
|
res = res[num_mutated:]
|
|
try:
|
|
res = pytree.tree_unflatten(res, self.call_spec.out_spec)
|
|
except Exception:
|
|
_, received_spec = pytree.tree_flatten(res)
|
|
raise error.InternalError( # noqa: TRY200
|
|
"Trying to flatten user outputs with exported output tree spec: \n"
|
|
f"{self.call_spec.out_spec}\n"
|
|
"but actually got outputs with tree spec of: \n"
|
|
f"{received_spec}"
|
|
)
|
|
finally:
|
|
ix = 0
|
|
for buffer in self.graph_signature.buffers_to_mutate.values():
|
|
self.state_dict[buffer] = mutated_buffers[ix]
|
|
ix += 1
|
|
return res
|
|
|
|
def __str__(self) -> str:
|
|
graph_module = self.graph_module.print_readable(print_output=False).replace(
|
|
"\n", "\n "
|
|
)
|
|
string = (
|
|
"ExportedProgram:\n"
|
|
f" {graph_module}\n"
|
|
f"Graph signature: {self.graph_signature}\n"
|
|
f"Range constraints: {self.range_constraints}\n"
|
|
f"Equality constraints: {self.equality_constraints}\n"
|
|
)
|
|
return string
|
|
|
|
def module(self, *, flat: bool = True) -> torch.nn.Module:
|
|
"""
|
|
Returns a self contained GraphModule with all the parameters/buffers inlined.
|
|
"""
|
|
from torch._export.exported_program import unlift_exported_program_lifted_states
|
|
from torch._export.unflatten import unflatten
|
|
|
|
if flat:
|
|
return unlift_exported_program_lifted_states(self)
|
|
else:
|
|
return unflatten(self)
|
|
|
|
def run_decompositions(
|
|
self, decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None
|
|
) -> "ExportedProgram":
|
|
"""
|
|
Run a set of decompositions on the exported program and returns a new
|
|
exported program. By default we will run the Core ATen decompositions to
|
|
get operators in the
|
|
`Core ATen Operator Set <https://pytorch.org/docs/stable/torch.compiler_ir.html>`_.
|
|
|
|
For now, we do not decompose joint graphs.
|
|
"""
|
|
from torch._decomp import core_aten_decompositions
|
|
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
|
|
_AddRuntimeAssertionsForInlineConstraintsPass,
|
|
InputDim,
|
|
)
|
|
from torch._export.passes.lift_constant_tensor_pass import (
|
|
lift_constant_tensor_pass,
|
|
)
|
|
from torch._export.passes.replace_sym_size_ops_pass import (
|
|
_replace_sym_size_ops_pass,
|
|
)
|
|
from torch._functorch.aot_autograd import aot_export_module
|
|
|
|
def _get_placeholders(gm):
|
|
placeholders = []
|
|
for node in gm.graph.nodes:
|
|
if node.op != "placeholder":
|
|
break
|
|
placeholders.append(node)
|
|
return placeholders
|
|
|
|
decomp_table = decomp_table or core_aten_decompositions()
|
|
|
|
old_placeholders = _get_placeholders(self.graph_module)
|
|
fake_args = [node.meta["val"] for node in old_placeholders]
|
|
|
|
buffers_to_remove = [name for name, _ in self.graph_module.named_buffers()]
|
|
for name in buffers_to_remove:
|
|
delattr(self.graph_module, name)
|
|
# TODO(zhxhchen17) Return the new graph_signature directly.
|
|
gm, graph_signature = aot_export_module(
|
|
self.graph_module, fake_args, decompositions=decomp_table, trace_joint=False
|
|
)
|
|
|
|
# Update the signatures with the new placeholder names in case they
|
|
# changed when calling aot_export
|
|
new_placeholders = _get_placeholders(gm)
|
|
assert len(new_placeholders) == len(old_placeholders)
|
|
old_new_placeholder_map = {
|
|
old_node.name: new_node.name
|
|
for old_node, new_node in zip(old_placeholders, new_placeholders)
|
|
}
|
|
old_outputs = list(self.graph.nodes)[-1].args[0]
|
|
new_outputs = list(gm.graph.nodes)[-1].args[0]
|
|
assert len(new_outputs) == len(old_outputs)
|
|
old_new_output_map = {
|
|
old_node.name: new_node.name
|
|
for old_node, new_node in zip(old_outputs, new_outputs)
|
|
}
|
|
|
|
def make_argument_spec(old_node, node) -> ArgumentSpec:
|
|
if "val" not in node.meta:
|
|
assert len(node.users) == 0
|
|
val = old_node.meta["val"]
|
|
else:
|
|
val = node.meta["val"]
|
|
if isinstance(val, torch.Tensor):
|
|
return TensorArgument(name=node.name)
|
|
elif isinstance(val, torch.SymInt):
|
|
return SymIntArgument(name=node.name)
|
|
else:
|
|
return ConstantArgument(value=val)
|
|
|
|
input_specs, output_specs = _sig_to_specs(
|
|
user_inputs={
|
|
old_new_placeholder_map[inp] for inp in self.graph_signature.user_inputs
|
|
},
|
|
inputs_to_parameters={
|
|
old_new_placeholder_map[inp]: param
|
|
for inp, param in self.graph_signature.inputs_to_parameters.items()
|
|
},
|
|
inputs_to_buffers={
|
|
old_new_placeholder_map[inp]: buffer
|
|
for inp, buffer in self.graph_signature.inputs_to_buffers.items()
|
|
},
|
|
user_outputs={
|
|
old_new_output_map[out] for out in self.graph_signature.user_outputs
|
|
},
|
|
buffer_mutations={
|
|
old_new_output_map[out]: buffer
|
|
for out, buffer in self.graph_signature.buffers_to_mutate.items()
|
|
},
|
|
grad_params={},
|
|
grad_user_inputs={},
|
|
loss_output=None,
|
|
inputs=[
|
|
make_argument_spec(old_placeholders[i], node)
|
|
for i, node in enumerate(gm.graph.nodes)
|
|
if node.op == "placeholder"
|
|
],
|
|
outputs=[
|
|
make_argument_spec(old_outputs[i], node)
|
|
for i, node in enumerate(
|
|
pytree.tree_leaves(next(iter(reversed(gm.graph.nodes))).args)
|
|
)
|
|
],
|
|
)
|
|
|
|
new_graph_signature = ExportGraphSignature(
|
|
input_specs=input_specs, output_specs=output_specs
|
|
)
|
|
# NOTE: aot_export adds symint metadata for placeholders with int
|
|
# values; since these become specialized, we replace such metadata with
|
|
# the original values.
|
|
# Also, set the param/buffer metadata back to the placeholders.
|
|
for old_node, new_node in zip(old_placeholders, new_placeholders):
|
|
if not isinstance(old_node.meta["val"], torch.Tensor):
|
|
new_node.meta["val"] = old_node.meta["val"]
|
|
|
|
if (
|
|
new_node.target in new_graph_signature.inputs_to_parameters
|
|
or new_node.target in new_graph_signature.inputs_to_buffers
|
|
):
|
|
for k, v in old_node.meta.items():
|
|
new_node.meta[k] = v
|
|
|
|
# TODO unfortunately preserving graph-level metadata is not
|
|
# working well with aot_export. So we manually copy it.
|
|
# (The node-level meta is addressed above.)
|
|
gm.meta.update(self.graph_module.meta)
|
|
|
|
new_range_constraints = _get_updated_range_constraints(gm)
|
|
|
|
new_equality_constraints = [
|
|
(
|
|
InputDim(old_new_placeholder_map[inp_dim1.input_name], inp_dim1.dim),
|
|
InputDim(old_new_placeholder_map[inp_dim2.input_name], inp_dim2.dim),
|
|
)
|
|
for inp_dim1, inp_dim2 in self.equality_constraints
|
|
]
|
|
|
|
state_dict = self.state_dict.copy()
|
|
lift_constant_tensor_pass(gm, new_graph_signature, state_dict)
|
|
_replace_sym_size_ops_pass(gm)
|
|
exported_program = ExportedProgram(
|
|
gm,
|
|
gm.graph,
|
|
new_graph_signature,
|
|
state_dict,
|
|
new_range_constraints,
|
|
new_equality_constraints,
|
|
copy.deepcopy(self.module_call_graph),
|
|
self.example_inputs,
|
|
self.verifier,
|
|
)
|
|
|
|
if len(new_range_constraints) > 0 or len(new_equality_constraints) > 0:
|
|
exported_program = exported_program._transform(
|
|
_AddRuntimeAssertionsForInlineConstraintsPass(
|
|
new_range_constraints, new_equality_constraints
|
|
)
|
|
)
|
|
|
|
return exported_program
|
|
|
|
def _transform(self, *passes: PassType) -> "ExportedProgram":
|
|
pm = PassManager(list(passes))
|
|
res = pm(self.graph_module)
|
|
transformed_gm = res.graph_module if res is not None else self.graph_module
|
|
assert transformed_gm is not None
|
|
|
|
if transformed_gm is self.graph_module and not res.modified:
|
|
return self
|
|
|
|
# TODO(zhxchen17) Remove this.
|
|
def _get_updated_graph_signature(
|
|
old_signature: ExportGraphSignature,
|
|
new_gm: torch.fx.GraphModule,
|
|
) -> ExportGraphSignature:
|
|
"""
|
|
Update the graph signature's user_input/user_outputs.
|
|
"""
|
|
new_graph_inputs = [
|
|
node.name for node in new_gm.graph.nodes if node.op == "placeholder"
|
|
]
|
|
num_inputs = (
|
|
len(old_signature.parameters)
|
|
+ len(old_signature.buffers)
|
|
+ len(
|
|
[
|
|
s
|
|
for s in old_signature.input_specs
|
|
if s.kind == InputKind.USER_INPUT
|
|
]
|
|
)
|
|
)
|
|
|
|
assert len(new_graph_inputs) == num_inputs, (
|
|
f"Number of input nodes changed from {len(new_graph_inputs)} "
|
|
f"to {num_inputs} after transformation. This transformation "
|
|
"is currently not supported."
|
|
)
|
|
num_param_buffers = len(old_signature.buffers) + len(
|
|
old_signature.parameters
|
|
)
|
|
new_user_inputs = new_graph_inputs[num_param_buffers:]
|
|
|
|
output_node = list(new_gm.graph.nodes)[-1]
|
|
assert output_node.op == "output"
|
|
new_graph_outputs = [arg.name for arg in output_node.args[0]]
|
|
|
|
assert len(new_graph_outputs) == len(old_signature.buffers_to_mutate) + len(
|
|
[
|
|
s
|
|
for s in old_signature.output_specs
|
|
if s.kind == OutputKind.USER_OUTPUT
|
|
]
|
|
), (
|
|
f"Number of output nodes changed from {len(new_graph_outputs)} "
|
|
f"to {len(old_signature.buffers_to_mutate) + len(old_signature.user_outputs)} "
|
|
"after transformation. This transformation is currently not supported."
|
|
)
|
|
new_user_outputs = new_graph_outputs[len(old_signature.buffers_to_mutate) :]
|
|
|
|
def make_argument_spec(node) -> ArgumentSpec:
|
|
val = node.meta["val"]
|
|
if isinstance(val, torch.Tensor):
|
|
return TensorArgument(name=node.name)
|
|
elif isinstance(val, torch.SymInt):
|
|
return SymIntArgument(name=node.name)
|
|
else:
|
|
return ConstantArgument(value=val)
|
|
|
|
input_specs, output_specs = _sig_to_specs(
|
|
user_inputs=set(new_user_inputs),
|
|
inputs_to_parameters=old_signature.inputs_to_parameters,
|
|
inputs_to_buffers=old_signature.inputs_to_buffers,
|
|
user_outputs=set(new_user_outputs),
|
|
buffer_mutations=old_signature.buffers_to_mutate,
|
|
grad_params={},
|
|
grad_user_inputs={},
|
|
loss_output=None,
|
|
inputs=[
|
|
make_argument_spec(node)
|
|
for node in transformed_gm.graph.nodes
|
|
if node.op == "placeholder"
|
|
],
|
|
outputs=[
|
|
make_argument_spec(node)
|
|
for node in pytree.tree_flatten(
|
|
next(iter(reversed(transformed_gm.graph.nodes))).args
|
|
)[0]
|
|
],
|
|
)
|
|
new_signature = ExportGraphSignature(
|
|
input_specs=input_specs, output_specs=output_specs
|
|
)
|
|
return new_signature
|
|
|
|
transformed_ep = ExportedProgram(
|
|
transformed_gm,
|
|
transformed_gm.graph,
|
|
_get_updated_graph_signature(self.graph_signature, transformed_gm),
|
|
self.state_dict,
|
|
_get_updated_range_constraints(transformed_gm),
|
|
copy.deepcopy(self.equality_constraints),
|
|
copy.deepcopy(self._module_call_graph),
|
|
self.example_inputs,
|
|
self.verifier,
|
|
)
|
|
transformed_ep.graph_module.meta.update(self.graph_module.meta)
|
|
transformed_ep.graph_module.meta.update(res.graph_module.meta)
|
|
return transformed_ep
|
|
|
|
def _check_input_constraints(self, *args):
|
|
from torch._export.utils import _check_input_constraints_for_graph
|
|
|
|
_check_input_constraints_for_graph(
|
|
self.graph, self.range_constraints, self.equality_constraints
|
|
)(*args)
|
|
|
|
def _validate(self):
|
|
self.verifier().check(self)
|
|
|
|
|
|
def _get_updated_range_constraints(
|
|
gm: torch.fx.GraphModule,
|
|
) -> "Dict[sympy.Symbol, Any]":
|
|
def get_shape_env(gm):
|
|
vals = [
|
|
node.meta["val"]
|
|
for node in gm.graph.nodes
|
|
if node.meta.get("val", None) is not None
|
|
]
|
|
from torch._guards import detect_fake_mode
|
|
|
|
fake_mode = detect_fake_mode(vals)
|
|
if fake_mode is not None:
|
|
return fake_mode.shape_env
|
|
for v in vals:
|
|
if isinstance(v, torch.SymInt):
|
|
return v.node.shape_env
|
|
|
|
shape_env = get_shape_env(gm)
|
|
if shape_env is None:
|
|
return {}
|
|
range_constraints = {
|
|
k: v
|
|
for k, v in shape_env.var_to_range.items()
|
|
if k not in shape_env.replacements
|
|
}
|
|
return range_constraints
|