Sandeep Narendranath Karjala
76ca23c41c
[dynamo] Add FakeProcessGroup support for fx_graph_runnable with distributed collectives ( #157162 )
...
Stack from [ghstack](https://github.com/ezyang/ghstack ) (oldest at bottom):
Summary:
- Modified generate_compiler_repro_string() to automatically detect distributed operations and inject FakeProcessGroup setup code
- Added distributed collective tests in test/dynamo/test_fx_graph_runnable.py using FakeProcessGroup API to test distributed collective operations
- Generated fx_graph_runnable code now runs successfully standalone when containing distributed operations
```import os
os.environ['TORCHINDUCTOR_CACHE_DIR'] = '/var/folders/fd/kcv8m1kn0lqgxz42wvgr46sc0000gn/T/torchinductor_skarjala'
import torch
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
import torch._inductor.inductor_prims
import torch.distributed as dist
from torch.testing._internal.distributed.fake_pg import FakeStore
import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
import torch.fx.experimental._config
torch._functorch.config.functionalize_rng_ops = False
torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access = True
torch._functorch.config.unlift_effect_tokens = True
isolate_fails_code_str = None
# torch version: 2.9.0a0+gitf23d314
# torch cuda version: None
# torch git version: f23d31463ca452918e23063409a2bdc55efc0d46
# torch.cuda.is_available()==False, no GPU info collected
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, arg0_1):
all_reduce = torch.ops._c10d_functional.all_reduce.default(arg0_1, 'sum', '0')
wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_reduce); all_reduce = None
mul = torch.ops.aten.mul.Tensor(wait_tensor, 2)
copy_ = torch.ops.aten.copy_.default(arg0_1, wait_tensor); arg0_1 = wait_tensor = copy_ = None
return (mul,)
def load_args(reader):
buf0 = reader.storage(None, 64)
reader.tensor(buf0, (4, 4), is_leaf=True) # arg0_1
load_args._version = 0
mod = Repro()
if __name__ == '__main__':
from torch._dynamo.repro.after_aot import run_repro
# Initialize FakeProcessGroup for distributed operations
store = FakeStore()
dist.init_process_group(
backend="fake",
rank=0,
world_size=2,
store=store
)
with torch.no_grad():
run_repro(mod, load_args, accuracy=False, command='run', save_dir=None, tracing_mode='real', check_str=None)
# To run it separately, do
# mod, args = run_repro(mod, load_args, accuracy=False, command='get_args', save_dir=None, tracing_mode='real', check_str=None)
# mod(*args)
dist.destroy_process_group()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157162
Approved by: https://github.com/xmfan
2025-07-10 20:30:27 +00:00