mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[TensorExpr] Fix lowering for aten::div. (#48329)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48329 Test Plan: Imported from OSS Reviewed By: eellison Differential Revision: D25130750 Pulled By: ZolotukhinM fbshipit-source-id: 7c6345adcaec5f92cd6ce78b01f6a7d5923c0004
This commit is contained in:
parent
5e1faa1d41
commit
b967119906
|
|
@ -1308,6 +1308,7 @@ class TestTEFuser(JitTestCase):
|
|||
torch.max,
|
||||
lambda x, y: torch.lerp(x, y, 0.5),
|
||||
torch.atan2,
|
||||
torch.div,
|
||||
|
||||
# FIXME: comparison ops yield different results when fused
|
||||
# torch.eq,
|
||||
|
|
@ -1316,12 +1317,13 @@ class TestTEFuser(JitTestCase):
|
|||
# torch.gt,
|
||||
# torch.lt,
|
||||
|
||||
# TODO: test operators exercising division too
|
||||
# FIXME: fails on CPU backend with int8
|
||||
# torch.fmod,
|
||||
# torch.remainder,
|
||||
|
||||
# FIXME: segfaults on CPU backend
|
||||
# operator.__rshift__,
|
||||
# operator.__lshift__,
|
||||
# torch.div,
|
||||
]
|
||||
devices = self.devices
|
||||
for dtype, op, device in product(dtypes, binary_ops, devices):
|
||||
|
|
@ -1358,7 +1360,7 @@ class TestTEFuser(JitTestCase):
|
|||
torch.float16,
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
# torch.bool intentionally not included
|
||||
torch.bool
|
||||
]
|
||||
binary_ops = [
|
||||
operator.__and__,
|
||||
|
|
@ -1375,12 +1377,51 @@ class TestTEFuser(JitTestCase):
|
|||
# torch.lt,
|
||||
# torch.gt,
|
||||
|
||||
# FIXME: fails with integer dtype and scalar={3,0}
|
||||
# torch.div,
|
||||
|
||||
# FIXME: segfaults on CPU backend
|
||||
# operator.__rshift__,
|
||||
# operator.__lshift__,
|
||||
]
|
||||
devices = self.devices
|
||||
# Maybe we should split this into separate tests to speed it up by
|
||||
# only using scalar values relevant to particular ops
|
||||
scalars = [1.5, 3, 0, -2.0, -1]
|
||||
for dtype, op, device, scalar in product(dtypes, binary_ops, devices, scalars):
|
||||
try:
|
||||
x = self.data_for(dtype, device)
|
||||
fn = apply_with_scalar(op, scalar)
|
||||
ref = fn(x)
|
||||
except Exception:
|
||||
# If eager mode doesn't support a dtype/op/device combo,
|
||||
# neither does the fuser. Catch everything to avoid needing to
|
||||
# guess what errors might be thrown by eager.
|
||||
continue
|
||||
try:
|
||||
t = torch.jit.trace(fn, (x))
|
||||
self.assertEqual(ref, t(x))
|
||||
self.assertAllFused(t.graph_for(x))
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
" ".join(["Failed:", str(dtype), op.__name__, device])
|
||||
)
|
||||
|
||||
def test_binary_div_ops(self):
|
||||
def apply_with_scalar(fn, scalar):
|
||||
return lambda x: fn(x, scalar)
|
||||
|
||||
dtypes = [
|
||||
torch.int8,
|
||||
torch.uint8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
# FIXME: breaks in IR eval
|
||||
# torch.float16,
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
torch.bool
|
||||
]
|
||||
binary_ops = [
|
||||
torch.div,
|
||||
|
||||
# FIXME: wrong results with int8 on cpu
|
||||
# torch.remainder,
|
||||
|
|
@ -1389,7 +1430,7 @@ class TestTEFuser(JitTestCase):
|
|||
devices = self.devices
|
||||
# Maybe we should split this into separate tests to speed it up by
|
||||
# only using scalar values relevant to particular ops
|
||||
scalars = [1.5, 3, 0, -2.0, -1]
|
||||
scalars = [1.5, 3, -2.0, -1] # skip 0
|
||||
for dtype, op, device, scalar in product(dtypes, binary_ops, devices, scalars):
|
||||
try:
|
||||
x = self.data_for(dtype, device)
|
||||
|
|
|
|||
|
|
@ -780,7 +780,7 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
|
|||
case aten::div: {
|
||||
return computeTwoOperand(
|
||||
"aten_div", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
||||
return boolToInteger(lhs) / boolToInteger(rhs);
|
||||
return promoteIntegerToFloat(lhs) / promoteIntegerToFloat(rhs);
|
||||
});
|
||||
} break;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user