mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
[export] Turn off output value from sources for export. (#115442)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/115442 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
parent
af09fe256a
commit
f78f23d753
|
|
@ -1562,6 +1562,32 @@ def forward(self, l_x_):
|
||||||
with self.assertRaisesRegex(RuntimeError, "shape\[0\] is specialized at 4"):
|
with self.assertRaisesRegex(RuntimeError, "shape\[0\] is specialized at 4"):
|
||||||
ep_v2(*test_inp)
|
ep_v2(*test_inp)
|
||||||
|
|
||||||
|
def test_constant_output(self):
|
||||||
|
class ModuleConstant(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.b = torch.randn(3, 2)
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return self.b
|
||||||
|
|
||||||
|
class ModuleNestedConstant(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.bff = torch.randn(3, 2)
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
return {"prediction": (x + y, self.bff)}
|
||||||
|
|
||||||
|
mod = ModuleConstant()
|
||||||
|
ep = torch.export.export(mod, ())
|
||||||
|
self.assertEqual(ep(), mod())
|
||||||
|
|
||||||
|
args = (torch.randn(3, 2), torch.randn(3, 2))
|
||||||
|
mod = ModuleNestedConstant()
|
||||||
|
ep = torch.export.export(mod, args)
|
||||||
|
self.assertEqual(ep(*args), mod(*args))
|
||||||
|
|
||||||
def test_non_arg_name_dynamic_shapes_api_with_kwarg(self):
|
def test_non_arg_name_dynamic_shapes_api_with_kwarg(self):
|
||||||
def foo(a, b, kw1, kw2):
|
def foo(a, b, kw1, kw2):
|
||||||
return a.sum() + b.sum() + kw1.sum() - kw2.sum()
|
return a.sum() + b.sum() + kw1.sum() - kw2.sum()
|
||||||
|
|
|
||||||
|
|
@ -60,14 +60,18 @@ class PyCodegen:
|
||||||
self.cell_and_freevars = self.tx.cell_and_freevars
|
self.cell_and_freevars = self.tx.cell_and_freevars
|
||||||
self.new_var = self.tx.output.new_var
|
self.new_var = self.tx.output.new_var
|
||||||
self.mutable_side_effects_from_source = False
|
self.mutable_side_effects_from_source = False
|
||||||
|
self.value_from_source: bool = True
|
||||||
|
|
||||||
def restore_stack(self, stack_values):
|
def restore_stack(self, stack_values, *, value_from_source=True):
|
||||||
prior = self.mutable_side_effects_from_source
|
prior = self.mutable_side_effects_from_source
|
||||||
self.mutable_side_effects_from_source = True
|
self.mutable_side_effects_from_source = True
|
||||||
|
prev = self.value_from_source
|
||||||
|
self.value_from_source &= value_from_source
|
||||||
try:
|
try:
|
||||||
self.foreach(stack_values)
|
self.foreach(stack_values)
|
||||||
finally:
|
finally:
|
||||||
self.mutable_side_effects_from_source = prior
|
self.mutable_side_effects_from_source = prior
|
||||||
|
self.value_from_source = prev
|
||||||
|
|
||||||
def graph_output_vars(self):
|
def graph_output_vars(self):
|
||||||
return [x.variable for x in self.graph_outputs.values()]
|
return [x.variable for x in self.graph_outputs.values()]
|
||||||
|
|
@ -108,7 +112,7 @@ class PyCodegen:
|
||||||
self.top_of_stack = value
|
self.top_of_stack = value
|
||||||
return
|
return
|
||||||
|
|
||||||
if value.source is not None and allow_cache:
|
if value.source is not None and allow_cache and self.value_from_source:
|
||||||
output.extend(value.source.reconstruct(self))
|
output.extend(value.source.reconstruct(self))
|
||||||
elif value.is_python_constant() and is_safe_constant(
|
elif value.is_python_constant() and is_safe_constant(
|
||||||
value.as_python_constant()
|
value.as_python_constant()
|
||||||
|
|
|
||||||
|
|
@ -911,7 +911,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
||||||
pass1 = PyCodegen(tx, root, graph_output_var)
|
pass1 = PyCodegen(tx, root, graph_output_var)
|
||||||
self.side_effects.codegen_hooks(pass1)
|
self.side_effects.codegen_hooks(pass1)
|
||||||
self.side_effects.codegen_save_tempvars(pass1)
|
self.side_effects.codegen_save_tempvars(pass1)
|
||||||
pass1.restore_stack(stack_values)
|
pass1.restore_stack(stack_values, value_from_source=not tx.export)
|
||||||
self.side_effects.codegen_update_mutated(pass1)
|
self.side_effects.codegen_update_mutated(pass1)
|
||||||
|
|
||||||
# one more time now that we have established tempvars
|
# one more time now that we have established tempvars
|
||||||
|
|
@ -923,7 +923,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
||||||
)
|
)
|
||||||
self.side_effects.codegen_hooks(pass2)
|
self.side_effects.codegen_hooks(pass2)
|
||||||
self.side_effects.codegen_save_tempvars(pass2)
|
self.side_effects.codegen_save_tempvars(pass2)
|
||||||
pass2.restore_stack(stack_values)
|
pass2.restore_stack(stack_values, value_from_source=not tx.export)
|
||||||
self.side_effects.codegen_update_mutated(pass2)
|
self.side_effects.codegen_update_mutated(pass2)
|
||||||
|
|
||||||
output = []
|
output = []
|
||||||
|
|
|
||||||
|
|
@ -23,14 +23,13 @@ def lift_constant_tensor_pass(gm, graph_signature) -> Dict[str, torch.Tensor]:
|
||||||
)
|
)
|
||||||
assert fake_mode is not None
|
assert fake_mode is not None
|
||||||
|
|
||||||
first_user_input_loc, first_user_input = None, None
|
first_user_input_loc, first_user_input = 0, None
|
||||||
for i, node in enumerate(gm.graph.nodes):
|
for node in gm.graph.nodes:
|
||||||
if node.op == "placeholder" and node.name in graph_signature.user_inputs:
|
if node.op == "placeholder" and node.name in graph_signature.user_inputs:
|
||||||
first_user_input = node
|
first_user_input = node
|
||||||
first_user_input_loc = i
|
|
||||||
break
|
break
|
||||||
|
first_user_input_loc += 1
|
||||||
|
|
||||||
assert first_user_input is not None and first_user_input_loc is not None
|
|
||||||
tensor_constants = {}
|
tensor_constants = {}
|
||||||
|
|
||||||
for node in gm.graph.nodes:
|
for node in gm.graph.nodes:
|
||||||
|
|
|
||||||
|
|
@ -19,10 +19,11 @@ from torch._export.passes.lift_constant_tensor_pass import lift_constant_tensor_
|
||||||
from torch._export.wrappers import _wrap_submodules
|
from torch._export.wrappers import _wrap_submodules
|
||||||
from torch._functorch.aot_autograd import aot_export_module, GraphSignature
|
from torch._functorch.aot_autograd import aot_export_module, GraphSignature
|
||||||
from torch._guards import detect_fake_mode
|
from torch._guards import detect_fake_mode
|
||||||
from torch._subclasses.fake_tensor import FakeTensor
|
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||||
from torch.fx.experimental.symbolic_shapes import (
|
from torch.fx.experimental.symbolic_shapes import (
|
||||||
ConstraintViolationError,
|
ConstraintViolationError,
|
||||||
GuardOnDataDependentSymNode,
|
GuardOnDataDependentSymNode,
|
||||||
|
ShapeEnv,
|
||||||
)
|
)
|
||||||
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
||||||
from torch.utils._sympy.value_ranges import ValueRangeError
|
from torch.utils._sympy.value_ranges import ValueRangeError
|
||||||
|
|
@ -57,16 +58,8 @@ DEFAULT_EXPORT_DYNAMO_CONFIG = ExportDynamoConfig()
|
||||||
|
|
||||||
|
|
||||||
def _convert_input_to_fake(gm, args, kwargs):
|
def _convert_input_to_fake(gm, args, kwargs):
|
||||||
if (
|
params_buffers = _get_params_buffers(gm)
|
||||||
len(args) == 0
|
|
||||||
and len(kwargs) == 0
|
|
||||||
and len(dict(gm.named_parameters())) == 0
|
|
||||||
and len(dict(gm.named_buffers())) == 0
|
|
||||||
):
|
|
||||||
return [], {}, {}, None
|
|
||||||
|
|
||||||
fake_inps: List[torch.Tensor] = []
|
fake_inps: List[torch.Tensor] = []
|
||||||
fake_mode = None
|
|
||||||
for node in gm.graph.nodes:
|
for node in gm.graph.nodes:
|
||||||
if node.op == "placeholder" and "val" in node.meta:
|
if node.op == "placeholder" and "val" in node.meta:
|
||||||
fake_val = node.meta["val"]
|
fake_val = node.meta["val"]
|
||||||
|
|
@ -75,10 +68,11 @@ def _convert_input_to_fake(gm, args, kwargs):
|
||||||
|
|
||||||
if detected_fake_mode := detect_fake_mode(fake_inps):
|
if detected_fake_mode := detect_fake_mode(fake_inps):
|
||||||
fake_mode = detected_fake_mode
|
fake_mode = detected_fake_mode
|
||||||
|
else:
|
||||||
|
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
||||||
|
|
||||||
assert (
|
if len(args) == 0 and len(kwargs) == 0:
|
||||||
fake_mode is not None
|
return (), {}, params_buffers, fake_mode
|
||||||
), "Cannot find fake_mode attatched to the graph's placeholders."
|
|
||||||
|
|
||||||
count = 0
|
count = 0
|
||||||
|
|
||||||
|
|
@ -94,10 +88,7 @@ def _convert_input_to_fake(gm, args, kwargs):
|
||||||
fake_params_buffers = pytree.tree_map_only(
|
fake_params_buffers = pytree.tree_map_only(
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
functools.partial(fake_mode.from_tensor, static_shapes=True),
|
functools.partial(fake_mode.from_tensor, static_shapes=True),
|
||||||
{
|
params_buffers,
|
||||||
**dict(gm.named_parameters(remove_duplicate=False)),
|
|
||||||
**dict(gm.named_buffers(remove_duplicate=False)),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
return fake_args, fake_kwargs, fake_params_buffers, fake_mode
|
return fake_args, fake_kwargs, fake_params_buffers, fake_mode
|
||||||
|
|
||||||
|
|
@ -653,8 +644,6 @@ def _export(
|
||||||
# The unbacked symint symbols are updated in aot_export
|
# The unbacked symint symbols are updated in aot_export
|
||||||
# so we serialize them here instead of inside dynamo
|
# so we serialize them here instead of inside dynamo
|
||||||
|
|
||||||
# dynamo_fake_mode can be None if there's no placeholder in gm_torch_level
|
|
||||||
if dynamo_fake_mode:
|
|
||||||
gm.meta["inline_constraints"] = {
|
gm.meta["inline_constraints"] = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in dynamo_fake_mode.shape_env.runtime_var_to_range.items()
|
for k, v in dynamo_fake_mode.shape_env.runtime_var_to_range.items()
|
||||||
|
|
@ -667,7 +656,7 @@ def _export(
|
||||||
for i, s in enumerate(export_graph_signature.input_specs)
|
for i, s in enumerate(export_graph_signature.input_specs)
|
||||||
if s.kind == InputKind.USER_INPUT
|
if s.kind == InputKind.USER_INPUT
|
||||||
),
|
),
|
||||||
0,
|
len(export_graph_signature.input_specs),
|
||||||
)
|
)
|
||||||
flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs))
|
flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs))
|
||||||
range_constraints, equality_constraints = _process_constraints(
|
range_constraints, equality_constraints = _process_constraints(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user