mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] Conditionally copy args to cpu to minimize memory overhead of autotuning (#136701)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136701 Approved by: https://github.com/eellison
This commit is contained in:
parent
900f57216f
commit
c87c9f0a01
|
|
@ -422,6 +422,7 @@ class CudaReproTests(TestCase):
|
|||
configs=configs,
|
||||
save_cache_hook=False,
|
||||
mutated_arg_names=["in_out_ptr0"],
|
||||
optimize_mem=True,
|
||||
heuristic_type=HeuristicType.POINTWISE,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -126,6 +126,7 @@ class TestTritonHeuristics(TestCase):
|
|||
"configs": configs,
|
||||
"save_cache_hook": False,
|
||||
"mutated_arg_names": [],
|
||||
"optimize_mem": True,
|
||||
"heuristic_type": HeuristicType.POINTWISE,
|
||||
"inductor_meta": inductor_meta,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2764,10 +2764,16 @@ class TritonKernel(SIMDKernel):
|
|||
"constants": {},
|
||||
}
|
||||
|
||||
# Skip memory optimization for forward of the training loop where we expect
|
||||
# every new node will increase the peak memory and our greedy approach would
|
||||
# introduce a lot of unnecessary cpu copies.
|
||||
optimize_mem = V.graph.is_inference or V.graph.is_backward
|
||||
|
||||
inductor_meta = {
|
||||
"autotune_hints": set(self.autotune_hints),
|
||||
"kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
|
||||
"mutated_arg_names": mutated_args,
|
||||
"optimize_mem": optimize_mem,
|
||||
"no_x_dim": self.no_x_dim,
|
||||
"num_load": self.num_load,
|
||||
"num_reduction": self.num_reduction,
|
||||
|
|
|
|||
|
|
@ -811,6 +811,7 @@ def fx_codegen_and_compile(
|
|||
user_visible_outputs=user_visible_outputs,
|
||||
extern_node_serializer=extern_node_serializer,
|
||||
is_inference=is_inference,
|
||||
is_backward=is_backward,
|
||||
is_const_graph=True,
|
||||
)
|
||||
with V.set_graph_handler(const_graph):
|
||||
|
|
@ -832,6 +833,7 @@ def fx_codegen_and_compile(
|
|||
user_visible_outputs=user_visible_outputs,
|
||||
extern_node_serializer=extern_node_serializer,
|
||||
is_inference=is_inference,
|
||||
is_backward=is_backward,
|
||||
const_output_index=const_output_index,
|
||||
const_code=const_code,
|
||||
const_module=const_graph,
|
||||
|
|
|
|||
|
|
@ -321,6 +321,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
Callable[[List[ir.ExternKernelNode]], Any]
|
||||
] = None,
|
||||
is_inference: bool = False,
|
||||
is_backward: bool = False,
|
||||
is_const_graph: bool = False,
|
||||
const_output_index: Optional[Dict[str, int]] = None,
|
||||
const_code: Optional[str] = None,
|
||||
|
|
@ -336,6 +337,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
)
|
||||
self.num_channels_last_conv = 0
|
||||
self.is_inference = is_inference
|
||||
self.is_backward = is_backward
|
||||
self.is_const_graph = is_const_graph
|
||||
self.const_code = const_code
|
||||
self.const_module = const_module
|
||||
|
|
@ -659,6 +661,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
aot_mode=self.aot_mode,
|
||||
extern_node_serializer=self.extern_node_serializer,
|
||||
is_inference=self.is_inference,
|
||||
is_backward=self.is_backward,
|
||||
name=self.qualify_name(subgraph_name),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import re
|
|||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from typing import Any, Container, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -187,6 +187,7 @@ class CachingAutotuner(KernelInterface):
|
|||
configs,
|
||||
save_cache_hook,
|
||||
mutated_arg_names: List[str], # see [Note: clone mutated buffers]
|
||||
optimize_mem,
|
||||
heuristic_type,
|
||||
size_hints=None,
|
||||
inductor_meta=None, # metadata not relevant to triton
|
||||
|
|
@ -210,6 +211,7 @@ class CachingAutotuner(KernelInterface):
|
|||
self.inductor_meta = {} if inductor_meta is None else inductor_meta
|
||||
self.save_cache_hook = save_cache_hook
|
||||
self.mutated_arg_names = mutated_arg_names
|
||||
self.optimize_mem = optimize_mem
|
||||
self.configs = configs
|
||||
self.heuristic_type = heuristic_type
|
||||
self.custom_kernel = custom_kernel
|
||||
|
|
@ -683,14 +685,19 @@ class CachingAutotuner(KernelInterface):
|
|||
device_interface = self.get_device_interface()
|
||||
stream = device_interface.get_raw_stream(device_interface.current_device())
|
||||
|
||||
cpu_copies = self.copy_args_to_cpu_if_needed(*args, **kwargs)
|
||||
|
||||
def kernel_call():
|
||||
cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs)
|
||||
cloned_args, cloned_kwargs = self.maybe_clone_args(
|
||||
cpu_copies, *args, **kwargs
|
||||
)
|
||||
launcher(
|
||||
*cloned_args,
|
||||
**cloned_kwargs,
|
||||
grid=grid,
|
||||
stream=stream,
|
||||
)
|
||||
self.restore_args_from_cpu(cpu_copies)
|
||||
|
||||
if with_profiler:
|
||||
from torch._inductor.utils import do_bench_using_profiling
|
||||
|
|
@ -702,31 +709,80 @@ class CachingAutotuner(KernelInterface):
|
|||
|
||||
return benchmarker.benchmark_gpu(kernel_call, rep=40)
|
||||
|
||||
def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]:
|
||||
def copy_args_to_cpu_if_needed(self, *args, **kwargs):
|
||||
"""
|
||||
To support benchmarking in the presence of mutated args, we need to avoid
|
||||
autotuning contanminating them. We try to pass cloned args to the kernel.
|
||||
If those clones would increase the peak memory usage, however, we instead
|
||||
copy to cpu and restore them after each iteratrion. Figure out the args
|
||||
to be copied and do the copying.
|
||||
"""
|
||||
if not self.optimize_mem:
|
||||
return {}
|
||||
|
||||
copies = {}
|
||||
budget = torch.cuda.max_memory_allocated() - torch.cuda.memory_allocated()
|
||||
|
||||
def maybe_copy(name, arg):
|
||||
if name in self.mutated_arg_names:
|
||||
nonlocal budget
|
||||
assert isinstance(arg, torch.Tensor)
|
||||
assert not arg.is_cpu
|
||||
size = arg.numel() * arg.element_size()
|
||||
if size > budget:
|
||||
cpu_arg = torch.empty_strided(
|
||||
arg.size(),
|
||||
arg.stride(),
|
||||
dtype=arg.dtype,
|
||||
device="cpu",
|
||||
pin_memory=True,
|
||||
)
|
||||
cpu_arg.copy_(arg, non_blocking=True)
|
||||
copies[name] = (arg, cpu_arg)
|
||||
else:
|
||||
budget -= size
|
||||
|
||||
for i, arg in enumerate(args):
|
||||
maybe_copy(self.fn.arg_names[i], arg)
|
||||
|
||||
for name, arg in kwargs.items():
|
||||
maybe_copy(name, arg)
|
||||
|
||||
return copies
|
||||
|
||||
def restore_args_from_cpu(self, cpu_copies):
|
||||
for pair in cpu_copies.values():
|
||||
arg, cpu_arg = pair
|
||||
arg.copy_(cpu_arg, non_blocking=True)
|
||||
|
||||
def maybe_clone_args(
|
||||
self, exclude: Container[str], *args, **kwargs
|
||||
) -> Tuple[List[Any], Dict[str, Any]]:
|
||||
"""
|
||||
Prepare new args and kwargs by cloning any in-place buffers
|
||||
(that are not in the provided exclusion list), to avoid autotune
|
||||
contaminating them. Avoid cloning the other buffers because it
|
||||
leads to increased memory usage.
|
||||
"""
|
||||
from ..compile_fx import clone_preserve_strides
|
||||
|
||||
# [Note: clone mutated buffers]
|
||||
# clone inplace buffers to avoid autotune contaminating them if
|
||||
# the kernel does in-place stores. avoid cloning other buffers because
|
||||
# it leads to increase memory use
|
||||
cloned_args = []
|
||||
for i, arg in enumerate(args):
|
||||
if self.fn.arg_names[i] in self.mutated_arg_names:
|
||||
def prepare_arg(name, arg):
|
||||
if name in self.mutated_arg_names and name not in exclude:
|
||||
assert isinstance(arg, torch.Tensor)
|
||||
cloned_args.append(clone_preserve_strides(arg))
|
||||
return clone_preserve_strides(arg)
|
||||
else:
|
||||
cloned_args.append(arg)
|
||||
return arg
|
||||
|
||||
cloned_kwargs: Dict[str, Any] = {}
|
||||
for name, arg in kwargs.items():
|
||||
if name in self.mutated_arg_names:
|
||||
assert isinstance(arg, torch.Tensor)
|
||||
cloned_kwargs[name] = clone_preserve_strides(arg)
|
||||
else:
|
||||
cloned_kwargs[name] = arg
|
||||
cloned_args = [
|
||||
prepare_arg(self.fn.arg_names[i], arg) for i, arg in enumerate(args)
|
||||
]
|
||||
cloned_kwargs = {name: prepare_arg(name, arg) for name, arg in kwargs.items()}
|
||||
|
||||
return cloned_args, cloned_kwargs
|
||||
|
||||
def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]:
|
||||
return self.maybe_clone_args(set(), *args, **kwargs)
|
||||
|
||||
def benchmark_all_configs(self, *args, **kwargs):
|
||||
with dynamo_timed("CachingAutotuner.benchmark_all_configs"):
|
||||
timings = {
|
||||
|
|
@ -1070,6 +1126,7 @@ def cached_autotune(
|
|||
log.debug("autotune caching is disabled by config.force_disable_caches")
|
||||
|
||||
mutated_arg_names = inductor_meta.pop("mutated_arg_names", ())
|
||||
optimize_mem = inductor_meta.pop("optimize_mem", True)
|
||||
|
||||
def decorator(fn):
|
||||
# Remove XBLOCK from config if it's not a function argument.
|
||||
|
|
@ -1096,6 +1153,7 @@ def cached_autotune(
|
|||
configs=configs,
|
||||
save_cache_hook=autotune_cache and autotune_cache.save,
|
||||
mutated_arg_names=mutated_arg_names,
|
||||
optimize_mem=optimize_mem,
|
||||
heuristic_type=heuristic_type,
|
||||
size_hints=size_hints,
|
||||
custom_kernel=custom_kernel,
|
||||
|
|
@ -1108,6 +1166,7 @@ def cached_autotune(
|
|||
configs=configs,
|
||||
save_cache_hook=autotune_cache and autotune_cache.save,
|
||||
mutated_arg_names=mutated_arg_names,
|
||||
optimize_mem=optimize_mem,
|
||||
heuristic_type=heuristic_type,
|
||||
size_hints=size_hints,
|
||||
custom_kernel=custom_kernel,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user