diff --git a/torch/csrc/jit/fusion_compiler.cpp b/torch/csrc/jit/fusion_compiler.cpp index 9c12feebe96..131ad24ab57 100644 --- a/torch/csrc/jit/fusion_compiler.cpp +++ b/torch/csrc/jit/fusion_compiler.cpp @@ -16,6 +16,68 @@ namespace torch { namespace jit { +std::unordered_map simple_map_ops = { + // unary + {kabs, "absf(${0})"}, + {ksigmoid, "1.f / (1.f + expf(-${0}))"}, + {klog, "logf(${0})"}, + {klog1p, "log1pf(${0})"}, + {klgamma, "lgammaf(${0})"}, + {kexp, "expf(${0})"}, + {kcos, "cosf(${0})"}, + {kacos, "acosf(${0})"}, + {kcosh, "coshf(${0})"}, + {ksin, "sinf(${0})"}, + {kasin, "asinf(${0})"}, + {ksinh, "sinhf(${0})"}, + {ktan, "tanf(${0})"}, + {katan, "atanf(${0})"}, + {ktanh, "tanhf(${0})"}, + {ksqrt, "sqrtf(${0})"}, + {krsqrt, "rsqrtf(${0})"}, + {kceil, "ceilf(${0})"}, + {kfloor, "floorf(${0})"}, + {kround, "roundf(${0})"}, + {ktrunc, "truncf(${0})"}, + {kfrac, "fracf(${0})"}, + {kreciprocal, "reciprocalf(${0})"}, + {kneg, "-${0}"}, + //simple binary + {katan2, "atan2(${0}, ${1})"}, + {kmin, "fminf(${0}, ${1})"}, + {kmax, "fmaxf(${0}, ${1})"}, + + //binary with other + // TODO: some of these ops will not get generated because + // we only work on float inputs/outputs, but they are here to record + // that they are valid mappable ops once we handle more type + {k__and__, "${0} && ${1}"}, + {k__lshift__, "${0} << ${1}"}, + {k__or__, "${0} || ${1}"}, + {k__rshift__, "${0} >> ${1}"}, + {k__xor__, "${0} ^ ${1}"}, + {kdiv, "${0} / ${1}"}, + {keq, "${0} == ${1}"}, + {kfmod, "fmodf(${0}, ${1})"}, + {kge, "${0} >= ${1})"}, + {kgt, "${0} > ${1}"}, + {kle, "${0} <= ${1})"}, + {klt, "${0} < ${1}"}, + {kmul, "${0} * ${1}"}, + {kne, "${0} != ${1}"}, + {kremainder, "remainderf(${0}, ${1})"}, + {kpow, "powf(${0}, ${1})"}, + + //alpha + {kadd, "${0} + ${alpha}*${1}"}, + {ksub, "${0} - ${alpha}*${1})"}, + + // special + {klerp, "${0} + ${weight}*(${1} - ${0})"}, + {kclamp, "min(max(${0},${min}),${max})"}, + +}; + std::vector TensorDesc::findContiguous( const at::IntList& sizes, const at::IntList& strides) { @@ -136,21 +198,6 @@ std::string nodeName(Node * n) { std::to_string(s.toDouble()); } -// TODO: we need to support double-precision -std::unordered_map> simple_map_ops = { - {ksigmoid, [](Node*) { return "1.f / (1.f + expf(-${0}))"; }}, - {ktanh, [](Node*) { return "tanhf(${0})"; }}, - {kmul, [](Node*) { return "${0} * ${1}"; }}, - {kadd, [](Node*n) -> std::string { - if(n->inputs().size() == 2) - return "${0} + ${1}"; - else - return "${0} + " + scalarValue(n->t(kother)); - }}, - {kneg, [](Node*) { return "(-${0})"; }}, - -}; - const char * scalarTypeName(at::ScalarType type) { switch(type) { #define DEFINE_CASE(ctype,name,_) \ @@ -162,6 +209,31 @@ const char * scalarTypeName(at::ScalarType type) { } } +std::string encodeRHS(Node * n) { + TemplateEnv env; + size_t i = 0; + for(auto in : n->inputs()) { + env.s(std::to_string(i++),nodeName(in)); + } + // ops like div have a / b or a / 2 with the constant having the attribute other + // so we add other as an input if it is present + // 'pow' is the same but uses exponent as the attribute, so we handle that here as well + if(n->hasAttribute(kother) || n->hasAttribute(kexponent)) { + env.s(std::to_string(i), scalarValue(n->t(kother))); + } + // we also add any other scalar tensors to the env for special ops + for(auto a : n->attributeNames()) { + if(n->kindOf(a) == AttributeKind::t) { + auto v = n->t(a); + if(v.dim() == 0) { + env.s(symbolToString(a), scalarValue(v)); + } + } + } + const auto & str = simple_map_ops.at(n->kind()); + return format(str, env); +} + std::vector emitCompilationUnit(std::ostream & out, const std::string & name, AnnotatedGraph & agraph) { @@ -219,12 +291,8 @@ std::vector emitCompilationUnit(std::ostream & out, for(auto n : subgraph.nodes()) { if(n->kind() == kcat) continue; // Concat nodes by narrowing the output Tensors before the kernel runs - size_t i = 0; - for(auto in : n->inputs()) { - env.s(std::to_string(i++),nodeName(in)); - } env.s("node",nodeName(n)); - env.s("rhs",format(simple_map_ops.at(n->kind())(n),env)); + env.s("rhs", encodeRHS(n)); body << format("auto ${node} = ${rhs};\n",env); } for(auto o : flat_output_nodes) { diff --git a/torch/csrc/jit/interned_strings.h b/torch/csrc/jit/interned_strings.h index 840d67916a2..8c990340f2c 100644 --- a/torch/csrc/jit/interned_strings.h +++ b/torch/csrc/jit/interned_strings.h @@ -75,7 +75,53 @@ _(shape) \ _(axes) \ _(group) \ _(inplace) \ -_(other) +_(other) \ +_(__and__) \ +_(__lshift__) \ +_(__or__) \ +_(__rshift__) \ +_(__xor__) \ +_(abs) \ +_(acos) \ +_(asin) \ +_(atan) \ +_(atan2) \ +_(ceil) \ +_(clamp) \ +_(cos) \ +_(cosh) \ +_(div) \ +_(eq) \ +_(equal) \ +_(exp) \ +_(floor) \ +_(fmod) \ +_(frac) \ +_(ge) \ +_(gt) \ +_(le) \ +_(lerp) \ +_(lgamma) \ +_(log) \ +_(log1p) \ +_(lt) \ +_(max) \ +_(min) \ +_(ne) \ +_(ones) \ +_(pow) \ +_(reciprocal) \ +_(remainder) \ +_(round) \ +_(rsqrt) \ +_(sin) \ +_(sinh) \ +_(sqrt) \ +_(sub) \ +_(tan) \ +_(trunc) \ +_(zeros) \ +_(exponent) enum BuiltinSymbol { #define DEFINE_SYMBOL(s) \ diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index b27726b9799..c9efdeda3ad 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -13,15 +13,64 @@ namespace { // Some of these restrictions may be relaxable, but you should // carefully read the code first, as we rely on these assumptions. std::unordered_set simple_mappable = { - ksigmoid, - ktanh, - kmul, + k__and__, + k__lshift__, + k__or__, + k__rshift__, + k__xor__, + kabs, + kacos, kadd, + kasin, + katan, + katan2, + kceil, + kclamp, + kcos, + kcosh, + kdiv, + keq, + kexp, + kfloor, + kfmod, + kfrac, + kge, + kgt, + kle, + klerp, + klgamma, + klog, + klog1p, + klt, + kmax, + kmin, + kmul, + kne, kneg, + kones, + kpow, + kreciprocal, + kremainder, + kround, + krsqrt, + ksigmoid, + ksin, + ksinh, + ksqrt, + ksub, + ktan, + ktanh, + ktrunc, + kzeros, }; bool isSimpleMap(Node *node) { - return simple_mappable.count(node->kind()); + if(simple_mappable.count(node->kind())) { + if(node->kind() == kmin || node->kind() == kmax) + return node->inputs().size() == 2; // unary min/max is a reduction... + return true; + } + return false; } struct GraphFuser { @@ -40,18 +89,35 @@ struct GraphFuser { bool isCuda(Node * node) { return node->type()->expect()->device() != -1; } - // TODO: the fusion compiler needs to know how to handle 'alpha' - // and other attributes in code generation for us to be able to fuse them - // then it is safe to remove the !hasSpecialAlpha check - bool hasSpecialAlpha(Node * node) { - if(!node->hasAttribute(kalpha)) + // TODO: the fusion compiler has a lot of float-specific codegen + // so for now we only consider nodes that operate on floating point numbers + bool hasFloatType(Node * node) { + if(!node->hasType()) { return false; - return at::Scalar(node->t(kalpha)).toDouble() != 1; + } + if(auto tt = node->type()->cast()) { + return tt->scalarType() != at::kFloat; + } else { + return false; + } + } + bool allFloatIO(Node * node) { + for(auto & o : node->outputs()) { + if(!hasFloatType(o)) { + return false; + } + } + for(auto & o : node->inputs()) { + if(!hasFloatType(o)) { + return false; + } + } + return true; } bool isFusable(Node * node) { if (!node->hasType()) return false; if (node->kind() == kFusionGroup) return true; - return isSimpleMap(node) && !hasSpecialAlpha(node) && isCuda(node); + return isSimpleMap(node) && allFloatIO(node) && isCuda(node); } // Can this node produce an _output_ of a fusion group? diff --git a/torch/csrc/jit/test_jit.cpp b/torch/csrc/jit/test_jit.cpp index c28c2f347fc..75db2c3a75f 100644 --- a/torch/csrc/jit/test_jit.cpp +++ b/torch/csrc/jit/test_jit.cpp @@ -121,6 +121,7 @@ static void fusionTests() { auto p14 = appendNewNode(kmul,graph,{p20, i0}); auto p11 = appendNewNode(kmul,graph,{p22, p18}); auto o1 = appendNewNode(kadd,graph,{p14, p11}); + o1->t_(kalpha, at::Scalar(1).toTensor()); auto p5 = appendNewNode(ktanh,graph,{o1}); auto o0 = appendNewNode(kmul,graph,{p16, p5});