mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[JIT] Add a phase to perform inplace<->functional conversion for activation operators (#57477)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57477 Currently the conversion only deals with activation operators. The legality check is somewhat strict for now. Test Plan: ``` python test/test_jit.py -k test_functional_to_inplace_activation python test/test_jit.py -k test_inplace_to_functional_activation ``` Reviewed By: mrshenli Differential Revision: D28155153 Pulled By: desertfire fbshipit-source-id: df092830c4dff3ce9578ff76285eb7a566b7d81b
This commit is contained in:
parent
91b7bcf4c0
commit
add291cf66
|
|
@ -38,6 +38,7 @@
|
|||
#include <torch/csrc/jit/passes/lower_tuples.h>
|
||||
#include <torch/csrc/jit/passes/pass_manager.h>
|
||||
#include <torch/csrc/jit/passes/requires_grad_analysis.h>
|
||||
#include <torch/csrc/jit/passes/restore_mutation.h>
|
||||
#include <torch/csrc/jit/passes/shape_analysis.h>
|
||||
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
||||
#include <torch/csrc/jit/runtime/argument_spec.h>
|
||||
|
|
@ -2593,5 +2594,39 @@ graph(%x.1 : Tensor):
|
|||
testing::FileCheck().check_not("aten::add_")->run(*graph);
|
||||
}
|
||||
|
||||
TEST(TestInplaceToFunctionalActivation, Basic) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%x.1 : Tensor):
|
||||
%2 : int = prim::Constant[value=1]()
|
||||
%x.3 : Tensor = aten::add(%x.1, %2, %2)
|
||||
%y : Tensor = aten::relu_(%x.3)
|
||||
return (%y))IR",
|
||||
&*graph,
|
||||
vmap);
|
||||
InplaceToFunctionalActivation(graph);
|
||||
testing::FileCheck().check("aten::relu")->run(*graph);
|
||||
testing::FileCheck().check_not("aten::relu_")->run(*graph);
|
||||
}
|
||||
|
||||
TEST(TestFunctionalToInplaceActivation, Basic) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%x.1 : Tensor):
|
||||
%2 : int = prim::Constant[value=1]()
|
||||
%x.3 : Tensor = aten::add(%x.1, %2, %2)
|
||||
%y : Tensor = aten::relu(%x.3)
|
||||
return (%y))IR",
|
||||
&*graph,
|
||||
vmap);
|
||||
FunctionalToInplaceActivation(graph);
|
||||
testing::FileCheck().check("aten::relu_")->run(*graph);
|
||||
testing::FileCheck().check_not("aten::relu(")->run(*graph);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
175
test/jit/test_convert_activation.py
Normal file
175
test/jit/test_convert_activation.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
from itertools import product
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.testing import FileCheck
|
||||
import unittest
|
||||
|
||||
try:
|
||||
import torchvision
|
||||
HAS_TORCHVISION = True
|
||||
except ImportError:
|
||||
HAS_TORCHVISION = False
|
||||
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
|
||||
activations = [
|
||||
F.celu,
|
||||
F.elu,
|
||||
F.hardsigmoid,
|
||||
F.hardswish,
|
||||
F.hardtanh,
|
||||
F.leaky_relu,
|
||||
F.relu,
|
||||
F.relu6,
|
||||
F.rrelu,
|
||||
F.selu,
|
||||
F.silu,
|
||||
]
|
||||
|
||||
class TestFunctionalToInplaceActivation(JitTestCase):
|
||||
def test_check_no_type_promotion(self):
|
||||
dtypes = [
|
||||
torch.bool,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
]
|
||||
# restore_mutation.h contains a mapping from activation operators
|
||||
# to whether they allow type conversion. Use this checking to
|
||||
# guard the mapping, and if any later change breaks the assumption
|
||||
# we need to update the mapping correspondingly.
|
||||
for activation, dtype in product(activations, dtypes):
|
||||
inp = torch.normal(0, 5, size=(4, 4)).to(dtype)
|
||||
try:
|
||||
out = activation(inp)
|
||||
self.assertEqual(dtype, out.dtype)
|
||||
except RuntimeError:
|
||||
# Skip the not implemented error
|
||||
pass
|
||||
|
||||
def test_functional_to_inplace_activation(self):
|
||||
for activation in activations:
|
||||
def test_basic(x):
|
||||
y = x + 1
|
||||
z = activation(y)
|
||||
return z
|
||||
|
||||
fn = torch.jit.script(test_basic)
|
||||
self.run_pass("inline", fn.graph)
|
||||
self.run_pass("constant_propagation", fn.graph)
|
||||
FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
|
||||
self.run_pass('functional_to_inplace_activation', fn.graph)
|
||||
FileCheck().check_not(f"aten::{activation.__name__}(").run(fn.graph)
|
||||
FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
|
||||
inp = torch.rand([2, 2])
|
||||
self.assertEqual(fn(inp), test_basic(inp))
|
||||
|
||||
def test_no_functional_to_inplace(self):
|
||||
# inplace conversion should not happen because sigmoid may
|
||||
# perform type conversion
|
||||
def test1():
|
||||
y = torch.ones([2, 2])
|
||||
z = torch.sigmoid(y)
|
||||
return z
|
||||
|
||||
fn = torch.jit.script(test1)
|
||||
self.run_pass('functional_to_inplace_activation', fn.graph)
|
||||
FileCheck().check_not("aten::sigmoid_").run(fn.graph)
|
||||
|
||||
# inplace conversion should not happen because y is alias
|
||||
# the input x
|
||||
def test2(x):
|
||||
y = x[0]
|
||||
z = torch.relu(y)
|
||||
return z
|
||||
|
||||
fn = torch.jit.script(test2)
|
||||
self.run_pass('functional_to_inplace_activation', fn.graph)
|
||||
FileCheck().check_not("aten::relu_").run(fn.graph)
|
||||
|
||||
# inplace conversion should not happen because self.x is
|
||||
# at the global scope
|
||||
class Test3(nn.Module):
|
||||
def __init__(self, x):
|
||||
super(Test3, self).__init__()
|
||||
self.x = x
|
||||
|
||||
def forward(self):
|
||||
y = torch.relu(self.x)
|
||||
return y
|
||||
|
||||
fn = torch.jit.script(Test3(torch.rand([2, 2])).eval())
|
||||
self.run_pass('functional_to_inplace_activation', fn.graph)
|
||||
FileCheck().check_not("aten::relu_").run(fn.graph)
|
||||
|
||||
@skipIfNoTorchVision
|
||||
def test_resnet18_correctness(self):
|
||||
model = torchvision.models.resnet18()
|
||||
frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
|
||||
N, C, H, W, = 10, 3, 224, 224
|
||||
inp = torch.randn(N, C, H, W)
|
||||
self.run_pass('functional_to_inplace_activation', frozen_model.graph)
|
||||
self.assertEqual(model(inp), frozen_model(inp))
|
||||
|
||||
|
||||
class TestInplaceToFunctionalActivation(JitTestCase):
|
||||
def test_inplace_to_functional_activation(self):
|
||||
for activation in activations:
|
||||
def test_basic(x):
|
||||
y = x + 1
|
||||
activation(y, inplace=True)
|
||||
return y
|
||||
|
||||
fn = torch.jit.script(test_basic)
|
||||
self.run_pass("inline", fn.graph)
|
||||
self.run_pass("constant_propagation", fn.graph)
|
||||
FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
|
||||
self.run_pass('inplace_to_functional_activation', fn.graph)
|
||||
FileCheck().check_not(f"aten::{activation.__name__}_").run(fn.graph)
|
||||
FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
|
||||
|
||||
for activation in [
|
||||
torch.relu_,
|
||||
torch.sigmoid_,
|
||||
torch.tanh_,
|
||||
]:
|
||||
def test_basic(x):
|
||||
y = x + 1
|
||||
activation(y)
|
||||
return y
|
||||
|
||||
fn = torch.jit.script(test_basic)
|
||||
self.run_pass("inline", fn.graph)
|
||||
self.run_pass("constant_propagation", fn.graph)
|
||||
FileCheck().check(f"aten::{activation.__name__}").run(fn.graph)
|
||||
self.run_pass('inplace_to_functional_activation', fn.graph)
|
||||
FileCheck().check_not(f"aten::{activation.__name__}").run(fn.graph)
|
||||
FileCheck().check(f"aten::{activation.__name__[:-1]}(").run(fn.graph)
|
||||
|
||||
inp = torch.rand([2, 2])
|
||||
self.assertEqual(fn(inp), test_basic(inp))
|
||||
|
||||
@skipIfNoTorchVision
|
||||
def test_resnet18_correctness(self):
|
||||
model = torchvision.models.resnet18()
|
||||
frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
|
||||
N, C, H, W, = 10, 3, 224, 224
|
||||
inp = torch.randn(N, C, H, W)
|
||||
self.run_pass('inplace_to_functional_activation', frozen_model.graph)
|
||||
self.assertEqual(model(inp), frozen_model(inp))
|
||||
|
|
@ -55,6 +55,7 @@ from jit.test_pdt import TestPDT # noqa: F401
|
|||
from jit.test_tensor_creation_ops import TestTensorCreationOps # noqa: F401
|
||||
from jit.test_module_apis import TestModuleAPIs # noqa: F401
|
||||
from jit.test_script_profile import TestScriptProfile # noqa: F401
|
||||
from jit.test_convert_activation import TestFunctionalToInplaceActivation, TestInplaceToFunctionalActivation # noqa: F401
|
||||
|
||||
# Torch
|
||||
from torch import Tensor
|
||||
|
|
|
|||
|
|
@ -182,6 +182,7 @@ core_sources_full_mobile = [
|
|||
"torch/csrc/jit/passes/concat_opt.cpp",
|
||||
"torch/csrc/jit/passes/constant_pooling.cpp",
|
||||
"torch/csrc/jit/passes/constant_propagation.cpp",
|
||||
"torch/csrc/jit/passes/restore_mutation.cpp",
|
||||
"torch/csrc/jit/passes/create_autodiff_subgraphs.cpp",
|
||||
"torch/csrc/jit/passes/dead_code_elimination.cpp",
|
||||
"torch/csrc/jit/passes/remove_redundant_profiles.cpp",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#include <torch/csrc/jit/passes/remove_mutation.h>
|
||||
#include <torch/csrc/jit/passes/restore_mutation.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
|
@ -11,7 +12,7 @@ bool MutationRemover::removeTensorMutation() {
|
|||
return RemoveTensorMutation(graph_->block());
|
||||
}
|
||||
|
||||
bool MutationRemover::newMemoryLocation(Value* v) {
|
||||
bool MutationRemover::hasSideEffectOrAlias(Value* v, AliasDb* aliasDb) {
|
||||
// bail on nodes with side effects, blocks, or graph / graph inputs
|
||||
Node* n = v->node();
|
||||
bool unhandled_node = n->blocks().size() != 0 ||
|
||||
|
|
@ -20,9 +21,8 @@ bool MutationRemover::newMemoryLocation(Value* v) {
|
|||
|
||||
// if the output isn't contained or alias by the inputs to its node, it's
|
||||
// unique
|
||||
return !unhandled_node &&
|
||||
!getOrCreateAliasDb()->mayContainAlias(v->node()->inputs(), v) &&
|
||||
!(v->node()->kind() == prim::Param);
|
||||
return unhandled_node || aliasDb->mayContainAlias(v->node()->inputs(), v) ||
|
||||
(v->node()->kind() == prim::Param);
|
||||
}
|
||||
|
||||
Node* MutationRemover::createSpecialMappedOp(Node* n) {
|
||||
|
|
@ -81,7 +81,7 @@ bool MutationRemover::tryMakeCreationAndMutationAtomic(
|
|||
// We can only remove mutation to values that are unique aliases in the
|
||||
// graph. if x = y[0] or y = self.y, then removing the mutation could
|
||||
// change observable semantics
|
||||
if (!newMemoryLocation(mutated_value)) {
|
||||
if (hasSideEffectOrAlias(mutated_value, getOrCreateAliasDb())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -116,7 +116,8 @@ bool MutationRemover::tryMakeUnaliasedIfOutputAndMutationAtomic(
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!newMemoryLocation(true_value) || !newMemoryLocation(false_value)) {
|
||||
if (hasSideEffectOrAlias(true_value, getOrCreateAliasDb()) ||
|
||||
hasSideEffectOrAlias(false_value, getOrCreateAliasDb())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -329,5 +330,20 @@ bool RemoveTensorMutation(
|
|||
return mr.removeTensorMutation();
|
||||
}
|
||||
|
||||
static const std::unordered_set<Symbol> activation_ops = []() {
|
||||
std::unordered_set<Symbol> target_ops;
|
||||
for (const auto& iter : activation_type_promotion_mapping) {
|
||||
std::string name = std::string(iter.first.toQualString()) + "_";
|
||||
target_ops.insert(Symbol::fromQualString(name));
|
||||
}
|
||||
return target_ops;
|
||||
}();
|
||||
|
||||
bool InplaceToFunctionalActivation(const std::shared_ptr<Graph>& graph) {
|
||||
return RemoveTensorMutation(graph, [](Node* node) {
|
||||
return activation_ops.count(node->kind()) != 0;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -33,8 +33,9 @@ struct TORCH_API MutationRemover {
|
|||
|
||||
bool inplaceOpVariant(Node* n);
|
||||
|
||||
static bool hasSideEffectOrAlias(Value* v, AliasDb* aliasDb);
|
||||
|
||||
private:
|
||||
bool newMemoryLocation(Value* v);
|
||||
Node* createSpecialMappedOp(Node* n);
|
||||
bool listMutationFollowingListConstruct(Node* n);
|
||||
bool tryMakeCreationAndMutationAtomic(
|
||||
|
|
@ -73,5 +74,9 @@ TORCH_API bool RemoveTensorMutation(
|
|||
const std::shared_ptr<Graph>& graph,
|
||||
c10::optional<std::function<bool(Node*)>> mutation_filter = c10::nullopt);
|
||||
|
||||
// Replaces in-place aten activation ops with their functional equivalence
|
||||
TORCH_API bool InplaceToFunctionalActivation(
|
||||
const std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
85
torch/csrc/jit/passes/restore_mutation.cpp
Normal file
85
torch/csrc/jit/passes/restore_mutation.cpp
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
#include <ATen/core/interned_strings.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <torch/csrc/jit/passes/remove_mutation.h>
|
||||
#include <torch/csrc/jit/passes/restore_mutation.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
FunctionalToInplaceRewriter::FunctionalToInplaceRewriter(
|
||||
std::shared_ptr<Graph> graph)
|
||||
: aliasDb_(nullptr), graph_(std::move(graph)) {}
|
||||
|
||||
bool FunctionalToInplaceRewriter::CanBeInplace(Node* node) {
|
||||
if (activation_type_promotion_mapping.find(node->kind()) ==
|
||||
activation_type_promotion_mapping.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Symbol inplace_op =
|
||||
Symbol::fromQualString(std::string(node->kind().toQualString()) + "_");
|
||||
if (!inplace_op) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If type promotion is allowed, then perform dtype check
|
||||
bool check_dtype = activation_type_promotion_mapping.at(node->kind());
|
||||
|
||||
Value* input = node->inputs().at(0);
|
||||
Value* output = node->outputs().at(0);
|
||||
auto inputDtype = input->type()->expect<TensorType>()->scalarType();
|
||||
auto outputDtype = output->type()->expect<TensorType>()->scalarType();
|
||||
|
||||
// In general, we don't need to check shape for activation ops as they
|
||||
// element-wise. But for those where type promotion could happen, we need to
|
||||
// make sure the dtype of input and output are the same. For now the dtype
|
||||
// checking will always fail until the type inference is ready.
|
||||
if (check_dtype &&
|
||||
(!inputDtype || !outputDtype ||
|
||||
inputDtype.value() != outputDtype.value())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Skip if input's def node has side effect or input has alias
|
||||
if (MutationRemover::hasSideEffectOrAlias(input, getOrCreateAliasDb())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If x has more than one use, skip the converson.
|
||||
// TODO: Use liveness analysis to catch more general scenario
|
||||
return (input->uses().size() == 1);
|
||||
}
|
||||
|
||||
bool FunctionalToInplaceRewriter::FunctionalToInplace(Block* block) {
|
||||
bool changed = false;
|
||||
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
|
||||
auto* node = *it;
|
||||
it++;
|
||||
|
||||
for (Block* sub_block : node->blocks()) {
|
||||
changed |= FunctionalToInplace(sub_block);
|
||||
}
|
||||
|
||||
if (!CanBeInplace(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
changed = true;
|
||||
Node* inplace_node = node->replaceWithNewSymbol(
|
||||
Symbol::fromQualString(node->schema().name() + "_"));
|
||||
inplace_node->output()->replaceAllUsesWith(node->inputs().at(0));
|
||||
getOrCreateAliasDb()->replaceWithNewValue(
|
||||
node->output(), inplace_node->output());
|
||||
|
||||
node->destroy();
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool FunctionalToInplaceActivation(const std::shared_ptr<Graph>& graph) {
|
||||
FunctionalToInplaceRewriter rewriter(graph);
|
||||
return rewriter.FunctionalToInplace(graph->block());
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
64
torch/csrc/jit/passes/restore_mutation.h
Normal file
64
torch/csrc/jit/passes/restore_mutation.h
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/core/interned_strings.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <torch/csrc/jit/ir/alias_analysis.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/utils/memory.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// A map which stores if an activation operator can perform type promotion
|
||||
const std::unordered_map<Symbol, bool> activation_type_promotion_mapping = {
|
||||
{aten::sigmoid, true},
|
||||
{aten::tanh, true},
|
||||
{aten::celu, false},
|
||||
{aten::elu, false},
|
||||
{aten::gelu, false},
|
||||
{aten::glu, false},
|
||||
{aten::hardshrink, false},
|
||||
{aten::hardsigmoid, false},
|
||||
{aten::hardswish, false},
|
||||
{aten::hardtanh, false},
|
||||
{aten::leaky_relu, false},
|
||||
{aten::prelu, false},
|
||||
{aten::relu6, false},
|
||||
{aten::relu, false},
|
||||
{aten::rrelu, false},
|
||||
{aten::selu, false},
|
||||
{aten::silu, false}};
|
||||
|
||||
class FunctionalToInplaceRewriter {
|
||||
public:
|
||||
FunctionalToInplaceRewriter(std::shared_ptr<Graph> graph);
|
||||
|
||||
bool FunctionalToInplace(Block* block);
|
||||
|
||||
private:
|
||||
AliasDb* getOrCreateAliasDb() {
|
||||
if (!aliasDb_) {
|
||||
aliasDb_ = std::make_unique<AliasDb>(graph_);
|
||||
}
|
||||
return aliasDb_.get();
|
||||
}
|
||||
|
||||
bool CanBeInplace(Node* node);
|
||||
|
||||
std::unique_ptr<AliasDb> aliasDb_ = nullptr;
|
||||
std::shared_ptr<Graph> graph_;
|
||||
};
|
||||
|
||||
// A common application scenario is to apply InplaceToFunctionalActivation
|
||||
// before some JIT optimization passes, so that those passes are less
|
||||
// constrained by in-place ops. After those passes are done, we can call
|
||||
// FunctionalToInplaceActivation to recover in-place activation ops,
|
||||
// so that we won't lose the performance benefit coming from memory reduction.
|
||||
|
||||
// Replaces functional aten activation ops with their in-place equivalents
|
||||
TORCH_API bool FunctionalToInplaceActivation(
|
||||
const std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -67,6 +67,7 @@
|
|||
#include <torch/csrc/jit/passes/remove_expands.h>
|
||||
#include <torch/csrc/jit/passes/remove_inplace_ops.h>
|
||||
#include <torch/csrc/jit/passes/remove_mutation.h>
|
||||
#include <torch/csrc/jit/passes/restore_mutation.h>
|
||||
#include <torch/csrc/jit/passes/shape_analysis.h>
|
||||
#include <torch/csrc/jit/passes/specialize_autogradzero.h>
|
||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||
|
|
@ -419,6 +420,16 @@ void initJITBindings(PyObject* module) {
|
|||
RemoveListMutation(g);
|
||||
return RemoveTensorMutation(g);
|
||||
})
|
||||
.def(
|
||||
"_jit_pass_functional_to_inplace_activation",
|
||||
[](std::shared_ptr<Graph>& g) {
|
||||
return FunctionalToInplaceActivation(g);
|
||||
})
|
||||
.def(
|
||||
"_jit_pass_inplace_to_functional_activation",
|
||||
[](std::shared_ptr<Graph>& g) {
|
||||
return InplaceToFunctionalActivation(g);
|
||||
})
|
||||
.def(
|
||||
"_jit_pass_inline_functional_graphs",
|
||||
[](std::shared_ptr<Graph>& g) { return InlineFunctionalGraphs(g); })
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user