mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
A fix for reduction + pointwise + multi-level reduction optimization (#112935)
ATT, for cases like reduction + multiple pointwises + multi-level reduction, previously to decide num_splits of the multi-level reduction, we only check whether the input of multi-level reduction or input of input of multi-level reduction is a reduction node (i.e. max search level is 2). This PR changes the behavior to search for a reduction input node recursively if previous input nodes are pointwise nodes. Performance-wise it looks fine.  Pull Request resolved: https://github.com/pytorch/pytorch/pull/112935 Approved by: https://github.com/chenyang78
This commit is contained in:
parent
2abfb8ec7d
commit
3b108a150a
|
|
@ -372,15 +372,11 @@ class FusionTests(TestCase):
|
|||
|
||||
def test_reduction_pointwise_multi_level_reduction(self):
|
||||
hidden_size = 4096
|
||||
layer_norm = torch.nn.LayerNorm(hidden_size).cuda().float()
|
||||
|
||||
@torch.inference_mode()
|
||||
def f(x, scale, amax_keep_dim):
|
||||
x = torch.nn.functional.layer_norm(
|
||||
x.to(dtype=torch.float),
|
||||
[hidden_size],
|
||||
weight=None,
|
||||
bias=None,
|
||||
eps=1e-05,
|
||||
)
|
||||
x = layer_norm(x.to(dtype=torch.float))
|
||||
amax = torch.amax(torch.abs(x), keepdim=amax_keep_dim)
|
||||
x_scaled = x * scale
|
||||
y = torch.nn.functional.sigmoid(x_scaled)
|
||||
|
|
@ -389,22 +385,26 @@ class FusionTests(TestCase):
|
|||
inp = (T(4, 2048, hidden_size, dtype=torch.float), T(1, dtype=torch.float))
|
||||
|
||||
# 3 kernels:
|
||||
# kernel 1: (input = X, scale, output = LN_pointwise(X), welford_reduction(X) * 2)
|
||||
# kernel 2: (input = X, welford_reduction(X) * 2, output = first-level amax (split-reduction))
|
||||
# kernel 1: (input = X, scale, LN scale, LN bias, output = LN_pointwise(X), welford_reduction(X) * 2)
|
||||
# kernel 2: (input = X, welford_reduction(X) * 2, LN scale, LN bias, output = first-level amax (split-reduction))
|
||||
# kernel 3: (input = first-level amax, output = final amax)
|
||||
# scale (1) + X (4*2048*hidden_size) * 3 + welford_reduction (4*2048) * 4 + amax (num_splits * 2 + 1)
|
||||
# scale (1) + X (4*2048*hidden_size) * 3 + welford_reduction (4*2048) * 4 +
|
||||
# LN scale (hidden_size) * 2 + LN bias (hidden_size) * 2 + amax (num_splits * 2 + 1)
|
||||
# num_splits depends on SM architectures.
|
||||
expected_amax_keep_dim_numel = 1 + 4 * 2048 * hidden_size * 3 + 4 * 2048 * 4 + 1
|
||||
expected_amax_keep_dim_numel = (
|
||||
1 + hidden_size * 4 + 4 * 2048 * hidden_size * 3 + 4 * 2048 * 4 + 1
|
||||
)
|
||||
self.assertGreaterAlmostEqual(
|
||||
count_numel(f, *inp, True), str(expected_amax_keep_dim_numel)
|
||||
int(count_numel(f, *inp, True)), expected_amax_keep_dim_numel
|
||||
)
|
||||
|
||||
# 2 kernels:
|
||||
# kernel 1: (input = X, scale, output = LN_pointwise(X), first-level amax (split-reduction))
|
||||
# kernel 1: (input = X, scale, LN scale, LN bias, output = LN_pointwise(X), first-level amax (split-reduction))
|
||||
# kernel 2: (input = first-level amax, output = final amax)
|
||||
# scale (1) + X (4*2048*hidden_size) * 2 + amax (4 * 2048 * 2 + 1)
|
||||
# scale (1) + X (4*2048*hidden_size) * 2 + LN scale (hidden_size) + LN bias (hidden_size) + amax (4 * 2048 * 2 + 1)
|
||||
|
||||
expected_amax_no_keep_dim_numel = (
|
||||
1 + 4 * 2048 * hidden_size * 2 + 4 * 2048 * 2 + 1
|
||||
1 + hidden_size * 2 + 4 * 2048 * hidden_size * 2 + 4 * 2048 * 2 + 1
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
count_numel(f, *inp, False), str(expected_amax_no_keep_dim_numel)
|
||||
|
|
|
|||
|
|
@ -407,21 +407,36 @@ def extract_input_node_reduction_ranges(
|
|||
reads = input_node.get_reads()
|
||||
reduction_size = None
|
||||
size = None
|
||||
for read in reads:
|
||||
if not isinstance(read, MemoryDep):
|
||||
continue
|
||||
buffer = V.graph.get_buffer(read.name)
|
||||
if buffer is None:
|
||||
continue
|
||||
if isinstance(buffer, ComputedBuffer) and len(buffer.get_reduction_size()) > 0:
|
||||
if reduction_size is None:
|
||||
reduction_size = buffer.get_reduction_size()
|
||||
size = buffer.get_size()
|
||||
elif (
|
||||
reduction_size != buffer.get_reduction_size()
|
||||
or size != buffer.get_size()
|
||||
while reduction_size is None and len(reads) > 0:
|
||||
seen = set()
|
||||
new_reads = []
|
||||
for read in reads:
|
||||
if not isinstance(read, MemoryDep):
|
||||
continue
|
||||
if read.name in seen:
|
||||
continue
|
||||
seen.add(read.name)
|
||||
buffer = V.graph.get_buffer(read.name)
|
||||
if buffer is None:
|
||||
continue
|
||||
if (
|
||||
isinstance(buffer, ComputedBuffer)
|
||||
and len(buffer.get_reduction_size()) > 0
|
||||
):
|
||||
return (None, None)
|
||||
if reduction_size is None:
|
||||
reduction_size = buffer.get_reduction_size()
|
||||
size = buffer.get_size()
|
||||
elif (
|
||||
reduction_size != buffer.get_reduction_size()
|
||||
or size != buffer.get_size()
|
||||
):
|
||||
return (None, None)
|
||||
else:
|
||||
new_reads.extend(buffer.get_reads())
|
||||
if reads == new_reads:
|
||||
return (size, reduction_size)
|
||||
else:
|
||||
reads = new_reads
|
||||
return (size, reduction_size)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user