diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index a81dc3dcce9..ecde9278bb8 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -736,6 +736,10 @@ ProcessedNode::ProcessedNode( std::ostringstream ss; node->print(ss, 0, nullptr, false); VLOG(1) << "Switch to native impl for node: " << ss.str(); + } else { + std::ostringstream ss; + node->print(ss, 0, nullptr, false); + VLOG(1) << "Fallback interpreter for node: " << ss.str(); } } diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 016f9eacbbd..d7d3e77f754 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -19,8 +19,10 @@ bool canRunOutOfPlace(Node* n) { "aten::bmm", "aten::cat", "aten::clamp", + "aten::clone", "aten::leaky_relu", "aten::logit", + "aten::logit", "aten::mul", "aten::nan_to_num", "aten::relu", @@ -37,12 +39,22 @@ bool canRunNatively(Node* n) { // In alphabetical order const static std::unordered_set native_nodes{ "aten::flatten", + "aten::permute", + "aten::reshape", + "aten::slice", "aten::transpose", + "aten::to", "prim::ListConstruct", "prim::ListUnpack", "prim::TupleConstruct"}; auto str = std::string(n->kind().toQualString()); - return native_nodes.count(str) > 0; + if (!native_nodes.count(str)) { + return false; + } + if (str == "aten::to") { + return n->inputs().size() == 5; + } + return true; } std::function&)> @@ -217,7 +229,9 @@ getOutOfPlaceOperation(Node* n) { } else if (n->kind() == c10::Symbol::fromQualString("aten::logit")) { return [](const ProcessedNode* p_node, std::vector& reg) { auto in0_t = p_node->Input(0, reg).toTensor(); - auto in1_d = p_node->Input(1, reg).toDouble(); + double in1_d = p_node->input_regs().size() > 1 + ? p_node->Input(1, reg).toDouble() + : -1.0; if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } @@ -310,6 +324,44 @@ getNativeOperation(Node* n) { p_node->Output(i, reg) = std::move(stack[i]); } }; + } else if (n->kind() == c10::Symbol::fromQualString("aten::permute")) { + return [=](const ProcessedNode* p_node, std::vector& reg) { + auto in0_t = p_node->Input(0, reg).toTensor(); + auto in1_iv = p_node->Input(1, reg).toIntVector(); + p_node->Output(0, reg) = at::native::permute(in0_t, in1_iv); + }; + } else if (n->kind() == c10::Symbol::fromQualString("aten::reshape")) { + return [=](const ProcessedNode* p_node, std::vector& reg) { + auto in0_t = p_node->Input(0, reg).toTensor(); + auto in1_iv = p_node->Input(1, reg).toIntVector(); + p_node->Output(0, reg) = at::native::reshape(in0_t, in1_iv); + }; + } else if (n->kind() == c10::Symbol::fromQualString("aten::slice")) { + return [=](const ProcessedNode* p_node, std::vector& reg) { + auto in0_t = p_node->Input(0, reg).toTensor(); + auto in1_i = p_node->Input(1, reg).toInt(); + auto in2_i = p_node->Input(2, reg).toInt(); + auto in3_i = p_node->Input(3, reg).toInt(); + auto in4_i = p_node->Input(4, reg).toInt(); + p_node->Output(0, reg) = + at::native::slice(in0_t, in1_i, in2_i, in3_i, in4_i); + }; + } else if (n->kind() == c10::Symbol::fromQualString("aten::to")) { + return [=](const ProcessedNode* p_node, std::vector& reg) { + DCHECK(p_node->input_regs().size() == 5); + auto in0_t = p_node->Input(0, reg).toTensor(); + auto in1_i = p_node->Input(1, reg).toScalarType(); + auto in2_i = p_node->Input(2, reg).toBool(); + auto in3_i = p_node->Input(3, reg).toBool(); + if (p_node->Input(4, reg).isNone()) { + p_node->Output(0, reg) = + at::native::to(in0_t, in1_i, in2_i, in3_i, c10::nullopt); + } else { + auto in4_o = p_node->Input(4, reg).toMemoryFormat(); + p_node->Output(0, reg) = + at::native::to(in0_t, in1_i, in2_i, in3_i, in4_o); + } + }; } return [](const ProcessedNode*, std::vector&) { TORCH_CHECK(0); }; }