[PyTorch][Static Runtime] Combine ProcessedNode::{native_,}fn_ (#65414)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65414

Saves 24 bytes (`sizeof(std::function) - 8`) per ProcessedNode.
ghstack-source-id: 139999909

Test Plan: CI

Reviewed By: hlu1

Differential Revision: D31085561

fbshipit-source-id: 70734b8319e805736ba41aedaaf7fa3d463400c9
This commit is contained in:
Scott Wolchok 2021-10-08 18:06:50 -07:00 committed by Facebook GitHub Bot
parent 566922bbcd
commit 5a67ffe0ad
2 changed files with 21 additions and 16 deletions

View File

@ -1301,17 +1301,21 @@ ProcessedNode::ProcessedNode(
// TODO leverage type information // TODO leverage type information
outputs_.resize(node->outputs().size()); outputs_.resize(node->outputs().size());
if (enable_out_variant && (fn_ = getOutOfPlaceOperation(node))) { if (enable_out_variant) {
VLOG(1) << "Switch to out variant for node: " << PrintNode(node); if (OutVariant fn = getOutOfPlaceOperation(node)) {
return; fn_.emplace<0>(std::move(fn));
VLOG(1) << "Switch to out variant for node: " << PrintNode(node);
return;
}
} }
if (!fn_ && (native_fn_ = getNativeOperation(node))) { if (NativeFunction fn = getNativeOperation(node)) {
fn_.emplace<1>(std::move(fn));
VLOG(1) << "Switch to native impl for node: " << PrintNode(node); VLOG(1) << "Switch to native impl for node: " << PrintNode(node);
return; return;
} }
{ {
const Operator& op = node->getOperator(); const Operator& op = node->getOperator();
op_ = op.getOperation(node); fn_.emplace<2>(op.getOperation(node));
VLOG(1) << "Fallback interpreter for node: " << PrintNode(node); VLOG(1) << "Fallback interpreter for node: " << PrintNode(node);
} }
} }
@ -1329,10 +1333,10 @@ std::vector<IValue> ProcessedNode::clone_inputs() const {
void ProcessedNode::run_impl() { void ProcessedNode::run_impl() {
DCHECK(verify_no_memory_overlap()); DCHECK(verify_no_memory_overlap());
if (fn_) { if (fn_.index() == 0) {
fn_(this); c10::get<0>(fn_)(this);
} else if (native_fn_) { } else if (fn_.index() == 1) {
native_fn_(this); c10::get<1>(fn_)(this);
} else { } else {
std::vector<IValue> stack; std::vector<IValue> stack;
const size_t size = node_->inputs().size(); const size_t size = node_->inputs().size();
@ -1345,8 +1349,8 @@ void ProcessedNode::run_impl() {
stack.emplace_back(static_cast<int>(size)); stack.emplace_back(static_cast<int>(size));
} }
DCHECK(op_); DCHECK(fn_.index() == 2);
op_->operator()(stack); c10::get<2>(fn_)(stack);
DCHECK_EQ(stack.size(), node_->outputs().size()); DCHECK_EQ(stack.size(), node_->outputs().size());
for (const auto i : c10::irange(node_->outputs().size())) { for (const auto i : c10::irange(node_->outputs().size())) {

View File

@ -3,6 +3,7 @@
#include <ATen/core/interned_strings.h> #include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h> #include <ATen/core/ivalue.h>
#include <c10/core/CPUAllocator.h> #include <c10/core/CPUAllocator.h>
#include <c10/util/variant.h>
#include <torch/csrc/jit/api/module.h> #include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h> #include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/constant_propagation.h> #include <torch/csrc/jit/passes/constant_propagation.h>
@ -403,11 +404,11 @@ class TORCH_API ProcessedNode {
std::vector<IValue> clone_inputs() const; std::vector<IValue> clone_inputs() const;
bool has_out_variant() const { bool has_out_variant() const {
return static_cast<bool>(fn_); return fn_.index() == 0;
} }
bool has_native() const { bool has_native() const {
return static_cast<bool>(native_fn_); return fn_.index() == 1;
} }
bool verify_no_memory_overlap() const; bool verify_no_memory_overlap() const;
@ -420,9 +421,9 @@ class TORCH_API ProcessedNode {
void run_impl(); void run_impl();
Node* node_; Node* node_;
c10::optional<Operation> op_; using OutVariant = std::function<void(ProcessedNode*)>;
std::function<void(ProcessedNode*)> fn_; using NativeFunction = std::function<void(ProcessedNode*)>;
std::function<void(ProcessedNode*)> native_fn_; c10::variant<OutVariant, NativeFunction, Operation> fn_;
std::vector<const IValue*> inputs_; // unowned std::vector<const IValue*> inputs_; // unowned
std::vector<IValue> outputs_; std::vector<IValue> outputs_;
const char* op_name_; const char* op_name_;