mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Implements https://github.com/pytorch/pytorch/issues/51075#issuecomment-768884685 and additions discussed offline with ezyang ngimel . (Calling it "simple" is charitable but it's not too bad). [High level strategy](https://github.com/pytorch/pytorch/pull/51436/files#diff-acc6337586bf9cdcf0a684380779300ec171897d05b8569bf439820dc8c93bd5R57-R82) The current design aggregates stats from private pools with the ordinary pools, which may or may not be what we want. Instead of adding PrivatePools as an internal feature of DeviceAllocator, I could inherit from DeviceAllocator (eg `DevicePrivateAllocator : public DeviceAllocator`) and create separate per-graph instances of the inherited class. I'm not sure if that would be better. Graph bindings in Python are almost unchanged from https://github.com/pytorch/pytorch/pull/48875: ```python # Same bindings as 48875, but now implicitly grabs a private mempool graph1.capture_begin() graph1.capture_end() # pool=... is new. It hints that allocations during graph2's capture may share graph1's mempool graph2.capture_begin(pool=graph1.pool()) graph2.capture_end() # graph3 also implicitly creates its own mempool graph3.capture_begin() graph3.capture_end() ``` Test plan (other suggestions appreciated): - [x] Stop maintaining manual references for all the tensors in my existing graphs+RNG tests. If private pools somehow give bad allocations, they should start failing intermittently. They run eager ops and eager allocations mixed with graph replays, so they may expose if eager ops and replays corrupt each other. - [x] `test_graph_two_successive`: Capture successive graphs, with the second graph using the first graph's result. Try with and without sharing a pool. Check results, also check memory stats to confirm sharing a pool saves memory. - [x] `test_graph_concurrent_replay`: Capture some graphs in separate private pools, replay them concurrently in different streams, check the results to make sure they don't corrupt each other's memory. Capture some graphs with a shared pool, replay them concurrently in different streams, check results, confirm they DO corrupt each other's memory. - [x] `test_graph_three_successive`: A three-graph case, checking the safe and unsafe replay patterns in [Restrictions of the Strawman API](https://github.com/pytorch/pytorch/issues/51075)). - [x] `test_graph_memory_stats_and_use_result_after_destroy_graph`: Comprehensively check torch.cuda.memory_stats() changes that result from graph capture and delete. Check that a tensor ref created during capture and held after graph delete stays valid until the tensor itself is deleted. Pull Request resolved: https://github.com/pytorch/pytorch/pull/51436 Reviewed By: mruberry Differential Revision: D26993790 Pulled By: ngimel fbshipit-source-id: a992eaee1b8c23628e7b388a5a3c26e0f80e54da
59 lines
2.4 KiB
C++
59 lines
2.4 KiB
C++
#include <torch/csrc/python_headers.h>
|
|
|
|
#include <pybind11/chrono.h>
|
|
|
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
|
|
#include <ATen/cuda/CUDAGraph.h>
|
|
|
|
// Cargo culted partially from csrc/distributed/c10d/init.cpp
|
|
// and partially from csrc/cuda/Stream.cpp.
|
|
// THCPStream_init is also declared at global scope.
|
|
|
|
// Because THCPGraph_init is forward declared in the only consumer (csrc/Module.cpp)
|
|
// I don't think we need a Graph.h.
|
|
|
|
template <typename T>
|
|
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
|
|
|
|
void THCPGraph_init(PyObject *module) {
|
|
// Pybind11 patch notes say "py::module_" is more up-to-date syntax,
|
|
// but CI linter and some builds prefer "module".
|
|
auto torch_C_m = py::handle(module).cast<py::module>();
|
|
|
|
torch_C_m
|
|
.def("_graph_pool_handle", &::at::cuda::graph_pool_handle);
|
|
|
|
shared_ptr_class_<::at::cuda::CUDAGraph>(torch_C_m, "_CudaGraphBase")
|
|
.def(py::init<>())
|
|
// I'm not sure this is the correct order of all the arguments. Pybind11 docs
|
|
// aren't clear. But it works.
|
|
.def("capture_begin",
|
|
&::at::cuda::CUDAGraph::capture_begin,
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(``capture_begin`` begins Cuda graph capture on the current stream.)",
|
|
py::arg("pool") = c10::cuda::MempoolId_t{0, 0})
|
|
.def("capture_end",
|
|
&::at::cuda::CUDAGraph::capture_end,
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(``capture_end`` ends Cuda graph capture on the current stream.
|
|
After ``capture_end``, ``replay`` may be called on this instance.)")
|
|
.def("replay",
|
|
&::at::cuda::CUDAGraph::replay,
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(``replay`` replays the Cuda graph captured by this instance.)")
|
|
// reset is called in __del__ on the Python side
|
|
// (see class Graph in torch/cuda/streams.py for reasons and caveats)
|
|
.def("reset",
|
|
&::at::cuda::CUDAGraph::reset,
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(``reset`` deletes the graph currently held by this instance.)")
|
|
.def("pool",
|
|
&::at::cuda::CUDAGraph::pool,
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(``pool`` retrieves the id of this graph's memory pool.
|
|
This id can optionally be passed to another graph's capture_begin,
|
|
which hints that other graph may share the same memory pool.)");
|
|
}
|