mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Another attempt to update NVTX to NVTX3. We now avoid changing NVTX header inclusion of existing code. Pull Request resolved: https://github.com/pytorch/pytorch/pull/109843 Approved by: https://github.com/peterbell10
21 lines
545 B
C++
21 lines
545 B
C++
#ifdef _WIN32
|
|
#include <wchar.h> // _wgetenv for nvtx
|
|
#endif
|
|
#include <nvtx3/nvToolsExt.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
|
|
namespace torch::cuda::shared {
|
|
|
|
void initNvtxBindings(PyObject* module) {
|
|
auto m = py::handle(module).cast<py::module>();
|
|
|
|
auto nvtx = m.def_submodule("_nvtx", "nvtx3 bindings");
|
|
nvtx.def("rangePushA", nvtxRangePushA);
|
|
nvtx.def("rangePop", nvtxRangePop);
|
|
nvtx.def("rangeStartA", nvtxRangeStartA);
|
|
nvtx.def("rangeEnd", nvtxRangeEnd);
|
|
nvtx.def("markA", nvtxMarkA);
|
|
}
|
|
|
|
} // namespace torch::cuda::shared
|