[inductor][ck] add kBatch_sweep to config.rocm (#148223)

Summary:
# Why

enable testing and users to specify a set of kBatches to try rather than relying on our hand written heuristic

# What

add rocm.kBatch_sweep as a list of kBatches to try out. These will generate a product of CK instances, one per kBatch for each existing op, though they are often filtered out if they are likely to fail at runtime

Test Plan: n/a

Reviewed By: chenyang78

Differential Revision: D70226055

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148223
Approved by: https://github.com/ColinPeppler
This commit is contained in:
Ruben Rodriguez Buchillon 2025-03-06 01:14:31 +00:00 committed by PyTorch MergeBot
parent 63fbc738dc
commit 32715a2311
2 changed files with 11 additions and 2 deletions

View File

@ -887,8 +887,11 @@ class CKGemmTemplate(CKTemplate):
M = X_meta.size[-2]
K = X_meta.size[-1]
N = W_meta.size[-1]
if K < 16 * max(M, N):
if K // max(M, N) < config.rocm.split_k_threshold:
return [1]
# if the user is telling us which kBatches to sweep, just use those
if config.rocm.kBatch_sweep is not None:
return config.rocm.kBatch_sweep
# Calculate the number of blocks needed for each dimension
total_k_blocks = math.ceil(K / op.k_per_block)
# we want to calculate how many blocks we need to fit per CU
@ -927,7 +930,6 @@ class CKGemmTemplate(CKTemplate):
assert generator is not None
# TODO(coconutruben): allow users to provide a list of kBatches to sweep over
rops = generator()
ops = []
for o in rops:

View File

@ -1379,6 +1379,13 @@ class rocm:
# Currently RCR and F16 only
use_preselected_instances: bool = False
# List to determine kBatch parameters to sweep over. By default, we calculate one in splitK
# scenarios, and run on kBatch=1 in non-splitK scenarios
kBatch_sweep: Optional[list[int]] = None
# The threshold at which we trigger a splitK config - K // max(M,N) has to be greater than this
split_k_threshold: int = 16
# Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental)
cpu_backend: Literal["cpp", "triton", "halide"] = "cpp"