pytorch/test/distributed/test_compute_comm_reordering.py

396 lines
17 KiB
Python

# Owner(s): ["module: inductor"]
import unittest
from unittest.mock import patch
import torch
import torch._dynamo
import torch._dynamo.logging
import torch._dynamo.test_case
# for some reason importing functional collectives after dynamo breaks collectives handling!
import torch.distributed._functional_collectives as _functional_collectives
from torch._C import FileCheck
from torch._dynamo.utils import same
from torch._inductor import ir, scheduler
from torch._inductor.comm_analysis import (
baseLat,
hwLat,
llMaxBws,
NCCL_ALGO,
NCCL_HW,
NCCL_PROTO,
NVIDIA_GPU_TYPE,
)
from torch._inductor.utils import run_and_get_triton_code
from torch.testing._internal.common_distributed import (
_dynamo_dist_per_rank_init,
at_least_x_gpu,
DynamoDistributedMultiProcTestCase,
requires_nccl,
)
from torch.utils._triton import has_triton
def get_snode_runtime_for_reorder_compute_test(snode):
# NOTE: custom cost model to show that the compute reordering algorithm is working
# Collective kernels
if isinstance(snode.node, ir._CollectiveKernel):
return 100
elif isinstance(snode.node, ir._WaitKernel):
return 0
# High-arithmetic-intensity compute kernels
elif isinstance(snode.node, ir.ExternKernel):
return 5
# All other kernels
return 1
def create_grouped_node_for_allreduce_and_its_deps(snodes):
name_to_snode = {snode.node.name: snode for snode in snodes}
all_reduce_snodes = [
snode
for snode in snodes
if isinstance(snode.node, ir._CollectiveKernel)
and snode.node.op_overload == torch.ops._c10d_functional.all_reduce_.default
]
assert len(all_reduce_snodes) == 1
all_reduce_snode = all_reduce_snodes[0]
all_reduce_dep_snodes = [
name_to_snode[node.name] for node in all_reduce_snode.node.inputs
]
assert len(all_reduce_dep_snodes) == 1
all_reduce_dep_snode = all_reduce_dep_snodes[0]
grouped_snode = scheduler.GroupedSchedulerNode.create(
[all_reduce_dep_snode, all_reduce_snode]
)
new_snode_order = []
new_snode_order.append(grouped_snode)
for snode in snodes:
if snode in grouped_snode.snodes:
continue
new_snode_order.append(snode)
return new_snode_order
@requires_nccl()
class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
"""
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
"""
def get_world_trs(self):
return {
"tag": "",
"ranks": list(range(self.world_size)),
"group_size": self.world_size,
}
@property
def world_size(self) -> int:
# hack: no matter whether we have 2 or 3 or 4 gpus, just run on 2
# works around issue with skipif<2 and workers with unpredictable #s gpu
return 2
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(torch._inductor.config, "reorder_for_locality", False)
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
@patch.object(
torch._inductor.config,
"reorder_for_compute_comm_overlap_passes",
[
"sink_waits",
],
)
def test_sink_waits(self):
def func(a):
ar = _functional_collectives.all_reduce(a, "sum", "0")
b = torch.matmul(a, a)
return torch.matmul(ar, b)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs)
# Verify that the wait_tensor is sinked below the 1st matmul but
# above the 2nd matmul.
(
FileCheck()
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("extern_kernels.mm")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("extern_kernels.mm")
.run(code)
)
out = compiled(inputs)
correct = func(inputs)
self.assertTrue(same(out, correct))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(torch._inductor.config, "reorder_for_locality", False)
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
@patch.object(
torch._inductor.config,
"reorder_for_compute_comm_overlap_passes",
[
"raise_comms",
],
)
def test_raise_comms(self):
def func(a):
b = torch.matmul(a, a)
c = torch.relu(b)
d = torch.matmul(c, c)
e = _functional_collectives.all_reduce(b, "sum", "0")
return torch.matmul(d, e)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs)
print(code)
# Verify that the all_reduce_ has been raised above the 2nd matmul
# but below the 1st matmul. Note that the all_reduce_ directly
# writes to the output buffer of the 1st matmul, which is an input
# to the first relu. Therefore, the all_reduce_ should be scheduled
# after the first relu.
(
FileCheck()
.check("extern_kernels.mm")
.check("triton_poi_fused_relu")
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("extern_kernels.mm")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("extern_kernels.mm")
.run(code)
)
out = compiled(inputs)
correct = func(inputs)
self.assertTrue(same(out, correct))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
@patch.object(
torch._inductor.config,
"reorder_for_compute_comm_overlap_passes",
[
"sink_waits",
"raise_comms",
],
)
def test_sink_waits_raise_comms(self):
def func(a, *, tag, ranks, group_size):
b = torch.matmul(a, a)
c = torch.relu(b)
d = torch.matmul(c, c)
e = _functional_collectives.all_reduce(b, "sum", "0")
f = torch.relu(d)
g = torch.matmul(f, f)
return torch.mm(e, g)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# Things to verify:
# - The clone prologue of the all_reduce_ should not be fused with
# any relus.
# - The all_reduce_ and its prologue should be raised above the 2nd
# matmul but below the 1st matmul.
# - The wait_tensor should be sinked below the 3rd matmul but above
# the 4th matmul.
(
FileCheck()
.check("extern_kernels.mm")
.check("triton_poi_fused_all_reduce_0")
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("triton_poi_fused_relu")
.check("extern_kernels.mm")
.check("triton_poi_fused_relu")
.check("extern_kernels.mm")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("extern_kernels.mm")
.run(code)
)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
@patch.object(
torch._inductor.config,
"reorder_for_compute_comm_overlap_passes",
[
"reorder_compute_for_overlap",
],
)
def test_reorder_compute_for_overlap(self):
def func(a, *, tag, ranks, group_size):
ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
g = torch.matmul(a, a)
c = torch.relu(a)
d = torch.matmul(c, c)
f = d * c * ar
fr = _functional_collectives.all_reduce(f, "sum", ranks, tag)
e = torch.matmul(d + ar + fr, g)
return (e,)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# NOTE: after scheduling the first all_reduce:
# 1. we first schedule the ops (c and d) that ARE required for second all_reduce but DO NOT depend on first all_reduce.
# 2. then, we schedule the ops (g) that ARE NOT required for second all_reduce and DO NOT depend on first all_reduce.
# 3. then, we schedule the ops (f) that ARE required for second all_reduce and DO depend on first all_reduce.
# and then, we schedule the second all_reduce. And then schedule all ops that depend on second all_reduce.
(
FileCheck()
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("triton_poi_fused_relu")
.check("extern_kernels.mm")
.check("extern_kernels.mm")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("triton_poi_fused_mul")
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("triton_poi_fused_add")
.check("extern_kernels.mm")
.run(code)
)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
@patch.object(
torch._inductor.config,
"reorder_for_compute_comm_overlap_passes",
[
"reorder_compute_for_overlap",
],
)
@patch.object(
torch._inductor.config,
"estimate_op_runtime",
get_snode_runtime_for_reorder_compute_test,
)
def test_reorder_compute_for_overlap_custom_runtime_estimation(self):
def func(a, *, tag, ranks, group_size):
ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
g = torch.matmul(a, a)
c = torch.relu(a)
d = torch.matmul(c, c)
f = d * c * ar
fr = _functional_collectives.all_reduce(f, "sum", ranks, tag)
e = torch.matmul(d + ar + fr, g)
return (e,)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# NOTE: after scheduling the first all_reduce:
# 1. we first schedule the ops (c and d) that ARE required for second all_reduce but DO NOT depend on first all_reduce.
# 2. then, we schedule the ops (g) that ARE NOT required for second all_reduce and DO NOT depend on first all_reduce.
# 3. then, we schedule the ops (f) that ARE required for second all_reduce and DO depend on first all_reduce.
# and then, we schedule the second all_reduce. And then schedule all ops that depend on second all_reduce.
(
FileCheck()
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("triton_poi_fused_relu")
.check("extern_kernels.mm")
.check("extern_kernels.mm")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("triton_poi_fused_mul")
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("triton_poi_fused_add")
.check("extern_kernels.mm")
.run(code)
)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(
torch._inductor.config,
"_pre_fusion_custom_pass",
create_grouped_node_for_allreduce_and_its_deps,
)
def test_grouped_scheduler_node(self):
def func(a, *, tag, ranks, group_size):
add = a + a
div = add / a
ar = _functional_collectives.all_reduce(div, "sum", ranks, tag)
# Normally, we would fuse `add = a + a`, `div = add / a` and `mul = a * a` together into a single fused op,
# but here in this unit test, we intentionally put `add`, `div` and `ar` computation
# into a GroupedSchedulerNode, which prevents them from being fused with any other ops.
mul = a * a
mm = torch.matmul(mul, ar)
return (mm,)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# Expectations:
# 1. `add = a + a` and `div = add / a` are still fused, which means fusion
# still happens among nodes within a GroupedSchedulerNode.
# 2. `mul = a * a` is not fused with `add` or `div`, because the latter two are within
# GroupedSchedulerNode and thus are prevented from being fused with any outside ops.
FileCheck().check("triton_poi_fused_add_div_0.").check(
"_c10d_functional.all_reduce_."
).check("triton_poi_fused_mul_1.").run(code)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
def test_nccl_heuristics(self):
assert len(baseLat) == len(NCCL_ALGO)
assert all(len(x) == len(NCCL_PROTO) for x in baseLat)
assert len(hwLat) == len(NCCL_HW)
assert all(len(x) == len(NCCL_ALGO) for x in hwLat)
assert all(len(y) == len(NCCL_PROTO) for x in hwLat for y in x)
assert len(llMaxBws) == len(NVIDIA_GPU_TYPE)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()