mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
These happen when building with CMAKE_BUILD_TYPE=RelWithAssert This should fix two types of failures that started with https://github.com/pytorch/pytorch/pull/163665 Disclaimer that I used a lot of AI since I don't how pybind works or what refcounts and pointers are, so idk if this is a good solution, or even a solution at all (fwiw the tests pass now) The first one type is Truncated: ``` default_pg, _ = _new_process_group_helper( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2096, in _new_process_group_helper backend_class = creator_fn(dist_backend_opts, backend_options) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/distributed/fake_pg.py", line 25, in _create_fake_pg return FakeProcessGroup._create_internal( RuntimeError: new_refcount != 1 INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/c10/util/intrusive_ptr.h":319, please report a bug to PyTorch. intrusive_ptr: Cannot increase refcount after it reached zero. Exception raised from retain_ at /var/lib/jenkins/workspace/c10/util/intrusive_ptr.h:319 (most recent call first): C++ CapturedTraceback: #4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0 #5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0 #6 c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) from ??:0 #7 c10::detail::torchInternalAssertFail(char const*, char const*, unsigned int, char const*, char const*) from ??:0 #8 void pybind11::class_<c10d::FakeProcessGroup, (anonymous namespace)::IntrusivePtrNoGilDestructor<c10d::FakeProcessGroup> >::init_instance<(anonymous namespace)::IntrusivePtrNoGilDestructor<c10d::FakeProcessGroup>, 0>(pybind11::detail::instance*, void const*) from init.cpp:0 #9 pybind11::detail::type_caster_generic::cast(void const*, pybind11::return_value_policy, pybind11::handle, pybind11::detail::type_info const*, void* (*)(void const*), void* (*)(void const*), void const*) from :0 #10 pybind11::cpp_function::initialize<torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >)#127}, c10::intrusive_ptr<c10d::FakeProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup> >, int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >, pybind11::name, pybind11::scope, pybind11::sibling, pybind11::arg, pybind11::arg, pybind11::arg_v>(torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >)#127}&&, c10::intrusive_ptr<c10d::FakeProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup> > (*)(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) from init.cpp:0 ``` and I fix it here by getting rid of `DontIncreaseRefcount` and using make_intrusive to do the ref count handling instead. However, I also had to move the constructor to be public, which I think is not good, based on the reasoning of the original PR The other one type is ``` Traceback (most recent call last): File "/var/lib/jenkins/workspace/test/test_testing.py", line 2415, in test_no_warning_on_import self.assertEqual(out, "") File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 4233, in assertEqual raise error_metas.pop()[0].to_error( # type: ignore[index] AssertionError: String comparison failed: "/opt/conda/envs/py_3.10/lib/python3.10/s[352 chars]):\n" != '' - /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/__init__.py:29: FutureWarning: pybind11-bound class 'torch._C._distributed_c10d.FakeProcessGroup' is using an old-style placement-new '__init__' which has been deprecated. See the upgrade guide in pybind11's docs. This message is only visible when compiled in debug mode. - if is_available() and not torch._C._c10d_init(): To execute this test, run the following from the base repo dir: python test/test_testing.py TestImports.test_no_warning_on_import ``` which I fix by getting rid of the `__init__` which I think is ok since it'll just error if you try to make one? Pull Request resolved: https://github.com/pytorch/pytorch/pull/165479 Approved by: https://github.com/ezyang
254 lines
8.0 KiB
C++
254 lines
8.0 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
|
#include <torch/csrc/utils.h>
|
|
|
|
namespace c10d {
|
|
|
|
class FakeWork : public Work {
|
|
public:
|
|
int seq_id = -1;
|
|
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override {
|
|
return true;
|
|
}
|
|
|
|
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
|
|
auto fut = c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
|
|
fut->markCompleted();
|
|
return fut;
|
|
}
|
|
};
|
|
|
|
class FakeProcessGroup : public Backend {
|
|
public:
|
|
struct Options : Backend::Options {
|
|
explicit Options() : Backend::Options("fake") {}
|
|
|
|
int fake_option = 0;
|
|
bool error_on_collective = false;
|
|
};
|
|
|
|
// Static factory method for official APIs
|
|
static c10::intrusive_ptr<FakeProcessGroup> _create_internal(
|
|
int rank,
|
|
int size,
|
|
c10::intrusive_ptr<Options> options = c10::make_intrusive<Options>()) {
|
|
return c10::make_intrusive<FakeProcessGroup>(
|
|
rank, size, std::move(options));
|
|
}
|
|
|
|
const std::string getBackendName() const override {
|
|
return "fake";
|
|
}
|
|
|
|
c10::intrusive_ptr<Backend::Options> getBackendOptions() override {
|
|
return c10::static_intrusive_pointer_cast<Backend::Options>(options_);
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> broadcast(
|
|
std::vector<at::Tensor>& /* tensors */,
|
|
const BroadcastOptions& /* opts */ = BroadcastOptions()) override {
|
|
checkCollectiveError();
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> allreduce(
|
|
std::vector<at::Tensor>& /* tensors */,
|
|
const AllreduceOptions& /* opts */ = AllreduceOptions()) override {
|
|
checkCollectiveError();
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> allreduce_sparse(
|
|
std::vector<at::Tensor>& /* tensors */,
|
|
const AllreduceOptions& /* opts */ = AllreduceOptions()) override {
|
|
checkCollectiveError();
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> allreduce_coalesced(
|
|
std::vector<at::Tensor>& /* tensors */,
|
|
const AllreduceCoalescedOptions& /* opts */ =
|
|
AllreduceCoalescedOptions()) override {
|
|
checkCollectiveError();
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> reduce(
|
|
std::vector<at::Tensor>& /* tensors */,
|
|
const ReduceOptions& /* opts */ = ReduceOptions()) override {
|
|
checkCollectiveError();
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
// NOTE [allgather on FakeProcessGroup]
|
|
// Assume each rank have the same input tensor so we just copy to the results
|
|
// since it's not a real allgather, we simply make this copying logic to let
|
|
// some simple validation works (i.e. calling allgather to see if each rank
|
|
// have the same tensor or not).
|
|
//
|
|
// NOTE: in general it's not good form to try to make FakeProcessGroup work
|
|
// with real data, but the reasoning here is that we want FakeProcessGroup to
|
|
// work with DeviceMesh's init code that have the data validation, which
|
|
// makes it worth the tradeoff.
|
|
c10::intrusive_ptr<Work> allgather(
|
|
std::vector<std::vector<at::Tensor>>& outputTensors,
|
|
std::vector<at::Tensor>& inputTensors,
|
|
const AllgatherOptions& /* opts */ = AllgatherOptions()) override {
|
|
checkCollectiveError();
|
|
for (auto& tensor : outputTensors[0]) {
|
|
tensor.copy_(inputTensors[0]);
|
|
}
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> _allgather_base(
|
|
at::Tensor& outputBuffer,
|
|
at::Tensor& inputBuffer,
|
|
const AllgatherOptions& /* opts */ = AllgatherOptions()) override {
|
|
checkCollectiveError();
|
|
auto chunks = outputBuffer.chunk(size_);
|
|
for (auto& tensor : chunks) {
|
|
tensor.copy_(inputBuffer);
|
|
}
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> allgather_coalesced(
|
|
std::vector<std::vector<at::Tensor>>& /* outputTensorLists */,
|
|
std::vector<at::Tensor>& /* inputTensors */,
|
|
const AllgatherOptions& /* opts */ = AllgatherOptions()) override {
|
|
checkCollectiveError();
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> allgather_into_tensor_coalesced(
|
|
std::vector<at::Tensor>& outputs,
|
|
std::vector<at::Tensor>& inputs,
|
|
const AllgatherOptions& /* opts */ = AllgatherOptions()) override {
|
|
checkCollectiveError();
|
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
|
auto chunks = outputs[i].chunk(size_);
|
|
for (auto& chunk : chunks) {
|
|
chunk.copy_(inputs[i]);
|
|
}
|
|
}
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> gather(
|
|
std::vector<std::vector<at::Tensor>>& /* outputTensors */,
|
|
std::vector<at::Tensor>& /* inputTensors */,
|
|
const GatherOptions& /* opts */ = GatherOptions()) override {
|
|
checkCollectiveError();
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> scatter(
|
|
std::vector<at::Tensor>& /* outputTensors */,
|
|
std::vector<std::vector<at::Tensor>>& /* inputTensors */,
|
|
const ScatterOptions& /* opts */ = ScatterOptions()) override {
|
|
checkCollectiveError();
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> reduce_scatter(
|
|
std::vector<at::Tensor>& /* outputTensors */,
|
|
std::vector<std::vector<at::Tensor>>& /* inputTensors */,
|
|
const ReduceScatterOptions& /* opts */ =
|
|
ReduceScatterOptions()) override {
|
|
checkCollectiveError();
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> _reduce_scatter_base(
|
|
at::Tensor& /* outputBuffer */,
|
|
at::Tensor& /* inputBuffer */,
|
|
const ReduceScatterOptions& /* opts */ =
|
|
ReduceScatterOptions()) override {
|
|
checkCollectiveError();
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced(
|
|
std::vector<at::Tensor>& /* outputs */,
|
|
std::vector<at::Tensor>& /* inputs */,
|
|
const ReduceScatterOptions& /* opts */ =
|
|
ReduceScatterOptions()) override {
|
|
checkCollectiveError();
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> alltoall_base(
|
|
at::Tensor& /* outputBuffer */,
|
|
at::Tensor& /* inputBuffer */,
|
|
std::vector<int64_t>& /* outputSplitSizes */,
|
|
std::vector<int64_t>& /* inputSplitSizes */,
|
|
const AllToAllOptions& /* opts */ = AllToAllOptions()) override {
|
|
checkCollectiveError();
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> alltoall(
|
|
std::vector<at::Tensor>& /* outputTensors */,
|
|
std::vector<at::Tensor>& /* inputTensors */,
|
|
const AllToAllOptions& opts = AllToAllOptions()) override {
|
|
checkCollectiveError();
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> send(
|
|
std::vector<at::Tensor>& /* tensors */,
|
|
int /* dstRank */,
|
|
int /* tag */) override {
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> recv(
|
|
std::vector<at::Tensor>& /* tensors */,
|
|
int /* srcRank */,
|
|
int /* tag */) override {
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> recvAnysource(
|
|
std::vector<at::Tensor>& /* tensors */,
|
|
int /* tag */) override {
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
void startCoalescing() override {
|
|
// No-op
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> endCoalescing(OpType /* optype */) {
|
|
checkCollectiveError();
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> endCoalescing() override {
|
|
checkCollectiveError();
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> barrier(
|
|
const BarrierOptions& /* opts */ = BarrierOptions()) override {
|
|
checkCollectiveError();
|
|
return c10::make_intrusive<FakeWork>();
|
|
}
|
|
|
|
// Private constructor used by official APIs
|
|
FakeProcessGroup(int rank, int size, c10::intrusive_ptr<Options> options)
|
|
: Backend(rank, size), options_(std::move(options)) {}
|
|
c10::intrusive_ptr<Options> options_;
|
|
|
|
private:
|
|
void checkCollectiveError() {
|
|
TORCH_CHECK(
|
|
!options_ || !options_->error_on_collective,
|
|
"FakeProcessGroup collective operation error (error_on_collective=true)");
|
|
}
|
|
};
|
|
|
|
} // namespace c10d
|