mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "Fast standalone symbolize for unwinding (#123966)"
This reverts commit 772ae6da1e.
Reverted https://github.com/pytorch/pytorch/pull/123966 on behalf of https://github.com/jeanschmidt due to Breaking internal builds, check D56522678 ([comment](https://github.com/pytorch/pytorch/pull/123966#issuecomment-2076821043))
This commit is contained in:
parent
2d7f709752
commit
c0fd7894cc
|
|
@ -16,11 +16,8 @@ except ImportError:
|
|||
import collections
|
||||
import gc
|
||||
import json
|
||||
import mmap
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import struct
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
|
@ -77,9 +74,7 @@ from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
|
|||
from torch.testing._internal.common_device_type import skipCUDAVersionIn
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
IS_ARM64,
|
||||
IS_JETSON,
|
||||
IS_LINUX,
|
||||
IS_WINDOWS,
|
||||
parametrize,
|
||||
run_tests,
|
||||
|
|
@ -3602,70 +3597,6 @@ aten::mm""",
|
|||
finally:
|
||||
os.remove("torchtidy_report.json")
|
||||
|
||||
@unittest.skipIf(IS_ARM64 or not IS_LINUX, "x86 linux only cpp unwinding")
|
||||
def test_fuzz_symbolize(self):
|
||||
# generate some random addresses in the text section and make sure the
|
||||
# symbolizers do not throw exceptions/crash
|
||||
def get_text_sections():
|
||||
text_sections = []
|
||||
seen = set()
|
||||
for filename in os.listdir("/proc/self/map_files"):
|
||||
library = os.readlink("/proc/self/map_files/" + filename)
|
||||
if ".so" not in library or library in seen:
|
||||
continue
|
||||
seen.add(library)
|
||||
with open(os.path.join("/proc/self/map_files", library), "rb") as f:
|
||||
mm = mmap.mmap(f.fileno(), 0, prot=mmap.PROT_READ)
|
||||
|
||||
def unpack(fmt, offset):
|
||||
return struct.unpack(
|
||||
fmt, mm[offset : offset + struct.calcsize(fmt)]
|
||||
)
|
||||
|
||||
if mm[:4] != b"\x7fELF":
|
||||
continue
|
||||
(section_headers_start,) = unpack("Q", 40)
|
||||
(section_header_size,) = unpack("H", 58)
|
||||
(num_section_headers,) = unpack("H", 60)
|
||||
(shstrndx,) = unpack("H", 62)
|
||||
(shstrtab_offset,) = unpack(
|
||||
"Q", section_headers_start + shstrndx * section_header_size + 24
|
||||
)
|
||||
for i in range(num_section_headers):
|
||||
(section_name_offset,) = unpack(
|
||||
"I", section_headers_start + i * section_header_size
|
||||
)
|
||||
name_start = shstrtab_offset + section_name_offset
|
||||
section_name = mm[name_start : name_start + 6]
|
||||
if section_name != b".text\0":
|
||||
continue
|
||||
(section_offset,) = unpack(
|
||||
"Q", section_headers_start + i * section_header_size + 24
|
||||
)
|
||||
(section_size,) = unpack(
|
||||
"Q", section_headers_start + i * section_header_size + 32
|
||||
)
|
||||
start = int(filename.split("-")[0], 16) + section_offset
|
||||
text_sections.append((start, section_size))
|
||||
break
|
||||
mm.close()
|
||||
return text_sections
|
||||
|
||||
r = random.Random()
|
||||
r.seed(1)
|
||||
text_sections = get_text_sections()
|
||||
addrs = []
|
||||
for i in range(200):
|
||||
s = r.randrange(0, len(text_sections))
|
||||
start, size = text_sections[s]
|
||||
addr = r.randrange(start, start + size)
|
||||
addrs.append(addr)
|
||||
fast = torch._C._profiler.symbolize_addresses(addrs, "fast")
|
||||
dladdr = torch._C._profiler.symbolize_addresses(addrs, "dladdr")
|
||||
addr2line = torch._C._profiler.symbolize_addresses(addrs, "addr2line")
|
||||
self.assertEqual(len(fast), len(addrs))
|
||||
self.assertEqual(len(addr2line), len(fast))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import subprocess
|
|||
import random
|
||||
from random import randint
|
||||
import json
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
from torch.cuda._memory_viz import profile_plot, _profile_to_snapshot
|
||||
|
|
|
|||
|
|
@ -168,14 +168,12 @@ static PyObject* THPModule_initExtension(
|
|||
PyObject* shm_manager_path) {
|
||||
HANDLE_TH_ERRORS
|
||||
#if !defined(FBCODE_CAFFE2)
|
||||
if (torch::get_cpp_stacktraces_enabled()) {
|
||||
if (torch::get_cpp_stacktraces_enabled() && !torch::get_disable_addr2line()) {
|
||||
c10::SetStackTraceFetcher([]() -> std::string {
|
||||
auto tb = torch::CapturedTraceback::gather(false, false, true);
|
||||
if (torch::get_symbolize_mode() == torch::unwind::Mode::addr2line) {
|
||||
LOG(WARNING)
|
||||
<< "symbolizing C++ stack trace for exception; if this hangs, rerun with TORCH_DISABLE_ADDR2LINE=1..."
|
||||
<< std::endl;
|
||||
}
|
||||
LOG(WARNING)
|
||||
<< "symbolizing C++ stack trace for exception; if this hangs, rerun with TORCH_DISABLE_ADDR2LINE=1..."
|
||||
<< std::endl;
|
||||
auto s_tbs = torch::symbolize({tb.get()});
|
||||
std::stringstream oss;
|
||||
oss << "C++ CapturedTraceback:" << std::endl;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
#include <torch/csrc/profiler/combined_traceback.h>
|
||||
#include <torch/csrc/utils/cpp_stacktraces.h>
|
||||
|
||||
namespace torch {
|
||||
|
||||
|
|
@ -78,7 +77,7 @@ SymbolizedTracebacks symbolize(
|
|||
}
|
||||
// gather symbol names for C++ frames
|
||||
if (!all_cpp_ips.empty()) {
|
||||
r.all_frames = unwind::symbolize(all_cpp_ips, torch::get_symbolize_mode());
|
||||
r.all_frames = unwind::symbolize(all_cpp_ips);
|
||||
}
|
||||
|
||||
// batch symbolization requests so we dedup frame objects
|
||||
|
|
|
|||
|
|
@ -79,7 +79,8 @@ PyTypeObject THPCapturedTracebackType = {
|
|||
nullptr, /* tp_new */
|
||||
};
|
||||
|
||||
namespace pybind11::detail {
|
||||
namespace pybind11 {
|
||||
namespace detail {
|
||||
|
||||
template <>
|
||||
struct type_caster<std::shared_ptr<torch::CapturedTraceback>> {
|
||||
|
|
@ -106,9 +107,11 @@ struct type_caster<std::shared_ptr<torch::CapturedTraceback>> {
|
|||
}
|
||||
};
|
||||
|
||||
} // namespace pybind11::detail
|
||||
} // namespace detail
|
||||
} // namespace pybind11
|
||||
|
||||
namespace torch::profiler {
|
||||
namespace torch {
|
||||
namespace profiler {
|
||||
|
||||
/* [NOTE: RecordFunctionFast]
|
||||
* This is an alternate way to call record_function from python.
|
||||
|
|
@ -604,33 +607,6 @@ void initPythonBindings(PyObject* module) {
|
|||
}
|
||||
return py_symbolize(tb_ptrs);
|
||||
});
|
||||
// directly convert address pointers to frames, used for testing symbolize
|
||||
m.def(
|
||||
"symbolize_addresses",
|
||||
[](const std::vector<uint64_t>& frames, const std::string& mode_s) {
|
||||
std::vector<std::tuple<std::string, int64_t, std::string>> frames_out;
|
||||
torch::unwind::Mode mode = torch::unwind::Mode::addr2line;
|
||||
if (mode_s == "fast") {
|
||||
mode = torch::unwind::Mode::fast;
|
||||
} else if (mode_s == "addr2line") {
|
||||
mode = torch::unwind::Mode::addr2line;
|
||||
} else if (mode_s == "dladdr") {
|
||||
mode = torch::unwind::Mode::dladdr;
|
||||
} else {
|
||||
TORCH_CHECK(false, "unexpected mode ", mode_s);
|
||||
}
|
||||
std::vector<void*> frames_p;
|
||||
frames_p.reserve(frames.size());
|
||||
for (auto f : frames) {
|
||||
frames_p.push_back((void*)f); // NOLINT
|
||||
}
|
||||
auto frame_objects = unwind::symbolize(frames_p, mode);
|
||||
frames_out.reserve(frame_objects.size());
|
||||
for (auto& frame : frame_objects) {
|
||||
frames_out.emplace_back(frame.filename, frame.lineno, frame.funcname);
|
||||
}
|
||||
return frames_out;
|
||||
});
|
||||
installCapturedTracebackPython();
|
||||
|
||||
// NOLINTNEXTLINE(*-c-arrays*)
|
||||
|
|
@ -664,4 +640,5 @@ void initPythonBindings(PyObject* module) {
|
|||
throw python_error();
|
||||
}
|
||||
}
|
||||
} // namespace torch::profiler
|
||||
} // namespace profiler
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@
|
|||
#include <stdint.h>
|
||||
#include <ostream>
|
||||
|
||||
namespace torch::unwind {
|
||||
|
||||
enum {
|
||||
A_UNDEFINED = 0x0,
|
||||
A_REG_PLUS_DATA = 0x1, // exp = REG[reg] + data0
|
||||
|
|
@ -55,5 +53,3 @@ struct Action {
|
|||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace torch::unwind
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
#include <unistd.h>
|
||||
#include <memory>
|
||||
|
||||
namespace torch::unwind {
|
||||
// helper to open a process with stdin/stdout/stderr streams.
|
||||
struct Communicate {
|
||||
Communicate(const char* command, const char** args) {
|
||||
|
|
@ -64,5 +63,3 @@ struct Communicate {
|
|||
std::unique_ptr<std::ostream> out_;
|
||||
std::unique_ptr<std::ostream> err_;
|
||||
};
|
||||
|
||||
} // namespace torch::unwind
|
||||
|
|
|
|||
|
|
@ -1,279 +0,0 @@
|
|||
#pragma once
|
||||
#include <torch/csrc/profiler/unwind/dwarf_enums.h>
|
||||
#include <torch/csrc/profiler/unwind/dwarf_symbolize_enums.h>
|
||||
#include <torch/csrc/profiler/unwind/lexer.h>
|
||||
#include <torch/csrc/profiler/unwind/sections.h>
|
||||
#include <torch/csrc/profiler/unwind/unwind_error.h>
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
|
||||
namespace torch::unwind {
|
||||
|
||||
struct DebugInfo {
|
||||
DebugInfo(Sections& s) : s_(s) {}
|
||||
|
||||
void parse(uint64_t offset) {
|
||||
auto L = parseHeader(offset);
|
||||
parseCompileUnit(L);
|
||||
}
|
||||
unwind::optional<uint64_t> lineNumberProgramOffset() {
|
||||
return line_number_program_offset_;
|
||||
}
|
||||
uint64_t nextOffset() {
|
||||
return end_ - s_.debug_info.data;
|
||||
}
|
||||
std::vector<std::pair<uint64_t, uint64_t>> ranges() {
|
||||
if (range_ptr_) {
|
||||
auto offset = range_ptr_->first;
|
||||
if (range_ptr_->second == DW_FORM_rnglistx) {
|
||||
UNWIND_CHECK(rnglists_base_, "rnglistx but not rnglists_base_ set");
|
||||
LOG_INFO("index for rnglistx {:x} + {:x}\n", *rnglists_base_, offset);
|
||||
CheckedLexer L = s_.debug_rnglists.lexer(
|
||||
*rnglists_base_ + offset * sec_offset_size_);
|
||||
auto read = readSegmentOffset(L);
|
||||
offset = *rnglists_base_ + read;
|
||||
}
|
||||
return version_ == 4 ? readRanges4(offset) : readRanges5(offset);
|
||||
}
|
||||
if (!highpc_) {
|
||||
return {};
|
||||
}
|
||||
return {{lowpc_, lowpc_ + *highpc_}};
|
||||
}
|
||||
|
||||
bool is64bit() {
|
||||
return is_64bit_;
|
||||
}
|
||||
|
||||
private:
|
||||
CheckedLexer parseHeader(uint64_t offset) {
|
||||
offset_ = offset;
|
||||
CheckedLexer L = s_.debug_info.lexer(offset_);
|
||||
std::tie(length_, is_64bit_) = L.readSectionLength();
|
||||
sec_offset_size_ = is_64bit_ ? 8 : 4;
|
||||
end_ = (const char*)L.loc() + length_;
|
||||
version_ = L.read<uint16_t>();
|
||||
UNWIND_CHECK(
|
||||
version_ == 5 || version_ == 4,
|
||||
"unexpected dwarf version {}",
|
||||
version_);
|
||||
uint8_t address_size = 0;
|
||||
if (version_ == 5) {
|
||||
auto unit_type = L.read<uint8_t>();
|
||||
UNWIND_CHECK(unit_type == 0x1, "unexpected unit type {}", unit_type);
|
||||
address_size = L.read<uint8_t>();
|
||||
debug_abbrev_offset_ =
|
||||
is_64bit_ ? L.read<uint64_t>() : L.read<uint32_t>();
|
||||
} else {
|
||||
debug_abbrev_offset_ =
|
||||
is_64bit_ ? L.read<uint64_t>() : L.read<uint32_t>();
|
||||
address_size = L.read<uint8_t>();
|
||||
}
|
||||
LOG_INFO(
|
||||
"compilation unit at offset {:x} with length {:x} and debug_abbrev_offset {:x}\n",
|
||||
offset,
|
||||
length_,
|
||||
debug_abbrev_offset_);
|
||||
UNWIND_CHECK(
|
||||
address_size == 8,
|
||||
"expected 64-bit dwarf but found address size {}",
|
||||
address_size);
|
||||
return L;
|
||||
}
|
||||
|
||||
uint64_t readSegmentOffset(CheckedLexer& L) {
|
||||
return s_.readSegmentOffset(L, is_64bit_);
|
||||
}
|
||||
|
||||
uint64_t readEncoded(CheckedLexer& L, uint64_t encoding) {
|
||||
switch (encoding) {
|
||||
case DW_FORM_data8:
|
||||
case DW_FORM_addr:
|
||||
return L.read<uint64_t>();
|
||||
case DW_FORM_data4:
|
||||
return L.read<uint32_t>();
|
||||
case DW_FORM_addrx: {
|
||||
auto idx = L.readULEB128();
|
||||
return s_.debug_addr.lexer(address_base_ + sizeof(uint64_t) * idx)
|
||||
.read<uint64_t>();
|
||||
}
|
||||
case DW_FORM_sec_offset:
|
||||
return readSegmentOffset(L);
|
||||
case DW_FORM_rnglistx: {
|
||||
return L.readULEB128();
|
||||
}
|
||||
default:
|
||||
UNWIND_CHECK(false, "unexpected encoding");
|
||||
}
|
||||
}
|
||||
|
||||
void parseCompileUnit(CheckedLexer& L) {
|
||||
auto entry = L.readULEB128();
|
||||
auto A = findAbbrev(debug_abbrev_offset_, entry);
|
||||
while (true) {
|
||||
auto attr = A.readULEB128();
|
||||
auto form = A.readULEB128();
|
||||
if (attr == 0 && form == 0) {
|
||||
break;
|
||||
}
|
||||
if (form == DW_FORM_implicit_const) {
|
||||
A.readSLEB128();
|
||||
}
|
||||
if (attr == DW_AT_low_pc) {
|
||||
lowpc_ = readEncoded(L, form);
|
||||
LOG_INFO(" lowpc {:x}\n", lowpc_);
|
||||
} else if (attr == DW_AT_high_pc) {
|
||||
highpc_ = readEncoded(L, form);
|
||||
range_ptr_ = std::nullopt;
|
||||
LOG_INFO(" highpc {:x}\n", *highpc_);
|
||||
} else if (attr == DW_AT_addr_base) {
|
||||
UNWIND_CHECK(form == DW_FORM_sec_offset, "unexpected addr_base form");
|
||||
address_base_ = readSegmentOffset(L);
|
||||
LOG_INFO(" address base {:x}\n", address_base_);
|
||||
} else if (attr == DW_AT_rnglists_base) {
|
||||
UNWIND_CHECK(
|
||||
form == DW_FORM_sec_offset, "unexpected rnglists_base form");
|
||||
rnglists_base_ = readSegmentOffset(L);
|
||||
LOG_INFO(" range base {:x}\n", *rnglists_base_);
|
||||
} else if (form == DW_FORM_string) {
|
||||
L.readCString();
|
||||
} else if (attr == DW_AT_stmt_list) {
|
||||
UNWIND_CHECK(form == DW_FORM_sec_offset, "unexpected stmt_list form");
|
||||
LOG_INFO(" program table offset {:x}\n", *line_number_program_offset_);
|
||||
line_number_program_offset_ = readSegmentOffset(L);
|
||||
} else if (form == DW_FORM_exprloc) {
|
||||
auto sz = L.readULEB128();
|
||||
L.skip(int64_t(sz));
|
||||
} else if (form == DW_FORM_block1) {
|
||||
auto sz = L.read<uint8_t>();
|
||||
L.skip(int64_t(sz));
|
||||
} else if (attr == DW_AT_ranges) {
|
||||
auto range_offset = readEncoded(L, form);
|
||||
LOG_INFO("setting range_ptr to {:x} {:x}\n", range_offset, form);
|
||||
range_ptr_.emplace(range_offset, form);
|
||||
} else if (
|
||||
form == DW_FORM_udata || form == DW_FORM_rnglistx ||
|
||||
form == DW_FORM_strx || form == DW_FORM_loclistx ||
|
||||
form == DW_FORM_addrx) {
|
||||
L.readULEB128();
|
||||
} else if (form == DW_FORM_sdata) {
|
||||
L.readSLEB128();
|
||||
} else {
|
||||
auto sz = formSize(form, sec_offset_size_);
|
||||
UNWIND_CHECK(sz, "unsupported form in compilation unit {:x}", form);
|
||||
L.skip(int64_t(*sz));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::pair<uint64_t, uint64_t>> readRanges4(uint64_t offset) {
|
||||
CheckedLexer L = s_.debug_ranges.lexer(offset);
|
||||
std::vector<std::pair<uint64_t, uint64_t>> ranges;
|
||||
uint64_t base = lowpc_;
|
||||
while (true) {
|
||||
auto start = L.read<uint64_t>();
|
||||
auto end = L.read<uint64_t>();
|
||||
if (start == 0 && end == 0) {
|
||||
break;
|
||||
}
|
||||
if (start == std::numeric_limits<uint64_t>::max()) {
|
||||
base = end;
|
||||
} else {
|
||||
ranges.emplace_back(base + start, base + end);
|
||||
}
|
||||
}
|
||||
return ranges;
|
||||
}
|
||||
|
||||
std::vector<std::pair<uint64_t, uint64_t>> readRanges5(uint64_t offset) {
|
||||
CheckedLexer L = s_.debug_rnglists.lexer(offset);
|
||||
uint64_t base = 0;
|
||||
LOG_INFO("BEGIN RANGES {:x}\n", offset);
|
||||
std::vector<std::pair<uint64_t, uint64_t>> ranges;
|
||||
while (true) {
|
||||
auto op = L.read<uint8_t>();
|
||||
switch (op) {
|
||||
case DW_RLE_end_of_list:
|
||||
LOG_INFO("END RANGES\n");
|
||||
return ranges;
|
||||
case DW_RLE_base_addressx: {
|
||||
base = readEncoded(L, DW_FORM_addrx);
|
||||
LOG_INFO("BASE ADDRX {:x}\n", base);
|
||||
} break;
|
||||
case DW_RLE_startx_length: {
|
||||
auto s = readEncoded(L, DW_FORM_addrx);
|
||||
auto e = L.readULEB128();
|
||||
LOG_INFO("startx_length {:x} {:x}\n", s, e);
|
||||
ranges.emplace_back(s, s + e);
|
||||
} break;
|
||||
case DW_RLE_base_address:
|
||||
base = L.read<uint64_t>();
|
||||
LOG_INFO("BASE ADDR {:x}\n", base);
|
||||
break;
|
||||
case DW_RLE_offset_pair: {
|
||||
auto s = L.readULEB128();
|
||||
auto e = L.readULEB128();
|
||||
LOG_INFO("offset_pair {:x} {:x}\n", s, e);
|
||||
ranges.emplace_back(base + s, base + e);
|
||||
} break;
|
||||
case DW_RLE_start_length: {
|
||||
auto s = L.read<uint64_t>();
|
||||
auto e = L.readULEB128();
|
||||
LOG_INFO("start_length {:x} {:x}\n", s, e);
|
||||
ranges.emplace_back(s, s + e);
|
||||
} break;
|
||||
default:
|
||||
UNWIND_CHECK(false, "unknown range op: {}", op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CheckedLexer findAbbrev(uint64_t offset, uint64_t entry) {
|
||||
CheckedLexer L = s_.debug_abbrev.lexer(offset);
|
||||
while (true) {
|
||||
auto abbrev_code = L.readULEB128();
|
||||
UNWIND_CHECK(
|
||||
abbrev_code != 0,
|
||||
"could not find entry {} at offset {:x}",
|
||||
entry,
|
||||
offset);
|
||||
auto tag = L.readULEB128();
|
||||
L.read<uint8_t>(); // has children
|
||||
if (abbrev_code == entry) {
|
||||
UNWIND_CHECK(
|
||||
tag == DW_TAG_compile_unit,
|
||||
"first entry was not a compile unit but {}",
|
||||
tag);
|
||||
return L;
|
||||
}
|
||||
while (true) {
|
||||
auto attr = L.readULEB128();
|
||||
auto form = L.readULEB128();
|
||||
if (attr == 0 && form == 0) {
|
||||
break;
|
||||
}
|
||||
if (form == DW_FORM_implicit_const) {
|
||||
L.readSLEB128();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Sections& s_;
|
||||
optional<uint64_t> line_number_program_offset_;
|
||||
uint64_t offset_ = 0;
|
||||
uint8_t sec_offset_size_ = 0;
|
||||
uint64_t length_ = 0;
|
||||
const char* end_ = nullptr;
|
||||
uint64_t debug_abbrev_offset_ = 0;
|
||||
bool is_64bit_ = false;
|
||||
|
||||
std::optional<std::pair<uint64_t, uint8_t>> range_ptr_;
|
||||
uint64_t lowpc_ = 0;
|
||||
optional<uint64_t> highpc_;
|
||||
uint16_t version_ = 0;
|
||||
uint64_t address_base_ = 0;
|
||||
optional<uint64_t> rnglists_base_;
|
||||
};
|
||||
|
||||
} // namespace torch::unwind
|
||||
|
|
@ -1,181 +0,0 @@
|
|||
#pragma once
|
||||
#include <torch/csrc/profiler/unwind/unwind_error.h>
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
|
||||
enum {
|
||||
DW_TAG_subprogram = 0x2e,
|
||||
DW_TAG_inlined_subroutine = 0x1d,
|
||||
DW_TAG_compile_unit = 0x11,
|
||||
DW_AT_sibling = 0x1, // reference
|
||||
DW_AT_name = 0x3, // string
|
||||
DW_AT_stmt_list = 0x10, // lineptr
|
||||
DW_AT_addr_base = 0x73, // sec_offset
|
||||
DW_AT_rnglists_base = 0x74, // sec_offset
|
||||
DW_AT_low_pc = 0x11, // address
|
||||
DW_AT_high_pc = 0x12, // address
|
||||
DW_AT_specification = 0x47, // reference
|
||||
DW_AT_abstract_origin = 0x31, // reference
|
||||
DW_AT_linkage_name = 0x6e, // string
|
||||
DW_AT_ranges = 0x55, // rnglist
|
||||
DW_AT_str_offsets_base = 0x72, // sec_offset
|
||||
DW_FORM_addr = 0x01,
|
||||
DW_FORM_block2 = 0x03,
|
||||
DW_FORM_block4 = 0x04,
|
||||
DW_FORM_data2 = 0x05,
|
||||
DW_FORM_data4 = 0x06,
|
||||
DW_FORM_data8 = 0x07,
|
||||
DW_FORM_string = 0x08,
|
||||
DW_FORM_block = 0x09,
|
||||
DW_FORM_block1 = 0x0a,
|
||||
DW_FORM_data1 = 0x0b,
|
||||
DW_FORM_flag = 0x0c,
|
||||
DW_FORM_sdata = 0x0d,
|
||||
DW_FORM_strp = 0x0e,
|
||||
DW_FORM_udata = 0x0f,
|
||||
DW_FORM_ref_addr = 0x10,
|
||||
DW_FORM_ref1 = 0x11,
|
||||
DW_FORM_ref2 = 0x12,
|
||||
DW_FORM_ref4 = 0x13,
|
||||
DW_FORM_ref8 = 0x14,
|
||||
DW_FORM_ref_udata = 0x15,
|
||||
DW_FORM_indirect = 0x16,
|
||||
DW_FORM_sec_offset = 0x17,
|
||||
DW_FORM_exprloc = 0x18,
|
||||
DW_FORM_flag_present = 0x19,
|
||||
DW_FORM_strx = 0x1a,
|
||||
DW_FORM_addrx = 0x1b,
|
||||
DW_FORM_ref_sup4 = 0x1c,
|
||||
DW_FORM_strp_sup = 0x1d,
|
||||
DW_FORM_data16 = 0x1e,
|
||||
DW_FORM_line_strp = 0x1f,
|
||||
DW_FORM_ref_sig8 = 0x20,
|
||||
DW_FORM_implicit_const = 0x21,
|
||||
DW_FORM_loclistx = 0x22,
|
||||
DW_FORM_rnglistx = 0x23,
|
||||
DW_FORM_ref_sup8 = 0x24,
|
||||
DW_FORM_strx1 = 0x25,
|
||||
DW_FORM_strx2 = 0x26,
|
||||
DW_FORM_strx3 = 0x27,
|
||||
DW_FORM_strx4 = 0x28,
|
||||
DW_FORM_addrx1 = 0x29,
|
||||
DW_FORM_addrx2 = 0x2a,
|
||||
DW_FORM_addrx3 = 0x2b,
|
||||
DW_FORM_addrx4 = 0x2c,
|
||||
/* GNU Debug Fission extensions. */
|
||||
DW_FORM_GNU_addr_index = 0x1f01,
|
||||
DW_FORM_GNU_str_index = 0x1f02,
|
||||
DW_FORM_GNU_ref_alt = 0x1f20, /* offset in alternate .debuginfo. */
|
||||
DW_FORM_GNU_strp_alt = 0x1f21, /* offset in alternate .debug_str. */
|
||||
DW_LNCT_path = 0x1,
|
||||
DW_LNCT_directory_index = 0x2,
|
||||
DW_LNS_extended_op = 0x00,
|
||||
DW_LNE_end_sequence = 0x01,
|
||||
DW_LNE_set_address = 0x02,
|
||||
DW_LNS_copy = 0x01,
|
||||
DW_LNS_advance_pc = 0x02,
|
||||
DW_LNS_advance_line = 0x03,
|
||||
DW_LNS_set_file = 0x04,
|
||||
DW_LNS_const_add_pc = 0x08,
|
||||
DW_LNS_fixed_advance_pc = 0x09,
|
||||
DW_RLE_end_of_list = 0x0,
|
||||
DW_RLE_base_addressx = 0x1,
|
||||
DW_RLE_startx_endx = 0x2,
|
||||
DW_RLE_startx_length = 0x3,
|
||||
DW_RLE_offset_pair = 0x4,
|
||||
DW_RLE_base_address = 0x5,
|
||||
DW_RLE_start_end = 0x6,
|
||||
DW_RLE_start_length = 0x7
|
||||
};
|
||||
|
||||
static torch::unwind::optional<size_t> formSize(
|
||||
uint64_t form,
|
||||
uint8_t sec_offset_size) {
|
||||
switch (form) {
|
||||
case DW_FORM_addr:
|
||||
return sizeof(void*);
|
||||
case DW_FORM_block2:
|
||||
case DW_FORM_block4:
|
||||
return std::nullopt;
|
||||
case DW_FORM_data2:
|
||||
return 2;
|
||||
case DW_FORM_data4:
|
||||
return 4;
|
||||
case DW_FORM_data8:
|
||||
return 8;
|
||||
case DW_FORM_string:
|
||||
case DW_FORM_block:
|
||||
case DW_FORM_block1:
|
||||
return std::nullopt;
|
||||
case DW_FORM_data1:
|
||||
case DW_FORM_flag:
|
||||
return 1;
|
||||
case DW_FORM_sdata:
|
||||
return std::nullopt;
|
||||
case DW_FORM_strp:
|
||||
return sec_offset_size;
|
||||
case DW_FORM_udata:
|
||||
return std::nullopt;
|
||||
case DW_FORM_ref_addr:
|
||||
return sec_offset_size;
|
||||
case DW_FORM_ref1:
|
||||
return 1;
|
||||
case DW_FORM_ref2:
|
||||
return 2;
|
||||
case DW_FORM_ref4:
|
||||
return 4;
|
||||
case DW_FORM_ref8:
|
||||
return 8;
|
||||
case DW_FORM_ref_udata:
|
||||
case DW_FORM_indirect:
|
||||
return std::nullopt;
|
||||
case DW_FORM_sec_offset:
|
||||
return sec_offset_size;
|
||||
case DW_FORM_exprloc:
|
||||
return std::nullopt;
|
||||
case DW_FORM_flag_present:
|
||||
return 0;
|
||||
case DW_FORM_strx:
|
||||
case DW_FORM_addrx:
|
||||
return std::nullopt;
|
||||
case DW_FORM_ref_sup4:
|
||||
return 4;
|
||||
case DW_FORM_strp_sup:
|
||||
return sec_offset_size;
|
||||
case DW_FORM_data16:
|
||||
return 16;
|
||||
case DW_FORM_line_strp:
|
||||
return sec_offset_size;
|
||||
case DW_FORM_ref_sig8:
|
||||
return 8;
|
||||
case DW_FORM_implicit_const:
|
||||
return 0;
|
||||
case DW_FORM_loclistx:
|
||||
case DW_FORM_rnglistx:
|
||||
return std::nullopt;
|
||||
case DW_FORM_ref_sup8:
|
||||
return 8;
|
||||
case DW_FORM_strx1:
|
||||
return 1;
|
||||
case DW_FORM_strx2:
|
||||
return 2;
|
||||
case DW_FORM_strx3:
|
||||
return 3;
|
||||
case DW_FORM_strx4:
|
||||
return 4;
|
||||
case DW_FORM_addrx1:
|
||||
return 1;
|
||||
case DW_FORM_addrx2:
|
||||
return 2;
|
||||
case DW_FORM_addrx3:
|
||||
return 3;
|
||||
case DW_FORM_addrx4:
|
||||
return 4;
|
||||
case DW_FORM_GNU_addr_index:
|
||||
case DW_FORM_GNU_str_index:
|
||||
case DW_FORM_GNU_ref_alt:
|
||||
case DW_FORM_GNU_strp_alt:
|
||||
default:
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
// Overview of the format described in
|
||||
// https://refspecs.linuxfoundation.org/LSB_1.3.0/gLSB/gLSB/ehframehdr.html
|
||||
namespace torch::unwind {
|
||||
|
||||
struct EHFrameHdr {
|
||||
EHFrameHdr(void* base) : base_(base) {
|
||||
|
|
@ -94,5 +93,3 @@ struct EHFrameHdr {
|
|||
int64_t fde_count_;
|
||||
uint32_t table_size_;
|
||||
};
|
||||
|
||||
} // namespace torch::unwind
|
||||
|
|
|
|||
|
|
@ -1,108 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <sys/types.h>
|
||||
#include <torch/csrc/profiler/unwind/debug_info.h>
|
||||
#include <torch/csrc/profiler/unwind/line_number_program.h>
|
||||
#include <torch/csrc/profiler/unwind/sections.h>
|
||||
#include <torch/csrc/profiler/unwind/unwind.h>
|
||||
#include <torch/csrc/profiler/unwind/unwind_error.h>
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
|
||||
namespace torch::unwind {
|
||||
|
||||
#define UNWIND_WARN(w, ...) \
|
||||
do { \
|
||||
w.emplace_back(fmt::format(__VA_ARGS__)); \
|
||||
LOG_INFO("WARNING: {}\n", w.back()); \
|
||||
} while (0);
|
||||
|
||||
struct FastSymbolizer {
|
||||
FastSymbolizer() = default;
|
||||
Frame symbolize(const std::string& library, uint64_t offset) {
|
||||
LOG_INFO("symbolizing {} + 0x{:x}\n", library, offset);
|
||||
Frame frame;
|
||||
frame.funcname = "??";
|
||||
frame.filename = library;
|
||||
frame.lineno = offset;
|
||||
auto s = getOrCreateSections(library);
|
||||
if (auto e = s->findSubprogramName(offset)) {
|
||||
frame.funcname = *e;
|
||||
} else {
|
||||
UNWIND_WARN(
|
||||
warnings_,
|
||||
"failed to find subprogram name for {} 0x{:x}",
|
||||
library,
|
||||
offset);
|
||||
}
|
||||
if (auto e = findLine(s, offset)) {
|
||||
frame.filename = e->first;
|
||||
frame.lineno = e->second;
|
||||
} else {
|
||||
UNWIND_WARN(
|
||||
warnings_, "failed to find file/line for {} 0x{:x}", library, offset);
|
||||
}
|
||||
return frame;
|
||||
}
|
||||
const std::vector<std::string>& warnings() {
|
||||
return warnings_;
|
||||
}
|
||||
|
||||
private:
|
||||
void parseDebugInfo(Sections* s) {
|
||||
uint64_t offset = 0;
|
||||
while (offset < s->debug_info.size) {
|
||||
DebugInfo info(*s);
|
||||
info.parse(offset);
|
||||
if (auto lnp_offset = info.lineNumberProgramOffset()) {
|
||||
for (auto r : info.ranges()) {
|
||||
s->addDebugInfoRange(r.first, r.second, line_number_programs_.size());
|
||||
}
|
||||
line_number_programs_.emplace_back(
|
||||
std::make_unique<LineNumberProgram>(*s, *lnp_offset));
|
||||
}
|
||||
offset = info.nextOffset();
|
||||
}
|
||||
}
|
||||
Sections* getOrCreateSections(const std::string& library) {
|
||||
auto it = libraries_.find(library);
|
||||
if (it == libraries_.end()) {
|
||||
it = libraries_.insert({library, std::make_unique<Sections>()}).first;
|
||||
try {
|
||||
Sections* s = it->second.get();
|
||||
s->parse(library.c_str());
|
||||
parseDebugInfo(s);
|
||||
} catch (UnwindError& err) {
|
||||
UNWIND_WARN(
|
||||
warnings_, "failed to parse library {}: {}", library, err.what());
|
||||
}
|
||||
}
|
||||
return it->second.get();
|
||||
}
|
||||
optional<std::pair<std::string, int64_t>> findLine(
|
||||
Sections* s,
|
||||
uint64_t offset) {
|
||||
if (auto idx = s->findDebugInfoOffset(offset)) {
|
||||
auto r = line_number_programs_.at(*idx).get();
|
||||
try {
|
||||
r->parse();
|
||||
} catch (UnwindError& err) {
|
||||
UNWIND_WARN(
|
||||
warnings_,
|
||||
"failed to read line number program [{:x}] {}",
|
||||
r->offset(),
|
||||
err.what());
|
||||
}
|
||||
if (auto e = r->find(offset)) {
|
||||
return std::make_pair(r->filename(e->file), e->line);
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
std::unordered_map<std::string, std::unique_ptr<Sections>> libraries_;
|
||||
std::vector<std::unique_ptr<LineNumberProgram>> line_number_programs_;
|
||||
std::vector<std::string> warnings_;
|
||||
};
|
||||
|
||||
} // namespace torch::unwind
|
||||
|
|
@ -7,8 +7,6 @@
|
|||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
namespace torch::unwind {
|
||||
|
||||
struct TableState {
|
||||
Action cfa;
|
||||
std::array<Action, D_REG_SIZE> registers;
|
||||
|
|
@ -400,5 +398,3 @@ struct FDE {
|
|||
return strstr(augmentation_string_, s) != nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace torch::unwind
|
||||
|
|
|
|||
|
|
@ -1,31 +1,19 @@
|
|||
#pragma once
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <utility>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <torch/csrc/profiler/unwind/dwarf_enums.h>
|
||||
#include <torch/csrc/profiler/unwind/unwind_error.h>
|
||||
|
||||
namespace torch::unwind {
|
||||
|
||||
template <bool checked>
|
||||
struct LexerImpl {
|
||||
LexerImpl(void* data, void* base = nullptr, void* end = nullptr)
|
||||
: next_((const char*)data),
|
||||
base_((int64_t)base),
|
||||
end_((const char*)end) {}
|
||||
struct Lexer {
|
||||
Lexer(void* data, void* base = nullptr)
|
||||
: next_((const char*)data), base_((int64_t)base) {}
|
||||
|
||||
template <typename T>
|
||||
T read() {
|
||||
T result;
|
||||
auto end = next_ + sizeof(T);
|
||||
UNWIND_CHECK(
|
||||
!checked || end <= end_,
|
||||
"read out of bounds {} >= {}",
|
||||
(void*)end,
|
||||
(void*)end_);
|
||||
memcpy(&result, next_, sizeof(T));
|
||||
next_ = end;
|
||||
next_ += sizeof(T);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
@ -33,7 +21,7 @@ struct LexerImpl {
|
|||
int64_t readSLEB128() {
|
||||
int64_t Value = 0;
|
||||
unsigned Shift = 0;
|
||||
uint8_t Byte = 0;
|
||||
uint8_t Byte;
|
||||
do {
|
||||
Byte = read<uint8_t>();
|
||||
uint64_t Slice = Byte & 0x7f;
|
||||
|
|
@ -41,12 +29,12 @@ struct LexerImpl {
|
|||
(Shift == 63 && Slice != 0 && Slice != 0x7f)) {
|
||||
throw UnwindError("sleb128 too big for int64");
|
||||
}
|
||||
Value |= int64_t(Slice << Shift);
|
||||
Value |= Slice << Shift;
|
||||
Shift += 7;
|
||||
} while (Byte >= 128);
|
||||
// Sign extend negative numbers if needed.
|
||||
if (Shift < 64 && (Byte & 0x40)) {
|
||||
Value |= int64_t((-1ULL) << Shift);
|
||||
Value |= (-1ULL) << Shift;
|
||||
}
|
||||
return Value;
|
||||
}
|
||||
|
|
@ -54,7 +42,7 @@ struct LexerImpl {
|
|||
uint64_t readULEB128() {
|
||||
uint64_t Value = 0;
|
||||
unsigned Shift = 0;
|
||||
uint8_t p = 0;
|
||||
uint8_t p;
|
||||
do {
|
||||
p = read<uint8_t>();
|
||||
uint64_t Slice = p & 0x7f;
|
||||
|
|
@ -68,17 +56,8 @@ struct LexerImpl {
|
|||
}
|
||||
const char* readCString() {
|
||||
auto result = next_;
|
||||
if (!checked) {
|
||||
next_ += strlen(next_) + 1;
|
||||
return result;
|
||||
}
|
||||
while (next_ < end_) {
|
||||
if (*next_++ == '\0') {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
UNWIND_CHECK(
|
||||
false, "string is out of bounds {} >= {}", (void*)next_, (void*)end_);
|
||||
next_ += strlen(next_) + 1;
|
||||
return result;
|
||||
}
|
||||
int64_t readEncoded(uint8_t enc) {
|
||||
int64_t r = 0;
|
||||
|
|
@ -102,27 +81,20 @@ struct LexerImpl {
|
|||
}
|
||||
return readEncoded(enc);
|
||||
}
|
||||
|
||||
int64_t read4or8Length() {
|
||||
return readSectionLength().first;
|
||||
}
|
||||
|
||||
std::pair<int64_t, bool> readSectionLength() {
|
||||
int64_t length = read<uint32_t>();
|
||||
if (length == 0xFFFFFFFF) {
|
||||
return std::make_pair(read<int64_t>(), true);
|
||||
length = read<int64_t>();
|
||||
}
|
||||
return std::make_pair(length, false);
|
||||
return length;
|
||||
}
|
||||
|
||||
void* loc() const {
|
||||
return (void*)next_;
|
||||
}
|
||||
LexerImpl& skip(int64_t bytes) {
|
||||
Lexer& skip(int64_t bytes) {
|
||||
next_ += bytes;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int64_t readEncodedValue(uint8_t enc) {
|
||||
switch (enc & 0xF) {
|
||||
case DW_EH_PE_udata2:
|
||||
|
|
@ -149,11 +121,4 @@ struct LexerImpl {
|
|||
private:
|
||||
const char* next_;
|
||||
int64_t base_;
|
||||
const char* end_;
|
||||
};
|
||||
|
||||
// using Lexer = LexerImpl<false>;
|
||||
using CheckedLexer = LexerImpl<true>;
|
||||
using Lexer = LexerImpl<false>;
|
||||
|
||||
} // namespace torch::unwind
|
||||
|
|
|
|||
|
|
@ -1,325 +0,0 @@
|
|||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/profiler/unwind/debug_info.h>
|
||||
#include <torch/csrc/profiler/unwind/dwarf_enums.h>
|
||||
#include <torch/csrc/profiler/unwind/dwarf_symbolize_enums.h>
|
||||
#include <torch/csrc/profiler/unwind/lexer.h>
|
||||
#include <torch/csrc/profiler/unwind/sections.h>
|
||||
#include <torch/csrc/profiler/unwind/unwind_error.h>
|
||||
#include <tuple>
|
||||
|
||||
namespace torch::unwind {
|
||||
|
||||
struct LineNumberProgram {
|
||||
LineNumberProgram(Sections& s, uint64_t offset) : s_(s), offset_(offset) {}
|
||||
|
||||
uint64_t offset() {
|
||||
return offset_;
|
||||
}
|
||||
void parse() {
|
||||
if (parsed_) {
|
||||
return;
|
||||
}
|
||||
parsed_ = true;
|
||||
CheckedLexer L = s_.debug_line.lexer(offset_);
|
||||
std::tie(length_, is_64bit_) = L.readSectionLength();
|
||||
program_end_ = (char*)L.loc() + length_;
|
||||
auto version = L.read<uint16_t>();
|
||||
UNWIND_CHECK(
|
||||
version == 5 || version == 4,
|
||||
"expected version 4 or 5 but found {}",
|
||||
version);
|
||||
if (version == 5) {
|
||||
auto address_size = L.read<uint8_t>();
|
||||
UNWIND_CHECK(
|
||||
address_size == 8,
|
||||
"expected 64-bit dwarf but found address size {}",
|
||||
address_size);
|
||||
segment_selector_size_ = L.read<uint8_t>();
|
||||
}
|
||||
header_length_ = is_64bit_ ? L.read<uint64_t>() : L.read<uint32_t>();
|
||||
program_ = L;
|
||||
program_.skip(int64_t(header_length_));
|
||||
minimum_instruction_length_ = L.read<uint8_t>();
|
||||
maximum_operations_per_instruction_ = L.read<uint8_t>();
|
||||
default_is_stmt_ = L.read<uint8_t>();
|
||||
line_base_ = L.read<int8_t>();
|
||||
line_range_ = L.read<uint8_t>();
|
||||
opcode_base_ = L.read<uint8_t>();
|
||||
UNWIND_CHECK(line_range_ != 0, "line_range_ must be non-zero");
|
||||
standard_opcode_lengths_.resize(opcode_base_);
|
||||
for (size_t i = 1; i < opcode_base_; i++) {
|
||||
standard_opcode_lengths_[i] = L.read<uint8_t>();
|
||||
}
|
||||
// fmt::print("{:x} {:x} {} {} {} {} {}\n", offset_, header_length_,
|
||||
// minimum_instruction_length_, maximum_operations_per_instruction_,
|
||||
// line_base_, line_range_, opcode_base_);
|
||||
uint8_t directory_entry_format_count = L.read<uint8_t>();
|
||||
|
||||
if (version == 5) {
|
||||
struct Member {
|
||||
uint64_t content_type;
|
||||
uint64_t form;
|
||||
};
|
||||
std::vector<Member> directory_members;
|
||||
for (size_t i = 0; i < directory_entry_format_count; i++) {
|
||||
directory_members.push_back({L.readULEB128(), L.readULEB128()});
|
||||
}
|
||||
uint64_t directories_count = L.readULEB128();
|
||||
for (size_t i = 0; i < directories_count; i++) {
|
||||
for (auto& member : directory_members) {
|
||||
switch (member.content_type) {
|
||||
case DW_LNCT_path: {
|
||||
include_directories_.emplace_back(
|
||||
s_.readString(L, member.form, is_64bit_, 0));
|
||||
} break;
|
||||
default: {
|
||||
skipForm(L, member.form);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto i : c10::irange(directories_count)) {
|
||||
(void)i;
|
||||
LOG_INFO("{} {}\n", i, include_directories_[i]);
|
||||
}
|
||||
auto file_name_entry_format_count = L.read<uint8_t>();
|
||||
std::vector<Member> file_members;
|
||||
for (size_t i = 0; i < file_name_entry_format_count; i++) {
|
||||
file_members.push_back({L.readULEB128(), L.readULEB128()});
|
||||
}
|
||||
auto files_count = L.readULEB128();
|
||||
for (size_t i = 0; i < files_count; i++) {
|
||||
for (auto& member : file_members) {
|
||||
switch (member.content_type) {
|
||||
case DW_LNCT_path: {
|
||||
file_names_.emplace_back(
|
||||
s_.readString(L, member.form, is_64bit_, 0));
|
||||
} break;
|
||||
case DW_LNCT_directory_index: {
|
||||
file_directory_index_.emplace_back(readData(L, member.form));
|
||||
UNWIND_CHECK(
|
||||
file_directory_index_.back() < include_directories_.size(),
|
||||
"directory index out of range");
|
||||
} break;
|
||||
default: {
|
||||
skipForm(L, member.form);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto i : c10::irange(files_count)) {
|
||||
(void)i;
|
||||
LOG_INFO("{} {} {}\n", i, file_names_[i], file_directory_index_[i]);
|
||||
}
|
||||
} else {
|
||||
include_directories_.emplace_back(""); // implicit cwd
|
||||
while (true) {
|
||||
auto str = L.readCString();
|
||||
if (*str == '\0') {
|
||||
break;
|
||||
}
|
||||
include_directories_.emplace_back(str);
|
||||
}
|
||||
file_names_.emplace_back("");
|
||||
file_directory_index_.emplace_back(0);
|
||||
while (true) {
|
||||
auto str = L.readCString();
|
||||
if (*str == '\0') {
|
||||
break;
|
||||
}
|
||||
auto directory_index = L.readULEB128();
|
||||
L.readULEB128(); // mod_time
|
||||
L.readULEB128(); // file_length
|
||||
file_names_.emplace_back(str);
|
||||
file_directory_index_.push_back(directory_index);
|
||||
}
|
||||
}
|
||||
UNWIND_CHECK(
|
||||
maximum_operations_per_instruction_ == 1,
|
||||
"maximum_operations_per_instruction_ must be 1");
|
||||
UNWIND_CHECK(
|
||||
minimum_instruction_length_ == 1,
|
||||
"minimum_instruction_length_ must be 1");
|
||||
readProgram();
|
||||
}
|
||||
struct Entry {
|
||||
uint32_t file = 1;
|
||||
int64_t line = 1;
|
||||
};
|
||||
unwind::optional<Entry> find(uint64_t address) {
|
||||
auto e = program_index_.find(address);
|
||||
if (!e) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return all_programs_.at(*e).find(address);
|
||||
}
|
||||
std::string filename(uint64_t index) {
|
||||
return fmt::format(
|
||||
"{}/{}",
|
||||
include_directories_.at(file_directory_index_.at(index)),
|
||||
file_names_.at(index));
|
||||
}
|
||||
|
||||
private:
|
||||
void skipForm(CheckedLexer& L, uint64_t form) {
|
||||
auto sz = formSize(form, is_64bit_ ? 8 : 4);
|
||||
UNWIND_CHECK(sz, "unsupported form {}", form);
|
||||
L.skip(int64_t(*sz));
|
||||
}
|
||||
|
||||
uint64_t readData(CheckedLexer& L, uint64_t encoding) {
|
||||
switch (encoding) {
|
||||
case DW_FORM_data1:
|
||||
return L.read<uint8_t>();
|
||||
case DW_FORM_data2:
|
||||
return L.read<uint16_t>();
|
||||
case DW_FORM_data4:
|
||||
return L.read<uint32_t>();
|
||||
case DW_FORM_data8:
|
||||
return L.read<uint64_t>();
|
||||
case DW_FORM_udata:
|
||||
return L.readULEB128();
|
||||
default:
|
||||
UNWIND_CHECK(false, "unsupported data encoding {}", encoding);
|
||||
}
|
||||
}
|
||||
|
||||
void produceEntry() {
|
||||
if (shadow_) {
|
||||
return;
|
||||
}
|
||||
if (ranges_.size() == 1) {
|
||||
start_address_ = address_;
|
||||
}
|
||||
PRINT_LINE_TABLE(
|
||||
"{:x}\t{}\t{}\n", address_, filename(entry_.file), entry_.line);
|
||||
UNWIND_CHECK(
|
||||
entry_.file < file_names_.size(),
|
||||
"file index {} > {} entries",
|
||||
entry_.file,
|
||||
file_names_.size());
|
||||
ranges_.add(address_, entry_, true);
|
||||
}
|
||||
void endSequence() {
|
||||
if (shadow_) {
|
||||
return;
|
||||
}
|
||||
PRINT_LINE_TABLE(
|
||||
"{:x}\tEND\n", address_, filename(entry_.file), entry_.line);
|
||||
program_index_.add(start_address_, all_programs_.size(), false);
|
||||
program_index_.add(address_, std::nullopt, false);
|
||||
all_programs_.emplace_back(std::move(ranges_));
|
||||
ranges_ = RangeTable<Entry>();
|
||||
}
|
||||
void readProgram() {
|
||||
while (program_.loc() < program_end_) {
|
||||
PRINT_INST("{:x}: ", (char*)program_.loc() - (s_.debug_line.data));
|
||||
uint8_t op = program_.read<uint8_t>();
|
||||
if (op >= opcode_base_) {
|
||||
auto op2 = int64_t(op - opcode_base_);
|
||||
address_ += op2 / line_range_;
|
||||
entry_.line += line_base_ + (op2 % line_range_);
|
||||
PRINT_INST(
|
||||
"address += {}, line += {}\n",
|
||||
op2 / line_range_,
|
||||
line_base_ + (op2 % line_range_));
|
||||
produceEntry();
|
||||
} else {
|
||||
switch (op) {
|
||||
case DW_LNS_extended_op: {
|
||||
auto len = program_.readULEB128();
|
||||
auto extended_op = program_.read<uint8_t>();
|
||||
switch (extended_op) {
|
||||
case DW_LNE_end_sequence: {
|
||||
PRINT_INST("end_sequence\n");
|
||||
endSequence();
|
||||
entry_ = Entry{};
|
||||
} break;
|
||||
case DW_LNE_set_address: {
|
||||
address_ = program_.read<uint64_t>();
|
||||
if (!shadow_) {
|
||||
PRINT_INST(
|
||||
"set address {:x} {:x} {:x}\n",
|
||||
address_,
|
||||
min_address_,
|
||||
max_address_);
|
||||
}
|
||||
shadow_ = address_ == 0;
|
||||
} break;
|
||||
default: {
|
||||
PRINT_INST("skip extended op {}\n", extended_op);
|
||||
program_.skip(int64_t(len - 1));
|
||||
} break;
|
||||
}
|
||||
} break;
|
||||
case DW_LNS_copy: {
|
||||
PRINT_INST("copy\n");
|
||||
produceEntry();
|
||||
} break;
|
||||
case DW_LNS_advance_pc: {
|
||||
PRINT_INST("advance pc\n");
|
||||
address_ += program_.readULEB128();
|
||||
} break;
|
||||
case DW_LNS_advance_line: {
|
||||
entry_.line += program_.readSLEB128();
|
||||
PRINT_INST("advance line {}\n", entry_.line);
|
||||
|
||||
} break;
|
||||
case DW_LNS_set_file: {
|
||||
PRINT_INST("set file\n");
|
||||
entry_.file = program_.readULEB128();
|
||||
} break;
|
||||
case DW_LNS_const_add_pc: {
|
||||
PRINT_INST("const add pc\n");
|
||||
address_ += (255 - opcode_base_) / line_range_;
|
||||
} break;
|
||||
case DW_LNS_fixed_advance_pc: {
|
||||
PRINT_INST("fixed advance pc\n");
|
||||
address_ += program_.read<uint16_t>();
|
||||
} break;
|
||||
default: {
|
||||
PRINT_INST("other {}\n", op);
|
||||
auto n = standard_opcode_lengths_[op];
|
||||
for (int i = 0; i < n; ++i) {
|
||||
program_.readULEB128();
|
||||
}
|
||||
} break;
|
||||
}
|
||||
}
|
||||
}
|
||||
PRINT_INST(
|
||||
"{:x}: end {:x}\n",
|
||||
((char*)program_.loc() - s_.debug_line.data),
|
||||
program_end_ - s_.debug_line.data);
|
||||
}
|
||||
|
||||
uint64_t address_ = 0;
|
||||
bool shadow_ = false;
|
||||
bool parsed_ = false;
|
||||
Entry entry_ = {};
|
||||
std::vector<std::string> include_directories_;
|
||||
std::vector<std::string> file_names_;
|
||||
std::vector<uint64_t> file_directory_index_;
|
||||
uint8_t segment_selector_size_ = 0;
|
||||
uint8_t minimum_instruction_length_ = 0;
|
||||
uint8_t maximum_operations_per_instruction_ = 0;
|
||||
int8_t line_base_ = 0;
|
||||
uint8_t line_range_ = 0;
|
||||
uint8_t opcode_base_ = 0;
|
||||
bool default_is_stmt_ = false;
|
||||
CheckedLexer program_ = {nullptr};
|
||||
char* program_end_ = nullptr;
|
||||
uint64_t header_length_ = 0;
|
||||
uint64_t length_ = 0;
|
||||
bool is_64bit_ = false;
|
||||
std::vector<uint8_t> standard_opcode_lengths_;
|
||||
Sections& s_;
|
||||
uint64_t offset_;
|
||||
uint64_t start_address_ = 0;
|
||||
RangeTable<uint64_t> program_index_;
|
||||
std::vector<RangeTable<Entry>> all_programs_;
|
||||
RangeTable<Entry> ranges_;
|
||||
};
|
||||
|
||||
} // namespace torch::unwind
|
||||
|
|
@ -1,150 +0,0 @@
|
|||
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
// All rights reserved.
|
||||
//
|
||||
// This source code is licensed under the BSD-style license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <elf.h>
|
||||
#include <fcntl.h>
|
||||
#include <fmt/format.h>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <torch/csrc/profiler/unwind/lexer.h>
|
||||
#include <torch/csrc/profiler/unwind/unwind_error.h>
|
||||
#include <unistd.h>
|
||||
#include <cerrno>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
|
||||
namespace torch::unwind {
|
||||
|
||||
struct Section {
|
||||
char* data = nullptr;
|
||||
size_t size = 0;
|
||||
const char* string(size_t offset) {
|
||||
return lexer(offset).readCString();
|
||||
}
|
||||
CheckedLexer lexer(size_t offset) {
|
||||
return CheckedLexer(data + offset, data, data + size);
|
||||
}
|
||||
};
|
||||
|
||||
/// Memory maps a file into the address space read-only, and manages the
|
||||
/// lifetime of the mapping. Here are a few use cases:
|
||||
/// 1. Used in the loader to read in initial image, and to inspect
|
||||
// ELF files for dependencies before callling dlopen.
|
||||
///
|
||||
/// 2. Used in unity to load the elf file.
|
||||
struct MemFile {
|
||||
explicit MemFile(const char* filename_)
|
||||
: fd_(open(filename_, O_RDONLY)),
|
||||
mem_(nullptr),
|
||||
n_bytes_(0),
|
||||
name_(filename_) {
|
||||
UNWIND_CHECK(
|
||||
fd_ != -1, "failed to open {}: {}", filename_, strerror(errno));
|
||||
// NOLINTNEXTLINE
|
||||
struct stat s;
|
||||
if (-1 == fstat(fd_, &s)) {
|
||||
close(fd_); // destructors don't run during exceptions
|
||||
UNWIND_CHECK(false, "failed to stat {}: {}", filename_, strerror(errno));
|
||||
}
|
||||
n_bytes_ = s.st_size;
|
||||
UNWIND_CHECK(
|
||||
n_bytes_ > sizeof(Elf64_Ehdr), "empty shared library: {}", filename_);
|
||||
mem_ = (char*)mmap(nullptr, n_bytes_, PROT_READ, MAP_SHARED, fd_, 0);
|
||||
if (MAP_FAILED == mem_) {
|
||||
close(fd_);
|
||||
UNWIND_CHECK(false, "failed to mmap {}: {}", filename_, strerror(errno));
|
||||
}
|
||||
ehdr_ = (Elf64_Ehdr*)mem_;
|
||||
#define ELF_CHECK(cond) UNWIND_CHECK(cond, "not an ELF file: {}", filename_)
|
||||
ELF_CHECK(ehdr_->e_ident[EI_MAG0] == ELFMAG0);
|
||||
ELF_CHECK(ehdr_->e_ident[EI_MAG1] == ELFMAG1);
|
||||
ELF_CHECK(ehdr_->e_ident[EI_MAG2] == ELFMAG2);
|
||||
ELF_CHECK(ehdr_->e_ident[EI_MAG3] == ELFMAG3);
|
||||
ELF_CHECK(ehdr_->e_ident[EI_CLASS] == ELFCLASS64);
|
||||
ELF_CHECK(ehdr_->e_ident[EI_VERSION] == EV_CURRENT);
|
||||
ELF_CHECK(ehdr_->e_version == EV_CURRENT);
|
||||
ELF_CHECK(ehdr_->e_machine == EM_X86_64);
|
||||
#undef ELF_CHECK
|
||||
UNWIND_CHECK(
|
||||
ehdr_->e_shoff + sizeof(Elf64_Shdr) * ehdr_->e_shnum <= n_bytes_,
|
||||
"invalid section header table {} {} {}",
|
||||
ehdr_->e_shoff + sizeof(Elf64_Shdr) * ehdr_->e_shnum,
|
||||
n_bytes_,
|
||||
ehdr_->e_shnum);
|
||||
shdr_ = (Elf64_Shdr*)(mem_ + ehdr_->e_shoff);
|
||||
UNWIND_CHECK(
|
||||
ehdr_->e_shstrndx < ehdr_->e_shnum, "invalid strtab section offset");
|
||||
auto& strtab_hdr = shdr_[ehdr_->e_shstrndx];
|
||||
strtab_ = getSection(strtab_hdr);
|
||||
}
|
||||
|
||||
MemFile(const MemFile&) = delete;
|
||||
MemFile& operator=(const MemFile&) = delete;
|
||||
[[nodiscard]] const char* data() const {
|
||||
return (const char*)mem_;
|
||||
}
|
||||
|
||||
/// Returns whether or not the file descriptor
|
||||
/// of the underlying file is valid.
|
||||
int valid() {
|
||||
return fcntl(fd_, F_GETFD) != -1 || errno != EBADF;
|
||||
}
|
||||
|
||||
~MemFile() {
|
||||
if (mem_) {
|
||||
munmap((void*)mem_, n_bytes_);
|
||||
}
|
||||
if (fd_) {
|
||||
close(fd_);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the size of the underlying file defined by the `MemFile`
|
||||
size_t size() {
|
||||
return n_bytes_;
|
||||
}
|
||||
[[nodiscard]] int fd() const {
|
||||
return fd_;
|
||||
}
|
||||
|
||||
Section getSection(const Elf64_Shdr& shdr) {
|
||||
UNWIND_CHECK(shdr.sh_offset + shdr.sh_size <= n_bytes_, "invalid section");
|
||||
return Section{mem_ + shdr.sh_offset, shdr.sh_size};
|
||||
}
|
||||
|
||||
Section getSection(const char* name, bool optional) {
|
||||
for (int i = 0; i < ehdr_->e_shnum; i++) {
|
||||
if (strcmp(strtab_.string(shdr_[i].sh_name), name) == 0) {
|
||||
return getSection(shdr_[i]);
|
||||
}
|
||||
}
|
||||
UNWIND_CHECK(optional, "{} has no section {}", name_, name);
|
||||
return Section{nullptr, 0};
|
||||
}
|
||||
|
||||
Section strtab() {
|
||||
return strtab_;
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
T* load(size_t offset) {
|
||||
UNWIND_CHECK(offset < n_bytes_, "out of range");
|
||||
return (T*)(mem_ + offset);
|
||||
}
|
||||
int fd_;
|
||||
char* mem_;
|
||||
size_t n_bytes_;
|
||||
std::string name_;
|
||||
Elf64_Ehdr* ehdr_;
|
||||
Elf64_Shdr* shdr_;
|
||||
Section strtab_ = {nullptr, 0};
|
||||
};
|
||||
|
||||
} // namespace torch::unwind
|
||||
|
|
@ -1,74 +0,0 @@
|
|||
#pragma once
|
||||
#include <torch/csrc/profiler/unwind/unwind_error.h>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace torch::unwind {
|
||||
template <typename T>
|
||||
struct RangeTable {
|
||||
RangeTable() {
|
||||
// guarentee that lower_bound[-1] is always valid
|
||||
addresses_.push_back(0);
|
||||
payloads_.emplace_back(std::nullopt);
|
||||
}
|
||||
void add(uint64_t address, unwind::optional<T> payload, bool sorted) {
|
||||
if (addresses_.back() > address) {
|
||||
UNWIND_CHECK(!sorted, "expected addresses to be sorted");
|
||||
sorted_ = false;
|
||||
}
|
||||
addresses_.push_back(address);
|
||||
payloads_.emplace_back(std::move(payload));
|
||||
}
|
||||
unwind::optional<T> find(uint64_t address) {
|
||||
maybeSort();
|
||||
auto it = std::upper_bound(addresses_.begin(), addresses_.end(), address);
|
||||
return payloads_.at(it - addresses_.begin() - 1);
|
||||
}
|
||||
void dump() {
|
||||
for (size_t i = 0; i < addresses_.size(); i++) {
|
||||
fmt::print("{} {:x}: {}\n", i, addresses_[i], payloads_[i] ? "" : "END");
|
||||
}
|
||||
}
|
||||
size_t size() const {
|
||||
return addresses_.size();
|
||||
}
|
||||
uint64_t back() {
|
||||
maybeSort();
|
||||
return addresses_.back();
|
||||
}
|
||||
|
||||
private:
|
||||
void maybeSort() {
|
||||
if (sorted_) {
|
||||
return;
|
||||
}
|
||||
std::vector<uint64_t> indices;
|
||||
indices.reserve(addresses_.size());
|
||||
for (size_t i = 0; i < addresses_.size(); i++) {
|
||||
indices.push_back(i);
|
||||
}
|
||||
std::sort(indices.begin(), indices.end(), [&](uint64_t a, uint64_t b) {
|
||||
return addresses_[a] < addresses_[b] ||
|
||||
(addresses_[a] == addresses_[b] &&
|
||||
bool(payloads_[a]) < bool(payloads_[b]));
|
||||
});
|
||||
std::vector<uint64_t> addresses;
|
||||
std::vector<unwind::optional<T>> payloads;
|
||||
addresses.reserve(addresses_.size());
|
||||
payloads.reserve(addresses_.size());
|
||||
for (auto i : indices) {
|
||||
addresses.push_back(addresses_[i]);
|
||||
payloads.push_back(payloads_[i]);
|
||||
}
|
||||
addresses_ = std::move(addresses);
|
||||
payloads_ = std::move(payloads);
|
||||
sorted_ = true;
|
||||
}
|
||||
bool sorted_ = true;
|
||||
std::vector<uint64_t> addresses_;
|
||||
std::vector<unwind::optional<T>> payloads_;
|
||||
};
|
||||
} // namespace torch::unwind
|
||||
|
|
@ -1,124 +0,0 @@
|
|||
#pragma once
|
||||
#include <cxxabi.h>
|
||||
#include <elf.h>
|
||||
#include <torch/csrc/profiler/unwind/dwarf_enums.h>
|
||||
#include <torch/csrc/profiler/unwind/dwarf_symbolize_enums.h>
|
||||
#include <torch/csrc/profiler/unwind/mem_file.h>
|
||||
#include <torch/csrc/profiler/unwind/range_table.h>
|
||||
#include <torch/csrc/profiler/unwind/unwind_error.h>
|
||||
#include <cstdint>
|
||||
|
||||
namespace torch::unwind {
|
||||
|
||||
static std::string demangle(const std::string& mangled_name) {
|
||||
int status = 0;
|
||||
char* realname =
|
||||
abi::__cxa_demangle(mangled_name.c_str(), nullptr, nullptr, &status);
|
||||
if (status == 0) {
|
||||
std::string demangled_name(realname);
|
||||
// NOLINTNEXTLINE
|
||||
free(realname);
|
||||
return demangled_name;
|
||||
} else {
|
||||
return mangled_name;
|
||||
}
|
||||
}
|
||||
|
||||
struct Sections {
|
||||
Sections() = default;
|
||||
void parse(const char* name) {
|
||||
library_ = std::make_unique<MemFile>(name);
|
||||
strtab = library_->getSection(".strtab", false);
|
||||
|
||||
symtab = library_->getSection(".symtab", true);
|
||||
debug_info = library_->getSection(".debug_info", true);
|
||||
if (debug_info.size > 0) {
|
||||
debug_abbrev = library_->getSection(".debug_abbrev", false);
|
||||
debug_str = library_->getSection(".debug_str", false);
|
||||
debug_line = library_->getSection(".debug_line", false);
|
||||
// dwarf 5
|
||||
debug_line_str = library_->getSection(".debug_line_str", true);
|
||||
debug_rnglists = library_->getSection(".debug_rnglists", true);
|
||||
debug_addr = library_->getSection(".debug_addr", true);
|
||||
// dwarf 4
|
||||
debug_ranges = library_->getSection(".debug_ranges", true);
|
||||
}
|
||||
parseSymtab();
|
||||
}
|
||||
|
||||
Section debug_info;
|
||||
Section debug_abbrev;
|
||||
Section debug_str;
|
||||
Section debug_line;
|
||||
Section debug_line_str;
|
||||
Section debug_rnglists;
|
||||
Section debug_ranges;
|
||||
Section debug_addr;
|
||||
Section symtab;
|
||||
Section strtab;
|
||||
|
||||
const char* readString(
|
||||
CheckedLexer& data,
|
||||
uint64_t encoding,
|
||||
bool is_64bit,
|
||||
uint64_t str_offsets_base) {
|
||||
switch (encoding) {
|
||||
case DW_FORM_string: {
|
||||
return data.readCString();
|
||||
}
|
||||
case DW_FORM_strp: {
|
||||
return debug_str.string(readSegmentOffset(data, is_64bit));
|
||||
}
|
||||
case DW_FORM_line_strp: {
|
||||
return debug_line_str.string(readSegmentOffset(data, is_64bit));
|
||||
}
|
||||
default:
|
||||
UNWIND_CHECK(false, "unsupported string encoding {:x}", encoding);
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t readSegmentOffset(CheckedLexer& data, bool is_64bit) {
|
||||
return is_64bit ? data.read<uint64_t>() : data.read<uint32_t>();
|
||||
}
|
||||
|
||||
unwind::optional<uint64_t> findDebugInfoOffset(uint64_t address) {
|
||||
return debug_info_offsets_.find(address);
|
||||
}
|
||||
size_t compilationUnitCount() {
|
||||
return debug_info_offsets_.size() / 2;
|
||||
}
|
||||
void addDebugInfoRange(
|
||||
uint64_t start,
|
||||
uint64_t end,
|
||||
uint64_t debug_info_offset) {
|
||||
debug_info_offsets_.add(start, debug_info_offset, false);
|
||||
debug_info_offsets_.add(end, std::nullopt, false);
|
||||
}
|
||||
optional<std::string> findSubprogramName(uint64_t address) {
|
||||
if (auto e = symbol_table_.find(address)) {
|
||||
return demangle(strtab.string(*e));
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
private:
|
||||
void parseSymtab() {
|
||||
auto L = symtab.lexer(0);
|
||||
char* end = symtab.data + symtab.size;
|
||||
while (L.loc() < end) {
|
||||
auto symbol = L.read<Elf64_Sym>();
|
||||
if (symbol.st_shndx == SHN_UNDEF ||
|
||||
ELF64_ST_TYPE(symbol.st_info) != STT_FUNC) {
|
||||
continue;
|
||||
}
|
||||
symbol_table_.add(symbol.st_value, symbol.st_name, false);
|
||||
symbol_table_.add(symbol.st_value + symbol.st_size, std::nullopt, false);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<MemFile> library_;
|
||||
RangeTable<uint64_t> debug_info_offsets_;
|
||||
RangeTable<uint64_t> symbol_table_;
|
||||
};
|
||||
|
||||
} // namespace torch::unwind
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/profiler/unwind/unwind.h>
|
||||
#include <torch/csrc/utils/cpp_stacktraces.h>
|
||||
#include <unordered_map>
|
||||
|
||||
#if !defined(__linux__) || !defined(__x86_64__) || !defined(__has_include) || \
|
||||
!__has_include("ext/stdio_filebuf.h")
|
||||
|
|
@ -19,7 +18,7 @@ c10::optional<std::pair<std::string, uint64_t>> libraryFor(void* addr) {
|
|||
}
|
||||
|
||||
#ifndef FBCODE_CAFFE2
|
||||
std::vector<Frame> symbolize(const std::vector<void*>& frames, Mode mode) {
|
||||
std::vector<Frame> symbolize(const std::vector<void*>& frames) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"record_context_cpp is not support on non-linux non-x86_64 platforms");
|
||||
|
|
@ -49,15 +48,10 @@ Stats stats() {
|
|||
#include <torch/csrc/profiler/unwind/communicate.h>
|
||||
#include <torch/csrc/profiler/unwind/dwarf_enums.h>
|
||||
#include <torch/csrc/profiler/unwind/eh_frame_hdr.h>
|
||||
#include <torch/csrc/profiler/unwind/fast_symbolizer.h>
|
||||
#include <torch/csrc/profiler/unwind/fde.h>
|
||||
#include <torch/csrc/profiler/unwind/unwinder.h>
|
||||
#include <shared_mutex>
|
||||
|
||||
extern "C" void unwind_c(std::vector<void*>* result, int64_t rsp, int64_t rbp);
|
||||
extern "C" void unwind_entry(std::vector<void*>* result);
|
||||
|
||||
namespace torch::unwind {
|
||||
struct UpgradeExclusive {
|
||||
UpgradeExclusive(std::shared_lock<std::shared_timed_mutex>& rdlock)
|
||||
: rdlock_(rdlock) {
|
||||
|
|
@ -203,7 +197,7 @@ struct UnwindCache {
|
|||
Unwinder unwinder = Unwinder::unknown();
|
||||
try {
|
||||
unwinder = libraryFor(addr).unwinderFor(addr);
|
||||
} catch (unwind::UnwindError& err) {
|
||||
} catch (UnwindError& err) {
|
||||
// because unwinders are cached this will only print
|
||||
// once per frame that cannot be unwound.
|
||||
TORCH_WARN("Unsupported unwinding pattern: ", err.what());
|
||||
|
|
@ -282,6 +276,46 @@ struct UnwindCache {
|
|||
static UnwindCache unwind_cache;
|
||||
static std::shared_timed_mutex cache_mutex_;
|
||||
|
||||
extern "C" void unwind_c(std::vector<void*>* result, int64_t rsp, int64_t rbp);
|
||||
extern "C" void unwind_c(std::vector<void*>* result, int64_t rsp, int64_t rbp) {
|
||||
std::shared_lock lock(cache_mutex_);
|
||||
UnwindState state{};
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
state.rip = *(int64_t*)(rsp);
|
||||
// +8 because we saved rsp after the return address was already pushed
|
||||
// to the stack
|
||||
state.rsp = rsp + 8;
|
||||
state.rbp = rbp;
|
||||
unwind_cache.checkRefresh(lock);
|
||||
while (true) { // unwind for _start sets rip as being undefined
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
result->push_back((void*)state.rip);
|
||||
const Unwinder& uw = unwind_cache.unwinderFor(state.rip, lock);
|
||||
if (uw.terminator()) {
|
||||
if (uw.isUnknown()) {
|
||||
result->push_back(nullptr);
|
||||
}
|
||||
break;
|
||||
}
|
||||
state = uw.run(state);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void unwind_entry(std::vector<void*>* result);
|
||||
|
||||
// calling convention puts the first three pointer/int64_t arguments in
|
||||
// rdi rsi rdx (all caller-saved)
|
||||
// rdi already holds the pointer to the result vector
|
||||
// we add arguments for current rsp and rbp and then tail call
|
||||
// into unwind_c
|
||||
__asm__(
|
||||
".global unwind_entry\n"
|
||||
"unwind_entry:\n"
|
||||
"mov %rsp, %rsi;\n"
|
||||
"mov %rbp, %rdx;\n"
|
||||
"jmp unwind_c;\n");
|
||||
|
||||
namespace torch::unwind {
|
||||
std::vector<void*> unwind() {
|
||||
std::vector<void*> frames;
|
||||
unwind_entry(&frames);
|
||||
|
|
@ -301,15 +335,6 @@ c10::optional<std::pair<std::string, uint64_t>> libraryFor(void* addr) {
|
|||
library_info->name(), (uint64_t)addr - library_info->load_bias());
|
||||
}
|
||||
|
||||
static std::string dladdr_lookup(void* addr) {
|
||||
Dl_info dlinfo;
|
||||
std::string funcname = "??";
|
||||
if (dladdr(addr, &dlinfo) && dlinfo.dli_sname) {
|
||||
funcname = demangle(dlinfo.dli_sname);
|
||||
}
|
||||
return funcname;
|
||||
}
|
||||
|
||||
struct Symbolizer {
|
||||
Symbolizer() {
|
||||
auto envar = std::getenv("TORCH_ADDR2LINE_BINARY");
|
||||
|
|
@ -320,6 +345,9 @@ struct Symbolizer {
|
|||
} else {
|
||||
addr2line_binary_ = "addr2line"; // default
|
||||
}
|
||||
if (torch::get_disable_addr2line()) {
|
||||
addr2line_binary_ = nullptr;
|
||||
}
|
||||
}
|
||||
static std::lock_guard<std::mutex> guard() {
|
||||
static std::mutex mutex;
|
||||
|
|
@ -339,6 +367,16 @@ struct Symbolizer {
|
|||
frame_map_[addr] = Frame{"??", "<unwind unsupported>", 0};
|
||||
return;
|
||||
}
|
||||
if (addr2line_binary_ == nullptr) {
|
||||
Dl_info dlinfo;
|
||||
std::string funcname = "??";
|
||||
if (dladdr(addr, &dlinfo) && dlinfo.dli_sname) {
|
||||
funcname = demangle(dlinfo.dli_sname);
|
||||
}
|
||||
frame_map_[addr] = Frame{
|
||||
maybe_library->first, std::move(funcname), maybe_library->second - 1};
|
||||
return;
|
||||
}
|
||||
has_pending_results_ = true;
|
||||
auto& entry = getOrCreate(maybe_library->first);
|
||||
entry.queried.push_back(addr);
|
||||
|
|
@ -410,59 +448,23 @@ struct Symbolizer {
|
|||
frame_map_[e.queried[e.completed]] = std::move(frame);
|
||||
}
|
||||
}
|
||||
std::string demangle(const std::string& mangled_name) {
|
||||
int status = 0;
|
||||
char* realname =
|
||||
abi::__cxa_demangle(mangled_name.c_str(), nullptr, nullptr, &status);
|
||||
if (status == 0) {
|
||||
std::string demangled_name(realname);
|
||||
// NOLINTNEXTLINE
|
||||
free(realname);
|
||||
return demangled_name;
|
||||
} else {
|
||||
return mangled_name;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static std::vector<Frame> symbolize_fast(
|
||||
const std::vector<void*>& frames,
|
||||
Mode mode) {
|
||||
static std::mutex cache_mutex;
|
||||
static std::array<ska::flat_hash_map<void*, Frame>, 2> frame_maps;
|
||||
auto& frame_map = frame_maps[mode == Mode::fast ? 0 : 1];
|
||||
|
||||
std::vector<uint32_t> indices_to_lookup;
|
||||
std::vector<Frame> results;
|
||||
results.reserve(frames.size());
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(cache_mutex);
|
||||
for (auto i : c10::irange(frames.size())) {
|
||||
void* f = frames.at(i);
|
||||
auto it = frame_map.find(f);
|
||||
if (it == frame_map.end()) {
|
||||
indices_to_lookup.push_back(i);
|
||||
results.emplace_back(Frame{"??", "??", 0});
|
||||
} else {
|
||||
results.emplace_back(it->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!indices_to_lookup.empty()) {
|
||||
// do symbolizer work
|
||||
FastSymbolizer symbolizer;
|
||||
for (auto i : indices_to_lookup) {
|
||||
void* addr = frames.at(i);
|
||||
Frame& f = results.at(i);
|
||||
auto library = libraryFor(frames.at(i));
|
||||
if (library) {
|
||||
if (mode == Mode::fast) {
|
||||
f = symbolizer.symbolize(library->first, library->second - 1);
|
||||
} else {
|
||||
f = Frame{library->first, "??", library->second - 1};
|
||||
}
|
||||
}
|
||||
if (f.funcname == "??") {
|
||||
f.funcname = dladdr_lookup(addr);
|
||||
}
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(cache_mutex);
|
||||
for (auto i : indices_to_lookup) {
|
||||
frame_map.emplace(frames.at(i), results.at(i));
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
static std::vector<Frame> symbolize_addr2line(
|
||||
const std::vector<void*>& frames) {
|
||||
#ifndef FBCODE_CAFFE2
|
||||
std::vector<Frame> symbolize(const std::vector<void*>& frames) {
|
||||
auto guard = Symbolizer::guard();
|
||||
Symbolizer& s = Symbolizer::get();
|
||||
for (auto f : frames) {
|
||||
|
|
@ -475,16 +477,6 @@ static std::vector<Frame> symbolize_addr2line(
|
|||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
// fbcode will use llvm symbolize since there is an llvm dependency already
|
||||
#ifndef FBCODE_CAFFE2
|
||||
std::vector<Frame> symbolize(const std::vector<void*>& frames, Mode mode) {
|
||||
if (mode == Mode::addr2line) {
|
||||
return symbolize_addr2line(frames);
|
||||
} else {
|
||||
return symbolize_fast(frames, mode);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
Stats stats() {
|
||||
|
|
@ -492,42 +484,4 @@ Stats stats() {
|
|||
}
|
||||
|
||||
} // namespace torch::unwind
|
||||
|
||||
extern "C" void unwind_c(std::vector<void*>* result, int64_t rsp, int64_t rbp) {
|
||||
std::shared_lock lock(torch::unwind::cache_mutex_);
|
||||
torch::unwind::UnwindState state{};
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
state.rip = *(int64_t*)(rsp);
|
||||
// +8 because we saved rsp after the return address was already pushed
|
||||
// to the stack
|
||||
state.rsp = rsp + 8;
|
||||
state.rbp = rbp;
|
||||
torch::unwind::unwind_cache.checkRefresh(lock);
|
||||
while (true) { // unwind for _start sets rip as being undefined
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
result->push_back((void*)state.rip);
|
||||
const torch::unwind::Unwinder& uw =
|
||||
torch::unwind::unwind_cache.unwinderFor(state.rip, lock);
|
||||
if (uw.terminator()) {
|
||||
if (uw.isUnknown()) {
|
||||
result->push_back(nullptr);
|
||||
}
|
||||
break;
|
||||
}
|
||||
state = uw.run(state);
|
||||
}
|
||||
}
|
||||
|
||||
// calling convention puts the first three pointer/int64_t arguments in
|
||||
// rdi rsi rdx (all caller-saved)
|
||||
// rdi already holds the pointer to the result vector
|
||||
// we add arguments for current rsp and rbp and then tail call
|
||||
// into unwind_c
|
||||
__asm__(
|
||||
".global unwind_entry\n"
|
||||
"unwind_entry:\n"
|
||||
"mov %rsp, %rsi;\n"
|
||||
"mov %rbp, %rdx;\n"
|
||||
"jmp unwind_c;\n");
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
#pragma once
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace torch::unwind {
|
||||
namespace torch {
|
||||
namespace unwind {
|
||||
// gather current stack, relatively fast.
|
||||
// gets faster once the cache of program counter locations is warm.
|
||||
TORCH_API std::vector<void*> unwind();
|
||||
|
|
@ -16,17 +16,13 @@ struct Frame {
|
|||
uint64_t lineno;
|
||||
};
|
||||
|
||||
enum class Mode { addr2line, fast, dladdr };
|
||||
|
||||
// note: symbolize is really slow
|
||||
// it will launch an addr2line process that has to parse dwarf
|
||||
// information from the libraries that frames point into.
|
||||
// Callers should first batch up all the unique void* pointers
|
||||
// across a number of unwind states and make a single call to
|
||||
// symbolize.
|
||||
TORCH_API std::vector<Frame> symbolize(
|
||||
const std::vector<void*>& frames,
|
||||
Mode mode);
|
||||
TORCH_API std::vector<Frame> symbolize(const std::vector<void*>& frames);
|
||||
|
||||
// returns path to the library, and the offset of the addr inside the library
|
||||
TORCH_API c10::optional<std::pair<std::string, uint64_t>> libraryFor(
|
||||
|
|
@ -40,4 +36,5 @@ struct Stats {
|
|||
};
|
||||
Stats stats();
|
||||
|
||||
} // namespace torch::unwind
|
||||
} // namespace unwind
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -1,31 +1,6 @@
|
|||
#pragma once
|
||||
#include <c10/util/Optional.h>
|
||||
#include <fmt/format.h>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace torch::unwind {
|
||||
|
||||
struct UnwindError : public std::runtime_error {
|
||||
using std::runtime_error::runtime_error;
|
||||
};
|
||||
|
||||
#define UNWIND_CHECK(cond, fmtstring, ...) \
|
||||
do { \
|
||||
if (!(cond)) { \
|
||||
throw unwind::UnwindError(fmt::format( \
|
||||
"{}:{}: " fmtstring, __FILE__, __LINE__, ##__VA_ARGS__)); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// #define LOG_INFO(...) fmt::print(__VA_ARGS__)
|
||||
#define LOG_INFO(...)
|
||||
|
||||
// #define PRINT_INST(...) LOG_INFO(__VA_ARGS__)
|
||||
#define PRINT_INST(...)
|
||||
|
||||
// #define PRINT_LINE_TABLE(...) LOG_INFO(__VA_ARGS__)
|
||||
#define PRINT_LINE_TABLE(...)
|
||||
|
||||
using c10::optional; // NOLINT
|
||||
|
||||
} // namespace torch::unwind
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@
|
|||
#include <torch/csrc/profiler/unwind/unwind.h>
|
||||
|
||||
namespace torch {
|
||||
namespace torch::unwind {
|
||||
namespace unwind {
|
||||
|
||||
std::vector<Frame> symbolize(const std::vector<void*>& frames, Mode mode) {
|
||||
std::vector<Frame> symbolize(const std::vector<void*>& frames) {
|
||||
static std::mutex symbolize_mutex;
|
||||
static llvm::symbolize::LLVMSymbolizer symbolizer;
|
||||
static ska::flat_hash_map<void*, Frame> frame_map_;
|
||||
|
|
@ -38,7 +38,7 @@ std::vector<Frame> symbolize(const std::vector<void*>& frames, Mode mode) {
|
|||
return results;
|
||||
}
|
||||
|
||||
} // namespace torch::unwind
|
||||
} // namespace unwind
|
||||
} // namespace torch
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -4,8 +4,6 @@
|
|||
#include <cstdint>
|
||||
#include <limits>
|
||||
|
||||
namespace torch::unwind {
|
||||
|
||||
struct UnwindState {
|
||||
int64_t rip, rbp, rsp;
|
||||
};
|
||||
|
|
@ -77,5 +75,3 @@ struct Unwinder {
|
|||
int64_t rbp_off_;
|
||||
bool deref_{false};
|
||||
};
|
||||
|
||||
} // namespace torch::unwind
|
||||
|
|
|
|||
|
|
@ -47,31 +47,9 @@ bool get_cpp_stacktraces_enabled() {
|
|||
return enabled;
|
||||
}
|
||||
|
||||
static torch::unwind::Mode compute_symbolize_mode() {
|
||||
auto envar_c = std::getenv("TORCH_SYMBOLIZE_MODE");
|
||||
if (envar_c) {
|
||||
std::string envar = envar_c;
|
||||
if (envar == "dladdr") {
|
||||
return unwind::Mode::dladdr;
|
||||
} else if (envar == "addr2line") {
|
||||
return unwind::Mode::addr2line;
|
||||
} else if (envar == "fast") {
|
||||
return unwind::Mode::fast;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"expected {dladdr, addr2line, fast} for TORCH_SYMBOLIZE_MODE, got ",
|
||||
envar);
|
||||
}
|
||||
} else {
|
||||
return compute_disable_addr2line() ? unwind::Mode::dladdr
|
||||
: unwind::Mode::addr2line;
|
||||
}
|
||||
}
|
||||
|
||||
unwind::Mode get_symbolize_mode() {
|
||||
static unwind::Mode mode = compute_symbolize_mode();
|
||||
return mode;
|
||||
bool get_disable_addr2line() {
|
||||
static bool disabled = compute_disable_addr2line();
|
||||
return disabled;
|
||||
}
|
||||
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/csrc/profiler/unwind/unwind.h>
|
||||
|
||||
namespace torch {
|
||||
TORCH_API bool get_cpp_stacktraces_enabled();
|
||||
TORCH_API torch::unwind::Mode get_symbolize_mode();
|
||||
TORCH_API bool get_disable_addr2line();
|
||||
} // namespace torch
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user