mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[te] Fix pow (#48213)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48213 it was completely broken unless rhs was a constant. Test Plan: new unit test in test_jit_fuser_te.py Reviewed By: eellison Differential Revision: D25071639 fbshipit-source-id: ef1010a9fd551db646b83adfaa961648a5c388ae
This commit is contained in:
parent
ed57f804fa
commit
6da26fe79b
|
|
@ -1366,6 +1366,25 @@ class TestTEFuser(JitTestCase):
|
||||||
b = torch.randint(-2, 2, (1, 64), device='cuda', dtype=torch.long)
|
b = torch.randint(-2, 2, (1, 64), device='cuda', dtype=torch.long)
|
||||||
script = self.checkScript(eager, (a, b))
|
script = self.checkScript(eager, (a, b))
|
||||||
|
|
||||||
|
def test_neg_pow(self):
|
||||||
|
def eager_tt(a: torch.Tensor, b: torch.Tensor):
|
||||||
|
return torch.neg(torch.pow(a, b))
|
||||||
|
|
||||||
|
def eager_ts(a: torch.Tensor, b: float):
|
||||||
|
return torch.neg(torch.pow(a, b))
|
||||||
|
|
||||||
|
def eager_st(a: float, b: torch.Tensor):
|
||||||
|
return torch.neg(torch.pow(a, b))
|
||||||
|
|
||||||
|
a = torch.rand(1, dtype=torch.float)
|
||||||
|
b = torch.rand(1, dtype=torch.float)
|
||||||
|
s = b.item()
|
||||||
|
script = self.checkScript(eager_tt, (a, b))
|
||||||
|
self.assertAllFused(script.graph_for(a, b))
|
||||||
|
script = self.checkScript(eager_ts, (a, s))
|
||||||
|
self.assertAllFused(script.graph_for(a, s))
|
||||||
|
script = self.checkScript(eager_st, (s, b))
|
||||||
|
self.assertAllFused(script.graph_for(s, b))
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -1026,10 +1026,11 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
|
||||||
case aten::pow: {
|
case aten::pow: {
|
||||||
return computeTwoOperand(
|
return computeTwoOperand(
|
||||||
"aten_pow", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
"aten_pow", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
||||||
double val = 0;
|
if (!rhs.node()->isConstant()) {
|
||||||
if (rhs.node()->isConstant()) {
|
return pow(lhs, rhs);
|
||||||
val = immediateAs<double>(IRSimplifier::simplify(rhs.node()));
|
|
||||||
}
|
}
|
||||||
|
double val =
|
||||||
|
immediateAs<double>(IRSimplifier::simplify(rhs.node()));
|
||||||
|
|
||||||
if (val == 1.0f) {
|
if (val == 1.0f) {
|
||||||
return lhs;
|
return lhs;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user