Revert "[inductor] Enable CudaWrapperCodeGen for non-AOT mode (#98264)"

This reverts commit 77f32eb6cc.

Reverted https://github.com/pytorch/pytorch/pull/98264 on behalf of https://github.com/huydhn due to Sorry for reverting your PR, but this is failing in trunk due to a name error fake_mode_from_tensors is not defined 67d1a77086. This is probably a landrace
This commit is contained in:
PyTorch MergeBot 2023-04-06 19:00:05 +00:00
parent 3b6e94cb8c
commit f228b3977b
7 changed files with 172 additions and 253 deletions

View File

@ -5,13 +5,8 @@ from typing import NamedTuple
import torch._dynamo
from torch._inductor import config
from torch.testing._internal.common_utils import (
IS_MACOS,
TEST_WITH_ASAN,
TestCase as TorchTestCase,
)
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
from torch.testing._internal.common_utils import IS_MACOS, TestCase as TorchTestCase
from torch.testing._internal.inductor_utils import HAS_CPU
try:
try:
@ -25,26 +20,14 @@ except unittest.SkipTest:
raise
RUN_CPU = HAS_CPU and not torch.backends.mps.is_available() and not IS_MACOS
RUN_CUDA = HAS_CUDA and not TEST_WITH_ASAN
class CppWrapperTemplate:
pass
class CudaWrapperTemplate:
pass
class TestCppWrapper(TorchTestCase):
device = "cpu"
class TestCudaWrapper(TorchTestCase):
device = "cuda"
def make_test_case(name, device, tests):
test_name = f"{name}_{device}" if device else name
@ -62,12 +45,10 @@ def make_test_case(name, device, tests):
tests.tearDownClass()
fn.__name__ = test_name
setattr(
CppWrapperTemplate if device == "cpu" else CudaWrapperTemplate, test_name, fn
)
setattr(CppWrapperTemplate, test_name, fn)
if RUN_CPU:
if HAS_CPU and not torch.backends.mps.is_available() and not IS_MACOS:
class BaseTest(NamedTuple):
name: str
@ -100,42 +81,8 @@ if RUN_CPU:
test_torchinductor.copy_tests(CppWrapperTemplate, TestCppWrapper, "cpp_wrapper")
if RUN_CUDA:
class BaseTest(NamedTuple):
name: str
device: str = "cuda"
tests: TorchTestCase = test_torchinductor.CudaTests()
# Maintain two separate test lists for cuda and cpp for now
for item in [
BaseTest("test_as_strided"), # buffer reuse
BaseTest("test_bitwise"), # int32
BaseTest("test_bmm1"),
BaseTest("test_bmm2"),
BaseTest("test_cat"), # alias
BaseTest("test_linear1"),
BaseTest("test_linear2"),
BaseTest("test_linear_packed"),
BaseTest("test_linear_unary"),
# BaseTest("test_lowmem_dropout1"), # None as output
BaseTest("test_mm_views"),
BaseTest("test_profiler_mark_wrapper_call"),
BaseTest("test_reduction1"), # Reduction
BaseTest("test_relu"), # multiple inputs
BaseTest("test_scalar_input"),
BaseTest("test_silu"), # single input, single output
BaseTest("test_sum_dtype"), # float64
BaseTest("test_sum_int"), # bool, int64, int8, uint8
BaseTest("test_transpose"), # multiple outputs, buffer clear
]:
make_test_case(item.name, item.device, item.tests)
test_torchinductor.copy_tests(CudaWrapperTemplate, TestCudaWrapper, "cuda_wrapper")
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
if RUN_CPU or RUN_CUDA:
if HAS_CPU and not torch.backends.mps.is_available() and not IS_MACOS:
run_tests(needs="filelock")

View File

@ -515,6 +515,7 @@ def get_include_and_linking_paths(
libs = ["c10", "torch", "torch_cpu", "torch_python"]
if cuda:
libs += ["c10_cuda", "cuda", "torch_cuda"]
ipaths += [f"{cpp_extension._TORCH_PATH}/../aten/src/"]
else:
libs += ["gomp"]
macros = vec_isa.build_macro()
@ -577,28 +578,6 @@ def cpp_compile_command(
).strip()
class CudaKernelParamCache:
cache = dict()
clear = staticmethod(cache.clear)
@classmethod
def set(cls, key, params, cubin):
from filelock import FileLock
cubin_path = os.path.join(cubin_cache_dir(), f"{key}.cubin")
params["cubin_path"] = cubin_path
lock_dir = get_lock_dir()
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
with lock:
cls.cache[key] = params
with open(cubin_path, "wb") as f:
f.write(cubin)
@classmethod
def get(cls, key):
return cls.cache.get(key, None)
class AotCodeCache:
cache = dict()
clear = staticmethod(cache.clear)

View File

@ -1698,8 +1698,8 @@ class TritonScheduling:
def define_kernel(self, src_code, node_schedule):
wrapper = V.graph.wrapper_code
if src_code in wrapper.src_to_kernel:
kernel_name = wrapper.src_to_kernel[src_code]
if src_code in wrapper.kernels:
kernel_name = wrapper.kernels[src_code]
else:
fused_name = (
get_fused_kernel_name(node_schedule)
@ -1710,8 +1710,7 @@ class TritonScheduling:
kernel_name = "_".join(
["triton", kernel_category, fused_name, wrapper.next_kernel_suffix()]
)
# use the original src_code as the key
wrapper.src_to_kernel[src_code] = kernel_name
wrapper.kernels[src_code] = kernel_name
subs_name = kernel_name if config.triton.unique_kernel_names else "triton_"
src_code = src_code.replace("KERNEL_NAME", subs_name)
@ -1719,9 +1718,7 @@ class TritonScheduling:
# not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
src_code = src_code.replace("#pragma CMT", "#")
basename, _, kernel_path = get_code_path(src_code, "py", extra="")
wrapper.kernel_to_hash[kernel_name] = basename
_, _, kernel_path = get_code_path(src_code, "py", extra="")
compile_wrapper = IndentedBuffer()
compile_wrapper.writeline("async_compile.triton('''")
compile_wrapper.splice(src_code, strip=True)

View File

@ -12,7 +12,7 @@ from sympy import Expr
from torch._dynamo.utils import dynamo_timed
from .. import codecache, config, ir
from ..codecache import CudaKernelParamCache
from ..codecache import cubin_cache_dir
from ..utils import (
cache_on_self,
get_benchmark_name,
@ -58,6 +58,21 @@ def is_float(s: str):
return True
class KernelParamCache:
cache = dict()
def __init__(self):
self.prev_cache = None
def __enter__(self):
self.prev_cache = KernelParamCache.cache
KernelParamCache.cache = dict()
def __exit__(self, *args):
KernelParamCache.cache.clear()
KernelParamCache.cache = self.prev_cache
class MemoryPlanningState:
def __init__(self):
super().__init__()
@ -192,8 +207,7 @@ class WrapperCodeGen(CodeGen):
self.header = IndentedBuffer()
self.prefix = IndentedBuffer()
self.wrapper_call = IndentedBuffer()
self.src_to_kernel = {}
self.kernel_to_hash = {}
self.kernels = {}
self.lines = []
self.need_seed = False
self.declare = ""
@ -706,7 +720,6 @@ class CppWrapperCodeGen(WrapperCodeGen):
self.size = "sizes()"
self.stride = "strides()"
self.call_func_name = "inductor_cpp_entry"
self.cuda = False
def seed(self):
"""
@ -817,10 +830,9 @@ class CppWrapperCodeGen(WrapperCodeGen):
warning_all_flag = codecache.get_warning_all_flag()
cpp_flags = codecache.cpp_flags()
ipaths, lpaths, libs, macros = codecache.get_include_and_linking_paths(
vec_isa=codecache.pick_vec_isa(),
cuda=self.cuda,
vec_isa=codecache.pick_vec_isa()
)
optimization_flags = codecache.optimization_flags(cuda=self.cuda)
optimization_flags = codecache.optimization_flags()
use_custom_generated_macros = codecache.use_custom_generated_macros()
extra_cflags = f"{cpp_flags} {optimization_flags} {warning_all_flag} {macros} {use_custom_generated_macros}"
@ -952,35 +964,27 @@ class CudaWrapperCodeGen(CppWrapperCodeGen):
def __init__(self):
super().__init__()
self.kernel_callsite_id = count()
self.arg_var_id = count()
self.cuda = True
self.kernel_callsite_id = 0
self.arg_var_id = 0
def write_prefix(self):
self.prefix.splice(
"""
#include <ATen/ATen.h>
#include <c10/util/Exception.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
#include <c10/cuda/CUDAGuard.h>
#define AT_CUDA_DRIVER_CHECK_OVERRIDE(EXPR) \
do { \
CUresult __err = EXPR; \
if (__err != CUDA_SUCCESS) { \
AT_ERROR("CUDA driver error: ", static_cast<int>(__err)); \
} \
} while (0)
static inline CUfunction loadKernel(const std::string &filePath,
const std::string &funcName) {
CUfunction loadKernel(const std::string &filePath, const std::string &funcName) {
CUmodule mod;
CUfunction func;
AT_CUDA_DRIVER_CHECK_OVERRIDE(cuModuleLoad(&mod, filePath.c_str()));
AT_CUDA_DRIVER_CHECK_OVERRIDE(cuModuleGetFunction(&func, mod, funcName.c_str()));
AT_CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str()));
AT_CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str()));
return func;
}
static inline void launchKernel(
void launchKernel(
CUfunction func,
int gridX,
int gridY,
@ -989,7 +993,7 @@ class CudaWrapperCodeGen(CppWrapperCodeGen):
int sharedMemBytes,
void* args[],
int device_index) {
AT_CUDA_DRIVER_CHECK_OVERRIDE(cuLaunchKernel(
AT_CUDA_DRIVER_CHECK(cuLaunchKernel(
func, gridX, gridY, gridZ, 32*numWraps, 1, 1, sharedMemBytes,
at::cuda::getCurrentCUDAStream(device_index), args, nullptr));
}
@ -1001,15 +1005,19 @@ class CudaWrapperCodeGen(CppWrapperCodeGen):
def generate(self):
self.prefix.writeline("\n")
for kernel in self.src_to_kernel.values():
for kernel in self.kernels.values():
self.prefix.writeline(f"static CUfunction {kernel} = nullptr;")
self.prefix.writeline("\n")
return super().generate()
def generate_load_kernel(self, name, params):
def generate_load_kernel(self, name: str = None):
params = KernelParamCache.cache.get(name, None)
assert (
params is not None
), "cuda kernel parameters should already exist at this moment"
mangled_name = params.get("mangled_name", None)
assert mangled_name is not None, "missing mangled_name"
cubin_path = params.get("cubin_path", None)
cubin_path = os.path.join(cubin_cache_dir(), f"{name}.cubin")
assert os.path.exists(
cubin_path
), "cubin file should already exist at this moment"
@ -1024,7 +1032,7 @@ class CudaWrapperCodeGen(CppWrapperCodeGen):
# TODO: only works for constant now, need type info
new_args = []
for arg in call_args:
var_name = f"var_{next(self.arg_var_id)}"
var_name = f"var_{self.arg_var_id}"
if is_int(arg):
self.writeline(f"int {var_name} = {arg};")
elif is_float(arg):
@ -1034,29 +1042,27 @@ class CudaWrapperCodeGen(CppWrapperCodeGen):
f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({arg}.data_ptr());"
)
new_args.append(f"&{var_name}")
self.arg_var_id += 1
return ", ".join(new_args)
def generate_kernel_call(self, name, call_args, device_index):
params = CudaKernelParamCache.get(self.kernel_to_hash.get(name, None))
params = KernelParamCache.cache.get(name, None)
assert (
params is not None
), "cuda kernel parameters should already exist at this moment"
grid_x = params.get("grid_x", None)
grid_y = params.get("grid_y", None)
grid_z = params.get("grid_z", None)
num_warps = params.get("num_warps", None)
shared_mem = params.get("shared_mem", None)
self.generate_load_kernel(name, params)
self.generate_load_kernel(name)
call_args = self.generate_args_decl(call_args)
kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}"
self.writeline(f"void* {kernel_args_var}[] = {{{call_args}}};")
args_name = f"kernel_args_{self.kernel_callsite_id}"
self.kernel_callsite_id += 1
self.writeline(f"void* {args_name}[] = {{{call_args}}};")
self.writeline(
"launchKernel({}, {}, {}, {}, {}, {}, {}, {});".format(
name,
params["grid_x"],
params["grid_y"],
params["grid_z"],
params["num_warps"],
params["shared_mem"],
kernel_args_var,
device_index,
)
f"launchKernel({name}, {grid_x}, {grid_y}, {grid_z}, {num_warps}, {shared_mem}, {args_name}, {device_index});"
)

View File

@ -21,21 +21,16 @@ from torch._functorch.aot_autograd import make_boxed_func
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.utils._mode_utils import no_dispatch
from .._dynamo.backends.common import aot_autograd
from ..fx.graph import _PyTreeCodeGen
from . import config, metrics, overrides, pattern_matcher
from .codegen.wrapper import KernelParamCache
from .debug import DebugContext
from .decomposition import select_decomp_table
from .graph import GraphLowering
from .mkldnn import convert_outplace_to_inplace
from .utils import (
developer_warning,
get_dtype_size,
has_incompatible_cudagraph_ops,
is_cpu_device,
)
from .utils import developer_warning, get_dtype_size, has_incompatible_cudagraph_ops
from .virtualized import V
log = logging.getLogger(__name__)
@ -156,7 +151,6 @@ def compile_fx_inner(
num_fixed=0,
is_backward=False,
graph_id=None,
cpp_wrapper=False,
aot_mode=False,
is_inference=False,
boxed_forward_device_index=None,
@ -206,8 +200,8 @@ def compile_fx_inner(
shape_env=shape_env,
num_static_inputs=num_fixed,
graph_id=graph_id,
cpp_wrapper=cpp_wrapper,
aot_mode=aot_mode,
cpp_wrapper=config.cpp_wrapper,
)
with V.set_graph_handler(graph):
graph.run(*example_inputs)
@ -498,93 +492,74 @@ def count_tangents(fx_g: torch.fx.GraphModule):
return len(static_arg_idxs)
def compile_fx_with_cpp_wrapper(
def compile_fx_aot(
module: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
inner_compile,
decompositions: Optional[Dict[OpOverload, Callable]] = None,
):
"""
Compile into cpp wrapper:
For CPU, this is currently done in one pass.
For GPU, this is done in two passes: JIT-compile the model with python wrapper code
and run it to generate autotuned kernel binaries in the first pass; and then generate
cpp wrapper code and compile it to a dynamic library in the second pass.
"""
from torch.ao.quantization.fx.utils import assert_and_get_unique_device
# Turns off cpp_wrapper before calling back into compile_fx
config_patches = {"cpp_wrapper": False}
device = assert_and_get_unique_device(module)
if is_cpu_device(example_inputs):
assert device is None or device.type == "cpu"
with config.patch(config_patches):
return compile_fx(
module,
example_inputs,
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
decompositions=decompositions,
)
else:
assert device is None or device.type == "cuda"
config_patches.update(
{
"triton.cudagraphs": False,
"triton.store_cubin": True,
}
)
with config.patch(config_patches):
# first pass
module_copy = deepcopy(module)
fake_mode = fake_mode_from_tensors(example_inputs)
if fake_mode:
with no_dispatch():
def to_real_tensor(e):
if isinstance(e, FakeTensor):
out = torch.zeros_like(e, device=e.fake_device)
return out
return e
inputs_copy = [to_real_tensor(t) for t in example_inputs]
else:
inputs_copy = deepcopy(example_inputs)
compiled = compile_fx(
module_copy,
inputs_copy,
inner_compile=functools.partial(inner_compile, cpp_wrapper=False),
decompositions=decompositions,
)
compiled(*inputs_copy)
del module_copy, inputs_copy
# second pass
return compile_fx(
module,
example_inputs,
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
decompositions=decompositions,
)
def compile_fx_aot(
model_: torch.fx.GraphModule,
example_inputs_: List[torch.Tensor],
inner_compile=compile_fx_inner,
config_patches: Optional[Dict[str, Any]] = None,
decompositions: Optional[Dict[OpOverload, Callable]] = None,
):
return compile_fx(
model_,
example_inputs_,
inner_compile=functools.partial(inner_compile, aot_mode=True),
config_patches=config_patches,
decompositions=decompositions,
)
"""
JIT-compile the model and run it to generate kernel binaries in the first pass;
Generate cpp wrapper code and compile it to a dynamic library in the second pass
"""
from torch.ao.quantization.fx.utils import assert_and_get_unique_device
# Do we need to check inputs device as well?
device = assert_and_get_unique_device(module)
new_config_patches = config_patches.copy() if config_patches is not None else dict()
if device.type == "cuda":
# So far, we only need to do this for the Triton backend
with KernelParamCache():
new_config_patches.update(
{
"cpp_wrapper": False,
"triton.cudagraphs": False,
"triton.unique_kernel_names": True,
"triton.store_cubin": True,
}
)
with config.patch(new_config_patches):
module_copy = deepcopy(module)
inputs_copy = deepcopy(example_inputs)
compiled = compile_fx(
module_copy,
inputs_copy,
inner_compile,
new_config_patches,
decompositions,
)
compiled(inputs_copy)
del module_copy, inputs_copy
new_config_patches.update(
{
"cpp_wrapper": True,
"triton.store_cubin": False,
}
)
return compile_fx(
module,
example_inputs,
inner_compile=functools.partial(inner_compile, aot_mode=True),
config_patches=new_config_patches,
decompositions=decompositions,
)
else:
assert device.type == "cpu"
new_config_patches.update(
{
"cpp_wrapper": True,
}
)
return compile_fx(
module,
example_inputs,
inner_compile=functools.partial(inner_compile, aot_mode=True),
config_patches=new_config_patches,
decompositions=decompositions,
)
_graph_counter = itertools.count(0)
@ -608,14 +583,6 @@ def compile_fx(
decompositions=decompositions,
)
if config.cpp_wrapper:
return compile_fx_with_cpp_wrapper(
model_,
example_inputs_,
inner_compile=inner_compile,
decompositions=decompositions,
)
recursive_compile_fx = functools.partial(
compile_fx,
inner_compile=inner_compile,

View File

@ -55,7 +55,7 @@ log = logging.getLogger(__name__)
output_code_log = torch._logging.getArtifactLogger(__name__, "output_code")
def supported_dtype_of_cpp_wrapper(dtype, cuda):
def supported_dtype_of_cpp_wrapper(dtype):
supported_dtype = {
torch.float32,
torch.float64,
@ -68,9 +68,6 @@ def supported_dtype_of_cpp_wrapper(dtype, cuda):
torch.bfloat16,
# torch.float16, # TODO: implement this
}
if cuda:
supported_dtype.add(torch.float16)
return dtype in supported_dtype
@ -141,8 +138,8 @@ class GraphLowering(torch.fx.Interpreter):
shape_env=None,
num_static_inputs=None,
graph_id=None,
cpp_wrapper=False,
aot_mode=False,
cpp_wrapper=False,
):
super().__init__(gm)
self.extra_traceback = False # we do our own error wrapping
@ -172,8 +169,9 @@ class GraphLowering(torch.fx.Interpreter):
self.name_to_buffer: Dict[str, ir.ComputedBuffer] = {}
self.creation_time = time.time()
self.name = "GraphLowering"
self.cpp_wrapper = cpp_wrapper
# TODO: aot_mode and cpp_wrapper are tangled now. Some refactoring is needed.
self.aot_mode = aot_mode
self.cpp_wrapper = cpp_wrapper
self.graph_id = graph_id
self.scheduler = None
self._warned_fallback = {"aten.convolution_backward"}
@ -583,7 +581,12 @@ class GraphLowering(torch.fx.Interpreter):
def get_single_device(self):
return list(self.device_types)[0] if len(self.device_types) == 1 else None
def check_input_for_cpp_buffer(self, cuda):
def check_device_for_cpp_buffer(self):
device = self.get_single_device()
if self.get_single_device() is None:
self.disable_cpp_wrapper("device not CPU or CUDA")
def check_input_for_cpp_buffer(self):
for _, value in self.graph_inputs.items():
dtype = None
if isinstance(value, TensorBox):
@ -591,33 +594,37 @@ class GraphLowering(torch.fx.Interpreter):
elif isinstance(value, sympy.Symbol):
dtype = may_get_constant_buffer_dtype(value)
if not supported_dtype_of_cpp_wrapper(dtype, cuda):
if not supported_dtype_of_cpp_wrapper(dtype):
self.disable_cpp_wrapper("unsupported inputs dtype")
def check_constant_for_cpp_buffer(self):
if self.constants:
self.disable_cpp_wrapper("Constants")
def check_cpp_wrapper(self, cuda):
def check_cpp_wrapper(self):
self.check_cpp_codegen_disabled()
self.check_platform()
self.check_input_for_cpp_buffer(cuda)
self.check_device_for_cpp_buffer()
self.check_input_for_cpp_buffer()
self.check_constant_for_cpp_buffer()
def init_wrapper_code(self):
if self.cpp_wrapper:
if self.aot_mode:
device = self.get_single_device()
assert device == "cpu" or device == "cuda"
cuda = device == "cuda"
self.check_cpp_wrapper(cuda)
# Re-check self.cpp_wrapper because it might be disabled due to failed checking
self.check_cpp_wrapper()
if device == "cpu":
self.wrapper_code = CppWrapperCodeGen()
else:
assert device == "cuda", "Non-supported device for AOT compilation"
self.wrapper_code = CudaWrapperCodeGen()
elif self.cpp_wrapper:
self.check_cpp_wrapper()
if self.cpp_wrapper:
self.wrapper_code = (
CudaWrapperCodeGen() if cuda else CppWrapperCodeGen()
)
return
self.wrapper_code = WrapperCodeGen()
self.wrapper_code = CppWrapperCodeGen()
else:
self.wrapper_code = WrapperCodeGen()
else:
self.wrapper_code = WrapperCodeGen()
def codegen(self):
from .scheduler import Scheduler

View File

@ -16,7 +16,7 @@ import torch
from torch._dynamo.utils import dynamo_timed
from . import config
from .codecache import cache_dir, CudaKernelParamCache
from .codecache import cache_dir, cubin_cache_dir
from .ir import ReductionHint, TileHint
from .utils import (
@ -141,7 +141,7 @@ class CachingAutotuner(KernelInterface):
launcher.store_cubin = config.triton.store_cubin
# store this global varible to avoid the high overhead of reading it when calling run
if launcher.store_cubin:
launcher.fn = self.fn
launcher.kernel_name = self.fn.__name__
launcher.bin = binary
return launcher
@ -192,12 +192,21 @@ class CachingAutotuner(KernelInterface):
self.save_cache_hook(self.launchers[0].config)
def save_cuda_kernel(self, grid, stream, launcher):
from .codegen.wrapper import KernelParamCache
# Make sure kernel_name is enough for distiguishing kernels
assert config.triton.unique_kernel_names
if callable(grid):
grid_x, grid_y, grid_z = grid(launcher.config.kwargs)
else:
grid_x, grid_y, grid_z = grid
key = launcher.fn.module.split(".")[-1]
kernel_name = launcher.kernel_name
cubin_path = os.path.join(cubin_cache_dir(), f"{kernel_name}.cubin")
with open(cubin_path, "wb") as f:
f.write(launcher.bin.asm["cubin"])
params = {
"mangled_name": launcher.bin.metadata["name"],
"grid_x": grid_x,
@ -207,7 +216,14 @@ class CachingAutotuner(KernelInterface):
"shared_mem": launcher.bin.shared,
"stream": stream,
}
CudaKernelParamCache.set(key, params, launcher.bin.asm["cubin"])
with self.lock:
if KernelParamCache.cache.get(kernel_name, None):
assert (
KernelParamCache.cache[kernel_name].get("mangled_name", None)
== launcher.bin.metadata["name"]
)
else:
KernelParamCache.cache[kernel_name] = params
def run(self, *args, grid, stream):
if len(self.launchers) != 1: