This PR get rids of the dim_groups attribute from DeviceMesh, the main
motivation behind this is that we should let c10d store the process
groups during its creation instead of DeviceMesh, DeviceMesh should just
handle ranks correctly.
This could enable DTensor becomes picklable! (torch.save/load could be
possible), which I will give it a try in the next PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103105
Approved by: https://github.com/XilunWu, https://github.com/fduwjj
Also not sure if this should be a public function or not. Leaving it private for now but let me know if you prefer for it to be public.
FYI @nikitaved this will logically conflict with your triton kernel PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101420
Approved by: https://github.com/malfet
Summary:
Currently there are build configs where the torchdynamo import trips over a
strange SystemError related to some module's __dict__.items() returning NULL,
while torchdynamo tries to iterate all torch modules and process them for
its allowed functions list.
While this is hard to repro, we should be able to work around it and then fix
it properly.
Test Plan: Rely on others to test this, assuming CI passes.
Reviewed By: anijain2305
Differential Revision: D45663313
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100901
Approved by: https://github.com/yanboliang, https://github.com/malfet
We do it by making it possible to register multiple tensors for the same
worker and coordinate waiting/cleanup among them.
This ensures waiting on any number the output tensors will result in a
single stream sync. This simplifies codegen by inductor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99763
Approved by: https://github.com/wanchaol
Summary:
This diff is reverting D45387167
D45387167: Basic dynamo support for traceable collectives (#94440) by wconstab has been identified to be causing the following test or build failures (internal)
If you believe this diff has been generated in error you may Commandeer and Abandon it.
Test Plan: NA
Reviewed By: s4ayub
Differential Revision: D45448312
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100424
Approved by: https://github.com/rohan-varma, https://github.com/kumpera
Make traceable collectives work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.
Accept a suboptimal solution for now, and optimize it later.
For now, wait happens immediately, which generally forces an early sync.
Later, find a way either in dynamo or AOT stack to handle
AsyncCollectiveTensor to get the wait in the optimal place.
Note on implementation:
- Dynamo traces 'user-level' fc apis that are designed to behave differently
in eager vs compiled. In eager, there will be work-obj registration and
a wrapper subclass will insert a 'wait' call at the appropriate time.
In compile/trace mode, wait will be immetiately called, and work obj
registration is required to be handled by the compile backend at runtime.
- Dynamo needs to trace into some of the helper functions in the 'user-level'
api, such as '_expand_group' which is essentially a constant transformation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94440
Approved by: https://github.com/kumpera
Summary:
Original commit changeset: ba36f8751adc
Original Phabricator Diff: D44788697
Test Plan: model loading is fine after reverting the diff
Reviewed By: zyan0, sayitmemory
Differential Revision: D44921259
---
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99168
Approved by: https://github.com/izaitsevfb
Inductor codegen is suboptimal when calling all_reduce_coalesced with input args. We need to fix inductor's calling convention for that, or something else.
Might not work if any outputs is unused.
Test code:
```python
import torch
import torch.distributed as dist
import torch.nn.functional as F
from functorch import make_fx
import os
import torch.distributed._functional_collectives as ft_c
from torch.testing._internal.common_distributed import (
spawn_threads_and_init_comms,
)
from torch._inductor.compile_fx import compile_fx_inner
def my_fun(a, b):
c = a * 3
tensors = ft_c.all_reduce_coalesced([a, c, b], "sum", [0])
return ((tensors[1] + tensors[0] + tensors[2]).sum(), )
@spawn_threads_and_init_comms(world_size=1)
def inductor_main(self):
x = torch.arange(4).cuda() * (dist.get_rank() + 1)
y = torch.arange(4).cuda() * (dist.get_rank() + 1)
x = x.to(torch.float)
y = y.to(torch.float) * 0.5
res = make_fx(my_fun)(x, y)
print(f"fx graph:\n{res.graph}")
ind = compile_fx_inner(res, [x, y])
print(f"inductor done:\n{ind}")
os.environ["PROXY_TENSOR_TRACING"] = "1"
os.environ["TORCH_COMPILE_DEBUG"] = "1"
torch._dynamo.config.output_code = True
if __name__ == "__main__":
inductor_main(None)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97157
Approved by: https://github.com/fegin
Among the changes is the introduction of gather_dim and scatter_dim in DeviceMesh collectives to simplify user code.
The current plan is to keep padding and gather/scatter dim support in DeviceMesh while we explore optimization opportunities in Inductor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96226
Approved by: https://github.com/wanchaol
_functional_collectives.py: Ensure we always wait all collectives.
derivatives.yaml: mark all_reduce as non differentiable
gen_variable_type.py: Add all_reduce to DONT_ENFORCE_TENSOR_IMPL_USE_COUNT
common_dtensor.py: replace dist.barrier with all_reduce
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95897
Approved by: https://github.com/wconstab, https://github.com/fegin
Inductor implementations of collectives/wait must match
eager impls in _functional_collectives in terms of interacting
with _register_tensor_work API. If they do, then splitting
a collective-wait pair so one half is in a compiled graph should
work fine.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95893
Approved by: https://github.com/kumpera
BC: This changes the signature and semantics of DeviceMesh::all_reduce.
DeviceMesh::all_reduce now uses a functional collective under the hood which makes it more easily traceable.
You no longer need to use CommTensor to get a trace.
all_reduce now is async only and uses AsyncCollectiveTensor to ensure proper stream synchronization.
Signature changed: removed `async_op` param and changes return type from `Optional[Work]` to `torch.Tensor`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95009
Approved by: https://github.com/wanchaol