mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor] Make combo kernel MAX_NUM_ARGS configurable (#166274)
The MAX_NUM_ARGS of ComboKernel is currently a fixed number. We need to tune this number to avoid large fusion for MTIA, thus making it configurable. Differential Revision: [D85509352](https://our.internmc.facebook.com/intern/diff/D85509352/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166274 Approved by: https://github.com/eellison
This commit is contained in:
parent
a076b4d7ac
commit
eb2bad5bb5
|
|
@ -172,15 +172,13 @@ class PartitionState:
|
||||||
|
|
||||||
|
|
||||||
class ComboKernel(Kernel):
|
class ComboKernel(Kernel):
|
||||||
MAX_NUM_ARGS = 250 # number where I would no longer get triton errors
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _update_partition(
|
def _update_partition(
|
||||||
partition_state: PartitionState,
|
partition_state: PartitionState,
|
||||||
node_rw_count: int,
|
node_rw_count: int,
|
||||||
node_info: BaseSchedulerNode,
|
node_info: BaseSchedulerNode,
|
||||||
) -> None:
|
) -> None:
|
||||||
if partition_state.cur_count + node_rw_count > ComboKernel.MAX_NUM_ARGS:
|
if partition_state.cur_count + node_rw_count > config.combo_kernel_max_num_args:
|
||||||
partition_state.partitions.append(partition_state.cur_partition)
|
partition_state.partitions.append(partition_state.cur_partition)
|
||||||
partition_state.cur_partition = [node_info]
|
partition_state.cur_partition = [node_info]
|
||||||
partition_state.cur_count = node_rw_count
|
partition_state.cur_count = node_rw_count
|
||||||
|
|
|
||||||
|
|
@ -745,6 +745,8 @@ combo_kernels_autotune = 1
|
||||||
combo_kernel_allow_mixed_sizes = 1
|
combo_kernel_allow_mixed_sizes = 1
|
||||||
# Enable dynamic shapes for foreach kernels
|
# Enable dynamic shapes for foreach kernels
|
||||||
combo_kernel_foreach_dynamic_shapes = True
|
combo_kernel_foreach_dynamic_shapes = True
|
||||||
|
# Maximum number of arguments (read/write buffers) allowed in a combo kernel
|
||||||
|
combo_kernel_max_num_args = 250
|
||||||
|
|
||||||
# constant folding on the joint graph
|
# constant folding on the joint graph
|
||||||
joint_graph_constant_folding = True
|
joint_graph_constant_folding = True
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user