mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Don't require FakeStore to be passed into fake backend (#162164)
Signed-off-by: Edward Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/162164 Approved by: https://github.com/bdhirsh, https://github.com/albanD, https://github.com/wconstab
This commit is contained in:
parent
1ebd70d0c0
commit
248355faf5
|
|
@ -40,16 +40,14 @@ class TestFakePG(TestCase):
|
|||
pass
|
||||
|
||||
def test_all_reduce(self):
|
||||
store = FakeStore()
|
||||
dist.init_process_group(backend="fake", rank=1, world_size=2, store=store)
|
||||
dist.init_process_group(backend="fake", rank=1, world_size=2)
|
||||
|
||||
output = torch.ones(3, 3) * dist.get_rank()
|
||||
dist.all_reduce(output)
|
||||
self.assertEqual(tuple(output.shape), (3, 3))
|
||||
|
||||
def test_allgather(self):
|
||||
store = FakeStore()
|
||||
dist.init_process_group(backend="fake", rank=1, world_size=2, store=store)
|
||||
dist.init_process_group(backend="fake", rank=1, world_size=2)
|
||||
|
||||
input_tensor = torch.ones(3, 3) * dist.get_rank()
|
||||
output_tensors = [torch.empty_like(input_tensor) for _ in range(2)]
|
||||
|
|
@ -106,8 +104,7 @@ class TestFakePG(TestCase):
|
|||
FileCheck().check("all_gather").check("wait_tensor").run(str(gm.graph))
|
||||
|
||||
def test_broadcast(self):
|
||||
store = FakeStore()
|
||||
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
|
||||
dist.init_process_group(backend="fake", rank=0, world_size=2)
|
||||
|
||||
# src == rank
|
||||
output = torch.ones(3, 3)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from functorch import make_fx
|
|||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
|
||||
|
||||
|
|
@ -431,12 +430,10 @@ class TestMakeFx(TestCase):
|
|||
# so create a fake_pg.
|
||||
self.rank = 0
|
||||
self.world_size = 2
|
||||
store = FakeStore()
|
||||
dist.init_process_group(
|
||||
backend="fake",
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
|
|
@ -598,7 +595,6 @@ class TestCollectivesWithDistributedBackend(DistributedTestBase):
|
|||
backend="fake",
|
||||
rank=0,
|
||||
world_size=8,
|
||||
store=FakeStore(),
|
||||
)
|
||||
allreduce(torch.randn(8, device=device), pg=dist.group.WORLD)
|
||||
dist.destroy_process_group()
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ if dist.is_available():
|
|||
wait_tensor,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
|
||||
|
||||
def normalize_graph(gm):
|
||||
|
|
@ -25,8 +24,7 @@ def normalize_graph(gm):
|
|||
class TestFakeDistributed(DynamoTestCase):
|
||||
def setUp(self):
|
||||
# Use FakeProcessGroup to run tests on a single process
|
||||
self.store = FakeStore()
|
||||
dist.init_process_group(backend="fake", rank=0, world_size=2, store=self.store)
|
||||
dist.init_process_group(backend="fake", rank=0, world_size=2)
|
||||
self.local_rank = 0
|
||||
self.world_size = 2
|
||||
|
||||
|
|
|
|||
|
|
@ -15529,13 +15529,10 @@ class GraphModule(torch.nn.Module):
|
|||
@contextmanager
|
||||
def distributed_env(self, world_size):
|
||||
try:
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
|
||||
torch.distributed.init_process_group(
|
||||
backend="fake",
|
||||
world_size=world_size,
|
||||
rank=0,
|
||||
store=FakeStore(),
|
||||
)
|
||||
yield
|
||||
|
||||
|
|
|
|||
|
|
@ -338,8 +338,6 @@ class TestDCE(TestCase):
|
|||
Test that DCE doesn't remote collective ops even the results are not used.
|
||||
"""
|
||||
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(
|
||||
self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
|
||||
|
|
@ -354,7 +352,6 @@ class TestDCE(TestCase):
|
|||
backend="fake",
|
||||
world_size=2,
|
||||
rank=0,
|
||||
store=FakeStore(),
|
||||
)
|
||||
# collective nodes should not be removed because they have side effects.
|
||||
self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=False)
|
||||
|
|
@ -366,8 +363,6 @@ class TestDCE(TestCase):
|
|||
Test that DCE doesn't remote collective ops (no overload version) even the results are not used.
|
||||
"""
|
||||
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(
|
||||
self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
|
||||
|
|
@ -382,7 +377,6 @@ class TestDCE(TestCase):
|
|||
backend="fake",
|
||||
world_size=2,
|
||||
rank=0,
|
||||
store=FakeStore(),
|
||||
)
|
||||
# collective nodes should not be removed because they have side effects.
|
||||
self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=False)
|
||||
|
|
|
|||
|
|
@ -1753,11 +1753,16 @@ def init_process_group(
|
|||
else:
|
||||
# backward compatible API
|
||||
if store is None:
|
||||
rendezvous_iterator = rendezvous(
|
||||
not_none(init_method), rank, world_size, timeout=timeout
|
||||
)
|
||||
store, rank, world_size = next(rendezvous_iterator)
|
||||
store.set_timeout(timeout)
|
||||
if backend == "fake":
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
|
||||
store = FakeStore()
|
||||
else:
|
||||
rendezvous_iterator = rendezvous(
|
||||
not_none(init_method), rank, world_size, timeout=timeout
|
||||
)
|
||||
store, rank, world_size = next(rendezvous_iterator)
|
||||
store.set_timeout(timeout)
|
||||
|
||||
# Use a PrefixStore to avoid accidental overrides of keys used by
|
||||
# different systems (e.g. RPC) in case the store is multi-tenant.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user