add more fusable nodes to the graph compiler (#3559)

This commit is contained in:
Zachary DeVito 2017-11-08 19:58:08 -08:00 committed by Soumith Chintala
parent 285ce10dbe
commit 25d3c25f50
4 changed files with 213 additions and 32 deletions

View File

@ -16,6 +16,68 @@
namespace torch { namespace jit {
std::unordered_map<NodeKind, std::string> simple_map_ops = {
// unary
{kabs, "absf(${0})"},
{ksigmoid, "1.f / (1.f + expf(-${0}))"},
{klog, "logf(${0})"},
{klog1p, "log1pf(${0})"},
{klgamma, "lgammaf(${0})"},
{kexp, "expf(${0})"},
{kcos, "cosf(${0})"},
{kacos, "acosf(${0})"},
{kcosh, "coshf(${0})"},
{ksin, "sinf(${0})"},
{kasin, "asinf(${0})"},
{ksinh, "sinhf(${0})"},
{ktan, "tanf(${0})"},
{katan, "atanf(${0})"},
{ktanh, "tanhf(${0})"},
{ksqrt, "sqrtf(${0})"},
{krsqrt, "rsqrtf(${0})"},
{kceil, "ceilf(${0})"},
{kfloor, "floorf(${0})"},
{kround, "roundf(${0})"},
{ktrunc, "truncf(${0})"},
{kfrac, "fracf(${0})"},
{kreciprocal, "reciprocalf(${0})"},
{kneg, "-${0}"},
//simple binary
{katan2, "atan2(${0}, ${1})"},
{kmin, "fminf(${0}, ${1})"},
{kmax, "fmaxf(${0}, ${1})"},
//binary with other
// TODO: some of these ops will not get generated because
// we only work on float inputs/outputs, but they are here to record
// that they are valid mappable ops once we handle more type
{k__and__, "${0} && ${1}"},
{k__lshift__, "${0} << ${1}"},
{k__or__, "${0} || ${1}"},
{k__rshift__, "${0} >> ${1}"},
{k__xor__, "${0} ^ ${1}"},
{kdiv, "${0} / ${1}"},
{keq, "${0} == ${1}"},
{kfmod, "fmodf(${0}, ${1})"},
{kge, "${0} >= ${1})"},
{kgt, "${0} > ${1}"},
{kle, "${0} <= ${1})"},
{klt, "${0} < ${1}"},
{kmul, "${0} * ${1}"},
{kne, "${0} != ${1}"},
{kremainder, "remainderf(${0}, ${1})"},
{kpow, "powf(${0}, ${1})"},
//alpha
{kadd, "${0} + ${alpha}*${1}"},
{ksub, "${0} - ${alpha}*${1})"},
// special
{klerp, "${0} + ${weight}*(${1} - ${0})"},
{kclamp, "min(max(${0},${min}),${max})"},
};
std::vector<bool> TensorDesc::findContiguous(
const at::IntList& sizes,
const at::IntList& strides) {
@ -136,21 +198,6 @@ std::string nodeName(Node * n) {
std::to_string(s.toDouble());
}
// TODO: we need to support double-precision
std::unordered_map<NodeKind,std::function<std::string(Node*)>> simple_map_ops = {
{ksigmoid, [](Node*) { return "1.f / (1.f + expf(-${0}))"; }},
{ktanh, [](Node*) { return "tanhf(${0})"; }},
{kmul, [](Node*) { return "${0} * ${1}"; }},
{kadd, [](Node*n) -> std::string {
if(n->inputs().size() == 2)
return "${0} + ${1}";
else
return "${0} + " + scalarValue(n->t(kother));
}},
{kneg, [](Node*) { return "(-${0})"; }},
};
const char * scalarTypeName(at::ScalarType type) {
switch(type) {
#define DEFINE_CASE(ctype,name,_) \
@ -162,6 +209,31 @@ const char * scalarTypeName(at::ScalarType type) {
}
}
std::string encodeRHS(Node * n) {
TemplateEnv env;
size_t i = 0;
for(auto in : n->inputs()) {
env.s(std::to_string(i++),nodeName(in));
}
// ops like div have a / b or a / 2 with the constant having the attribute other
// so we add other as an input if it is present
// 'pow' is the same but uses exponent as the attribute, so we handle that here as well
if(n->hasAttribute(kother) || n->hasAttribute(kexponent)) {
env.s(std::to_string(i), scalarValue(n->t(kother)));
}
// we also add any other scalar tensors to the env for special ops
for(auto a : n->attributeNames()) {
if(n->kindOf(a) == AttributeKind::t) {
auto v = n->t(a);
if(v.dim() == 0) {
env.s(symbolToString(a), scalarValue(v));
}
}
}
const auto & str = simple_map_ops.at(n->kind());
return format(str, env);
}
std::vector<ConcatDesc> emitCompilationUnit(std::ostream & out,
const std::string & name,
AnnotatedGraph & agraph) {
@ -219,12 +291,8 @@ std::vector<ConcatDesc> emitCompilationUnit(std::ostream & out,
for(auto n : subgraph.nodes()) {
if(n->kind() == kcat)
continue; // Concat nodes by narrowing the output Tensors before the kernel runs
size_t i = 0;
for(auto in : n->inputs()) {
env.s(std::to_string(i++),nodeName(in));
}
env.s("node",nodeName(n));
env.s("rhs",format(simple_map_ops.at(n->kind())(n),env));
env.s("rhs", encodeRHS(n));
body << format("auto ${node} = ${rhs};\n",env);
}
for(auto o : flat_output_nodes) {

View File

@ -75,7 +75,53 @@ _(shape) \
_(axes) \
_(group) \
_(inplace) \
_(other)
_(other) \
_(__and__) \
_(__lshift__) \
_(__or__) \
_(__rshift__) \
_(__xor__) \
_(abs) \
_(acos) \
_(asin) \
_(atan) \
_(atan2) \
_(ceil) \
_(clamp) \
_(cos) \
_(cosh) \
_(div) \
_(eq) \
_(equal) \
_(exp) \
_(floor) \
_(fmod) \
_(frac) \
_(ge) \
_(gt) \
_(le) \
_(lerp) \
_(lgamma) \
_(log) \
_(log1p) \
_(lt) \
_(max) \
_(min) \
_(ne) \
_(ones) \
_(pow) \
_(reciprocal) \
_(remainder) \
_(round) \
_(rsqrt) \
_(sin) \
_(sinh) \
_(sqrt) \
_(sub) \
_(tan) \
_(trunc) \
_(zeros) \
_(exponent)
enum BuiltinSymbol {
#define DEFINE_SYMBOL(s) \

View File

@ -13,15 +13,64 @@ namespace {
// Some of these restrictions may be relaxable, but you should
// carefully read the code first, as we rely on these assumptions.
std::unordered_set<NodeKind> simple_mappable = {
ksigmoid,
ktanh,
kmul,
k__and__,
k__lshift__,
k__or__,
k__rshift__,
k__xor__,
kabs,
kacos,
kadd,
kasin,
katan,
katan2,
kceil,
kclamp,
kcos,
kcosh,
kdiv,
keq,
kexp,
kfloor,
kfmod,
kfrac,
kge,
kgt,
kle,
klerp,
klgamma,
klog,
klog1p,
klt,
kmax,
kmin,
kmul,
kne,
kneg,
kones,
kpow,
kreciprocal,
kremainder,
kround,
krsqrt,
ksigmoid,
ksin,
ksinh,
ksqrt,
ksub,
ktan,
ktanh,
ktrunc,
kzeros,
};
bool isSimpleMap(Node *node) {
return simple_mappable.count(node->kind());
if(simple_mappable.count(node->kind())) {
if(node->kind() == kmin || node->kind() == kmax)
return node->inputs().size() == 2; // unary min/max is a reduction...
return true;
}
return false;
}
struct GraphFuser {
@ -40,18 +89,35 @@ struct GraphFuser {
bool isCuda(Node * node) {
return node->type()->expect<TensorType>()->device() != -1;
}
// TODO: the fusion compiler needs to know how to handle 'alpha'
// and other attributes in code generation for us to be able to fuse them
// then it is safe to remove the !hasSpecialAlpha check
bool hasSpecialAlpha(Node * node) {
if(!node->hasAttribute(kalpha))
// TODO: the fusion compiler has a lot of float-specific codegen
// so for now we only consider nodes that operate on floating point numbers
bool hasFloatType(Node * node) {
if(!node->hasType()) {
return false;
return at::Scalar(node->t(kalpha)).toDouble() != 1;
}
if(auto tt = node->type()->cast<TensorType>()) {
return tt->scalarType() != at::kFloat;
} else {
return false;
}
}
bool allFloatIO(Node * node) {
for(auto & o : node->outputs()) {
if(!hasFloatType(o)) {
return false;
}
}
for(auto & o : node->inputs()) {
if(!hasFloatType(o)) {
return false;
}
}
return true;
}
bool isFusable(Node * node) {
if (!node->hasType()) return false;
if (node->kind() == kFusionGroup) return true;
return isSimpleMap(node) && !hasSpecialAlpha(node) && isCuda(node);
return isSimpleMap(node) && allFloatIO(node) && isCuda(node);
}
// Can this node produce an _output_ of a fusion group?

View File

@ -121,6 +121,7 @@ static void fusionTests() {
auto p14 = appendNewNode(kmul,graph,{p20, i0});
auto p11 = appendNewNode(kmul,graph,{p22, p18});
auto o1 = appendNewNode(kadd,graph,{p14, p11});
o1->t_(kalpha, at::Scalar(1).toTensor());
auto p5 = appendNewNode(ktanh,graph,{o1});
auto o0 = appendNewNode(kmul,graph,{p16, p5});