pytorch/torch/_inductor/codegen/autotuner.py
Jason Ansel c7c09722ad Move TorchDynamo into PyTorch core (#86461)
Context:
https://github.com/pytorch/torchdynamo/issues/1588

This PR moves [TorchDynamo](https://github.com/pytorch/torchdynamo) and TorchInductor into PyTorch core.
- `torchdynamo` becomes `torch._dynamo`
- `torchinductor` becomes `torch._inductor`

This PR was generated by running `copy_to_core.sh` in https://github.com/pytorch/torchdynamo/pull/1538

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86461
Approved by: https://github.com/voznesenskym
2022-10-13 23:18:06 +00:00

275 lines
7.6 KiB
Python

import builtins
import torch
from .. import config, triton_ops
from ..triton_ops.autotune import mm_autotune, mm_heuristics
from ..utils import dynamo_testing
from ..virtualized import V
aten = torch.ops.aten
rand_strided = dynamo_testing.rand_strided
def str2func(str):
module, *name = str.split(".")
if module == "aten":
runnable = aten
elif module == "triton_ops":
runnable = triton_ops
elif module == "torch":
runnable = torch
else:
raise Exception(f"{str} could not be called")
for n in name:
runnable = getattr(runnable, n)
return runnable
class Autotuner:
def __init__(self):
self.cache = dict()
def _bench(self, kernel, *args, **kwargs):
def kernel_call():
kernel(*args, **kwargs)
from triton.testing import do_bench
return do_bench(kernel_call)
autotune = Autotuner()
def tuned_conv(
x_shape,
w_shape,
x_stride,
w_stride,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
device,
dtype,
adjust_triton=0.95,
):
"""
Return the best kernel name given inputs and layer parameters;
Considering potential pointwise fusion of conv, we could adjust triton timing
by multiplying adjust_triton (default=0.95)
"""
sizevars = V.graph.sizevars
x_shape = [sizevars.size_hint(s) for s in x_shape]
w_shape = [sizevars.size_hint(s) for s in w_shape]
x_stride = [sizevars.size_hint(s) for s in x_stride]
w_stride = [sizevars.size_hint(s) for s in w_stride]
x = rand_strided(x_shape, x_stride, device=device, dtype=dtype)
w = rand_strided(w_shape, w_stride, device=device, dtype=dtype)
# the identifiable args for the layers
id_args = [
*x_shape,
*w_shape,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
# *x_stride,
# *w_stride,
]
use_cuda = x.is_cuda
# gen_key
key = tuple(id_args)
key = ("conv",) + key
# candidate kernels
kernels = ["aten.convolution"]
if use_cuda:
kernels += ["triton_ops.conv"]
# filter kernels that args/kwargs does not meet requirements
remove_kernels = []
if groups > 1 or transposed:
remove_kernels += ["triton_ops.conv"]
kernels = [k for k in kernels if k not in remove_kernels]
# if only one choice, return that kernel
if len(kernels) == 1:
kernel = kernels[0]
# return kernel(
# x, w, stride, padding, dilation, transposed, output_padding, groups
# )
return kernel
timings = {}
if key not in autotune.cache:
for kernel in kernels:
runnable_kernel = str2func(kernel)
if "triton_ops" in kernel:
# because we use nhwc layout by default for triton conv
x = x.to(memory_format=torch.channels_last)
run_args = (
x,
w,
None,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
)
timing, _, _ = autotune._bench(runnable_kernel, *run_args)
if "triton_ops" in kernel:
timing = timing * adjust_triton
timings[kernel] = timing
autotune.cache[key] = builtins.min(timings, key=timings.get)
if config.debug:
print("for key = ", key)
print("timing", timings)
print("best_kernel", autotune.cache[key])
best_kernel = autotune.cache[key]
# if best_kernel == "triton_ops.conv":
# print(key, best_kernel)
return best_kernel
def tuned_mm(
a_shape,
b_shape,
a_stride,
b_stride,
device,
dtype,
adjust_triton=0.95,
):
"""
Return the best kernel name given mm input size;
Considering potential pointwise fusion of mm, we could adjust triton timing
by multiplying adjust_triton (default=0.95)
"""
sizevars = V.graph.sizevars
a_shape = [sizevars.size_hint(s) for s in a_shape]
b_shape = [sizevars.size_hint(s) for s in b_shape]
a_stride = [sizevars.size_hint(s) for s in a_stride]
b_stride = [sizevars.size_hint(s) for s in b_stride]
a = rand_strided(a_shape, a_stride, device=device, dtype=dtype)
b = rand_strided(b_shape, b_stride, device=device, dtype=dtype)
c = torch.empty((a_shape[0], b_shape[1]), device=device, dtype=dtype)
id_args = [
*a_shape,
*b_shape,
]
use_cuda = a.is_cuda
# gen_key
key = tuple(id_args)
key = ("mm",) + key
# candidate kernels
kernels = ["aten.mm.out"]
if use_cuda:
kernels += ["triton_ops.matmul_out"]
# if only one choice, return that kernel
if len(kernels) == 1:
kernel = kernels[0]
return kernel
timings = {}
if key not in autotune.cache:
# bench_start = time.time()
for kernel in kernels:
runnable_kernel = str2func(kernel)
if "triton_ops" in kernel:
run_args = (a, b, c)
run_kwargs = {}
inner_kernel = str2func(
kernel.replace("matmul_out", "_matmul_out") + ".kernel"
)
inner_kernel.kernel_decorators = []
# fix SPLIT_K = 1 for fusable kernels
mm_heuristics()(mm_autotune(get_io_bound_configs=False)(inner_kernel))
else:
run_args = (a, b)
run_kwargs = {"out": c}
timing, _, _ = autotune._bench(runnable_kernel, *run_args, **run_kwargs)
if "triton_ops" in kernel:
timing = timing * adjust_triton
timings[kernel] = timing
# bench_end = time.time()
# bench_time = bench_end - bench_start
autotune.cache[key] = builtins.min(timings, key=timings.get)
if config.debug:
print("for key = ", key)
print("timing", timings)
print("best_kernel", autotune.cache[key])
best_kernel = autotune.cache[key]
return best_kernel
def tuned_conv_layout(
kernel,
x_shape,
w_shape,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
device,
dtype,
):
sizevars = V.graph.sizevars
x_shape = [sizevars.size_hint(s) for s in x_shape]
w_shape = [sizevars.size_hint(s) for s in w_shape]
x = torch.randn(x_shape, device=device, dtype=dtype)
w = torch.randn(w_shape, device=device, dtype=dtype)
id_args = [
*x_shape,
*w_shape,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
]
# gen_key
key = tuple(id_args)
key = ("conv_layout",) + key
runnable_kernel = str2func(kernel)
timings = {}
if key not in autotune.cache:
for memory_format in ["torch.contiguous_format", "torch.channels_last"]:
x = x.to(memory_format=str2func(memory_format))
run_args = (
x,
w,
None,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
)
timing, _, _ = autotune._bench(runnable_kernel, *run_args)
timings[memory_format] = timing
autotune.cache[key] = builtins.min(timings, key=timings.get)
if config.debug:
print("for key = ", key)
print("timing", timings)
print("best_layout", autotune.cache[key])
best_layout = autotune.cache[key]
return best_layout