mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Ensure export joint with descriptors + compile works (#159337)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/159337 Approved by: https://github.com/wconstab ghstack dependencies: #159336
This commit is contained in:
parent
2f0db0444e
commit
31b3b38e3a
|
|
@ -739,6 +739,28 @@ class inner_f(torch.nn.Module):
|
||||||
|
|
||||||
self.assertEqual(param_count, len(param_nodes))
|
self.assertEqual(param_count, len(param_nodes))
|
||||||
|
|
||||||
|
def test_export_and_compile(self):
|
||||||
|
class SimpleModule(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(3, 2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear(x)
|
||||||
|
|
||||||
|
model = SimpleModule()
|
||||||
|
inputs = (torch.randn(4, 3),)
|
||||||
|
|
||||||
|
with ExitStack() as stack:
|
||||||
|
joint_with_descriptors = aot_export_joint_with_descriptors(
|
||||||
|
stack, model, inputs
|
||||||
|
)
|
||||||
|
model_fn = aot_compile_joint_with_descriptors(joint_with_descriptors)
|
||||||
|
|
||||||
|
compiled_fn = torch.compile(fullgraph=True)(model_fn)
|
||||||
|
compiled_fn(*dict(model.named_parameters()).values(), inputs).sum().backward()
|
||||||
|
self.assertIsNotNone(model.linear.weight.grad)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -1762,7 +1762,10 @@ def aot_stage2_autograd(
|
||||||
placeholder_list[i] = ph_arg.as_strided(ph_arg.size(), real_stride)
|
placeholder_list[i] = ph_arg.as_strided(ph_arg.size(), real_stride)
|
||||||
|
|
||||||
compiled_bw_func = None
|
compiled_bw_func = None
|
||||||
if num_symints_saved_for_bw > 0:
|
if (
|
||||||
|
num_symints_saved_for_bw > 0
|
||||||
|
or aot_config.force_non_lazy_backward_lowering
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
# See Note: [Backward graph lazy lowering]
|
# See Note: [Backward graph lazy lowering]
|
||||||
with torch._subclasses.fake_tensor.unset_fake_temporarily():
|
with torch._subclasses.fake_tensor.unset_fake_temporarily():
|
||||||
|
|
@ -1775,6 +1778,8 @@ def aot_stage2_autograd(
|
||||||
)
|
)
|
||||||
del bw_module_copy
|
del bw_module_copy
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if aot_config.force_non_lazy_backward_lowering:
|
||||||
|
raise
|
||||||
exc = e
|
exc = e
|
||||||
trace_structured(
|
trace_structured(
|
||||||
"artifact",
|
"artifact",
|
||||||
|
|
|
||||||
|
|
@ -1949,12 +1949,11 @@ class AOTDispatchAutograd:
|
||||||
expected_meta = meta.meta
|
expected_meta = meta.meta
|
||||||
|
|
||||||
runtime_type = type(x)
|
runtime_type = type(x)
|
||||||
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
|
# When we're inside compiled autograd's AOTDispatcher step,
|
||||||
# When we're inside compiled autograd's AOTDispatcher step,
|
# regular Tensors look like FunctionalTensors.
|
||||||
# regular Tensors look like FunctionalTensors.
|
# Tensor subclasses still look like Tensor subclasses though.
|
||||||
# Tensor subclasses still look like Tensor subclasses though.
|
if isinstance(x, torch._subclasses.functional_tensor.FunctionalTensor):
|
||||||
if isinstance(x, torch._subclasses.functional_tensor.FunctionalTensor):
|
runtime_type = torch.Tensor
|
||||||
runtime_type = torch.Tensor
|
|
||||||
|
|
||||||
runtime_meta = None
|
runtime_meta = None
|
||||||
runtime_subclass_keys: Sequence[str] = []
|
runtime_subclass_keys: Sequence[str] = []
|
||||||
|
|
|
||||||
|
|
@ -982,6 +982,7 @@ class AOTConfig:
|
||||||
# Used only by standalone_compile.
|
# Used only by standalone_compile.
|
||||||
ignore_shape_env: bool = False
|
ignore_shape_env: bool = False
|
||||||
precompile_backend_id: Optional[str] = None
|
precompile_backend_id: Optional[str] = None
|
||||||
|
force_non_lazy_backward_lowering: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.pre_dispatch:
|
if self.pre_dispatch:
|
||||||
|
|
|
||||||
|
|
@ -893,6 +893,8 @@ def prepare_aot_module_simplified(
|
||||||
boxed_forward_device_index: BoxedDeviceIndex,
|
boxed_forward_device_index: BoxedDeviceIndex,
|
||||||
ignore_shape_env: bool,
|
ignore_shape_env: bool,
|
||||||
flatten: bool,
|
flatten: bool,
|
||||||
|
*,
|
||||||
|
force_non_lazy_backward_lowering: bool = False,
|
||||||
):
|
):
|
||||||
if not flatten:
|
if not flatten:
|
||||||
assert kwargs is None
|
assert kwargs is None
|
||||||
|
|
@ -982,6 +984,7 @@ def prepare_aot_module_simplified(
|
||||||
cache_info=None,
|
cache_info=None,
|
||||||
ignore_shape_env=ignore_shape_env,
|
ignore_shape_env=ignore_shape_env,
|
||||||
precompile_backend_id=getattr(mod, "_backend_id", None),
|
precompile_backend_id=getattr(mod, "_backend_id", None),
|
||||||
|
force_non_lazy_backward_lowering=force_non_lazy_backward_lowering,
|
||||||
)
|
)
|
||||||
fake_mode, shape_env = construct_fake_mode(full_args, aot_config)
|
fake_mode, shape_env = construct_fake_mode(full_args, aot_config)
|
||||||
# NB: full_args_descs not needed here, fake_flat_args is 1:1 with full_args
|
# NB: full_args_descs not needed here, fake_flat_args is 1:1 with full_args
|
||||||
|
|
@ -1225,6 +1228,12 @@ def aot_export_joint_with_descriptors(
|
||||||
None,
|
None,
|
||||||
ignore_shape_env,
|
ignore_shape_env,
|
||||||
flatten=True,
|
flatten=True,
|
||||||
|
# Without this, we will attempt to "compile" the backward lazily
|
||||||
|
# at runtime, but this is pointless because it's just boxed_nop,
|
||||||
|
# it's trivial. But this will get Inductor confused about scoping
|
||||||
|
# Metric(s) {'is_forward'} have already been set in the current
|
||||||
|
# context.
|
||||||
|
force_non_lazy_backward_lowering=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Maybe this should be in create_aot_state? Not sure, that would
|
# TODO: Maybe this should be in create_aot_state? Not sure, that would
|
||||||
|
|
@ -1271,6 +1280,7 @@ def aot_compile_joint_with_descriptors(jd: JointWithDescriptors) -> callable:
|
||||||
|
|
||||||
# Cribbed from torch/export/pt2_archive/_package.py
|
# Cribbed from torch/export/pt2_archive/_package.py
|
||||||
@simple_wraps(compiled_fn)
|
@simple_wraps(compiled_fn)
|
||||||
|
@torch._dynamo.nonstrict_trace # allow recursive compilation
|
||||||
def unflattened_compiled_fn(*args, **kwargs):
|
def unflattened_compiled_fn(*args, **kwargs):
|
||||||
flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, jd.in_spec)))[0]
|
flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, jd.in_spec)))[0]
|
||||||
# TODO: do I need to filter? I hope not!
|
# TODO: do I need to filter? I hope not!
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user