#pragma once #include #include namespace c10d { class FakeWork : public Work { public: int seq_id = -1; bool wait(std::chrono::milliseconds timeout = kNoTimeout) override { return true; } c10::intrusive_ptr getFuture() override { auto fut = c10::make_intrusive(c10::NoneType::get()); fut->markCompleted(); return fut; } }; class FakeProcessGroup : public Backend { public: struct Options : Backend::Options { explicit Options() : Backend::Options("fake") {} int fake_option = 0; bool error_on_collective = false; }; // Static factory method for official APIs static c10::intrusive_ptr _create_internal( int rank, int size, c10::intrusive_ptr options = c10::make_intrusive()) { return c10::make_intrusive( rank, size, std::move(options)); } const std::string getBackendName() const override { return "fake"; } c10::intrusive_ptr getBackendOptions() override { return c10::static_intrusive_pointer_cast(options_); } c10::intrusive_ptr broadcast( std::vector& /* tensors */, const BroadcastOptions& /* opts */ = BroadcastOptions()) override { checkCollectiveError(); return c10::make_intrusive(); } c10::intrusive_ptr allreduce( std::vector& /* tensors */, const AllreduceOptions& /* opts */ = AllreduceOptions()) override { checkCollectiveError(); return c10::make_intrusive(); } c10::intrusive_ptr allreduce_sparse( std::vector& /* tensors */, const AllreduceOptions& /* opts */ = AllreduceOptions()) override { checkCollectiveError(); return c10::make_intrusive(); } c10::intrusive_ptr allreduce_coalesced( std::vector& /* tensors */, const AllreduceCoalescedOptions& /* opts */ = AllreduceCoalescedOptions()) override { checkCollectiveError(); return c10::make_intrusive(); } c10::intrusive_ptr reduce( std::vector& /* tensors */, const ReduceOptions& /* opts */ = ReduceOptions()) override { checkCollectiveError(); return c10::make_intrusive(); } // NOTE [allgather on FakeProcessGroup] // Assume each rank have the same input tensor so we just copy to the results // since it's not a real allgather, we simply make this copying logic to let // some simple validation works (i.e. calling allgather to see if each rank // have the same tensor or not). // // NOTE: in general it's not good form to try to make FakeProcessGroup work // with real data, but the reasoning here is that we want FakeProcessGroup to // work with DeviceMesh's init code that have the data validation, which // makes it worth the tradeoff. c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& /* opts */ = AllgatherOptions()) override { checkCollectiveError(); for (auto& tensor : outputTensors[0]) { tensor.copy_(inputTensors[0]); } return c10::make_intrusive(); } c10::intrusive_ptr _allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& /* opts */ = AllgatherOptions()) override { checkCollectiveError(); auto chunks = outputBuffer.chunk(size_); for (auto& tensor : chunks) { tensor.copy_(inputBuffer); } return c10::make_intrusive(); } c10::intrusive_ptr allgather_coalesced( std::vector>& /* outputTensorLists */, std::vector& /* inputTensors */, const AllgatherOptions& /* opts */ = AllgatherOptions()) override { checkCollectiveError(); return c10::make_intrusive(); } c10::intrusive_ptr allgather_into_tensor_coalesced( std::vector& outputs, std::vector& inputs, const AllgatherOptions& /* opts */ = AllgatherOptions()) override { checkCollectiveError(); for (size_t i = 0; i < outputs.size(); ++i) { auto chunks = outputs[i].chunk(size_); for (auto& chunk : chunks) { chunk.copy_(inputs[i]); } } return c10::make_intrusive(); } c10::intrusive_ptr gather( std::vector>& /* outputTensors */, std::vector& /* inputTensors */, const GatherOptions& /* opts */ = GatherOptions()) override { checkCollectiveError(); return c10::make_intrusive(); } c10::intrusive_ptr scatter( std::vector& /* outputTensors */, std::vector>& /* inputTensors */, const ScatterOptions& /* opts */ = ScatterOptions()) override { checkCollectiveError(); return c10::make_intrusive(); } c10::intrusive_ptr reduce_scatter( std::vector& /* outputTensors */, std::vector>& /* inputTensors */, const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) override { checkCollectiveError(); return c10::make_intrusive(); } c10::intrusive_ptr _reduce_scatter_base( at::Tensor& /* outputBuffer */, at::Tensor& /* inputBuffer */, const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) override { checkCollectiveError(); return c10::make_intrusive(); } c10::intrusive_ptr reduce_scatter_tensor_coalesced( std::vector& /* outputs */, std::vector& /* inputs */, const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) override { checkCollectiveError(); return c10::make_intrusive(); } c10::intrusive_ptr alltoall_base( at::Tensor& /* outputBuffer */, at::Tensor& /* inputBuffer */, std::vector& /* outputSplitSizes */, std::vector& /* inputSplitSizes */, const AllToAllOptions& /* opts */ = AllToAllOptions()) override { checkCollectiveError(); return c10::make_intrusive(); } c10::intrusive_ptr alltoall( std::vector& /* outputTensors */, std::vector& /* inputTensors */, const AllToAllOptions& opts = AllToAllOptions()) override { checkCollectiveError(); return c10::make_intrusive(); } c10::intrusive_ptr send( std::vector& /* tensors */, int /* dstRank */, int /* tag */) override { return c10::make_intrusive(); } c10::intrusive_ptr recv( std::vector& /* tensors */, int /* srcRank */, int /* tag */) override { return c10::make_intrusive(); } c10::intrusive_ptr recvAnysource( std::vector& /* tensors */, int /* tag */) override { return c10::make_intrusive(); } void startCoalescing() override { // No-op } c10::intrusive_ptr endCoalescing(OpType /* optype */) { checkCollectiveError(); return c10::make_intrusive(); } c10::intrusive_ptr endCoalescing() override { checkCollectiveError(); return c10::make_intrusive(); } c10::intrusive_ptr barrier( const BarrierOptions& /* opts */ = BarrierOptions()) override { checkCollectiveError(); return c10::make_intrusive(); } // Private constructor used by official APIs FakeProcessGroup(int rank, int size, c10::intrusive_ptr options) : Backend(rank, size), options_(std::move(options)) {} c10::intrusive_ptr options_; private: void checkCollectiveError() { TORCH_CHECK( !options_ || !options_->error_on_collective, "FakeProcessGroup collective operation error (error_on_collective=true)"); } }; } // namespace c10d