mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Ensure export joint with descriptors + compile works (#159337)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/159337 Approved by: https://github.com/wconstab ghstack dependencies: #159336
This commit is contained in:
parent
2f0db0444e
commit
31b3b38e3a
|
|
@ -739,6 +739,28 @@ class inner_f(torch.nn.Module):
|
|||
|
||||
self.assertEqual(param_count, len(param_nodes))
|
||||
|
||||
def test_export_and_compile(self):
|
||||
class SimpleModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(3, 2)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
model = SimpleModule()
|
||||
inputs = (torch.randn(4, 3),)
|
||||
|
||||
with ExitStack() as stack:
|
||||
joint_with_descriptors = aot_export_joint_with_descriptors(
|
||||
stack, model, inputs
|
||||
)
|
||||
model_fn = aot_compile_joint_with_descriptors(joint_with_descriptors)
|
||||
|
||||
compiled_fn = torch.compile(fullgraph=True)(model_fn)
|
||||
compiled_fn(*dict(model.named_parameters()).values(), inputs).sum().backward()
|
||||
self.assertIsNotNone(model.linear.weight.grad)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1762,7 +1762,10 @@ def aot_stage2_autograd(
|
|||
placeholder_list[i] = ph_arg.as_strided(ph_arg.size(), real_stride)
|
||||
|
||||
compiled_bw_func = None
|
||||
if num_symints_saved_for_bw > 0:
|
||||
if (
|
||||
num_symints_saved_for_bw > 0
|
||||
or aot_config.force_non_lazy_backward_lowering
|
||||
):
|
||||
try:
|
||||
# See Note: [Backward graph lazy lowering]
|
||||
with torch._subclasses.fake_tensor.unset_fake_temporarily():
|
||||
|
|
@ -1775,6 +1778,8 @@ def aot_stage2_autograd(
|
|||
)
|
||||
del bw_module_copy
|
||||
except Exception as e:
|
||||
if aot_config.force_non_lazy_backward_lowering:
|
||||
raise
|
||||
exc = e
|
||||
trace_structured(
|
||||
"artifact",
|
||||
|
|
|
|||
|
|
@ -1949,12 +1949,11 @@ class AOTDispatchAutograd:
|
|||
expected_meta = meta.meta
|
||||
|
||||
runtime_type = type(x)
|
||||
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
|
||||
# When we're inside compiled autograd's AOTDispatcher step,
|
||||
# regular Tensors look like FunctionalTensors.
|
||||
# Tensor subclasses still look like Tensor subclasses though.
|
||||
if isinstance(x, torch._subclasses.functional_tensor.FunctionalTensor):
|
||||
runtime_type = torch.Tensor
|
||||
# When we're inside compiled autograd's AOTDispatcher step,
|
||||
# regular Tensors look like FunctionalTensors.
|
||||
# Tensor subclasses still look like Tensor subclasses though.
|
||||
if isinstance(x, torch._subclasses.functional_tensor.FunctionalTensor):
|
||||
runtime_type = torch.Tensor
|
||||
|
||||
runtime_meta = None
|
||||
runtime_subclass_keys: Sequence[str] = []
|
||||
|
|
|
|||
|
|
@ -982,6 +982,7 @@ class AOTConfig:
|
|||
# Used only by standalone_compile.
|
||||
ignore_shape_env: bool = False
|
||||
precompile_backend_id: Optional[str] = None
|
||||
force_non_lazy_backward_lowering: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.pre_dispatch:
|
||||
|
|
|
|||
|
|
@ -893,6 +893,8 @@ def prepare_aot_module_simplified(
|
|||
boxed_forward_device_index: BoxedDeviceIndex,
|
||||
ignore_shape_env: bool,
|
||||
flatten: bool,
|
||||
*,
|
||||
force_non_lazy_backward_lowering: bool = False,
|
||||
):
|
||||
if not flatten:
|
||||
assert kwargs is None
|
||||
|
|
@ -982,6 +984,7 @@ def prepare_aot_module_simplified(
|
|||
cache_info=None,
|
||||
ignore_shape_env=ignore_shape_env,
|
||||
precompile_backend_id=getattr(mod, "_backend_id", None),
|
||||
force_non_lazy_backward_lowering=force_non_lazy_backward_lowering,
|
||||
)
|
||||
fake_mode, shape_env = construct_fake_mode(full_args, aot_config)
|
||||
# NB: full_args_descs not needed here, fake_flat_args is 1:1 with full_args
|
||||
|
|
@ -1225,6 +1228,12 @@ def aot_export_joint_with_descriptors(
|
|||
None,
|
||||
ignore_shape_env,
|
||||
flatten=True,
|
||||
# Without this, we will attempt to "compile" the backward lazily
|
||||
# at runtime, but this is pointless because it's just boxed_nop,
|
||||
# it's trivial. But this will get Inductor confused about scoping
|
||||
# Metric(s) {'is_forward'} have already been set in the current
|
||||
# context.
|
||||
force_non_lazy_backward_lowering=True,
|
||||
)
|
||||
|
||||
# TODO: Maybe this should be in create_aot_state? Not sure, that would
|
||||
|
|
@ -1271,6 +1280,7 @@ def aot_compile_joint_with_descriptors(jd: JointWithDescriptors) -> callable:
|
|||
|
||||
# Cribbed from torch/export/pt2_archive/_package.py
|
||||
@simple_wraps(compiled_fn)
|
||||
@torch._dynamo.nonstrict_trace # allow recursive compilation
|
||||
def unflattened_compiled_fn(*args, **kwargs):
|
||||
flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, jd.in_spec)))[0]
|
||||
# TODO: do I need to filter? I hope not!
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user