A fix for reduction + pointwise + multi-level reduction optimization (#112935)

ATT, for cases like reduction + multiple pointwises + multi-level reduction, previously to decide num_splits of the multi-level reduction, we only check whether the input of multi-level reduction or input of input of multi-level reduction is a reduction node (i.e. max search level is 2). This PR changes the behavior to search for a reduction input node recursively if previous input nodes are pointwise nodes.

Performance-wise it looks fine.
![Screenshot 2023-11-15 at 11 52 28 PM](https://github.com/pytorch/pytorch/assets/10527447/e726948c-0c00-4839-87a4-bcf9044c66d7)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112935
Approved by: https://github.com/chenyang78
This commit is contained in:
Ying Zhang 2023-11-21 20:34:02 +00:00 committed by PyTorch MergeBot
parent 2abfb8ec7d
commit 3b108a150a
2 changed files with 44 additions and 29 deletions

View File

@ -372,15 +372,11 @@ class FusionTests(TestCase):
def test_reduction_pointwise_multi_level_reduction(self):
hidden_size = 4096
layer_norm = torch.nn.LayerNorm(hidden_size).cuda().float()
@torch.inference_mode()
def f(x, scale, amax_keep_dim):
x = torch.nn.functional.layer_norm(
x.to(dtype=torch.float),
[hidden_size],
weight=None,
bias=None,
eps=1e-05,
)
x = layer_norm(x.to(dtype=torch.float))
amax = torch.amax(torch.abs(x), keepdim=amax_keep_dim)
x_scaled = x * scale
y = torch.nn.functional.sigmoid(x_scaled)
@ -389,22 +385,26 @@ class FusionTests(TestCase):
inp = (T(4, 2048, hidden_size, dtype=torch.float), T(1, dtype=torch.float))
# 3 kernels:
# kernel 1: (input = X, scale, output = LN_pointwise(X), welford_reduction(X) * 2)
# kernel 2: (input = X, welford_reduction(X) * 2, output = first-level amax (split-reduction))
# kernel 1: (input = X, scale, LN scale, LN bias, output = LN_pointwise(X), welford_reduction(X) * 2)
# kernel 2: (input = X, welford_reduction(X) * 2, LN scale, LN bias, output = first-level amax (split-reduction))
# kernel 3: (input = first-level amax, output = final amax)
# scale (1) + X (4*2048*hidden_size) * 3 + welford_reduction (4*2048) * 4 + amax (num_splits * 2 + 1)
# scale (1) + X (4*2048*hidden_size) * 3 + welford_reduction (4*2048) * 4 +
# LN scale (hidden_size) * 2 + LN bias (hidden_size) * 2 + amax (num_splits * 2 + 1)
# num_splits depends on SM architectures.
expected_amax_keep_dim_numel = 1 + 4 * 2048 * hidden_size * 3 + 4 * 2048 * 4 + 1
expected_amax_keep_dim_numel = (
1 + hidden_size * 4 + 4 * 2048 * hidden_size * 3 + 4 * 2048 * 4 + 1
)
self.assertGreaterAlmostEqual(
count_numel(f, *inp, True), str(expected_amax_keep_dim_numel)
int(count_numel(f, *inp, True)), expected_amax_keep_dim_numel
)
# 2 kernels:
# kernel 1: (input = X, scale, output = LN_pointwise(X), first-level amax (split-reduction))
# kernel 1: (input = X, scale, LN scale, LN bias, output = LN_pointwise(X), first-level amax (split-reduction))
# kernel 2: (input = first-level amax, output = final amax)
# scale (1) + X (4*2048*hidden_size) * 2 + amax (4 * 2048 * 2 + 1)
# scale (1) + X (4*2048*hidden_size) * 2 + LN scale (hidden_size) + LN bias (hidden_size) + amax (4 * 2048 * 2 + 1)
expected_amax_no_keep_dim_numel = (
1 + 4 * 2048 * hidden_size * 2 + 4 * 2048 * 2 + 1
1 + hidden_size * 2 + 4 * 2048 * hidden_size * 2 + 4 * 2048 * 2 + 1
)
self.assertExpectedInline(
count_numel(f, *inp, False), str(expected_amax_no_keep_dim_numel)

View File

@ -407,21 +407,36 @@ def extract_input_node_reduction_ranges(
reads = input_node.get_reads()
reduction_size = None
size = None
for read in reads:
if not isinstance(read, MemoryDep):
continue
buffer = V.graph.get_buffer(read.name)
if buffer is None:
continue
if isinstance(buffer, ComputedBuffer) and len(buffer.get_reduction_size()) > 0:
if reduction_size is None:
reduction_size = buffer.get_reduction_size()
size = buffer.get_size()
elif (
reduction_size != buffer.get_reduction_size()
or size != buffer.get_size()
while reduction_size is None and len(reads) > 0:
seen = set()
new_reads = []
for read in reads:
if not isinstance(read, MemoryDep):
continue
if read.name in seen:
continue
seen.add(read.name)
buffer = V.graph.get_buffer(read.name)
if buffer is None:
continue
if (
isinstance(buffer, ComputedBuffer)
and len(buffer.get_reduction_size()) > 0
):
return (None, None)
if reduction_size is None:
reduction_size = buffer.get_reduction_size()
size = buffer.get_size()
elif (
reduction_size != buffer.get_reduction_size()
or size != buffer.get_size()
):
return (None, None)
else:
new_reads.extend(buffer.get_reads())
if reads == new_reads:
return (size, reduction_size)
else:
reads = new_reads
return (size, reduction_size)