pytorch/torch/csrc/autograd/init.cpp
Zachary DeVito cc7f09a372
Add cudaEvent support to the profiler (#3734)
* Add cudaEvent support to the profiler

This adds the ability to record cuda timings using cudaEventRecord
in the profiler. Since it doesn't require nvprof it is easier
to run than the nvprof path.

This also records a thread id for each event, which will make
tracing results easier to understand

* Add flow arrows from cpu to cuda event

* Fix no cuda build

* Review comments

* Move CUDA checks to one place
2017-11-16 13:58:09 -08:00

63 lines
2.5 KiB
C++

#include <Python.h>
#include "torch/csrc/utils/pybind.h"
#include "torch/csrc/autograd/profiler.h"
#include "THP.h"
#ifdef _MSC_VER
#define ENSURE_UNREACHABLE __assume(0);
#else
#define ENSURE_UNREACHABLE __builtin_unreachable();
#endif
PyObject * THPAutograd_initExtension(PyObject *_unused)
{
THPUtils_assert_PyImport("torch.autograd", autograd_module);
PyObject *autograd_dict = PyModule_GetDict(autograd_module);
THPVariableClass = PyMapping_GetItemString(autograd_dict,(char*)"Variable");
THPFunctionClass = PyMapping_GetItemString(autograd_dict,(char*)"Function");
THPUtils_assert_PyImport("torch.nn._functions.thnn", thnn_functions);
THPBatchNormBackwardBackwardFunction = PyObject_GetAttrString(thnn_functions,(char*)"batchnorm_double_backwards_fn");
THPStochasticFunctionClass = PyMapping_GetItemString(autograd_dict,(char*)"StochasticFunction");
THPUtils_assert(THPVariableClass, "couldn't find Variable class in "
"torch.autograd module");
THPUtils_assert(THPFunctionClass, "couldn't find Function class in "
"torch.autograd module");
THPUtils_assert(THPStochasticFunctionClass, "couldn't find "
"StochasticFunction class in torch.autograd module");
auto m = py::handle(autograd_module).cast<py::module>();
py::class_<torch::autograd::profiler::Event>(m,"ProfilerEvent")
.def("kind",&torch::autograd::profiler::Event::kind)
.def("name",&torch::autograd::profiler::Event::name)
.def("thread_id",&torch::autograd::profiler::Event::thread_id)
.def("cpu_elapsed_us",&torch::autograd::profiler::Event::cpu_elapsed_us)
.def("cuda_elapsed_us",&torch::autograd::profiler::Event::cuda_elapsed_us)
.def("has_cuda",&torch::autograd::profiler::Event::has_cuda);
py::enum_<torch::autograd::profiler::ProfilerState>(m,"ProfilerState")
.value("Disabled", torch::autograd::profiler::ProfilerState::Disabled)
.value("CPU", torch::autograd::profiler::ProfilerState::CPU)
.value("CUDA", torch::autograd::profiler::ProfilerState::CUDA)
.value("NVTX", torch::autograd::profiler::ProfilerState::NVTX);
m.def("_enable_profiler", torch::autograd::profiler::enableProfiler);
m.def("_disable_profiler", torch::autograd::profiler::disableProfiler);
m.def("_push_range", [](const char *name) {
using namespace torch::autograd::profiler;
if (state == ProfilerState::Disabled) return;
pushRange(name);
});
m.def("_pop_range", []() {
using namespace torch::autograd::profiler;
if (state == ProfilerState::Disabled) return;
popRange();
});
Py_RETURN_TRUE;
}