mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This reverts commit e64ddd1ab9.
Reverted https://github.com/pytorch/pytorch/pull/90689 on behalf of https://github.com/osalpekar due to Build Failures due to not being able to find one nvtx3 header in FRL jobs: [D42332540](https://www.internalfb.com/diff/D42332540)
25 lines
600 B
C++
25 lines
600 B
C++
#ifdef _WIN32
|
|
#include <wchar.h> // _wgetenv for nvtx
|
|
#endif
|
|
#include <nvToolsExt.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
|
|
namespace torch {
|
|
namespace cuda {
|
|
namespace shared {
|
|
|
|
void initNvtxBindings(PyObject* module) {
|
|
auto m = py::handle(module).cast<py::module>();
|
|
|
|
auto nvtx = m.def_submodule("_nvtx", "libNvToolsExt.so bindings");
|
|
nvtx.def("rangePushA", nvtxRangePushA);
|
|
nvtx.def("rangePop", nvtxRangePop);
|
|
nvtx.def("rangeStartA", nvtxRangeStartA);
|
|
nvtx.def("rangeEnd", nvtxRangeEnd);
|
|
nvtx.def("markA", nvtxMarkA);
|
|
}
|
|
|
|
} // namespace shared
|
|
} // namespace cuda
|
|
} // namespace torch
|