mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[Profiler] Pull helper methods into dedicated file. (And start torch/csrc/profiler folder. (#69255)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69255 One thing that I've found as I optimize profier is that there's a lot of intermingled code, where the kineto profiler relies on the legacy (autograd) profiler for generic operations. This made optimization hard because I had to manage too many complex dependencies. (Exaserbated by the USE_KINETO #ifdef's sprinkled around.) This PR is the first of several to restructure the profiler(s) so the later optimizations go in easier. Test Plan: Unit tests Reviewed By: aaronenyeshi Differential Revision: D32671972 fbshipit-source-id: efa83b40dde4216f368f2a5fa707360031a85707
This commit is contained in:
parent
b23890177f
commit
ebc66bfeea
1
setup.py
1
setup.py
|
|
@ -1029,6 +1029,7 @@ if __name__ == '__main__':
|
|||
'include/torch/csrc/jit/tensorexpr/*.h',
|
||||
'include/torch/csrc/jit/tensorexpr/operators/*.h',
|
||||
'include/torch/csrc/onnx/*.h',
|
||||
'include/torch/csrc/profiler/*.h',
|
||||
'include/torch/csrc/utils/*.h',
|
||||
'include/torch/csrc/tensor/*.h',
|
||||
'include/torch/csrc/lazy/core/*.h',
|
||||
|
|
|
|||
|
|
@ -105,7 +105,6 @@ jit_core_sources = [
|
|||
# list for the shared files.
|
||||
|
||||
core_sources_common = [
|
||||
"torch/csrc/autograd/profiler_utils.cpp",
|
||||
"torch/csrc/autograd/autograd_meta.cpp",
|
||||
"torch/csrc/autograd/forward_grad.cpp",
|
||||
"torch/csrc/jit/frontend/edit_distance.cpp",
|
||||
|
|
@ -122,6 +121,7 @@ core_sources_common = [
|
|||
"torch/csrc/jit/mobile/promoted_prim_ops.cpp",
|
||||
"torch/csrc/jit/mobile/prim_ops_registery.cpp",
|
||||
"torch/csrc/jit/operator_upgraders/upgraders.cpp",
|
||||
"torch/csrc/profiler/util.cpp",
|
||||
]
|
||||
|
||||
torch_unpickler_common = [
|
||||
|
|
|
|||
|
|
@ -237,23 +237,8 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
|
|||
py::arg("scopes") = std::unordered_set<at::RecordScope>());
|
||||
m.def("_disable_profiler", disableProfiler);
|
||||
m.def("_prepare_profiler", prepareProfiler);
|
||||
|
||||
m.def("_add_metadata_json", [](const std::string& key, const std::string& value) {
|
||||
#ifdef USE_KINETO
|
||||
addMetadataJson(key, value);
|
||||
#else
|
||||
LOG(WARNING) << "Adding profiling metadata requires using "
|
||||
<< "torch.profiler with Kineto support (USE_KINETO=1)";
|
||||
#endif // USE_KINETO
|
||||
});
|
||||
|
||||
m.def("kineto_available", []() {
|
||||
#ifdef USE_KINETO
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
});
|
||||
m.def("_add_metadata_json", torch::profiler::impl::addMetadataJson); // Only if `USE_KINETO` is set
|
||||
m.def("kineto_available", []() { return torch::profiler::kKinetoAvailable; });
|
||||
|
||||
// NOTICE: These record functions are not torch operators and may not show up
|
||||
// in TorchScript tracing, FX transforms, or operator serialization. For these
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
#include <torch/csrc/autograd/profiler_kineto.h>
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <limits>
|
||||
#include <torch/csrc/autograd/profiler_kineto.h>
|
||||
|
||||
#include <torch/csrc/jit/frontend/tracer.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
|
|
@ -104,11 +105,6 @@ void _push_reverse_order(PyTraceEvent* e, std::vector<std::string>& names) {
|
|||
|
||||
namespace {
|
||||
|
||||
std::string shapesToStr(const std::vector<std::vector<int64_t>>& shapes);
|
||||
std::string stacksToStr(const std::vector<std::string>& stacks, const char* delim);
|
||||
std::string dtypesToStr(const std::vector<std::string>& types);
|
||||
std::vector<std::string> inputTypes(const at::RecordFunction& fn);
|
||||
|
||||
// Assumption: Total threads number will not exceed 2^16-1, and total ops will not exceed 2^48 -1.
|
||||
static inline uint64_t getForwardThreadKey(uint64_t tid, uint64_t seqNr) {
|
||||
return (((tid) << 48) | ((seqNr) & (((uint64_t)1 << 48) - 1)));
|
||||
|
|
@ -180,7 +176,7 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalState {
|
|||
kineto_events_.back().moduleHierarchy(*ctx->module_hierarchy);
|
||||
}
|
||||
if (ctx->extraArgs && !ctx->extraArgs->empty()) {
|
||||
kineto_events_.back().flops(computeFlops(std::string(evt_name), *ctx->extraArgs));
|
||||
kineto_events_.back().flops(torch::profiler::impl::computeFlops(std::string(evt_name), *ctx->extraArgs));
|
||||
}
|
||||
kineto_events_.back().cuda_event_start_ = ctx->cuda_event_start_;
|
||||
kineto_events_.back().cuda_event_end_ = ctx->cuda_event_end_;
|
||||
|
|
@ -325,18 +321,18 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalState {
|
|||
auto& activity = cpu_trace->activities[idx];
|
||||
|
||||
if (kineto_event.hasShapes()) {
|
||||
activity.addMetadata("Input Dims", shapesToStr(kineto_event.shapes()));
|
||||
activity.addMetadata("Input Dims", torch::profiler::impl::shapesToStr(kineto_event.shapes()));
|
||||
}
|
||||
if (kineto_event.hasStack()) {
|
||||
// NB: This is only for the JIT stack. The python stack (if applicable)
|
||||
// is constructed later.
|
||||
activity.addMetadata("Call stack", stacksToStr(kineto_event.stack(), ";"));
|
||||
activity.addMetadata("Call stack", torch::profiler::impl::stacksToStr(kineto_event.stack(), ";"));
|
||||
}
|
||||
if (kineto_event.hasModuleHierarchy()) {
|
||||
activity.addMetadata("Module Hierarchy", stacksToStr(kineto_event.moduleHierarchy(), "."));
|
||||
activity.addMetadata("Module Hierarchy", torch::profiler::impl::stacksToStr(kineto_event.moduleHierarchy(), "."));
|
||||
}
|
||||
if (kineto_event.hasTypes()) {
|
||||
activity.addMetadata("Input type", dtypesToStr(kineto_event.dtypes()));
|
||||
activity.addMetadata("Input type", torch::profiler::impl::dtypesToStr(kineto_event.dtypes()));
|
||||
}
|
||||
if (!kineto_event.backend().empty()) {
|
||||
activity.addMetadata("Backend", "\"" + kineto_event.backend() + "\"");
|
||||
|
|
@ -472,7 +468,7 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalState {
|
|||
std::vector<std::string> py_names;
|
||||
_push_reverse_order(python_caller, py_names);
|
||||
kineto_events_[idx].stack(py_names);
|
||||
activity.addMetadata("Call stack", stacksToStr(py_names, ";"));
|
||||
activity.addMetadata("Call stack", torch::profiler::impl::stacksToStr(py_names, ";"));
|
||||
}
|
||||
|
||||
cpu_trace->activities.push_back(activity);
|
||||
|
|
@ -532,27 +528,6 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalState {
|
|||
std::function<void(std::vector<KinetoEvent>&)> event_post_process_cb_;
|
||||
};
|
||||
|
||||
std::vector<std::string> inputTypes(const at::RecordFunction& fn) {
|
||||
std::vector<std::string> types;
|
||||
types.reserve(fn.inputs().size());
|
||||
for (const c10::IValue& input : fn.inputs()) {
|
||||
if (input.isTensor()) {
|
||||
const at::Tensor& tensor = input.toTensor();
|
||||
if (tensor.defined()) {
|
||||
types.push_back(
|
||||
static_cast<std::string>(input.toTensor().dtype().name()));
|
||||
} else {
|
||||
types.emplace_back();
|
||||
}
|
||||
} else if (input.isScalar() || input.isList()) {
|
||||
types.push_back(input.tagKind());
|
||||
} else {
|
||||
types.emplace_back();
|
||||
}
|
||||
}
|
||||
return types;
|
||||
}
|
||||
|
||||
KinetoThreadLocalState* getProfilerTLSState() {
|
||||
const auto& state = c10::ThreadLocalDebugInfo::get(
|
||||
c10::DebugInfoKind::PROFILER_STATE);
|
||||
|
|
@ -582,12 +557,12 @@ void pushProfilingCallbacks(const std::unordered_set<at::RecordScope>& scopes) {
|
|||
ctx_ptr->debug_handle = fn.debugHandle();
|
||||
|
||||
if (config.report_input_shapes) {
|
||||
ctx_ptr->shapes = inputSizes(fn);
|
||||
ctx_ptr->dtypes = inputTypes(fn);
|
||||
ctx_ptr->shapes = torch::profiler::impl::inputSizes(fn);
|
||||
ctx_ptr->dtypes = torch::profiler::impl::inputTypes(fn);
|
||||
}
|
||||
|
||||
if (config.with_flops) {
|
||||
ctx_ptr->extraArgs = saveExtraArgs(fn);
|
||||
ctx_ptr->extraArgs = torch::profiler::impl::saveExtraArgs(fn);
|
||||
}
|
||||
|
||||
ctx_ptr->sequenceNr = fn.seqNr();
|
||||
|
|
@ -599,7 +574,7 @@ void pushProfilingCallbacks(const std::unordered_set<at::RecordScope>& scopes) {
|
|||
// TODO: consider using C++ stack trace
|
||||
if (config.with_stack &&
|
||||
fn.scope() != at::RecordScope::BACKWARD_FUNCTION) {
|
||||
auto cs = prepareCallstack(jit::currentCallstack());
|
||||
auto cs = torch::profiler::impl::prepareCallstack(jit::currentCallstack());
|
||||
ctx_ptr->stack = callstackStr(cs);
|
||||
}
|
||||
if (config.with_modules &&
|
||||
|
|
@ -619,9 +594,9 @@ void pushProfilingCallbacks(const std::unordered_set<at::RecordScope>& scopes) {
|
|||
} else if (config.state == ProfilerState::NVTX) {
|
||||
std::vector<std::vector<int64_t>> shapes;
|
||||
if (config.report_input_shapes) {
|
||||
shapes = inputSizes(fn);
|
||||
shapes = torch::profiler::impl::inputSizes(fn);
|
||||
}
|
||||
cudaStubs()->nvtxRangePushA(getNvtxStr(
|
||||
cudaStubs()->nvtxRangePushA(torch::profiler::impl::getNvtxStr(
|
||||
fn.name(), fn.seqNr(), shapes).c_str());
|
||||
}
|
||||
return nullptr;
|
||||
|
|
@ -662,59 +637,6 @@ void pushProfilingCallbacks(const std::unordered_set<at::RecordScope>& scopes) {
|
|||
state_ptr->setCallbackHandle(handle);
|
||||
}
|
||||
|
||||
std::string shapesToStr(const std::vector<std::vector<int64_t>>& shapes) {
|
||||
std::ostringstream oss;
|
||||
oss << "[";
|
||||
for (const auto t_idx : c10::irange(shapes.size())) {
|
||||
if (t_idx > 0) {
|
||||
oss << ", ";
|
||||
}
|
||||
oss << "[";
|
||||
for (size_t s_idx = 0; s_idx < shapes[t_idx].size(); ++s_idx) {
|
||||
if (s_idx > 0) {
|
||||
oss << ", ";
|
||||
}
|
||||
oss << shapes[t_idx][s_idx];
|
||||
}
|
||||
oss << "]";
|
||||
}
|
||||
oss << "]";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
std::string dtypesToStr(const std::vector<std::string>& types) {
|
||||
if (types.empty()) {
|
||||
return "[]";
|
||||
} else {
|
||||
std::ostringstream oss;
|
||||
std::transform(
|
||||
types.begin(),
|
||||
types.end(),
|
||||
std::ostream_iterator<std::string>(oss, ", "),
|
||||
[](std::string s) -> std::string { return "\"" + s + "\""; });
|
||||
auto rc = oss.str();
|
||||
rc.erase(rc.length() - 2); // remove last ", "
|
||||
return "[" + rc + "]";
|
||||
}
|
||||
}
|
||||
|
||||
std::string stacksToStr(const std::vector<std::string>& stacks, const char* delim) {
|
||||
std::ostringstream oss;
|
||||
std::transform(
|
||||
stacks.begin(),
|
||||
stacks.end(),
|
||||
std::ostream_iterator<std::string>(oss, delim),
|
||||
[](std::string s) -> std::string {
|
||||
#ifdef _WIN32
|
||||
// replace the windows backslash with forward slash
|
||||
std::replace(s.begin(), s.end(), '\\', '/');
|
||||
#endif
|
||||
return s;
|
||||
});
|
||||
auto rc = oss.str();
|
||||
return "\"" + rc + "\"";
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void reportBackendEventToActiveKinetoProfiler(
|
||||
|
|
@ -883,16 +805,6 @@ std::unique_ptr<ProfilerResult> disableProfiler() {
|
|||
#endif // USE_KINETO
|
||||
}
|
||||
|
||||
void addMetadataJson(const std::string& key, const std::string& value) {
|
||||
#ifdef USE_KINETO
|
||||
if (libkineto::api().isProfilerInitialized()) {
|
||||
libkineto::api().activityProfiler().addMetadata(key, value);
|
||||
} else {
|
||||
LOG(WARNING) << "Profiler is not initialized: skipping profiling metadata";
|
||||
}
|
||||
#endif // USE_KINETO
|
||||
}
|
||||
|
||||
int64_t KinetoEvent::cudaElapsedUs() const {
|
||||
if (!cuda_event_start_ || !cuda_event_end_) {
|
||||
return -1;
|
||||
|
|
|
|||
|
|
@ -4,17 +4,6 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#ifdef USE_KINETO
|
||||
// skip Kineto dependency on mobile
|
||||
// unless explicitly asked for.
|
||||
// When is it explicitly asked for?
|
||||
// KinetoEdgeCPUProfiler uses KinetoProfiler for cpu
|
||||
// event profiling. This has dependency on cpu only libkineto
|
||||
#if defined(C10_MOBILE) && !defined(EDGE_PROFILER_USE_KINETO)
|
||||
#undef USE_KINETO
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef USE_KINETO
|
||||
namespace libkineto {
|
||||
struct TraceActivity;
|
||||
|
|
@ -407,10 +396,6 @@ TORCH_API void prepareProfiler(
|
|||
const ProfilerConfig& config,
|
||||
const std::set<ActivityType>& activities);
|
||||
|
||||
TORCH_API void addMetadataJson(
|
||||
const std::string& key, const std::string& value);
|
||||
|
||||
|
||||
namespace python_tracer {
|
||||
|
||||
/*
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#include <torch/csrc/autograd/profiler.h>
|
||||
#include <torch/csrc/autograd/profiler_legacy.h>
|
||||
|
||||
#include <torch/csrc/autograd/function.h>
|
||||
#include <torch/csrc/jit/frontend/code_template.h>
|
||||
|
||||
|
|
@ -24,34 +25,6 @@
|
|||
|
||||
namespace torch { namespace autograd { namespace profiler {
|
||||
|
||||
std::vector<FileLineFunc> prepareCallstack(const std::vector<jit::StackEntry>& cs) {
|
||||
std::vector<FileLineFunc> entries;
|
||||
entries.reserve(cs.size());
|
||||
for (const auto& entry : cs) {
|
||||
auto& range = entry.range;
|
||||
if (range.source()) {
|
||||
auto& src = range.source();
|
||||
if (src && src->filename()) {
|
||||
auto line = src->starting_line_no() +
|
||||
src->lineno_for_offset(range.start());
|
||||
entries.emplace_back(FileLineFunc{*(src->filename()), line, entry.filename});
|
||||
}
|
||||
}
|
||||
}
|
||||
return entries;
|
||||
}
|
||||
|
||||
std::vector<std::string> callstackStr(const std::vector<FileLineFunc>& cs) {
|
||||
std::vector<std::string> cs_str;
|
||||
cs_str.reserve(cs.size());
|
||||
for (const auto& entry : cs) {
|
||||
std::stringstream loc;
|
||||
loc << entry.filename << "(" << entry.line << "): " << entry.funcname;
|
||||
cs_str.push_back(loc.str());
|
||||
}
|
||||
return cs_str;
|
||||
}
|
||||
|
||||
// We decompose the profiler logic into the following components:
|
||||
//
|
||||
// ThreadLocalDebugInfo:
|
||||
|
|
@ -216,7 +189,7 @@ void ProfilerThreadLocalState::pushRange(
|
|||
return;
|
||||
}
|
||||
if (config_.state == ProfilerState::NVTX) {
|
||||
cuda_stubs()->nvtxRangePushA(getNvtxStr(
|
||||
cuda_stubs()->nvtxRangePushA(torch::profiler::impl::getNvtxStr(
|
||||
fn.name(), fn.seqNr(), shapes).c_str());
|
||||
} else {
|
||||
LegacyEvent evt(
|
||||
|
|
@ -232,8 +205,8 @@ void ProfilerThreadLocalState::pushRange(
|
|||
evt.setFwdThreadId(fn.forwardThreadId());
|
||||
evt.setScope((uint8_t)fn.scope());
|
||||
if (config_.with_flops) {
|
||||
evt.setExtraArgs(saveExtraArgs(fn));
|
||||
evt.setFlops(computeFlops(std::string(fn.name()), evt.extraArgs()));
|
||||
evt.setExtraArgs(torch::profiler::impl::saveExtraArgs(fn));
|
||||
evt.setFlops(torch::profiler::impl::computeFlops(std::string(fn.name()), evt.extraArgs()));
|
||||
}
|
||||
|
||||
// TODO: will unify the two macros BUILD_LITE_INTERPRETER and C10_MOBILE soon.
|
||||
|
|
@ -241,9 +214,9 @@ void ProfilerThreadLocalState::pushRange(
|
|||
// backward nodes source range corresponds to the forward node
|
||||
// TODO: consider using C++ stack trace
|
||||
if (config_.with_stack && fn.scope() != at::RecordScope::BACKWARD_FUNCTION) {
|
||||
auto cs = prepareCallstack(jit::currentCallstack());
|
||||
auto cs = torch::profiler::impl::prepareCallstack(jit::currentCallstack());
|
||||
if (cs.empty()) {
|
||||
cs = prepareCallstack(jit::tracer::pythonCallstack());
|
||||
cs = torch::profiler::impl::prepareCallstack(jit::tracer::pythonCallstack());
|
||||
}
|
||||
evt.setStack(callstackStr(cs));
|
||||
}
|
||||
|
|
@ -296,53 +269,6 @@ bool ProfilerThreadLocalState::memoryProfilingEnabled() const {
|
|||
return config_.profile_memory;
|
||||
}
|
||||
|
||||
std::string getNvtxStr(
|
||||
const char* name,
|
||||
int64_t sequence_nr,
|
||||
const std::vector<std::vector<int64_t>>& shapes) {
|
||||
if (sequence_nr >= -1 || shapes.size() > 0) {
|
||||
std::stringstream s;
|
||||
#if defined(USE_ROCM)
|
||||
s << name;
|
||||
#endif
|
||||
if (sequence_nr >= 0) {
|
||||
#if defined(USE_ROCM)
|
||||
s << ", seq = " << sequence_nr;
|
||||
#else
|
||||
s << name << ", seq = " << sequence_nr;
|
||||
#endif
|
||||
} else if (sequence_nr == -1) {
|
||||
#if !defined(USE_ROCM)
|
||||
s << name;
|
||||
#endif
|
||||
}
|
||||
if (shapes.size() > 0) {
|
||||
s << ", sizes = [";
|
||||
for (const auto idx : c10::irange(shapes.size())) {
|
||||
if (shapes[idx].size() > 0) {
|
||||
s << "[";
|
||||
for (size_t dim = 0; dim < shapes[idx].size(); ++dim) {
|
||||
s << shapes[idx][dim];
|
||||
if (dim < shapes[idx].size() - 1) {
|
||||
s << ", ";
|
||||
}
|
||||
}
|
||||
s << "]";
|
||||
} else {
|
||||
s << "[]";
|
||||
}
|
||||
if (idx < shapes.size() - 1) {
|
||||
s << ", ";
|
||||
}
|
||||
}
|
||||
s << "]";
|
||||
}
|
||||
return s.str();
|
||||
} else {
|
||||
return name;
|
||||
}
|
||||
}
|
||||
|
||||
RangeEventList& ProfilerThreadLocalState::getEventList(int64_t thread_id) {
|
||||
if (thread_id < 0) {
|
||||
thread_id = at::RecordFunction::currentThreadId();
|
||||
|
|
@ -360,24 +286,6 @@ RangeEventList& ProfilerThreadLocalState::getEventList(int64_t thread_id) {
|
|||
return *list_ptr;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> inputSizes(const at::RecordFunction& fn) {
|
||||
std::vector<std::vector<int64_t>> sizes;
|
||||
sizes.reserve(fn.inputs().size());
|
||||
for (const c10::IValue& input : fn.inputs()) {
|
||||
if (!input.isTensor()) {
|
||||
sizes.emplace_back();
|
||||
continue;
|
||||
}
|
||||
const at::Tensor& tensor = input.toTensor();
|
||||
if (tensor.defined()) {
|
||||
sizes.push_back(input.toTensor().sizes().vec());
|
||||
} else {
|
||||
sizes.emplace_back();
|
||||
}
|
||||
}
|
||||
return sizes;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
enum EventIValueIdx {
|
||||
|
|
@ -442,7 +350,7 @@ void pushProfilingCallbacksLegacy() {
|
|||
}
|
||||
|
||||
if (state_ptr->config().report_input_shapes) {
|
||||
auto sizes = inputSizes(fn);
|
||||
auto sizes = torch::profiler::impl::inputSizes(fn);
|
||||
state_ptr->pushRange(fn, record_cuda, std::move(sizes));
|
||||
} else {
|
||||
state_ptr->pushRange(fn, record_cuda);
|
||||
|
|
|
|||
|
|
@ -10,18 +10,8 @@
|
|||
#include <forward_list>
|
||||
#include <tuple>
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/csrc/profiler/util.h>
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/csrc/autograd/profiler_utils.h>
|
||||
#ifndef _WIN32
|
||||
#include <ctime>
|
||||
#endif
|
||||
#if defined(C10_IOS) && defined(C10_MOBILE)
|
||||
#include <sys/time.h> // for gettimeofday()
|
||||
#endif
|
||||
|
||||
#include <ATen/record_function.h>
|
||||
|
||||
#include <torch/csrc/jit/frontend/source_range.h>
|
||||
|
||||
struct CUevent_st;
|
||||
typedef std::shared_ptr<CUevent_st> CUDAEventStub;
|
||||
|
|
@ -69,33 +59,6 @@ private:
|
|||
TORCH_API void registerCUDAMethods(CUDAStubs* stubs);
|
||||
TORCH_API const CUDAStubs* cudaStubs();
|
||||
|
||||
constexpr inline size_t ceilToMultiple(size_t a, size_t b) {
|
||||
return ((a + b - 1) / b) * b;
|
||||
}
|
||||
|
||||
inline int64_t getTime(bool allow_monotonic = false) {
|
||||
#if defined(C10_IOS) && defined(C10_MOBILE)
|
||||
// clock_gettime is only available on iOS 10.0 or newer. Unlike OS X, iOS can't rely on
|
||||
// CLOCK_REALTIME, as it is defined no matter if clock_gettime is implemented or not
|
||||
struct timeval now;
|
||||
gettimeofday(&now, NULL);
|
||||
return static_cast<int64_t>(now.tv_sec) * 1000000000 + static_cast<int64_t>(now.tv_usec) * 1000;
|
||||
#elif defined(_WIN32) || defined(__MACH__)
|
||||
using namespace std::chrono;
|
||||
using clock = std::conditional<high_resolution_clock::is_steady, high_resolution_clock, steady_clock>::type;
|
||||
return duration_cast<nanoseconds>(clock::now().time_since_epoch()).count();
|
||||
#else
|
||||
// clock_gettime is *much* faster than std::chrono implementation on Linux
|
||||
struct timespec t{};
|
||||
auto mode = CLOCK_REALTIME;
|
||||
if (allow_monotonic) {
|
||||
mode = CLOCK_MONOTONIC;
|
||||
}
|
||||
clock_gettime(mode, &t);
|
||||
return static_cast<int64_t>(t.tv_sec) * 1000000000 + static_cast<int64_t>(t.tv_nsec);
|
||||
#endif
|
||||
}
|
||||
|
||||
enum class C10_API_ENUM EventKind : uint16_t {
|
||||
Mark,
|
||||
PushRange,
|
||||
|
|
@ -394,11 +357,6 @@ struct RangeEventList {
|
|||
static const size_t kReservedCapacity = 1024;
|
||||
};
|
||||
|
||||
std::string getNvtxStr(
|
||||
const char* name,
|
||||
int64_t sequence_nr,
|
||||
const std::vector<std::vector<int64_t>>& shapes);
|
||||
|
||||
enum class C10_API_ENUM ProfilerState {
|
||||
Disabled = 0,
|
||||
CPU, // CPU-only profiling
|
||||
|
|
@ -526,15 +484,6 @@ struct TORCH_API TLSLegacyProfilerGuard {
|
|||
const c10::optional<ProfilerDisableOptions> profilerDisableOptions_;
|
||||
};
|
||||
|
||||
struct TORCH_API FileLineFunc {
|
||||
std::string filename;
|
||||
size_t line;
|
||||
std::string funcname;
|
||||
};
|
||||
TORCH_API std::vector<FileLineFunc> prepareCallstack(const std::vector<jit::StackEntry>& cs);
|
||||
TORCH_API std::vector<std::string> callstackStr(const std::vector<FileLineFunc>& cs);
|
||||
TORCH_API std::vector<std::vector<int64_t>> inputSizes(const at::RecordFunction& fn);
|
||||
|
||||
struct TORCH_API ProfilerThreadLocalState : public c10::MemoryReportingInfoBase {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
explicit ProfilerThreadLocalState(const ProfilerConfig& config)
|
||||
|
|
@ -603,8 +552,6 @@ struct TORCH_API ProfilerThreadLocalState : public c10::MemoryReportingInfoBase
|
|||
namespace torch {
|
||||
namespace profiler {
|
||||
namespace impl {
|
||||
using torch::autograd::profiler::computeFlops;
|
||||
using torch::autograd::profiler::getTime;
|
||||
using torch::autograd::profiler::ProfilerConfig;
|
||||
using torch::autograd::profiler::ProfilerState;
|
||||
} // impl
|
||||
|
|
|
|||
|
|
@ -1,16 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/record_function.h>
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
namespace profiler {
|
||||
|
||||
std::unordered_map<std::string, c10::IValue> TORCH_API saveExtraArgs(const at::RecordFunction& fn);
|
||||
|
||||
uint64_t TORCH_API computeFlops(const std::string &op_name,
|
||||
const std::unordered_map<std::string, c10::IValue> &extra_args);
|
||||
|
||||
}}}
|
||||
563
torch/csrc/profiler/util.cpp
Normal file
563
torch/csrc/profiler/util.cpp
Normal file
|
|
@ -0,0 +1,563 @@
|
|||
#include <torch/csrc/profiler/util.h>
|
||||
|
||||
#include <c10/util/ArrayRef.h>
|
||||
|
||||
#ifdef USE_KINETO
|
||||
#include <libkineto.h>
|
||||
#endif
|
||||
|
||||
namespace torch {
|
||||
namespace profiler {
|
||||
namespace impl {
|
||||
|
||||
void addMetadataJson(const std::string& key, const std::string& value) {
|
||||
#ifdef USE_KINETO
|
||||
if (libkineto::api().isProfilerInitialized()) {
|
||||
libkineto::api().activityProfiler().addMetadata(key, value);
|
||||
} else {
|
||||
LOG(WARNING) << "Profiler is not initialized: skipping profiling metadata";
|
||||
}
|
||||
#else
|
||||
LOG(WARNING) << "Adding profiling metadata requires using "
|
||||
<< "torch.profiler with Kineto support (USE_KINETO=1)";
|
||||
#endif // USE_KINETO
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// -- NVTX --------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------
|
||||
std::string getNvtxStr(
|
||||
const char* name,
|
||||
int64_t sequence_nr,
|
||||
const std::vector<std::vector<int64_t>>& shapes) {
|
||||
if (sequence_nr >= -1 || shapes.size() > 0) {
|
||||
std::stringstream s;
|
||||
#if defined(USE_ROCM)
|
||||
s << name;
|
||||
#endif
|
||||
if (sequence_nr >= 0) {
|
||||
#if defined(USE_ROCM)
|
||||
s << ", seq = " << sequence_nr;
|
||||
#else
|
||||
s << name << ", seq = " << sequence_nr;
|
||||
#endif
|
||||
} else if (sequence_nr == -1) {
|
||||
#if !defined(USE_ROCM)
|
||||
s << name;
|
||||
#endif
|
||||
}
|
||||
if (shapes.size() > 0) {
|
||||
s << ", sizes = [";
|
||||
for (const auto idx : c10::irange(shapes.size())) {
|
||||
if (shapes[idx].size() > 0) {
|
||||
s << "[";
|
||||
for (size_t dim = 0; dim < shapes[idx].size(); ++dim) {
|
||||
s << shapes[idx][dim];
|
||||
if (dim < shapes[idx].size() - 1) {
|
||||
s << ", ";
|
||||
}
|
||||
}
|
||||
s << "]";
|
||||
} else {
|
||||
s << "[]";
|
||||
}
|
||||
if (idx < shapes.size() - 1) {
|
||||
s << ", ";
|
||||
}
|
||||
}
|
||||
s << "]";
|
||||
}
|
||||
return s.str();
|
||||
} else {
|
||||
return name;
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// -- Op context (shapes, call stack) -----------------------------------------
|
||||
// ----------------------------------------------------------------------------
|
||||
std::vector<FileLineFunc> prepareCallstack(
|
||||
const std::vector<jit::StackEntry>& cs) {
|
||||
std::vector<FileLineFunc> entries;
|
||||
entries.reserve(cs.size());
|
||||
for (const auto& entry : cs) {
|
||||
auto& range = entry.range;
|
||||
if (range.source()) {
|
||||
auto& src = range.source();
|
||||
if (src && src->filename()) {
|
||||
auto line =
|
||||
src->starting_line_no() + src->lineno_for_offset(range.start());
|
||||
entries.emplace_back(
|
||||
FileLineFunc{*(src->filename()), line, entry.filename});
|
||||
}
|
||||
}
|
||||
}
|
||||
return entries;
|
||||
}
|
||||
|
||||
std::vector<std::string> callstackStr(const std::vector<FileLineFunc>& cs) {
|
||||
std::vector<std::string> cs_str;
|
||||
cs_str.reserve(cs.size());
|
||||
for (const auto& entry : cs) {
|
||||
std::stringstream loc;
|
||||
loc << entry.filename << "(" << entry.line << "): " << entry.funcname;
|
||||
cs_str.push_back(loc.str());
|
||||
}
|
||||
return cs_str;
|
||||
}
|
||||
|
||||
std::string stacksToStr(
|
||||
const std::vector<std::string>& stacks,
|
||||
const char* delim) {
|
||||
std::ostringstream oss;
|
||||
std::transform(
|
||||
stacks.begin(),
|
||||
stacks.end(),
|
||||
std::ostream_iterator<std::string>(oss, delim),
|
||||
[](std::string s) -> std::string {
|
||||
#ifdef _WIN32
|
||||
// replace the windows backslash with forward slash
|
||||
std::replace(s.begin(), s.end(), '\\', '/');
|
||||
#endif
|
||||
return s;
|
||||
});
|
||||
auto rc = oss.str();
|
||||
return "\"" + rc + "\"";
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> inputSizes(const at::RecordFunction& fn) {
|
||||
std::vector<std::vector<int64_t>> sizes;
|
||||
sizes.reserve(fn.inputs().size());
|
||||
for (const c10::IValue& input : fn.inputs()) {
|
||||
if (!input.isTensor()) {
|
||||
sizes.emplace_back();
|
||||
continue;
|
||||
}
|
||||
const at::Tensor& tensor = input.toTensor();
|
||||
if (tensor.defined()) {
|
||||
sizes.push_back(input.toTensor().sizes().vec());
|
||||
} else {
|
||||
sizes.emplace_back();
|
||||
}
|
||||
}
|
||||
return sizes;
|
||||
}
|
||||
|
||||
std::string shapesToStr(const std::vector<std::vector<int64_t>>& shapes) {
|
||||
std::ostringstream oss;
|
||||
oss << "[";
|
||||
for (const auto t_idx : c10::irange(shapes.size())) {
|
||||
if (t_idx > 0) {
|
||||
oss << ", ";
|
||||
}
|
||||
oss << "[";
|
||||
for (size_t s_idx = 0; s_idx < shapes[t_idx].size(); ++s_idx) {
|
||||
if (s_idx > 0) {
|
||||
oss << ", ";
|
||||
}
|
||||
oss << shapes[t_idx][s_idx];
|
||||
}
|
||||
oss << "]";
|
||||
}
|
||||
oss << "]";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
std::string dtypesToStr(const std::vector<std::string>& types) {
|
||||
if (types.empty()) {
|
||||
return "[]";
|
||||
} else {
|
||||
std::ostringstream oss;
|
||||
std::transform(
|
||||
types.begin(),
|
||||
types.end(),
|
||||
std::ostream_iterator<std::string>(oss, ", "),
|
||||
[](std::string s) -> std::string { return "\"" + s + "\""; });
|
||||
auto rc = oss.str();
|
||||
rc.erase(rc.length() - 2); // remove last ", "
|
||||
return "[" + rc + "]";
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> inputTypes(const at::RecordFunction& fn) {
|
||||
std::vector<std::string> types;
|
||||
types.reserve(fn.inputs().size());
|
||||
for (const c10::IValue& input : fn.inputs()) {
|
||||
if (input.isTensor()) {
|
||||
const at::Tensor& tensor = input.toTensor();
|
||||
if (tensor.defined()) {
|
||||
types.push_back(
|
||||
static_cast<std::string>(input.toTensor().dtype().name()));
|
||||
} else {
|
||||
types.emplace_back();
|
||||
}
|
||||
} else if (input.isScalar() || input.isList()) {
|
||||
types.push_back(input.tagKind());
|
||||
} else {
|
||||
types.emplace_back();
|
||||
}
|
||||
}
|
||||
return types;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// -- FLOPS -------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------
|
||||
static constexpr auto kConv2dStride = 3;
|
||||
static constexpr auto kConv2dPadding = 4;
|
||||
static constexpr auto kConv2dDilation = 5;
|
||||
static constexpr auto kConv2dGroups = 6;
|
||||
|
||||
// List of supported operators
|
||||
static constexpr auto kConv2dOp = "aten::conv2d";
|
||||
static constexpr auto kMMOp = "aten::mm";
|
||||
static constexpr auto kAddMMOp = "aten::addmm";
|
||||
static constexpr auto kMulOp = "aten::mul";
|
||||
static constexpr auto kAddOp = "aten::add";
|
||||
static constexpr auto kBMMOp = "aten::bmm";
|
||||
static constexpr auto kBAddBMMOp = "aten::baddbmm";
|
||||
|
||||
static constexpr auto kInputSize = "input_size";
|
||||
static constexpr auto kWeightSize = "weight_size";
|
||||
static constexpr auto kGroups = "groups";
|
||||
static constexpr auto kPadding = "padding";
|
||||
static constexpr auto kStride = "stride";
|
||||
static constexpr auto kDilation = "dilation";
|
||||
static constexpr auto kMatSize = "mat_size";
|
||||
static constexpr auto kMat1Size = "mat1_size";
|
||||
static constexpr auto kMat2Size = "mat2_size";
|
||||
|
||||
static bool validateInput(
|
||||
const std::string& op_name,
|
||||
size_t min_size,
|
||||
const std::vector<c10::IValue>& inputs,
|
||||
const c10::ArrayRef<int>& should_be_tensor) {
|
||||
std::stringstream ss;
|
||||
if (inputs.size() < min_size) {
|
||||
ss << "Failed to save extra arguments for flops compuation of op "
|
||||
<< op_name << ", min size: " << min_size
|
||||
<< ", actual size: " << inputs.size();
|
||||
TORCH_WARN(ss.str());
|
||||
return false;
|
||||
}
|
||||
for (auto index : should_be_tensor) {
|
||||
if (!inputs[index].isTensor()) {
|
||||
ss << "Failed to save extra arguments for flops compuation of op "
|
||||
<< op_name << ", input[" << index << "] must be a tensor.";
|
||||
TORCH_WARN(ss.str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, c10::IValue> saveExtraArgs(
|
||||
const at::RecordFunction& fn) {
|
||||
// for specific types of fn, return the saved extra args for computing flops
|
||||
std::unordered_map<std::string, c10::IValue> map;
|
||||
std::vector<c10::IValue> inputs = fn.inputs();
|
||||
std::string fname(fn.name());
|
||||
|
||||
if (inputs.empty()) {
|
||||
// Input shape is unavailable, return empty map
|
||||
return map;
|
||||
}
|
||||
|
||||
if (fname == kConv2dOp) {
|
||||
bool check = validateInput(fname, kConv2dGroups + 1, inputs, {0, 1});
|
||||
if (!check) {
|
||||
return map;
|
||||
}
|
||||
|
||||
at::Tensor input = inputs[0].toTensor();
|
||||
at::Tensor weight = inputs[1].toTensor();
|
||||
if (weight.sizes().size() != 4) {
|
||||
TORCH_WARN(
|
||||
"Failed to compute flops for op aten::conv2d because it requires a 4D kernel tensor.");
|
||||
return map;
|
||||
}
|
||||
map[kInputSize] = at::IValue(input.sizes());
|
||||
map[kWeightSize] = at::IValue(weight.sizes());
|
||||
map[kStride] = inputs[kConv2dStride];
|
||||
map[kPadding] = inputs[kConv2dPadding];
|
||||
map[kDilation] = inputs[kConv2dDilation];
|
||||
map[kGroups] = inputs[kConv2dGroups];
|
||||
} else if (fname == kMMOp) {
|
||||
bool check = validateInput(fname, 2, inputs, {0, 1});
|
||||
if (!check) {
|
||||
return map;
|
||||
}
|
||||
|
||||
at::Tensor left = inputs[0].toTensor();
|
||||
at::Tensor right = inputs[1].toTensor();
|
||||
map[kMat1Size] = at::IValue(left.sizes());
|
||||
map[kMat2Size] = at::IValue(right.sizes());
|
||||
} else if (fname == kAddMMOp) {
|
||||
bool check = validateInput(fname, 3, inputs, {0, 1, 2});
|
||||
if (!check) {
|
||||
return map;
|
||||
}
|
||||
|
||||
// Exact FLOP count depends on scaling factors alpha and beta but
|
||||
// just assume these are +=1.
|
||||
// (similar to http://www.netlib.org/lapack/lawnspdf/lawn41.pdf,
|
||||
// "Operations Count for the BLAS and LAPACK", Table 3, SGEMM)
|
||||
at::Tensor left = inputs[1].toTensor();
|
||||
at::Tensor right = inputs[2].toTensor();
|
||||
map[kMat1Size] = at::IValue(left.sizes());
|
||||
map[kMat2Size] = at::IValue(right.sizes());
|
||||
} else if (fname == kMulOp) {
|
||||
bool check = validateInput(fname, 1, inputs, {0});
|
||||
if (!check) {
|
||||
return map;
|
||||
}
|
||||
|
||||
at::Tensor mat = inputs[0].toTensor();
|
||||
map[kMatSize] = at::IValue(mat.sizes());
|
||||
} else if (fname == kAddOp) {
|
||||
bool check = validateInput(fname, 1, inputs, {0});
|
||||
if (!check) {
|
||||
return map;
|
||||
}
|
||||
|
||||
at::Tensor mat = inputs[0].toTensor();
|
||||
map[kMatSize] = at::IValue(mat.sizes());
|
||||
} else if (fname == kBMMOp) {
|
||||
bool check = validateInput(fname, 2, inputs, {0, 1});
|
||||
if (!check) {
|
||||
return map;
|
||||
}
|
||||
|
||||
at::Tensor left = inputs[0].toTensor();
|
||||
at::Tensor right = inputs[1].toTensor();
|
||||
map[kMat1Size] = at::IValue(left.sizes());
|
||||
map[kMat2Size] = at::IValue(right.sizes());
|
||||
} else if (fname == kBAddBMMOp) {
|
||||
bool check = validateInput(fname, 3, inputs, {0, 1, 2});
|
||||
if (!check) {
|
||||
return map;
|
||||
}
|
||||
|
||||
// Exact FLOP count depends on scaling factors alpha and beta but
|
||||
// just assume these are +=1.
|
||||
// (similar to http://www.netlib.org/lapack/lawnspdf/lawn41.pdf,
|
||||
// "Operations Count for the BLAS and LAPACK", Table 3, SGEMM)
|
||||
at::Tensor left = inputs[1].toTensor();
|
||||
at::Tensor right = inputs[2].toTensor();
|
||||
map[kMat1Size] = at::IValue(left.sizes());
|
||||
map[kMat2Size] = at::IValue(right.sizes());
|
||||
}
|
||||
|
||||
return map;
|
||||
}
|
||||
|
||||
uint64_t computeFlops(
|
||||
const std::string& op_name,
|
||||
const std::unordered_map<std::string, c10::IValue>& extra_args) {
|
||||
if (op_name == kConv2dOp) {
|
||||
if (extra_args.find(kInputSize) == extra_args.end() ||
|
||||
extra_args.find(kWeightSize) == extra_args.end() ||
|
||||
extra_args.find(kGroups) == extra_args.end() ||
|
||||
extra_args.find(kPadding) == extra_args.end() ||
|
||||
extra_args.find(kStride) == extra_args.end() ||
|
||||
extra_args.find(kDilation) == extra_args.end()) {
|
||||
TORCH_WARN(
|
||||
"Calculating flops for aten::conv2d requires groups, padding, stride, dilation, input_size, and weight_size in saved arguments.");
|
||||
return 0;
|
||||
}
|
||||
auto input_sizes_ref = extra_args.at(kInputSize);
|
||||
auto kernel_sizes_ref = extra_args.at(kWeightSize);
|
||||
auto groups_ref = extra_args.at(kGroups);
|
||||
auto padding_ref = extra_args.at(kPadding);
|
||||
auto stride_ref = extra_args.at(kStride);
|
||||
auto dilation_ref = extra_args.at(kDilation);
|
||||
if (!input_sizes_ref.isIntList() || !kernel_sizes_ref.isIntList()) {
|
||||
TORCH_WARN(
|
||||
"Failed to compute flops for op aten::conv2d because it requires input and weight tensor sizes.");
|
||||
return 0;
|
||||
}
|
||||
if (!padding_ref.isIntList() || !stride_ref.isIntList() ||
|
||||
!dilation_ref.isIntList()) {
|
||||
TORCH_WARN(
|
||||
"Failed to compute flops for op aten::conv2d because it requires padding, stride, and dilation values.");
|
||||
return 0;
|
||||
}
|
||||
|
||||
const std::vector<int64_t> input_sizes = input_sizes_ref.toIntVector();
|
||||
const std::vector<int64_t> kernel_sizes = kernel_sizes_ref.toIntVector();
|
||||
const uint64_t groups = groups_ref.toInt();
|
||||
const std::vector<int64_t> padding = padding_ref.toIntVector();
|
||||
const std::vector<int64_t> stride = stride_ref.toIntVector();
|
||||
const std::vector<int64_t> dilation = dilation_ref.toIntVector();
|
||||
if (input_sizes.size() != 4 || kernel_sizes.size() != 4) {
|
||||
TORCH_WARN(
|
||||
"Failed to compute flops for op aten::conv2d because both input and weight must be size 4.");
|
||||
return 0;
|
||||
}
|
||||
if (!groups) {
|
||||
TORCH_WARN(
|
||||
"Failed to compute flops for op aten::conv2d because group size must not be 0.");
|
||||
return 0;
|
||||
}
|
||||
if (padding.size() != 2 || dilation.size() != 2) {
|
||||
TORCH_WARN(
|
||||
"Failed to compute flops for op aten::conv2d because both padding and dilation must be size 2.");
|
||||
return 0;
|
||||
}
|
||||
if (stride.size() != 2 || (stride[0] * stride[1] == 0)) {
|
||||
TORCH_WARN(
|
||||
"Failed to compute flops for op aten::conv2d because stride must be size 2 and cannot be 0.");
|
||||
return 0;
|
||||
}
|
||||
// format of the input is defined in torch.nn.quantized.functional.conv2d()
|
||||
uint64_t minibatch = 0, in_channels = 0, input_h = 0, input_w = 0;
|
||||
uint64_t out_channels = 0, kernel_h = 0, kernel_w = 0;
|
||||
const uint64_t conv2d_multiply_factor = 2;
|
||||
std::tie(minibatch, in_channels, input_h, input_w) = std::make_tuple(
|
||||
input_sizes[0], input_sizes[1], input_sizes[2], input_sizes[3]);
|
||||
std::tie(out_channels, std::ignore, kernel_h, kernel_w) = std::make_tuple(
|
||||
kernel_sizes[0], kernel_sizes[1], kernel_sizes[2], kernel_sizes[3]);
|
||||
uint64_t output_h =
|
||||
(input_h + 2 * padding[0] - dilation[0] * (kernel_h - 1) - 1) /
|
||||
stride[0] +
|
||||
1;
|
||||
uint64_t output_w =
|
||||
(input_w + 2 * padding[1] - dilation[1] * (kernel_w - 1) - 1) /
|
||||
stride[1] +
|
||||
1;
|
||||
|
||||
return conv2d_multiply_factor * minibatch * output_h * output_w * kernel_h *
|
||||
kernel_w * in_channels * out_channels / groups;
|
||||
} else if (op_name == kMMOp || op_name == kAddMMOp) {
|
||||
if (extra_args.find(kMat1Size) == extra_args.end() ||
|
||||
extra_args.find(kMat2Size) == extra_args.end()) {
|
||||
TORCH_WARN(
|
||||
"Calculating flops for ",
|
||||
op_name,
|
||||
" requires mat1_size and mat2_size in saved arguments.");
|
||||
return 0;
|
||||
}
|
||||
auto mat1_sizes_ref = extra_args.at(kMat1Size);
|
||||
auto mat2_sizes_ref = extra_args.at(kMat2Size);
|
||||
if (!mat1_sizes_ref.isIntList() || !mat2_sizes_ref.isIntList()) {
|
||||
TORCH_WARN(
|
||||
"Failed to compute flops for op ",
|
||||
op_name,
|
||||
" because it requires mat1_size and mat2_size to be IntList.");
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::vector<int64_t> mat1_size = mat1_sizes_ref.toIntVector();
|
||||
std::vector<int64_t> mat2_size = mat2_sizes_ref.toIntVector();
|
||||
if (mat1_size.size() == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int64_t overlap_dim = mat1_size.back();
|
||||
if (overlap_dim == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint64_t gemm_multiply_factor = 2;
|
||||
uint64_t flops = 1;
|
||||
for (int64_t dim : mat1_size) {
|
||||
flops *= dim;
|
||||
}
|
||||
flops /= overlap_dim;
|
||||
for (int64_t dim : mat2_size) {
|
||||
flops *= dim;
|
||||
}
|
||||
flops *= gemm_multiply_factor;
|
||||
return flops;
|
||||
} else if (op_name == kBMMOp || op_name == kBAddBMMOp) {
|
||||
if (extra_args.find(kMat1Size) == extra_args.end() ||
|
||||
extra_args.find(kMat2Size) == extra_args.end()) {
|
||||
TORCH_WARN(
|
||||
"Calculating flops for ",
|
||||
op_name,
|
||||
" requires mat1_size and mat2_size in saved arguments.");
|
||||
return 0;
|
||||
}
|
||||
auto mat1_sizes_ref = extra_args.at(kMat1Size);
|
||||
auto mat2_sizes_ref = extra_args.at(kMat2Size);
|
||||
if (!mat1_sizes_ref.isIntList() || !mat2_sizes_ref.isIntList()) {
|
||||
TORCH_WARN(
|
||||
"Failed to compute flops for op ",
|
||||
op_name,
|
||||
" because it requires mat1_size and mat2_size to be IntList.");
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::vector<int64_t> mat1_size = mat1_sizes_ref.toIntVector();
|
||||
std::vector<int64_t> mat2_size = mat2_sizes_ref.toIntVector();
|
||||
if (mat1_size.size() == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int64_t batch_size = mat1_size.front();
|
||||
if (batch_size == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int64_t overlap_dim = mat1_size.back();
|
||||
if (overlap_dim == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint64_t gemm_multiply_factor = 2;
|
||||
uint64_t flops = 1;
|
||||
for (int64_t dim : mat1_size) {
|
||||
flops *= dim;
|
||||
}
|
||||
flops /= overlap_dim;
|
||||
flops /= batch_size;
|
||||
for (int64_t dim : mat2_size) {
|
||||
flops *= dim;
|
||||
}
|
||||
flops *= gemm_multiply_factor;
|
||||
return flops;
|
||||
} else if (op_name == kMulOp) {
|
||||
if (extra_args.find(kMatSize) == extra_args.end()) {
|
||||
TORCH_WARN(
|
||||
"Calculating flops for aten::mul.Tensor requires mat_size in saved arguments.");
|
||||
return 0;
|
||||
}
|
||||
auto mat_sizes = extra_args.at(kMatSize);
|
||||
if (!mat_sizes.isIntList()) {
|
||||
TORCH_WARN(
|
||||
"Failed to compute flops for op aten::mul because it requires mat_size to be IntList.");
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::vector<int64_t> mat_size = mat_sizes.toIntVector();
|
||||
uint64_t flops = 1;
|
||||
for (int64_t dim : mat_size) {
|
||||
flops *= dim;
|
||||
}
|
||||
return flops;
|
||||
} else if (op_name == kAddOp) {
|
||||
if (extra_args.find(kMatSize) == extra_args.end()) {
|
||||
TORCH_WARN(
|
||||
"Calculating flops for aten::add.Tensor requires mat_size in saved arguments.");
|
||||
return 0;
|
||||
}
|
||||
auto mat_sizes = extra_args.at(kMatSize);
|
||||
if (!mat_sizes.isIntList()) {
|
||||
TORCH_WARN(
|
||||
"Failed to compute flops for op aten::add because it requires mat_size to be IntList.");
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::vector<int64_t> mat_size = mat_sizes.toIntVector();
|
||||
uint64_t flops = 1;
|
||||
for (int64_t dim : mat_size) {
|
||||
flops *= dim;
|
||||
}
|
||||
return flops;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
} // namespace profiler
|
||||
} // namespace torch
|
||||
119
torch/csrc/profiler/util.h
Normal file
119
torch/csrc/profiler/util.h
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/record_function.h>
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/csrc/jit/frontend/source_range.h>
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <ctime>
|
||||
#endif
|
||||
#if defined(C10_IOS) && defined(C10_MOBILE)
|
||||
#include <sys/time.h> // for gettimeofday()
|
||||
#endif
|
||||
|
||||
// skip Kineto dependency on mobile unless explicitly asked for.
|
||||
// When is it explicitly asked for?
|
||||
// KinetoEdgeCPUProfiler uses KinetoProfiler for cpu
|
||||
// event profiling. This has dependency on cpu only libkineto
|
||||
#if defined(USE_KINETO) && defined(C10_MOBILE) && \
|
||||
!defined(EDGE_PROFILER_USE_KINETO)
|
||||
#undef USE_KINETO
|
||||
#endif
|
||||
|
||||
namespace torch {
|
||||
namespace profiler {
|
||||
|
||||
#ifdef USE_KINETO
|
||||
constexpr bool kKinetoAvailable {true};
|
||||
#else
|
||||
constexpr bool kKinetoAvailable {false};
|
||||
#endif
|
||||
|
||||
namespace impl {
|
||||
|
||||
inline int64_t getTime(bool allow_monotonic = false) {
|
||||
#if defined(C10_IOS) && defined(C10_MOBILE)
|
||||
// clock_gettime is only available on iOS 10.0 or newer. Unlike OS X, iOS
|
||||
// can't rely on CLOCK_REALTIME, as it is defined no matter if clock_gettime
|
||||
// is implemented or not
|
||||
struct timeval now;
|
||||
gettimeofday(&now, NULL);
|
||||
return static_cast<int64_t>(now.tv_sec) * 1000000000 +
|
||||
static_cast<int64_t>(now.tv_usec) * 1000;
|
||||
#elif defined(_WIN32) || defined(__MACH__)
|
||||
using namespace std::chrono;
|
||||
using clock = std::conditional<
|
||||
high_resolution_clock::is_steady,
|
||||
high_resolution_clock,
|
||||
steady_clock>::type;
|
||||
return duration_cast<nanoseconds>(clock::now().time_since_epoch()).count();
|
||||
#else
|
||||
// clock_gettime is *much* faster than std::chrono implementation on Linux
|
||||
struct timespec t {};
|
||||
auto mode = CLOCK_REALTIME;
|
||||
if (allow_monotonic) {
|
||||
mode = CLOCK_MONOTONIC;
|
||||
}
|
||||
clock_gettime(mode, &t);
|
||||
return static_cast<int64_t>(t.tv_sec) * 1000000000 +
|
||||
static_cast<int64_t>(t.tv_nsec);
|
||||
#endif
|
||||
}
|
||||
|
||||
// NB: This only works if USE_KINETO is set. (Otherwise it just logs a warning)
|
||||
TORCH_API void addMetadataJson(
|
||||
const std::string& key,
|
||||
const std::string& value);
|
||||
|
||||
std::string getNvtxStr(
|
||||
const char* name,
|
||||
int64_t sequence_nr,
|
||||
const std::vector<std::vector<int64_t>>& shapes);
|
||||
|
||||
struct TORCH_API FileLineFunc {
|
||||
std::string filename;
|
||||
size_t line;
|
||||
std::string funcname;
|
||||
};
|
||||
|
||||
TORCH_API std::vector<FileLineFunc> prepareCallstack(
|
||||
const std::vector<jit::StackEntry>& cs);
|
||||
TORCH_API std::vector<std::string> callstackStr(
|
||||
const std::vector<FileLineFunc>& cs);
|
||||
TORCH_API std::string stacksToStr(
|
||||
const std::vector<std::string>& stacks,
|
||||
const char* delim);
|
||||
TORCH_API std::vector<std::vector<int64_t>> inputSizes(
|
||||
const at::RecordFunction& fn);
|
||||
TORCH_API std::string shapesToStr(
|
||||
const std::vector<std::vector<int64_t>>& shapes);
|
||||
TORCH_API std::string dtypesToStr(const std::vector<std::string>& types);
|
||||
TORCH_API std::vector<std::string> inputTypes(const at::RecordFunction& fn);
|
||||
|
||||
std::unordered_map<std::string, c10::IValue> TORCH_API
|
||||
saveExtraArgs(const at::RecordFunction& fn);
|
||||
|
||||
uint64_t TORCH_API computeFlops(
|
||||
const std::string& op_name,
|
||||
const std::unordered_map<std::string, c10::IValue>& extra_args);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace profiler
|
||||
} // namespace torch
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
namespace profiler {
|
||||
using torch::profiler::impl::getTime;
|
||||
using torch::profiler::impl::computeFlops;
|
||||
} // namespace profiler
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
Loading…
Reference in New Issue
Block a user