[inductor] Conditionally copy args to cpu to minimize memory overhead of autotuning (#136701)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136701
Approved by: https://github.com/eellison
This commit is contained in:
Sam Larsen 2024-10-04 15:15:44 -07:00 committed by PyTorch MergeBot
parent 900f57216f
commit c87c9f0a01
6 changed files with 91 additions and 19 deletions

View File

@ -422,6 +422,7 @@ class CudaReproTests(TestCase):
configs=configs,
save_cache_hook=False,
mutated_arg_names=["in_out_ptr0"],
optimize_mem=True,
heuristic_type=HeuristicType.POINTWISE,
)

View File

@ -126,6 +126,7 @@ class TestTritonHeuristics(TestCase):
"configs": configs,
"save_cache_hook": False,
"mutated_arg_names": [],
"optimize_mem": True,
"heuristic_type": HeuristicType.POINTWISE,
"inductor_meta": inductor_meta,
}

View File

@ -2764,10 +2764,16 @@ class TritonKernel(SIMDKernel):
"constants": {},
}
# Skip memory optimization for forward of the training loop where we expect
# every new node will increase the peak memory and our greedy approach would
# introduce a lot of unnecessary cpu copies.
optimize_mem = V.graph.is_inference or V.graph.is_backward
inductor_meta = {
"autotune_hints": set(self.autotune_hints),
"kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
"mutated_arg_names": mutated_args,
"optimize_mem": optimize_mem,
"no_x_dim": self.no_x_dim,
"num_load": self.num_load,
"num_reduction": self.num_reduction,

View File

@ -811,6 +811,7 @@ def fx_codegen_and_compile(
user_visible_outputs=user_visible_outputs,
extern_node_serializer=extern_node_serializer,
is_inference=is_inference,
is_backward=is_backward,
is_const_graph=True,
)
with V.set_graph_handler(const_graph):
@ -832,6 +833,7 @@ def fx_codegen_and_compile(
user_visible_outputs=user_visible_outputs,
extern_node_serializer=extern_node_serializer,
is_inference=is_inference,
is_backward=is_backward,
const_output_index=const_output_index,
const_code=const_code,
const_module=const_graph,

View File

@ -321,6 +321,7 @@ class GraphLowering(torch.fx.Interpreter):
Callable[[List[ir.ExternKernelNode]], Any]
] = None,
is_inference: bool = False,
is_backward: bool = False,
is_const_graph: bool = False,
const_output_index: Optional[Dict[str, int]] = None,
const_code: Optional[str] = None,
@ -336,6 +337,7 @@ class GraphLowering(torch.fx.Interpreter):
)
self.num_channels_last_conv = 0
self.is_inference = is_inference
self.is_backward = is_backward
self.is_const_graph = is_const_graph
self.const_code = const_code
self.const_module = const_module
@ -659,6 +661,7 @@ class GraphLowering(torch.fx.Interpreter):
aot_mode=self.aot_mode,
extern_node_serializer=self.extern_node_serializer,
is_inference=self.is_inference,
is_backward=self.is_backward,
name=self.qualify_name(subgraph_name),
)

View File

@ -15,7 +15,7 @@ import re
import sys
import threading
import time
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Container, Dict, List, Optional, Set, Tuple
import torch
@ -187,6 +187,7 @@ class CachingAutotuner(KernelInterface):
configs,
save_cache_hook,
mutated_arg_names: List[str], # see [Note: clone mutated buffers]
optimize_mem,
heuristic_type,
size_hints=None,
inductor_meta=None, # metadata not relevant to triton
@ -210,6 +211,7 @@ class CachingAutotuner(KernelInterface):
self.inductor_meta = {} if inductor_meta is None else inductor_meta
self.save_cache_hook = save_cache_hook
self.mutated_arg_names = mutated_arg_names
self.optimize_mem = optimize_mem
self.configs = configs
self.heuristic_type = heuristic_type
self.custom_kernel = custom_kernel
@ -683,14 +685,19 @@ class CachingAutotuner(KernelInterface):
device_interface = self.get_device_interface()
stream = device_interface.get_raw_stream(device_interface.current_device())
cpu_copies = self.copy_args_to_cpu_if_needed(*args, **kwargs)
def kernel_call():
cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs)
cloned_args, cloned_kwargs = self.maybe_clone_args(
cpu_copies, *args, **kwargs
)
launcher(
*cloned_args,
**cloned_kwargs,
grid=grid,
stream=stream,
)
self.restore_args_from_cpu(cpu_copies)
if with_profiler:
from torch._inductor.utils import do_bench_using_profiling
@ -702,31 +709,80 @@ class CachingAutotuner(KernelInterface):
return benchmarker.benchmark_gpu(kernel_call, rep=40)
def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]:
def copy_args_to_cpu_if_needed(self, *args, **kwargs):
"""
To support benchmarking in the presence of mutated args, we need to avoid
autotuning contanminating them. We try to pass cloned args to the kernel.
If those clones would increase the peak memory usage, however, we instead
copy to cpu and restore them after each iteratrion. Figure out the args
to be copied and do the copying.
"""
if not self.optimize_mem:
return {}
copies = {}
budget = torch.cuda.max_memory_allocated() - torch.cuda.memory_allocated()
def maybe_copy(name, arg):
if name in self.mutated_arg_names:
nonlocal budget
assert isinstance(arg, torch.Tensor)
assert not arg.is_cpu
size = arg.numel() * arg.element_size()
if size > budget:
cpu_arg = torch.empty_strided(
arg.size(),
arg.stride(),
dtype=arg.dtype,
device="cpu",
pin_memory=True,
)
cpu_arg.copy_(arg, non_blocking=True)
copies[name] = (arg, cpu_arg)
else:
budget -= size
for i, arg in enumerate(args):
maybe_copy(self.fn.arg_names[i], arg)
for name, arg in kwargs.items():
maybe_copy(name, arg)
return copies
def restore_args_from_cpu(self, cpu_copies):
for pair in cpu_copies.values():
arg, cpu_arg = pair
arg.copy_(cpu_arg, non_blocking=True)
def maybe_clone_args(
self, exclude: Container[str], *args, **kwargs
) -> Tuple[List[Any], Dict[str, Any]]:
"""
Prepare new args and kwargs by cloning any in-place buffers
(that are not in the provided exclusion list), to avoid autotune
contaminating them. Avoid cloning the other buffers because it
leads to increased memory usage.
"""
from ..compile_fx import clone_preserve_strides
# [Note: clone mutated buffers]
# clone inplace buffers to avoid autotune contaminating them if
# the kernel does in-place stores. avoid cloning other buffers because
# it leads to increase memory use
cloned_args = []
for i, arg in enumerate(args):
if self.fn.arg_names[i] in self.mutated_arg_names:
def prepare_arg(name, arg):
if name in self.mutated_arg_names and name not in exclude:
assert isinstance(arg, torch.Tensor)
cloned_args.append(clone_preserve_strides(arg))
return clone_preserve_strides(arg)
else:
cloned_args.append(arg)
return arg
cloned_kwargs: Dict[str, Any] = {}
for name, arg in kwargs.items():
if name in self.mutated_arg_names:
assert isinstance(arg, torch.Tensor)
cloned_kwargs[name] = clone_preserve_strides(arg)
else:
cloned_kwargs[name] = arg
cloned_args = [
prepare_arg(self.fn.arg_names[i], arg) for i, arg in enumerate(args)
]
cloned_kwargs = {name: prepare_arg(name, arg) for name, arg in kwargs.items()}
return cloned_args, cloned_kwargs
def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]:
return self.maybe_clone_args(set(), *args, **kwargs)
def benchmark_all_configs(self, *args, **kwargs):
with dynamo_timed("CachingAutotuner.benchmark_all_configs"):
timings = {
@ -1070,6 +1126,7 @@ def cached_autotune(
log.debug("autotune caching is disabled by config.force_disable_caches")
mutated_arg_names = inductor_meta.pop("mutated_arg_names", ())
optimize_mem = inductor_meta.pop("optimize_mem", True)
def decorator(fn):
# Remove XBLOCK from config if it's not a function argument.
@ -1096,6 +1153,7 @@ def cached_autotune(
configs=configs,
save_cache_hook=autotune_cache and autotune_cache.save,
mutated_arg_names=mutated_arg_names,
optimize_mem=optimize_mem,
heuristic_type=heuristic_type,
size_hints=size_hints,
custom_kernel=custom_kernel,
@ -1108,6 +1166,7 @@ def cached_autotune(
configs=configs,
save_cache_hook=autotune_cache and autotune_cache.save,
mutated_arg_names=mutated_arg_names,
optimize_mem=optimize_mem,
heuristic_type=heuristic_type,
size_hints=size_hints,
custom_kernel=custom_kernel,