[1/N]Port 3 distributed/_tools test cases to Intel GPU (#159543)

For [#114850](https://github.com/pytorch/pytorch/issues/114850), we will port distributed tests to Intel GPU.

We could enable Intel GPU with following methods and try the best to keep the original code styles:

1. use "torch.accelerator.current_accelerator()" to determine the accelerator backend
2. enabled XPU for some test path
3. skip some test cases which Intel GPU does not support

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159543
Approved by: https://github.com/guangyey, https://github.com/d4l3k

Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
This commit is contained in:
libohao 2025-08-13 12:48:57 +00:00 committed by PyTorch MergeBot
parent 42e51cd4b3
commit ee1b0412b9
3 changed files with 73 additions and 56 deletions

View File

@ -37,15 +37,16 @@ def _init_cublas_workspace(dev: torch.device):
def _reset_mem_stats(dev: torch.device):
torch.cuda.empty_cache()
torch.cuda.reset_accumulated_memory_stats(dev)
torch.cuda.reset_peak_memory_stats(dev)
mod = torch.get_device_module(dev)
mod.empty_cache()
mod.reset_accumulated_memory_stats(dev)
mod.reset_peak_memory_stats(dev)
class TestTrackerFullyShard1DTrainingCore(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
return min(4, torch.accelerator.device_count())
@skip_if_lt_x_gpu(2)
def test_tracker_multi_group_eager(self):
@ -77,17 +78,18 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
mp_policy: MixedPrecisionPolicy,
):
debug = False
dev = torch.device(torch.cuda.current_device())
dev = torch.device(torch.accelerator.current_device_index())
_init_cublas_workspace(dev)
gc.collect()
_reset_mem_stats(dev)
mem_stats = torch.cuda.memory_stats(dev)
pre_cuda_active = mem_stats["active_bytes.all.current"]
mod = torch.get_device_module(dev)
mem_stats = mod.memory_stats(dev)
pre_acc_active = mem_stats["active_bytes.all.current"]
torch.manual_seed(42)
lin_dim, bsz = 2048, 8192
with torch.device(dev):
model = nn.Sequential(*[MLP(dim=lin_dim, device=dev) for _ in range(4)])
mesh = init_device_mesh("cuda", (self.world_size,))
mesh = init_device_mesh(dev.type, (self.world_size,))
fully_shard_fn = functools.partial(
fully_shard,
mesh=mesh,
@ -110,17 +112,19 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
optim.zero_grad()
if iter_idx == 0:
fmt.reset_mod_stats()
mem_stats = torch.cuda.memory_stats()
mem_stats = mod.memory_stats()
tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"]
cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active
accuracy = tracker_max / cuda_max
acc_max = mem_stats["active_bytes.all.peak"] - pre_acc_active
accuracy = tracker_max / acc_max
if self.rank == 0 and debug:
print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}")
print(
f"Accuracy: {accuracy} Tracker Max:{tracker_max} Accelerator Max:{acc_max}"
)
self.assertAlmostEqual(
accuracy,
1.0,
delta=0.1,
msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}",
msg=f"Tracker Max:{tracker_max} Accelerator Max:{acc_max}",
)
del model
del inp
@ -132,12 +136,13 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
Tests tracker accuracy when running forward/backward through a non-root.
"""
debug = False
dev = torch.device(torch.cuda.current_device())
dev = torch.device(torch.accelerator.current_device_index())
_init_cublas_workspace(dev)
gc.collect()
_reset_mem_stats(dev)
mem_stats = torch.cuda.memory_stats(dev)
pre_cuda_active = mem_stats["active_bytes.all.current"]
mod = torch.get_device_module(dev)
mem_stats = mod.memory_stats(dev)
pre_acc_active = mem_stats["active_bytes.all.current"]
torch.manual_seed(42)
lin_dim, bsz = 2048, 8
model = nn.Sequential(*[MLP(lin_dim, dev) for _ in range(3)])
@ -157,17 +162,19 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
optim.zero_grad()
if iter_idx == 0:
fmt.reset_mod_stats()
mem_stats = torch.cuda.memory_stats()
mem_stats = mod.memory_stats()
tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"]
cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active
accuracy = tracker_max / cuda_max
acc_max = mem_stats["active_bytes.all.peak"] - pre_acc_active
accuracy = tracker_max / acc_max
if self.rank == 0 and debug:
print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}")
print(
f"Accuracy: {accuracy} Tracker Max:{tracker_max} Accelerator Max:{acc_max}"
)
self.assertAlmostEqual(
accuracy,
1.0,
delta=0.1,
msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}",
msg=f"Tracker Max:{tracker_max} Accelerator Max:{acc_max}",
)
del inp
del model
@ -177,7 +184,7 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
class TestTrackerFullyShard1DTrainingCompose(FSDPTest):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 4)
return min(torch.accelerator.device_count(), 4)
@skip_if_lt_x_gpu(2)
def test_tracker_with_activation_checkpointing(self):
@ -197,12 +204,13 @@ class TestTrackerFullyShard1DTrainingCompose(FSDPTest):
):
assert checkpoint_impl in ("composable", "wrapper")
debug = False
dev = torch.device(torch.cuda.current_device())
dev = torch.device(torch.accelerator.current_device_index())
_init_cublas_workspace(dev)
gc.collect()
_reset_mem_stats(dev)
mem_stats = torch.cuda.memory_stats(dev)
pre_cuda_active = mem_stats["active_bytes.all.current"]
mod = torch.get_device_module(dev)
mem_stats = mod.memory_stats(dev)
pre_acc_active = mem_stats["active_bytes.all.current"]
torch.manual_seed(42)
vocab_size = 8192
bsz, seq_len = 16, 512
@ -249,17 +257,19 @@ class TestTrackerFullyShard1DTrainingCompose(FSDPTest):
optim.zero_grad()
if iter_idx == 0:
fmt.reset_mod_stats()
mem_stats = torch.cuda.memory_stats()
mem_stats = mod.memory_stats()
tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"]
cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active
accuracy = tracker_max / cuda_max
acc_max = mem_stats["active_bytes.all.peak"] - pre_acc_active
accuracy = tracker_max / acc_max
if self.rank == 0 and debug:
print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}")
print(
f"Accuracy: {accuracy} Tracker Max:{tracker_max} Accelerator Max:{acc_max}"
)
self.assertAlmostEqual(
accuracy,
1.0,
delta=0.1,
msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}",
msg=f"Tracker Max:{tracker_max} Accelerator Max:{acc_max}",
)
del inp
del model

View File

@ -5,11 +5,12 @@ import unittest
import torch
import torch.nn as nn
from torch.distributed._tools.mem_tracker import MemTracker
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import (
run_tests,
skipIfRocm,
skipIfTorchDynamo,
TEST_CUDA,
TEST_XPU,
TestCase,
)
from torch.utils.checkpoint import checkpoint
@ -24,25 +25,29 @@ class TestMemTracker(TestCase):
del inp
def _reset_mem_stats(self, dev: torch.device):
torch.cuda.empty_cache()
torch.cuda.reset_accumulated_memory_stats(dev)
torch.cuda.reset_peak_memory_stats(dev)
mod = torch.get_device_module(dev)
mod.empty_cache()
mod.reset_accumulated_memory_stats(dev)
mod.reset_peak_memory_stats(dev)
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
@unittest.skipIf(
not TEST_CUDA and not TEST_XPU, "Neither CUDA or XPU is not available"
)
@skipIfRocm()
def test_cuda_tracker_equivalence(
def test_accelerator_tracker_equivalence(
self,
):
"""
Tests that the tracker correctly calculates the peak memory.
"""
dev = torch.device(torch.cuda.current_device())
dev = torch.device(torch.accelerator.current_device_index())
self._init_cublas_workspace(dev)
gc.collect(1)
self._reset_mem_stats(dev)
mem_stats = torch.cuda.memory_stats(dev)
pre_cuda_active = mem_stats["active_bytes.all.current"]
mod = torch.get_device_module(dev)
mem_stats = mod.memory_stats(dev)
pre_acc_active = mem_stats["active_bytes.all.current"]
bsz, n_layers, dim, dtype = 16, 4, 512, torch.bfloat16
class DummyModel(nn.Module):
@ -74,25 +79,28 @@ class TestMemTracker(TestCase):
# Check for accuracy of peak memory
tracker_max = mt.get_tracker_snapshot("peak")[dev]["Total"]
mem_stats = torch.cuda.memory_stats(dev)
cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active
accuracy = tracker_max / cuda_max
mem_stats = mod.memory_stats(dev)
acc_max = mem_stats["active_bytes.all.peak"] - pre_acc_active
accuracy = tracker_max / acc_max
self.assertAlmostEqual(accuracy, 1.0, delta=0.1)
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
@unittest.skipIf(
not TEST_CUDA and not TEST_XPU, "Neither CUDA or XPU is not available"
)
def test_tracker_with_activation_checkpointing(
self,
):
"""
Tests that the tracker correctly computes the peak memory during activation checkpointing.
"""
dev = torch.device(torch.cuda.current_device())
dev = torch.device(torch.accelerator.current_device_index())
self._init_cublas_workspace(dev)
gc.collect(1)
self._reset_mem_stats(dev)
mem_stats = torch.cuda.memory_stats(dev)
pre_cuda_active = mem_stats["active_bytes.all.current"]
mod = torch.get_device_module(dev)
mem_stats = mod.memory_stats(dev)
pre_acc_active = mem_stats["active_bytes.all.current"]
bsz, n_layers, dim, dtype = 128, 4, 1024, torch.float16
@ -144,9 +152,9 @@ class TestMemTracker(TestCase):
# Check for accuracy of peak memory
tracker_max = mt.get_tracker_snapshot("peak")[dev]["Total"]
mem_stats = torch.cuda.memory_stats(dev)
cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active
accuracy = tracker_max / cuda_max
mem_stats = mod.memory_stats(dev)
acc_max = mem_stats["active_bytes.all.peak"] - pre_acc_active
accuracy = tracker_max / acc_max
self.assertAlmostEqual(accuracy, 1.0, delta=0.1)
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")

View File

@ -5,19 +5,18 @@ import unittest
import torch
import torch.nn as nn
from torch.distributed._tools import MemoryTracker
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase
from torch.testing._internal.common_utils import run_tests, TestCase
class TestMemoryTracker(TestCase):
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "no cuda/xpu")
@unittest.skipIf(not torch.accelerator.is_available(), "no accelerator")
def test_local_model(self):
"""
Minimal test case to check the memory tracker can collect the expected
memory stats at operator level, as well as can print the summary result
without crash.
"""
device = "cuda" if TEST_CUDA else "xpu"
device = torch.accelerator.current_accelerator()
# Create a model with a hierarchy of modules
torch.manual_seed(0)
model = nn.Sequential(
@ -35,9 +34,9 @@ class TestMemoryTracker(TestCase):
tracker = MemoryTracker()
tracker.start_monitor(model)
x = torch.randn(size=(2, 3, 224, 224), device=torch.device(device))
# torch.LongTensor expects cpu device type, not device type in
# constructor, so calling .to(device) outside constructor here.
x = torch.randn(size=(2, 3, 224, 224), device=device)
# torch.LongTensor expects cpu device type, not gpu device type in
# constructor, so calling .to() outside constructor here.
target = torch.LongTensor([0, 1]).to(device)
criterion = nn.CrossEntropyLoss()
criterion(model(x), target).backward()