From fe81faee5f65dfd3c015c7337729f70b51bba10e Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Thu, 12 Nov 2020 11:06:50 -0800 Subject: [PATCH] Add more CPU tests (#47369) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47369 Test Plan: Imported from OSS Reviewed By: ansley Differential Revision: D24805251 Pulled By: eellison fbshipit-source-id: f1a8210ffdc3cc88354cb4896652151d83a0345a --- test/test_jit_fuser_te.py | 334 ++++++++++++++++----------------- torch/_C/__init__.pyi.in | 1 + torch/csrc/jit/python/init.cpp | 7 + 3 files changed, 173 insertions(+), 169 deletions(-) diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 7392db3a757..bbefd1fc3ab 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -170,18 +170,17 @@ class TestTEFuser(JitTestCase): traced_f = torch.jit.trace(f, (x, y,)) self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_broadcast_cuda(self): - def scaleshift(x, scale, shift): - return x * scale + shift + def test_broadcast(self): + for device in self.devices: + def scaleshift(x, scale, shift): + return x * scale + shift - inputs = [ - torch.randn(4, 4, dtype=torch.float, device='cuda'), - torch.randn(4, dtype=torch.float, device='cuda'), - torch.randn(4, dtype=torch.float, device='cuda'), - ] - self.checkScript(scaleshift, inputs) - self.assertLastGraphAllFused() + inputs = [ + torch.randn(4, 4, dtype=torch.float, device=device), + torch.randn(4, dtype=torch.float, device=device), + torch.randn(4, dtype=torch.float, device=device), + ] + self.checkScript(scaleshift, inputs) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not RUN_CUDA_HALF, "no half support") @@ -219,33 +218,33 @@ class TestTEFuser(JitTestCase): grads_half = [t.half() for t in grads] self.assertEqual(grads_half, fusion_grads) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_checks_cat_inputs(self): - # We shouldn't treat cat nodes as broadcasting. All their inputs - # need to be checked for having the same map size, before we can - # run the kernel. - def f(x, y): - return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0) + for device in self.devices: + # We shouldn't treat cat nodes as broadcasting. All their inputs + # need to be checked for having the same map size, before we can + # run the kernel. + def f(x, y): + return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0) - # NOTE: y is broadcastable to x, but output of f(x, y) should have - # shape 3x4, and not 4x4. - x = torch.randn(2, 4, dtype=torch.float, device='cuda') - y = torch.randn(1, 4, dtype=torch.float, device='cuda') + # NOTE: y is broadcastable to x, but output of f(x, y) should have + # shape 3x4, and not 4x4. + x = torch.randn(2, 4, dtype=torch.float, device=device) + y = torch.randn(1, 4, dtype=torch.float, device=device) - scripted = self.checkScript(f, (x, y)) - self.assertEqual(scripted(x, y).shape, (3, 4)) - self.assertAllFused(scripted.graph_for(x, y)) + scripted = self.checkScript(f, (x, y)) + self.assertEqual(scripted(x, y).shape, (3, 4)) + self.assertAllFused(scripted.graph_for(x, y)) - @unittest.skipIf(not RUN_CUDA, "No CUDA") - def test_chunk_cuda(self): - def fn(x): - a, b, c = x.chunk(3, 1) - return a * b + c + def test_chunk(self): + for device in self.devices: + def fn(x): + a, b, c = x.chunk(3, 1) + return a * b + c - inputs = [torch.randn(10, 6, dtype=torch.float, device='cuda')] + inputs = [torch.randn(10, 6, dtype=torch.float, device=device)] - self.checkScript(fn, inputs) - self.assertLastGraphAllFused() + self.checkScript(fn, inputs) + self.assertLastGraphAllFused() @staticmethod def _test_chunk_correctness(self, device='cpu'): @@ -303,99 +302,96 @@ class TestTEFuser(JitTestCase): "ConstantChunk", 1, exactly=True ).run(str(graph)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_chunk_motion_deduplicates_inputs(self): - def func1(x): - z = x * x - z0, z1 = z.chunk(2) - return z0 * z1 + for device in self.devices: + def func1(x): + z = x * x + z0, z1 = z.chunk(2) + return z0 * z1 - def func2(x): - z = x * x * x - z0, z1 = z.chunk(2) - return z0 * z1 + def func2(x): + z = x * x * x + z0, z1 = z.chunk(2) + return z0 * z1 - inputs = [ - torch.tensor([1.1, 1.2], device='cuda', dtype=torch.float), - ] - for func in [func1, func2]: - self.checkScript(func, inputs) - self.assertLastGraphAllFused() + inputs = [ + torch.tensor([1.1, 1.2], device=device, dtype=torch.float), + ] + for func in [func1, func2]: + self.checkScript(func, inputs) + self.assertLastGraphAllFused() - @unittest.skipIf(not RUN_CUDA, "No CUDA") - def test_chunk_multiple_cuda(self): - # The arguments are intentionally used out of order as a test to see - # if the fusion compiler adds extra args in the correct order - def fn(s, x, y, z): - z1, z2 = z.chunk(2, 2) - x1, x2, x3 = x.chunk(3, 1) - y1, y2 = y.chunk(2, 0) - return s + x1 + x2 + x3 + y1 + y2 + z1 + z2 + def test_chunk_multiple(self): + for device in self.devices: + # The arguments are intentionally used out of order as a test to see + # if the fusion compiler adds extra args in the correct order + def fn(s, x, y, z): + z1, z2 = z.chunk(2, 2) + x1, x2, x3 = x.chunk(3, 1) + y1, y2 = y.chunk(2, 0) + return s + x1 + x2 + x3 + y1 + y2 + z1 + z2 - inputs = [ - torch.randn(5, 2, 3, dtype=torch.float, device='cuda'), - torch.randn(5, 6, 3, dtype=torch.float, device='cuda'), - torch.randn(10, 2, 3, dtype=torch.float, device='cuda'), - torch.randn(5, 2, 6, dtype=torch.float, device='cuda'), - ] + inputs = [ + torch.randn(5, 2, 3, dtype=torch.float, device=device), + torch.randn(5, 6, 3, dtype=torch.float, device=device), + torch.randn(10, 2, 3, dtype=torch.float, device=device), + torch.randn(5, 2, 6, dtype=torch.float, device=device), + ] - ge = self.checkScript(fn, inputs) - self.assertAllFused(ge.graph_for(*inputs)) + ge = self.checkScript(fn, inputs) + self.assertAllFused(ge.graph_for(*inputs)) def test_minmax(self): - def tmax(a, b): - return torch.max(2 * a, b) + for device in self.devices: + def tmax(a, b): + return torch.max(2 * a, b) - def tmin(a, b): - return torch.min(2 * a, b) + def tmin(a, b): + return torch.min(2 * a, b) - a = torch.randn(4, 4, dtype=torch.float) - b = torch.randn(4, 4, dtype=torch.float) - nan = torch.tensor(float('nan'), dtype=torch.float) + a = torch.randn(4, 4, dtype=torch.float) + b = torch.randn(4, 4, dtype=torch.float) + nan = torch.tensor(float('nan'), dtype=torch.float) - devices = ["cpu"] - if torch.cuda.is_available(): - devices.append("cuda") - for f, inputs, device in product( - (tmax, tmin), - ([a, b], [a, nan], [b, nan]), - devices): - inputs = [t.to(device) for t in inputs] - s = self.checkScript(f, inputs) - self.assertAllFused(s.graph_for(*inputs)) + for f, inputs, device in product( + (tmax, tmin), + ([a, b], [a, nan], [b, nan]), + self.devices): + inputs = [t.to(device) for t in inputs] + s = self.checkScript(f, inputs) + self.assertAllFused(s.graph_for(*inputs)) - # TODO: reenable the test after backwards passes start working in PE - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_clamp(self): - def func2(a, b): - return torch.clamp(a + b, min=0, max=2) + for device in self.devices: + def func2(a, b): + return torch.clamp(a + b, min=0, max=2) - def funcInf(a, b): - return torch.clamp(a + b, min=0, max=float('inf')) + def funcInf(a, b): + return torch.clamp(a + b, min=0, max=float('inf')) - def funcNegInf(a, b): - return torch.clamp(a + b, min=float('-inf'), max=0) + def funcNegInf(a, b): + return torch.clamp(a + b, min=float('-inf'), max=0) - def funcOptMin(a, b): - return torch.clamp(a + b, max=2) + def funcOptMin(a, b): + return torch.clamp(a + b, max=2) - def funcOptMax(a, b): - return torch.clamp(a + b, min=0) + def funcOptMax(a, b): + return torch.clamp(a + b, min=0) - a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True) - b = torch.randn(4, 4, dtype=torch.float, device='cuda') - nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda') + a = torch.randn(4, 4, dtype=torch.float, device=device, requires_grad=True) + b = torch.randn(4, 4, dtype=torch.float, device=device) + nan = torch.tensor(float('nan'), dtype=torch.float, device=device) - funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax) - for f, inputs in product(funcs, [[a, b], [a, nan]]): - inp1, inp2 = inputs - s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING) - self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'}) - c = s(inp1, inp2) - with enable_profiling_mode_for_profiling_tests(): - warmup_backward(c.sum()) - graph = backward_graph(s) - self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'}) + funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax) + for f, inputs in product(funcs, [[a, b], [a, nan]]): + inp1, inp2 = inputs + s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING) + self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'}) + c = s(inp1, inp2) + with enable_profiling_mode_for_profiling_tests(): + warmup_backward(c.sum()) + graph = backward_graph(s) + self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'}) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on") @@ -425,31 +421,31 @@ class TestTEFuser(JitTestCase): ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) self.assertAllFused(ge.graph_for(x, y, z)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_mul_bool(self): - def f(x, y, z): - return x * y * z + for device in self.devices: + def f(x, y, z): + return x * y * z - x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') - y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') - z = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') + x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) + y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) + z = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) - ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) - self.assertAllFused(ge.graph_for(x, y, z)) + ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) + self.assertAllFused(ge.graph_for(x, y, z)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_div_bool(self): - def f(x, y, z): - return (x + y) / z + for device in self.devices: + def f(x, y, z): + return (x + y) / z - x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') - y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') - z = torch.ones_like(x, dtype=torch.bool, device='cuda') + x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) + y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) + z = torch.ones_like(x, dtype=torch.bool, device=device) - ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) - self.assertAllFused(ge.graph_for(x, y, z)) + ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) + self.assertAllFused(ge.graph_for(x, y, z)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") + @unittest.skipIf(not torch._C._llvm_enabled(), "TODO: bugs in ir eval") def test_bitwise_ops(self): def apply(fn): return lambda x, y, z: fn(fn(x, y), z) @@ -467,7 +463,7 @@ class TestTEFuser(JitTestCase): operator.__or__, operator.__xor__ ] - devices = ["cuda"] + devices = self.devices for dtype, op, device in product(dtypes, binary_ops, devices): try: x = self.data_for(dtype, device) @@ -528,20 +524,20 @@ class TestTEFuser(JitTestCase): " ".join(["Failed:", str(dtype), op.__name__, device]) ) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_comparison_eq_ne(self): - def f(x, y): - mask = (x == 0).type_as(x) - z = x * mask + y - mask = (x != 0).type_as(x) - z = z * mask + y - return z + for device in self.devices: + def f(x, y): + mask = (x == 0).type_as(x) + z = x * mask + y + mask = (x != 0).type_as(x) + z = z * mask + y + return z - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - y = torch.randn(4, 4, dtype=torch.float, device='cuda') + x = torch.randn(4, 4, dtype=torch.float, device=device) + y = torch.randn(4, 4, dtype=torch.float, device=device) - ge = self.checkTrace(f, (x, y)) - self.assertAllFused(ge.graph_for(x, y)) + ge = self.checkTrace(f, (x, y)) + self.assertAllFused(ge.graph_for(x, y)) @staticmethod def fn_test_comparison_gt_lt(x, y): @@ -551,47 +547,47 @@ class TestTEFuser(JitTestCase): z = z * mask + y return z - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_comparison_gt_lt_cuda(self): - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - y = torch.randn(4, 4, dtype=torch.float, device='cuda') + def test_comparison_gt_lt(self): + for device in self.devices: + x = torch.randn(4, 4, dtype=torch.float, device=device) + y = torch.randn(4, 4, dtype=torch.float, device=device) - ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y)) - self.assertAllFused(ge.graph_for(x, y)) + ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y)) + self.assertAllFused(ge.graph_for(x, y)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_comparison_ge_le_cuda(self): - def f(x, y): - mask = (x >= 0).type_as(x) - z = x * mask + y - mask = (x <= 0).type_as(x) - z = z * mask + y - return z + def test_comparison_ge_le(self): + for device in self.devices: + def f(x, y): + mask = (x >= 0).type_as(x) + z = x * mask + y + mask = (x <= 0).type_as(x) + z = z * mask + y + return z - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - y = torch.randn(4, 4, dtype=torch.float, device='cuda') + x = torch.randn(4, 4, dtype=torch.float, device=device) + y = torch.randn(4, 4, dtype=torch.float, device=device) - ge = self.checkTrace(f, (x, y)) - self.assertAllFused(ge.graph_for(x, y)) - x.requires_grad_(True) - y.requires_grad_(True) - self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes", - "aten::_size_if_not_equal")) + ge = self.checkTrace(f, (x, y)) + self.assertAllFused(ge.graph_for(x, y)) + x.requires_grad_(True) + y.requires_grad_(True) + self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes", + "aten::_size_if_not_equal")) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_addcmul_cuda(self): - t = torch.randn(1, 4, dtype=torch.float, device='cuda') - t1 = torch.randn(4, 1, dtype=torch.float, device='cuda') - t2 = torch.randn(1, 4, dtype=torch.float, device='cuda') + def test_addcmul(self): + for device in self.devices: + t = torch.randn(1, 4, dtype=torch.float, device=device) + t1 = torch.randn(4, 1, dtype=torch.float, device=device) + t2 = torch.randn(1, 4, dtype=torch.float, device=device) - def foo(t, t1, t2): - return t.addcmul(t + 1, t2, value=0.1) + def foo(t, t1, t2): + return t.addcmul(t + 1, t2, value=0.1) - ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True) - graph = ge.graph_for(t, t1, t2) - fusion_groups = self.findFusionGroups(graph) - self.assertEqual(len(fusion_groups), 1) - FileCheck().check("aten::add(").check("aten::addcmul(").run(str(fusion_groups[0])) + ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True) + graph = ge.graph_for(t, t1, t2) + fusion_groups = self.findFusionGroups(graph) + self.assertEqual(len(fusion_groups), 1) + FileCheck().check("aten::add(").check("aten::addcmul(").run(str(fusion_groups[0])) # TODO: We leak CUDA memory here because the traced graph holds onto a # constant-ified tensor. Since the Python-global CompilationUnit is alive diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index bdebb355e33..a6af57152b9 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -184,6 +184,7 @@ def _jit_can_fuse_on_cpu() -> _bool: ... def _jit_can_fuse_on_gpu() -> _bool: ... def _jit_texpr_fuser_enabled() -> _bool: ... def _jit_nvfuser_enabled() -> _bool: ... +def _llvm_enabled() -> _bool: ... def _jit_override_can_fuse_on_cpu(override: _bool): ... def _jit_override_can_fuse_on_gpu(override: _bool): ... def _jit_set_texpr_fuser_enabled(enable: _bool): ... diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 9cc2f07b2e6..c2de6ec9292 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -629,6 +629,13 @@ void initJITBindings(PyObject* module) { using namespace torch::jit::tensorexpr; getTEMustUseLLVMOnCPU() = use_llvm; }) + .def("_llvm_enabled", []() { + #ifdef TORCH_ENABLE_LLVM + return true; + #else + return false; + #endif + }) .def( "_jit_pass_fuse_tensorexprs", [](std::shared_ptr& g) { return FuseTensorExprs(g); })