From 68ddca94498fd7961cc5ebcb0dffafb8c2f4baca Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Mon, 24 Feb 2025 20:03:05 -0800 Subject: [PATCH] [ca] trace saved variable unpacking (#147242) ## Before Previously, CA will always unpack all saved variables stored in the autograd graph before executing it. This meant that we can't capture unpack hooks as part of the CA graph, and they would fire out of order wrt to other backward hooks. For memory saving APIs built on top of saved tensor hooks like non-reentrant checkpointing and offloading, we couldn't achieve any savings because all activations would be recomputed/loaded and active at the same time, resulting in no-op. ## After We add unpack hooks into the CA graph so that they can be executed progressively. The python hook and hook input themselves are wrapped by non-traceable code, so CA polyfills the wrapping as: ```python # pseudocode class SavedVariable: def unpack(self): if self.hook: return self.hook(self.packed_data) else: return self.packed_data # This approach won't directly work when we add support for Forward AD or double-backward. ``` Directly executing the CA graph (without torch.compiling it) under checkpointing/offloading, memory profile is expected to stay the same as when using the eager autograd engine. If AOT backward is in the autograd graph, memory profile is expected to be better than the eager autograd engine, since we can now delay saved activations unpacking into the AOT backward's execution. All tests pass when running the CA graph directly, the remaining issues are in Dynamo. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147242 Approved by: https://github.com/jansel --- test/inductor/test_compiled_autograd.py | 309 ++++++++++++++++-- torch/_dynamo/compiled_autograd.py | 55 +++- torch/csrc/autograd/engine.cpp | 1 + .../autograd/python_saved_variable_hooks.cpp | 9 + .../autograd/python_saved_variable_hooks.h | 3 + torch/csrc/autograd/saved_variable.cpp | 12 +- torch/csrc/autograd/saved_variable.h | 10 + torch/csrc/autograd/saved_variable_hooks.h | 6 + torch/csrc/dynamo/compiled_autograd.h | 201 +++++++----- .../csrc/dynamo/python_compiled_autograd.cpp | 25 +- 10 files changed, 511 insertions(+), 120 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index af933b70b7e..d8ed0107fc1 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -10,6 +10,7 @@ import os import re import subprocess import sys +import tempfile import unittest from copy import deepcopy from importlib.machinery import SourceFileLoader @@ -48,11 +49,17 @@ from torch.testing._internal.logging_utils import logs_to_string # note: these tests are not run on windows due to inductor_utils.HAS_CPU -def make_compiler_fn(fullgraph=True, dynamic=True, backend="inductor"): - assert backend in ["inductor", "aot_eager"] +def make_compiler_fn( + fullgraph=True, dynamic=True, backend="inductor", gm_hook=lambda gm: None +): + assert backend in ["inductor", "aot_eager", "ca_eager"] def _compiler_fn(gm): """Same as torch.compile() but counts number of compiles""" + gm_hook(gm) + + if backend == "ca_eager": + return gm def _inner_compiler(gm_, example_inputs_): counters["compiled_autograd"]["compiles"] += 1 @@ -915,7 +922,8 @@ main() inputs=[param, activ], sizes=(), scalars=(), - hooks=(), + hooks=[], + packed_inputs=[], ) finally: handle.remove() @@ -3336,7 +3344,7 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) { graph_code, """\ class CompiledAutograd0(torch.nn.Module): - def forward(self, inputs, sizes, scalars, hooks): + def forward(self, inputs, sizes, scalars, hooks, packed_data): getitem = inputs[0] getitem_1 = inputs[1] getitem_2 = inputs[2] @@ -3511,9 +3519,7 @@ class CompiledAutograd0(torch.nn.Module): fn, count=2, compiler_fn=make_compiler_fn(backend="aot_eager") ) - @unittest.expectedFailure def test_saved_tensor_unpack_hook_ordering(self): - # not the correct behaviour, I'm just preventing this from changing silently def f(x, y): return x * y @@ -3531,8 +3537,6 @@ class CompiledAutograd0(torch.nn.Module): return x def tensor_hook(_): - # in eager, tensor_hook is fired before unpack_hook - # but in compiled autograd, tensor_hook is lifted whereas unpack_hook is not self.assertEqual(unpack_count, 0) x = torch.ones(4, requires_grad=True) @@ -3544,21 +3548,252 @@ class CompiledAutograd0(torch.nn.Module): self.assertEqual(pack_count, 1) self.assertEqual(unpack_count, 0) loss = out_test.sum() - loss.register_hook(tensor_hook) + loss.register_hook( + tensor_hook + ) # scheduled to fire before any saved activations loss.backward() self.assertEqual(pack_count, 1) self.assertEqual(unpack_count, 1) - def test_reentrant_checkpointing(self): - def fn(x): - y = x.sin() - z = y.cos() - return (y * z).sum() + @parametrize("reentrant", (True, False)) + def test_checkpointing_simple(self, reentrant): + def fn(): + def _fn(x): + y = x.sin() + z = y.cos() + return (y * z).sum() - inp = torch.rand(10, 10, requires_grad=True) - out = torch.utils.checkpoint.checkpoint(fn, inp, use_reentrant=True) - with torch._dynamo.compiled_autograd._enable(torch.compile): + inp = torch.rand(10, 10, requires_grad=True) + out = torch.utils.checkpoint.checkpoint(_fn, inp, use_reentrant=reentrant) out.backward() + yield inp.grad + + if reentrant: + self.check_output_and_recompiles( + fn, count=[1, 3], compiler_fn=make_compiler_fn(fullgraph=False) + ) + else: + # dynamo issues, just run the CA graph directly for now + def check(gm): + graph_code = normalize_gm(gm.print_readable(print_output=False)) + self.assertExpectedInline( + graph_code, + """\ +class CompiledAutograd0(torch.nn.Module): + def forward(self, inputs, sizes, scalars, hooks, packed_data): + getitem = inputs[0] + getitem_1 = inputs[1]; inputs = None + + validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], False)]); getitem = None + getitem_2 = validate_outputs[0]; validate_outputs = None + + sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_2], [True], [10, 10]); getitem_2 = None + getitem_3 = sum_backward0[0]; sum_backward0 = None + validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_3], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_3 = None + getitem_4 = validate_outputs_1[0]; validate_outputs_1 = None + + getitem_5 = hooks[0] + getitem_6 = packed_data[0] + getitem_7 = hooks[1] + getitem_8 = packed_data[1] + call_hook = torch__dynamo_external_utils_call_hook(getitem_5, getitem_6, hook_type = 'unpack_hook'); getitem_5 = getitem_6 = None + call_hook_1 = torch__dynamo_external_utils_call_hook(getitem_7, getitem_8, hook_type = 'unpack_hook'); getitem_7 = getitem_8 = None + mul_backward0 = torch__dynamo_compiled_autograd_ops_MulBackward0([getitem_4], [True, True], call_hook, 6, call_hook_1, 6); getitem_4 = call_hook = call_hook_1 = None + getitem_9 = mul_backward0[0] + getitem_10 = mul_backward0[1]; mul_backward0 = None + validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_9, getitem_10], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False), ((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_9 = getitem_10 = None + getitem_11 = validate_outputs_2[0] + getitem_12 = validate_outputs_2[1]; validate_outputs_2 = None + + getitem_13 = hooks[2] + getitem_14 = packed_data[2] + call_hook_2 = torch__dynamo_external_utils_call_hook(getitem_13, getitem_14, hook_type = 'unpack_hook'); getitem_13 = getitem_14 = None + cos_backward0 = torch__dynamo_compiled_autograd_ops_CosBackward0([getitem_12], [True], call_hook_2); getitem_12 = call_hook_2 = None + getitem_15 = cos_backward0[0]; cos_backward0 = None + validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_15], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_15 = None + getitem_16 = validate_outputs_3[0]; validate_outputs_3 = None + add = torch.add(getitem_11, getitem_16); getitem_11 = getitem_16 = None + + getitem_17 = hooks[3]; hooks = None + getitem_18 = packed_data[3]; packed_data = None + call_hook_3 = torch__dynamo_external_utils_call_hook(getitem_17, getitem_18, hook_type = 'unpack_hook'); getitem_17 = getitem_18 = None + sin_backward0 = torch__dynamo_compiled_autograd_ops_SinBackward0([add], [True], call_hook_3); add = call_hook_3 = None + getitem_19 = sin_backward0[0]; sin_backward0 = None + validate_outputs_4 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_19], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_19 = None + getitem_20 = validate_outputs_4[0]; validate_outputs_4 = None + + accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_1, getitem_20); getitem_1 = getitem_20 = accumulate_grad_ = None + _exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None + return [] +""", # noqa: B950 + ) + + self.check_output_and_recompiles( + fn, + count=[1, 0], + compiler_fn=make_compiler_fn(backend="ca_eager", gm_hook=check), + ) + + @unittest.skipIf(not HAS_CUDA, "requires cuda") + def test_cpu_offloading(self): + def fn(): + def pack(x): + return x.cpu() + + def unpack(x): + return x.cuda() + + class MyMatMul(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return torch.matmul(x, x) + + @staticmethod + def backward(ctx, grad_out): + (x,) = ctx.saved_tensors + return grad_out * x + + with torch.autograd.graph.saved_tensors_hooks(pack, unpack): + for i in [10, 100, 10, 20, 30]: + x = torch.randn(i, requires_grad=True).cuda() + MyMatMul.apply(x).sum().backward() + yield x.grad + + i = 0 + + def check(gm): + nonlocal i + if i == 0: + i += 1 + return + + graph_code = normalize_gm(gm.print_readable(print_output=False)) + self.assertExpectedInline( + graph_code, + """\ +class CompiledAutograd1(torch.nn.Module): + def forward(self, inputs, sizes, scalars, hooks, packed_data): + getitem = inputs[0] + getitem_1 = inputs[1]; inputs = None + getitem_2 = sizes[0]; getitem_2 = None + getitem_3 = sizes[1] + getitem_4 = sizes[2]; sizes = None + + validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cuda', index=0), 6, 0, None), [], False)]); getitem = None + getitem_5 = validate_outputs[0]; validate_outputs = None + + sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_5], [True], []); getitem_5 = None + getitem_6 = sum_backward0[0]; sum_backward0 = None + validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_6], [((None, None, device(type='cuda', index=0), 6, 0, None), [], False)]); getitem_6 = None + getitem_7 = validate_outputs_1[0]; validate_outputs_1 = None + + getitem_8 = hooks[0] + getitem_9 = packed_data[0]; packed_data = None + getitem_10 = hooks[1]; hooks = None + call_hook = torch__dynamo_external_utils_call_hook(getitem_8, getitem_9, hook_type = 'unpack_hook'); getitem_8 = getitem_9 = None + call_backward = torch__dynamo_external_utils_call_backward(getitem_10, (call_hook,), getitem_7); getitem_10 = call_hook = getitem_7 = None + getitem_12 = call_backward[0]; call_backward = None + validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_12], [((None, None, device(type='cuda', index=0), 6, 0, None), [getitem_3], False)]); getitem_12 = getitem_3 = None + getitem_13 = validate_outputs_2[0]; validate_outputs_2 = None + + to_copy_backward0 = torch__dynamo_compiled_autograd_ops_ToCopyBackward0([getitem_13], [True], (None, None, device(type='cpu'), 6, 0, None)); getitem_13 = None + getitem_14 = to_copy_backward0[0]; to_copy_backward0 = None + validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_14], [((None, None, device(type='cpu'), 6, 0, None), [getitem_4], False)]); getitem_14 = getitem_4 = None + getitem_15 = validate_outputs_3[0]; validate_outputs_3 = None + + accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_1, getitem_15); getitem_1 = getitem_15 = accumulate_grad_ = None + _exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None + return [] +""", # noqa: B950 + ) + + self.check_output_and_recompiles( + fn, count=2, compiler_fn=make_compiler_fn(gm_hook=check) + ) + + @skipIfWindows(msg="temp dir not compatible") + def test_disk_offloading(self): + with tempfile.TemporaryDirectory() as d: + + def fn(): + pack_count = 0 + + def pack(x): + nonlocal pack_count + path = f"{d}/{pack_count}.pt" + torch.save(x, path) + return path + + def unpack(path): + x = torch.load(path) + return x + + class MyMatMul(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return torch.matmul(x, x) + + @staticmethod + def backward(ctx, grad_out): + (x,) = ctx.saved_tensors + return grad_out * x + + with torch.autograd.graph.saved_tensors_hooks(pack, unpack): + for i in [10, 100, 10, 20, 30]: + x = torch.randn(i, requires_grad=True) + MyMatMul.apply(x).sum().backward() + yield x.grad + + i = 0 + + def check(gm): + nonlocal i + if i == 0: + i += 1 + return + + graph_code = normalize_gm(gm.print_readable(print_output=False)) + self.assertExpectedInline( + graph_code, + """\ +class CompiledAutograd1(torch.nn.Module): + def forward(self, inputs, sizes, scalars, hooks, packed_data): + getitem = inputs[0] + getitem_1 = inputs[1]; inputs = None + getitem_2 = sizes[0]; getitem_2 = None + getitem_3 = sizes[1]; sizes = None + + validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], False)]); getitem = None + getitem_4 = validate_outputs[0]; validate_outputs = None + + sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_4], [True], []); getitem_4 = None + getitem_5 = sum_backward0[0]; sum_backward0 = None + validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_5], [((None, None, device(type='cpu'), 6, 0, None), [], False)]); getitem_5 = None + getitem_6 = validate_outputs_1[0]; validate_outputs_1 = None + + getitem_7 = hooks[0] + getitem_8 = packed_data[0]; packed_data = None + getitem_9 = hooks[1]; hooks = None + call_hook = torch__dynamo_external_utils_call_hook(getitem_7, getitem_8, hook_type = 'unpack_hook'); getitem_7 = getitem_8 = None + call_backward = torch__dynamo_external_utils_call_backward(getitem_9, (call_hook,), getitem_6); getitem_9 = call_hook = getitem_6 = None + getitem_11 = call_backward[0]; call_backward = None + validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_11], [((None, None, device(type='cpu'), 6, 0, None), [getitem_3], False)]); getitem_11 = getitem_3 = None + getitem_12 = validate_outputs_2[0]; validate_outputs_2 = None + + accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_1, getitem_12); getitem_1 = getitem_12 = accumulate_grad_ = None + _exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None + return [] +""", # noqa: B950 + ) + + # 1 graph break on torch.load -> 2 dynamo graphs + self.check_output_and_recompiles( + fn, + count=[2, 4], + compiler_fn=make_compiler_fn(fullgraph=False, gm_hook=check), + ) @skipIfWindows(msg="node name demangling inconsistent on windows") def test_backward_hook_relative_ordering_partial(self): @@ -3617,7 +3852,7 @@ class CompiledAutograd0(torch.nn.Module): self.check_output_and_recompiles(fn) - def test_sac(self): + def test_checkpointing_sac(self): # circular import from torch.utils.checkpoint import ( checkpoint, @@ -3666,7 +3901,9 @@ class CompiledAutograd0(torch.nn.Module): yield model.layer4.weight.grad yield model.layer4.bias.grad - self.check_output_and_recompiles(fn) + self.check_output_and_recompiles( + fn, count=[1, 5], compiler_fn=make_compiler_fn(fullgraph=False) + ) def load_test_module(name): @@ -3754,6 +3991,22 @@ known_graph_breaks_tests = { "test_deep_reentrant", # reentrant .backward "test_reentrant_priority", # reentrant .backward "test_simple_reentrant", # reentrant .backward + "test_checkpoint_detects_non_determinism", # unpack hook in skip files + "test_checkpoint_valid_reset_on_error", # unpack hook in skip files + "test_checkpointing_non_reentrant_autocast_cpu", # unpack hook in skip files + "test_checkpointing_non_reentrant_autocast_gpu", # unpack hook in skip files + "test_checkpointing_without_reentrant_arbitrary_input_output", # unpack hook in skip files + "test_checkpointing_without_reentrant_correct_grad", # unpack hook in skip files + "test_checkpointing_without_reentrant_custom_function_works", # unpack hook in skip files + "test_checkpointing_without_reentrant_dataparallel", # _get_device_index in skip files + "test_checkpointing_without_reentrant_detached_tensor_use_reentrant_True", # reentrant .backward + "test_checkpointing_without_reentrant_parameter_used_in_an_out", # unpack hook in skip files + "test_checkpointing_without_reentrant_with_context_fn", # unpack hook in skip files + "test_save_on_cpu_and_checkpoint", # unpack hook in skip files + "test_saved_tensor_hooks_custom_error_propagation", # CustomError + "test_access_saved_tensor_twice_without_recomputation_works", # unpack hook in skip files + "test_saved_tensor_hooks_extra_enter_during_bw_no_leak", # ctx in skip files + "test_saved_tensor_hooks_extra_exit_during_bw_no_crash", # ctx in skip files } test_contexts = { @@ -3764,9 +4017,7 @@ test_contexts = { } # These groups of tests aren't supported yet -known_failures_re = re.compile( - r"^test_(sparse|profiler|gradcheck|checkpoint|named_tensor)" -) +known_failures_re = re.compile(r"^test_(sparse|profiler|gradcheck|named_tensor)") # Bugs needing investigation: skipped_tests = { @@ -3837,7 +4088,7 @@ known_failing_tests = { # IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors) "test_grad_nonleaf_register_hook", "test_backward_twice_without_saved_values", # https://github.com/pytorch/pytorch/issues/129938 - # Category: Dynamo + # Category: Dynamo (pass when directly running CA graph) "test_accumulate_grad_tensor_reference", # Out of bounds: frame_state_entry.stride[i] is None "test_custom_function_exception", # torch.no_grad(), torch._dynamo.exc.Unsupported: missing: WITH_EXCEPT_START "test_to_sparse_backward", # Out of bounds: frame_state_entry.stride[i] is None @@ -3849,7 +4100,16 @@ known_failing_tests = { "test_return_duplicate", # gradient batching rule not implemented for aten::sym_size.int "test_return_duplicate_inplace", # gradient batching rule not implemented for aten::sym_size.int "test_setitem", # CopySlices accuracy error - # Category: Inductor + "test_save_on_cpu_and_checkpoint", # https://github.com/pytorch/pytorch/issues/147565 + "test_checkpoint_detects_non_determinism", # different error + "test_checkpointing_non_reentrant_autocast_cpu", # saved != recompute + "test_checkpointing_non_reentrant_autocast_gpu", # saved != recompute + "test_checkpointing_without_reentrant_saved_object_identity", # same as https://github.com/pytorch/pytorch/issues/136193 + "test_saved_variable_packing_unpacking_did_not_save_original_with_hooks", # register_hooks multiple times + "test_saved_variable_saved_original_inplace_detach", # RuntimeError not raised + "test_access_saved_tensor_twice_without_recomputation_works", # saved != recompute + "test_checkpointing_without_reentrant_dataparallel", # https://github.com/pytorch/pytorch/issues/127115 + # Category: Inductor (pass on backend="aot_eager") "test_input_buffer_accum", # does not support sparse_grad=True: https://github.com/pytorch/pytorch/issues/120267 "test_graph_save_on_cpu", # does not support pin_memory: https://github.com/pytorch/pytorch/issues/134173 # Category: FakeTensor @@ -3861,6 +4121,7 @@ known_failing_tests = { "test_invalid_gradients", # can't give autograd error due to inaccurate output metadata of lifted backward "test_autograd_node_isinstance", # backward ctx is a fake cls and not directly a Node instance "test_backward_hook_relative_ordering", # compiled autograd collects breadth first, and module backward hook not supported + "test_checkpointing_without_reentrant_custom_function_works", # ctx.saved_tensors are cached by CA # Category: Subclasses "test_dtensor_basic", "test_dtensor_contiguous_dtensor_noncontiguous_local_as_tangent", diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index a2b4b010c0a..c39bb809c77 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -134,7 +134,7 @@ class Op: ops = OpNamespace() -_graph_placeholders = ["inputs", "sizes", "scalars", "hooks"] +_graph_placeholders = ["inputs", "sizes", "scalars", "hooks", "packed_data"] _impure_targets = OrderedSet( [ call_hook, @@ -206,7 +206,13 @@ class AutogradCompilerInstance: self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer) self.fx_tracer.tensor_attrs = {} self.symnode_proxy_lookup = {} - args_proxy, self.sizes_proxy, self.scalars_proxy, self.hooks_proxy = ( + ( + args_proxy, + self.sizes_proxy, + self.scalars_proxy, + self.hooks_proxy, + self.packed_data_proxy, + ) = ( self.fx_tracer.create_proxy("placeholder", name, (), {}) for name in _graph_placeholders ) @@ -268,7 +274,12 @@ class AutogradCompilerInstance: self.stack.enter_context( torch.fx.experimental.symbolic_shapes._suppress_guards(env) ) - return str(CompileContext.current_compile_id()), inputs, sizes, scalars + return ( + str(CompileContext.current_compile_id()), + inputs, + sizes, + scalars, + ) def log_compile_reasons( self, @@ -567,6 +578,19 @@ class AutogradCompilerInstance: kwargs, ) + def unpack_hook(self, hook_id, data_id): + assert self.hooks_proxy is not None + hook = self.hooks_proxy[hook_id] # type: ignore[index] + data = self.packed_data_proxy[data_id] # type: ignore[index] + proxy = self.proxy_call_hook( + hook, + data, + hook_type="unpack_hook", + ) + out = self.allocate_dummy() + self.bind_objects_to_proxies([out], [proxy]) + return out + def tensor_pre_hook(self, inputs, hook_id, i: int): assert self.hooks_proxy is not None hook = self.hooks_proxy[hook_id] # type: ignore[index] @@ -706,6 +730,9 @@ class AutogradCompilerInstance: after = len(self.fx_tracer.graph.nodes) verbose_log.debug("DCE removed %d nodes", before - after) + def create_graph_module(self, id): + return GraphModule(self.fx_tracer.root, self.fx_tracer.graph, id) + def end_capture(self, outputs): self.fx_tracer.create_proxy( "call_function", @@ -745,6 +772,7 @@ class AutogradCompilerInstance: ).print_readable(print_output=False), ) self.rename_aot_dispatcher_nodes() + self.delay_unpack_hook_nodes() self.reorder_tensor_pre_hook_nodes() self.reorder_pre_hook_nodes_to_schedule_asap() self.reorder_accumulate_grad_nodes() @@ -763,9 +791,7 @@ class AutogradCompilerInstance: # should prevent these ops from going into the CA graph. self.dce() - graph = GraphModule( - self.fx_tracer.root, self.fx_tracer.graph, f"CompiledAutograd{self.id}" - ) + graph = self.create_graph_module(f"CompiledAutograd{self.id}") set_locals_to_steal(graph, ["inputs"]) lazy_graph_code = lazy_format_graph_code( "Compiled autograd graph", @@ -781,7 +807,7 @@ class AutogradCompilerInstance: payload_fn=lambda: graph.print_readable(print_output=False), ) - def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks): + def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks, packed_inputs): global in_compiled_autograd_region try: in_compiled_autograd_region = True @@ -789,7 +815,7 @@ class AutogradCompilerInstance: inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True) with _disable(), make_compile_context(self.id): - return compiled_fn(inputs, sizes, scalars, hooks) + return compiled_fn(inputs, sizes, scalars, hooks, packed_inputs) finally: in_compiled_autograd_region = False @@ -938,6 +964,19 @@ class AutogradCompilerInstance: if getitem_node is not None: arg.append(getitem_node) + def delay_unpack_hook_nodes(self): + """ + We can delay unpack hooks until they are needed, even later than in the eager autograd engine. + """ + for node in self.fx_tracer.graph.find_nodes( + op="call_function", target=call_hook + ): + if node.kwargs.get("hook_type", None) != "unpack_hook": + continue + + first_user = min(node.users) + first_user.prepend(node) + def reorder_tensor_pre_hook_nodes(self): """ Usage of AOTAutograd causes all the tensor_pre_hook nodes to get pushed diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index c8d465211a6..a33fb25e500 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -1334,6 +1334,7 @@ auto Engine::execute( !AnomalyMode::is_enabled(), "compiled_autograd does not support AnomalyMode") GraphTaskGuard guard(graph_task); + CheckpointValidGuard cpvguard(graph_task); return (*compiled_autograd)( graph_root, *graph_task, accumulate_grad, outputs); } diff --git a/torch/csrc/autograd/python_saved_variable_hooks.cpp b/torch/csrc/autograd/python_saved_variable_hooks.cpp index 30fde593e0c..431a98f43d1 100644 --- a/torch/csrc/autograd/python_saved_variable_hooks.cpp +++ b/torch/csrc/autograd/python_saved_variable_hooks.cpp @@ -46,6 +46,15 @@ at::Tensor PySavedVariableHooks::call_unpack_hook() { // unpack_hook_ will be manually decrefed when the saved variable is released } +std::optional> +PySavedVariableHooks::retrieve_unpack_hook_data() const { + Py_INCREF(unpack_hook_); + Py_INCREF(data_); + return std::make_pair( + c10::SafePyObject(unpack_hook_, getPyInterpreter()), + c10::SafePyObject(data_, getPyInterpreter())); +} + // NOLINTNEXTLINE(bugprone-exception-escape) PySavedVariableHooks::~PySavedVariableHooks() { // If python is already dead, leak the wrapped python objects diff --git a/torch/csrc/autograd/python_saved_variable_hooks.h b/torch/csrc/autograd/python_saved_variable_hooks.h index ed7e1a28768..151d221458d 100644 --- a/torch/csrc/autograd/python_saved_variable_hooks.h +++ b/torch/csrc/autograd/python_saved_variable_hooks.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -17,6 +18,8 @@ struct PySavedVariableHooks : public SavedVariableHooks { void call_pack_hook(const at::Tensor& tensor) override; at::Tensor call_unpack_hook() override; ~PySavedVariableHooks() override; + std::optional> + retrieve_unpack_hook_data() const override; private: PyObject* pack_hook_; diff --git a/torch/csrc/autograd/saved_variable.cpp b/torch/csrc/autograd/saved_variable.cpp index a85275ceb04..4eeccb18977 100644 --- a/torch/csrc/autograd/saved_variable.cpp +++ b/torch/csrc/autograd/saved_variable.cpp @@ -59,6 +59,7 @@ SavedVariable::SavedVariable( if (maybe_hooks && !variable.unsafeGetTensorImpl()->is_wrapped_number()) { save_metadata(variable); set_hooks_and_pack_data(std::move(maybe_hooks), variable); + TORCH_INTERNAL_ASSERT(!data_.defined()); return; } @@ -134,9 +135,14 @@ Variable SavedVariable::unpack(std::shared_ptr saved_for) const { // We want grad_fn here to provide the most helpful debug message to the user // if versions don't match - auto grad_fn = is_inplace_on_view_ ? weak_grad_fn_.lock() - : !hooks_ ? saved_original_ ? data_.grad_fn() : nullptr - : grad_fn_; + std::shared_ptr grad_fn; + if (is_inplace_on_view_) { + grad_fn = weak_grad_fn_.lock(); + } else if (!hooks_) { + grad_fn = saved_original_ ? data_.grad_fn() : nullptr; + } else { + grad_fn = grad_fn_; + } if (!is_leaf_ && !grad_fn) { // This issue was introduced when we added logic to save the original diff --git a/torch/csrc/autograd/saved_variable.h b/torch/csrc/autograd/saved_variable.h index 0d28c95e19a..78510969400 100644 --- a/torch/csrc/autograd/saved_variable.h +++ b/torch/csrc/autograd/saved_variable.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -53,6 +54,15 @@ class TORCH_API SavedVariable { return (bool)hooks_; } + // Used by compiled autograd + std::optional> + retrieve_unpack_hook_data() const { + if (!hooks_) { + return std::nullopt; + } + return hooks_->retrieve_unpack_hook_data(); + } + private: // This field contains either: // 1. the variable to save diff --git a/torch/csrc/autograd/saved_variable_hooks.h b/torch/csrc/autograd/saved_variable_hooks.h index 2bbc8f92d42..ed255d34a04 100644 --- a/torch/csrc/autograd/saved_variable_hooks.h +++ b/torch/csrc/autograd/saved_variable_hooks.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace torch::autograd { @@ -8,6 +9,11 @@ struct TORCH_API SavedVariableHooks { virtual void call_pack_hook(const at::Tensor& tensor) = 0; virtual at::Tensor call_unpack_hook() = 0; virtual ~SavedVariableHooks() = default; + virtual std::optional> + retrieve_unpack_hook_data() const { + throw std::runtime_error( + "Compiled Autograd only supports python saved tensor hooks "); + } }; } // namespace torch::autograd diff --git a/torch/csrc/dynamo/compiled_autograd.h b/torch/csrc/dynamo/compiled_autograd.h index a775218ac67..6cb353b18a2 100644 --- a/torch/csrc/dynamo/compiled_autograd.h +++ b/torch/csrc/dynamo/compiled_autograd.h @@ -17,6 +17,78 @@ namespace torch::dynamo::autograd { using namespace torch::autograd; +// This is a layer of indirection for calling methods on the Python +// AutogradCompilerInstance (referred to as the "py_compiler") from +// libtorch_cpu (where Python is not available). +// A PyCompilerInterfaceImpl in libtorch_python subclasses it and +// overrides the methods to do the actual calls back to Python. +struct TORCH_API PyCompilerInterface { + PyCompilerInterface() = default; + PyCompilerInterface(const PyCompilerInterface&) = delete; + PyCompilerInterface& operator=(const PyCompilerInterface&) = delete; + PyCompilerInterface(PyCompilerInterface&&) = delete; + PyCompilerInterface& operator=(PyCompilerInterface&&) = delete; + virtual ~PyCompilerInterface() = default; + + // Invokes py_compiler.bind_function + virtual std::string bind_function( + PyObject* py_compiler, + const std::string& fn_name, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + functional_apply_t fn, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::vector packed_args_schema, + bool is_custom_function = false, + bool is_traceable = true) { + TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); + } + + // Invokes py_compiler.method_name(fn_name, inputs, packed_args, + // output_metadata) + virtual variable_list call_function( + PyObject* py_compiler, + const char* method_name, + const std::string& fn_name, + const variable_list& inputs, + const ivalue_list& packed_args, + const c10::IValue& output_metadata) { + TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); + } + virtual variable_list call_copy_slices_prologue( + PyObject* py_compiler, + const variable_list& inputs, + const at::TensorGeometry& base, + const at::TensorGeometry& view) { + TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); + } + virtual variable_list call_copy_slices_epilogue( + PyObject* py_compiler, + const std::vector& needs_input_grad, + const at::Tensor& result, + const variable_list& res, + const at::Tensor& grad_slice) { + TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); + } + virtual at::Tensor call_unpack( + PyObject* py_compiler, + std::optional hook_id, + size_t hook_input_id) { + TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); + } +}; + +TORCH_API const std::unique_ptr& getPyCompilerInterface(); +TORCH_API void setPyCompilerInterface( + std::unique_ptr&& impl); +TORCH_API void resetPyCompilerInterface(); + +// including torch/csrc/autograd/engine.h breaks BC by somehow introducing +// symbol resolution issues. Instead requiring downstream users to include +// engine.h to access collect_input_metadata, we provide it here (with a +// different name to avoid ambigous symbols...) +TORCH_API std::vector> get_input_metadata( + const edge_list& edges); + struct SizeInput { // Note: int value is still needed when dynamic to pass as an arg enum DynType : uint8_t { STATIC = 0, DYNAMIC = 1 }; @@ -154,9 +226,14 @@ struct TensorArgs { } TensorArg& lookup(const SavedVariable& sv) { - auto it = _saved_variables.find(&sv); - TORCH_INTERNAL_ASSERT(it != _saved_variables.end()); - return *it->second; + if (auto it = _saved_variables.find(&sv); it != _saved_variables.end()) { + // unpacked before graph + return *it->second; + } + // unpacked in graph + auto it2 = _saved_variables_proxies.find(&sv); + TORCH_INTERNAL_ASSERT(it2 != _saved_variables_proxies.end()); + return *it2->second; } TensorArg& add(const at::Tensor& tensor) { @@ -164,9 +241,7 @@ struct TensorArgs { } TensorArg& add(const SavedVariable& sv, const std::shared_ptr& node) { - // TODO(jansel): Here we unpack the SavedVariable exactly once. This might - // fire SavedTensor hooks. In the future we should try to put saved tensor - // hooks into the graph. + // no unpack hooks in this codepath at::Tensor tensor = sv.unpack(node); TensorArg& arg = add(tensor); _saved_variables.emplace(&sv, &arg); @@ -185,6 +260,7 @@ struct TensorArgs { // Every TensorArg from this is actually owned by _args (or _undefined) and // that's why we have an un-owned pointer here. std::unordered_map _saved_variables; + std::unordered_map _saved_variables_proxies; TensorArg _undefined; uint32_t _next_id = 1; // id=0 used by _undefined }; @@ -245,6 +321,11 @@ struct AutogradCompilerCall { return hooks.size() - 1; } + size_t emplace_packed_input(c10::SafePyObject&& input) { + packed_inputs.emplace_back(std::move(input)); + return packed_inputs.size() - 1; + } + void set_active_node_call_idx(size_t node_call_idx) { active_node_call_idx = node_call_idx; } @@ -255,10 +336,13 @@ struct AutogradCompilerCall { LiftedIValueArgs lifted_ivalue_args; std::vector dyn_size_inputs; std::vector hooks; + std::vector packed_inputs; NodeCalls node_calls; SizeInput::DynType default_dyn_type; // NodeCall id of each size, only when verbose logging is enabled std::vector size_input_origins; + std::unordered_map> + sv_to_hooks; }; class CompiledNodeArgs { @@ -285,8 +369,19 @@ class CompiledNodeArgs { collect(_compiler.tensor_args.add(t)); } void collect(const SavedVariable& sv, bool is_output) { - collect( - _compiler.tensor_args.add(sv, is_output ? _node_call.node : nullptr)); + if (auto hook_data = sv.retrieve_unpack_hook_data(); + hook_data.has_value()) { + // hooks, unpack in graph + auto& [hook, packed_input] = hook_data.value(); + size_t hook_id = _compiler.emplace_hook(std::move(hook)); + // rely on dynamo to dedup packed tensors from unpacked tensors + size_t input_id = _compiler.emplace_packed_input(std::move(packed_input)); + _compiler.sv_to_hooks.emplace(&sv, std::make_pair(hook_id, input_id)); + } else { + // no hooks, unpack now + collect( + _compiler.tensor_args.add(sv, is_output ? _node_call.node : nullptr)); + } } void collect(const c10::SymInt& t) { _compiler.add_size_input(t); @@ -655,13 +750,26 @@ class SwapSavedVariables { } void before(SavedVariable& t) { - TensorArg& arg = compiler.tensor_args.lookup(t); - stashed_variables.save(&t, std::move(t)); - if (arg.defined()) { + if (auto it = compiler.sv_to_hooks.find(&t); + it != compiler.sv_to_hooks.end()) { + const auto& pyinterface = + torch::dynamo::autograd::getPyCompilerInterface(); + auto proxy_tensor = pyinterface->call_unpack( + get_py_compiler(), it->second.first, it->second.second); + stashed_variables.save(&t, std::move(t)); bool prior = at::SavedTensorDefaultHooks::set_tracing(true); - TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined()); - t = SavedVariable(arg.proxy_tensor, false); + t = SavedVariable(proxy_tensor, false); at::SavedTensorDefaultHooks::set_tracing(prior); + } else { + // no hooks, was already unpacked + TensorArg& arg = compiler.tensor_args.lookup(t); + stashed_variables.save(&t, std::move(t)); + if (arg.defined()) { + bool prior = at::SavedTensorDefaultHooks::set_tracing(true); + TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined()); + t = SavedVariable(arg.proxy_tensor, false); + at::SavedTensorDefaultHooks::set_tracing(prior); + } } } void after(SavedVariable& t) { @@ -1370,73 +1478,6 @@ struct PackedArgs { int64_t idx = 0; }; -// This is a layer of indirection for calling methods on the Python -// AutogradCompilerInstance (referred to as the "py_compiler") from -// libtorch_cpu (where Python is not available). -// A PyCompilerInterfaceImpl in libtorch_python subclasses it and -// overrides the methods to do the actual calls back to Python. -struct TORCH_API PyCompilerInterface { - PyCompilerInterface() = default; - PyCompilerInterface(const PyCompilerInterface&) = delete; - PyCompilerInterface& operator=(const PyCompilerInterface&) = delete; - PyCompilerInterface(PyCompilerInterface&&) = delete; - PyCompilerInterface& operator=(PyCompilerInterface&&) = delete; - virtual ~PyCompilerInterface() = default; - - // Invokes py_compiler.bind_function - virtual std::string bind_function( - PyObject* py_compiler, - const std::string& fn_name, - // NOLINTNEXTLINE(performance-unnecessary-value-param) - functional_apply_t fn, - // NOLINTNEXTLINE(performance-unnecessary-value-param) - std::vector packed_args_schema, - bool is_custom_function = false, - bool is_traceable = true) { - TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); - } - - // Invokes py_compiler.method_name(fn_name, inputs, packed_args, - // output_metadata) - virtual variable_list call_function( - PyObject* py_compiler, - const char* method_name, - const std::string& fn_name, - const variable_list& inputs, - const ivalue_list& packed_args, - const c10::IValue& output_metadata) { - TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); - } - - virtual variable_list call_copy_slices_prologue( - PyObject* py_compiler, - const variable_list& inputs, - const at::TensorGeometry& base, - const at::TensorGeometry& view) { - TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); - } - virtual variable_list call_copy_slices_epilogue( - PyObject* py_compiler, - const std::vector& needs_input_grad, - const at::Tensor& result, - const variable_list& res, - const at::Tensor& grad_slice) { - TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); - } -}; - -TORCH_API const std::unique_ptr& getPyCompilerInterface(); -TORCH_API void setPyCompilerInterface( - std::unique_ptr&& impl); -TORCH_API void resetPyCompilerInterface(); - -// including torch/csrc/autograd/engine.h breaks BC by somehow introducing -// symbol resolution issues. Instead requiring downstream users to include -// engine.h to access collect_input_metadata, we provide it here (with a -// different name to avoid ambigous symbols...) -TORCH_API std::vector> get_input_metadata( - const edge_list& edges); - } // namespace torch::dynamo::autograd template <> diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index b96439ff6d4..7f105f920ca 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -203,6 +203,16 @@ struct PyCompilerInterfaceImpl : PyCompilerInterface { auto output = py::cast>>(stuff); return toTensorList(output); } + at::Tensor call_unpack( + PyObject* py_compiler, + std::optional hook_id, + size_t hook_input_id) override { + py::handle handle(py_compiler); + py::object proxy = handle.attr("unpack_hook")(hook_id, hook_input_id); + auto tmp = py::cast>(proxy); + TORCH_INTERNAL_ASSERT(tmp.has_value()); + return tmp.value(); + } }; static PyObject* wrap_int_list(const std::vector& inputs) { @@ -213,7 +223,7 @@ static PyObject* wrap_int_list(const std::vector& inputs) { return pyinput; } -static PyObject* convert_hook_list(std::vector& inputs) { +static PyObject* convert_pyobj_list(std::vector& inputs) { // inplace, consumes the input hooks PyObject* pyinput = PyTuple_New(static_cast(inputs.size())); for (const auto i : c10::irange(inputs.size())) { @@ -654,7 +664,7 @@ static PyObject* wrap_string_list(const std::vector& strs) { return pystrs; } -std::string unwrap_string(PyObject* pystr) { +static std::string unwrap_string(PyObject* pystr) { TORCH_INTERNAL_ASSERT(PyUnicode_Check(pystr)); const char* str = PyUnicode_AsUTF8(pystr); TORCH_INTERNAL_ASSERT(str != nullptr); @@ -796,7 +806,8 @@ static CacheNode* _compiled_autograd_impl( THPObjectPtr* graph_arg_inputs, THPObjectPtr* graph_arg_sizes, THPObjectPtr* graph_arg_ivalue_args, - THPObjectPtr* graph_arg_hooks) { + THPObjectPtr* graph_arg_hooks, + THPObjectPtr* graph_arg_packed_inputs) { std::unordered_map& dependencies = graph_task.dependencies_; std::vector> worklist{graph_root}; AutogradCompilerCall compiler_call(get_default_dyn_type()); @@ -1052,7 +1063,8 @@ static CacheNode* _compiled_autograd_impl( *graph_arg_sizes = wrap_int_list(compiler_call.dyn_size_inputs); *graph_arg_ivalue_args = wrap_lifted_ivalue_args(compiler_call.lifted_ivalue_args.args); - *graph_arg_hooks = convert_hook_list(compiler_call.hooks); + *graph_arg_hooks = convert_pyobj_list(compiler_call.hooks); + *graph_arg_packed_inputs = convert_pyobj_list(compiler_call.packed_inputs); return cache; } @@ -1093,6 +1105,7 @@ static variable_list compiled_autograd( THPObjectPtr sizes; THPObjectPtr ivalue_args; THPObjectPtr hooks; + THPObjectPtr packed_inputs; CacheNode* cache = _compiled_autograd_impl( graph_root, graph_task, @@ -1101,7 +1114,8 @@ static variable_list compiled_autograd( &inputs, &sizes, &ivalue_args, - &hooks); + &hooks, + &packed_inputs); THPObjectPtr pyresult(check(PyObject_CallFunctionObjArgs( cache->runtime_wrapper.get(), @@ -1110,6 +1124,7 @@ static variable_list compiled_autograd( sizes.get(), ivalue_args.get(), hooks.get(), + packed_inputs.get(), NULL))); variable_list outputs = THPVariable_UnpackList(pyresult); TORCH_INTERNAL_ASSERT(outputs.size() == output_edges.size());