Ensure export joint with descriptors + compile works (#159337)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159337
Approved by: https://github.com/wconstab
ghstack dependencies: #159336
This commit is contained in:
Edward Z. Yang 2025-07-28 20:51:36 -07:00 committed by PyTorch MergeBot
parent 2f0db0444e
commit 31b3b38e3a
5 changed files with 44 additions and 7 deletions

View File

@ -739,6 +739,28 @@ class inner_f(torch.nn.Module):
self.assertEqual(param_count, len(param_nodes))
def test_export_and_compile(self):
class SimpleModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 2)
def forward(self, x):
return self.linear(x)
model = SimpleModule()
inputs = (torch.randn(4, 3),)
with ExitStack() as stack:
joint_with_descriptors = aot_export_joint_with_descriptors(
stack, model, inputs
)
model_fn = aot_compile_joint_with_descriptors(joint_with_descriptors)
compiled_fn = torch.compile(fullgraph=True)(model_fn)
compiled_fn(*dict(model.named_parameters()).values(), inputs).sum().backward()
self.assertIsNotNone(model.linear.weight.grad)
if __name__ == "__main__":
run_tests()

View File

@ -1762,7 +1762,10 @@ def aot_stage2_autograd(
placeholder_list[i] = ph_arg.as_strided(ph_arg.size(), real_stride)
compiled_bw_func = None
if num_symints_saved_for_bw > 0:
if (
num_symints_saved_for_bw > 0
or aot_config.force_non_lazy_backward_lowering
):
try:
# See Note: [Backward graph lazy lowering]
with torch._subclasses.fake_tensor.unset_fake_temporarily():
@ -1775,6 +1778,8 @@ def aot_stage2_autograd(
)
del bw_module_copy
except Exception as e:
if aot_config.force_non_lazy_backward_lowering:
raise
exc = e
trace_structured(
"artifact",

View File

@ -1949,7 +1949,6 @@ class AOTDispatchAutograd:
expected_meta = meta.meta
runtime_type = type(x)
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
# When we're inside compiled autograd's AOTDispatcher step,
# regular Tensors look like FunctionalTensors.
# Tensor subclasses still look like Tensor subclasses though.

View File

@ -982,6 +982,7 @@ class AOTConfig:
# Used only by standalone_compile.
ignore_shape_env: bool = False
precompile_backend_id: Optional[str] = None
force_non_lazy_backward_lowering: bool = False
def __post_init__(self):
if self.pre_dispatch:

View File

@ -893,6 +893,8 @@ def prepare_aot_module_simplified(
boxed_forward_device_index: BoxedDeviceIndex,
ignore_shape_env: bool,
flatten: bool,
*,
force_non_lazy_backward_lowering: bool = False,
):
if not flatten:
assert kwargs is None
@ -982,6 +984,7 @@ def prepare_aot_module_simplified(
cache_info=None,
ignore_shape_env=ignore_shape_env,
precompile_backend_id=getattr(mod, "_backend_id", None),
force_non_lazy_backward_lowering=force_non_lazy_backward_lowering,
)
fake_mode, shape_env = construct_fake_mode(full_args, aot_config)
# NB: full_args_descs not needed here, fake_flat_args is 1:1 with full_args
@ -1225,6 +1228,12 @@ def aot_export_joint_with_descriptors(
None,
ignore_shape_env,
flatten=True,
# Without this, we will attempt to "compile" the backward lazily
# at runtime, but this is pointless because it's just boxed_nop,
# it's trivial. But this will get Inductor confused about scoping
# Metric(s) {'is_forward'} have already been set in the current
# context.
force_non_lazy_backward_lowering=True,
)
# TODO: Maybe this should be in create_aot_state? Not sure, that would
@ -1271,6 +1280,7 @@ def aot_compile_joint_with_descriptors(jd: JointWithDescriptors) -> callable:
# Cribbed from torch/export/pt2_archive/_package.py
@simple_wraps(compiled_fn)
@torch._dynamo.nonstrict_trace # allow recursive compilation
def unflattened_compiled_fn(*args, **kwargs):
flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, jd.in_spec)))[0]
# TODO: do I need to filter? I hope not!