mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PyTorch][Static Runtime] Combine ProcessedNode::{native_,}fn_ (#65414)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65414 Saves 24 bytes (`sizeof(std::function) - 8`) per ProcessedNode. ghstack-source-id: 139999909 Test Plan: CI Reviewed By: hlu1 Differential Revision: D31085561 fbshipit-source-id: 70734b8319e805736ba41aedaaf7fa3d463400c9
This commit is contained in:
parent
566922bbcd
commit
5a67ffe0ad
|
|
@ -1301,17 +1301,21 @@ ProcessedNode::ProcessedNode(
|
|||
// TODO leverage type information
|
||||
outputs_.resize(node->outputs().size());
|
||||
|
||||
if (enable_out_variant && (fn_ = getOutOfPlaceOperation(node))) {
|
||||
VLOG(1) << "Switch to out variant for node: " << PrintNode(node);
|
||||
return;
|
||||
if (enable_out_variant) {
|
||||
if (OutVariant fn = getOutOfPlaceOperation(node)) {
|
||||
fn_.emplace<0>(std::move(fn));
|
||||
VLOG(1) << "Switch to out variant for node: " << PrintNode(node);
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (!fn_ && (native_fn_ = getNativeOperation(node))) {
|
||||
if (NativeFunction fn = getNativeOperation(node)) {
|
||||
fn_.emplace<1>(std::move(fn));
|
||||
VLOG(1) << "Switch to native impl for node: " << PrintNode(node);
|
||||
return;
|
||||
}
|
||||
{
|
||||
const Operator& op = node->getOperator();
|
||||
op_ = op.getOperation(node);
|
||||
fn_.emplace<2>(op.getOperation(node));
|
||||
VLOG(1) << "Fallback interpreter for node: " << PrintNode(node);
|
||||
}
|
||||
}
|
||||
|
|
@ -1329,10 +1333,10 @@ std::vector<IValue> ProcessedNode::clone_inputs() const {
|
|||
|
||||
void ProcessedNode::run_impl() {
|
||||
DCHECK(verify_no_memory_overlap());
|
||||
if (fn_) {
|
||||
fn_(this);
|
||||
} else if (native_fn_) {
|
||||
native_fn_(this);
|
||||
if (fn_.index() == 0) {
|
||||
c10::get<0>(fn_)(this);
|
||||
} else if (fn_.index() == 1) {
|
||||
c10::get<1>(fn_)(this);
|
||||
} else {
|
||||
std::vector<IValue> stack;
|
||||
const size_t size = node_->inputs().size();
|
||||
|
|
@ -1345,8 +1349,8 @@ void ProcessedNode::run_impl() {
|
|||
stack.emplace_back(static_cast<int>(size));
|
||||
}
|
||||
|
||||
DCHECK(op_);
|
||||
op_->operator()(stack);
|
||||
DCHECK(fn_.index() == 2);
|
||||
c10::get<2>(fn_)(stack);
|
||||
|
||||
DCHECK_EQ(stack.size(), node_->outputs().size());
|
||||
for (const auto i : c10::irange(node_->outputs().size())) {
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include <ATen/core/interned_strings.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
#include <c10/util/variant.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/passes/constant_propagation.h>
|
||||
|
|
@ -403,11 +404,11 @@ class TORCH_API ProcessedNode {
|
|||
std::vector<IValue> clone_inputs() const;
|
||||
|
||||
bool has_out_variant() const {
|
||||
return static_cast<bool>(fn_);
|
||||
return fn_.index() == 0;
|
||||
}
|
||||
|
||||
bool has_native() const {
|
||||
return static_cast<bool>(native_fn_);
|
||||
return fn_.index() == 1;
|
||||
}
|
||||
|
||||
bool verify_no_memory_overlap() const;
|
||||
|
|
@ -420,9 +421,9 @@ class TORCH_API ProcessedNode {
|
|||
void run_impl();
|
||||
|
||||
Node* node_;
|
||||
c10::optional<Operation> op_;
|
||||
std::function<void(ProcessedNode*)> fn_;
|
||||
std::function<void(ProcessedNode*)> native_fn_;
|
||||
using OutVariant = std::function<void(ProcessedNode*)>;
|
||||
using NativeFunction = std::function<void(ProcessedNode*)>;
|
||||
c10::variant<OutVariant, NativeFunction, Operation> fn_;
|
||||
std::vector<const IValue*> inputs_; // unowned
|
||||
std::vector<IValue> outputs_;
|
||||
const char* op_name_;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user