Don't require FakeStore to be passed into fake backend (#162164)

Signed-off-by: Edward Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162164
Approved by: https://github.com/bdhirsh, https://github.com/albanD, https://github.com/wconstab
This commit is contained in:
Edward Yang 2025-09-04 09:41:47 -04:00 committed by PyTorch MergeBot
parent 1ebd70d0c0
commit 248355faf5
6 changed files with 14 additions and 27 deletions

View File

@ -40,16 +40,14 @@ class TestFakePG(TestCase):
pass
def test_all_reduce(self):
store = FakeStore()
dist.init_process_group(backend="fake", rank=1, world_size=2, store=store)
dist.init_process_group(backend="fake", rank=1, world_size=2)
output = torch.ones(3, 3) * dist.get_rank()
dist.all_reduce(output)
self.assertEqual(tuple(output.shape), (3, 3))
def test_allgather(self):
store = FakeStore()
dist.init_process_group(backend="fake", rank=1, world_size=2, store=store)
dist.init_process_group(backend="fake", rank=1, world_size=2)
input_tensor = torch.ones(3, 3) * dist.get_rank()
output_tensors = [torch.empty_like(input_tensor) for _ in range(2)]
@ -106,8 +104,7 @@ class TestFakePG(TestCase):
FileCheck().check("all_gather").check("wait_tensor").run(str(gm.graph))
def test_broadcast(self):
store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
dist.init_process_group(backend="fake", rank=0, world_size=2)
# src == rank
output = torch.ones(3, 3)

View File

@ -13,7 +13,6 @@ from functorch import make_fx
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.testing._internal.inductor_utils import HAS_GPU
@ -431,12 +430,10 @@ class TestMakeFx(TestCase):
# so create a fake_pg.
self.rank = 0
self.world_size = 2
store = FakeStore()
dist.init_process_group(
backend="fake",
world_size=self.world_size,
rank=self.rank,
store=store,
)
def tearDown(self):
@ -598,7 +595,6 @@ class TestCollectivesWithDistributedBackend(DistributedTestBase):
backend="fake",
rank=0,
world_size=8,
store=FakeStore(),
)
allreduce(torch.randn(8, device=device), pg=dist.group.WORLD)
dist.destroy_process_group()

View File

@ -14,7 +14,6 @@ if dist.is_available():
wait_tensor,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.testing._internal.distributed.fake_pg import FakeStore
def normalize_graph(gm):
@ -25,8 +24,7 @@ def normalize_graph(gm):
class TestFakeDistributed(DynamoTestCase):
def setUp(self):
# Use FakeProcessGroup to run tests on a single process
self.store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=self.store)
dist.init_process_group(backend="fake", rank=0, world_size=2)
self.local_rank = 0
self.world_size = 2

View File

@ -15529,13 +15529,10 @@ class GraphModule(torch.nn.Module):
@contextmanager
def distributed_env(self, world_size):
try:
from torch.testing._internal.distributed.fake_pg import FakeStore
torch.distributed.init_process_group(
backend="fake",
world_size=world_size,
rank=0,
store=FakeStore(),
)
yield

View File

@ -338,8 +338,6 @@ class TestDCE(TestCase):
Test that DCE doesn't remote collective ops even the results are not used.
"""
from torch.testing._internal.distributed.fake_pg import FakeStore
class TestModule(torch.nn.Module):
def forward(
self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
@ -354,7 +352,6 @@ class TestDCE(TestCase):
backend="fake",
world_size=2,
rank=0,
store=FakeStore(),
)
# collective nodes should not be removed because they have side effects.
self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=False)
@ -366,8 +363,6 @@ class TestDCE(TestCase):
Test that DCE doesn't remote collective ops (no overload version) even the results are not used.
"""
from torch.testing._internal.distributed.fake_pg import FakeStore
class TestModule(torch.nn.Module):
def forward(
self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
@ -382,7 +377,6 @@ class TestDCE(TestCase):
backend="fake",
world_size=2,
rank=0,
store=FakeStore(),
)
# collective nodes should not be removed because they have side effects.
self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=False)

View File

@ -1753,11 +1753,16 @@ def init_process_group(
else:
# backward compatible API
if store is None:
rendezvous_iterator = rendezvous(
not_none(init_method), rank, world_size, timeout=timeout
)
store, rank, world_size = next(rendezvous_iterator)
store.set_timeout(timeout)
if backend == "fake":
from torch.testing._internal.distributed.fake_pg import FakeStore
store = FakeStore()
else:
rendezvous_iterator = rendezvous(
not_none(init_method), rank, world_size, timeout=timeout
)
store, rank, world_size = next(rendezvous_iterator)
store.set_timeout(timeout)
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.