mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[aot][ca] save bw_module in AOTAutogradCache (#151860)"
This reverts commit 613bd46272.
Reverted https://github.com/pytorch/pytorch/pull/151860 on behalf of https://github.com/huydhn due to Chatting with @xmfan and decide to revert and reland this instead ([comment](https://github.com/pytorch/pytorch/pull/151860#issuecomment-2856709646))
This commit is contained in:
parent
f6db749e60
commit
a28dcdba2c
|
|
@ -66,118 +66,6 @@ class AOTAutogradCacheTests(InductorTestCase):
|
|||
torch._dynamo.reset()
|
||||
torch._inductor.codecache.PyCodeCache.cache_clear(purge=True)
|
||||
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@inductor_config.patch(
|
||||
{
|
||||
"fx_graph_cache": True,
|
||||
"fx_graph_remote_cache": False,
|
||||
"autotune_local_cache": True,
|
||||
}
|
||||
)
|
||||
def test_cache_lazy_backward_for_compiled_autograd(self):
|
||||
device = "cpu"
|
||||
dtype = torch.float32
|
||||
dynamic = True
|
||||
"""
|
||||
Verify that we can populate and hot load functions from the cache.
|
||||
"""
|
||||
if device == GPU_TYPE and not HAS_GPU:
|
||||
raise unittest.SkipTest(f"requires {GPU_TYPE}")
|
||||
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
|
||||
raise unittest.SkipTest("requires SM80 or later")
|
||||
|
||||
def fn(x, y):
|
||||
return x.sin() @ y
|
||||
|
||||
a = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True)
|
||||
b = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True)
|
||||
|
||||
# Record artifacts
|
||||
with fresh_inductor_cache():
|
||||
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
||||
|
||||
# A first call should miss in the cache.
|
||||
eager_result = fn(a, b)
|
||||
expected_grads = torch.autograd.grad(eager_result.sum(), inputs=(a, b))
|
||||
compiled_result = compiled_fn(a, b)
|
||||
with torch._dynamo.compiled_autograd._enable(
|
||||
torch.compile(dynamic=dynamic)
|
||||
):
|
||||
actual_grads = torch.autograd.grad(compiled_result.sum(), inputs=(a, b))
|
||||
if hasattr(a, "_dynamo_weak_dynamic_indices"):
|
||||
del a._dynamo_weak_dynamic_indices
|
||||
self.assertEqual(eager_result, compiled_result)
|
||||
self.assertEqual(expected_grads[0], actual_grads[0])
|
||||
self.assertEqual(expected_grads[1], actual_grads[1])
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 3)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
|
||||
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
|
||||
|
||||
artifacts = torch.compiler.save_cache_artifacts()
|
||||
|
||||
self.assertIsNotNone(artifacts)
|
||||
|
||||
artifact_bytes, cache_info = artifacts
|
||||
|
||||
autotune_expect = 2 if device == GPU_TYPE else 0
|
||||
|
||||
self.assertEqual(len(cache_info.inductor_artifacts), 3)
|
||||
self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect)
|
||||
self.assertEqual(len(cache_info.aot_autograd_artifacts), 1)
|
||||
self.assertEqual(len(cache_info.pgo_artifacts), 0)
|
||||
|
||||
self._clear_all_caches()
|
||||
|
||||
# Clean triton kernels
|
||||
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
||||
|
||||
# Hot load and hit, should not recompile
|
||||
with fresh_inductor_cache():
|
||||
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||
|
||||
self.assertEqual(len(cache_info.inductor_artifacts), 3)
|
||||
self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect)
|
||||
self.assertEqual(len(cache_info.aot_autograd_artifacts), 1)
|
||||
self.assertEqual(len(cache_info.pgo_artifacts), 0)
|
||||
|
||||
for i in range(3):
|
||||
counters.clear()
|
||||
eager_result = fn(a, b)
|
||||
expected_grads = torch.autograd.grad(eager_result.sum(), inputs=(a, b))
|
||||
compiled_result = compiled_fn(a, b)
|
||||
with torch._dynamo.compiled_autograd._enable(
|
||||
torch.compile(dynamic=dynamic)
|
||||
):
|
||||
actual_grads = torch.autograd.grad(
|
||||
compiled_result.sum(), inputs=(a, b)
|
||||
)
|
||||
if hasattr(a, "_dynamo_weak_dynamic_indices"):
|
||||
del a._dynamo_weak_dynamic_indices
|
||||
self.assertEqual(eager_result, compiled_result)
|
||||
self.assertEqual(expected_grads[0], actual_grads[0])
|
||||
self.assertEqual(expected_grads[1], actual_grads[1])
|
||||
|
||||
if i == 0:
|
||||
# initial compile
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 3)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["fxgraph_lookup_write_file"], 3
|
||||
)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 0)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
|
||||
self.assertEqual(
|
||||
counters["aot_autograd"]["autograd_cache_saved"], 0
|
||||
)
|
||||
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
|
||||
else:
|
||||
# no recompiles
|
||||
self.assertFalse(counters)
|
||||
|
||||
@requires_triton()
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@inductor_config.patch(
|
||||
|
|
|
|||
|
|
@ -457,13 +457,12 @@ class StructuredTraceTest(TestCase):
|
|||
{"inductor_post_grad_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
|
||||
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
|
||||
{"dynamo_start": {"stack": "STACK"}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||
{"bwd_compilation_metrics": "METRICS", "frame_id": 2, "frame_compile_id": 0, "attempt": 1}
|
||||
{"dynamo_start": {"stack": "STACK"}, "frame_id": 5, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 5, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 5, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['output']"}, "frame_id": 5, "frame_compile_id": 0, "attempt": 0}
|
||||
{"compilation_metrics": "METRICS", "frame_id": 5, "frame_compile_id": 0, "attempt": 0}
|
||||
{"dynamo_start": {"stack": "STACK"}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['output']"}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||
{"compilation_metrics": "METRICS", "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -409,11 +409,7 @@ class AutogradCompilerInstance:
|
|||
metadata = CompiledFunction.metadata
|
||||
maybe_subclass_metadata = CompiledFunction.maybe_subclass_metadata
|
||||
aot_id = CompiledFunction._aot_id
|
||||
bw_module = ctx._bw_module
|
||||
aot_symints = ctx.symints
|
||||
symints = ctx._get_compiled_autograd_symints()
|
||||
del CompiledFunction
|
||||
del ctx
|
||||
|
||||
@torch._dynamo.allow_in_graph # type: ignore[misc]
|
||||
def call_aot_bwd_prologue(ctx_saved_tensors, ctx_symints, *flat_args):
|
||||
|
|
@ -455,12 +451,13 @@ class AutogradCompilerInstance:
|
|||
|
||||
# set up the proxy inputs to ctx._bw_module
|
||||
# the calling convention is: [*symints, *args (primals and tangents), backward_state]
|
||||
num_args = num_inputs(bw_module.graph)
|
||||
num_args = num_inputs(ctx._bw_module.graph)
|
||||
pall_args = [
|
||||
pgrads[i] for i in range(num_args - int(pbackward_state is not None))
|
||||
]
|
||||
# replace the symints with our symints
|
||||
assert len(symints) == len(aot_symints)
|
||||
symints = ctx._get_compiled_autograd_symints()
|
||||
assert len(symints) == len(ctx.symints)
|
||||
psymints = [self.to_proxy(e) for e in symints]
|
||||
pall_args[: len(symints)] = psymints
|
||||
# Add backward_state
|
||||
|
|
@ -484,7 +481,7 @@ class AutogradCompilerInstance:
|
|||
# make it both informative and unique
|
||||
return f"aot{deduped_aot_id}_{node_name}"
|
||||
|
||||
for node in bw_module.graph.nodes:
|
||||
for node in ctx._bw_module.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
ph = pall_args[args_idx].node
|
||||
ph.name = make_unique(node.name)
|
||||
|
|
@ -501,7 +498,9 @@ class AutogradCompilerInstance:
|
|||
elif node.op == "get_attr":
|
||||
name = node.target
|
||||
qualname = self.fx_tracer.get_fresh_qualname(name)
|
||||
setattr(self.fx_tracer.root, qualname, getattr(bw_module, name))
|
||||
setattr(
|
||||
self.fx_tracer.root, qualname, getattr(ctx._bw_module, name)
|
||||
)
|
||||
result = self.fx_tracer.create_node("get_attr", qualname, (), {})
|
||||
result.name = make_unique(node.name)
|
||||
value_remap[node] = result
|
||||
|
|
@ -1271,6 +1270,11 @@ class AutogradCompilerInstance:
|
|||
forward_cls = pyobj._forward_cls # type: ignore[attr-defined]
|
||||
if hasattr(forward_cls, "_aot_id"):
|
||||
# backward was created by AOT Dispatcher
|
||||
if forward_cls._lazy_backward_info is None:
|
||||
raise RuntimeError(
|
||||
"""This compiled backward function was saved by AOTAutogradCache, which does not support
|
||||
compiled autograd. Please turn off AOTAutogradCache using `TORCHINDUCTOR_AUTOGRAD_CACHE=0`."""
|
||||
)
|
||||
maybe_aot_id = forward_cls._aot_id
|
||||
new_code = f"{node_name}{maybe_aot_id} (NodeCall {nodecall_index})"
|
||||
raw_stack_trace = CapturedTraceback.extract().format()[-1]
|
||||
|
|
|
|||
|
|
@ -57,7 +57,6 @@ from torchgen.utils import dataclass_repr
|
|||
from .runtime_wrappers import (
|
||||
AOTDispatchAutograd,
|
||||
AOTDispatchSubclassWrapper,
|
||||
CachedAutogradLazyBackwardCompileInfo,
|
||||
CompilerWrapper,
|
||||
FunctionalizedRngRuntimeWrapper,
|
||||
post_compile,
|
||||
|
|
@ -552,9 +551,6 @@ class GenericAOTAutogradCacheEntry(Generic[TForward, TBackward]):
|
|||
|
||||
guards_expr: Optional[str]
|
||||
|
||||
# # Used by compiled autograd
|
||||
cached_lazy_backward_info: Optional[CachedAutogradLazyBackwardCompileInfo]
|
||||
|
||||
# Turn cache entry into the original callable
|
||||
def wrap_post_compile(
|
||||
self,
|
||||
|
|
@ -700,7 +696,7 @@ class GenericAOTAutogradCacheEntry(Generic[TForward, TBackward]):
|
|||
self.compiled_bw.backward_state_indices,
|
||||
disable_amp,
|
||||
self.indices_of_inps_to_detach,
|
||||
self.cached_lazy_backward_info,
|
||||
None, # lazy_backward_info
|
||||
aot_config,
|
||||
fw_metadata=self.runtime_metadata,
|
||||
try_save_cache_entry=None,
|
||||
|
|
|
|||
|
|
@ -55,7 +55,6 @@ from .runtime_wrappers import (
|
|||
AOTDispatchSubclassWrapper,
|
||||
AOTSyntheticBaseWrapper,
|
||||
AutogradLazyBackwardCompileInfo,
|
||||
CachedAutogradLazyBackwardCompileInfo,
|
||||
CompilerWrapper,
|
||||
DebugAssertWrapper,
|
||||
EffectTokensWrapper,
|
||||
|
|
@ -279,7 +278,6 @@ def aot_dispatch_base(
|
|||
backward_time_taken_ns=0,
|
||||
sanitized_aot_config=sanitize_aot_config(aot_config),
|
||||
guards_expr=guards_expr,
|
||||
cached_lazy_backward_info=None,
|
||||
)
|
||||
AOTAutogradCache.save(
|
||||
cache_info.cache_key, entry, remote=should_use_remote_autograd_cache()
|
||||
|
|
@ -1283,13 +1281,8 @@ def aot_dispatch_autograd(
|
|||
# close over aot_config.cache_info, since aot_config never changes.
|
||||
# But closing over random variables is confusing IMO, so I'm leaving it.
|
||||
def try_save_cache_entry( # noqa: F811
|
||||
compiled_bw_func, lazy_backward_info, _fw_metadata, aot_config
|
||||
compiled_bw_func, _fw_metadata, aot_config
|
||||
):
|
||||
bw_module = lazy_backward_info.bw_module
|
||||
bw_module.meta = {}
|
||||
for node in bw_module.graph.nodes:
|
||||
node.meta = {}
|
||||
|
||||
fw_key = getattr(compiled_fw_func, "_fx_graph_cache_key", None)
|
||||
fw_debug_lines = getattr(
|
||||
compiled_fw_func, "_fx_graph_cache_debug_lines", []
|
||||
|
|
@ -1335,18 +1328,13 @@ def aot_dispatch_autograd(
|
|||
backward_time_taken_ns,
|
||||
sanitized_aot_config=sanitize_aot_config(aot_config),
|
||||
guards_expr=guards_expr,
|
||||
cached_lazy_backward_info=CachedAutogradLazyBackwardCompileInfo(
|
||||
bw_module
|
||||
),
|
||||
)
|
||||
remote = should_use_remote_autograd_cache()
|
||||
AOTAutogradCache.save(cache_info.cache_key, entry, remote)
|
||||
|
||||
if compiled_bw_func is not None:
|
||||
# If we already compiled it we can just run it right now without waiting
|
||||
try_save_cache_entry(
|
||||
compiled_bw_func, lazy_backward_info, fw_metadata, aot_config
|
||||
)
|
||||
try_save_cache_entry(compiled_bw_func, fw_metadata, aot_config)
|
||||
try_save_cache_entry = None
|
||||
|
||||
compiled_fn = AOTDispatchAutograd.post_compile(
|
||||
|
|
|
|||
|
|
@ -1484,20 +1484,12 @@ def merge_view_inputs(
|
|||
# with compiled autograd. See: https://github.com/pytorch/pytorch/pull/149229#discussion_r2002122645.
|
||||
@dataclass
|
||||
class AutogradLazyBackwardCompileInfo:
|
||||
bw_module: torch.fx.GraphModule
|
||||
bw_module: Callable
|
||||
placeholder_list: list[Any]
|
||||
saved_context: Optional[TracingContext]
|
||||
saved_compile_context: Optional[CompileContext]
|
||||
|
||||
|
||||
# On an AOT Autograd cache hit, we already have a lowered backward, so there is usually
|
||||
# no need to keep information around for a new lazy compilation. Except for compiled autograd,
|
||||
# which wants to retrace this backward into a larger graph, and it needs the graph module to do so.
|
||||
@dataclass
|
||||
class CachedAutogradLazyBackwardCompileInfo:
|
||||
bw_module: torch.fx.GraphModule # missing a couple of fields compared to AutogradLazyBackwardCompileInfo's bw_module
|
||||
|
||||
|
||||
def _raise_if_functorch_active():
|
||||
# not ideal but prevent the user from seeing a nasty traceback - See #138422
|
||||
stack = torch._C._functorch.peek_interpreter_stack()
|
||||
|
|
@ -1917,11 +1909,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
|||
backward_state_indices: list[int],
|
||||
disable_amp: bool,
|
||||
indices_of_inps_to_detach: list[int],
|
||||
lazy_backward_info: Optional[
|
||||
Union[
|
||||
AutogradLazyBackwardCompileInfo, CachedAutogradLazyBackwardCompileInfo
|
||||
]
|
||||
],
|
||||
lazy_backward_info: Optional[AutogradLazyBackwardCompileInfo],
|
||||
aot_config: AOTConfig,
|
||||
*,
|
||||
fw_metadata: ViewAndMutationMeta, # runtime metadata
|
||||
|
|
@ -2229,9 +2217,6 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
|||
|
||||
if CompiledFunction.compiled_bw is None:
|
||||
assert lazy_backward_info is not None
|
||||
assert isinstance(
|
||||
lazy_backward_info, AutogradLazyBackwardCompileInfo
|
||||
)
|
||||
|
||||
if not saved_tensors_use_once:
|
||||
fw_metadata.bw_donated_idxs = []
|
||||
|
|
@ -2276,7 +2261,6 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
|||
if try_save_cache_entry is not None:
|
||||
try_save_cache_entry(
|
||||
CompiledFunction.compiled_bw,
|
||||
lazy_backward_info,
|
||||
fw_metadata,
|
||||
aot_config,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user