[sigmoid] replace unflatten with upstream version (#115468)

as title

Differential Revision: [D52000213](https://our.internmc.facebook.com/intern/diff/D52000213/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115468
Approved by: https://github.com/zhxchen17
This commit is contained in:
suo 2023-12-21 12:55:25 -08:00 committed by PyTorch MergeBot
parent 127cae7ec8
commit d2d129de65
2 changed files with 10 additions and 7 deletions

View File

@ -58,13 +58,14 @@ __all__ = [
"save",
"unflatten",
"FlatArgsAdapter",
"UnflattenedModule",
]
from .dynamic_shapes import Constraint, Dim, dims, dynamic_dim
from .exported_program import ExportedProgram, ModuleCallEntry, ModuleCallSignature
from .graph_signature import ExportBackwardSignature, ExportGraphSignature
from .unflatten import FlatArgsAdapter, unflatten
from .unflatten import FlatArgsAdapter, unflatten, UnflattenedModule
PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]

View File

@ -189,6 +189,7 @@ class UnflattenedModule(torch.nn.Module):
self.input_placeholders = [
node for node in self.graph.nodes if node.op == "placeholder"
]
self.check_input_constraints = True
def forward(self, *args, **kwargs):
if is_fx_tracing():
@ -227,13 +228,14 @@ class UnflattenedModule(torch.nn.Module):
f"Exported module: {signature.in_spec.num_leaves}"
)
# Import here to avoid an unfortunate circular dependency.
# TODO(suo): untangle this.
from torch._export.utils import _check_input_constraints_for_graph
if self.check_input_constraints:
# Import here to avoid an unfortunate circular dependency.
# TODO(suo): untangle this.
from torch._export.utils import _check_input_constraints_for_graph
_check_input_constraints_for_graph(
self.input_placeholders, flat_args, self.range_constraints
)
_check_input_constraints_for_graph(
self.input_placeholders, flat_args, self.range_constraints
)
tree_out = torch.fx.Interpreter(self, graph=self.graph).run(
*flat_args, enable_io_processing=False
)