mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
[sigmoid] replace unflatten with upstream version (#115468)
as title Differential Revision: [D52000213](https://our.internmc.facebook.com/intern/diff/D52000213/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/115468 Approved by: https://github.com/zhxchen17
This commit is contained in:
parent
127cae7ec8
commit
d2d129de65
|
|
@ -58,13 +58,14 @@ __all__ = [
|
|||
"save",
|
||||
"unflatten",
|
||||
"FlatArgsAdapter",
|
||||
"UnflattenedModule",
|
||||
]
|
||||
|
||||
|
||||
from .dynamic_shapes import Constraint, Dim, dims, dynamic_dim
|
||||
from .exported_program import ExportedProgram, ModuleCallEntry, ModuleCallSignature
|
||||
from .graph_signature import ExportBackwardSignature, ExportGraphSignature
|
||||
from .unflatten import FlatArgsAdapter, unflatten
|
||||
from .unflatten import FlatArgsAdapter, unflatten, UnflattenedModule
|
||||
|
||||
|
||||
PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
|
||||
|
|
|
|||
|
|
@ -189,6 +189,7 @@ class UnflattenedModule(torch.nn.Module):
|
|||
self.input_placeholders = [
|
||||
node for node in self.graph.nodes if node.op == "placeholder"
|
||||
]
|
||||
self.check_input_constraints = True
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if is_fx_tracing():
|
||||
|
|
@ -227,13 +228,14 @@ class UnflattenedModule(torch.nn.Module):
|
|||
f"Exported module: {signature.in_spec.num_leaves}"
|
||||
)
|
||||
|
||||
# Import here to avoid an unfortunate circular dependency.
|
||||
# TODO(suo): untangle this.
|
||||
from torch._export.utils import _check_input_constraints_for_graph
|
||||
if self.check_input_constraints:
|
||||
# Import here to avoid an unfortunate circular dependency.
|
||||
# TODO(suo): untangle this.
|
||||
from torch._export.utils import _check_input_constraints_for_graph
|
||||
|
||||
_check_input_constraints_for_graph(
|
||||
self.input_placeholders, flat_args, self.range_constraints
|
||||
)
|
||||
_check_input_constraints_for_graph(
|
||||
self.input_placeholders, flat_args, self.range_constraints
|
||||
)
|
||||
tree_out = torch.fx.Interpreter(self, graph=self.graph).run(
|
||||
*flat_args, enable_io_processing=False
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user