mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Add more CPU tests (#47369)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47369 Test Plan: Imported from OSS Reviewed By: ansley Differential Revision: D24805251 Pulled By: eellison fbshipit-source-id: f1a8210ffdc3cc88354cb4896652151d83a0345a
This commit is contained in:
parent
b8a1070ec0
commit
fe81faee5f
|
|
@ -170,18 +170,17 @@ class TestTEFuser(JitTestCase):
|
||||||
traced_f = torch.jit.trace(f, (x, y,))
|
traced_f = torch.jit.trace(f, (x, y,))
|
||||||
self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y))
|
self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y))
|
||||||
|
|
||||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
def test_broadcast(self):
|
||||||
def test_broadcast_cuda(self):
|
for device in self.devices:
|
||||||
def scaleshift(x, scale, shift):
|
def scaleshift(x, scale, shift):
|
||||||
return x * scale + shift
|
return x * scale + shift
|
||||||
|
|
||||||
inputs = [
|
inputs = [
|
||||||
torch.randn(4, 4, dtype=torch.float, device='cuda'),
|
torch.randn(4, 4, dtype=torch.float, device=device),
|
||||||
torch.randn(4, dtype=torch.float, device='cuda'),
|
torch.randn(4, dtype=torch.float, device=device),
|
||||||
torch.randn(4, dtype=torch.float, device='cuda'),
|
torch.randn(4, dtype=torch.float, device=device),
|
||||||
]
|
]
|
||||||
self.checkScript(scaleshift, inputs)
|
self.checkScript(scaleshift, inputs)
|
||||||
self.assertLastGraphAllFused()
|
|
||||||
|
|
||||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||||
@unittest.skipIf(not RUN_CUDA_HALF, "no half support")
|
@unittest.skipIf(not RUN_CUDA_HALF, "no half support")
|
||||||
|
|
@ -219,8 +218,8 @@ class TestTEFuser(JitTestCase):
|
||||||
grads_half = [t.half() for t in grads]
|
grads_half = [t.half() for t in grads]
|
||||||
self.assertEqual(grads_half, fusion_grads)
|
self.assertEqual(grads_half, fusion_grads)
|
||||||
|
|
||||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
||||||
def test_checks_cat_inputs(self):
|
def test_checks_cat_inputs(self):
|
||||||
|
for device in self.devices:
|
||||||
# We shouldn't treat cat nodes as broadcasting. All their inputs
|
# We shouldn't treat cat nodes as broadcasting. All their inputs
|
||||||
# need to be checked for having the same map size, before we can
|
# need to be checked for having the same map size, before we can
|
||||||
# run the kernel.
|
# run the kernel.
|
||||||
|
|
@ -229,20 +228,20 @@ class TestTEFuser(JitTestCase):
|
||||||
|
|
||||||
# NOTE: y is broadcastable to x, but output of f(x, y) should have
|
# NOTE: y is broadcastable to x, but output of f(x, y) should have
|
||||||
# shape 3x4, and not 4x4.
|
# shape 3x4, and not 4x4.
|
||||||
x = torch.randn(2, 4, dtype=torch.float, device='cuda')
|
x = torch.randn(2, 4, dtype=torch.float, device=device)
|
||||||
y = torch.randn(1, 4, dtype=torch.float, device='cuda')
|
y = torch.randn(1, 4, dtype=torch.float, device=device)
|
||||||
|
|
||||||
scripted = self.checkScript(f, (x, y))
|
scripted = self.checkScript(f, (x, y))
|
||||||
self.assertEqual(scripted(x, y).shape, (3, 4))
|
self.assertEqual(scripted(x, y).shape, (3, 4))
|
||||||
self.assertAllFused(scripted.graph_for(x, y))
|
self.assertAllFused(scripted.graph_for(x, y))
|
||||||
|
|
||||||
@unittest.skipIf(not RUN_CUDA, "No CUDA")
|
def test_chunk(self):
|
||||||
def test_chunk_cuda(self):
|
for device in self.devices:
|
||||||
def fn(x):
|
def fn(x):
|
||||||
a, b, c = x.chunk(3, 1)
|
a, b, c = x.chunk(3, 1)
|
||||||
return a * b + c
|
return a * b + c
|
||||||
|
|
||||||
inputs = [torch.randn(10, 6, dtype=torch.float, device='cuda')]
|
inputs = [torch.randn(10, 6, dtype=torch.float, device=device)]
|
||||||
|
|
||||||
self.checkScript(fn, inputs)
|
self.checkScript(fn, inputs)
|
||||||
self.assertLastGraphAllFused()
|
self.assertLastGraphAllFused()
|
||||||
|
|
@ -303,8 +302,8 @@ class TestTEFuser(JitTestCase):
|
||||||
"ConstantChunk", 1, exactly=True
|
"ConstantChunk", 1, exactly=True
|
||||||
).run(str(graph))
|
).run(str(graph))
|
||||||
|
|
||||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
||||||
def test_chunk_motion_deduplicates_inputs(self):
|
def test_chunk_motion_deduplicates_inputs(self):
|
||||||
|
for device in self.devices:
|
||||||
def func1(x):
|
def func1(x):
|
||||||
z = x * x
|
z = x * x
|
||||||
z0, z1 = z.chunk(2)
|
z0, z1 = z.chunk(2)
|
||||||
|
|
@ -316,14 +315,14 @@ class TestTEFuser(JitTestCase):
|
||||||
return z0 * z1
|
return z0 * z1
|
||||||
|
|
||||||
inputs = [
|
inputs = [
|
||||||
torch.tensor([1.1, 1.2], device='cuda', dtype=torch.float),
|
torch.tensor([1.1, 1.2], device=device, dtype=torch.float),
|
||||||
]
|
]
|
||||||
for func in [func1, func2]:
|
for func in [func1, func2]:
|
||||||
self.checkScript(func, inputs)
|
self.checkScript(func, inputs)
|
||||||
self.assertLastGraphAllFused()
|
self.assertLastGraphAllFused()
|
||||||
|
|
||||||
@unittest.skipIf(not RUN_CUDA, "No CUDA")
|
def test_chunk_multiple(self):
|
||||||
def test_chunk_multiple_cuda(self):
|
for device in self.devices:
|
||||||
# The arguments are intentionally used out of order as a test to see
|
# The arguments are intentionally used out of order as a test to see
|
||||||
# if the fusion compiler adds extra args in the correct order
|
# if the fusion compiler adds extra args in the correct order
|
||||||
def fn(s, x, y, z):
|
def fn(s, x, y, z):
|
||||||
|
|
@ -333,16 +332,17 @@ class TestTEFuser(JitTestCase):
|
||||||
return s + x1 + x2 + x3 + y1 + y2 + z1 + z2
|
return s + x1 + x2 + x3 + y1 + y2 + z1 + z2
|
||||||
|
|
||||||
inputs = [
|
inputs = [
|
||||||
torch.randn(5, 2, 3, dtype=torch.float, device='cuda'),
|
torch.randn(5, 2, 3, dtype=torch.float, device=device),
|
||||||
torch.randn(5, 6, 3, dtype=torch.float, device='cuda'),
|
torch.randn(5, 6, 3, dtype=torch.float, device=device),
|
||||||
torch.randn(10, 2, 3, dtype=torch.float, device='cuda'),
|
torch.randn(10, 2, 3, dtype=torch.float, device=device),
|
||||||
torch.randn(5, 2, 6, dtype=torch.float, device='cuda'),
|
torch.randn(5, 2, 6, dtype=torch.float, device=device),
|
||||||
]
|
]
|
||||||
|
|
||||||
ge = self.checkScript(fn, inputs)
|
ge = self.checkScript(fn, inputs)
|
||||||
self.assertAllFused(ge.graph_for(*inputs))
|
self.assertAllFused(ge.graph_for(*inputs))
|
||||||
|
|
||||||
def test_minmax(self):
|
def test_minmax(self):
|
||||||
|
for device in self.devices:
|
||||||
def tmax(a, b):
|
def tmax(a, b):
|
||||||
return torch.max(2 * a, b)
|
return torch.max(2 * a, b)
|
||||||
|
|
||||||
|
|
@ -353,20 +353,16 @@ class TestTEFuser(JitTestCase):
|
||||||
b = torch.randn(4, 4, dtype=torch.float)
|
b = torch.randn(4, 4, dtype=torch.float)
|
||||||
nan = torch.tensor(float('nan'), dtype=torch.float)
|
nan = torch.tensor(float('nan'), dtype=torch.float)
|
||||||
|
|
||||||
devices = ["cpu"]
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
devices.append("cuda")
|
|
||||||
for f, inputs, device in product(
|
for f, inputs, device in product(
|
||||||
(tmax, tmin),
|
(tmax, tmin),
|
||||||
([a, b], [a, nan], [b, nan]),
|
([a, b], [a, nan], [b, nan]),
|
||||||
devices):
|
self.devices):
|
||||||
inputs = [t.to(device) for t in inputs]
|
inputs = [t.to(device) for t in inputs]
|
||||||
s = self.checkScript(f, inputs)
|
s = self.checkScript(f, inputs)
|
||||||
self.assertAllFused(s.graph_for(*inputs))
|
self.assertAllFused(s.graph_for(*inputs))
|
||||||
|
|
||||||
# TODO: reenable the test after backwards passes start working in PE
|
|
||||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
||||||
def test_clamp(self):
|
def test_clamp(self):
|
||||||
|
for device in self.devices:
|
||||||
def func2(a, b):
|
def func2(a, b):
|
||||||
return torch.clamp(a + b, min=0, max=2)
|
return torch.clamp(a + b, min=0, max=2)
|
||||||
|
|
||||||
|
|
@ -382,9 +378,9 @@ class TestTEFuser(JitTestCase):
|
||||||
def funcOptMax(a, b):
|
def funcOptMax(a, b):
|
||||||
return torch.clamp(a + b, min=0)
|
return torch.clamp(a + b, min=0)
|
||||||
|
|
||||||
a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
|
a = torch.randn(4, 4, dtype=torch.float, device=device, requires_grad=True)
|
||||||
b = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
b = torch.randn(4, 4, dtype=torch.float, device=device)
|
||||||
nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda')
|
nan = torch.tensor(float('nan'), dtype=torch.float, device=device)
|
||||||
|
|
||||||
funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax)
|
funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax)
|
||||||
for f, inputs in product(funcs, [[a, b], [a, nan]]):
|
for f, inputs in product(funcs, [[a, b], [a, nan]]):
|
||||||
|
|
@ -425,31 +421,31 @@ class TestTEFuser(JitTestCase):
|
||||||
ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
|
ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
|
||||||
self.assertAllFused(ge.graph_for(x, y, z))
|
self.assertAllFused(ge.graph_for(x, y, z))
|
||||||
|
|
||||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
||||||
def test_mul_bool(self):
|
def test_mul_bool(self):
|
||||||
|
for device in self.devices:
|
||||||
def f(x, y, z):
|
def f(x, y, z):
|
||||||
return x * y * z
|
return x * y * z
|
||||||
|
|
||||||
x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda')
|
x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
|
||||||
y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda')
|
y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
|
||||||
z = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda')
|
z = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
|
||||||
|
|
||||||
ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
|
ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
|
||||||
self.assertAllFused(ge.graph_for(x, y, z))
|
self.assertAllFused(ge.graph_for(x, y, z))
|
||||||
|
|
||||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
||||||
def test_div_bool(self):
|
def test_div_bool(self):
|
||||||
|
for device in self.devices:
|
||||||
def f(x, y, z):
|
def f(x, y, z):
|
||||||
return (x + y) / z
|
return (x + y) / z
|
||||||
|
|
||||||
x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda')
|
x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
|
||||||
y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda')
|
y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
|
||||||
z = torch.ones_like(x, dtype=torch.bool, device='cuda')
|
z = torch.ones_like(x, dtype=torch.bool, device=device)
|
||||||
|
|
||||||
ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
|
ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
|
||||||
self.assertAllFused(ge.graph_for(x, y, z))
|
self.assertAllFused(ge.graph_for(x, y, z))
|
||||||
|
|
||||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
@unittest.skipIf(not torch._C._llvm_enabled(), "TODO: bugs in ir eval")
|
||||||
def test_bitwise_ops(self):
|
def test_bitwise_ops(self):
|
||||||
def apply(fn):
|
def apply(fn):
|
||||||
return lambda x, y, z: fn(fn(x, y), z)
|
return lambda x, y, z: fn(fn(x, y), z)
|
||||||
|
|
@ -467,7 +463,7 @@ class TestTEFuser(JitTestCase):
|
||||||
operator.__or__,
|
operator.__or__,
|
||||||
operator.__xor__
|
operator.__xor__
|
||||||
]
|
]
|
||||||
devices = ["cuda"]
|
devices = self.devices
|
||||||
for dtype, op, device in product(dtypes, binary_ops, devices):
|
for dtype, op, device in product(dtypes, binary_ops, devices):
|
||||||
try:
|
try:
|
||||||
x = self.data_for(dtype, device)
|
x = self.data_for(dtype, device)
|
||||||
|
|
@ -528,8 +524,8 @@ class TestTEFuser(JitTestCase):
|
||||||
" ".join(["Failed:", str(dtype), op.__name__, device])
|
" ".join(["Failed:", str(dtype), op.__name__, device])
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
||||||
def test_comparison_eq_ne(self):
|
def test_comparison_eq_ne(self):
|
||||||
|
for device in self.devices:
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
mask = (x == 0).type_as(x)
|
mask = (x == 0).type_as(x)
|
||||||
z = x * mask + y
|
z = x * mask + y
|
||||||
|
|
@ -537,8 +533,8 @@ class TestTEFuser(JitTestCase):
|
||||||
z = z * mask + y
|
z = z * mask + y
|
||||||
return z
|
return z
|
||||||
|
|
||||||
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
x = torch.randn(4, 4, dtype=torch.float, device=device)
|
||||||
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
y = torch.randn(4, 4, dtype=torch.float, device=device)
|
||||||
|
|
||||||
ge = self.checkTrace(f, (x, y))
|
ge = self.checkTrace(f, (x, y))
|
||||||
self.assertAllFused(ge.graph_for(x, y))
|
self.assertAllFused(ge.graph_for(x, y))
|
||||||
|
|
@ -551,16 +547,16 @@ class TestTEFuser(JitTestCase):
|
||||||
z = z * mask + y
|
z = z * mask + y
|
||||||
return z
|
return z
|
||||||
|
|
||||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
def test_comparison_gt_lt(self):
|
||||||
def test_comparison_gt_lt_cuda(self):
|
for device in self.devices:
|
||||||
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
x = torch.randn(4, 4, dtype=torch.float, device=device)
|
||||||
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
y = torch.randn(4, 4, dtype=torch.float, device=device)
|
||||||
|
|
||||||
ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
|
ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
|
||||||
self.assertAllFused(ge.graph_for(x, y))
|
self.assertAllFused(ge.graph_for(x, y))
|
||||||
|
|
||||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
def test_comparison_ge_le(self):
|
||||||
def test_comparison_ge_le_cuda(self):
|
for device in self.devices:
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
mask = (x >= 0).type_as(x)
|
mask = (x >= 0).type_as(x)
|
||||||
z = x * mask + y
|
z = x * mask + y
|
||||||
|
|
@ -568,8 +564,8 @@ class TestTEFuser(JitTestCase):
|
||||||
z = z * mask + y
|
z = z * mask + y
|
||||||
return z
|
return z
|
||||||
|
|
||||||
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
x = torch.randn(4, 4, dtype=torch.float, device=device)
|
||||||
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
y = torch.randn(4, 4, dtype=torch.float, device=device)
|
||||||
|
|
||||||
ge = self.checkTrace(f, (x, y))
|
ge = self.checkTrace(f, (x, y))
|
||||||
self.assertAllFused(ge.graph_for(x, y))
|
self.assertAllFused(ge.graph_for(x, y))
|
||||||
|
|
@ -578,11 +574,11 @@ class TestTEFuser(JitTestCase):
|
||||||
self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes",
|
self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes",
|
||||||
"aten::_size_if_not_equal"))
|
"aten::_size_if_not_equal"))
|
||||||
|
|
||||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
def test_addcmul(self):
|
||||||
def test_addcmul_cuda(self):
|
for device in self.devices:
|
||||||
t = torch.randn(1, 4, dtype=torch.float, device='cuda')
|
t = torch.randn(1, 4, dtype=torch.float, device=device)
|
||||||
t1 = torch.randn(4, 1, dtype=torch.float, device='cuda')
|
t1 = torch.randn(4, 1, dtype=torch.float, device=device)
|
||||||
t2 = torch.randn(1, 4, dtype=torch.float, device='cuda')
|
t2 = torch.randn(1, 4, dtype=torch.float, device=device)
|
||||||
|
|
||||||
def foo(t, t1, t2):
|
def foo(t, t1, t2):
|
||||||
return t.addcmul(t + 1, t2, value=0.1)
|
return t.addcmul(t + 1, t2, value=0.1)
|
||||||
|
|
|
||||||
|
|
@ -184,6 +184,7 @@ def _jit_can_fuse_on_cpu() -> _bool: ...
|
||||||
def _jit_can_fuse_on_gpu() -> _bool: ...
|
def _jit_can_fuse_on_gpu() -> _bool: ...
|
||||||
def _jit_texpr_fuser_enabled() -> _bool: ...
|
def _jit_texpr_fuser_enabled() -> _bool: ...
|
||||||
def _jit_nvfuser_enabled() -> _bool: ...
|
def _jit_nvfuser_enabled() -> _bool: ...
|
||||||
|
def _llvm_enabled() -> _bool: ...
|
||||||
def _jit_override_can_fuse_on_cpu(override: _bool): ...
|
def _jit_override_can_fuse_on_cpu(override: _bool): ...
|
||||||
def _jit_override_can_fuse_on_gpu(override: _bool): ...
|
def _jit_override_can_fuse_on_gpu(override: _bool): ...
|
||||||
def _jit_set_texpr_fuser_enabled(enable: _bool): ...
|
def _jit_set_texpr_fuser_enabled(enable: _bool): ...
|
||||||
|
|
|
||||||
|
|
@ -629,6 +629,13 @@ void initJITBindings(PyObject* module) {
|
||||||
using namespace torch::jit::tensorexpr;
|
using namespace torch::jit::tensorexpr;
|
||||||
getTEMustUseLLVMOnCPU() = use_llvm;
|
getTEMustUseLLVMOnCPU() = use_llvm;
|
||||||
})
|
})
|
||||||
|
.def("_llvm_enabled", []() {
|
||||||
|
#ifdef TORCH_ENABLE_LLVM
|
||||||
|
return true;
|
||||||
|
#else
|
||||||
|
return false;
|
||||||
|
#endif
|
||||||
|
})
|
||||||
.def(
|
.def(
|
||||||
"_jit_pass_fuse_tensorexprs",
|
"_jit_pass_fuse_tensorexprs",
|
||||||
[](std::shared_ptr<Graph>& g) { return FuseTensorExprs(g); })
|
[](std::shared_ptr<Graph>& g) { return FuseTensorExprs(g); })
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user