mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Static Runtime] Remove ProcessedNode::num_outputs_ (#72592)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72592 Only code paths that are not perf-critical read `ProcessedNode::num_outputs_` and also its static feature of the op that `ProcessedNode` instance is executing. Therefore, it's better to move `ProcessedNode::num_outputs_` into `ProcessedFunction::num_outputs_` and let `ProcessedNode` access it via `ProcessedNode::fn_` for its occasional use. Note that this prevents duplicating num_outputs_ per node & per Static Runtime instance since `ProcessedFunction` instances are shared across all runtime instances. It's confirmed that this change reduces the `sizeof(ProcessedNode)` by 14% from local instrumentation as follows: - Before -- sizeof(ProcessedNode): 56 - After -- sizeof(Processednode): 48 Test Plan: `buck test //caffe2/benchmarks/static_runtime:static_runtime_cpptest` Reviewed By: mikeiovine Differential Revision: D33984792 fbshipit-source-id: e29ffc97b799e679215f42e1e85cd3fcd7e88983
This commit is contained in:
parent
74f94447fc
commit
0f7003f4df
|
|
@ -1734,7 +1734,8 @@ ProcessedFunction::ProcessedFunction(
|
|||
Node* node,
|
||||
bool enable_out_variant,
|
||||
bool check_memory_overlap)
|
||||
: check_memory_overlap_(check_memory_overlap) {
|
||||
: check_memory_overlap_(check_memory_overlap),
|
||||
num_outputs_(node->outputs().size()) {
|
||||
if (enable_out_variant) {
|
||||
f_ = getOutOfPlaceOperation(node);
|
||||
if (f_) {
|
||||
|
|
@ -1791,13 +1792,7 @@ ProcessedNode::ProcessedNode(
|
|||
fn_(fn),
|
||||
inputs_(std::move(inputs)),
|
||||
outputs_offset_(outputs_offset) {
|
||||
TORCH_CHECK(
|
||||
node->outputs().size() < (1 << (sizeof(num_outputs_) * 8)),
|
||||
node->outputs().size(),
|
||||
" outputs to ProcessedNode ",
|
||||
node->kind().toQualString(),
|
||||
" is too many to use 2-byte indexing");
|
||||
num_outputs_ = node->outputs().size();
|
||||
TORCH_CHECK(num_outputs() == node->outputs().size());
|
||||
}
|
||||
|
||||
std::vector<IValue> ProcessedNode::inputs_ivalue_vec() const {
|
||||
|
|
@ -1869,12 +1864,12 @@ bool ProcessedNode::verify_no_memory_overlap(bool force_check) const {
|
|||
}
|
||||
|
||||
bool ProcessedNode::verify_outputs_dont_overlap_each_other() const {
|
||||
for (const auto i : c10::irange(num_outputs_)) {
|
||||
for (const auto i : c10::irange(num_outputs())) {
|
||||
if (!Output(i).isTensor()) {
|
||||
continue;
|
||||
}
|
||||
const auto& out0_t = Output(i).toTensor();
|
||||
for (const auto j : c10::irange(i + 1, num_outputs_)) {
|
||||
for (const auto j : c10::irange(i + 1, num_outputs())) {
|
||||
if (!Output(j).isTensor()) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -1894,7 +1889,7 @@ bool ProcessedNode::verify_inputs_dont_overlap_outputs(bool force_check) const {
|
|||
// skip memory overlap check for mutable or view ops with only one output
|
||||
bool skip_check = !schema ||
|
||||
((schema->is_mutable() || !fn_->checkMemoryOverlap()) &&
|
||||
num_outputs_ == 1);
|
||||
num_outputs() == 1);
|
||||
if (!force_check && skip_check) {
|
||||
if (!schema) {
|
||||
VLOG(2) << "Detected that op schema is null";
|
||||
|
|
@ -1902,7 +1897,7 @@ bool ProcessedNode::verify_inputs_dont_overlap_outputs(bool force_check) const {
|
|||
}
|
||||
VLOG(2) << "schema->is_mutable: " << schema->is_mutable()
|
||||
<< ", fn_->checkMemoryOverlap: " << fn_->checkMemoryOverlap()
|
||||
<< ", num_outputs_: " << num_outputs_;
|
||||
<< ", num_outputs_: " << num_outputs();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -1912,7 +1907,7 @@ bool ProcessedNode::verify_inputs_dont_overlap_outputs(bool force_check) const {
|
|||
continue;
|
||||
}
|
||||
const auto& in_t = in->toTensor();
|
||||
for (const auto j : c10::irange(num_outputs_)) {
|
||||
for (const auto j : c10::irange(num_outputs())) {
|
||||
const IValue& out = Output(j);
|
||||
if (!out.isTensor()) {
|
||||
continue;
|
||||
|
|
@ -1949,7 +1944,7 @@ void ProcessedNode::verify_and_correct_memory_overlap() {
|
|||
continue;
|
||||
}
|
||||
const auto& in_t = in.toTensor();
|
||||
for (const auto j : c10::irange(num_outputs_)) {
|
||||
for (const auto j : c10::irange(num_outputs())) {
|
||||
auto& output = Output(j);
|
||||
if (output.isTensor()) {
|
||||
check_and_correct_overlap_with(in_t, output);
|
||||
|
|
|
|||
|
|
@ -752,10 +752,15 @@ class TORCH_API ProcessedFunction {
|
|||
return check_memory_overlap_;
|
||||
}
|
||||
|
||||
size_t num_outputs() const {
|
||||
return num_outputs_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<void(ProcessedNode*)> f_;
|
||||
Kind kind_{ProcessedFunction::Kind::kOutVariant};
|
||||
bool check_memory_overlap_{false};
|
||||
size_t num_outputs_{0};
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
|
|
@ -777,10 +782,9 @@ class TORCH_API ProcessedNode {
|
|||
ProcessedNode(const ProcessedNode& other)
|
||||
: node_(other.node_),
|
||||
fn_(other.fn_),
|
||||
overlap_detected_(other.overlap_detected_),
|
||||
inputs_(other.inputs_),
|
||||
outputs_offset_(other.outputs_offset_),
|
||||
num_outputs_(other.num_outputs_),
|
||||
overlap_detected_(other.overlap_detected_),
|
||||
values_(other.values_),
|
||||
// It doesn't really make sense to copy block runners,
|
||||
// each processed node needs its own. This is OK to do
|
||||
|
|
@ -797,10 +801,9 @@ class TORCH_API ProcessedNode {
|
|||
}
|
||||
node_ = other.node_;
|
||||
fn_ = other.fn_;
|
||||
overlap_detected_ = other.overlap_detected_;
|
||||
inputs_ = other.inputs_;
|
||||
outputs_offset_ = other.outputs_offset_;
|
||||
num_outputs_ = other.num_outputs_;
|
||||
overlap_detected_ = other.overlap_detected_;
|
||||
values_ = other.values_;
|
||||
block_runners_ = nullptr;
|
||||
return *this;
|
||||
|
|
@ -825,21 +828,23 @@ class TORCH_API ProcessedNode {
|
|||
|
||||
// Output is readwrite
|
||||
IValue& Output(uint32_t i) {
|
||||
DCHECK(i < num_outputs_);
|
||||
DCHECK(i < num_outputs());
|
||||
return values_[outputs_offset_ + i];
|
||||
}
|
||||
|
||||
C10_NODISCARD const IValue& Output(uint32_t i) const {
|
||||
DCHECK(i < num_outputs_);
|
||||
DCHECK(i < num_outputs());
|
||||
return values_[outputs_offset_ + i];
|
||||
}
|
||||
|
||||
C10_NODISCARD c10::ArrayRef<const IValue> outputs() const {
|
||||
return c10::ArrayRef<const IValue>(values_ + outputs_offset_, num_outputs_);
|
||||
size_t num_outputs() const {
|
||||
DCHECK(fn_ != nullptr);
|
||||
return fn_->num_outputs();
|
||||
}
|
||||
|
||||
C10_NODISCARD auto num_outputs() const {
|
||||
return num_outputs_;
|
||||
C10_NODISCARD c10::ArrayRef<const IValue> outputs() const {
|
||||
return c10::ArrayRef<const IValue>(
|
||||
values_ + outputs_offset_, num_outputs());
|
||||
}
|
||||
|
||||
C10_NODISCARD uint16_t num_inputs() const {
|
||||
|
|
@ -885,7 +890,7 @@ class TORCH_API ProcessedNode {
|
|||
}
|
||||
|
||||
C10_NODISCARD uint16_t output_ivalue_index(uint16_t i) const {
|
||||
DCHECK(i < num_outputs_);
|
||||
DCHECK(i < num_outputs());
|
||||
return outputs_offset_ + i;
|
||||
}
|
||||
// used in debug mode
|
||||
|
|
@ -907,10 +912,9 @@ class TORCH_API ProcessedNode {
|
|||
|
||||
Node* node_;
|
||||
const ProcessedFunction* fn_;
|
||||
bool overlap_detected_{false};
|
||||
ProcessedNodeInputs inputs_;
|
||||
uint16_t outputs_offset_;
|
||||
uint16_t num_outputs_;
|
||||
bool overlap_detected_{false};
|
||||
IValue* values_ = nullptr; // unowned
|
||||
// For control flow; processed nodes may have sub-blocks which can
|
||||
// be executed by op implementations.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user