[PyTorch][Static Runtime] Combine ProcessedNode::{native_,}fn_ (#65414)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65414

Saves 24 bytes (`sizeof(std::function) - 8`) per ProcessedNode.
ghstack-source-id: 139999909

Test Plan: CI

Reviewed By: hlu1

Differential Revision: D31085561

fbshipit-source-id: 70734b8319e805736ba41aedaaf7fa3d463400c9
This commit is contained in:
Scott Wolchok 2021-10-08 18:06:50 -07:00 committed by Facebook GitHub Bot
parent 566922bbcd
commit 5a67ffe0ad
2 changed files with 21 additions and 16 deletions

View File

@ -1301,17 +1301,21 @@ ProcessedNode::ProcessedNode(
// TODO leverage type information
outputs_.resize(node->outputs().size());
if (enable_out_variant && (fn_ = getOutOfPlaceOperation(node))) {
VLOG(1) << "Switch to out variant for node: " << PrintNode(node);
return;
if (enable_out_variant) {
if (OutVariant fn = getOutOfPlaceOperation(node)) {
fn_.emplace<0>(std::move(fn));
VLOG(1) << "Switch to out variant for node: " << PrintNode(node);
return;
}
}
if (!fn_ && (native_fn_ = getNativeOperation(node))) {
if (NativeFunction fn = getNativeOperation(node)) {
fn_.emplace<1>(std::move(fn));
VLOG(1) << "Switch to native impl for node: " << PrintNode(node);
return;
}
{
const Operator& op = node->getOperator();
op_ = op.getOperation(node);
fn_.emplace<2>(op.getOperation(node));
VLOG(1) << "Fallback interpreter for node: " << PrintNode(node);
}
}
@ -1329,10 +1333,10 @@ std::vector<IValue> ProcessedNode::clone_inputs() const {
void ProcessedNode::run_impl() {
DCHECK(verify_no_memory_overlap());
if (fn_) {
fn_(this);
} else if (native_fn_) {
native_fn_(this);
if (fn_.index() == 0) {
c10::get<0>(fn_)(this);
} else if (fn_.index() == 1) {
c10::get<1>(fn_)(this);
} else {
std::vector<IValue> stack;
const size_t size = node_->inputs().size();
@ -1345,8 +1349,8 @@ void ProcessedNode::run_impl() {
stack.emplace_back(static_cast<int>(size));
}
DCHECK(op_);
op_->operator()(stack);
DCHECK(fn_.index() == 2);
c10::get<2>(fn_)(stack);
DCHECK_EQ(stack.size(), node_->outputs().size());
for (const auto i : c10::irange(node_->outputs().size())) {

View File

@ -3,6 +3,7 @@
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
#include <c10/core/CPUAllocator.h>
#include <c10/util/variant.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
@ -403,11 +404,11 @@ class TORCH_API ProcessedNode {
std::vector<IValue> clone_inputs() const;
bool has_out_variant() const {
return static_cast<bool>(fn_);
return fn_.index() == 0;
}
bool has_native() const {
return static_cast<bool>(native_fn_);
return fn_.index() == 1;
}
bool verify_no_memory_overlap() const;
@ -420,9 +421,9 @@ class TORCH_API ProcessedNode {
void run_impl();
Node* node_;
c10::optional<Operation> op_;
std::function<void(ProcessedNode*)> fn_;
std::function<void(ProcessedNode*)> native_fn_;
using OutVariant = std::function<void(ProcessedNode*)>;
using NativeFunction = std::function<void(ProcessedNode*)>;
c10::variant<OutVariant, NativeFunction, Operation> fn_;
std::vector<const IValue*> inputs_; // unowned
std::vector<IValue> outputs_;
const char* op_name_;