[inductor] consider pointwise nodes when deciding reduction hint (#124131)

In certain **rare** scenarios, inductor can generate a reduction kernel with really bad perf. E.g., if
- the reduction kernel contains a reduction node followed by a pointwise node
- And the pointwise node use a transposed layout.
- the reduction node is an inner reduction
- and rnumel <= 1024 ,

then inductor will generate a persistent reduction kernel and it causes really bad perf when doing tl.store for the pointwise node since we use a very skinny tile `(XBLOCK=1, RBLOCK=next_power_of_2(rnumel))` .

I've tried a few version of fix.
- The first version is, if I found any pointwise node in a reduction kernel uses a non-contiguous dependency, we use ReductionHint.DEFAULT. This cause 8s compilation time increase for huggingface with no perf wins... The reason is ReductionHint.DEFAULT does more autotunings.
- Then I changed the code to be more specific. We change the hint from INNER to DEFAULT if we are sure that the pointwise kernel can use a >1 stride for the lowest dimension. Kernels meet this condition should mostly have really bad perf anyways.

The situation mentioned above is rare. But it's reported by internal users. I'll also run one more perf test.

Testing script: https://gist.github.com/shunting314/9d3389891fa43633b49b8b7564ad6d8b . Something equivalent is also added as a unit test.

For this specific test from user reports, we improve the mentioned reduction kernels perf by **4.14x** (451us -> 109us)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124131
Approved by: https://github.com/jansel
This commit is contained in:
Shunting Zhang 2024-04-18 11:21:50 -07:00 committed by PyTorch MergeBot
parent 57f64197f3
commit c5a4ba2257
4 changed files with 160 additions and 3 deletions

View File

@ -82,6 +82,8 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
from torch.utils._triton import has_triton
from torch.utils.weak import WeakTensorKeyDictionary
DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"
if IS_WINDOWS and IS_CI:
sys.stderr.write(
"Windows CI does not have necessary dependencies for test_torchinductor yet\n"
@ -10332,6 +10334,69 @@ if HAS_GPU and RUN_GPU and not TEST_WITH_ASAN:
seq_nr_set.add(int(res.group(1)))
self.assertTrue(bwd_seq_nr_set.issubset(fwd_seq_nr_set))
@config.patch(
{
"coordinate_descent_tuning": True,
"triton.unique_kernel_names": True,
"benchmark_kernel": True,
}
)
@skipIfRocm
@unittest.skipIf(
torch.cuda.get_device_capability() < (9, 0),
"Triton does not support fp8 on A100",
)
def test_red_followed_by_transposed_pointwise(self):
bs = 26624
dim = 1024
@torch.compile(dynamic=False)
def f(in1, in2, a, b):
out = torch.nn.functional.silu(in1) * in2
out_row = (out / out.amax(dim=1, keepdim=True)).to(torch.float8_e4m3fn)
out_col = (out / out.amax(dim=0, keepdim=True)).to(torch.float8_e4m3fn)
# setup strides for _scaled_mm
out_row = out_row.contiguous()
out_col = out_col.t().contiguous().t()
return (
torch._scaled_mm(out_row, a, out_dtype=torch.bfloat16)[0],
torch._scaled_mm(b, out_col, out_dtype=torch.bfloat16)[0],
)
in1 = torch.randn((bs, dim), dtype=torch.bfloat16, device=GPU_TYPE)
in2 = torch.randn((bs, dim), dtype=torch.bfloat16, device=GPU_TYPE)
a = (
torch.randn((dim, dim), dtype=torch.bfloat16, device=GPU_TYPE)
.t()
.to(torch.float8_e4m3fn)
)
b = torch.randn((dim, bs), dtype=torch.bfloat16, device=GPU_TYPE).to(
torch.float8_e4m3fn
)
# warmup
_, (wrapper,) = run_and_get_code(f, in1, in2, a, b)
# Previously indcutor decide reduction hint for a reduction kernel without considering
# the pointwise nodes. That will cause the third reduction kernel in this wrapper to be a
# persistent inner reduction and cause bad perf.
#
# We fix that by making the third reduction a non-persistent reduction
# and improve the perf by 4.14x (451us -> 109us)
self.assertEqual(3, wrapper.count("def triton_red_"))
self.assertEqual(0, wrapper.count("def triton_per_"))
if DO_PERF_TEST:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA]
) as p:
for _ in range(1000):
f(in1, in2, a, b)
print(p.key_averages().table(max_name_column_width=200))
class RNNTest(TestCase):
class Model(torch.nn.Module):
def __init__(self):

View File

@ -3398,6 +3398,29 @@ class TritonScheduling(BaseScheduling):
return "tl.int32"
return "tl.int64"
def has_non_contiguous_pw_in_reduction_kernel(self, node_schedule, numel, rnumel):
pointwise_nodes = list(
filter(
lambda n: n not in (EnableReduction, DisableReduction)
and not n.is_reduction()
and n.group[1][0] == numel * rnumel,
node_schedule,
)
)
for node in pointwise_nodes:
# An index can be an integer when loading a random seed.
if not all(
not isinstance(dep, MemoryDep)
or dep.is_contiguous()
or isinstance(dep.index, (sympy.Integer, int))
or dep.stride1_for_last_dim()
for dep in itertools.chain(
node.read_writes.reads, node.read_writes.writes
)
):
return True
return False
def get_kernel_args(self, node_schedule, numel, reduction_numel):
reductions = list(
filter(
@ -3412,6 +3435,14 @@ class TritonScheduling(BaseScheduling):
reduction_hint_val = hints[0]
else:
reduction_hint_val = ReductionHint.DEFAULT
if (
reduction_hint_val == ReductionHint.INNER
and self.has_non_contiguous_pw_in_reduction_kernel(
node_schedule, numel, reduction_numel
)
):
reduction_hint_val = ReductionHint.DEFAULT
else:
reduction_hint_val = ReductionHint.DEFAULT
@ -3456,9 +3487,11 @@ class TritonScheduling(BaseScheduling):
from torch._inductor.codegen.triton_split_scan import TritonSplitScanKernel
tiled_groups = self.select_tiling(node_schedule, numel, reduction_numel)
reduction_hint_val, mutations, index_dtype = self.get_kernel_args(
node_schedule, numel, reduction_numel
)
(
reduction_hint_val,
mutations,
index_dtype,
) = self.get_kernel_args(node_schedule, numel, reduction_numel)
is_split_scan = any(
isinstance(node, BaseSchedulerNode) and node.is_split_scan()

View File

@ -71,6 +71,35 @@ class MemoryDep(typing.NamedTuple):
def is_contiguous(self) -> bool:
return isinstance(self.index, sympy.Symbol) and self.index in self.var_names
def stride1_for_last_dim(self, result_for_complex_expression=True) -> bool:
"""
Whether the stride for the last dimension is 1.
"""
# python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_masked_scatter_cuda_float16
# will exercise thru this corner case.
if len(self.var_names) == 0:
return True
terms = self.index.args if isinstance(self.index, sympy.Add) else [self.index]
last_sym = self.var_names[-1]
for term in terms:
if term is last_sym:
return True
# Having a >1 stride for the last dimension is bad for perf
# return False.
if (
isinstance(term, sympy.Mul)
and len(term.args) == 2
and term.args[1] is last_sym
and isinstance(term.args[0], (int, sympy.Integer))
and term.args[0] > 1
):
return False
return result_for_complex_expression
def is_scalar(self) -> bool:
if isinstance(self.index, sympy.Symbol):
return self.index not in self.var_names and not self.is_indirect()

View File

@ -1558,3 +1558,33 @@ def use_scatter_fallback(
or (reduction_type == reduce_ty and self_dtype in {torch.bool, torch.int64})
or torch.are_deterministic_algorithms_enabled()
)
def dump_node_schedule(node_schedule):
"""
An API that can be used in pdb to dump a node_schedule.
Right mainly dump the read/write dependencies but can add more as needed.
"""
from torch._inductor.codegen.triton import DisableReduction, EnableReduction
from torch._inductor.scheduler import SchedulerNode
print(f"Node schedule with {len(node_schedule)} nodes")
for idx, node in enumerate(node_schedule):
print(f" {idx:3}:")
if node is EnableReduction:
print("enable reduction")
elif node is DisableReduction:
print("disable reduction")
elif isinstance(node, SchedulerNode):
is_red = node.is_reduction()
print(f"{'red' if is_red else 'pw'} scheduler node")
if is_red:
print(f"original reduction hint {node.node.data.reduction_hint}") # type: ignore[attr-defined]
print("ReadDep:")
for dep in node.read_writes.reads:
print(dep)
print("WriteDep:")
for dep in node.read_writes.writes:
print(dep)
else:
raise RuntimeError(f"Unrecognized node type: {type(node)}")