[TensorExpr] Fix lowering for aten::div. (#48329)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48329

Test Plan: Imported from OSS

Reviewed By: eellison

Differential Revision: D25130750

Pulled By: ZolotukhinM

fbshipit-source-id: 7c6345adcaec5f92cd6ce78b01f6a7d5923c0004
This commit is contained in:
Mikhail Zolotukhin 2020-11-21 09:14:03 -08:00 committed by Facebook GitHub Bot
parent 5e1faa1d41
commit b967119906
2 changed files with 49 additions and 8 deletions

View File

@ -1308,6 +1308,7 @@ class TestTEFuser(JitTestCase):
torch.max,
lambda x, y: torch.lerp(x, y, 0.5),
torch.atan2,
torch.div,
# FIXME: comparison ops yield different results when fused
# torch.eq,
@ -1316,12 +1317,13 @@ class TestTEFuser(JitTestCase):
# torch.gt,
# torch.lt,
# TODO: test operators exercising division too
# FIXME: fails on CPU backend with int8
# torch.fmod,
# torch.remainder,
# FIXME: segfaults on CPU backend
# operator.__rshift__,
# operator.__lshift__,
# torch.div,
]
devices = self.devices
for dtype, op, device in product(dtypes, binary_ops, devices):
@ -1358,7 +1360,7 @@ class TestTEFuser(JitTestCase):
torch.float16,
torch.float32,
torch.float64,
# torch.bool intentionally not included
torch.bool
]
binary_ops = [
operator.__and__,
@ -1375,12 +1377,51 @@ class TestTEFuser(JitTestCase):
# torch.lt,
# torch.gt,
# FIXME: fails with integer dtype and scalar={3,0}
# torch.div,
# FIXME: segfaults on CPU backend
# operator.__rshift__,
# operator.__lshift__,
]
devices = self.devices
# Maybe we should split this into separate tests to speed it up by
# only using scalar values relevant to particular ops
scalars = [1.5, 3, 0, -2.0, -1]
for dtype, op, device, scalar in product(dtypes, binary_ops, devices, scalars):
try:
x = self.data_for(dtype, device)
fn = apply_with_scalar(op, scalar)
ref = fn(x)
except Exception:
# If eager mode doesn't support a dtype/op/device combo,
# neither does the fuser. Catch everything to avoid needing to
# guess what errors might be thrown by eager.
continue
try:
t = torch.jit.trace(fn, (x))
self.assertEqual(ref, t(x))
self.assertAllFused(t.graph_for(x))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
)
def test_binary_div_ops(self):
def apply_with_scalar(fn, scalar):
return lambda x: fn(x, scalar)
dtypes = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
# FIXME: breaks in IR eval
# torch.float16,
torch.float32,
torch.float64,
torch.bool
]
binary_ops = [
torch.div,
# FIXME: wrong results with int8 on cpu
# torch.remainder,
@ -1389,7 +1430,7 @@ class TestTEFuser(JitTestCase):
devices = self.devices
# Maybe we should split this into separate tests to speed it up by
# only using scalar values relevant to particular ops
scalars = [1.5, 3, 0, -2.0, -1]
scalars = [1.5, 3, -2.0, -1] # skip 0
for dtype, op, device, scalar in product(dtypes, binary_ops, devices, scalars):
try:
x = self.data_for(dtype, device)

View File

@ -780,7 +780,7 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
case aten::div: {
return computeTwoOperand(
"aten_div", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
return boolToInteger(lhs) / boolToInteger(rhs);
return promoteIntegerToFloat(lhs) / promoteIntegerToFloat(rhs);
});
} break;