mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Zero bubble can be expressed through `ScheduleFlexibleInterleaved1F1B` by setting `enable_zero_bubble=True`. But instead of having to include this flag in schedule initialization we should create a separate ZeroBubbleSchedule and also transition `Interleaved1F1B` to derive from `ScheduleFlexibleInterleaved1F1B`. Then we dont need to expose `ScheduleFlexibleInterleaved1F1B` since the naming is not obvious Pull Request resolved: https://github.com/pytorch/pytorch/pull/133467 Approved by: https://github.com/wconstab ghstack dependencies: #132691
824 lines
29 KiB
Python
824 lines
29 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
import copy
|
|
import logging
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
|
|
from model_registry import ModelWithKwargs, MultiMLP, MultiMLPWithDw
|
|
from schedule_registry import ScheduleUnbalanced, ScheduleVShaped, ScheduleWithW
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed.pipelining import (
|
|
_ScheduleForwardOnly,
|
|
pipeline,
|
|
PipelineStage,
|
|
Schedule1F1B,
|
|
ScheduleFlexibleInterleaved1F1B,
|
|
ScheduleGPipe,
|
|
ScheduleInterleaved1F1B,
|
|
ScheduleInterleavedZeroBubble,
|
|
ScheduleLoopedBFS,
|
|
)
|
|
from torch.distributed.pipelining.schedules import _PipelineScheduleRuntime
|
|
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
|
from torch.testing._internal.common_distributed import (
|
|
MultiProcContinousTest,
|
|
requires_nccl,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
skip_but_pass_in_sandcastle_if,
|
|
)
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
d_hid = 512
|
|
batch_size = 256
|
|
|
|
torch.manual_seed(0)
|
|
|
|
|
|
class ScheduleTest(MultiProcContinousTest):
|
|
@classmethod
|
|
def backend_str(cls) -> str:
|
|
# Testing with NCCL backend
|
|
return "nccl"
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
"""
|
|
Class-scope test fixture. Run once for entire test class, before any test starts.
|
|
Set up the device.
|
|
"""
|
|
super().setUpClass()
|
|
dev_id = cls.rank % torch.cuda.device_count()
|
|
cls.device = torch.device(f"cuda:{dev_id}")
|
|
|
|
@requires_nccl()
|
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
|
@parametrize("ScheduleClass", [_ScheduleForwardOnly])
|
|
def test_forward_only(self, ScheduleClass):
|
|
mod = MultiMLP(d_hid, n_layers=self.world_size)
|
|
mod.to(self.device)
|
|
|
|
mod_ref = copy.deepcopy(mod)
|
|
|
|
x = torch.randn(batch_size, d_hid, device=self.device)
|
|
x_clone = x.clone()
|
|
|
|
num_microbatches = 4
|
|
x_mb = x.chunk(num_microbatches)[0]
|
|
|
|
# Create a pipeline
|
|
split_spec = mod.split_spec if hasattr(mod, "split_spec") else None
|
|
pipe = pipeline(
|
|
mod,
|
|
mb_args=(x_mb,),
|
|
split_spec=split_spec,
|
|
)
|
|
|
|
stage = pipe.build_stage(
|
|
self.rank,
|
|
self.device,
|
|
)
|
|
|
|
# Attach to a schedule
|
|
schedule = ScheduleClass(stage, num_microbatches)
|
|
|
|
# Run
|
|
num_iters = 20
|
|
for _ in range(num_iters):
|
|
if self.rank == 0:
|
|
schedule.step(x)
|
|
dist.recv(x, src=self.world_size - 1)
|
|
elif self.rank == self.world_size - 1:
|
|
out = schedule.step()
|
|
dist.send(out, dst=0)
|
|
else:
|
|
schedule.step()
|
|
|
|
# Validate pipelined output is the same as reference model
|
|
if self.rank == self.world_size - 1:
|
|
for _ in range(num_iters):
|
|
x_clone = mod_ref(x_clone)
|
|
|
|
torch.testing.assert_close(x_clone, out)
|
|
|
|
@requires_nccl()
|
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
|
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
|
def test_multi_iter(self, ScheduleClass):
|
|
mod = MultiMLP(d_hid, n_layers=self.world_size)
|
|
mod.to(self.device)
|
|
|
|
x = torch.randn(batch_size, d_hid, device=self.device)
|
|
target = torch.randn(batch_size, d_hid, device=self.device)
|
|
loss_fn = torch.nn.MSELoss(reduction="sum")
|
|
|
|
chunks = 4
|
|
x_mb = x.chunk(chunks)[0]
|
|
|
|
# Create a pipeline
|
|
split_spec = mod.split_spec if hasattr(mod, "split_spec") else None
|
|
pipe = pipeline(
|
|
mod,
|
|
mb_args=(x_mb,),
|
|
split_spec=split_spec,
|
|
)
|
|
|
|
stage = pipe.build_stage(
|
|
self.rank,
|
|
self.device,
|
|
)
|
|
|
|
# Attach to a schedule
|
|
schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn)
|
|
|
|
# Run
|
|
for _ in range(20):
|
|
if self.rank == 0:
|
|
schedule.step(x)
|
|
elif self.rank == self.world_size - 1:
|
|
losses = []
|
|
out = schedule.step(target=target, losses=losses)
|
|
else:
|
|
schedule.step()
|
|
|
|
@requires_nccl()
|
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
|
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
|
def test_kwargs_with_tracer(self, ScheduleClass):
|
|
mod = ModelWithKwargs(d_hid)
|
|
mod.to(self.device)
|
|
|
|
x = torch.randn(batch_size, d_hid, device=self.device)
|
|
y = torch.randn(batch_size, d_hid, device=self.device)
|
|
target = torch.randn(batch_size, d_hid, device=self.device)
|
|
loss_fn = torch.nn.MSELoss(reduction="sum")
|
|
|
|
chunks = 4
|
|
x_mb = x.chunk(chunks)[0]
|
|
y_mb = y.chunk(chunks)[0]
|
|
|
|
pipe = pipeline(
|
|
mod,
|
|
mb_args=(x_mb,),
|
|
mb_kwargs={"y": y_mb},
|
|
)
|
|
|
|
stage = pipe.build_stage(
|
|
self.rank,
|
|
self.device,
|
|
)
|
|
|
|
# Attach to a schedule
|
|
schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn)
|
|
|
|
# Run
|
|
if self.rank == 0:
|
|
schedule.step(x, y=y)
|
|
elif self.rank == self.world_size - 1:
|
|
losses = []
|
|
out = schedule.step(target=target, losses=losses)
|
|
else:
|
|
schedule.step()
|
|
|
|
dist.barrier()
|
|
|
|
# Last rank checks result
|
|
if self.rank == self.world_size - 1:
|
|
ref_out = mod(x, y=y)
|
|
ref_loss = loss_fn(ref_out, target)
|
|
pipe_loss = sum(losses)
|
|
torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=5e-3)
|
|
torch.testing.assert_close(pipe_loss, ref_loss)
|
|
|
|
@requires_nccl()
|
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
|
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
|
@parametrize("ModelClass", [MultiMLP])
|
|
def test_grad_with_tracer(self, ScheduleClass, ModelClass):
|
|
mod = ModelClass(d_hid)
|
|
mod.to(self.device)
|
|
|
|
ref_mod = copy.deepcopy(mod)
|
|
x = torch.randn(batch_size, d_hid, device=self.device)
|
|
with torch.no_grad():
|
|
y = ref_mod(x)
|
|
# Add a small perturbation
|
|
target = y + torch.randn(batch_size, d_hid, device=self.device)
|
|
|
|
loss_fn = torch.nn.MSELoss(reduction="sum")
|
|
|
|
# Run reference
|
|
for _ in range(2):
|
|
ref_mod.zero_grad()
|
|
ref_out = ref_mod(x)
|
|
ref_loss = loss_fn(ref_out, target)
|
|
ref_loss.backward()
|
|
|
|
# Create a pipeline
|
|
chunks = 4
|
|
x_mb = x.chunk(chunks)[0]
|
|
split_spec = mod.split_spec if hasattr(mod, "split_spec") else None
|
|
pipe = pipeline(
|
|
mod,
|
|
mb_args=(x_mb,),
|
|
split_spec=split_spec,
|
|
)
|
|
|
|
stage = pipe.build_stage(
|
|
self.rank,
|
|
self.device,
|
|
)
|
|
|
|
# Attach to a schedule
|
|
schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn)
|
|
|
|
# Run
|
|
stage_module = pipe.get_stage_module(self.rank)
|
|
for _ in range(2):
|
|
# Zero gradients
|
|
stage_module.zero_grad()
|
|
if self.rank == 0:
|
|
schedule.step(x)
|
|
elif self.rank == self.world_size - 1:
|
|
losses = []
|
|
out = schedule.step(target=target, losses=losses)
|
|
else:
|
|
schedule.step()
|
|
|
|
dist.barrier()
|
|
|
|
# Last rank checks result
|
|
if self.rank == self.world_size - 1:
|
|
# Check output
|
|
torch.testing.assert_close(out, ref_out)
|
|
# Check loss
|
|
# Since the reduction used in the loss function above is "sum", we use
|
|
# "sum" here to reduce microbatch losses into a single value too.
|
|
pipe_loss = sum(losses)
|
|
torch.testing.assert_close(pipe_loss, ref_loss)
|
|
|
|
# Every rank checks gradients
|
|
for name, p in stage_module.named_parameters():
|
|
ref_p = ref_mod.get_parameter(name)
|
|
try:
|
|
torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5)
|
|
except AssertionError:
|
|
print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
|
|
raise
|
|
|
|
@requires_nccl()
|
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
|
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
|
def test_grad_with_manual(self, ScheduleClass):
|
|
full_mod = MultiMLP(d_hid, n_layers=self.world_size)
|
|
full_mod.to(self.device)
|
|
|
|
ref_mod = copy.deepcopy(full_mod)
|
|
x = torch.randn(batch_size, d_hid, device=self.device)
|
|
with torch.no_grad():
|
|
y = ref_mod(x)
|
|
# Add a small perturbation
|
|
target = y + torch.randn(batch_size, d_hid, device=self.device)
|
|
|
|
loss_fn = torch.nn.MSELoss(reduction="sum")
|
|
|
|
# Run reference
|
|
for _ in range(2):
|
|
ref_mod.zero_grad()
|
|
ref_out = ref_mod(x)
|
|
ref_loss = loss_fn(ref_out, target)
|
|
ref_loss.backward()
|
|
|
|
# Get a submodule, e.g. `layers.0` or `layers.1`
|
|
submod_name = f"layers.{self.rank}"
|
|
stage_module = full_mod.get_submodule(submod_name)
|
|
chunks = 4
|
|
# Create a pipeline stage to wrap that submodule
|
|
stage = PipelineStage(
|
|
stage_module,
|
|
self.rank,
|
|
self.world_size,
|
|
self.device,
|
|
input_args=x.chunk(chunks)[0],
|
|
)
|
|
|
|
# Attach to a schedule
|
|
schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn)
|
|
|
|
# Run
|
|
for _ in range(2):
|
|
# Zero gradients
|
|
stage_module.zero_grad()
|
|
if self.rank == 0:
|
|
schedule.step(x)
|
|
elif self.rank == self.world_size - 1:
|
|
losses = []
|
|
out = schedule.step(target=target, losses=losses)
|
|
else:
|
|
schedule.step()
|
|
|
|
dist.barrier()
|
|
|
|
# Last rank checks result
|
|
if self.rank == self.world_size - 1:
|
|
# Check output
|
|
torch.testing.assert_close(out, ref_out)
|
|
# Check loss
|
|
# Since the reduction used in the loss function above is "sum", we use
|
|
# "sum" here to reduce microbatch losses into a single value too.
|
|
pipe_loss = sum(losses)
|
|
torch.testing.assert_close(pipe_loss, ref_loss)
|
|
|
|
# Every rank checks gradients
|
|
ref_submod = ref_mod.get_submodule(submod_name)
|
|
for name, p in stage_module.named_parameters():
|
|
ref_p = ref_submod.get_parameter(name)
|
|
try:
|
|
torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5)
|
|
except AssertionError:
|
|
print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
|
|
raise
|
|
|
|
@requires_nccl()
|
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
|
@parametrize(
|
|
"ScheduleClass",
|
|
[ScheduleInterleaved1F1B, ScheduleLoopedBFS, ScheduleInterleavedZeroBubble],
|
|
)
|
|
@parametrize("use_new_runtime", [False, True])
|
|
def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime):
|
|
stages_per_rank = 2
|
|
n_stages = stages_per_rank * self.world_size
|
|
full_mod = MultiMLP(d_hid, n_layers=n_stages)
|
|
full_mod.to(self.device)
|
|
|
|
ref_mod = copy.deepcopy(full_mod)
|
|
x = torch.randn(batch_size, d_hid, device=self.device)
|
|
with torch.no_grad():
|
|
y = ref_mod(x)
|
|
# Add a small perturbation
|
|
target = y + torch.randn(batch_size, d_hid, device=self.device)
|
|
|
|
loss_fn = torch.nn.MSELoss(reduction="sum")
|
|
|
|
# Run reference
|
|
for _ in range(2):
|
|
ref_mod.zero_grad()
|
|
ref_out = ref_mod(x)
|
|
ref_loss = loss_fn(ref_out, target)
|
|
ref_loss.backward()
|
|
|
|
# Get a submodule, e.g. `layers.0` or `layers.1`
|
|
stage_indices = [
|
|
self.rank + i * self.world_size for i in range(stages_per_rank)
|
|
]
|
|
print(f"Rank {self.rank} stages: {stage_indices}")
|
|
submod_names = [f"layers.{i}" for i in stage_indices]
|
|
stage_modules = [
|
|
full_mod.get_submodule(submod_name) for submod_name in submod_names
|
|
]
|
|
# Create a pipeline stage to wrap that submodule
|
|
num_microbatches = (
|
|
ScheduleClass.num_microbatches
|
|
if hasattr(ScheduleClass, "num_microbatches")
|
|
else 8
|
|
)
|
|
input_args = x.chunk(num_microbatches)[0]
|
|
stages = [
|
|
PipelineStage(
|
|
stage_module,
|
|
stage_idx,
|
|
n_stages,
|
|
self.device,
|
|
input_args=input_args,
|
|
)
|
|
for stage_module, stage_idx in zip(stage_modules, stage_indices)
|
|
]
|
|
|
|
# Attach to a schedule
|
|
schedule = ScheduleClass(stages, num_microbatches, loss_fn=loss_fn)
|
|
if use_new_runtime:
|
|
old_schedule = schedule
|
|
tmp_schedule = _PipelineScheduleRuntime(
|
|
stages,
|
|
num_microbatches,
|
|
loss_fn=loss_fn,
|
|
stage_index_to_group_rank=old_schedule.stage_index_to_group_rank,
|
|
use_full_backward=old_schedule.use_full_backward,
|
|
)
|
|
tmp_schedule._load_actions(old_schedule.pipeline_order)
|
|
# test that csv round-trip works for compute_comms schedule
|
|
schedule = _PipelineScheduleRuntime(
|
|
stages,
|
|
num_microbatches,
|
|
loss_fn=loss_fn,
|
|
stage_index_to_group_rank=old_schedule.stage_index_to_group_rank,
|
|
use_full_backward=old_schedule.use_full_backward,
|
|
)
|
|
with tempfile.NamedTemporaryFile() as f:
|
|
tmp_schedule._dump_csv(f.name)
|
|
f.seek(0)
|
|
schedule._load_csv(f.name, format="compute_comms")
|
|
one_more_schedule = _PipelineScheduleRuntime(
|
|
stages,
|
|
num_microbatches,
|
|
loss_fn=loss_fn,
|
|
stage_index_to_group_rank=old_schedule.stage_index_to_group_rank,
|
|
use_full_backward=old_schedule.use_full_backward,
|
|
)
|
|
one_more_schedule._load_actions(
|
|
schedule.pipeline_order_with_comms, format="compute_comms"
|
|
)
|
|
self.assertEqual(
|
|
len(schedule.pipeline_order_with_comms),
|
|
len(
|
|
one_more_schedule.pipeline_order_with_comms,
|
|
),
|
|
)
|
|
for rank in schedule.pipeline_order_with_comms:
|
|
self.assertEqual(
|
|
len(schedule.pipeline_order_with_comms[rank]),
|
|
len(
|
|
one_more_schedule.pipeline_order_with_comms[rank],
|
|
),
|
|
)
|
|
for a, b in zip(
|
|
schedule.pipeline_order_with_comms[rank],
|
|
one_more_schedule.pipeline_order_with_comms[rank],
|
|
):
|
|
self.assertEqual(a, b)
|
|
|
|
# Run
|
|
for _ in range(2):
|
|
# Zero gradients
|
|
for stage_module in stage_modules:
|
|
stage_module.zero_grad()
|
|
if self.rank == 0:
|
|
schedule.step(x)
|
|
elif self.rank == self.world_size - 1:
|
|
losses = []
|
|
out = schedule.step(target=target, losses=losses)
|
|
else:
|
|
schedule.step()
|
|
|
|
dist.barrier()
|
|
|
|
# Last rank checks result
|
|
if self.rank == self.world_size - 1:
|
|
# Check output
|
|
torch.testing.assert_close(out, ref_out)
|
|
# Check loss
|
|
# Since the reduction used in the loss function above is "sum", we use
|
|
# "sum" here to reduce microbatch losses into a single value too.
|
|
pipe_loss = sum(losses)
|
|
torch.testing.assert_close(pipe_loss, ref_loss)
|
|
|
|
# Every rank checks gradients
|
|
for stage_module, submod_name in zip(stage_modules, submod_names):
|
|
# Get corresponding submodule from reference model
|
|
ref_submod = ref_mod.get_submodule(submod_name)
|
|
# Check gradients per parameter
|
|
for name, p in stage_module.named_parameters():
|
|
ref_p = ref_submod.get_parameter(name)
|
|
try:
|
|
torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5)
|
|
except AssertionError:
|
|
print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
|
|
raise
|
|
|
|
@requires_nccl()
|
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
|
@parametrize("ScheduleClass", [ScheduleWithW, ScheduleFlexibleInterleaved1F1B])
|
|
def test_schedule_with_native_zero_bubble(self, ScheduleClass):
|
|
print(ScheduleClass)
|
|
if ScheduleClass is ScheduleFlexibleInterleaved1F1B:
|
|
n_stages = 4
|
|
num_microbatches = 8
|
|
rank_stages = {
|
|
0: [0, 2],
|
|
1: [1, 3],
|
|
}
|
|
else:
|
|
n_stages = ScheduleClass.n_stages
|
|
num_microbatches = ScheduleClass.num_microbatches
|
|
rank_stages = ScheduleClass.rank_stages
|
|
|
|
num_steps = 4
|
|
full_mod = MultiMLP(d_hid, n_layers=n_stages)
|
|
full_mod.to(self.device)
|
|
|
|
ref_mod = copy.deepcopy(full_mod)
|
|
x = torch.randn(batch_size, d_hid, device=self.device)
|
|
# x = torch.randn(batch_size, d_hid, device=self.device, requires_grad=True)
|
|
with torch.no_grad():
|
|
y = ref_mod(x)
|
|
# Add a small perturbation
|
|
target = y + torch.randn(batch_size, d_hid, device=self.device)
|
|
|
|
loss_fn = torch.nn.MSELoss(reduction="sum")
|
|
|
|
# Create a pipeline stage to wrap that submodule
|
|
input_args = x.chunk(num_microbatches)[0]
|
|
stage_indices = rank_stages[self.rank]
|
|
print(f"Rank {self.rank} stages: {stage_indices}")
|
|
submod_names = [f"layers.{i}" for i in stage_indices]
|
|
stage_modules = [
|
|
full_mod.get_submodule(submod_name) for submod_name in submod_names
|
|
]
|
|
stages = [
|
|
PipelineStage(
|
|
stage_module,
|
|
stage_idx,
|
|
n_stages,
|
|
self.device,
|
|
input_args=input_args,
|
|
)
|
|
for stage_module, stage_idx in zip(stage_modules, rank_stages[self.rank])
|
|
]
|
|
|
|
schedule = ScheduleClass(
|
|
stages, num_microbatches, loss_fn=loss_fn, enable_zero_bubble=True
|
|
)
|
|
|
|
# Run reference
|
|
ref_x = x.clone().detach().requires_grad_(x.requires_grad)
|
|
torch.testing.assert_close(x, ref_x)
|
|
for _ in range(num_steps):
|
|
ref_out = ref_mod(ref_x)
|
|
ref_loss = loss_fn(ref_out, target)
|
|
ref_loss.backward()
|
|
|
|
# Run pipelined stages
|
|
for _ in range(num_steps):
|
|
if self.rank == 0:
|
|
schedule.step(x)
|
|
elif self.rank == self.world_size - 1:
|
|
losses = []
|
|
out = schedule.step(target=target, losses=losses)
|
|
else:
|
|
schedule.step()
|
|
|
|
# Every rank checks parameters compared with the reference model
|
|
for stage_module, submod_name in zip(stage_modules, submod_names):
|
|
# Get corresponding submodule from reference model
|
|
ref_submod = ref_mod.get_submodule(submod_name)
|
|
# Check gradients per parameter
|
|
for name, p in stage_module.named_parameters():
|
|
ref_p = ref_submod.get_parameter(name)
|
|
try:
|
|
torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5)
|
|
except AssertionError:
|
|
print(
|
|
f"Parameter test failed for {submod_name}.{name}: {p.grad} vs {ref_p.grad}"
|
|
)
|
|
raise
|
|
|
|
@requires_nccl()
|
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
|
@parametrize("ScheduleClass", [ScheduleVShaped, ScheduleUnbalanced])
|
|
def test_non_symmetric_stage_ids(self, ScheduleClass):
|
|
n_stages = ScheduleClass.n_stages
|
|
full_mod = MultiMLP(d_hid, n_layers=n_stages)
|
|
full_mod.to(self.device)
|
|
|
|
ref_mod = copy.deepcopy(full_mod)
|
|
x = torch.randn(batch_size, d_hid, device=self.device)
|
|
with torch.no_grad():
|
|
y = ref_mod(x)
|
|
# Add a small perturbation
|
|
target = y + torch.randn(batch_size, d_hid, device=self.device)
|
|
|
|
loss_fn = torch.nn.MSELoss(reduction="sum")
|
|
|
|
# Run reference
|
|
for _ in range(2):
|
|
ref_mod.zero_grad()
|
|
ref_out = ref_mod(x)
|
|
ref_loss = loss_fn(ref_out, target)
|
|
ref_loss.backward()
|
|
|
|
# Create a pipeline stage to wrap that submodule
|
|
chunks = 1
|
|
input_args = x.chunk(chunks)[0]
|
|
rank_stages = ScheduleClass.rank_stages
|
|
stage_indices = rank_stages[self.rank]
|
|
print(f"Rank {self.rank} stages: {stage_indices}")
|
|
submod_names = [f"layers.{i}" for i in stage_indices]
|
|
stage_modules = [
|
|
full_mod.get_submodule(submod_name) for submod_name in submod_names
|
|
]
|
|
stages = [
|
|
PipelineStage(
|
|
stage_module,
|
|
stage_idx,
|
|
n_stages,
|
|
self.device,
|
|
input_args=input_args,
|
|
)
|
|
for stage_module, stage_idx in zip(stage_modules, rank_stages[self.rank])
|
|
]
|
|
|
|
# Attach to a schedule
|
|
stage_index_to_group_rank = {
|
|
value: key for key, values in rank_stages.items() for value in values
|
|
}
|
|
schedule = ScheduleClass(
|
|
stages, chunks, stage_index_to_group_rank, loss_fn=loss_fn
|
|
)
|
|
|
|
# Run
|
|
# TODO how to better specify .step() when first and last stage are on rank 0...
|
|
for _ in range(2):
|
|
# Zero gradients
|
|
for stage_module in stage_modules:
|
|
stage_module.zero_grad()
|
|
if self.rank == 0:
|
|
losses = []
|
|
out = schedule.step(x, target=target, losses=losses)
|
|
else:
|
|
schedule.step()
|
|
|
|
dist.barrier()
|
|
|
|
# Last rank checks result
|
|
if self.rank == 0:
|
|
# Check output
|
|
torch.testing.assert_close(out, ref_out)
|
|
# Check loss
|
|
# Since the reduction used in the loss function above is "sum", we use
|
|
# "sum" here to reduce microbatch losses into a single value too.
|
|
pipe_loss = sum(losses)
|
|
torch.testing.assert_close(pipe_loss, ref_loss)
|
|
|
|
# Every rank checks gradients
|
|
for stage_module, submod_name in zip(stage_modules, submod_names):
|
|
# Get corresponding submodule from reference model
|
|
ref_submod = ref_mod.get_submodule(submod_name)
|
|
# Check gradients per parameter
|
|
for name, p in stage_module.named_parameters():
|
|
ref_p = ref_submod.get_parameter(name)
|
|
try:
|
|
torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5)
|
|
except AssertionError:
|
|
print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
|
|
raise
|
|
|
|
@requires_nccl()
|
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
|
@parametrize("ScheduleClass", [ScheduleFlexibleInterleaved1F1B])
|
|
def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass):
|
|
stages_per_rank = 2
|
|
n_stages = stages_per_rank * self.world_size
|
|
full_mod = MultiMLPWithDw(d_hid, n_layers=n_stages)
|
|
full_mod.to(self.device)
|
|
|
|
ref_mod = copy.deepcopy(full_mod)
|
|
x = torch.randn(batch_size, d_hid, device=self.device)
|
|
with torch.no_grad():
|
|
y = ref_mod(x)
|
|
# Add a small perturbation
|
|
target = y + torch.randn(batch_size, d_hid, device=self.device)
|
|
|
|
ref_loss_fn = torch.nn.MSELoss(reduction="sum")
|
|
full_loss_fn = torch.nn.MSELoss(reduction="sum")
|
|
|
|
full_mod.toggle()
|
|
|
|
# Get a submodule, e.g. `layers.0` or `layers.1`
|
|
stage_indices = [
|
|
self.rank + i * self.world_size for i in range(stages_per_rank)
|
|
]
|
|
submod_names = [f"layers.{i}" for i in stage_indices]
|
|
stage_modules = [
|
|
full_mod.get_submodule(submod_name) for submod_name in submod_names
|
|
]
|
|
|
|
# Run reference
|
|
for _ in range(2):
|
|
ref_stage_modules = [
|
|
ref_mod.get_submodule(submod_name) for submod_name in submod_names
|
|
]
|
|
for stage_module in ref_stage_modules:
|
|
stage_module.zero_grad()
|
|
|
|
ref_mod.zero_grad()
|
|
ref_out = ref_mod(x)
|
|
ref_loss = ref_loss_fn(ref_out, target)
|
|
ref_loss.backward()
|
|
|
|
class CustomState:
|
|
def __init__(self, stage_module, stage_idx, rank):
|
|
self.i = 0
|
|
self.stage_module = stage_module
|
|
self.stage_idx = stage_idx
|
|
self.rank = rank
|
|
|
|
def dw_builder(self):
|
|
def dw_runner():
|
|
# This inner function would be called by PipelineStage during `backward_weight_one_chunk`
|
|
self.i += 1
|
|
print(
|
|
f"[Rank {self.rank}] dw_count={self.i} stage={self.stage_idx}"
|
|
)
|
|
self.stage_module.compute_dW()
|
|
|
|
return dw_runner
|
|
|
|
cs = {}
|
|
for stage_module, stage_idx in zip(stage_modules, stage_indices):
|
|
cs[stage_idx] = CustomState(stage_module, stage_idx, self.rank)
|
|
|
|
# Create a pipeline stage to wrap that submodule
|
|
chunks = 2
|
|
input_args = x.chunk(chunks)[0]
|
|
stages = [
|
|
PipelineStage(
|
|
stage_module,
|
|
stage_idx,
|
|
n_stages,
|
|
self.device,
|
|
input_args=input_args,
|
|
dw_builder=cs[stage_idx].dw_builder,
|
|
)
|
|
for stage_module, stage_idx in zip(stage_modules, stage_indices)
|
|
]
|
|
|
|
# Attach to a schedule
|
|
schedule = ScheduleClass(
|
|
stages, chunks, loss_fn=full_loss_fn, enable_zero_bubble=True
|
|
)
|
|
|
|
for _ in range(2):
|
|
# Zero gradients
|
|
for stage_module in stage_modules:
|
|
stage_module.zero_grad()
|
|
if self.rank == 0:
|
|
schedule.step(x)
|
|
elif self.rank == self.world_size - 1:
|
|
losses = []
|
|
out = schedule.step(target=target, losses=losses)
|
|
else:
|
|
schedule.step()
|
|
|
|
dist.barrier()
|
|
# Last rank checks result
|
|
if self.rank == self.world_size - 1:
|
|
# Check output
|
|
torch.testing.assert_close(out, ref_out)
|
|
|
|
# Check loss
|
|
# Since the reduction used in the loss function above is "sum", we use
|
|
# "sum" here to reduce microbatch losses into a single value too.
|
|
pipe_loss = sum(losses)
|
|
torch.testing.assert_close(pipe_loss, ref_loss)
|
|
|
|
# Every rank checks gradients
|
|
for stage_module, submod_name in zip(stage_modules, submod_names):
|
|
# Get corresponding submodule from reference model
|
|
ref_submod = ref_mod.get_submodule(submod_name)
|
|
# Check gradients per parameter
|
|
for name, p in stage_module.named_parameters():
|
|
ref_p = ref_submod.get_parameter(name)
|
|
torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5)
|
|
|
|
|
|
instantiate_parametrized_tests(ScheduleTest)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Check if GPU and NCCL are available
|
|
if not (
|
|
dist.is_available()
|
|
and dist.is_nccl_available()
|
|
and torch.cuda.device_count() > 1
|
|
):
|
|
print(
|
|
"c10d NCCL not available or not enough GPUs, skipping tests",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(0)
|
|
|
|
rank = int(os.getenv("RANK", -1))
|
|
world_size = int(os.getenv("WORLD_SIZE", 2))
|
|
|
|
if rank != -1:
|
|
# Launched with torchrun or other multi-proc launchers. Directly run the test.
|
|
ScheduleTest.run_rank(rank, world_size)
|
|
else:
|
|
# Launched as a single process. Spawn subprocess to run the tests.
|
|
# Also need a rendezvous file for `init_process_group` purpose.
|
|
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
|
|
torch.multiprocessing.spawn(
|
|
ScheduleTest.run_rank,
|
|
nprocs=world_size,
|
|
args=(world_size, rdvz_file),
|
|
)
|