mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
add more fusable nodes to the graph compiler (#3559)
This commit is contained in:
parent
285ce10dbe
commit
25d3c25f50
|
|
@ -16,6 +16,68 @@
|
|||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
std::unordered_map<NodeKind, std::string> simple_map_ops = {
|
||||
// unary
|
||||
{kabs, "absf(${0})"},
|
||||
{ksigmoid, "1.f / (1.f + expf(-${0}))"},
|
||||
{klog, "logf(${0})"},
|
||||
{klog1p, "log1pf(${0})"},
|
||||
{klgamma, "lgammaf(${0})"},
|
||||
{kexp, "expf(${0})"},
|
||||
{kcos, "cosf(${0})"},
|
||||
{kacos, "acosf(${0})"},
|
||||
{kcosh, "coshf(${0})"},
|
||||
{ksin, "sinf(${0})"},
|
||||
{kasin, "asinf(${0})"},
|
||||
{ksinh, "sinhf(${0})"},
|
||||
{ktan, "tanf(${0})"},
|
||||
{katan, "atanf(${0})"},
|
||||
{ktanh, "tanhf(${0})"},
|
||||
{ksqrt, "sqrtf(${0})"},
|
||||
{krsqrt, "rsqrtf(${0})"},
|
||||
{kceil, "ceilf(${0})"},
|
||||
{kfloor, "floorf(${0})"},
|
||||
{kround, "roundf(${0})"},
|
||||
{ktrunc, "truncf(${0})"},
|
||||
{kfrac, "fracf(${0})"},
|
||||
{kreciprocal, "reciprocalf(${0})"},
|
||||
{kneg, "-${0}"},
|
||||
//simple binary
|
||||
{katan2, "atan2(${0}, ${1})"},
|
||||
{kmin, "fminf(${0}, ${1})"},
|
||||
{kmax, "fmaxf(${0}, ${1})"},
|
||||
|
||||
//binary with other
|
||||
// TODO: some of these ops will not get generated because
|
||||
// we only work on float inputs/outputs, but they are here to record
|
||||
// that they are valid mappable ops once we handle more type
|
||||
{k__and__, "${0} && ${1}"},
|
||||
{k__lshift__, "${0} << ${1}"},
|
||||
{k__or__, "${0} || ${1}"},
|
||||
{k__rshift__, "${0} >> ${1}"},
|
||||
{k__xor__, "${0} ^ ${1}"},
|
||||
{kdiv, "${0} / ${1}"},
|
||||
{keq, "${0} == ${1}"},
|
||||
{kfmod, "fmodf(${0}, ${1})"},
|
||||
{kge, "${0} >= ${1})"},
|
||||
{kgt, "${0} > ${1}"},
|
||||
{kle, "${0} <= ${1})"},
|
||||
{klt, "${0} < ${1}"},
|
||||
{kmul, "${0} * ${1}"},
|
||||
{kne, "${0} != ${1}"},
|
||||
{kremainder, "remainderf(${0}, ${1})"},
|
||||
{kpow, "powf(${0}, ${1})"},
|
||||
|
||||
//alpha
|
||||
{kadd, "${0} + ${alpha}*${1}"},
|
||||
{ksub, "${0} - ${alpha}*${1})"},
|
||||
|
||||
// special
|
||||
{klerp, "${0} + ${weight}*(${1} - ${0})"},
|
||||
{kclamp, "min(max(${0},${min}),${max})"},
|
||||
|
||||
};
|
||||
|
||||
std::vector<bool> TensorDesc::findContiguous(
|
||||
const at::IntList& sizes,
|
||||
const at::IntList& strides) {
|
||||
|
|
@ -136,21 +198,6 @@ std::string nodeName(Node * n) {
|
|||
std::to_string(s.toDouble());
|
||||
}
|
||||
|
||||
// TODO: we need to support double-precision
|
||||
std::unordered_map<NodeKind,std::function<std::string(Node*)>> simple_map_ops = {
|
||||
{ksigmoid, [](Node*) { return "1.f / (1.f + expf(-${0}))"; }},
|
||||
{ktanh, [](Node*) { return "tanhf(${0})"; }},
|
||||
{kmul, [](Node*) { return "${0} * ${1}"; }},
|
||||
{kadd, [](Node*n) -> std::string {
|
||||
if(n->inputs().size() == 2)
|
||||
return "${0} + ${1}";
|
||||
else
|
||||
return "${0} + " + scalarValue(n->t(kother));
|
||||
}},
|
||||
{kneg, [](Node*) { return "(-${0})"; }},
|
||||
|
||||
};
|
||||
|
||||
const char * scalarTypeName(at::ScalarType type) {
|
||||
switch(type) {
|
||||
#define DEFINE_CASE(ctype,name,_) \
|
||||
|
|
@ -162,6 +209,31 @@ const char * scalarTypeName(at::ScalarType type) {
|
|||
}
|
||||
}
|
||||
|
||||
std::string encodeRHS(Node * n) {
|
||||
TemplateEnv env;
|
||||
size_t i = 0;
|
||||
for(auto in : n->inputs()) {
|
||||
env.s(std::to_string(i++),nodeName(in));
|
||||
}
|
||||
// ops like div have a / b or a / 2 with the constant having the attribute other
|
||||
// so we add other as an input if it is present
|
||||
// 'pow' is the same but uses exponent as the attribute, so we handle that here as well
|
||||
if(n->hasAttribute(kother) || n->hasAttribute(kexponent)) {
|
||||
env.s(std::to_string(i), scalarValue(n->t(kother)));
|
||||
}
|
||||
// we also add any other scalar tensors to the env for special ops
|
||||
for(auto a : n->attributeNames()) {
|
||||
if(n->kindOf(a) == AttributeKind::t) {
|
||||
auto v = n->t(a);
|
||||
if(v.dim() == 0) {
|
||||
env.s(symbolToString(a), scalarValue(v));
|
||||
}
|
||||
}
|
||||
}
|
||||
const auto & str = simple_map_ops.at(n->kind());
|
||||
return format(str, env);
|
||||
}
|
||||
|
||||
std::vector<ConcatDesc> emitCompilationUnit(std::ostream & out,
|
||||
const std::string & name,
|
||||
AnnotatedGraph & agraph) {
|
||||
|
|
@ -219,12 +291,8 @@ std::vector<ConcatDesc> emitCompilationUnit(std::ostream & out,
|
|||
for(auto n : subgraph.nodes()) {
|
||||
if(n->kind() == kcat)
|
||||
continue; // Concat nodes by narrowing the output Tensors before the kernel runs
|
||||
size_t i = 0;
|
||||
for(auto in : n->inputs()) {
|
||||
env.s(std::to_string(i++),nodeName(in));
|
||||
}
|
||||
env.s("node",nodeName(n));
|
||||
env.s("rhs",format(simple_map_ops.at(n->kind())(n),env));
|
||||
env.s("rhs", encodeRHS(n));
|
||||
body << format("auto ${node} = ${rhs};\n",env);
|
||||
}
|
||||
for(auto o : flat_output_nodes) {
|
||||
|
|
|
|||
|
|
@ -75,7 +75,53 @@ _(shape) \
|
|||
_(axes) \
|
||||
_(group) \
|
||||
_(inplace) \
|
||||
_(other)
|
||||
_(other) \
|
||||
_(__and__) \
|
||||
_(__lshift__) \
|
||||
_(__or__) \
|
||||
_(__rshift__) \
|
||||
_(__xor__) \
|
||||
_(abs) \
|
||||
_(acos) \
|
||||
_(asin) \
|
||||
_(atan) \
|
||||
_(atan2) \
|
||||
_(ceil) \
|
||||
_(clamp) \
|
||||
_(cos) \
|
||||
_(cosh) \
|
||||
_(div) \
|
||||
_(eq) \
|
||||
_(equal) \
|
||||
_(exp) \
|
||||
_(floor) \
|
||||
_(fmod) \
|
||||
_(frac) \
|
||||
_(ge) \
|
||||
_(gt) \
|
||||
_(le) \
|
||||
_(lerp) \
|
||||
_(lgamma) \
|
||||
_(log) \
|
||||
_(log1p) \
|
||||
_(lt) \
|
||||
_(max) \
|
||||
_(min) \
|
||||
_(ne) \
|
||||
_(ones) \
|
||||
_(pow) \
|
||||
_(reciprocal) \
|
||||
_(remainder) \
|
||||
_(round) \
|
||||
_(rsqrt) \
|
||||
_(sin) \
|
||||
_(sinh) \
|
||||
_(sqrt) \
|
||||
_(sub) \
|
||||
_(tan) \
|
||||
_(trunc) \
|
||||
_(zeros) \
|
||||
_(exponent)
|
||||
|
||||
enum BuiltinSymbol {
|
||||
#define DEFINE_SYMBOL(s) \
|
||||
|
|
|
|||
|
|
@ -13,15 +13,64 @@ namespace {
|
|||
// Some of these restrictions may be relaxable, but you should
|
||||
// carefully read the code first, as we rely on these assumptions.
|
||||
std::unordered_set<NodeKind> simple_mappable = {
|
||||
ksigmoid,
|
||||
ktanh,
|
||||
kmul,
|
||||
k__and__,
|
||||
k__lshift__,
|
||||
k__or__,
|
||||
k__rshift__,
|
||||
k__xor__,
|
||||
kabs,
|
||||
kacos,
|
||||
kadd,
|
||||
kasin,
|
||||
katan,
|
||||
katan2,
|
||||
kceil,
|
||||
kclamp,
|
||||
kcos,
|
||||
kcosh,
|
||||
kdiv,
|
||||
keq,
|
||||
kexp,
|
||||
kfloor,
|
||||
kfmod,
|
||||
kfrac,
|
||||
kge,
|
||||
kgt,
|
||||
kle,
|
||||
klerp,
|
||||
klgamma,
|
||||
klog,
|
||||
klog1p,
|
||||
klt,
|
||||
kmax,
|
||||
kmin,
|
||||
kmul,
|
||||
kne,
|
||||
kneg,
|
||||
kones,
|
||||
kpow,
|
||||
kreciprocal,
|
||||
kremainder,
|
||||
kround,
|
||||
krsqrt,
|
||||
ksigmoid,
|
||||
ksin,
|
||||
ksinh,
|
||||
ksqrt,
|
||||
ksub,
|
||||
ktan,
|
||||
ktanh,
|
||||
ktrunc,
|
||||
kzeros,
|
||||
};
|
||||
|
||||
bool isSimpleMap(Node *node) {
|
||||
return simple_mappable.count(node->kind());
|
||||
if(simple_mappable.count(node->kind())) {
|
||||
if(node->kind() == kmin || node->kind() == kmax)
|
||||
return node->inputs().size() == 2; // unary min/max is a reduction...
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
struct GraphFuser {
|
||||
|
|
@ -40,18 +89,35 @@ struct GraphFuser {
|
|||
bool isCuda(Node * node) {
|
||||
return node->type()->expect<TensorType>()->device() != -1;
|
||||
}
|
||||
// TODO: the fusion compiler needs to know how to handle 'alpha'
|
||||
// and other attributes in code generation for us to be able to fuse them
|
||||
// then it is safe to remove the !hasSpecialAlpha check
|
||||
bool hasSpecialAlpha(Node * node) {
|
||||
if(!node->hasAttribute(kalpha))
|
||||
// TODO: the fusion compiler has a lot of float-specific codegen
|
||||
// so for now we only consider nodes that operate on floating point numbers
|
||||
bool hasFloatType(Node * node) {
|
||||
if(!node->hasType()) {
|
||||
return false;
|
||||
return at::Scalar(node->t(kalpha)).toDouble() != 1;
|
||||
}
|
||||
if(auto tt = node->type()->cast<TensorType>()) {
|
||||
return tt->scalarType() != at::kFloat;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
bool allFloatIO(Node * node) {
|
||||
for(auto & o : node->outputs()) {
|
||||
if(!hasFloatType(o)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for(auto & o : node->inputs()) {
|
||||
if(!hasFloatType(o)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool isFusable(Node * node) {
|
||||
if (!node->hasType()) return false;
|
||||
if (node->kind() == kFusionGroup) return true;
|
||||
return isSimpleMap(node) && !hasSpecialAlpha(node) && isCuda(node);
|
||||
return isSimpleMap(node) && allFloatIO(node) && isCuda(node);
|
||||
}
|
||||
|
||||
// Can this node produce an _output_ of a fusion group?
|
||||
|
|
|
|||
|
|
@ -121,6 +121,7 @@ static void fusionTests() {
|
|||
auto p14 = appendNewNode(kmul,graph,{p20, i0});
|
||||
auto p11 = appendNewNode(kmul,graph,{p22, p18});
|
||||
auto o1 = appendNewNode(kadd,graph,{p14, p11});
|
||||
o1->t_(kalpha, at::Scalar(1).toTensor());
|
||||
auto p5 = appendNewNode(ktanh,graph,{o1});
|
||||
auto o0 = appendNewNode(kmul,graph,{p16, p5});
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user