mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
PyWork: preserve Python reference counting when used in functional collectives (#146376)
@fegin found an issue where torchft is not compatible with functional collectives. Found in https://github.com/pytorch/torchtitan/pull/806 The root cause is because PyProcessGroup/PyWork are not compatible with functional collectives due to a nasty ownership bug. PyWork relies on a pybind trampoline to propagate requests to Python unfortunately the way Pybind works is that the Python object owns the C++ object rather than some form of shared ownership. Thus what happens is that the PyWork Python object will collected when returned to C++ from the PyProcessGroup but the C++ PyWork object still exists. When the PyWork object is used, this causes a deadlock as the corresponding Python object no longer exists To solve this, we introduce a new `PyWorkHolder` class which holds a reference to the `py::object` as well as the trampoline class. This resolves any dependency issues since we can now hold ownership in C++ to both the Python and C++ objects. To make this cleaner we introduce a `WORK_OVERRIDE` macro which is a patched version of `PYBIND11_OVERRIDE` that returns a `PyWorkHolder` rather than just `PyWork` and use for all collectives in PyProcessGroup. Test plan: ``` cd pytorch pytest test/distributed/test_c10d_functional_native.py ``` ``` cd torchft pytest torchft/process_group_test.py -k functional -v -x -s ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/146376 Approved by: https://github.com/yifuwang
This commit is contained in:
parent
76c8a2dc48
commit
68631f6e87
|
|
@ -1,6 +1,9 @@
|
|||
# Owner(s): ["module: c10d"]
|
||||
import gc
|
||||
import threading
|
||||
import unittest
|
||||
from datetime import timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -435,22 +438,6 @@ class TestWithNCCL(MultiProcessTestCase):
|
|||
)
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_py_work(self) -> None:
|
||||
self._init_process_group()
|
||||
|
||||
wait_called = False
|
||||
|
||||
class MyWork(dist.Work):
|
||||
def wait(self, _):
|
||||
nonlocal wait_called
|
||||
wait_called = True
|
||||
|
||||
tensor = torch.rand(2, 2)
|
||||
torch._C._distributed_c10d._register_work(tensor, MyWork())
|
||||
torch.ops._c10d_functional.wait_tensor(tensor)
|
||||
self.assertTrue(wait_called)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@fresh_inductor_cache()
|
||||
|
|
@ -494,6 +481,158 @@ class TestWithNCCL(MultiProcessTestCase):
|
|||
t.join()
|
||||
|
||||
|
||||
def dummy_init_pg() -> None:
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(
|
||||
backend="gloo", rank=0, world_size=1, store=dist.HashStore()
|
||||
)
|
||||
|
||||
|
||||
class _DummyWork(dist.Work):
|
||||
def __init__(self, pg: "ProcessGroupDummy") -> None:
|
||||
super().__init__()
|
||||
self.pg = pg
|
||||
|
||||
def wait(self, timeout: Optional[timedelta] = None) -> bool:
|
||||
self.pg.waits += 1
|
||||
return True
|
||||
|
||||
def __del__(self):
|
||||
self.pg.dels += 1
|
||||
|
||||
|
||||
class ProcessGroupDummy(dist.ProcessGroup):
|
||||
"""
|
||||
This process group discards all data passed to it and returns success. This
|
||||
is intended for rare cases where we want to discard certain operations
|
||||
without modifying the underlying library.
|
||||
|
||||
This PG only supports world_size of 1.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(0, 1)
|
||||
|
||||
self._group_name = "dummy:dummy"
|
||||
|
||||
self.waits = 0
|
||||
self.dels = 0
|
||||
|
||||
def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> dist.Work:
|
||||
return _DummyWork(self)
|
||||
|
||||
def allgather_into_tensor_coalesced(
|
||||
self,
|
||||
output_lists: List[torch.Tensor],
|
||||
input_list: List[torch.Tensor],
|
||||
opts: object,
|
||||
) -> dist.Work:
|
||||
return _DummyWork(self)
|
||||
|
||||
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> dist.Work:
|
||||
return _DummyWork(self)
|
||||
|
||||
def reduce_scatter_tensor_coalesced(
|
||||
self,
|
||||
outputTensors: List[torch.Tensor],
|
||||
inputTensors: List[torch.Tensor],
|
||||
opts: object,
|
||||
) -> dist.Work:
|
||||
return _DummyWork(self)
|
||||
|
||||
@property
|
||||
def group_name(self) -> str:
|
||||
if self._group_name is None:
|
||||
raise ValueError("ProcessGroup name not set")
|
||||
return self._group_name
|
||||
|
||||
def _set_group_name(self, name: str) -> None:
|
||||
self._group_name = name
|
||||
|
||||
def register(self) -> dist.ProcessGroup:
|
||||
def create_pg(
|
||||
prefix_store: dist.PrefixStore, rank: int, world_size: int, timeout: float
|
||||
) -> dist.ProcessGroup:
|
||||
return self
|
||||
|
||||
dist.Backend.register_backend(self.group_name, create_pg, devices=["cpu"])
|
||||
|
||||
return dist.new_group(
|
||||
ranks=[0],
|
||||
backend=self.group_name,
|
||||
group_desc=self.group_name,
|
||||
timeout=timedelta(seconds=60.0), # this timeout isn't used
|
||||
)
|
||||
|
||||
|
||||
class PyWorkTest(TestCase):
|
||||
"""
|
||||
Native functional collectives have some interesting interactions with
|
||||
PyProcessGroup due to Python reference counting and pybind trampoline
|
||||
classes with C++ types. This validates that PyProcessGroup and PyWork
|
||||
aren't getting prematurely freed.
|
||||
"""
|
||||
|
||||
def test_wait_tensor(self) -> None:
|
||||
wait_called = False
|
||||
|
||||
class MyWork(dist.Work):
|
||||
def wait(self, _):
|
||||
nonlocal wait_called
|
||||
wait_called = True
|
||||
|
||||
# check registration and implicit unregistration
|
||||
|
||||
tensor = torch.rand(2, 2)
|
||||
work = MyWork()
|
||||
torch._C._distributed_c10d._register_work(tensor, work)
|
||||
|
||||
# Force GC collection of the MyWork object, if we're not doing correct
|
||||
# reference counting we'll deadlock in wait_tensor.
|
||||
del work
|
||||
gc.collect()
|
||||
|
||||
torch.ops._c10d_functional.wait_tensor(tensor)
|
||||
self.assertTrue(wait_called)
|
||||
|
||||
def test_collectives(self) -> None:
|
||||
dummy_init_pg()
|
||||
|
||||
pg = ProcessGroupDummy().register()
|
||||
|
||||
x = torch.rand(2, 2)
|
||||
x = funcol.all_reduce(x, "sum", group=pg)
|
||||
gc.collect()
|
||||
self.assertEqual(pg.dels, 0)
|
||||
x.wait()
|
||||
self.assertEqual(pg.waits, 1)
|
||||
self.assertEqual(pg.dels, 1)
|
||||
|
||||
x = torch.rand(2, 2)
|
||||
x = funcol.broadcast(x, 0, group=pg)
|
||||
gc.collect()
|
||||
self.assertEqual(pg.dels, 1)
|
||||
x.wait()
|
||||
self.assertEqual(pg.waits, 2)
|
||||
self.assertEqual(pg.dels, 2)
|
||||
|
||||
x = torch.rand(2, 2)
|
||||
x = funcol.all_gather_tensor(x, 0, group=pg)
|
||||
gc.collect()
|
||||
self.assertEqual(pg.dels, 2)
|
||||
x.wait()
|
||||
self.assertEqual(pg.waits, 3)
|
||||
self.assertEqual(pg.dels, 3)
|
||||
|
||||
x = torch.rand(2, 2)
|
||||
x = funcol.reduce_scatter_tensor(x, "sum", 0, group=pg)
|
||||
gc.collect()
|
||||
self.assertEqual(pg.dels, 3)
|
||||
x.wait()
|
||||
self.assertEqual(pg.waits, 4)
|
||||
self.assertEqual(pg.dels, 4)
|
||||
|
||||
|
||||
class CompileTest(TestCase):
|
||||
def setUp(self):
|
||||
# Allow testing aoti after torch.compile
|
||||
|
|
|
|||
|
|
@ -41,19 +41,48 @@ class PyProcessGroup : public ProcessGroup {
|
|||
|
||||
return Work::getFuture();
|
||||
}
|
||||
};
|
||||
|
||||
// Take a reference of the corresponding py::object.
|
||||
// With functional collectives, ownership of work objects is generally
|
||||
// transferred to C++. For pure C++ work objects, it is sufficient to
|
||||
// transfer the ownership of work object. For user-defined work objects in
|
||||
// Python, it is necessary to keep the corresponding py::object alive in
|
||||
// addition to ensure that the user-defined methods can be executed.
|
||||
void ref_py_object() {
|
||||
py_obj_ = py::cast(this);
|
||||
#define WORK_OVERRIDE(cname, name, ...) \
|
||||
do { \
|
||||
pybind11::gil_scoped_acquire gil; \
|
||||
pybind11::function override = \
|
||||
pybind11::get_override(static_cast<const cname*>(this), #name); \
|
||||
if (override) { \
|
||||
auto o = override(__VA_ARGS__); \
|
||||
return c10::make_intrusive<PyWorkHolder>(o); \
|
||||
} \
|
||||
return cname::name(__VA_ARGS__); \
|
||||
} while (false)
|
||||
|
||||
// This class is used to wrap a PyWork trampoline with it's corresponding
|
||||
// Python object to prevent the Python object from being garbage collected.
|
||||
class PyWorkHolder : public Work {
|
||||
public:
|
||||
PyWorkHolder(const c10::intrusive_ptr<Work>& work, py::object pyWork)
|
||||
: work_(work), pyWork_(std::move(pyWork)) {}
|
||||
|
||||
PyWorkHolder(py::object pyWork)
|
||||
: work_(pyWork.cast<c10::intrusive_ptr<Work>>()),
|
||||
pyWork_(std::move(pyWork)) {}
|
||||
|
||||
~PyWorkHolder() override {
|
||||
// GIL must be held when freeing python objects.
|
||||
py::gil_scoped_acquire gil;
|
||||
pyWork_ = py::object();
|
||||
}
|
||||
|
||||
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override {
|
||||
return work_->wait(timeout);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
|
||||
return work_->getFuture();
|
||||
}
|
||||
|
||||
private:
|
||||
py::object py_obj_;
|
||||
c10::intrusive_ptr<Work> work_;
|
||||
py::object pyWork_;
|
||||
};
|
||||
|
||||
using ProcessGroup::ProcessGroup;
|
||||
|
|
@ -118,8 +147,7 @@ class PyProcessGroup : public ProcessGroup {
|
|||
std::vector<std::vector<at::Tensor>>& outputTensors,
|
||||
std::vector<at::Tensor>& inputTensors,
|
||||
const AllgatherOptions& opts = AllgatherOptions()) override {
|
||||
PYBIND11_OVERRIDE(
|
||||
c10::intrusive_ptr<Work>, /* Return type */
|
||||
WORK_OVERRIDE(
|
||||
ProcessGroup, /* Parent class */
|
||||
allgather, /* Name of function in C++ */
|
||||
outputTensors,
|
||||
|
|
@ -131,8 +159,7 @@ class PyProcessGroup : public ProcessGroup {
|
|||
std::vector<at::Tensor>& outputTensors,
|
||||
std::vector<at::Tensor>& inputTensors,
|
||||
const AllgatherOptions& opts = AllgatherOptions()) override {
|
||||
PYBIND11_OVERRIDE(
|
||||
c10::intrusive_ptr<Work>, /* Return type */
|
||||
WORK_OVERRIDE(
|
||||
ProcessGroup, /* Parent class */
|
||||
allgather_into_tensor_coalesced, /* Name of function in C++ */
|
||||
outputTensors,
|
||||
|
|
@ -143,8 +170,8 @@ class PyProcessGroup : public ProcessGroup {
|
|||
c10::intrusive_ptr<Work> allreduce(
|
||||
std::vector<at::Tensor>& tensors,
|
||||
const AllreduceOptions& opts = AllreduceOptions()) override {
|
||||
PYBIND11_OVERRIDE(
|
||||
c10::intrusive_ptr<Work>, /* Return type */
|
||||
WORK_OVERRIDE(
|
||||
// py::object, /* Return type */
|
||||
ProcessGroup, /* Parent class */
|
||||
allreduce, /* Name of function in C++ */
|
||||
tensors,
|
||||
|
|
@ -155,8 +182,7 @@ class PyProcessGroup : public ProcessGroup {
|
|||
std::vector<at::Tensor>& tensors,
|
||||
const AllreduceCoalescedOptions& opts =
|
||||
AllreduceCoalescedOptions()) override {
|
||||
PYBIND11_OVERRIDE(
|
||||
c10::intrusive_ptr<Work>, /* Return type */
|
||||
WORK_OVERRIDE(
|
||||
ProcessGroup, /* Parent class */
|
||||
allreduce_coalesced, /* Name of function in C++ */
|
||||
tensors,
|
||||
|
|
@ -169,8 +195,7 @@ class PyProcessGroup : public ProcessGroup {
|
|||
std::vector<int64_t>& outputSplitSizes,
|
||||
std::vector<int64_t>& inputSplitSizes,
|
||||
const AllToAllOptions& opts = AllToAllOptions()) override {
|
||||
PYBIND11_OVERRIDE(
|
||||
c10::intrusive_ptr<Work>, /* Return type */
|
||||
WORK_OVERRIDE(
|
||||
ProcessGroup, /* Parent class */
|
||||
alltoall_base, /* Name of function in C++ */
|
||||
outputBuffer,
|
||||
|
|
@ -182,8 +207,7 @@ class PyProcessGroup : public ProcessGroup {
|
|||
|
||||
c10::intrusive_ptr<Work> barrier(
|
||||
const BarrierOptions& opts = BarrierOptions()) override {
|
||||
PYBIND11_OVERRIDE(
|
||||
c10::intrusive_ptr<Work>, /* Return type */
|
||||
WORK_OVERRIDE(
|
||||
ProcessGroup, /* Parent class */
|
||||
barrier, /* Name of function in C++ */
|
||||
opts);
|
||||
|
|
@ -192,8 +216,7 @@ class PyProcessGroup : public ProcessGroup {
|
|||
c10::intrusive_ptr<Work> broadcast(
|
||||
std::vector<at::Tensor>& tensors,
|
||||
const BroadcastOptions& opts = BroadcastOptions()) override {
|
||||
PYBIND11_OVERRIDE(
|
||||
c10::intrusive_ptr<Work>, /* Return type */
|
||||
WORK_OVERRIDE(
|
||||
ProcessGroup, /* Parent class */
|
||||
broadcast, /* Name of function in C++ */
|
||||
tensors,
|
||||
|
|
@ -204,8 +227,7 @@ class PyProcessGroup : public ProcessGroup {
|
|||
std::vector<at::Tensor>& outputTensors,
|
||||
std::vector<std::vector<at::Tensor>>& inputTensors,
|
||||
const ReduceScatterOptions& opts = ReduceScatterOptions()) override {
|
||||
PYBIND11_OVERRIDE(
|
||||
c10::intrusive_ptr<Work>, /* Return type */
|
||||
WORK_OVERRIDE(
|
||||
ProcessGroup, /* Parent class */
|
||||
reduce_scatter, /* Name of function in C++ */
|
||||
outputTensors,
|
||||
|
|
@ -217,8 +239,7 @@ class PyProcessGroup : public ProcessGroup {
|
|||
std::vector<at::Tensor>& outputTensors,
|
||||
std::vector<at::Tensor>& inputTensors,
|
||||
const ReduceScatterOptions& opts = ReduceScatterOptions()) override {
|
||||
PYBIND11_OVERRIDE(
|
||||
c10::intrusive_ptr<Work>, /* Return type */
|
||||
WORK_OVERRIDE(
|
||||
ProcessGroup, /* Parent class */
|
||||
reduce_scatter_tensor_coalesced, /* Name of function in C++ */
|
||||
outputTensors,
|
||||
|
|
@ -230,8 +251,7 @@ class PyProcessGroup : public ProcessGroup {
|
|||
std::vector<at::Tensor>& tensors,
|
||||
int dstRank,
|
||||
int tag) override {
|
||||
PYBIND11_OVERRIDE(
|
||||
c10::intrusive_ptr<Work>, /* Return type */
|
||||
WORK_OVERRIDE(
|
||||
ProcessGroup, /* Parent class */
|
||||
send, /* Name of function in C++ */
|
||||
tensors,
|
||||
|
|
@ -243,8 +263,7 @@ class PyProcessGroup : public ProcessGroup {
|
|||
std::vector<at::Tensor>& tensors,
|
||||
int srcRank,
|
||||
int tag) override {
|
||||
PYBIND11_OVERRIDE(
|
||||
c10::intrusive_ptr<Work>, /* Return type */
|
||||
WORK_OVERRIDE(
|
||||
ProcessGroup, /* Parent class */
|
||||
recv, /* Name of function in C++ */
|
||||
tensors,
|
||||
|
|
|
|||
|
|
@ -953,9 +953,10 @@ This class does not support ``__members__`` property.)");
|
|||
"_register_work",
|
||||
[](const at::Tensor& tensor,
|
||||
const c10::intrusive_ptr<::c10d::Work>& work) {
|
||||
dynamic_cast<::c10d::PyProcessGroup::PyWork*>(work.get())
|
||||
->ref_py_object();
|
||||
::c10d::register_work(tensor, work);
|
||||
py::object obj = py::cast(work);
|
||||
auto holder = c10::make_intrusive<::c10d::PyProcessGroup::PyWorkHolder>(
|
||||
work, obj);
|
||||
::c10d::register_work(tensor, holder);
|
||||
},
|
||||
py::arg("tensor"),
|
||||
py::arg("work"));
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user