Revert "Fast standalone symbolize for unwinding (#123966)"

This reverts commit 772ae6da1e.

Reverted https://github.com/pytorch/pytorch/pull/123966 on behalf of https://github.com/jeanschmidt due to Breaking internal builds, check D56522678 ([comment](https://github.com/pytorch/pytorch/pull/123966#issuecomment-2076821043))
This commit is contained in:
PyTorch MergeBot 2024-04-25 10:04:48 +00:00
parent 2d7f709752
commit c0fd7894cc
24 changed files with 111 additions and 1596 deletions

View File

@ -16,11 +16,8 @@ except ImportError:
import collections
import gc
import json
import mmap
import os
import random
import re
import struct
import subprocess
import sys
import tempfile
@ -77,9 +74,7 @@ from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
from torch.testing._internal.common_device_type import skipCUDAVersionIn
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_ARM64,
IS_JETSON,
IS_LINUX,
IS_WINDOWS,
parametrize,
run_tests,
@ -3602,70 +3597,6 @@ aten::mm""",
finally:
os.remove("torchtidy_report.json")
@unittest.skipIf(IS_ARM64 or not IS_LINUX, "x86 linux only cpp unwinding")
def test_fuzz_symbolize(self):
# generate some random addresses in the text section and make sure the
# symbolizers do not throw exceptions/crash
def get_text_sections():
text_sections = []
seen = set()
for filename in os.listdir("/proc/self/map_files"):
library = os.readlink("/proc/self/map_files/" + filename)
if ".so" not in library or library in seen:
continue
seen.add(library)
with open(os.path.join("/proc/self/map_files", library), "rb") as f:
mm = mmap.mmap(f.fileno(), 0, prot=mmap.PROT_READ)
def unpack(fmt, offset):
return struct.unpack(
fmt, mm[offset : offset + struct.calcsize(fmt)]
)
if mm[:4] != b"\x7fELF":
continue
(section_headers_start,) = unpack("Q", 40)
(section_header_size,) = unpack("H", 58)
(num_section_headers,) = unpack("H", 60)
(shstrndx,) = unpack("H", 62)
(shstrtab_offset,) = unpack(
"Q", section_headers_start + shstrndx * section_header_size + 24
)
for i in range(num_section_headers):
(section_name_offset,) = unpack(
"I", section_headers_start + i * section_header_size
)
name_start = shstrtab_offset + section_name_offset
section_name = mm[name_start : name_start + 6]
if section_name != b".text\0":
continue
(section_offset,) = unpack(
"Q", section_headers_start + i * section_header_size + 24
)
(section_size,) = unpack(
"Q", section_headers_start + i * section_header_size + 32
)
start = int(filename.split("-")[0], 16) + section_offset
text_sections.append((start, section_size))
break
mm.close()
return text_sections
r = random.Random()
r.seed(1)
text_sections = get_text_sections()
addrs = []
for i in range(200):
s = r.randrange(0, len(text_sections))
start, size = text_sections[s]
addr = r.randrange(start, start + size)
addrs.append(addr)
fast = torch._C._profiler.symbolize_addresses(addrs, "fast")
dladdr = torch._C._profiler.symbolize_addresses(addrs, "dladdr")
addr2line = torch._C._profiler.symbolize_addresses(addrs, "addr2line")
self.assertEqual(len(fast), len(addrs))
self.assertEqual(len(addr2line), len(fast))
if __name__ == "__main__":
run_tests()

View File

@ -16,6 +16,7 @@ import subprocess
import random
from random import randint
import json
import torch
import torch.cuda
from torch.cuda._memory_viz import profile_plot, _profile_to_snapshot

View File

@ -168,14 +168,12 @@ static PyObject* THPModule_initExtension(
PyObject* shm_manager_path) {
HANDLE_TH_ERRORS
#if !defined(FBCODE_CAFFE2)
if (torch::get_cpp_stacktraces_enabled()) {
if (torch::get_cpp_stacktraces_enabled() && !torch::get_disable_addr2line()) {
c10::SetStackTraceFetcher([]() -> std::string {
auto tb = torch::CapturedTraceback::gather(false, false, true);
if (torch::get_symbolize_mode() == torch::unwind::Mode::addr2line) {
LOG(WARNING)
<< "symbolizing C++ stack trace for exception; if this hangs, rerun with TORCH_DISABLE_ADDR2LINE=1..."
<< std::endl;
}
LOG(WARNING)
<< "symbolizing C++ stack trace for exception; if this hangs, rerun with TORCH_DISABLE_ADDR2LINE=1..."
<< std::endl;
auto s_tbs = torch::symbolize({tb.get()});
std::stringstream oss;
oss << "C++ CapturedTraceback:" << std::endl;

View File

@ -1,5 +1,4 @@
#include <torch/csrc/profiler/combined_traceback.h>
#include <torch/csrc/utils/cpp_stacktraces.h>
namespace torch {
@ -78,7 +77,7 @@ SymbolizedTracebacks symbolize(
}
// gather symbol names for C++ frames
if (!all_cpp_ips.empty()) {
r.all_frames = unwind::symbolize(all_cpp_ips, torch::get_symbolize_mode());
r.all_frames = unwind::symbolize(all_cpp_ips);
}
// batch symbolization requests so we dedup frame objects

View File

@ -79,7 +79,8 @@ PyTypeObject THPCapturedTracebackType = {
nullptr, /* tp_new */
};
namespace pybind11::detail {
namespace pybind11 {
namespace detail {
template <>
struct type_caster<std::shared_ptr<torch::CapturedTraceback>> {
@ -106,9 +107,11 @@ struct type_caster<std::shared_ptr<torch::CapturedTraceback>> {
}
};
} // namespace pybind11::detail
} // namespace detail
} // namespace pybind11
namespace torch::profiler {
namespace torch {
namespace profiler {
/* [NOTE: RecordFunctionFast]
* This is an alternate way to call record_function from python.
@ -604,33 +607,6 @@ void initPythonBindings(PyObject* module) {
}
return py_symbolize(tb_ptrs);
});
// directly convert address pointers to frames, used for testing symbolize
m.def(
"symbolize_addresses",
[](const std::vector<uint64_t>& frames, const std::string& mode_s) {
std::vector<std::tuple<std::string, int64_t, std::string>> frames_out;
torch::unwind::Mode mode = torch::unwind::Mode::addr2line;
if (mode_s == "fast") {
mode = torch::unwind::Mode::fast;
} else if (mode_s == "addr2line") {
mode = torch::unwind::Mode::addr2line;
} else if (mode_s == "dladdr") {
mode = torch::unwind::Mode::dladdr;
} else {
TORCH_CHECK(false, "unexpected mode ", mode_s);
}
std::vector<void*> frames_p;
frames_p.reserve(frames.size());
for (auto f : frames) {
frames_p.push_back((void*)f); // NOLINT
}
auto frame_objects = unwind::symbolize(frames_p, mode);
frames_out.reserve(frame_objects.size());
for (auto& frame : frame_objects) {
frames_out.emplace_back(frame.filename, frame.lineno, frame.funcname);
}
return frames_out;
});
installCapturedTracebackPython();
// NOLINTNEXTLINE(*-c-arrays*)
@ -664,4 +640,5 @@ void initPythonBindings(PyObject* module) {
throw python_error();
}
}
} // namespace torch::profiler
} // namespace profiler
} // namespace torch

View File

@ -2,8 +2,6 @@
#include <stdint.h>
#include <ostream>
namespace torch::unwind {
enum {
A_UNDEFINED = 0x0,
A_REG_PLUS_DATA = 0x1, // exp = REG[reg] + data0
@ -55,5 +53,3 @@ struct Action {
return out;
}
};
} // namespace torch::unwind

View File

@ -5,7 +5,6 @@
#include <unistd.h>
#include <memory>
namespace torch::unwind {
// helper to open a process with stdin/stdout/stderr streams.
struct Communicate {
Communicate(const char* command, const char** args) {
@ -64,5 +63,3 @@ struct Communicate {
std::unique_ptr<std::ostream> out_;
std::unique_ptr<std::ostream> err_;
};
} // namespace torch::unwind

View File

@ -1,279 +0,0 @@
#pragma once
#include <torch/csrc/profiler/unwind/dwarf_enums.h>
#include <torch/csrc/profiler/unwind/dwarf_symbolize_enums.h>
#include <torch/csrc/profiler/unwind/lexer.h>
#include <torch/csrc/profiler/unwind/sections.h>
#include <torch/csrc/profiler/unwind/unwind_error.h>
#include <cstdint>
#include <optional>
namespace torch::unwind {
struct DebugInfo {
DebugInfo(Sections& s) : s_(s) {}
void parse(uint64_t offset) {
auto L = parseHeader(offset);
parseCompileUnit(L);
}
unwind::optional<uint64_t> lineNumberProgramOffset() {
return line_number_program_offset_;
}
uint64_t nextOffset() {
return end_ - s_.debug_info.data;
}
std::vector<std::pair<uint64_t, uint64_t>> ranges() {
if (range_ptr_) {
auto offset = range_ptr_->first;
if (range_ptr_->second == DW_FORM_rnglistx) {
UNWIND_CHECK(rnglists_base_, "rnglistx but not rnglists_base_ set");
LOG_INFO("index for rnglistx {:x} + {:x}\n", *rnglists_base_, offset);
CheckedLexer L = s_.debug_rnglists.lexer(
*rnglists_base_ + offset * sec_offset_size_);
auto read = readSegmentOffset(L);
offset = *rnglists_base_ + read;
}
return version_ == 4 ? readRanges4(offset) : readRanges5(offset);
}
if (!highpc_) {
return {};
}
return {{lowpc_, lowpc_ + *highpc_}};
}
bool is64bit() {
return is_64bit_;
}
private:
CheckedLexer parseHeader(uint64_t offset) {
offset_ = offset;
CheckedLexer L = s_.debug_info.lexer(offset_);
std::tie(length_, is_64bit_) = L.readSectionLength();
sec_offset_size_ = is_64bit_ ? 8 : 4;
end_ = (const char*)L.loc() + length_;
version_ = L.read<uint16_t>();
UNWIND_CHECK(
version_ == 5 || version_ == 4,
"unexpected dwarf version {}",
version_);
uint8_t address_size = 0;
if (version_ == 5) {
auto unit_type = L.read<uint8_t>();
UNWIND_CHECK(unit_type == 0x1, "unexpected unit type {}", unit_type);
address_size = L.read<uint8_t>();
debug_abbrev_offset_ =
is_64bit_ ? L.read<uint64_t>() : L.read<uint32_t>();
} else {
debug_abbrev_offset_ =
is_64bit_ ? L.read<uint64_t>() : L.read<uint32_t>();
address_size = L.read<uint8_t>();
}
LOG_INFO(
"compilation unit at offset {:x} with length {:x} and debug_abbrev_offset {:x}\n",
offset,
length_,
debug_abbrev_offset_);
UNWIND_CHECK(
address_size == 8,
"expected 64-bit dwarf but found address size {}",
address_size);
return L;
}
uint64_t readSegmentOffset(CheckedLexer& L) {
return s_.readSegmentOffset(L, is_64bit_);
}
uint64_t readEncoded(CheckedLexer& L, uint64_t encoding) {
switch (encoding) {
case DW_FORM_data8:
case DW_FORM_addr:
return L.read<uint64_t>();
case DW_FORM_data4:
return L.read<uint32_t>();
case DW_FORM_addrx: {
auto idx = L.readULEB128();
return s_.debug_addr.lexer(address_base_ + sizeof(uint64_t) * idx)
.read<uint64_t>();
}
case DW_FORM_sec_offset:
return readSegmentOffset(L);
case DW_FORM_rnglistx: {
return L.readULEB128();
}
default:
UNWIND_CHECK(false, "unexpected encoding");
}
}
void parseCompileUnit(CheckedLexer& L) {
auto entry = L.readULEB128();
auto A = findAbbrev(debug_abbrev_offset_, entry);
while (true) {
auto attr = A.readULEB128();
auto form = A.readULEB128();
if (attr == 0 && form == 0) {
break;
}
if (form == DW_FORM_implicit_const) {
A.readSLEB128();
}
if (attr == DW_AT_low_pc) {
lowpc_ = readEncoded(L, form);
LOG_INFO(" lowpc {:x}\n", lowpc_);
} else if (attr == DW_AT_high_pc) {
highpc_ = readEncoded(L, form);
range_ptr_ = std::nullopt;
LOG_INFO(" highpc {:x}\n", *highpc_);
} else if (attr == DW_AT_addr_base) {
UNWIND_CHECK(form == DW_FORM_sec_offset, "unexpected addr_base form");
address_base_ = readSegmentOffset(L);
LOG_INFO(" address base {:x}\n", address_base_);
} else if (attr == DW_AT_rnglists_base) {
UNWIND_CHECK(
form == DW_FORM_sec_offset, "unexpected rnglists_base form");
rnglists_base_ = readSegmentOffset(L);
LOG_INFO(" range base {:x}\n", *rnglists_base_);
} else if (form == DW_FORM_string) {
L.readCString();
} else if (attr == DW_AT_stmt_list) {
UNWIND_CHECK(form == DW_FORM_sec_offset, "unexpected stmt_list form");
LOG_INFO(" program table offset {:x}\n", *line_number_program_offset_);
line_number_program_offset_ = readSegmentOffset(L);
} else if (form == DW_FORM_exprloc) {
auto sz = L.readULEB128();
L.skip(int64_t(sz));
} else if (form == DW_FORM_block1) {
auto sz = L.read<uint8_t>();
L.skip(int64_t(sz));
} else if (attr == DW_AT_ranges) {
auto range_offset = readEncoded(L, form);
LOG_INFO("setting range_ptr to {:x} {:x}\n", range_offset, form);
range_ptr_.emplace(range_offset, form);
} else if (
form == DW_FORM_udata || form == DW_FORM_rnglistx ||
form == DW_FORM_strx || form == DW_FORM_loclistx ||
form == DW_FORM_addrx) {
L.readULEB128();
} else if (form == DW_FORM_sdata) {
L.readSLEB128();
} else {
auto sz = formSize(form, sec_offset_size_);
UNWIND_CHECK(sz, "unsupported form in compilation unit {:x}", form);
L.skip(int64_t(*sz));
}
}
}
std::vector<std::pair<uint64_t, uint64_t>> readRanges4(uint64_t offset) {
CheckedLexer L = s_.debug_ranges.lexer(offset);
std::vector<std::pair<uint64_t, uint64_t>> ranges;
uint64_t base = lowpc_;
while (true) {
auto start = L.read<uint64_t>();
auto end = L.read<uint64_t>();
if (start == 0 && end == 0) {
break;
}
if (start == std::numeric_limits<uint64_t>::max()) {
base = end;
} else {
ranges.emplace_back(base + start, base + end);
}
}
return ranges;
}
std::vector<std::pair<uint64_t, uint64_t>> readRanges5(uint64_t offset) {
CheckedLexer L = s_.debug_rnglists.lexer(offset);
uint64_t base = 0;
LOG_INFO("BEGIN RANGES {:x}\n", offset);
std::vector<std::pair<uint64_t, uint64_t>> ranges;
while (true) {
auto op = L.read<uint8_t>();
switch (op) {
case DW_RLE_end_of_list:
LOG_INFO("END RANGES\n");
return ranges;
case DW_RLE_base_addressx: {
base = readEncoded(L, DW_FORM_addrx);
LOG_INFO("BASE ADDRX {:x}\n", base);
} break;
case DW_RLE_startx_length: {
auto s = readEncoded(L, DW_FORM_addrx);
auto e = L.readULEB128();
LOG_INFO("startx_length {:x} {:x}\n", s, e);
ranges.emplace_back(s, s + e);
} break;
case DW_RLE_base_address:
base = L.read<uint64_t>();
LOG_INFO("BASE ADDR {:x}\n", base);
break;
case DW_RLE_offset_pair: {
auto s = L.readULEB128();
auto e = L.readULEB128();
LOG_INFO("offset_pair {:x} {:x}\n", s, e);
ranges.emplace_back(base + s, base + e);
} break;
case DW_RLE_start_length: {
auto s = L.read<uint64_t>();
auto e = L.readULEB128();
LOG_INFO("start_length {:x} {:x}\n", s, e);
ranges.emplace_back(s, s + e);
} break;
default:
UNWIND_CHECK(false, "unknown range op: {}", op);
}
}
}
CheckedLexer findAbbrev(uint64_t offset, uint64_t entry) {
CheckedLexer L = s_.debug_abbrev.lexer(offset);
while (true) {
auto abbrev_code = L.readULEB128();
UNWIND_CHECK(
abbrev_code != 0,
"could not find entry {} at offset {:x}",
entry,
offset);
auto tag = L.readULEB128();
L.read<uint8_t>(); // has children
if (abbrev_code == entry) {
UNWIND_CHECK(
tag == DW_TAG_compile_unit,
"first entry was not a compile unit but {}",
tag);
return L;
}
while (true) {
auto attr = L.readULEB128();
auto form = L.readULEB128();
if (attr == 0 && form == 0) {
break;
}
if (form == DW_FORM_implicit_const) {
L.readSLEB128();
}
}
}
}
Sections& s_;
optional<uint64_t> line_number_program_offset_;
uint64_t offset_ = 0;
uint8_t sec_offset_size_ = 0;
uint64_t length_ = 0;
const char* end_ = nullptr;
uint64_t debug_abbrev_offset_ = 0;
bool is_64bit_ = false;
std::optional<std::pair<uint64_t, uint8_t>> range_ptr_;
uint64_t lowpc_ = 0;
optional<uint64_t> highpc_;
uint16_t version_ = 0;
uint64_t address_base_ = 0;
optional<uint64_t> rnglists_base_;
};
} // namespace torch::unwind

View File

@ -1,181 +0,0 @@
#pragma once
#include <torch/csrc/profiler/unwind/unwind_error.h>
#include <cstdint>
#include <optional>
enum {
DW_TAG_subprogram = 0x2e,
DW_TAG_inlined_subroutine = 0x1d,
DW_TAG_compile_unit = 0x11,
DW_AT_sibling = 0x1, // reference
DW_AT_name = 0x3, // string
DW_AT_stmt_list = 0x10, // lineptr
DW_AT_addr_base = 0x73, // sec_offset
DW_AT_rnglists_base = 0x74, // sec_offset
DW_AT_low_pc = 0x11, // address
DW_AT_high_pc = 0x12, // address
DW_AT_specification = 0x47, // reference
DW_AT_abstract_origin = 0x31, // reference
DW_AT_linkage_name = 0x6e, // string
DW_AT_ranges = 0x55, // rnglist
DW_AT_str_offsets_base = 0x72, // sec_offset
DW_FORM_addr = 0x01,
DW_FORM_block2 = 0x03,
DW_FORM_block4 = 0x04,
DW_FORM_data2 = 0x05,
DW_FORM_data4 = 0x06,
DW_FORM_data8 = 0x07,
DW_FORM_string = 0x08,
DW_FORM_block = 0x09,
DW_FORM_block1 = 0x0a,
DW_FORM_data1 = 0x0b,
DW_FORM_flag = 0x0c,
DW_FORM_sdata = 0x0d,
DW_FORM_strp = 0x0e,
DW_FORM_udata = 0x0f,
DW_FORM_ref_addr = 0x10,
DW_FORM_ref1 = 0x11,
DW_FORM_ref2 = 0x12,
DW_FORM_ref4 = 0x13,
DW_FORM_ref8 = 0x14,
DW_FORM_ref_udata = 0x15,
DW_FORM_indirect = 0x16,
DW_FORM_sec_offset = 0x17,
DW_FORM_exprloc = 0x18,
DW_FORM_flag_present = 0x19,
DW_FORM_strx = 0x1a,
DW_FORM_addrx = 0x1b,
DW_FORM_ref_sup4 = 0x1c,
DW_FORM_strp_sup = 0x1d,
DW_FORM_data16 = 0x1e,
DW_FORM_line_strp = 0x1f,
DW_FORM_ref_sig8 = 0x20,
DW_FORM_implicit_const = 0x21,
DW_FORM_loclistx = 0x22,
DW_FORM_rnglistx = 0x23,
DW_FORM_ref_sup8 = 0x24,
DW_FORM_strx1 = 0x25,
DW_FORM_strx2 = 0x26,
DW_FORM_strx3 = 0x27,
DW_FORM_strx4 = 0x28,
DW_FORM_addrx1 = 0x29,
DW_FORM_addrx2 = 0x2a,
DW_FORM_addrx3 = 0x2b,
DW_FORM_addrx4 = 0x2c,
/* GNU Debug Fission extensions. */
DW_FORM_GNU_addr_index = 0x1f01,
DW_FORM_GNU_str_index = 0x1f02,
DW_FORM_GNU_ref_alt = 0x1f20, /* offset in alternate .debuginfo. */
DW_FORM_GNU_strp_alt = 0x1f21, /* offset in alternate .debug_str. */
DW_LNCT_path = 0x1,
DW_LNCT_directory_index = 0x2,
DW_LNS_extended_op = 0x00,
DW_LNE_end_sequence = 0x01,
DW_LNE_set_address = 0x02,
DW_LNS_copy = 0x01,
DW_LNS_advance_pc = 0x02,
DW_LNS_advance_line = 0x03,
DW_LNS_set_file = 0x04,
DW_LNS_const_add_pc = 0x08,
DW_LNS_fixed_advance_pc = 0x09,
DW_RLE_end_of_list = 0x0,
DW_RLE_base_addressx = 0x1,
DW_RLE_startx_endx = 0x2,
DW_RLE_startx_length = 0x3,
DW_RLE_offset_pair = 0x4,
DW_RLE_base_address = 0x5,
DW_RLE_start_end = 0x6,
DW_RLE_start_length = 0x7
};
static torch::unwind::optional<size_t> formSize(
uint64_t form,
uint8_t sec_offset_size) {
switch (form) {
case DW_FORM_addr:
return sizeof(void*);
case DW_FORM_block2:
case DW_FORM_block4:
return std::nullopt;
case DW_FORM_data2:
return 2;
case DW_FORM_data4:
return 4;
case DW_FORM_data8:
return 8;
case DW_FORM_string:
case DW_FORM_block:
case DW_FORM_block1:
return std::nullopt;
case DW_FORM_data1:
case DW_FORM_flag:
return 1;
case DW_FORM_sdata:
return std::nullopt;
case DW_FORM_strp:
return sec_offset_size;
case DW_FORM_udata:
return std::nullopt;
case DW_FORM_ref_addr:
return sec_offset_size;
case DW_FORM_ref1:
return 1;
case DW_FORM_ref2:
return 2;
case DW_FORM_ref4:
return 4;
case DW_FORM_ref8:
return 8;
case DW_FORM_ref_udata:
case DW_FORM_indirect:
return std::nullopt;
case DW_FORM_sec_offset:
return sec_offset_size;
case DW_FORM_exprloc:
return std::nullopt;
case DW_FORM_flag_present:
return 0;
case DW_FORM_strx:
case DW_FORM_addrx:
return std::nullopt;
case DW_FORM_ref_sup4:
return 4;
case DW_FORM_strp_sup:
return sec_offset_size;
case DW_FORM_data16:
return 16;
case DW_FORM_line_strp:
return sec_offset_size;
case DW_FORM_ref_sig8:
return 8;
case DW_FORM_implicit_const:
return 0;
case DW_FORM_loclistx:
case DW_FORM_rnglistx:
return std::nullopt;
case DW_FORM_ref_sup8:
return 8;
case DW_FORM_strx1:
return 1;
case DW_FORM_strx2:
return 2;
case DW_FORM_strx3:
return 3;
case DW_FORM_strx4:
return 4;
case DW_FORM_addrx1:
return 1;
case DW_FORM_addrx2:
return 2;
case DW_FORM_addrx3:
return 3;
case DW_FORM_addrx4:
return 4;
case DW_FORM_GNU_addr_index:
case DW_FORM_GNU_str_index:
case DW_FORM_GNU_ref_alt:
case DW_FORM_GNU_strp_alt:
default:
return std::nullopt;
}
}

View File

@ -7,7 +7,6 @@
// Overview of the format described in
// https://refspecs.linuxfoundation.org/LSB_1.3.0/gLSB/gLSB/ehframehdr.html
namespace torch::unwind {
struct EHFrameHdr {
EHFrameHdr(void* base) : base_(base) {
@ -94,5 +93,3 @@ struct EHFrameHdr {
int64_t fde_count_;
uint32_t table_size_;
};
} // namespace torch::unwind

View File

@ -1,108 +0,0 @@
#pragma once
#include <fmt/format.h>
#include <sys/types.h>
#include <torch/csrc/profiler/unwind/debug_info.h>
#include <torch/csrc/profiler/unwind/line_number_program.h>
#include <torch/csrc/profiler/unwind/sections.h>
#include <torch/csrc/profiler/unwind/unwind.h>
#include <torch/csrc/profiler/unwind/unwind_error.h>
#include <cstddef>
#include <memory>
namespace torch::unwind {
#define UNWIND_WARN(w, ...) \
do { \
w.emplace_back(fmt::format(__VA_ARGS__)); \
LOG_INFO("WARNING: {}\n", w.back()); \
} while (0);
struct FastSymbolizer {
FastSymbolizer() = default;
Frame symbolize(const std::string& library, uint64_t offset) {
LOG_INFO("symbolizing {} + 0x{:x}\n", library, offset);
Frame frame;
frame.funcname = "??";
frame.filename = library;
frame.lineno = offset;
auto s = getOrCreateSections(library);
if (auto e = s->findSubprogramName(offset)) {
frame.funcname = *e;
} else {
UNWIND_WARN(
warnings_,
"failed to find subprogram name for {} 0x{:x}",
library,
offset);
}
if (auto e = findLine(s, offset)) {
frame.filename = e->first;
frame.lineno = e->second;
} else {
UNWIND_WARN(
warnings_, "failed to find file/line for {} 0x{:x}", library, offset);
}
return frame;
}
const std::vector<std::string>& warnings() {
return warnings_;
}
private:
void parseDebugInfo(Sections* s) {
uint64_t offset = 0;
while (offset < s->debug_info.size) {
DebugInfo info(*s);
info.parse(offset);
if (auto lnp_offset = info.lineNumberProgramOffset()) {
for (auto r : info.ranges()) {
s->addDebugInfoRange(r.first, r.second, line_number_programs_.size());
}
line_number_programs_.emplace_back(
std::make_unique<LineNumberProgram>(*s, *lnp_offset));
}
offset = info.nextOffset();
}
}
Sections* getOrCreateSections(const std::string& library) {
auto it = libraries_.find(library);
if (it == libraries_.end()) {
it = libraries_.insert({library, std::make_unique<Sections>()}).first;
try {
Sections* s = it->second.get();
s->parse(library.c_str());
parseDebugInfo(s);
} catch (UnwindError& err) {
UNWIND_WARN(
warnings_, "failed to parse library {}: {}", library, err.what());
}
}
return it->second.get();
}
optional<std::pair<std::string, int64_t>> findLine(
Sections* s,
uint64_t offset) {
if (auto idx = s->findDebugInfoOffset(offset)) {
auto r = line_number_programs_.at(*idx).get();
try {
r->parse();
} catch (UnwindError& err) {
UNWIND_WARN(
warnings_,
"failed to read line number program [{:x}] {}",
r->offset(),
err.what());
}
if (auto e = r->find(offset)) {
return std::make_pair(r->filename(e->file), e->line);
}
}
return std::nullopt;
}
std::unordered_map<std::string, std::unique_ptr<Sections>> libraries_;
std::vector<std::unique_ptr<LineNumberProgram>> line_number_programs_;
std::vector<std::string> warnings_;
};
} // namespace torch::unwind

View File

@ -7,8 +7,6 @@
#include <sstream>
#include <vector>
namespace torch::unwind {
struct TableState {
Action cfa;
std::array<Action, D_REG_SIZE> registers;
@ -400,5 +398,3 @@ struct FDE {
return strstr(augmentation_string_, s) != nullptr;
}
};
} // namespace torch::unwind

View File

@ -1,31 +1,19 @@
#pragma once
#include <cstdint>
#include <cstring>
#include <utility>
#include <stdint.h>
#include <string.h>
#include <torch/csrc/profiler/unwind/dwarf_enums.h>
#include <torch/csrc/profiler/unwind/unwind_error.h>
namespace torch::unwind {
template <bool checked>
struct LexerImpl {
LexerImpl(void* data, void* base = nullptr, void* end = nullptr)
: next_((const char*)data),
base_((int64_t)base),
end_((const char*)end) {}
struct Lexer {
Lexer(void* data, void* base = nullptr)
: next_((const char*)data), base_((int64_t)base) {}
template <typename T>
T read() {
T result;
auto end = next_ + sizeof(T);
UNWIND_CHECK(
!checked || end <= end_,
"read out of bounds {} >= {}",
(void*)end,
(void*)end_);
memcpy(&result, next_, sizeof(T));
next_ = end;
next_ += sizeof(T);
return result;
}
@ -33,7 +21,7 @@ struct LexerImpl {
int64_t readSLEB128() {
int64_t Value = 0;
unsigned Shift = 0;
uint8_t Byte = 0;
uint8_t Byte;
do {
Byte = read<uint8_t>();
uint64_t Slice = Byte & 0x7f;
@ -41,12 +29,12 @@ struct LexerImpl {
(Shift == 63 && Slice != 0 && Slice != 0x7f)) {
throw UnwindError("sleb128 too big for int64");
}
Value |= int64_t(Slice << Shift);
Value |= Slice << Shift;
Shift += 7;
} while (Byte >= 128);
// Sign extend negative numbers if needed.
if (Shift < 64 && (Byte & 0x40)) {
Value |= int64_t((-1ULL) << Shift);
Value |= (-1ULL) << Shift;
}
return Value;
}
@ -54,7 +42,7 @@ struct LexerImpl {
uint64_t readULEB128() {
uint64_t Value = 0;
unsigned Shift = 0;
uint8_t p = 0;
uint8_t p;
do {
p = read<uint8_t>();
uint64_t Slice = p & 0x7f;
@ -68,17 +56,8 @@ struct LexerImpl {
}
const char* readCString() {
auto result = next_;
if (!checked) {
next_ += strlen(next_) + 1;
return result;
}
while (next_ < end_) {
if (*next_++ == '\0') {
return result;
}
}
UNWIND_CHECK(
false, "string is out of bounds {} >= {}", (void*)next_, (void*)end_);
next_ += strlen(next_) + 1;
return result;
}
int64_t readEncoded(uint8_t enc) {
int64_t r = 0;
@ -102,27 +81,20 @@ struct LexerImpl {
}
return readEncoded(enc);
}
int64_t read4or8Length() {
return readSectionLength().first;
}
std::pair<int64_t, bool> readSectionLength() {
int64_t length = read<uint32_t>();
if (length == 0xFFFFFFFF) {
return std::make_pair(read<int64_t>(), true);
length = read<int64_t>();
}
return std::make_pair(length, false);
return length;
}
void* loc() const {
return (void*)next_;
}
LexerImpl& skip(int64_t bytes) {
Lexer& skip(int64_t bytes) {
next_ += bytes;
return *this;
}
int64_t readEncodedValue(uint8_t enc) {
switch (enc & 0xF) {
case DW_EH_PE_udata2:
@ -149,11 +121,4 @@ struct LexerImpl {
private:
const char* next_;
int64_t base_;
const char* end_;
};
// using Lexer = LexerImpl<false>;
using CheckedLexer = LexerImpl<true>;
using Lexer = LexerImpl<false>;
} // namespace torch::unwind

View File

@ -1,325 +0,0 @@
#include <c10/util/irange.h>
#include <torch/csrc/profiler/unwind/debug_info.h>
#include <torch/csrc/profiler/unwind/dwarf_enums.h>
#include <torch/csrc/profiler/unwind/dwarf_symbolize_enums.h>
#include <torch/csrc/profiler/unwind/lexer.h>
#include <torch/csrc/profiler/unwind/sections.h>
#include <torch/csrc/profiler/unwind/unwind_error.h>
#include <tuple>
namespace torch::unwind {
struct LineNumberProgram {
LineNumberProgram(Sections& s, uint64_t offset) : s_(s), offset_(offset) {}
uint64_t offset() {
return offset_;
}
void parse() {
if (parsed_) {
return;
}
parsed_ = true;
CheckedLexer L = s_.debug_line.lexer(offset_);
std::tie(length_, is_64bit_) = L.readSectionLength();
program_end_ = (char*)L.loc() + length_;
auto version = L.read<uint16_t>();
UNWIND_CHECK(
version == 5 || version == 4,
"expected version 4 or 5 but found {}",
version);
if (version == 5) {
auto address_size = L.read<uint8_t>();
UNWIND_CHECK(
address_size == 8,
"expected 64-bit dwarf but found address size {}",
address_size);
segment_selector_size_ = L.read<uint8_t>();
}
header_length_ = is_64bit_ ? L.read<uint64_t>() : L.read<uint32_t>();
program_ = L;
program_.skip(int64_t(header_length_));
minimum_instruction_length_ = L.read<uint8_t>();
maximum_operations_per_instruction_ = L.read<uint8_t>();
default_is_stmt_ = L.read<uint8_t>();
line_base_ = L.read<int8_t>();
line_range_ = L.read<uint8_t>();
opcode_base_ = L.read<uint8_t>();
UNWIND_CHECK(line_range_ != 0, "line_range_ must be non-zero");
standard_opcode_lengths_.resize(opcode_base_);
for (size_t i = 1; i < opcode_base_; i++) {
standard_opcode_lengths_[i] = L.read<uint8_t>();
}
// fmt::print("{:x} {:x} {} {} {} {} {}\n", offset_, header_length_,
// minimum_instruction_length_, maximum_operations_per_instruction_,
// line_base_, line_range_, opcode_base_);
uint8_t directory_entry_format_count = L.read<uint8_t>();
if (version == 5) {
struct Member {
uint64_t content_type;
uint64_t form;
};
std::vector<Member> directory_members;
for (size_t i = 0; i < directory_entry_format_count; i++) {
directory_members.push_back({L.readULEB128(), L.readULEB128()});
}
uint64_t directories_count = L.readULEB128();
for (size_t i = 0; i < directories_count; i++) {
for (auto& member : directory_members) {
switch (member.content_type) {
case DW_LNCT_path: {
include_directories_.emplace_back(
s_.readString(L, member.form, is_64bit_, 0));
} break;
default: {
skipForm(L, member.form);
} break;
}
}
}
for (auto i : c10::irange(directories_count)) {
(void)i;
LOG_INFO("{} {}\n", i, include_directories_[i]);
}
auto file_name_entry_format_count = L.read<uint8_t>();
std::vector<Member> file_members;
for (size_t i = 0; i < file_name_entry_format_count; i++) {
file_members.push_back({L.readULEB128(), L.readULEB128()});
}
auto files_count = L.readULEB128();
for (size_t i = 0; i < files_count; i++) {
for (auto& member : file_members) {
switch (member.content_type) {
case DW_LNCT_path: {
file_names_.emplace_back(
s_.readString(L, member.form, is_64bit_, 0));
} break;
case DW_LNCT_directory_index: {
file_directory_index_.emplace_back(readData(L, member.form));
UNWIND_CHECK(
file_directory_index_.back() < include_directories_.size(),
"directory index out of range");
} break;
default: {
skipForm(L, member.form);
} break;
}
}
}
for (auto i : c10::irange(files_count)) {
(void)i;
LOG_INFO("{} {} {}\n", i, file_names_[i], file_directory_index_[i]);
}
} else {
include_directories_.emplace_back(""); // implicit cwd
while (true) {
auto str = L.readCString();
if (*str == '\0') {
break;
}
include_directories_.emplace_back(str);
}
file_names_.emplace_back("");
file_directory_index_.emplace_back(0);
while (true) {
auto str = L.readCString();
if (*str == '\0') {
break;
}
auto directory_index = L.readULEB128();
L.readULEB128(); // mod_time
L.readULEB128(); // file_length
file_names_.emplace_back(str);
file_directory_index_.push_back(directory_index);
}
}
UNWIND_CHECK(
maximum_operations_per_instruction_ == 1,
"maximum_operations_per_instruction_ must be 1");
UNWIND_CHECK(
minimum_instruction_length_ == 1,
"minimum_instruction_length_ must be 1");
readProgram();
}
struct Entry {
uint32_t file = 1;
int64_t line = 1;
};
unwind::optional<Entry> find(uint64_t address) {
auto e = program_index_.find(address);
if (!e) {
return std::nullopt;
}
return all_programs_.at(*e).find(address);
}
std::string filename(uint64_t index) {
return fmt::format(
"{}/{}",
include_directories_.at(file_directory_index_.at(index)),
file_names_.at(index));
}
private:
void skipForm(CheckedLexer& L, uint64_t form) {
auto sz = formSize(form, is_64bit_ ? 8 : 4);
UNWIND_CHECK(sz, "unsupported form {}", form);
L.skip(int64_t(*sz));
}
uint64_t readData(CheckedLexer& L, uint64_t encoding) {
switch (encoding) {
case DW_FORM_data1:
return L.read<uint8_t>();
case DW_FORM_data2:
return L.read<uint16_t>();
case DW_FORM_data4:
return L.read<uint32_t>();
case DW_FORM_data8:
return L.read<uint64_t>();
case DW_FORM_udata:
return L.readULEB128();
default:
UNWIND_CHECK(false, "unsupported data encoding {}", encoding);
}
}
void produceEntry() {
if (shadow_) {
return;
}
if (ranges_.size() == 1) {
start_address_ = address_;
}
PRINT_LINE_TABLE(
"{:x}\t{}\t{}\n", address_, filename(entry_.file), entry_.line);
UNWIND_CHECK(
entry_.file < file_names_.size(),
"file index {} > {} entries",
entry_.file,
file_names_.size());
ranges_.add(address_, entry_, true);
}
void endSequence() {
if (shadow_) {
return;
}
PRINT_LINE_TABLE(
"{:x}\tEND\n", address_, filename(entry_.file), entry_.line);
program_index_.add(start_address_, all_programs_.size(), false);
program_index_.add(address_, std::nullopt, false);
all_programs_.emplace_back(std::move(ranges_));
ranges_ = RangeTable<Entry>();
}
void readProgram() {
while (program_.loc() < program_end_) {
PRINT_INST("{:x}: ", (char*)program_.loc() - (s_.debug_line.data));
uint8_t op = program_.read<uint8_t>();
if (op >= opcode_base_) {
auto op2 = int64_t(op - opcode_base_);
address_ += op2 / line_range_;
entry_.line += line_base_ + (op2 % line_range_);
PRINT_INST(
"address += {}, line += {}\n",
op2 / line_range_,
line_base_ + (op2 % line_range_));
produceEntry();
} else {
switch (op) {
case DW_LNS_extended_op: {
auto len = program_.readULEB128();
auto extended_op = program_.read<uint8_t>();
switch (extended_op) {
case DW_LNE_end_sequence: {
PRINT_INST("end_sequence\n");
endSequence();
entry_ = Entry{};
} break;
case DW_LNE_set_address: {
address_ = program_.read<uint64_t>();
if (!shadow_) {
PRINT_INST(
"set address {:x} {:x} {:x}\n",
address_,
min_address_,
max_address_);
}
shadow_ = address_ == 0;
} break;
default: {
PRINT_INST("skip extended op {}\n", extended_op);
program_.skip(int64_t(len - 1));
} break;
}
} break;
case DW_LNS_copy: {
PRINT_INST("copy\n");
produceEntry();
} break;
case DW_LNS_advance_pc: {
PRINT_INST("advance pc\n");
address_ += program_.readULEB128();
} break;
case DW_LNS_advance_line: {
entry_.line += program_.readSLEB128();
PRINT_INST("advance line {}\n", entry_.line);
} break;
case DW_LNS_set_file: {
PRINT_INST("set file\n");
entry_.file = program_.readULEB128();
} break;
case DW_LNS_const_add_pc: {
PRINT_INST("const add pc\n");
address_ += (255 - opcode_base_) / line_range_;
} break;
case DW_LNS_fixed_advance_pc: {
PRINT_INST("fixed advance pc\n");
address_ += program_.read<uint16_t>();
} break;
default: {
PRINT_INST("other {}\n", op);
auto n = standard_opcode_lengths_[op];
for (int i = 0; i < n; ++i) {
program_.readULEB128();
}
} break;
}
}
}
PRINT_INST(
"{:x}: end {:x}\n",
((char*)program_.loc() - s_.debug_line.data),
program_end_ - s_.debug_line.data);
}
uint64_t address_ = 0;
bool shadow_ = false;
bool parsed_ = false;
Entry entry_ = {};
std::vector<std::string> include_directories_;
std::vector<std::string> file_names_;
std::vector<uint64_t> file_directory_index_;
uint8_t segment_selector_size_ = 0;
uint8_t minimum_instruction_length_ = 0;
uint8_t maximum_operations_per_instruction_ = 0;
int8_t line_base_ = 0;
uint8_t line_range_ = 0;
uint8_t opcode_base_ = 0;
bool default_is_stmt_ = false;
CheckedLexer program_ = {nullptr};
char* program_end_ = nullptr;
uint64_t header_length_ = 0;
uint64_t length_ = 0;
bool is_64bit_ = false;
std::vector<uint8_t> standard_opcode_lengths_;
Sections& s_;
uint64_t offset_;
uint64_t start_address_ = 0;
RangeTable<uint64_t> program_index_;
std::vector<RangeTable<Entry>> all_programs_;
RangeTable<Entry> ranges_;
};
} // namespace torch::unwind

View File

@ -1,150 +0,0 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <elf.h>
#include <fcntl.h>
#include <fmt/format.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <torch/csrc/profiler/unwind/lexer.h>
#include <torch/csrc/profiler/unwind/unwind_error.h>
#include <unistd.h>
#include <cerrno>
#include <cstdio>
#include <cstring>
#include <iostream>
namespace torch::unwind {
struct Section {
char* data = nullptr;
size_t size = 0;
const char* string(size_t offset) {
return lexer(offset).readCString();
}
CheckedLexer lexer(size_t offset) {
return CheckedLexer(data + offset, data, data + size);
}
};
/// Memory maps a file into the address space read-only, and manages the
/// lifetime of the mapping. Here are a few use cases:
/// 1. Used in the loader to read in initial image, and to inspect
// ELF files for dependencies before callling dlopen.
///
/// 2. Used in unity to load the elf file.
struct MemFile {
explicit MemFile(const char* filename_)
: fd_(open(filename_, O_RDONLY)),
mem_(nullptr),
n_bytes_(0),
name_(filename_) {
UNWIND_CHECK(
fd_ != -1, "failed to open {}: {}", filename_, strerror(errno));
// NOLINTNEXTLINE
struct stat s;
if (-1 == fstat(fd_, &s)) {
close(fd_); // destructors don't run during exceptions
UNWIND_CHECK(false, "failed to stat {}: {}", filename_, strerror(errno));
}
n_bytes_ = s.st_size;
UNWIND_CHECK(
n_bytes_ > sizeof(Elf64_Ehdr), "empty shared library: {}", filename_);
mem_ = (char*)mmap(nullptr, n_bytes_, PROT_READ, MAP_SHARED, fd_, 0);
if (MAP_FAILED == mem_) {
close(fd_);
UNWIND_CHECK(false, "failed to mmap {}: {}", filename_, strerror(errno));
}
ehdr_ = (Elf64_Ehdr*)mem_;
#define ELF_CHECK(cond) UNWIND_CHECK(cond, "not an ELF file: {}", filename_)
ELF_CHECK(ehdr_->e_ident[EI_MAG0] == ELFMAG0);
ELF_CHECK(ehdr_->e_ident[EI_MAG1] == ELFMAG1);
ELF_CHECK(ehdr_->e_ident[EI_MAG2] == ELFMAG2);
ELF_CHECK(ehdr_->e_ident[EI_MAG3] == ELFMAG3);
ELF_CHECK(ehdr_->e_ident[EI_CLASS] == ELFCLASS64);
ELF_CHECK(ehdr_->e_ident[EI_VERSION] == EV_CURRENT);
ELF_CHECK(ehdr_->e_version == EV_CURRENT);
ELF_CHECK(ehdr_->e_machine == EM_X86_64);
#undef ELF_CHECK
UNWIND_CHECK(
ehdr_->e_shoff + sizeof(Elf64_Shdr) * ehdr_->e_shnum <= n_bytes_,
"invalid section header table {} {} {}",
ehdr_->e_shoff + sizeof(Elf64_Shdr) * ehdr_->e_shnum,
n_bytes_,
ehdr_->e_shnum);
shdr_ = (Elf64_Shdr*)(mem_ + ehdr_->e_shoff);
UNWIND_CHECK(
ehdr_->e_shstrndx < ehdr_->e_shnum, "invalid strtab section offset");
auto& strtab_hdr = shdr_[ehdr_->e_shstrndx];
strtab_ = getSection(strtab_hdr);
}
MemFile(const MemFile&) = delete;
MemFile& operator=(const MemFile&) = delete;
[[nodiscard]] const char* data() const {
return (const char*)mem_;
}
/// Returns whether or not the file descriptor
/// of the underlying file is valid.
int valid() {
return fcntl(fd_, F_GETFD) != -1 || errno != EBADF;
}
~MemFile() {
if (mem_) {
munmap((void*)mem_, n_bytes_);
}
if (fd_) {
close(fd_);
}
}
/// Returns the size of the underlying file defined by the `MemFile`
size_t size() {
return n_bytes_;
}
[[nodiscard]] int fd() const {
return fd_;
}
Section getSection(const Elf64_Shdr& shdr) {
UNWIND_CHECK(shdr.sh_offset + shdr.sh_size <= n_bytes_, "invalid section");
return Section{mem_ + shdr.sh_offset, shdr.sh_size};
}
Section getSection(const char* name, bool optional) {
for (int i = 0; i < ehdr_->e_shnum; i++) {
if (strcmp(strtab_.string(shdr_[i].sh_name), name) == 0) {
return getSection(shdr_[i]);
}
}
UNWIND_CHECK(optional, "{} has no section {}", name_, name);
return Section{nullptr, 0};
}
Section strtab() {
return strtab_;
}
private:
template <typename T>
T* load(size_t offset) {
UNWIND_CHECK(offset < n_bytes_, "out of range");
return (T*)(mem_ + offset);
}
int fd_;
char* mem_;
size_t n_bytes_;
std::string name_;
Elf64_Ehdr* ehdr_;
Elf64_Shdr* shdr_;
Section strtab_ = {nullptr, 0};
};
} // namespace torch::unwind

View File

@ -1,74 +0,0 @@
#pragma once
#include <torch/csrc/profiler/unwind/unwind_error.h>
#include <algorithm>
#include <memory>
#include <optional>
#include <unordered_map>
#include <vector>
namespace torch::unwind {
template <typename T>
struct RangeTable {
RangeTable() {
// guarentee that lower_bound[-1] is always valid
addresses_.push_back(0);
payloads_.emplace_back(std::nullopt);
}
void add(uint64_t address, unwind::optional<T> payload, bool sorted) {
if (addresses_.back() > address) {
UNWIND_CHECK(!sorted, "expected addresses to be sorted");
sorted_ = false;
}
addresses_.push_back(address);
payloads_.emplace_back(std::move(payload));
}
unwind::optional<T> find(uint64_t address) {
maybeSort();
auto it = std::upper_bound(addresses_.begin(), addresses_.end(), address);
return payloads_.at(it - addresses_.begin() - 1);
}
void dump() {
for (size_t i = 0; i < addresses_.size(); i++) {
fmt::print("{} {:x}: {}\n", i, addresses_[i], payloads_[i] ? "" : "END");
}
}
size_t size() const {
return addresses_.size();
}
uint64_t back() {
maybeSort();
return addresses_.back();
}
private:
void maybeSort() {
if (sorted_) {
return;
}
std::vector<uint64_t> indices;
indices.reserve(addresses_.size());
for (size_t i = 0; i < addresses_.size(); i++) {
indices.push_back(i);
}
std::sort(indices.begin(), indices.end(), [&](uint64_t a, uint64_t b) {
return addresses_[a] < addresses_[b] ||
(addresses_[a] == addresses_[b] &&
bool(payloads_[a]) < bool(payloads_[b]));
});
std::vector<uint64_t> addresses;
std::vector<unwind::optional<T>> payloads;
addresses.reserve(addresses_.size());
payloads.reserve(addresses_.size());
for (auto i : indices) {
addresses.push_back(addresses_[i]);
payloads.push_back(payloads_[i]);
}
addresses_ = std::move(addresses);
payloads_ = std::move(payloads);
sorted_ = true;
}
bool sorted_ = true;
std::vector<uint64_t> addresses_;
std::vector<unwind::optional<T>> payloads_;
};
} // namespace torch::unwind

View File

@ -1,124 +0,0 @@
#pragma once
#include <cxxabi.h>
#include <elf.h>
#include <torch/csrc/profiler/unwind/dwarf_enums.h>
#include <torch/csrc/profiler/unwind/dwarf_symbolize_enums.h>
#include <torch/csrc/profiler/unwind/mem_file.h>
#include <torch/csrc/profiler/unwind/range_table.h>
#include <torch/csrc/profiler/unwind/unwind_error.h>
#include <cstdint>
namespace torch::unwind {
static std::string demangle(const std::string& mangled_name) {
int status = 0;
char* realname =
abi::__cxa_demangle(mangled_name.c_str(), nullptr, nullptr, &status);
if (status == 0) {
std::string demangled_name(realname);
// NOLINTNEXTLINE
free(realname);
return demangled_name;
} else {
return mangled_name;
}
}
struct Sections {
Sections() = default;
void parse(const char* name) {
library_ = std::make_unique<MemFile>(name);
strtab = library_->getSection(".strtab", false);
symtab = library_->getSection(".symtab", true);
debug_info = library_->getSection(".debug_info", true);
if (debug_info.size > 0) {
debug_abbrev = library_->getSection(".debug_abbrev", false);
debug_str = library_->getSection(".debug_str", false);
debug_line = library_->getSection(".debug_line", false);
// dwarf 5
debug_line_str = library_->getSection(".debug_line_str", true);
debug_rnglists = library_->getSection(".debug_rnglists", true);
debug_addr = library_->getSection(".debug_addr", true);
// dwarf 4
debug_ranges = library_->getSection(".debug_ranges", true);
}
parseSymtab();
}
Section debug_info;
Section debug_abbrev;
Section debug_str;
Section debug_line;
Section debug_line_str;
Section debug_rnglists;
Section debug_ranges;
Section debug_addr;
Section symtab;
Section strtab;
const char* readString(
CheckedLexer& data,
uint64_t encoding,
bool is_64bit,
uint64_t str_offsets_base) {
switch (encoding) {
case DW_FORM_string: {
return data.readCString();
}
case DW_FORM_strp: {
return debug_str.string(readSegmentOffset(data, is_64bit));
}
case DW_FORM_line_strp: {
return debug_line_str.string(readSegmentOffset(data, is_64bit));
}
default:
UNWIND_CHECK(false, "unsupported string encoding {:x}", encoding);
}
}
uint64_t readSegmentOffset(CheckedLexer& data, bool is_64bit) {
return is_64bit ? data.read<uint64_t>() : data.read<uint32_t>();
}
unwind::optional<uint64_t> findDebugInfoOffset(uint64_t address) {
return debug_info_offsets_.find(address);
}
size_t compilationUnitCount() {
return debug_info_offsets_.size() / 2;
}
void addDebugInfoRange(
uint64_t start,
uint64_t end,
uint64_t debug_info_offset) {
debug_info_offsets_.add(start, debug_info_offset, false);
debug_info_offsets_.add(end, std::nullopt, false);
}
optional<std::string> findSubprogramName(uint64_t address) {
if (auto e = symbol_table_.find(address)) {
return demangle(strtab.string(*e));
}
return std::nullopt;
}
private:
void parseSymtab() {
auto L = symtab.lexer(0);
char* end = symtab.data + symtab.size;
while (L.loc() < end) {
auto symbol = L.read<Elf64_Sym>();
if (symbol.st_shndx == SHN_UNDEF ||
ELF64_ST_TYPE(symbol.st_info) != STT_FUNC) {
continue;
}
symbol_table_.add(symbol.st_value, symbol.st_name, false);
symbol_table_.add(symbol.st_value + symbol.st_size, std::nullopt, false);
}
}
std::unique_ptr<MemFile> library_;
RangeTable<uint64_t> debug_info_offsets_;
RangeTable<uint64_t> symbol_table_;
};
} // namespace torch::unwind

View File

@ -1,7 +1,6 @@
#include <c10/util/Exception.h>
#include <torch/csrc/profiler/unwind/unwind.h>
#include <torch/csrc/utils/cpp_stacktraces.h>
#include <unordered_map>
#if !defined(__linux__) || !defined(__x86_64__) || !defined(__has_include) || \
!__has_include("ext/stdio_filebuf.h")
@ -19,7 +18,7 @@ c10::optional<std::pair<std::string, uint64_t>> libraryFor(void* addr) {
}
#ifndef FBCODE_CAFFE2
std::vector<Frame> symbolize(const std::vector<void*>& frames, Mode mode) {
std::vector<Frame> symbolize(const std::vector<void*>& frames) {
TORCH_CHECK(
false,
"record_context_cpp is not support on non-linux non-x86_64 platforms");
@ -49,15 +48,10 @@ Stats stats() {
#include <torch/csrc/profiler/unwind/communicate.h>
#include <torch/csrc/profiler/unwind/dwarf_enums.h>
#include <torch/csrc/profiler/unwind/eh_frame_hdr.h>
#include <torch/csrc/profiler/unwind/fast_symbolizer.h>
#include <torch/csrc/profiler/unwind/fde.h>
#include <torch/csrc/profiler/unwind/unwinder.h>
#include <shared_mutex>
extern "C" void unwind_c(std::vector<void*>* result, int64_t rsp, int64_t rbp);
extern "C" void unwind_entry(std::vector<void*>* result);
namespace torch::unwind {
struct UpgradeExclusive {
UpgradeExclusive(std::shared_lock<std::shared_timed_mutex>& rdlock)
: rdlock_(rdlock) {
@ -203,7 +197,7 @@ struct UnwindCache {
Unwinder unwinder = Unwinder::unknown();
try {
unwinder = libraryFor(addr).unwinderFor(addr);
} catch (unwind::UnwindError& err) {
} catch (UnwindError& err) {
// because unwinders are cached this will only print
// once per frame that cannot be unwound.
TORCH_WARN("Unsupported unwinding pattern: ", err.what());
@ -282,6 +276,46 @@ struct UnwindCache {
static UnwindCache unwind_cache;
static std::shared_timed_mutex cache_mutex_;
extern "C" void unwind_c(std::vector<void*>* result, int64_t rsp, int64_t rbp);
extern "C" void unwind_c(std::vector<void*>* result, int64_t rsp, int64_t rbp) {
std::shared_lock lock(cache_mutex_);
UnwindState state{};
// NOLINTNEXTLINE(performance-no-int-to-ptr)
state.rip = *(int64_t*)(rsp);
// +8 because we saved rsp after the return address was already pushed
// to the stack
state.rsp = rsp + 8;
state.rbp = rbp;
unwind_cache.checkRefresh(lock);
while (true) { // unwind for _start sets rip as being undefined
// NOLINTNEXTLINE(performance-no-int-to-ptr)
result->push_back((void*)state.rip);
const Unwinder& uw = unwind_cache.unwinderFor(state.rip, lock);
if (uw.terminator()) {
if (uw.isUnknown()) {
result->push_back(nullptr);
}
break;
}
state = uw.run(state);
}
}
extern "C" void unwind_entry(std::vector<void*>* result);
// calling convention puts the first three pointer/int64_t arguments in
// rdi rsi rdx (all caller-saved)
// rdi already holds the pointer to the result vector
// we add arguments for current rsp and rbp and then tail call
// into unwind_c
__asm__(
".global unwind_entry\n"
"unwind_entry:\n"
"mov %rsp, %rsi;\n"
"mov %rbp, %rdx;\n"
"jmp unwind_c;\n");
namespace torch::unwind {
std::vector<void*> unwind() {
std::vector<void*> frames;
unwind_entry(&frames);
@ -301,15 +335,6 @@ c10::optional<std::pair<std::string, uint64_t>> libraryFor(void* addr) {
library_info->name(), (uint64_t)addr - library_info->load_bias());
}
static std::string dladdr_lookup(void* addr) {
Dl_info dlinfo;
std::string funcname = "??";
if (dladdr(addr, &dlinfo) && dlinfo.dli_sname) {
funcname = demangle(dlinfo.dli_sname);
}
return funcname;
}
struct Symbolizer {
Symbolizer() {
auto envar = std::getenv("TORCH_ADDR2LINE_BINARY");
@ -320,6 +345,9 @@ struct Symbolizer {
} else {
addr2line_binary_ = "addr2line"; // default
}
if (torch::get_disable_addr2line()) {
addr2line_binary_ = nullptr;
}
}
static std::lock_guard<std::mutex> guard() {
static std::mutex mutex;
@ -339,6 +367,16 @@ struct Symbolizer {
frame_map_[addr] = Frame{"??", "<unwind unsupported>", 0};
return;
}
if (addr2line_binary_ == nullptr) {
Dl_info dlinfo;
std::string funcname = "??";
if (dladdr(addr, &dlinfo) && dlinfo.dli_sname) {
funcname = demangle(dlinfo.dli_sname);
}
frame_map_[addr] = Frame{
maybe_library->first, std::move(funcname), maybe_library->second - 1};
return;
}
has_pending_results_ = true;
auto& entry = getOrCreate(maybe_library->first);
entry.queried.push_back(addr);
@ -410,59 +448,23 @@ struct Symbolizer {
frame_map_[e.queried[e.completed]] = std::move(frame);
}
}
std::string demangle(const std::string& mangled_name) {
int status = 0;
char* realname =
abi::__cxa_demangle(mangled_name.c_str(), nullptr, nullptr, &status);
if (status == 0) {
std::string demangled_name(realname);
// NOLINTNEXTLINE
free(realname);
return demangled_name;
} else {
return mangled_name;
}
}
};
static std::vector<Frame> symbolize_fast(
const std::vector<void*>& frames,
Mode mode) {
static std::mutex cache_mutex;
static std::array<ska::flat_hash_map<void*, Frame>, 2> frame_maps;
auto& frame_map = frame_maps[mode == Mode::fast ? 0 : 1];
std::vector<uint32_t> indices_to_lookup;
std::vector<Frame> results;
results.reserve(frames.size());
{
std::lock_guard<std::mutex> lock(cache_mutex);
for (auto i : c10::irange(frames.size())) {
void* f = frames.at(i);
auto it = frame_map.find(f);
if (it == frame_map.end()) {
indices_to_lookup.push_back(i);
results.emplace_back(Frame{"??", "??", 0});
} else {
results.emplace_back(it->second);
}
}
}
if (!indices_to_lookup.empty()) {
// do symbolizer work
FastSymbolizer symbolizer;
for (auto i : indices_to_lookup) {
void* addr = frames.at(i);
Frame& f = results.at(i);
auto library = libraryFor(frames.at(i));
if (library) {
if (mode == Mode::fast) {
f = symbolizer.symbolize(library->first, library->second - 1);
} else {
f = Frame{library->first, "??", library->second - 1};
}
}
if (f.funcname == "??") {
f.funcname = dladdr_lookup(addr);
}
}
std::lock_guard<std::mutex> lock(cache_mutex);
for (auto i : indices_to_lookup) {
frame_map.emplace(frames.at(i), results.at(i));
}
}
return results;
}
static std::vector<Frame> symbolize_addr2line(
const std::vector<void*>& frames) {
#ifndef FBCODE_CAFFE2
std::vector<Frame> symbolize(const std::vector<void*>& frames) {
auto guard = Symbolizer::guard();
Symbolizer& s = Symbolizer::get();
for (auto f : frames) {
@ -475,16 +477,6 @@ static std::vector<Frame> symbolize_addr2line(
}
return results;
}
// fbcode will use llvm symbolize since there is an llvm dependency already
#ifndef FBCODE_CAFFE2
std::vector<Frame> symbolize(const std::vector<void*>& frames, Mode mode) {
if (mode == Mode::addr2line) {
return symbolize_addr2line(frames);
} else {
return symbolize_fast(frames, mode);
}
}
#endif
Stats stats() {
@ -492,42 +484,4 @@ Stats stats() {
}
} // namespace torch::unwind
extern "C" void unwind_c(std::vector<void*>* result, int64_t rsp, int64_t rbp) {
std::shared_lock lock(torch::unwind::cache_mutex_);
torch::unwind::UnwindState state{};
// NOLINTNEXTLINE(performance-no-int-to-ptr)
state.rip = *(int64_t*)(rsp);
// +8 because we saved rsp after the return address was already pushed
// to the stack
state.rsp = rsp + 8;
state.rbp = rbp;
torch::unwind::unwind_cache.checkRefresh(lock);
while (true) { // unwind for _start sets rip as being undefined
// NOLINTNEXTLINE(performance-no-int-to-ptr)
result->push_back((void*)state.rip);
const torch::unwind::Unwinder& uw =
torch::unwind::unwind_cache.unwinderFor(state.rip, lock);
if (uw.terminator()) {
if (uw.isUnknown()) {
result->push_back(nullptr);
}
break;
}
state = uw.run(state);
}
}
// calling convention puts the first three pointer/int64_t arguments in
// rdi rsi rdx (all caller-saved)
// rdi already holds the pointer to the result vector
// we add arguments for current rsp and rbp and then tail call
// into unwind_c
__asm__(
".global unwind_entry\n"
"unwind_entry:\n"
"mov %rsp, %rsi;\n"
"mov %rbp, %rdx;\n"
"jmp unwind_c;\n");
#endif

View File

@ -1,11 +1,11 @@
#pragma once
#include <c10/macros/Export.h>
#include <c10/util/Optional.h>
#include <cstdint>
#include <string>
#include <vector>
namespace torch::unwind {
namespace torch {
namespace unwind {
// gather current stack, relatively fast.
// gets faster once the cache of program counter locations is warm.
TORCH_API std::vector<void*> unwind();
@ -16,17 +16,13 @@ struct Frame {
uint64_t lineno;
};
enum class Mode { addr2line, fast, dladdr };
// note: symbolize is really slow
// it will launch an addr2line process that has to parse dwarf
// information from the libraries that frames point into.
// Callers should first batch up all the unique void* pointers
// across a number of unwind states and make a single call to
// symbolize.
TORCH_API std::vector<Frame> symbolize(
const std::vector<void*>& frames,
Mode mode);
TORCH_API std::vector<Frame> symbolize(const std::vector<void*>& frames);
// returns path to the library, and the offset of the addr inside the library
TORCH_API c10::optional<std::pair<std::string, uint64_t>> libraryFor(
@ -40,4 +36,5 @@ struct Stats {
};
Stats stats();
} // namespace torch::unwind
} // namespace unwind
} // namespace torch

View File

@ -1,31 +1,6 @@
#pragma once
#include <c10/util/Optional.h>
#include <fmt/format.h>
#include <stdexcept>
namespace torch::unwind {
struct UnwindError : public std::runtime_error {
using std::runtime_error::runtime_error;
};
#define UNWIND_CHECK(cond, fmtstring, ...) \
do { \
if (!(cond)) { \
throw unwind::UnwindError(fmt::format( \
"{}:{}: " fmtstring, __FILE__, __LINE__, ##__VA_ARGS__)); \
} \
} while (0)
// #define LOG_INFO(...) fmt::print(__VA_ARGS__)
#define LOG_INFO(...)
// #define PRINT_INST(...) LOG_INFO(__VA_ARGS__)
#define PRINT_INST(...)
// #define PRINT_LINE_TABLE(...) LOG_INFO(__VA_ARGS__)
#define PRINT_LINE_TABLE(...)
using c10::optional; // NOLINT
} // namespace torch::unwind

View File

@ -6,9 +6,9 @@
#include <torch/csrc/profiler/unwind/unwind.h>
namespace torch {
namespace torch::unwind {
namespace unwind {
std::vector<Frame> symbolize(const std::vector<void*>& frames, Mode mode) {
std::vector<Frame> symbolize(const std::vector<void*>& frames) {
static std::mutex symbolize_mutex;
static llvm::symbolize::LLVMSymbolizer symbolizer;
static ska::flat_hash_map<void*, Frame> frame_map_;
@ -38,7 +38,7 @@ std::vector<Frame> symbolize(const std::vector<void*>& frames, Mode mode) {
return results;
}
} // namespace torch::unwind
} // namespace unwind
} // namespace torch
#endif

View File

@ -4,8 +4,6 @@
#include <cstdint>
#include <limits>
namespace torch::unwind {
struct UnwindState {
int64_t rip, rbp, rsp;
};
@ -77,5 +75,3 @@ struct Unwinder {
int64_t rbp_off_;
bool deref_{false};
};
} // namespace torch::unwind

View File

@ -47,31 +47,9 @@ bool get_cpp_stacktraces_enabled() {
return enabled;
}
static torch::unwind::Mode compute_symbolize_mode() {
auto envar_c = std::getenv("TORCH_SYMBOLIZE_MODE");
if (envar_c) {
std::string envar = envar_c;
if (envar == "dladdr") {
return unwind::Mode::dladdr;
} else if (envar == "addr2line") {
return unwind::Mode::addr2line;
} else if (envar == "fast") {
return unwind::Mode::fast;
} else {
TORCH_CHECK(
false,
"expected {dladdr, addr2line, fast} for TORCH_SYMBOLIZE_MODE, got ",
envar);
}
} else {
return compute_disable_addr2line() ? unwind::Mode::dladdr
: unwind::Mode::addr2line;
}
}
unwind::Mode get_symbolize_mode() {
static unwind::Mode mode = compute_symbolize_mode();
return mode;
bool get_disable_addr2line() {
static bool disabled = compute_disable_addr2line();
return disabled;
}
} // namespace torch

View File

@ -1,9 +1,8 @@
#pragma once
#include <torch/csrc/Export.h>
#include <torch/csrc/profiler/unwind/unwind.h>
namespace torch {
TORCH_API bool get_cpp_stacktraces_enabled();
TORCH_API torch::unwind::Mode get_symbolize_mode();
TORCH_API bool get_disable_addr2line();
} // namespace torch