mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[export] Turn off output value from sources for export. (#115442)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/115442 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
parent
af09fe256a
commit
f78f23d753
|
|
@ -1562,6 +1562,32 @@ def forward(self, l_x_):
|
|||
with self.assertRaisesRegex(RuntimeError, "shape\[0\] is specialized at 4"):
|
||||
ep_v2(*test_inp)
|
||||
|
||||
def test_constant_output(self):
|
||||
class ModuleConstant(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.b = torch.randn(3, 2)
|
||||
|
||||
def forward(self):
|
||||
return self.b
|
||||
|
||||
class ModuleNestedConstant(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bff = torch.randn(3, 2)
|
||||
|
||||
def forward(self, x, y):
|
||||
return {"prediction": (x + y, self.bff)}
|
||||
|
||||
mod = ModuleConstant()
|
||||
ep = torch.export.export(mod, ())
|
||||
self.assertEqual(ep(), mod())
|
||||
|
||||
args = (torch.randn(3, 2), torch.randn(3, 2))
|
||||
mod = ModuleNestedConstant()
|
||||
ep = torch.export.export(mod, args)
|
||||
self.assertEqual(ep(*args), mod(*args))
|
||||
|
||||
def test_non_arg_name_dynamic_shapes_api_with_kwarg(self):
|
||||
def foo(a, b, kw1, kw2):
|
||||
return a.sum() + b.sum() + kw1.sum() - kw2.sum()
|
||||
|
|
|
|||
|
|
@ -60,14 +60,18 @@ class PyCodegen:
|
|||
self.cell_and_freevars = self.tx.cell_and_freevars
|
||||
self.new_var = self.tx.output.new_var
|
||||
self.mutable_side_effects_from_source = False
|
||||
self.value_from_source: bool = True
|
||||
|
||||
def restore_stack(self, stack_values):
|
||||
def restore_stack(self, stack_values, *, value_from_source=True):
|
||||
prior = self.mutable_side_effects_from_source
|
||||
self.mutable_side_effects_from_source = True
|
||||
prev = self.value_from_source
|
||||
self.value_from_source &= value_from_source
|
||||
try:
|
||||
self.foreach(stack_values)
|
||||
finally:
|
||||
self.mutable_side_effects_from_source = prior
|
||||
self.value_from_source = prev
|
||||
|
||||
def graph_output_vars(self):
|
||||
return [x.variable for x in self.graph_outputs.values()]
|
||||
|
|
@ -108,7 +112,7 @@ class PyCodegen:
|
|||
self.top_of_stack = value
|
||||
return
|
||||
|
||||
if value.source is not None and allow_cache:
|
||||
if value.source is not None and allow_cache and self.value_from_source:
|
||||
output.extend(value.source.reconstruct(self))
|
||||
elif value.is_python_constant() and is_safe_constant(
|
||||
value.as_python_constant()
|
||||
|
|
|
|||
|
|
@ -911,7 +911,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
|||
pass1 = PyCodegen(tx, root, graph_output_var)
|
||||
self.side_effects.codegen_hooks(pass1)
|
||||
self.side_effects.codegen_save_tempvars(pass1)
|
||||
pass1.restore_stack(stack_values)
|
||||
pass1.restore_stack(stack_values, value_from_source=not tx.export)
|
||||
self.side_effects.codegen_update_mutated(pass1)
|
||||
|
||||
# one more time now that we have established tempvars
|
||||
|
|
@ -923,7 +923,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
|||
)
|
||||
self.side_effects.codegen_hooks(pass2)
|
||||
self.side_effects.codegen_save_tempvars(pass2)
|
||||
pass2.restore_stack(stack_values)
|
||||
pass2.restore_stack(stack_values, value_from_source=not tx.export)
|
||||
self.side_effects.codegen_update_mutated(pass2)
|
||||
|
||||
output = []
|
||||
|
|
|
|||
|
|
@ -23,14 +23,13 @@ def lift_constant_tensor_pass(gm, graph_signature) -> Dict[str, torch.Tensor]:
|
|||
)
|
||||
assert fake_mode is not None
|
||||
|
||||
first_user_input_loc, first_user_input = None, None
|
||||
for i, node in enumerate(gm.graph.nodes):
|
||||
first_user_input_loc, first_user_input = 0, None
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder" and node.name in graph_signature.user_inputs:
|
||||
first_user_input = node
|
||||
first_user_input_loc = i
|
||||
break
|
||||
first_user_input_loc += 1
|
||||
|
||||
assert first_user_input is not None and first_user_input_loc is not None
|
||||
tensor_constants = {}
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
|
|
|
|||
|
|
@ -19,10 +19,11 @@ from torch._export.passes.lift_constant_tensor_pass import lift_constant_tensor_
|
|||
from torch._export.wrappers import _wrap_submodules
|
||||
from torch._functorch.aot_autograd import aot_export_module, GraphSignature
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
ConstraintViolationError,
|
||||
GuardOnDataDependentSymNode,
|
||||
ShapeEnv,
|
||||
)
|
||||
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
||||
from torch.utils._sympy.value_ranges import ValueRangeError
|
||||
|
|
@ -57,16 +58,8 @@ DEFAULT_EXPORT_DYNAMO_CONFIG = ExportDynamoConfig()
|
|||
|
||||
|
||||
def _convert_input_to_fake(gm, args, kwargs):
|
||||
if (
|
||||
len(args) == 0
|
||||
and len(kwargs) == 0
|
||||
and len(dict(gm.named_parameters())) == 0
|
||||
and len(dict(gm.named_buffers())) == 0
|
||||
):
|
||||
return [], {}, {}, None
|
||||
|
||||
params_buffers = _get_params_buffers(gm)
|
||||
fake_inps: List[torch.Tensor] = []
|
||||
fake_mode = None
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder" and "val" in node.meta:
|
||||
fake_val = node.meta["val"]
|
||||
|
|
@ -75,10 +68,11 @@ def _convert_input_to_fake(gm, args, kwargs):
|
|||
|
||||
if detected_fake_mode := detect_fake_mode(fake_inps):
|
||||
fake_mode = detected_fake_mode
|
||||
else:
|
||||
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
||||
|
||||
assert (
|
||||
fake_mode is not None
|
||||
), "Cannot find fake_mode attatched to the graph's placeholders."
|
||||
if len(args) == 0 and len(kwargs) == 0:
|
||||
return (), {}, params_buffers, fake_mode
|
||||
|
||||
count = 0
|
||||
|
||||
|
|
@ -94,10 +88,7 @@ def _convert_input_to_fake(gm, args, kwargs):
|
|||
fake_params_buffers = pytree.tree_map_only(
|
||||
torch.Tensor,
|
||||
functools.partial(fake_mode.from_tensor, static_shapes=True),
|
||||
{
|
||||
**dict(gm.named_parameters(remove_duplicate=False)),
|
||||
**dict(gm.named_buffers(remove_duplicate=False)),
|
||||
},
|
||||
params_buffers,
|
||||
)
|
||||
return fake_args, fake_kwargs, fake_params_buffers, fake_mode
|
||||
|
||||
|
|
@ -653,13 +644,11 @@ def _export(
|
|||
# The unbacked symint symbols are updated in aot_export
|
||||
# so we serialize them here instead of inside dynamo
|
||||
|
||||
# dynamo_fake_mode can be None if there's no placeholder in gm_torch_level
|
||||
if dynamo_fake_mode:
|
||||
gm.meta["inline_constraints"] = {
|
||||
k: v
|
||||
for k, v in dynamo_fake_mode.shape_env.runtime_var_to_range.items()
|
||||
if re.match(r"^[if]\d+$", str(k))
|
||||
}
|
||||
gm.meta["inline_constraints"] = {
|
||||
k: v
|
||||
for k, v in dynamo_fake_mode.shape_env.runtime_var_to_range.items()
|
||||
if re.match(r"^[if]\d+$", str(k))
|
||||
}
|
||||
|
||||
num_lifted = next(
|
||||
(
|
||||
|
|
@ -667,7 +656,7 @@ def _export(
|
|||
for i, s in enumerate(export_graph_signature.input_specs)
|
||||
if s.kind == InputKind.USER_INPUT
|
||||
),
|
||||
0,
|
||||
len(export_graph_signature.input_specs),
|
||||
)
|
||||
flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs))
|
||||
range_constraints, equality_constraints = _process_constraints(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user