Add more CPU tests (#47369)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47369

Test Plan: Imported from OSS

Reviewed By: ansley

Differential Revision: D24805251

Pulled By: eellison

fbshipit-source-id: f1a8210ffdc3cc88354cb4896652151d83a0345a
This commit is contained in:
Elias Ellison 2020-11-12 11:06:50 -08:00 committed by Facebook GitHub Bot
parent b8a1070ec0
commit fe81faee5f
3 changed files with 173 additions and 169 deletions

View File

@ -170,18 +170,17 @@ class TestTEFuser(JitTestCase):
traced_f = torch.jit.trace(f, (x, y,)) traced_f = torch.jit.trace(f, (x, y,))
self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y)) self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y))
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_broadcast(self):
def test_broadcast_cuda(self): for device in self.devices:
def scaleshift(x, scale, shift): def scaleshift(x, scale, shift):
return x * scale + shift return x * scale + shift
inputs = [ inputs = [
torch.randn(4, 4, dtype=torch.float, device='cuda'), torch.randn(4, 4, dtype=torch.float, device=device),
torch.randn(4, dtype=torch.float, device='cuda'), torch.randn(4, dtype=torch.float, device=device),
torch.randn(4, dtype=torch.float, device='cuda'), torch.randn(4, dtype=torch.float, device=device),
] ]
self.checkScript(scaleshift, inputs) self.checkScript(scaleshift, inputs)
self.assertLastGraphAllFused()
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skipIf(not RUN_CUDA_HALF, "no half support") @unittest.skipIf(not RUN_CUDA_HALF, "no half support")
@ -219,8 +218,8 @@ class TestTEFuser(JitTestCase):
grads_half = [t.half() for t in grads] grads_half = [t.half() for t in grads]
self.assertEqual(grads_half, fusion_grads) self.assertEqual(grads_half, fusion_grads)
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_checks_cat_inputs(self): def test_checks_cat_inputs(self):
for device in self.devices:
# We shouldn't treat cat nodes as broadcasting. All their inputs # We shouldn't treat cat nodes as broadcasting. All their inputs
# need to be checked for having the same map size, before we can # need to be checked for having the same map size, before we can
# run the kernel. # run the kernel.
@ -229,20 +228,20 @@ class TestTEFuser(JitTestCase):
# NOTE: y is broadcastable to x, but output of f(x, y) should have # NOTE: y is broadcastable to x, but output of f(x, y) should have
# shape 3x4, and not 4x4. # shape 3x4, and not 4x4.
x = torch.randn(2, 4, dtype=torch.float, device='cuda') x = torch.randn(2, 4, dtype=torch.float, device=device)
y = torch.randn(1, 4, dtype=torch.float, device='cuda') y = torch.randn(1, 4, dtype=torch.float, device=device)
scripted = self.checkScript(f, (x, y)) scripted = self.checkScript(f, (x, y))
self.assertEqual(scripted(x, y).shape, (3, 4)) self.assertEqual(scripted(x, y).shape, (3, 4))
self.assertAllFused(scripted.graph_for(x, y)) self.assertAllFused(scripted.graph_for(x, y))
@unittest.skipIf(not RUN_CUDA, "No CUDA") def test_chunk(self):
def test_chunk_cuda(self): for device in self.devices:
def fn(x): def fn(x):
a, b, c = x.chunk(3, 1) a, b, c = x.chunk(3, 1)
return a * b + c return a * b + c
inputs = [torch.randn(10, 6, dtype=torch.float, device='cuda')] inputs = [torch.randn(10, 6, dtype=torch.float, device=device)]
self.checkScript(fn, inputs) self.checkScript(fn, inputs)
self.assertLastGraphAllFused() self.assertLastGraphAllFused()
@ -303,8 +302,8 @@ class TestTEFuser(JitTestCase):
"ConstantChunk", 1, exactly=True "ConstantChunk", 1, exactly=True
).run(str(graph)) ).run(str(graph))
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_chunk_motion_deduplicates_inputs(self): def test_chunk_motion_deduplicates_inputs(self):
for device in self.devices:
def func1(x): def func1(x):
z = x * x z = x * x
z0, z1 = z.chunk(2) z0, z1 = z.chunk(2)
@ -316,14 +315,14 @@ class TestTEFuser(JitTestCase):
return z0 * z1 return z0 * z1
inputs = [ inputs = [
torch.tensor([1.1, 1.2], device='cuda', dtype=torch.float), torch.tensor([1.1, 1.2], device=device, dtype=torch.float),
] ]
for func in [func1, func2]: for func in [func1, func2]:
self.checkScript(func, inputs) self.checkScript(func, inputs)
self.assertLastGraphAllFused() self.assertLastGraphAllFused()
@unittest.skipIf(not RUN_CUDA, "No CUDA") def test_chunk_multiple(self):
def test_chunk_multiple_cuda(self): for device in self.devices:
# The arguments are intentionally used out of order as a test to see # The arguments are intentionally used out of order as a test to see
# if the fusion compiler adds extra args in the correct order # if the fusion compiler adds extra args in the correct order
def fn(s, x, y, z): def fn(s, x, y, z):
@ -333,16 +332,17 @@ class TestTEFuser(JitTestCase):
return s + x1 + x2 + x3 + y1 + y2 + z1 + z2 return s + x1 + x2 + x3 + y1 + y2 + z1 + z2
inputs = [ inputs = [
torch.randn(5, 2, 3, dtype=torch.float, device='cuda'), torch.randn(5, 2, 3, dtype=torch.float, device=device),
torch.randn(5, 6, 3, dtype=torch.float, device='cuda'), torch.randn(5, 6, 3, dtype=torch.float, device=device),
torch.randn(10, 2, 3, dtype=torch.float, device='cuda'), torch.randn(10, 2, 3, dtype=torch.float, device=device),
torch.randn(5, 2, 6, dtype=torch.float, device='cuda'), torch.randn(5, 2, 6, dtype=torch.float, device=device),
] ]
ge = self.checkScript(fn, inputs) ge = self.checkScript(fn, inputs)
self.assertAllFused(ge.graph_for(*inputs)) self.assertAllFused(ge.graph_for(*inputs))
def test_minmax(self): def test_minmax(self):
for device in self.devices:
def tmax(a, b): def tmax(a, b):
return torch.max(2 * a, b) return torch.max(2 * a, b)
@ -353,20 +353,16 @@ class TestTEFuser(JitTestCase):
b = torch.randn(4, 4, dtype=torch.float) b = torch.randn(4, 4, dtype=torch.float)
nan = torch.tensor(float('nan'), dtype=torch.float) nan = torch.tensor(float('nan'), dtype=torch.float)
devices = ["cpu"]
if torch.cuda.is_available():
devices.append("cuda")
for f, inputs, device in product( for f, inputs, device in product(
(tmax, tmin), (tmax, tmin),
([a, b], [a, nan], [b, nan]), ([a, b], [a, nan], [b, nan]),
devices): self.devices):
inputs = [t.to(device) for t in inputs] inputs = [t.to(device) for t in inputs]
s = self.checkScript(f, inputs) s = self.checkScript(f, inputs)
self.assertAllFused(s.graph_for(*inputs)) self.assertAllFused(s.graph_for(*inputs))
# TODO: reenable the test after backwards passes start working in PE
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_clamp(self): def test_clamp(self):
for device in self.devices:
def func2(a, b): def func2(a, b):
return torch.clamp(a + b, min=0, max=2) return torch.clamp(a + b, min=0, max=2)
@ -382,9 +378,9 @@ class TestTEFuser(JitTestCase):
def funcOptMax(a, b): def funcOptMax(a, b):
return torch.clamp(a + b, min=0) return torch.clamp(a + b, min=0)
a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True) a = torch.randn(4, 4, dtype=torch.float, device=device, requires_grad=True)
b = torch.randn(4, 4, dtype=torch.float, device='cuda') b = torch.randn(4, 4, dtype=torch.float, device=device)
nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda') nan = torch.tensor(float('nan'), dtype=torch.float, device=device)
funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax) funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax)
for f, inputs in product(funcs, [[a, b], [a, nan]]): for f, inputs in product(funcs, [[a, b], [a, nan]]):
@ -425,31 +421,31 @@ class TestTEFuser(JitTestCase):
ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
self.assertAllFused(ge.graph_for(x, y, z)) self.assertAllFused(ge.graph_for(x, y, z))
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_mul_bool(self): def test_mul_bool(self):
for device in self.devices:
def f(x, y, z): def f(x, y, z):
return x * y * z return x * y * z
x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
z = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') z = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
self.assertAllFused(ge.graph_for(x, y, z)) self.assertAllFused(ge.graph_for(x, y, z))
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_div_bool(self): def test_div_bool(self):
for device in self.devices:
def f(x, y, z): def f(x, y, z):
return (x + y) / z return (x + y) / z
x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
z = torch.ones_like(x, dtype=torch.bool, device='cuda') z = torch.ones_like(x, dtype=torch.bool, device=device)
ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
self.assertAllFused(ge.graph_for(x, y, z)) self.assertAllFused(ge.graph_for(x, y, z))
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not torch._C._llvm_enabled(), "TODO: bugs in ir eval")
def test_bitwise_ops(self): def test_bitwise_ops(self):
def apply(fn): def apply(fn):
return lambda x, y, z: fn(fn(x, y), z) return lambda x, y, z: fn(fn(x, y), z)
@ -467,7 +463,7 @@ class TestTEFuser(JitTestCase):
operator.__or__, operator.__or__,
operator.__xor__ operator.__xor__
] ]
devices = ["cuda"] devices = self.devices
for dtype, op, device in product(dtypes, binary_ops, devices): for dtype, op, device in product(dtypes, binary_ops, devices):
try: try:
x = self.data_for(dtype, device) x = self.data_for(dtype, device)
@ -528,8 +524,8 @@ class TestTEFuser(JitTestCase):
" ".join(["Failed:", str(dtype), op.__name__, device]) " ".join(["Failed:", str(dtype), op.__name__, device])
) )
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_comparison_eq_ne(self): def test_comparison_eq_ne(self):
for device in self.devices:
def f(x, y): def f(x, y):
mask = (x == 0).type_as(x) mask = (x == 0).type_as(x)
z = x * mask + y z = x * mask + y
@ -537,8 +533,8 @@ class TestTEFuser(JitTestCase):
z = z * mask + y z = z * mask + y
return z return z
x = torch.randn(4, 4, dtype=torch.float, device='cuda') x = torch.randn(4, 4, dtype=torch.float, device=device)
y = torch.randn(4, 4, dtype=torch.float, device='cuda') y = torch.randn(4, 4, dtype=torch.float, device=device)
ge = self.checkTrace(f, (x, y)) ge = self.checkTrace(f, (x, y))
self.assertAllFused(ge.graph_for(x, y)) self.assertAllFused(ge.graph_for(x, y))
@ -551,16 +547,16 @@ class TestTEFuser(JitTestCase):
z = z * mask + y z = z * mask + y
return z return z
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_comparison_gt_lt(self):
def test_comparison_gt_lt_cuda(self): for device in self.devices:
x = torch.randn(4, 4, dtype=torch.float, device='cuda') x = torch.randn(4, 4, dtype=torch.float, device=device)
y = torch.randn(4, 4, dtype=torch.float, device='cuda') y = torch.randn(4, 4, dtype=torch.float, device=device)
ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y)) ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
self.assertAllFused(ge.graph_for(x, y)) self.assertAllFused(ge.graph_for(x, y))
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_comparison_ge_le(self):
def test_comparison_ge_le_cuda(self): for device in self.devices:
def f(x, y): def f(x, y):
mask = (x >= 0).type_as(x) mask = (x >= 0).type_as(x)
z = x * mask + y z = x * mask + y
@ -568,8 +564,8 @@ class TestTEFuser(JitTestCase):
z = z * mask + y z = z * mask + y
return z return z
x = torch.randn(4, 4, dtype=torch.float, device='cuda') x = torch.randn(4, 4, dtype=torch.float, device=device)
y = torch.randn(4, 4, dtype=torch.float, device='cuda') y = torch.randn(4, 4, dtype=torch.float, device=device)
ge = self.checkTrace(f, (x, y)) ge = self.checkTrace(f, (x, y))
self.assertAllFused(ge.graph_for(x, y)) self.assertAllFused(ge.graph_for(x, y))
@ -578,11 +574,11 @@ class TestTEFuser(JitTestCase):
self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes", self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes",
"aten::_size_if_not_equal")) "aten::_size_if_not_equal"))
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_addcmul(self):
def test_addcmul_cuda(self): for device in self.devices:
t = torch.randn(1, 4, dtype=torch.float, device='cuda') t = torch.randn(1, 4, dtype=torch.float, device=device)
t1 = torch.randn(4, 1, dtype=torch.float, device='cuda') t1 = torch.randn(4, 1, dtype=torch.float, device=device)
t2 = torch.randn(1, 4, dtype=torch.float, device='cuda') t2 = torch.randn(1, 4, dtype=torch.float, device=device)
def foo(t, t1, t2): def foo(t, t1, t2):
return t.addcmul(t + 1, t2, value=0.1) return t.addcmul(t + 1, t2, value=0.1)

View File

@ -184,6 +184,7 @@ def _jit_can_fuse_on_cpu() -> _bool: ...
def _jit_can_fuse_on_gpu() -> _bool: ... def _jit_can_fuse_on_gpu() -> _bool: ...
def _jit_texpr_fuser_enabled() -> _bool: ... def _jit_texpr_fuser_enabled() -> _bool: ...
def _jit_nvfuser_enabled() -> _bool: ... def _jit_nvfuser_enabled() -> _bool: ...
def _llvm_enabled() -> _bool: ...
def _jit_override_can_fuse_on_cpu(override: _bool): ... def _jit_override_can_fuse_on_cpu(override: _bool): ...
def _jit_override_can_fuse_on_gpu(override: _bool): ... def _jit_override_can_fuse_on_gpu(override: _bool): ...
def _jit_set_texpr_fuser_enabled(enable: _bool): ... def _jit_set_texpr_fuser_enabled(enable: _bool): ...

View File

@ -629,6 +629,13 @@ void initJITBindings(PyObject* module) {
using namespace torch::jit::tensorexpr; using namespace torch::jit::tensorexpr;
getTEMustUseLLVMOnCPU() = use_llvm; getTEMustUseLLVMOnCPU() = use_llvm;
}) })
.def("_llvm_enabled", []() {
#ifdef TORCH_ENABLE_LLVM
return true;
#else
return false;
#endif
})
.def( .def(
"_jit_pass_fuse_tensorexprs", "_jit_pass_fuse_tensorexprs",
[](std::shared_ptr<Graph>& g) { return FuseTensorExprs(g); }) [](std::shared_ptr<Graph>& g) { return FuseTensorExprs(g); })