#pragma once #include #include #include #include #include #include #include namespace torch::jit { namespace profiling { struct Datapoint { using Timepoint = std::chrono::time_point; SourceRange sourceRange; Timepoint start; Timepoint end; explicit Datapoint(SourceRange sr) : sourceRange(std::move(sr)), start(std::chrono::steady_clock::now()) {} }; class TORCH_API InstructionSpan { public: explicit InstructionSpan(Node& /*node*/); ~InstructionSpan(); InstructionSpan(InstructionSpan&&) = delete; InstructionSpan& operator=(InstructionSpan&&) = delete; private: std::unique_ptr datapoint_; }; bool TORCH_API isProfilingOngoing(); } // namespace profiling struct TORCH_API InstructionStats : public CustomClassHolder { int64_t count{0}; std::chrono::nanoseconds duration{0}; }; class TORCH_API SourceStats : public CustomClassHolder { public: using LineMap = c10::Dict>; SourceStats(SourceRef source, const LineMap& lineMap) : source_(std::move(source)), lineMap_(lineMap) {} const SourceRef& getSourceRef() const { return source_; } const LineMap& getLineMap() const { return lineMap_; } private: SourceRef source_; LineMap lineMap_; }; /** * ScriptProfile is an underlying C++ implementation for TorchScript profiling. * The profiling section is specified by calling enable() and disable(): * * ... * scriptProfile.enable(); * ... * (scripts) * ... * scriptProfile.disable(); * ... * * NOTE: you cannot attach the profiler while the script is running. * * To retrieve collected runtime data, users may call dumpStats() and do * arbitrary filtering on the data they want. Note that dumpStats() should * not be called inside a profiling section. * In general, stats are aggregated per source function body, and then by line * number. */ class TORCH_API ScriptProfile : public CustomClassHolder { // Aggregates datapoints by function source id, then by line number. using LineMap = std::map; using SourceMap = std::map>; public: void enable(); void disable(); const SourceMap& dumpStats(); void addDatapoint(std::shared_ptr /*datapoint*/); ~ScriptProfile() override; private: bool enabled_{false}; std::vector> datapoints_; SourceMap sourceMap_; }; } // namespace torch::jit