diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index d690d994224..f381759b4cf 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -1366,6 +1366,25 @@ class TestTEFuser(JitTestCase): b = torch.randint(-2, 2, (1, 64), device='cuda', dtype=torch.long) script = self.checkScript(eager, (a, b)) + def test_neg_pow(self): + def eager_tt(a: torch.Tensor, b: torch.Tensor): + return torch.neg(torch.pow(a, b)) + + def eager_ts(a: torch.Tensor, b: float): + return torch.neg(torch.pow(a, b)) + + def eager_st(a: float, b: torch.Tensor): + return torch.neg(torch.pow(a, b)) + + a = torch.rand(1, dtype=torch.float) + b = torch.rand(1, dtype=torch.float) + s = b.item() + script = self.checkScript(eager_tt, (a, b)) + self.assertAllFused(script.graph_for(a, b)) + script = self.checkScript(eager_ts, (a, s)) + self.assertAllFused(script.graph_for(a, s)) + script = self.checkScript(eager_st, (s, b)) + self.assertAllFused(script.graph_for(s, b)) if __name__ == '__main__': run_tests() diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 66b81039f33..3311f6338ca 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -1026,10 +1026,11 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { case aten::pow: { return computeTwoOperand( "aten_pow", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { - double val = 0; - if (rhs.node()->isConstant()) { - val = immediateAs(IRSimplifier::simplify(rhs.node())); + if (!rhs.node()->isConstant()) { + return pow(lhs, rhs); } + double val = + immediateAs(IRSimplifier::simplify(rhs.node())); if (val == 1.0f) { return lhs;