[Inductor] Make combo kernel MAX_NUM_ARGS configurable (#166274)

The MAX_NUM_ARGS of ComboKernel is currently a fixed number. We need to tune this number to avoid large fusion for MTIA, thus making it configurable.

Differential Revision: [D85509352](https://our.internmc.facebook.com/intern/diff/D85509352/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166274
Approved by: https://github.com/eellison
This commit is contained in:
anwang 2025-10-26 15:08:19 -07:00 committed by PyTorch MergeBot
parent a076b4d7ac
commit eb2bad5bb5
2 changed files with 3 additions and 3 deletions

View File

@ -172,15 +172,13 @@ class PartitionState:
class ComboKernel(Kernel): class ComboKernel(Kernel):
MAX_NUM_ARGS = 250 # number where I would no longer get triton errors
@staticmethod @staticmethod
def _update_partition( def _update_partition(
partition_state: PartitionState, partition_state: PartitionState,
node_rw_count: int, node_rw_count: int,
node_info: BaseSchedulerNode, node_info: BaseSchedulerNode,
) -> None: ) -> None:
if partition_state.cur_count + node_rw_count > ComboKernel.MAX_NUM_ARGS: if partition_state.cur_count + node_rw_count > config.combo_kernel_max_num_args:
partition_state.partitions.append(partition_state.cur_partition) partition_state.partitions.append(partition_state.cur_partition)
partition_state.cur_partition = [node_info] partition_state.cur_partition = [node_info]
partition_state.cur_count = node_rw_count partition_state.cur_count = node_rw_count

View File

@ -745,6 +745,8 @@ combo_kernels_autotune = 1
combo_kernel_allow_mixed_sizes = 1 combo_kernel_allow_mixed_sizes = 1
# Enable dynamic shapes for foreach kernels # Enable dynamic shapes for foreach kernels
combo_kernel_foreach_dynamic_shapes = True combo_kernel_foreach_dynamic_shapes = True
# Maximum number of arguments (read/write buffers) allowed in a combo kernel
combo_kernel_max_num_args = 250
# constant folding on the joint graph # constant folding on the joint graph
joint_graph_constant_folding = True joint_graph_constant_folding = True