[te] Fix pow (#48213)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48213

it was completely broken unless rhs was a constant.

Test Plan: new unit test in test_jit_fuser_te.py

Reviewed By: eellison

Differential Revision: D25071639

fbshipit-source-id: ef1010a9fd551db646b83adfaa961648a5c388ae
This commit is contained in:
Bert Maher 2020-11-18 22:42:25 -08:00 committed by Facebook GitHub Bot
parent ed57f804fa
commit 6da26fe79b
2 changed files with 23 additions and 3 deletions

View File

@ -1366,6 +1366,25 @@ class TestTEFuser(JitTestCase):
b = torch.randint(-2, 2, (1, 64), device='cuda', dtype=torch.long)
script = self.checkScript(eager, (a, b))
def test_neg_pow(self):
def eager_tt(a: torch.Tensor, b: torch.Tensor):
return torch.neg(torch.pow(a, b))
def eager_ts(a: torch.Tensor, b: float):
return torch.neg(torch.pow(a, b))
def eager_st(a: float, b: torch.Tensor):
return torch.neg(torch.pow(a, b))
a = torch.rand(1, dtype=torch.float)
b = torch.rand(1, dtype=torch.float)
s = b.item()
script = self.checkScript(eager_tt, (a, b))
self.assertAllFused(script.graph_for(a, b))
script = self.checkScript(eager_ts, (a, s))
self.assertAllFused(script.graph_for(a, s))
script = self.checkScript(eager_st, (s, b))
self.assertAllFused(script.graph_for(s, b))
if __name__ == '__main__':
run_tests()

View File

@ -1026,10 +1026,11 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
case aten::pow: {
return computeTwoOperand(
"aten_pow", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
double val = 0;
if (rhs.node()->isConstant()) {
val = immediateAs<double>(IRSimplifier::simplify(rhs.node()));
if (!rhs.node()->isConstant()) {
return pow(lhs, rhs);
}
double val =
immediateAs<double>(IRSimplifier::simplify(rhs.node()));
if (val == 1.0f) {
return lhs;