PyWork: preserve Python reference counting when used in functional collectives (#146376)

@fegin  found an issue where torchft is not compatible with functional collectives.

Found in https://github.com/pytorch/torchtitan/pull/806

The root cause is because PyProcessGroup/PyWork are not compatible with functional collectives due to a nasty ownership bug.

PyWork relies on a pybind trampoline to propagate requests to Python unfortunately the way Pybind works is that the Python object owns the C++ object rather than some form of shared ownership. Thus what happens is that the PyWork Python object will collected when returned to C++ from the PyProcessGroup but the C++ PyWork object still exists. When the PyWork object is used, this causes a deadlock as the corresponding Python object no longer exists

To solve this, we introduce a new `PyWorkHolder` class which holds a reference to the `py::object` as well as the trampoline class. This resolves any dependency issues since we can now hold ownership in C++ to both the Python and C++ objects.

To make this cleaner we introduce a `WORK_OVERRIDE` macro which is a patched version of `PYBIND11_OVERRIDE` that returns a `PyWorkHolder` rather than just `PyWork` and use for all collectives in PyProcessGroup.

Test plan:

```
cd pytorch
pytest test/distributed/test_c10d_functional_native.py
```

```
cd torchft
pytest torchft/process_group_test.py -k functional -v -x -s
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146376
Approved by: https://github.com/yifuwang
This commit is contained in:
Tristan Rice 2025-02-07 18:07:53 +00:00 committed by PyTorch MergeBot
parent 76c8a2dc48
commit 68631f6e87
3 changed files with 209 additions and 50 deletions

View File

@ -1,6 +1,9 @@
# Owner(s): ["module: c10d"]
import gc
import threading
import unittest
from datetime import timedelta
from typing import List, Optional
import torch
import torch.distributed as dist
@ -435,22 +438,6 @@ class TestWithNCCL(MultiProcessTestCase):
)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
@skip_if_lt_x_gpu(2)
def test_py_work(self) -> None:
self._init_process_group()
wait_called = False
class MyWork(dist.Work):
def wait(self, _):
nonlocal wait_called
wait_called = True
tensor = torch.rand(2, 2)
torch._C._distributed_c10d._register_work(tensor, MyWork())
torch.ops._c10d_functional.wait_tensor(tensor)
self.assertTrue(wait_called)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@fresh_inductor_cache()
@ -494,6 +481,158 @@ class TestWithNCCL(MultiProcessTestCase):
t.join()
def dummy_init_pg() -> None:
if not dist.is_initialized():
dist.init_process_group(
backend="gloo", rank=0, world_size=1, store=dist.HashStore()
)
class _DummyWork(dist.Work):
def __init__(self, pg: "ProcessGroupDummy") -> None:
super().__init__()
self.pg = pg
def wait(self, timeout: Optional[timedelta] = None) -> bool:
self.pg.waits += 1
return True
def __del__(self):
self.pg.dels += 1
class ProcessGroupDummy(dist.ProcessGroup):
"""
This process group discards all data passed to it and returns success. This
is intended for rare cases where we want to discard certain operations
without modifying the underlying library.
This PG only supports world_size of 1.
"""
def __init__(self) -> None:
super().__init__(0, 1)
self._group_name = "dummy:dummy"
self.waits = 0
self.dels = 0
def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> dist.Work:
return _DummyWork(self)
def allgather_into_tensor_coalesced(
self,
output_lists: List[torch.Tensor],
input_list: List[torch.Tensor],
opts: object,
) -> dist.Work:
return _DummyWork(self)
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> dist.Work:
return _DummyWork(self)
def reduce_scatter_tensor_coalesced(
self,
outputTensors: List[torch.Tensor],
inputTensors: List[torch.Tensor],
opts: object,
) -> dist.Work:
return _DummyWork(self)
@property
def group_name(self) -> str:
if self._group_name is None:
raise ValueError("ProcessGroup name not set")
return self._group_name
def _set_group_name(self, name: str) -> None:
self._group_name = name
def register(self) -> dist.ProcessGroup:
def create_pg(
prefix_store: dist.PrefixStore, rank: int, world_size: int, timeout: float
) -> dist.ProcessGroup:
return self
dist.Backend.register_backend(self.group_name, create_pg, devices=["cpu"])
return dist.new_group(
ranks=[0],
backend=self.group_name,
group_desc=self.group_name,
timeout=timedelta(seconds=60.0), # this timeout isn't used
)
class PyWorkTest(TestCase):
"""
Native functional collectives have some interesting interactions with
PyProcessGroup due to Python reference counting and pybind trampoline
classes with C++ types. This validates that PyProcessGroup and PyWork
aren't getting prematurely freed.
"""
def test_wait_tensor(self) -> None:
wait_called = False
class MyWork(dist.Work):
def wait(self, _):
nonlocal wait_called
wait_called = True
# check registration and implicit unregistration
tensor = torch.rand(2, 2)
work = MyWork()
torch._C._distributed_c10d._register_work(tensor, work)
# Force GC collection of the MyWork object, if we're not doing correct
# reference counting we'll deadlock in wait_tensor.
del work
gc.collect()
torch.ops._c10d_functional.wait_tensor(tensor)
self.assertTrue(wait_called)
def test_collectives(self) -> None:
dummy_init_pg()
pg = ProcessGroupDummy().register()
x = torch.rand(2, 2)
x = funcol.all_reduce(x, "sum", group=pg)
gc.collect()
self.assertEqual(pg.dels, 0)
x.wait()
self.assertEqual(pg.waits, 1)
self.assertEqual(pg.dels, 1)
x = torch.rand(2, 2)
x = funcol.broadcast(x, 0, group=pg)
gc.collect()
self.assertEqual(pg.dels, 1)
x.wait()
self.assertEqual(pg.waits, 2)
self.assertEqual(pg.dels, 2)
x = torch.rand(2, 2)
x = funcol.all_gather_tensor(x, 0, group=pg)
gc.collect()
self.assertEqual(pg.dels, 2)
x.wait()
self.assertEqual(pg.waits, 3)
self.assertEqual(pg.dels, 3)
x = torch.rand(2, 2)
x = funcol.reduce_scatter_tensor(x, "sum", 0, group=pg)
gc.collect()
self.assertEqual(pg.dels, 3)
x.wait()
self.assertEqual(pg.waits, 4)
self.assertEqual(pg.dels, 4)
class CompileTest(TestCase):
def setUp(self):
# Allow testing aoti after torch.compile

View File

@ -41,19 +41,48 @@ class PyProcessGroup : public ProcessGroup {
return Work::getFuture();
}
};
// Take a reference of the corresponding py::object.
// With functional collectives, ownership of work objects is generally
// transferred to C++. For pure C++ work objects, it is sufficient to
// transfer the ownership of work object. For user-defined work objects in
// Python, it is necessary to keep the corresponding py::object alive in
// addition to ensure that the user-defined methods can be executed.
void ref_py_object() {
py_obj_ = py::cast(this);
#define WORK_OVERRIDE(cname, name, ...) \
do { \
pybind11::gil_scoped_acquire gil; \
pybind11::function override = \
pybind11::get_override(static_cast<const cname*>(this), #name); \
if (override) { \
auto o = override(__VA_ARGS__); \
return c10::make_intrusive<PyWorkHolder>(o); \
} \
return cname::name(__VA_ARGS__); \
} while (false)
// This class is used to wrap a PyWork trampoline with it's corresponding
// Python object to prevent the Python object from being garbage collected.
class PyWorkHolder : public Work {
public:
PyWorkHolder(const c10::intrusive_ptr<Work>& work, py::object pyWork)
: work_(work), pyWork_(std::move(pyWork)) {}
PyWorkHolder(py::object pyWork)
: work_(pyWork.cast<c10::intrusive_ptr<Work>>()),
pyWork_(std::move(pyWork)) {}
~PyWorkHolder() override {
// GIL must be held when freeing python objects.
py::gil_scoped_acquire gil;
pyWork_ = py::object();
}
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override {
return work_->wait(timeout);
}
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
return work_->getFuture();
}
private:
py::object py_obj_;
c10::intrusive_ptr<Work> work_;
py::object pyWork_;
};
using ProcessGroup::ProcessGroup;
@ -118,8 +147,7 @@ class PyProcessGroup : public ProcessGroup {
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<Work>, /* Return type */
WORK_OVERRIDE(
ProcessGroup, /* Parent class */
allgather, /* Name of function in C++ */
outputTensors,
@ -131,8 +159,7 @@ class PyProcessGroup : public ProcessGroup {
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<Work>, /* Return type */
WORK_OVERRIDE(
ProcessGroup, /* Parent class */
allgather_into_tensor_coalesced, /* Name of function in C++ */
outputTensors,
@ -143,8 +170,8 @@ class PyProcessGroup : public ProcessGroup {
c10::intrusive_ptr<Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<Work>, /* Return type */
WORK_OVERRIDE(
// py::object, /* Return type */
ProcessGroup, /* Parent class */
allreduce, /* Name of function in C++ */
tensors,
@ -155,8 +182,7 @@ class PyProcessGroup : public ProcessGroup {
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts =
AllreduceCoalescedOptions()) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<Work>, /* Return type */
WORK_OVERRIDE(
ProcessGroup, /* Parent class */
allreduce_coalesced, /* Name of function in C++ */
tensors,
@ -169,8 +195,7 @@ class PyProcessGroup : public ProcessGroup {
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& opts = AllToAllOptions()) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<Work>, /* Return type */
WORK_OVERRIDE(
ProcessGroup, /* Parent class */
alltoall_base, /* Name of function in C++ */
outputBuffer,
@ -182,8 +207,7 @@ class PyProcessGroup : public ProcessGroup {
c10::intrusive_ptr<Work> barrier(
const BarrierOptions& opts = BarrierOptions()) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<Work>, /* Return type */
WORK_OVERRIDE(
ProcessGroup, /* Parent class */
barrier, /* Name of function in C++ */
opts);
@ -192,8 +216,7 @@ class PyProcessGroup : public ProcessGroup {
c10::intrusive_ptr<Work> broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts = BroadcastOptions()) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<Work>, /* Return type */
WORK_OVERRIDE(
ProcessGroup, /* Parent class */
broadcast, /* Name of function in C++ */
tensors,
@ -204,8 +227,7 @@ class PyProcessGroup : public ProcessGroup {
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<Work>, /* Return type */
WORK_OVERRIDE(
ProcessGroup, /* Parent class */
reduce_scatter, /* Name of function in C++ */
outputTensors,
@ -217,8 +239,7 @@ class PyProcessGroup : public ProcessGroup {
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<Work>, /* Return type */
WORK_OVERRIDE(
ProcessGroup, /* Parent class */
reduce_scatter_tensor_coalesced, /* Name of function in C++ */
outputTensors,
@ -230,8 +251,7 @@ class PyProcessGroup : public ProcessGroup {
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<Work>, /* Return type */
WORK_OVERRIDE(
ProcessGroup, /* Parent class */
send, /* Name of function in C++ */
tensors,
@ -243,8 +263,7 @@ class PyProcessGroup : public ProcessGroup {
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<Work>, /* Return type */
WORK_OVERRIDE(
ProcessGroup, /* Parent class */
recv, /* Name of function in C++ */
tensors,

View File

@ -953,9 +953,10 @@ This class does not support ``__members__`` property.)");
"_register_work",
[](const at::Tensor& tensor,
const c10::intrusive_ptr<::c10d::Work>& work) {
dynamic_cast<::c10d::PyProcessGroup::PyWork*>(work.get())
->ref_py_object();
::c10d::register_work(tensor, work);
py::object obj = py::cast(work);
auto holder = c10::make_intrusive<::c10d::PyProcessGroup::PyWorkHolder>(
work, obj);
::c10d::register_work(tensor, holder);
},
py::arg("tensor"),
py::arg("work"));