mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PyTorch][Static Runtime] Combine ProcessedNode::{native_,}fn_ (#65414)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65414 Saves 24 bytes (`sizeof(std::function) - 8`) per ProcessedNode. ghstack-source-id: 139999909 Test Plan: CI Reviewed By: hlu1 Differential Revision: D31085561 fbshipit-source-id: 70734b8319e805736ba41aedaaf7fa3d463400c9
This commit is contained in:
parent
566922bbcd
commit
5a67ffe0ad
|
|
@ -1301,17 +1301,21 @@ ProcessedNode::ProcessedNode(
|
||||||
// TODO leverage type information
|
// TODO leverage type information
|
||||||
outputs_.resize(node->outputs().size());
|
outputs_.resize(node->outputs().size());
|
||||||
|
|
||||||
if (enable_out_variant && (fn_ = getOutOfPlaceOperation(node))) {
|
if (enable_out_variant) {
|
||||||
VLOG(1) << "Switch to out variant for node: " << PrintNode(node);
|
if (OutVariant fn = getOutOfPlaceOperation(node)) {
|
||||||
return;
|
fn_.emplace<0>(std::move(fn));
|
||||||
|
VLOG(1) << "Switch to out variant for node: " << PrintNode(node);
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (!fn_ && (native_fn_ = getNativeOperation(node))) {
|
if (NativeFunction fn = getNativeOperation(node)) {
|
||||||
|
fn_.emplace<1>(std::move(fn));
|
||||||
VLOG(1) << "Switch to native impl for node: " << PrintNode(node);
|
VLOG(1) << "Switch to native impl for node: " << PrintNode(node);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const Operator& op = node->getOperator();
|
const Operator& op = node->getOperator();
|
||||||
op_ = op.getOperation(node);
|
fn_.emplace<2>(op.getOperation(node));
|
||||||
VLOG(1) << "Fallback interpreter for node: " << PrintNode(node);
|
VLOG(1) << "Fallback interpreter for node: " << PrintNode(node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1329,10 +1333,10 @@ std::vector<IValue> ProcessedNode::clone_inputs() const {
|
||||||
|
|
||||||
void ProcessedNode::run_impl() {
|
void ProcessedNode::run_impl() {
|
||||||
DCHECK(verify_no_memory_overlap());
|
DCHECK(verify_no_memory_overlap());
|
||||||
if (fn_) {
|
if (fn_.index() == 0) {
|
||||||
fn_(this);
|
c10::get<0>(fn_)(this);
|
||||||
} else if (native_fn_) {
|
} else if (fn_.index() == 1) {
|
||||||
native_fn_(this);
|
c10::get<1>(fn_)(this);
|
||||||
} else {
|
} else {
|
||||||
std::vector<IValue> stack;
|
std::vector<IValue> stack;
|
||||||
const size_t size = node_->inputs().size();
|
const size_t size = node_->inputs().size();
|
||||||
|
|
@ -1345,8 +1349,8 @@ void ProcessedNode::run_impl() {
|
||||||
stack.emplace_back(static_cast<int>(size));
|
stack.emplace_back(static_cast<int>(size));
|
||||||
}
|
}
|
||||||
|
|
||||||
DCHECK(op_);
|
DCHECK(fn_.index() == 2);
|
||||||
op_->operator()(stack);
|
c10::get<2>(fn_)(stack);
|
||||||
|
|
||||||
DCHECK_EQ(stack.size(), node_->outputs().size());
|
DCHECK_EQ(stack.size(), node_->outputs().size());
|
||||||
for (const auto i : c10::irange(node_->outputs().size())) {
|
for (const auto i : c10::irange(node_->outputs().size())) {
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
#include <ATen/core/interned_strings.h>
|
#include <ATen/core/interned_strings.h>
|
||||||
#include <ATen/core/ivalue.h>
|
#include <ATen/core/ivalue.h>
|
||||||
#include <c10/core/CPUAllocator.h>
|
#include <c10/core/CPUAllocator.h>
|
||||||
|
#include <c10/util/variant.h>
|
||||||
#include <torch/csrc/jit/api/module.h>
|
#include <torch/csrc/jit/api/module.h>
|
||||||
#include <torch/csrc/jit/ir/ir.h>
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
#include <torch/csrc/jit/passes/constant_propagation.h>
|
#include <torch/csrc/jit/passes/constant_propagation.h>
|
||||||
|
|
@ -403,11 +404,11 @@ class TORCH_API ProcessedNode {
|
||||||
std::vector<IValue> clone_inputs() const;
|
std::vector<IValue> clone_inputs() const;
|
||||||
|
|
||||||
bool has_out_variant() const {
|
bool has_out_variant() const {
|
||||||
return static_cast<bool>(fn_);
|
return fn_.index() == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool has_native() const {
|
bool has_native() const {
|
||||||
return static_cast<bool>(native_fn_);
|
return fn_.index() == 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool verify_no_memory_overlap() const;
|
bool verify_no_memory_overlap() const;
|
||||||
|
|
@ -420,9 +421,9 @@ class TORCH_API ProcessedNode {
|
||||||
void run_impl();
|
void run_impl();
|
||||||
|
|
||||||
Node* node_;
|
Node* node_;
|
||||||
c10::optional<Operation> op_;
|
using OutVariant = std::function<void(ProcessedNode*)>;
|
||||||
std::function<void(ProcessedNode*)> fn_;
|
using NativeFunction = std::function<void(ProcessedNode*)>;
|
||||||
std::function<void(ProcessedNode*)> native_fn_;
|
c10::variant<OutVariant, NativeFunction, Operation> fn_;
|
||||||
std::vector<const IValue*> inputs_; // unowned
|
std::vector<const IValue*> inputs_; // unowned
|
||||||
std::vector<IValue> outputs_;
|
std::vector<IValue> outputs_;
|
||||||
const char* op_name_;
|
const char* op_name_;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user