[1/N]Port 3 distributed/_tools test cases to Intel GPU (#159543)

For [#114850](https://github.com/pytorch/pytorch/issues/114850), we will port distributed tests to Intel GPU.

We could enable Intel GPU with following methods and try the best to keep the original code styles:

1. use "torch.accelerator.current_accelerator()" to determine the accelerator backend
2. enabled XPU for some test path
3. skip some test cases which Intel GPU does not support

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159543
Approved by: https://github.com/guangyey, https://github.com/d4l3k

Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
This commit is contained in:
libohao 2025-08-13 12:48:57 +00:00 committed by PyTorch MergeBot
parent 42e51cd4b3
commit ee1b0412b9
3 changed files with 73 additions and 56 deletions

View File

@ -37,15 +37,16 @@ def _init_cublas_workspace(dev: torch.device):
def _reset_mem_stats(dev: torch.device): def _reset_mem_stats(dev: torch.device):
torch.cuda.empty_cache() mod = torch.get_device_module(dev)
torch.cuda.reset_accumulated_memory_stats(dev) mod.empty_cache()
torch.cuda.reset_peak_memory_stats(dev) mod.reset_accumulated_memory_stats(dev)
mod.reset_peak_memory_stats(dev)
class TestTrackerFullyShard1DTrainingCore(FSDPTest): class TestTrackerFullyShard1DTrainingCore(FSDPTest):
@property @property
def world_size(self) -> int: def world_size(self) -> int:
return min(4, torch.cuda.device_count()) return min(4, torch.accelerator.device_count())
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
def test_tracker_multi_group_eager(self): def test_tracker_multi_group_eager(self):
@ -77,17 +78,18 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
mp_policy: MixedPrecisionPolicy, mp_policy: MixedPrecisionPolicy,
): ):
debug = False debug = False
dev = torch.device(torch.cuda.current_device()) dev = torch.device(torch.accelerator.current_device_index())
_init_cublas_workspace(dev) _init_cublas_workspace(dev)
gc.collect() gc.collect()
_reset_mem_stats(dev) _reset_mem_stats(dev)
mem_stats = torch.cuda.memory_stats(dev) mod = torch.get_device_module(dev)
pre_cuda_active = mem_stats["active_bytes.all.current"] mem_stats = mod.memory_stats(dev)
pre_acc_active = mem_stats["active_bytes.all.current"]
torch.manual_seed(42) torch.manual_seed(42)
lin_dim, bsz = 2048, 8192 lin_dim, bsz = 2048, 8192
with torch.device(dev): with torch.device(dev):
model = nn.Sequential(*[MLP(dim=lin_dim, device=dev) for _ in range(4)]) model = nn.Sequential(*[MLP(dim=lin_dim, device=dev) for _ in range(4)])
mesh = init_device_mesh("cuda", (self.world_size,)) mesh = init_device_mesh(dev.type, (self.world_size,))
fully_shard_fn = functools.partial( fully_shard_fn = functools.partial(
fully_shard, fully_shard,
mesh=mesh, mesh=mesh,
@ -110,17 +112,19 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
optim.zero_grad() optim.zero_grad()
if iter_idx == 0: if iter_idx == 0:
fmt.reset_mod_stats() fmt.reset_mod_stats()
mem_stats = torch.cuda.memory_stats() mem_stats = mod.memory_stats()
tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"] tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"]
cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active acc_max = mem_stats["active_bytes.all.peak"] - pre_acc_active
accuracy = tracker_max / cuda_max accuracy = tracker_max / acc_max
if self.rank == 0 and debug: if self.rank == 0 and debug:
print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}") print(
f"Accuracy: {accuracy} Tracker Max:{tracker_max} Accelerator Max:{acc_max}"
)
self.assertAlmostEqual( self.assertAlmostEqual(
accuracy, accuracy,
1.0, 1.0,
delta=0.1, delta=0.1,
msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}", msg=f"Tracker Max:{tracker_max} Accelerator Max:{acc_max}",
) )
del model del model
del inp del inp
@ -132,12 +136,13 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
Tests tracker accuracy when running forward/backward through a non-root. Tests tracker accuracy when running forward/backward through a non-root.
""" """
debug = False debug = False
dev = torch.device(torch.cuda.current_device()) dev = torch.device(torch.accelerator.current_device_index())
_init_cublas_workspace(dev) _init_cublas_workspace(dev)
gc.collect() gc.collect()
_reset_mem_stats(dev) _reset_mem_stats(dev)
mem_stats = torch.cuda.memory_stats(dev) mod = torch.get_device_module(dev)
pre_cuda_active = mem_stats["active_bytes.all.current"] mem_stats = mod.memory_stats(dev)
pre_acc_active = mem_stats["active_bytes.all.current"]
torch.manual_seed(42) torch.manual_seed(42)
lin_dim, bsz = 2048, 8 lin_dim, bsz = 2048, 8
model = nn.Sequential(*[MLP(lin_dim, dev) for _ in range(3)]) model = nn.Sequential(*[MLP(lin_dim, dev) for _ in range(3)])
@ -157,17 +162,19 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
optim.zero_grad() optim.zero_grad()
if iter_idx == 0: if iter_idx == 0:
fmt.reset_mod_stats() fmt.reset_mod_stats()
mem_stats = torch.cuda.memory_stats() mem_stats = mod.memory_stats()
tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"] tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"]
cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active acc_max = mem_stats["active_bytes.all.peak"] - pre_acc_active
accuracy = tracker_max / cuda_max accuracy = tracker_max / acc_max
if self.rank == 0 and debug: if self.rank == 0 and debug:
print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}") print(
f"Accuracy: {accuracy} Tracker Max:{tracker_max} Accelerator Max:{acc_max}"
)
self.assertAlmostEqual( self.assertAlmostEqual(
accuracy, accuracy,
1.0, 1.0,
delta=0.1, delta=0.1,
msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}", msg=f"Tracker Max:{tracker_max} Accelerator Max:{acc_max}",
) )
del inp del inp
del model del model
@ -177,7 +184,7 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
class TestTrackerFullyShard1DTrainingCompose(FSDPTest): class TestTrackerFullyShard1DTrainingCompose(FSDPTest):
@property @property
def world_size(self) -> int: def world_size(self) -> int:
return min(torch.cuda.device_count(), 4) return min(torch.accelerator.device_count(), 4)
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
def test_tracker_with_activation_checkpointing(self): def test_tracker_with_activation_checkpointing(self):
@ -197,12 +204,13 @@ class TestTrackerFullyShard1DTrainingCompose(FSDPTest):
): ):
assert checkpoint_impl in ("composable", "wrapper") assert checkpoint_impl in ("composable", "wrapper")
debug = False debug = False
dev = torch.device(torch.cuda.current_device()) dev = torch.device(torch.accelerator.current_device_index())
_init_cublas_workspace(dev) _init_cublas_workspace(dev)
gc.collect() gc.collect()
_reset_mem_stats(dev) _reset_mem_stats(dev)
mem_stats = torch.cuda.memory_stats(dev) mod = torch.get_device_module(dev)
pre_cuda_active = mem_stats["active_bytes.all.current"] mem_stats = mod.memory_stats(dev)
pre_acc_active = mem_stats["active_bytes.all.current"]
torch.manual_seed(42) torch.manual_seed(42)
vocab_size = 8192 vocab_size = 8192
bsz, seq_len = 16, 512 bsz, seq_len = 16, 512
@ -249,17 +257,19 @@ class TestTrackerFullyShard1DTrainingCompose(FSDPTest):
optim.zero_grad() optim.zero_grad()
if iter_idx == 0: if iter_idx == 0:
fmt.reset_mod_stats() fmt.reset_mod_stats()
mem_stats = torch.cuda.memory_stats() mem_stats = mod.memory_stats()
tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"] tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"]
cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active acc_max = mem_stats["active_bytes.all.peak"] - pre_acc_active
accuracy = tracker_max / cuda_max accuracy = tracker_max / acc_max
if self.rank == 0 and debug: if self.rank == 0 and debug:
print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}") print(
f"Accuracy: {accuracy} Tracker Max:{tracker_max} Accelerator Max:{acc_max}"
)
self.assertAlmostEqual( self.assertAlmostEqual(
accuracy, accuracy,
1.0, 1.0,
delta=0.1, delta=0.1,
msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}", msg=f"Tracker Max:{tracker_max} Accelerator Max:{acc_max}",
) )
del inp del inp
del model del model

View File

@ -5,11 +5,12 @@ import unittest
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed._tools.mem_tracker import MemTracker from torch.distributed._tools.mem_tracker import MemTracker
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
run_tests, run_tests,
skipIfRocm, skipIfRocm,
skipIfTorchDynamo, skipIfTorchDynamo,
TEST_CUDA,
TEST_XPU,
TestCase, TestCase,
) )
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
@ -24,25 +25,29 @@ class TestMemTracker(TestCase):
del inp del inp
def _reset_mem_stats(self, dev: torch.device): def _reset_mem_stats(self, dev: torch.device):
torch.cuda.empty_cache() mod = torch.get_device_module(dev)
torch.cuda.reset_accumulated_memory_stats(dev) mod.empty_cache()
torch.cuda.reset_peak_memory_stats(dev) mod.reset_accumulated_memory_stats(dev)
mod.reset_peak_memory_stats(dev)
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
@unittest.skipIf(not TEST_CUDA, "CUDA not available") @unittest.skipIf(
not TEST_CUDA and not TEST_XPU, "Neither CUDA or XPU is not available"
)
@skipIfRocm() @skipIfRocm()
def test_cuda_tracker_equivalence( def test_accelerator_tracker_equivalence(
self, self,
): ):
""" """
Tests that the tracker correctly calculates the peak memory. Tests that the tracker correctly calculates the peak memory.
""" """
dev = torch.device(torch.cuda.current_device()) dev = torch.device(torch.accelerator.current_device_index())
self._init_cublas_workspace(dev) self._init_cublas_workspace(dev)
gc.collect(1) gc.collect(1)
self._reset_mem_stats(dev) self._reset_mem_stats(dev)
mem_stats = torch.cuda.memory_stats(dev) mod = torch.get_device_module(dev)
pre_cuda_active = mem_stats["active_bytes.all.current"] mem_stats = mod.memory_stats(dev)
pre_acc_active = mem_stats["active_bytes.all.current"]
bsz, n_layers, dim, dtype = 16, 4, 512, torch.bfloat16 bsz, n_layers, dim, dtype = 16, 4, 512, torch.bfloat16
class DummyModel(nn.Module): class DummyModel(nn.Module):
@ -74,25 +79,28 @@ class TestMemTracker(TestCase):
# Check for accuracy of peak memory # Check for accuracy of peak memory
tracker_max = mt.get_tracker_snapshot("peak")[dev]["Total"] tracker_max = mt.get_tracker_snapshot("peak")[dev]["Total"]
mem_stats = torch.cuda.memory_stats(dev) mem_stats = mod.memory_stats(dev)
cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active acc_max = mem_stats["active_bytes.all.peak"] - pre_acc_active
accuracy = tracker_max / cuda_max accuracy = tracker_max / acc_max
self.assertAlmostEqual(accuracy, 1.0, delta=0.1) self.assertAlmostEqual(accuracy, 1.0, delta=0.1)
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
@unittest.skipIf(not TEST_CUDA, "CUDA not available") @unittest.skipIf(
not TEST_CUDA and not TEST_XPU, "Neither CUDA or XPU is not available"
)
def test_tracker_with_activation_checkpointing( def test_tracker_with_activation_checkpointing(
self, self,
): ):
""" """
Tests that the tracker correctly computes the peak memory during activation checkpointing. Tests that the tracker correctly computes the peak memory during activation checkpointing.
""" """
dev = torch.device(torch.cuda.current_device()) dev = torch.device(torch.accelerator.current_device_index())
self._init_cublas_workspace(dev) self._init_cublas_workspace(dev)
gc.collect(1) gc.collect(1)
self._reset_mem_stats(dev) self._reset_mem_stats(dev)
mem_stats = torch.cuda.memory_stats(dev) mod = torch.get_device_module(dev)
pre_cuda_active = mem_stats["active_bytes.all.current"] mem_stats = mod.memory_stats(dev)
pre_acc_active = mem_stats["active_bytes.all.current"]
bsz, n_layers, dim, dtype = 128, 4, 1024, torch.float16 bsz, n_layers, dim, dtype = 128, 4, 1024, torch.float16
@ -144,9 +152,9 @@ class TestMemTracker(TestCase):
# Check for accuracy of peak memory # Check for accuracy of peak memory
tracker_max = mt.get_tracker_snapshot("peak")[dev]["Total"] tracker_max = mt.get_tracker_snapshot("peak")[dev]["Total"]
mem_stats = torch.cuda.memory_stats(dev) mem_stats = mod.memory_stats(dev)
cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active acc_max = mem_stats["active_bytes.all.peak"] - pre_acc_active
accuracy = tracker_max / cuda_max accuracy = tracker_max / acc_max
self.assertAlmostEqual(accuracy, 1.0, delta=0.1) self.assertAlmostEqual(accuracy, 1.0, delta=0.1)
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")

View File

@ -5,19 +5,18 @@ import unittest
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed._tools import MemoryTracker from torch.distributed._tools import MemoryTracker
from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase
class TestMemoryTracker(TestCase): class TestMemoryTracker(TestCase):
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "no cuda/xpu") @unittest.skipIf(not torch.accelerator.is_available(), "no accelerator")
def test_local_model(self): def test_local_model(self):
""" """
Minimal test case to check the memory tracker can collect the expected Minimal test case to check the memory tracker can collect the expected
memory stats at operator level, as well as can print the summary result memory stats at operator level, as well as can print the summary result
without crash. without crash.
""" """
device = "cuda" if TEST_CUDA else "xpu" device = torch.accelerator.current_accelerator()
# Create a model with a hierarchy of modules # Create a model with a hierarchy of modules
torch.manual_seed(0) torch.manual_seed(0)
model = nn.Sequential( model = nn.Sequential(
@ -35,9 +34,9 @@ class TestMemoryTracker(TestCase):
tracker = MemoryTracker() tracker = MemoryTracker()
tracker.start_monitor(model) tracker.start_monitor(model)
x = torch.randn(size=(2, 3, 224, 224), device=torch.device(device)) x = torch.randn(size=(2, 3, 224, 224), device=device)
# torch.LongTensor expects cpu device type, not device type in # torch.LongTensor expects cpu device type, not gpu device type in
# constructor, so calling .to(device) outside constructor here. # constructor, so calling .to() outside constructor here.
target = torch.LongTensor([0, 1]).to(device) target = torch.LongTensor([0, 1]).to(device)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
criterion(model(x), target).backward() criterion(model(x), target).backward()