mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
[1/N]Port 3 distributed/_tools test cases to Intel GPU (#159543)
For [#114850](https://github.com/pytorch/pytorch/issues/114850), we will port distributed tests to Intel GPU. We could enable Intel GPU with following methods and try the best to keep the original code styles: 1. use "torch.accelerator.current_accelerator()" to determine the accelerator backend 2. enabled XPU for some test path 3. skip some test cases which Intel GPU does not support Pull Request resolved: https://github.com/pytorch/pytorch/pull/159543 Approved by: https://github.com/guangyey, https://github.com/d4l3k Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
This commit is contained in:
parent
42e51cd4b3
commit
ee1b0412b9
|
|
@ -37,15 +37,16 @@ def _init_cublas_workspace(dev: torch.device):
|
|||
|
||||
|
||||
def _reset_mem_stats(dev: torch.device):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_accumulated_memory_stats(dev)
|
||||
torch.cuda.reset_peak_memory_stats(dev)
|
||||
mod = torch.get_device_module(dev)
|
||||
mod.empty_cache()
|
||||
mod.reset_accumulated_memory_stats(dev)
|
||||
mod.reset_peak_memory_stats(dev)
|
||||
|
||||
|
||||
class TestTrackerFullyShard1DTrainingCore(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(4, torch.cuda.device_count())
|
||||
return min(4, torch.accelerator.device_count())
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_tracker_multi_group_eager(self):
|
||||
|
|
@ -77,17 +78,18 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
|
|||
mp_policy: MixedPrecisionPolicy,
|
||||
):
|
||||
debug = False
|
||||
dev = torch.device(torch.cuda.current_device())
|
||||
dev = torch.device(torch.accelerator.current_device_index())
|
||||
_init_cublas_workspace(dev)
|
||||
gc.collect()
|
||||
_reset_mem_stats(dev)
|
||||
mem_stats = torch.cuda.memory_stats(dev)
|
||||
pre_cuda_active = mem_stats["active_bytes.all.current"]
|
||||
mod = torch.get_device_module(dev)
|
||||
mem_stats = mod.memory_stats(dev)
|
||||
pre_acc_active = mem_stats["active_bytes.all.current"]
|
||||
torch.manual_seed(42)
|
||||
lin_dim, bsz = 2048, 8192
|
||||
with torch.device(dev):
|
||||
model = nn.Sequential(*[MLP(dim=lin_dim, device=dev) for _ in range(4)])
|
||||
mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
mesh = init_device_mesh(dev.type, (self.world_size,))
|
||||
fully_shard_fn = functools.partial(
|
||||
fully_shard,
|
||||
mesh=mesh,
|
||||
|
|
@ -110,17 +112,19 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
|
|||
optim.zero_grad()
|
||||
if iter_idx == 0:
|
||||
fmt.reset_mod_stats()
|
||||
mem_stats = torch.cuda.memory_stats()
|
||||
mem_stats = mod.memory_stats()
|
||||
tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"]
|
||||
cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active
|
||||
accuracy = tracker_max / cuda_max
|
||||
acc_max = mem_stats["active_bytes.all.peak"] - pre_acc_active
|
||||
accuracy = tracker_max / acc_max
|
||||
if self.rank == 0 and debug:
|
||||
print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}")
|
||||
print(
|
||||
f"Accuracy: {accuracy} Tracker Max:{tracker_max} Accelerator Max:{acc_max}"
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
accuracy,
|
||||
1.0,
|
||||
delta=0.1,
|
||||
msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}",
|
||||
msg=f"Tracker Max:{tracker_max} Accelerator Max:{acc_max}",
|
||||
)
|
||||
del model
|
||||
del inp
|
||||
|
|
@ -132,12 +136,13 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
|
|||
Tests tracker accuracy when running forward/backward through a non-root.
|
||||
"""
|
||||
debug = False
|
||||
dev = torch.device(torch.cuda.current_device())
|
||||
dev = torch.device(torch.accelerator.current_device_index())
|
||||
_init_cublas_workspace(dev)
|
||||
gc.collect()
|
||||
_reset_mem_stats(dev)
|
||||
mem_stats = torch.cuda.memory_stats(dev)
|
||||
pre_cuda_active = mem_stats["active_bytes.all.current"]
|
||||
mod = torch.get_device_module(dev)
|
||||
mem_stats = mod.memory_stats(dev)
|
||||
pre_acc_active = mem_stats["active_bytes.all.current"]
|
||||
torch.manual_seed(42)
|
||||
lin_dim, bsz = 2048, 8
|
||||
model = nn.Sequential(*[MLP(lin_dim, dev) for _ in range(3)])
|
||||
|
|
@ -157,17 +162,19 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
|
|||
optim.zero_grad()
|
||||
if iter_idx == 0:
|
||||
fmt.reset_mod_stats()
|
||||
mem_stats = torch.cuda.memory_stats()
|
||||
mem_stats = mod.memory_stats()
|
||||
tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"]
|
||||
cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active
|
||||
accuracy = tracker_max / cuda_max
|
||||
acc_max = mem_stats["active_bytes.all.peak"] - pre_acc_active
|
||||
accuracy = tracker_max / acc_max
|
||||
if self.rank == 0 and debug:
|
||||
print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}")
|
||||
print(
|
||||
f"Accuracy: {accuracy} Tracker Max:{tracker_max} Accelerator Max:{acc_max}"
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
accuracy,
|
||||
1.0,
|
||||
delta=0.1,
|
||||
msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}",
|
||||
msg=f"Tracker Max:{tracker_max} Accelerator Max:{acc_max}",
|
||||
)
|
||||
del inp
|
||||
del model
|
||||
|
|
@ -177,7 +184,7 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
|
|||
class TestTrackerFullyShard1DTrainingCompose(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(torch.cuda.device_count(), 4)
|
||||
return min(torch.accelerator.device_count(), 4)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_tracker_with_activation_checkpointing(self):
|
||||
|
|
@ -197,12 +204,13 @@ class TestTrackerFullyShard1DTrainingCompose(FSDPTest):
|
|||
):
|
||||
assert checkpoint_impl in ("composable", "wrapper")
|
||||
debug = False
|
||||
dev = torch.device(torch.cuda.current_device())
|
||||
dev = torch.device(torch.accelerator.current_device_index())
|
||||
_init_cublas_workspace(dev)
|
||||
gc.collect()
|
||||
_reset_mem_stats(dev)
|
||||
mem_stats = torch.cuda.memory_stats(dev)
|
||||
pre_cuda_active = mem_stats["active_bytes.all.current"]
|
||||
mod = torch.get_device_module(dev)
|
||||
mem_stats = mod.memory_stats(dev)
|
||||
pre_acc_active = mem_stats["active_bytes.all.current"]
|
||||
torch.manual_seed(42)
|
||||
vocab_size = 8192
|
||||
bsz, seq_len = 16, 512
|
||||
|
|
@ -249,17 +257,19 @@ class TestTrackerFullyShard1DTrainingCompose(FSDPTest):
|
|||
optim.zero_grad()
|
||||
if iter_idx == 0:
|
||||
fmt.reset_mod_stats()
|
||||
mem_stats = torch.cuda.memory_stats()
|
||||
mem_stats = mod.memory_stats()
|
||||
tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"]
|
||||
cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active
|
||||
accuracy = tracker_max / cuda_max
|
||||
acc_max = mem_stats["active_bytes.all.peak"] - pre_acc_active
|
||||
accuracy = tracker_max / acc_max
|
||||
if self.rank == 0 and debug:
|
||||
print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}")
|
||||
print(
|
||||
f"Accuracy: {accuracy} Tracker Max:{tracker_max} Accelerator Max:{acc_max}"
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
accuracy,
|
||||
1.0,
|
||||
delta=0.1,
|
||||
msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}",
|
||||
msg=f"Tracker Max:{tracker_max} Accelerator Max:{acc_max}",
|
||||
)
|
||||
del inp
|
||||
del model
|
||||
|
|
|
|||
|
|
@ -5,11 +5,12 @@ import unittest
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed._tools.mem_tracker import MemTracker
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skipIfRocm,
|
||||
skipIfTorchDynamo,
|
||||
TEST_CUDA,
|
||||
TEST_XPU,
|
||||
TestCase,
|
||||
)
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
|
@ -24,25 +25,29 @@ class TestMemTracker(TestCase):
|
|||
del inp
|
||||
|
||||
def _reset_mem_stats(self, dev: torch.device):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_accumulated_memory_stats(dev)
|
||||
torch.cuda.reset_peak_memory_stats(dev)
|
||||
mod = torch.get_device_module(dev)
|
||||
mod.empty_cache()
|
||||
mod.reset_accumulated_memory_stats(dev)
|
||||
mod.reset_peak_memory_stats(dev)
|
||||
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA and not TEST_XPU, "Neither CUDA or XPU is not available"
|
||||
)
|
||||
@skipIfRocm()
|
||||
def test_cuda_tracker_equivalence(
|
||||
def test_accelerator_tracker_equivalence(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
Tests that the tracker correctly calculates the peak memory.
|
||||
"""
|
||||
dev = torch.device(torch.cuda.current_device())
|
||||
dev = torch.device(torch.accelerator.current_device_index())
|
||||
self._init_cublas_workspace(dev)
|
||||
gc.collect(1)
|
||||
self._reset_mem_stats(dev)
|
||||
mem_stats = torch.cuda.memory_stats(dev)
|
||||
pre_cuda_active = mem_stats["active_bytes.all.current"]
|
||||
mod = torch.get_device_module(dev)
|
||||
mem_stats = mod.memory_stats(dev)
|
||||
pre_acc_active = mem_stats["active_bytes.all.current"]
|
||||
bsz, n_layers, dim, dtype = 16, 4, 512, torch.bfloat16
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
|
@ -74,25 +79,28 @@ class TestMemTracker(TestCase):
|
|||
# Check for accuracy of peak memory
|
||||
|
||||
tracker_max = mt.get_tracker_snapshot("peak")[dev]["Total"]
|
||||
mem_stats = torch.cuda.memory_stats(dev)
|
||||
cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active
|
||||
accuracy = tracker_max / cuda_max
|
||||
mem_stats = mod.memory_stats(dev)
|
||||
acc_max = mem_stats["active_bytes.all.peak"] - pre_acc_active
|
||||
accuracy = tracker_max / acc_max
|
||||
self.assertAlmostEqual(accuracy, 1.0, delta=0.1)
|
||||
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA and not TEST_XPU, "Neither CUDA or XPU is not available"
|
||||
)
|
||||
def test_tracker_with_activation_checkpointing(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
Tests that the tracker correctly computes the peak memory during activation checkpointing.
|
||||
"""
|
||||
dev = torch.device(torch.cuda.current_device())
|
||||
dev = torch.device(torch.accelerator.current_device_index())
|
||||
self._init_cublas_workspace(dev)
|
||||
gc.collect(1)
|
||||
self._reset_mem_stats(dev)
|
||||
mem_stats = torch.cuda.memory_stats(dev)
|
||||
pre_cuda_active = mem_stats["active_bytes.all.current"]
|
||||
mod = torch.get_device_module(dev)
|
||||
mem_stats = mod.memory_stats(dev)
|
||||
pre_acc_active = mem_stats["active_bytes.all.current"]
|
||||
|
||||
bsz, n_layers, dim, dtype = 128, 4, 1024, torch.float16
|
||||
|
||||
|
|
@ -144,9 +152,9 @@ class TestMemTracker(TestCase):
|
|||
|
||||
# Check for accuracy of peak memory
|
||||
tracker_max = mt.get_tracker_snapshot("peak")[dev]["Total"]
|
||||
mem_stats = torch.cuda.memory_stats(dev)
|
||||
cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active
|
||||
accuracy = tracker_max / cuda_max
|
||||
mem_stats = mod.memory_stats(dev)
|
||||
acc_max = mem_stats["active_bytes.all.peak"] - pre_acc_active
|
||||
accuracy = tracker_max / acc_max
|
||||
self.assertAlmostEqual(accuracy, 1.0, delta=0.1)
|
||||
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
|
||||
|
|
|
|||
|
|
@ -5,19 +5,18 @@ import unittest
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed._tools import MemoryTracker
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class TestMemoryTracker(TestCase):
|
||||
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "no cuda/xpu")
|
||||
@unittest.skipIf(not torch.accelerator.is_available(), "no accelerator")
|
||||
def test_local_model(self):
|
||||
"""
|
||||
Minimal test case to check the memory tracker can collect the expected
|
||||
memory stats at operator level, as well as can print the summary result
|
||||
without crash.
|
||||
"""
|
||||
device = "cuda" if TEST_CUDA else "xpu"
|
||||
device = torch.accelerator.current_accelerator()
|
||||
# Create a model with a hierarchy of modules
|
||||
torch.manual_seed(0)
|
||||
model = nn.Sequential(
|
||||
|
|
@ -35,9 +34,9 @@ class TestMemoryTracker(TestCase):
|
|||
tracker = MemoryTracker()
|
||||
tracker.start_monitor(model)
|
||||
|
||||
x = torch.randn(size=(2, 3, 224, 224), device=torch.device(device))
|
||||
# torch.LongTensor expects cpu device type, not device type in
|
||||
# constructor, so calling .to(device) outside constructor here.
|
||||
x = torch.randn(size=(2, 3, 224, 224), device=device)
|
||||
# torch.LongTensor expects cpu device type, not gpu device type in
|
||||
# constructor, so calling .to() outside constructor here.
|
||||
target = torch.LongTensor([0, 1]).to(device)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
criterion(model(x), target).backward()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user