mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "[inductor] Enable CudaWrapperCodeGen for non-AOT mode (#98264)"
This reverts commit77f32eb6cc. Reverted https://github.com/pytorch/pytorch/pull/98264 on behalf of https://github.com/huydhn due to Sorry for reverting your PR, but this is failing in trunk due to a name error fake_mode_from_tensors is not defined67d1a77086. This is probably a landrace
This commit is contained in:
parent
3b6e94cb8c
commit
f228b3977b
|
|
@ -5,13 +5,8 @@ from typing import NamedTuple
|
|||
|
||||
import torch._dynamo
|
||||
from torch._inductor import config
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_MACOS,
|
||||
TEST_WITH_ASAN,
|
||||
TestCase as TorchTestCase,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
||||
|
||||
from torch.testing._internal.common_utils import IS_MACOS, TestCase as TorchTestCase
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU
|
||||
|
||||
try:
|
||||
try:
|
||||
|
|
@ -25,26 +20,14 @@ except unittest.SkipTest:
|
|||
raise
|
||||
|
||||
|
||||
RUN_CPU = HAS_CPU and not torch.backends.mps.is_available() and not IS_MACOS
|
||||
RUN_CUDA = HAS_CUDA and not TEST_WITH_ASAN
|
||||
|
||||
|
||||
class CppWrapperTemplate:
|
||||
pass
|
||||
|
||||
|
||||
class CudaWrapperTemplate:
|
||||
pass
|
||||
|
||||
|
||||
class TestCppWrapper(TorchTestCase):
|
||||
device = "cpu"
|
||||
|
||||
|
||||
class TestCudaWrapper(TorchTestCase):
|
||||
device = "cuda"
|
||||
|
||||
|
||||
def make_test_case(name, device, tests):
|
||||
test_name = f"{name}_{device}" if device else name
|
||||
|
||||
|
|
@ -62,12 +45,10 @@ def make_test_case(name, device, tests):
|
|||
tests.tearDownClass()
|
||||
|
||||
fn.__name__ = test_name
|
||||
setattr(
|
||||
CppWrapperTemplate if device == "cpu" else CudaWrapperTemplate, test_name, fn
|
||||
)
|
||||
setattr(CppWrapperTemplate, test_name, fn)
|
||||
|
||||
|
||||
if RUN_CPU:
|
||||
if HAS_CPU and not torch.backends.mps.is_available() and not IS_MACOS:
|
||||
|
||||
class BaseTest(NamedTuple):
|
||||
name: str
|
||||
|
|
@ -100,42 +81,8 @@ if RUN_CPU:
|
|||
|
||||
test_torchinductor.copy_tests(CppWrapperTemplate, TestCppWrapper, "cpp_wrapper")
|
||||
|
||||
if RUN_CUDA:
|
||||
|
||||
class BaseTest(NamedTuple):
|
||||
name: str
|
||||
device: str = "cuda"
|
||||
tests: TorchTestCase = test_torchinductor.CudaTests()
|
||||
|
||||
# Maintain two separate test lists for cuda and cpp for now
|
||||
for item in [
|
||||
BaseTest("test_as_strided"), # buffer reuse
|
||||
BaseTest("test_bitwise"), # int32
|
||||
BaseTest("test_bmm1"),
|
||||
BaseTest("test_bmm2"),
|
||||
BaseTest("test_cat"), # alias
|
||||
BaseTest("test_linear1"),
|
||||
BaseTest("test_linear2"),
|
||||
BaseTest("test_linear_packed"),
|
||||
BaseTest("test_linear_unary"),
|
||||
# BaseTest("test_lowmem_dropout1"), # None as output
|
||||
BaseTest("test_mm_views"),
|
||||
BaseTest("test_profiler_mark_wrapper_call"),
|
||||
BaseTest("test_reduction1"), # Reduction
|
||||
BaseTest("test_relu"), # multiple inputs
|
||||
BaseTest("test_scalar_input"),
|
||||
BaseTest("test_silu"), # single input, single output
|
||||
BaseTest("test_sum_dtype"), # float64
|
||||
BaseTest("test_sum_int"), # bool, int64, int8, uint8
|
||||
BaseTest("test_transpose"), # multiple outputs, buffer clear
|
||||
]:
|
||||
make_test_case(item.name, item.device, item.tests)
|
||||
|
||||
test_torchinductor.copy_tests(CudaWrapperTemplate, TestCudaWrapper, "cuda_wrapper")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
if RUN_CPU or RUN_CUDA:
|
||||
if HAS_CPU and not torch.backends.mps.is_available() and not IS_MACOS:
|
||||
run_tests(needs="filelock")
|
||||
|
|
|
|||
|
|
@ -515,6 +515,7 @@ def get_include_and_linking_paths(
|
|||
libs = ["c10", "torch", "torch_cpu", "torch_python"]
|
||||
if cuda:
|
||||
libs += ["c10_cuda", "cuda", "torch_cuda"]
|
||||
ipaths += [f"{cpp_extension._TORCH_PATH}/../aten/src/"]
|
||||
else:
|
||||
libs += ["gomp"]
|
||||
macros = vec_isa.build_macro()
|
||||
|
|
@ -577,28 +578,6 @@ def cpp_compile_command(
|
|||
).strip()
|
||||
|
||||
|
||||
class CudaKernelParamCache:
|
||||
cache = dict()
|
||||
clear = staticmethod(cache.clear)
|
||||
|
||||
@classmethod
|
||||
def set(cls, key, params, cubin):
|
||||
from filelock import FileLock
|
||||
|
||||
cubin_path = os.path.join(cubin_cache_dir(), f"{key}.cubin")
|
||||
params["cubin_path"] = cubin_path
|
||||
lock_dir = get_lock_dir()
|
||||
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
|
||||
with lock:
|
||||
cls.cache[key] = params
|
||||
with open(cubin_path, "wb") as f:
|
||||
f.write(cubin)
|
||||
|
||||
@classmethod
|
||||
def get(cls, key):
|
||||
return cls.cache.get(key, None)
|
||||
|
||||
|
||||
class AotCodeCache:
|
||||
cache = dict()
|
||||
clear = staticmethod(cache.clear)
|
||||
|
|
|
|||
|
|
@ -1698,8 +1698,8 @@ class TritonScheduling:
|
|||
|
||||
def define_kernel(self, src_code, node_schedule):
|
||||
wrapper = V.graph.wrapper_code
|
||||
if src_code in wrapper.src_to_kernel:
|
||||
kernel_name = wrapper.src_to_kernel[src_code]
|
||||
if src_code in wrapper.kernels:
|
||||
kernel_name = wrapper.kernels[src_code]
|
||||
else:
|
||||
fused_name = (
|
||||
get_fused_kernel_name(node_schedule)
|
||||
|
|
@ -1710,8 +1710,7 @@ class TritonScheduling:
|
|||
kernel_name = "_".join(
|
||||
["triton", kernel_category, fused_name, wrapper.next_kernel_suffix()]
|
||||
)
|
||||
# use the original src_code as the key
|
||||
wrapper.src_to_kernel[src_code] = kernel_name
|
||||
wrapper.kernels[src_code] = kernel_name
|
||||
subs_name = kernel_name if config.triton.unique_kernel_names else "triton_"
|
||||
src_code = src_code.replace("KERNEL_NAME", subs_name)
|
||||
|
||||
|
|
@ -1719,9 +1718,7 @@ class TritonScheduling:
|
|||
# not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
|
||||
src_code = src_code.replace("#pragma CMT", "#")
|
||||
|
||||
basename, _, kernel_path = get_code_path(src_code, "py", extra="")
|
||||
wrapper.kernel_to_hash[kernel_name] = basename
|
||||
|
||||
_, _, kernel_path = get_code_path(src_code, "py", extra="")
|
||||
compile_wrapper = IndentedBuffer()
|
||||
compile_wrapper.writeline("async_compile.triton('''")
|
||||
compile_wrapper.splice(src_code, strip=True)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from sympy import Expr
|
|||
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
from .. import codecache, config, ir
|
||||
from ..codecache import CudaKernelParamCache
|
||||
from ..codecache import cubin_cache_dir
|
||||
from ..utils import (
|
||||
cache_on_self,
|
||||
get_benchmark_name,
|
||||
|
|
@ -58,6 +58,21 @@ def is_float(s: str):
|
|||
return True
|
||||
|
||||
|
||||
class KernelParamCache:
|
||||
cache = dict()
|
||||
|
||||
def __init__(self):
|
||||
self.prev_cache = None
|
||||
|
||||
def __enter__(self):
|
||||
self.prev_cache = KernelParamCache.cache
|
||||
KernelParamCache.cache = dict()
|
||||
|
||||
def __exit__(self, *args):
|
||||
KernelParamCache.cache.clear()
|
||||
KernelParamCache.cache = self.prev_cache
|
||||
|
||||
|
||||
class MemoryPlanningState:
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -192,8 +207,7 @@ class WrapperCodeGen(CodeGen):
|
|||
self.header = IndentedBuffer()
|
||||
self.prefix = IndentedBuffer()
|
||||
self.wrapper_call = IndentedBuffer()
|
||||
self.src_to_kernel = {}
|
||||
self.kernel_to_hash = {}
|
||||
self.kernels = {}
|
||||
self.lines = []
|
||||
self.need_seed = False
|
||||
self.declare = ""
|
||||
|
|
@ -706,7 +720,6 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
|||
self.size = "sizes()"
|
||||
self.stride = "strides()"
|
||||
self.call_func_name = "inductor_cpp_entry"
|
||||
self.cuda = False
|
||||
|
||||
def seed(self):
|
||||
"""
|
||||
|
|
@ -817,10 +830,9 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
|||
warning_all_flag = codecache.get_warning_all_flag()
|
||||
cpp_flags = codecache.cpp_flags()
|
||||
ipaths, lpaths, libs, macros = codecache.get_include_and_linking_paths(
|
||||
vec_isa=codecache.pick_vec_isa(),
|
||||
cuda=self.cuda,
|
||||
vec_isa=codecache.pick_vec_isa()
|
||||
)
|
||||
optimization_flags = codecache.optimization_flags(cuda=self.cuda)
|
||||
optimization_flags = codecache.optimization_flags()
|
||||
use_custom_generated_macros = codecache.use_custom_generated_macros()
|
||||
|
||||
extra_cflags = f"{cpp_flags} {optimization_flags} {warning_all_flag} {macros} {use_custom_generated_macros}"
|
||||
|
|
@ -952,35 +964,27 @@ class CudaWrapperCodeGen(CppWrapperCodeGen):
|
|||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.kernel_callsite_id = count()
|
||||
self.arg_var_id = count()
|
||||
self.cuda = True
|
||||
self.kernel_callsite_id = 0
|
||||
self.arg_var_id = 0
|
||||
|
||||
def write_prefix(self):
|
||||
self.prefix.splice(
|
||||
"""
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#define AT_CUDA_DRIVER_CHECK_OVERRIDE(EXPR) \
|
||||
do { \
|
||||
CUresult __err = EXPR; \
|
||||
if (__err != CUDA_SUCCESS) { \
|
||||
AT_ERROR("CUDA driver error: ", static_cast<int>(__err)); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
static inline CUfunction loadKernel(const std::string &filePath,
|
||||
const std::string &funcName) {
|
||||
CUfunction loadKernel(const std::string &filePath, const std::string &funcName) {
|
||||
CUmodule mod;
|
||||
CUfunction func;
|
||||
AT_CUDA_DRIVER_CHECK_OVERRIDE(cuModuleLoad(&mod, filePath.c_str()));
|
||||
AT_CUDA_DRIVER_CHECK_OVERRIDE(cuModuleGetFunction(&func, mod, funcName.c_str()));
|
||||
AT_CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str()));
|
||||
AT_CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str()));
|
||||
return func;
|
||||
}
|
||||
|
||||
static inline void launchKernel(
|
||||
void launchKernel(
|
||||
CUfunction func,
|
||||
int gridX,
|
||||
int gridY,
|
||||
|
|
@ -989,7 +993,7 @@ class CudaWrapperCodeGen(CppWrapperCodeGen):
|
|||
int sharedMemBytes,
|
||||
void* args[],
|
||||
int device_index) {
|
||||
AT_CUDA_DRIVER_CHECK_OVERRIDE(cuLaunchKernel(
|
||||
AT_CUDA_DRIVER_CHECK(cuLaunchKernel(
|
||||
func, gridX, gridY, gridZ, 32*numWraps, 1, 1, sharedMemBytes,
|
||||
at::cuda::getCurrentCUDAStream(device_index), args, nullptr));
|
||||
}
|
||||
|
|
@ -1001,15 +1005,19 @@ class CudaWrapperCodeGen(CppWrapperCodeGen):
|
|||
|
||||
def generate(self):
|
||||
self.prefix.writeline("\n")
|
||||
for kernel in self.src_to_kernel.values():
|
||||
for kernel in self.kernels.values():
|
||||
self.prefix.writeline(f"static CUfunction {kernel} = nullptr;")
|
||||
self.prefix.writeline("\n")
|
||||
return super().generate()
|
||||
|
||||
def generate_load_kernel(self, name, params):
|
||||
def generate_load_kernel(self, name: str = None):
|
||||
params = KernelParamCache.cache.get(name, None)
|
||||
assert (
|
||||
params is not None
|
||||
), "cuda kernel parameters should already exist at this moment"
|
||||
mangled_name = params.get("mangled_name", None)
|
||||
assert mangled_name is not None, "missing mangled_name"
|
||||
cubin_path = params.get("cubin_path", None)
|
||||
cubin_path = os.path.join(cubin_cache_dir(), f"{name}.cubin")
|
||||
assert os.path.exists(
|
||||
cubin_path
|
||||
), "cubin file should already exist at this moment"
|
||||
|
|
@ -1024,7 +1032,7 @@ class CudaWrapperCodeGen(CppWrapperCodeGen):
|
|||
# TODO: only works for constant now, need type info
|
||||
new_args = []
|
||||
for arg in call_args:
|
||||
var_name = f"var_{next(self.arg_var_id)}"
|
||||
var_name = f"var_{self.arg_var_id}"
|
||||
if is_int(arg):
|
||||
self.writeline(f"int {var_name} = {arg};")
|
||||
elif is_float(arg):
|
||||
|
|
@ -1034,29 +1042,27 @@ class CudaWrapperCodeGen(CppWrapperCodeGen):
|
|||
f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({arg}.data_ptr());"
|
||||
)
|
||||
new_args.append(f"&{var_name}")
|
||||
self.arg_var_id += 1
|
||||
|
||||
return ", ".join(new_args)
|
||||
|
||||
def generate_kernel_call(self, name, call_args, device_index):
|
||||
params = CudaKernelParamCache.get(self.kernel_to_hash.get(name, None))
|
||||
params = KernelParamCache.cache.get(name, None)
|
||||
assert (
|
||||
params is not None
|
||||
), "cuda kernel parameters should already exist at this moment"
|
||||
grid_x = params.get("grid_x", None)
|
||||
grid_y = params.get("grid_y", None)
|
||||
grid_z = params.get("grid_z", None)
|
||||
num_warps = params.get("num_warps", None)
|
||||
shared_mem = params.get("shared_mem", None)
|
||||
|
||||
self.generate_load_kernel(name, params)
|
||||
self.generate_load_kernel(name)
|
||||
|
||||
call_args = self.generate_args_decl(call_args)
|
||||
kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}"
|
||||
self.writeline(f"void* {kernel_args_var}[] = {{{call_args}}};")
|
||||
args_name = f"kernel_args_{self.kernel_callsite_id}"
|
||||
self.kernel_callsite_id += 1
|
||||
self.writeline(f"void* {args_name}[] = {{{call_args}}};")
|
||||
self.writeline(
|
||||
"launchKernel({}, {}, {}, {}, {}, {}, {}, {});".format(
|
||||
name,
|
||||
params["grid_x"],
|
||||
params["grid_y"],
|
||||
params["grid_z"],
|
||||
params["num_warps"],
|
||||
params["shared_mem"],
|
||||
kernel_args_var,
|
||||
device_index,
|
||||
)
|
||||
f"launchKernel({name}, {grid_x}, {grid_y}, {grid_z}, {num_warps}, {shared_mem}, {args_name}, {device_index});"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,21 +21,16 @@ from torch._functorch.aot_autograd import make_boxed_func
|
|||
from torch._ops import OpOverload
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
|
||||
from .._dynamo.backends.common import aot_autograd
|
||||
from ..fx.graph import _PyTreeCodeGen
|
||||
from . import config, metrics, overrides, pattern_matcher
|
||||
from .codegen.wrapper import KernelParamCache
|
||||
from .debug import DebugContext
|
||||
from .decomposition import select_decomp_table
|
||||
from .graph import GraphLowering
|
||||
from .mkldnn import convert_outplace_to_inplace
|
||||
from .utils import (
|
||||
developer_warning,
|
||||
get_dtype_size,
|
||||
has_incompatible_cudagraph_ops,
|
||||
is_cpu_device,
|
||||
)
|
||||
from .utils import developer_warning, get_dtype_size, has_incompatible_cudagraph_ops
|
||||
from .virtualized import V
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -156,7 +151,6 @@ def compile_fx_inner(
|
|||
num_fixed=0,
|
||||
is_backward=False,
|
||||
graph_id=None,
|
||||
cpp_wrapper=False,
|
||||
aot_mode=False,
|
||||
is_inference=False,
|
||||
boxed_forward_device_index=None,
|
||||
|
|
@ -206,8 +200,8 @@ def compile_fx_inner(
|
|||
shape_env=shape_env,
|
||||
num_static_inputs=num_fixed,
|
||||
graph_id=graph_id,
|
||||
cpp_wrapper=cpp_wrapper,
|
||||
aot_mode=aot_mode,
|
||||
cpp_wrapper=config.cpp_wrapper,
|
||||
)
|
||||
with V.set_graph_handler(graph):
|
||||
graph.run(*example_inputs)
|
||||
|
|
@ -498,93 +492,74 @@ def count_tangents(fx_g: torch.fx.GraphModule):
|
|||
return len(static_arg_idxs)
|
||||
|
||||
|
||||
def compile_fx_with_cpp_wrapper(
|
||||
def compile_fx_aot(
|
||||
module: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor],
|
||||
inner_compile,
|
||||
decompositions: Optional[Dict[OpOverload, Callable]] = None,
|
||||
):
|
||||
"""
|
||||
Compile into cpp wrapper:
|
||||
For CPU, this is currently done in one pass.
|
||||
For GPU, this is done in two passes: JIT-compile the model with python wrapper code
|
||||
and run it to generate autotuned kernel binaries in the first pass; and then generate
|
||||
cpp wrapper code and compile it to a dynamic library in the second pass.
|
||||
"""
|
||||
from torch.ao.quantization.fx.utils import assert_and_get_unique_device
|
||||
|
||||
# Turns off cpp_wrapper before calling back into compile_fx
|
||||
config_patches = {"cpp_wrapper": False}
|
||||
device = assert_and_get_unique_device(module)
|
||||
|
||||
if is_cpu_device(example_inputs):
|
||||
assert device is None or device.type == "cpu"
|
||||
with config.patch(config_patches):
|
||||
return compile_fx(
|
||||
module,
|
||||
example_inputs,
|
||||
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
|
||||
decompositions=decompositions,
|
||||
)
|
||||
else:
|
||||
assert device is None or device.type == "cuda"
|
||||
|
||||
config_patches.update(
|
||||
{
|
||||
"triton.cudagraphs": False,
|
||||
"triton.store_cubin": True,
|
||||
}
|
||||
)
|
||||
with config.patch(config_patches):
|
||||
# first pass
|
||||
module_copy = deepcopy(module)
|
||||
fake_mode = fake_mode_from_tensors(example_inputs)
|
||||
|
||||
if fake_mode:
|
||||
with no_dispatch():
|
||||
|
||||
def to_real_tensor(e):
|
||||
if isinstance(e, FakeTensor):
|
||||
out = torch.zeros_like(e, device=e.fake_device)
|
||||
return out
|
||||
return e
|
||||
|
||||
inputs_copy = [to_real_tensor(t) for t in example_inputs]
|
||||
else:
|
||||
inputs_copy = deepcopy(example_inputs)
|
||||
|
||||
compiled = compile_fx(
|
||||
module_copy,
|
||||
inputs_copy,
|
||||
inner_compile=functools.partial(inner_compile, cpp_wrapper=False),
|
||||
decompositions=decompositions,
|
||||
)
|
||||
compiled(*inputs_copy)
|
||||
del module_copy, inputs_copy
|
||||
|
||||
# second pass
|
||||
return compile_fx(
|
||||
module,
|
||||
example_inputs,
|
||||
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
|
||||
decompositions=decompositions,
|
||||
)
|
||||
|
||||
|
||||
def compile_fx_aot(
|
||||
model_: torch.fx.GraphModule,
|
||||
example_inputs_: List[torch.Tensor],
|
||||
inner_compile=compile_fx_inner,
|
||||
config_patches: Optional[Dict[str, Any]] = None,
|
||||
decompositions: Optional[Dict[OpOverload, Callable]] = None,
|
||||
):
|
||||
return compile_fx(
|
||||
model_,
|
||||
example_inputs_,
|
||||
inner_compile=functools.partial(inner_compile, aot_mode=True),
|
||||
config_patches=config_patches,
|
||||
decompositions=decompositions,
|
||||
)
|
||||
"""
|
||||
JIT-compile the model and run it to generate kernel binaries in the first pass;
|
||||
Generate cpp wrapper code and compile it to a dynamic library in the second pass
|
||||
"""
|
||||
from torch.ao.quantization.fx.utils import assert_and_get_unique_device
|
||||
|
||||
# Do we need to check inputs device as well?
|
||||
device = assert_and_get_unique_device(module)
|
||||
new_config_patches = config_patches.copy() if config_patches is not None else dict()
|
||||
|
||||
if device.type == "cuda":
|
||||
# So far, we only need to do this for the Triton backend
|
||||
with KernelParamCache():
|
||||
new_config_patches.update(
|
||||
{
|
||||
"cpp_wrapper": False,
|
||||
"triton.cudagraphs": False,
|
||||
"triton.unique_kernel_names": True,
|
||||
"triton.store_cubin": True,
|
||||
}
|
||||
)
|
||||
with config.patch(new_config_patches):
|
||||
module_copy = deepcopy(module)
|
||||
inputs_copy = deepcopy(example_inputs)
|
||||
compiled = compile_fx(
|
||||
module_copy,
|
||||
inputs_copy,
|
||||
inner_compile,
|
||||
new_config_patches,
|
||||
decompositions,
|
||||
)
|
||||
compiled(inputs_copy)
|
||||
del module_copy, inputs_copy
|
||||
|
||||
new_config_patches.update(
|
||||
{
|
||||
"cpp_wrapper": True,
|
||||
"triton.store_cubin": False,
|
||||
}
|
||||
)
|
||||
return compile_fx(
|
||||
module,
|
||||
example_inputs,
|
||||
inner_compile=functools.partial(inner_compile, aot_mode=True),
|
||||
config_patches=new_config_patches,
|
||||
decompositions=decompositions,
|
||||
)
|
||||
else:
|
||||
assert device.type == "cpu"
|
||||
new_config_patches.update(
|
||||
{
|
||||
"cpp_wrapper": True,
|
||||
}
|
||||
)
|
||||
return compile_fx(
|
||||
module,
|
||||
example_inputs,
|
||||
inner_compile=functools.partial(inner_compile, aot_mode=True),
|
||||
config_patches=new_config_patches,
|
||||
decompositions=decompositions,
|
||||
)
|
||||
|
||||
|
||||
_graph_counter = itertools.count(0)
|
||||
|
|
@ -608,14 +583,6 @@ def compile_fx(
|
|||
decompositions=decompositions,
|
||||
)
|
||||
|
||||
if config.cpp_wrapper:
|
||||
return compile_fx_with_cpp_wrapper(
|
||||
model_,
|
||||
example_inputs_,
|
||||
inner_compile=inner_compile,
|
||||
decompositions=decompositions,
|
||||
)
|
||||
|
||||
recursive_compile_fx = functools.partial(
|
||||
compile_fx,
|
||||
inner_compile=inner_compile,
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ log = logging.getLogger(__name__)
|
|||
output_code_log = torch._logging.getArtifactLogger(__name__, "output_code")
|
||||
|
||||
|
||||
def supported_dtype_of_cpp_wrapper(dtype, cuda):
|
||||
def supported_dtype_of_cpp_wrapper(dtype):
|
||||
supported_dtype = {
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
|
|
@ -68,9 +68,6 @@ def supported_dtype_of_cpp_wrapper(dtype, cuda):
|
|||
torch.bfloat16,
|
||||
# torch.float16, # TODO: implement this
|
||||
}
|
||||
if cuda:
|
||||
supported_dtype.add(torch.float16)
|
||||
|
||||
return dtype in supported_dtype
|
||||
|
||||
|
||||
|
|
@ -141,8 +138,8 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
shape_env=None,
|
||||
num_static_inputs=None,
|
||||
graph_id=None,
|
||||
cpp_wrapper=False,
|
||||
aot_mode=False,
|
||||
cpp_wrapper=False,
|
||||
):
|
||||
super().__init__(gm)
|
||||
self.extra_traceback = False # we do our own error wrapping
|
||||
|
|
@ -172,8 +169,9 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
self.name_to_buffer: Dict[str, ir.ComputedBuffer] = {}
|
||||
self.creation_time = time.time()
|
||||
self.name = "GraphLowering"
|
||||
self.cpp_wrapper = cpp_wrapper
|
||||
# TODO: aot_mode and cpp_wrapper are tangled now. Some refactoring is needed.
|
||||
self.aot_mode = aot_mode
|
||||
self.cpp_wrapper = cpp_wrapper
|
||||
self.graph_id = graph_id
|
||||
self.scheduler = None
|
||||
self._warned_fallback = {"aten.convolution_backward"}
|
||||
|
|
@ -583,7 +581,12 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
def get_single_device(self):
|
||||
return list(self.device_types)[0] if len(self.device_types) == 1 else None
|
||||
|
||||
def check_input_for_cpp_buffer(self, cuda):
|
||||
def check_device_for_cpp_buffer(self):
|
||||
device = self.get_single_device()
|
||||
if self.get_single_device() is None:
|
||||
self.disable_cpp_wrapper("device not CPU or CUDA")
|
||||
|
||||
def check_input_for_cpp_buffer(self):
|
||||
for _, value in self.graph_inputs.items():
|
||||
dtype = None
|
||||
if isinstance(value, TensorBox):
|
||||
|
|
@ -591,33 +594,37 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
elif isinstance(value, sympy.Symbol):
|
||||
dtype = may_get_constant_buffer_dtype(value)
|
||||
|
||||
if not supported_dtype_of_cpp_wrapper(dtype, cuda):
|
||||
if not supported_dtype_of_cpp_wrapper(dtype):
|
||||
self.disable_cpp_wrapper("unsupported inputs dtype")
|
||||
|
||||
def check_constant_for_cpp_buffer(self):
|
||||
if self.constants:
|
||||
self.disable_cpp_wrapper("Constants")
|
||||
|
||||
def check_cpp_wrapper(self, cuda):
|
||||
def check_cpp_wrapper(self):
|
||||
self.check_cpp_codegen_disabled()
|
||||
self.check_platform()
|
||||
self.check_input_for_cpp_buffer(cuda)
|
||||
self.check_device_for_cpp_buffer()
|
||||
self.check_input_for_cpp_buffer()
|
||||
self.check_constant_for_cpp_buffer()
|
||||
|
||||
def init_wrapper_code(self):
|
||||
if self.cpp_wrapper:
|
||||
if self.aot_mode:
|
||||
device = self.get_single_device()
|
||||
assert device == "cpu" or device == "cuda"
|
||||
cuda = device == "cuda"
|
||||
self.check_cpp_wrapper(cuda)
|
||||
# Re-check self.cpp_wrapper because it might be disabled due to failed checking
|
||||
self.check_cpp_wrapper()
|
||||
if device == "cpu":
|
||||
self.wrapper_code = CppWrapperCodeGen()
|
||||
else:
|
||||
assert device == "cuda", "Non-supported device for AOT compilation"
|
||||
self.wrapper_code = CudaWrapperCodeGen()
|
||||
elif self.cpp_wrapper:
|
||||
self.check_cpp_wrapper()
|
||||
if self.cpp_wrapper:
|
||||
self.wrapper_code = (
|
||||
CudaWrapperCodeGen() if cuda else CppWrapperCodeGen()
|
||||
)
|
||||
return
|
||||
|
||||
self.wrapper_code = WrapperCodeGen()
|
||||
self.wrapper_code = CppWrapperCodeGen()
|
||||
else:
|
||||
self.wrapper_code = WrapperCodeGen()
|
||||
else:
|
||||
self.wrapper_code = WrapperCodeGen()
|
||||
|
||||
def codegen(self):
|
||||
from .scheduler import Scheduler
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import torch
|
|||
from torch._dynamo.utils import dynamo_timed
|
||||
|
||||
from . import config
|
||||
from .codecache import cache_dir, CudaKernelParamCache
|
||||
from .codecache import cache_dir, cubin_cache_dir
|
||||
|
||||
from .ir import ReductionHint, TileHint
|
||||
from .utils import (
|
||||
|
|
@ -141,7 +141,7 @@ class CachingAutotuner(KernelInterface):
|
|||
launcher.store_cubin = config.triton.store_cubin
|
||||
# store this global varible to avoid the high overhead of reading it when calling run
|
||||
if launcher.store_cubin:
|
||||
launcher.fn = self.fn
|
||||
launcher.kernel_name = self.fn.__name__
|
||||
launcher.bin = binary
|
||||
|
||||
return launcher
|
||||
|
|
@ -192,12 +192,21 @@ class CachingAutotuner(KernelInterface):
|
|||
self.save_cache_hook(self.launchers[0].config)
|
||||
|
||||
def save_cuda_kernel(self, grid, stream, launcher):
|
||||
from .codegen.wrapper import KernelParamCache
|
||||
|
||||
# Make sure kernel_name is enough for distiguishing kernels
|
||||
assert config.triton.unique_kernel_names
|
||||
|
||||
if callable(grid):
|
||||
grid_x, grid_y, grid_z = grid(launcher.config.kwargs)
|
||||
else:
|
||||
grid_x, grid_y, grid_z = grid
|
||||
|
||||
key = launcher.fn.module.split(".")[-1]
|
||||
kernel_name = launcher.kernel_name
|
||||
cubin_path = os.path.join(cubin_cache_dir(), f"{kernel_name}.cubin")
|
||||
with open(cubin_path, "wb") as f:
|
||||
f.write(launcher.bin.asm["cubin"])
|
||||
|
||||
params = {
|
||||
"mangled_name": launcher.bin.metadata["name"],
|
||||
"grid_x": grid_x,
|
||||
|
|
@ -207,7 +216,14 @@ class CachingAutotuner(KernelInterface):
|
|||
"shared_mem": launcher.bin.shared,
|
||||
"stream": stream,
|
||||
}
|
||||
CudaKernelParamCache.set(key, params, launcher.bin.asm["cubin"])
|
||||
with self.lock:
|
||||
if KernelParamCache.cache.get(kernel_name, None):
|
||||
assert (
|
||||
KernelParamCache.cache[kernel_name].get("mangled_name", None)
|
||||
== launcher.bin.metadata["name"]
|
||||
)
|
||||
else:
|
||||
KernelParamCache.cache[kernel_name] = params
|
||||
|
||||
def run(self, *args, grid, stream):
|
||||
if len(self.launchers) != 1:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user