mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58249 Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D28425629 Pulled By: Krovatkin fbshipit-source-id: 494ab165d548aa76f036344ab1c19c5fd64bae82
899 lines
37 KiB
Python
899 lines
37 KiB
Python
import unittest
|
|
import os
|
|
|
|
import torch
|
|
|
|
from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR
|
|
from torch.testing._internal.codegen.random_topo_test import runDefaultTestWithSeed
|
|
|
|
from test_jit import JitTestCase, RUN_CUDA
|
|
|
|
from jit.test_fuser_common import TestFuserCommon # noqa: F401
|
|
|
|
import itertools
|
|
import numpy as np
|
|
|
|
os.environ['PYTORCH_CUDA_FUSER_DISABLE_FALLBACK'] = '1'
|
|
os.environ['PYTORCH_CUDA_FUSER_DISABLE_FMA'] = '1'
|
|
os.environ['PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL'] = '0'
|
|
|
|
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
|
|
torch._C._jit_set_texpr_fuser_enabled(False)
|
|
torch._C._jit_set_profiling_executor(True)
|
|
torch._C._jit_set_profiling_mode(True)
|
|
|
|
FUSION_GROUP = 'prim::CudaFusionGroup'
|
|
FUSION_GUARD = 'prim::CudaFusionGuard'
|
|
|
|
class TestCudaFuser(JitTestCase):
|
|
|
|
def _getSubgraphInFusion(self, graph):
|
|
num_node = 0
|
|
subgraph = None
|
|
|
|
def count(block, ret):
|
|
for n in block.nodes():
|
|
if n.kind() == FUSION_GROUP:
|
|
ret[0] = ret[0] + 1
|
|
self.assertTrue(n.hasAttribute('Subgraph'))
|
|
ret[1] = n.g('Subgraph')
|
|
for block in n.blocks():
|
|
count(block, ret)
|
|
ret = [num_node, subgraph]
|
|
count(graph, ret)
|
|
self.assertEqual(ret[0], 1)
|
|
return ret[1]
|
|
|
|
def setUp(self):
|
|
super(TestCudaFuser, self).setUp()
|
|
self.old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
|
|
self.old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
|
|
torch._C._jit_override_can_fuse_on_cpu(False)
|
|
torch._C._jit_override_can_fuse_on_gpu(False)
|
|
self.old_guard = torch._C._jit_set_nvfuser_guard_mode(False)
|
|
|
|
if(RUN_CUDA):
|
|
self.old_nvfuser = torch._C._jit_set_nvfuser_enabled(True)
|
|
|
|
def tearDown(self):
|
|
if(RUN_CUDA):
|
|
torch._C._jit_set_nvfuser_enabled(self.old_nvfuser)
|
|
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuse)
|
|
torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuse)
|
|
torch._C._jit_set_nvfuser_guard_mode(self.old_guard)
|
|
super(TestCudaFuser, self).tearDown()
|
|
|
|
def _run_helper(self, jit_op, op, *args):
|
|
torch.cuda.manual_seed_all(123)
|
|
jit_o = jit_op(*args)
|
|
torch.cuda.manual_seed_all(123)
|
|
jit_o = jit_op(*args)
|
|
torch.cuda.manual_seed_all(123)
|
|
o = op(*args)
|
|
self.assertEqual(o, jit_o)
|
|
self.assertGraphContains(jit_op.graph_for(*args), FUSION_GUARD)
|
|
|
|
def _run_training_helper(self, jit_op, op, grads, *args):
|
|
torch.cuda.manual_seed_all(123)
|
|
jit_o = jit_op(*args)
|
|
jit_g = jit_o.backward(grads)
|
|
torch.cuda.manual_seed_all(123)
|
|
jit_o = jit_op(*args)
|
|
jit_g = jit_o.backward(grads)
|
|
torch.cuda.manual_seed_all(123)
|
|
jit_o = jit_op(*args)
|
|
jit_g = jit_o.backward(grads)
|
|
torch.cuda.manual_seed_all(123)
|
|
o = op(*args)
|
|
g = o.backward(grads)
|
|
self.assertEqual(o, jit_o)
|
|
self.assertEqual(g, jit_g)
|
|
self.assertGraphContainsExactly(jit_op.graph_for(*args), FUSION_GUARD, 1, consider_subgraphs=True)
|
|
bwd_graph = list(
|
|
list(jit_op.get_debug_state().execution_plans.values())[
|
|
0].code.grad_executor_states()[0].execution_plans.values()
|
|
)[0].graph
|
|
self.assertGraphContainsExactly(bwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_half(self):
|
|
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, alpha: float):
|
|
o_16 = torch.add(x, y)
|
|
o_32_a = torch.add(y, z, alpha=alpha)
|
|
o_32_b = torch.add(o_16, z)
|
|
return (o_16, o_32_a, o_32_b)
|
|
|
|
t_jit = torch.jit.script(t)
|
|
alpha = 0.5
|
|
# stick to integers, this avoid the numerical difference due to our
|
|
# promotion
|
|
x = torch.randint(0, 256, (4, 8)).to(dtype=torch.float16, device="cuda")
|
|
y = torch.randint(0, 256, (4, 8)).to(dtype=torch.float16, device="cuda")
|
|
z = torch.randint(0, 256, (4, 8)).to(dtype=torch.float16, device="cuda")
|
|
jit_o = t_jit(x, y, z, alpha)
|
|
jit_o = t_jit(x, y, z, alpha)
|
|
o = t(x, y, z, alpha)
|
|
for oo, jit_oo in zip(o, jit_o):
|
|
self.assertEqual(oo.dtype, jit_oo.dtype)
|
|
self.assertEqual(oo, jit_oo)
|
|
self.assertGraphContains(t_jit.graph_for(x, y, z, alpha), FUSION_GUARD)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_const(self):
|
|
def t(x, y):
|
|
o = x + y
|
|
o = o + 2.0
|
|
return o
|
|
t_jit = torch.jit.script(t)
|
|
x = torch.randn(4, 8, dtype=torch.float, device="cuda")
|
|
y = torch.randn(4, 8, dtype=torch.float, device="cuda")
|
|
jit_o = t_jit(x, y)
|
|
jit_o = t_jit(x, y)
|
|
o = t(x, y)
|
|
self.assertEqual(o, jit_o)
|
|
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_chunk(self):
|
|
def t(x, y, z, q):
|
|
o = x + q
|
|
x0, x1 = torch.chunk(o, 2)
|
|
o = x0 + x1
|
|
o = o + y
|
|
o = o * z
|
|
o = torch.relu(o)
|
|
return o
|
|
t_jit = torch.jit.script(t)
|
|
x = torch.randn(4, 8, dtype=torch.float, device="cuda")
|
|
y = torch.randn(2, 8, dtype=torch.float, device="cuda")
|
|
z = torch.randn(2, 8, dtype=torch.float, device="cuda")
|
|
q = torch.randn(4, 8, dtype=torch.float, device="cuda")
|
|
jit_o = t_jit(x, y, z, q)
|
|
jit_o = t_jit(x, y, z, q)
|
|
o = t(x, y, z, q)
|
|
self.assertEqual(o, jit_o)
|
|
self.assertGraphContains(t_jit.graph_for(x, y, z, q), FUSION_GUARD)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_scalar_input(self):
|
|
def t(x: torch.Tensor, y: torch.Tensor, z: float):
|
|
o = x + y
|
|
o = o + z
|
|
return o
|
|
t_jit = torch.jit.script(t)
|
|
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
|
|
y = torch.randn(4, 8, 1, 32, dtype=torch.float, device="cuda")
|
|
y = y.expand(4, 8, 32, 32)
|
|
jit_o = t_jit(x, y, 2.0)
|
|
jit_o = t_jit(x, y, 2.0)
|
|
o = t(x, y, 2.0)
|
|
self.assertEqual(o, jit_o)
|
|
self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_broadcasting_0(self):
|
|
|
|
def t(x: torch.Tensor, y: torch.Tensor, z: float):
|
|
o = x + y
|
|
o = o + z
|
|
return o
|
|
t_jit = torch.jit.script(t)
|
|
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
|
|
y = torch.randn(32, 32, dtype=torch.float, device="cuda")
|
|
jit_o = t_jit(x, y, 2.0)
|
|
jit_o = t_jit(x, y, 2.0)
|
|
o = t(x, y, 2.0)
|
|
self.assertEqual(o, jit_o)
|
|
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0))
|
|
self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_broadcasting_1(self):
|
|
|
|
def t(x: torch.Tensor, y: torch.Tensor, z: float):
|
|
o = x + y
|
|
o = o + z
|
|
return o
|
|
t_jit = torch.jit.script(t)
|
|
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
|
|
y = torch.randn(1, 32, 32, dtype=torch.float, device="cuda")
|
|
jit_o = t_jit(x, y, 2.0)
|
|
jit_o = t_jit(x, y, 2.0)
|
|
o = t(x, y, 2.0)
|
|
self.assertEqual(o, jit_o)
|
|
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0))
|
|
self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_broadcasting_2(self):
|
|
|
|
def t(x: torch.Tensor, y: torch.Tensor, z: float):
|
|
o = x + y
|
|
o = o + z
|
|
return o
|
|
t_jit = torch.jit.script(t)
|
|
x = torch.randn(4, 1, 32, 32, dtype=torch.float, device="cuda")
|
|
y = torch.randn(8, 32, 32, dtype=torch.float, device="cuda")
|
|
jit_o = t_jit(x, y, 2.0)
|
|
jit_o = t_jit(x, y, 2.0)
|
|
o = t(x, y, 2.0)
|
|
self.assertEqual(o, jit_o)
|
|
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0))
|
|
self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_broadcasting_3(self):
|
|
|
|
def t(x: torch.Tensor, y: torch.Tensor, z: float):
|
|
o = x + y
|
|
o = o + z
|
|
return o
|
|
t_jit = torch.jit.script(t)
|
|
x = torch.randn(8, 17, 8, dtype=torch.float, device="cuda")
|
|
y = torch.randn(8, 17, 1, dtype=torch.float, device="cuda")
|
|
jit_o = t_jit(x, y, 2.0)
|
|
jit_o = t_jit(x, y, 2.0)
|
|
o = t(x, y, 2.0)
|
|
self.assertEqual(o, jit_o)
|
|
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0))
|
|
self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
|
|
|
|
# test_broadcasting_partition_logic_X
|
|
# Testing partition logic that is capable to avoid creating unsupported
|
|
# broadcasting semantics in CudaFusionGroup
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_broadcasting_partition_logic_0(self):
|
|
|
|
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
|
|
x = x + 12.0
|
|
o1 = x + y
|
|
o2 = x + z
|
|
o = o1 + o2
|
|
return o
|
|
t_jit = torch.jit.script(t)
|
|
x = torch.randn(4, 8, 6, 8, dtype=torch.float32, device="cuda")
|
|
y = torch.randn(8, 6, 8, dtype=torch.float32, device="cuda")
|
|
z = torch.randn(6, 8, dtype=torch.float32, device="cuda")
|
|
jit_o = t_jit(x, y, z)
|
|
jit_o = t_jit(x, y, z)
|
|
o = t(x, y, z)
|
|
self.assertEqual(o, jit_o)
|
|
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, z))
|
|
self.assertGraphContainsExactly(subgraph, 'aten::add', 4, consider_subgraphs=False)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_broadcasting_partition_logic_1(self):
|
|
|
|
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
|
|
x = x + 12.0
|
|
o1 = x + y
|
|
o2 = x + z
|
|
o = o1 + o2
|
|
return o
|
|
t_jit = torch.jit.script(t)
|
|
x = torch.randn(8, 6, 8, dtype=torch.float32, device="cuda")
|
|
y = torch.randn(4, 8, 6, 8, dtype=torch.float32, device="cuda")
|
|
z = torch.randn(4, 1, 6, 8, dtype=torch.float32, device="cuda")
|
|
jit_o = t_jit(x, y, z)
|
|
jit_o = t_jit(x, y, z)
|
|
o = t(x, y, z)
|
|
self.assertEqual(o, jit_o)
|
|
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, z))
|
|
self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
|
|
|
|
@unittest.skipIf(True, "Broadcast with different output not supported yet")
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_broadcasting_multiple_output_shape(self):
|
|
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
|
|
o = x + 12
|
|
o1 = o + y
|
|
o2 = o + z
|
|
oo = o1.sum() + o2.sum()
|
|
return oo
|
|
t_jit = torch.jit.script(t)
|
|
x = torch.randn(32, 32, dtype=torch.float, device="cuda")
|
|
y = torch.randn(2, 32, 32, dtype=torch.float, device="cuda")
|
|
z = torch.randn(4, 32, 32, dtype=torch.float, device="cuda")
|
|
jit_o = t_jit(x, y, z)
|
|
jit_o = t_jit(x, y, z)
|
|
o = t(x, y, z)
|
|
self.assertEqual(o, jit_o)
|
|
# Currently cannot fuse this
|
|
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
|
|
|
|
@unittest.skipIf(True, "broadcast on branches can't be resolved yet")
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_broadcasting_multiple_output(self):
|
|
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
|
|
o = x + 12
|
|
o1 = o + y
|
|
o2 = o + z
|
|
oo = o1.sum() + o2.sum()
|
|
return oo
|
|
t_jit = torch.jit.script(t)
|
|
x = torch.randn(32, 32, dtype=torch.float, device="cuda")
|
|
y = torch.randn(4, 32, 32, dtype=torch.float, device="cuda")
|
|
z = torch.randn(4, 32, 32, dtype=torch.float, device="cuda")
|
|
jit_o = t_jit(x, y, z)
|
|
jit_o = t_jit(x, y, z)
|
|
o = t(x, y, z)
|
|
self.assertEqual(o, jit_o)
|
|
# Currently cannot fuse this
|
|
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
|
|
|
|
def _binary_test_helper(self, operation):
|
|
def t(x: torch.Tensor, y: torch.Tensor, z: float):
|
|
o = x + z
|
|
o = operation(o, y)
|
|
return o
|
|
t_jit = torch.jit.script(t)
|
|
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
|
|
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
|
|
jit_o = t_jit(x, y, 2.0)
|
|
jit_o = t_jit(x, y, 2.0)
|
|
o = t(x, y, 2.0)
|
|
self.assertEqual(o, jit_o)
|
|
self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD)
|
|
|
|
def _unary_test_helper(self, operation):
|
|
def t(x: torch.Tensor, z: float):
|
|
o = x + z
|
|
o = operation(o)
|
|
return o
|
|
t_jit = torch.jit.script(t)
|
|
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
|
|
jit_o = t_jit(x, 2.0)
|
|
jit_o = t_jit(x, 2.0)
|
|
o = t(x, 2.0)
|
|
self.assertEqual(o, jit_o)
|
|
self.assertGraphContains(t_jit.graph_for(x, 2.0), FUSION_GUARD)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_unary_ops(self):
|
|
operations = [torch.neg,
|
|
torch.abs,
|
|
torch.log,
|
|
torch.log10,
|
|
torch.log1p,
|
|
torch.log2,
|
|
torch.lgamma,
|
|
torch.exp,
|
|
torch.expm1,
|
|
torch.erf,
|
|
torch.erfc,
|
|
torch.cos,
|
|
torch.acos,
|
|
torch.cosh,
|
|
torch.sin,
|
|
torch.asin,
|
|
torch.tan,
|
|
torch.atan,
|
|
torch.sqrt,
|
|
torch.rsqrt,
|
|
torch.ceil,
|
|
torch.floor,
|
|
torch.round,
|
|
torch.trunc,
|
|
torch.frac,
|
|
torch.reciprocal,
|
|
torch.relu,
|
|
torch.sigmoid,
|
|
torch.tanh,
|
|
torch.nn.functional.gelu]
|
|
for op in operations:
|
|
self._unary_test_helper(op)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_binary_ops(self):
|
|
operations = [torch.div,
|
|
torch.mul,
|
|
torch.atan2,
|
|
torch.max,
|
|
torch.min,
|
|
torch.pow,
|
|
torch.remainder,
|
|
torch.fmod,
|
|
torch.eq,
|
|
torch.ne,
|
|
torch.ge,
|
|
torch.gt,
|
|
torch.le,
|
|
torch.lt]
|
|
for op in operations:
|
|
self._binary_test_helper(op)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
# legacy fuser does not work for rand_like, see issue #34361
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective")
|
|
def test_ternary_ops(self):
|
|
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
|
|
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
|
|
z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
|
|
cond = torch.randint(0, 2, (4, 8, 32, 32)).to(dtype=torch.bool, device="cuda")
|
|
|
|
def add(x: torch.Tensor, other: torch.Tensor, alpha: float):
|
|
o = torch.relu(x)
|
|
o = torch.add(o, other=other, alpha=alpha)
|
|
return o
|
|
add_jit = torch.jit.script(add)
|
|
self._run_helper(add_jit, add, x, y, 2.0)
|
|
|
|
def clamp0(x: torch.Tensor, f: float):
|
|
o = torch.rand_like(x)
|
|
o = o * torch.clamp(x, min=f)
|
|
return o
|
|
clamp0_jit = torch.jit.script(clamp0)
|
|
self._run_helper(clamp0_jit, clamp0, x, 0.5)
|
|
|
|
def clamp1(x: torch.Tensor, f: float, ff: float):
|
|
o = torch.rand_like(x)
|
|
o = o * torch.clamp(x, min=f, max=ff)
|
|
return o
|
|
clamp1_jit = torch.jit.script(clamp1)
|
|
self._run_helper(clamp1_jit, clamp1, x, -0.2, 0.7)
|
|
|
|
def threshold(x: torch.Tensor, th: float, val: float):
|
|
o = torch.rand_like(x)
|
|
o = x * torch.threshold(o, th, val)
|
|
return o
|
|
threshold_jit = torch.jit.script(threshold)
|
|
self._run_helper(threshold_jit, threshold, x, 0.2, 0.9)
|
|
|
|
def where(x: torch.Tensor, y: torch.Tensor, cond: torch.Tensor):
|
|
o = torch.rand_like(x)
|
|
o = o * torch.where(cond, x, y)
|
|
return o
|
|
where_jit = torch.jit.script(where)
|
|
self._run_helper(where_jit, where, x, y, cond)
|
|
|
|
def lerp(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
|
|
o = torch.rand_like(x)
|
|
o = o * torch.lerp(x, y, z)
|
|
return o
|
|
lerp_jit = torch.jit.script(lerp)
|
|
self._run_helper(lerp_jit, lerp, x, y, z)
|
|
|
|
def lerp_scale(x: torch.Tensor, y: torch.Tensor, z: float):
|
|
o = torch.rand_like(x)
|
|
o = o * torch.lerp(x, y, z)
|
|
return o
|
|
lerp_scale_jit = torch.jit.script(lerp_scale)
|
|
self._run_helper(lerp_scale_jit, lerp_scale, x, y, 0.5)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
|
|
def test_addcmul_ops(self):
|
|
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
|
|
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
|
|
z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
|
|
|
|
def addcmul(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, value: float):
|
|
o = torch.add(x, 0.5)
|
|
o = torch.addcmul(o, y, z, value=value)
|
|
return o
|
|
addcmul_jit = torch.jit.script(addcmul)
|
|
self._run_helper(addcmul_jit, addcmul, x, y, z, 2.0)
|
|
|
|
def addcmul_no_alpha(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
|
|
o = torch.add(x, 0.5)
|
|
o = torch.addcmul(o, y, z)
|
|
return o
|
|
addcmul_no_alpha_jit = torch.jit.script(addcmul_no_alpha)
|
|
self._run_helper(addcmul_no_alpha_jit, addcmul_no_alpha, x, y, z)
|
|
|
|
def addcmul_const_alpha(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
|
|
o = torch.add(x, 0.5)
|
|
o = torch.addcmul(o, y, z, value=0.75)
|
|
return o
|
|
addcmul_const_alpha_jit = torch.jit.script(addcmul_const_alpha)
|
|
self._run_helper(addcmul_const_alpha_jit, addcmul_const_alpha, x, y, z)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_dynamic_size(self):
|
|
old_guard = torch._C._jit_set_nvfuser_guard_mode(True)
|
|
torch._C._jit_set_bailout_depth(20)
|
|
|
|
def t(x: torch.Tensor, y: torch.Tensor, z: float):
|
|
o = x + y
|
|
o = o + z
|
|
return o
|
|
t_jit = torch.jit.script(t)
|
|
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
|
|
y = torch.randn(32, 32, dtype=torch.float, device="cuda")
|
|
jit_o = t_jit(x, y, 2.0)
|
|
jit_o = t_jit(x, y, 2.0)
|
|
o = t(x, y, 2.0)
|
|
self.assertEqual(o, jit_o)
|
|
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0))
|
|
self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
|
|
|
|
# this test is not ideal, as we rely on the bailout to test it and we
|
|
# don't know a way to verify the bailout graph to validate the proper
|
|
# fusion.
|
|
x = torch.randn(8, 32, 16, 8, dtype=torch.float, device="cuda")
|
|
y = torch.randn(16, 8, dtype=torch.float, device="cuda")
|
|
jit_o = t_jit(x, y, 2.0)
|
|
jit_o = t_jit(x, y, 2.0)
|
|
o = t(x, y, 2.0)
|
|
self.assertEqual(o, jit_o)
|
|
self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD)
|
|
x = torch.randn(8, 17, 8, dtype=torch.float, device="cuda")
|
|
y = torch.randn(8, 17, 1, dtype=torch.float, device="cuda")
|
|
jit_o = t_jit(x, y, 2.0)
|
|
jit_o = t_jit(x, y, 2.0)
|
|
o = t(x, y, 2.0)
|
|
self.assertEqual(o, jit_o)
|
|
self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD)
|
|
torch._C._jit_set_nvfuser_guard_mode(old_guard)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
def test_random_topo(self):
|
|
os.environ["PYTORCH_CUDA_FUSER_DISABLE_FALLBACK"] = "1"
|
|
self.assertTrue(runDefaultTestWithSeed(28449))
|
|
|
|
def _compare(self, desc, inp1, inp2, error):
|
|
a = inp1.clone().detach().cpu().numpy()
|
|
b = inp2.clone().detach().cpu().numpy()
|
|
close = np.allclose(a, b, error, error)
|
|
if not close:
|
|
print(desc, close)
|
|
z = a - b
|
|
index = (np.abs(z) >= error + error * np.abs(b)).nonzero()
|
|
print("dif : ", z[index])
|
|
print("inp1 : ", a[index])
|
|
print("inp2 : ", b[index])
|
|
return close
|
|
|
|
# Permutation helper that applies binary operation between two tensors:
|
|
# 1. applies separate permutation `perm0` & `perm1` to two inputs
|
|
# 2. reduce dimension `broadcast_axis` of operand two to size 1
|
|
# The purpose of this test is to ensure permutation works well in
|
|
# complicated cases with arbitrary stride order and broadcasting dimensions
|
|
def _permutation_helper(self, sizes, broadcast_axis, dtype, device, perm0, perm1):
|
|
def t(x: torch.Tensor, y: torch.Tensor):
|
|
o = torch.add(x, y)
|
|
o = torch.relu(o)
|
|
return o
|
|
|
|
x = torch.randn([sizes[i] for i in perm0], dtype=dtype, device=device).permute([perm0.index(i) for i in range(len(sizes))])
|
|
if broadcast_axis >= 0:
|
|
sizes[broadcast_axis] = 1
|
|
y = torch.randn([sizes[i] for i in perm1], dtype=dtype, device=device).permute([perm1.index(i) for i in range(len(sizes))])
|
|
t_jit = torch.jit.script(t)
|
|
jit_o = t_jit(x, y)
|
|
jit_o = t_jit(x, y)
|
|
o = t(x, y)
|
|
self.assertEqual(o.dtype, jit_o.dtype)
|
|
self.assertEqual(o, jit_o)
|
|
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
|
|
|
|
# end-2-end test of permutation & contiguity handling in integration.
|
|
# we are testing inputs with all combination of permutation order, just to
|
|
# ensure that integration would be able to generate functionally correct
|
|
# kernels
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_binary_ops_permutation(self):
|
|
# note that num_dim is exclusive from len(x), so we are not reducing
|
|
# to single element (codegen limitation at this moment)
|
|
x = [7, 8, 12]
|
|
b_axes = range(-1, len(x))
|
|
for b_axis in b_axes:
|
|
for perm0 in itertools.permutations(range(len(x))):
|
|
for perm1 in itertools.permutations(range(len(x))):
|
|
x = [7, 8, 12]
|
|
self._permutation_helper(x, b_axis, torch.float32, "cuda", perm0, perm1)
|
|
|
|
def _reduction_helper(self, sizes, reduction_axis, dtype, device, perm0, perm1):
|
|
class MyReduction(torch.nn.Module):
|
|
__constants__ = ['reduction_axis']
|
|
|
|
def __init__(self):
|
|
super(MyReduction, self).__init__()
|
|
self.reduction_axis = reduction_axis
|
|
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
|
o = torch.add(x, y)
|
|
o = torch.sum(o, dim=self.reduction_axis)
|
|
return o
|
|
|
|
t = MyReduction()
|
|
|
|
x = torch.randn([sizes[i] for i in perm0], dtype=dtype, device=device).permute([perm0.index(i) for i in range(len(sizes))])
|
|
y = torch.randn([sizes[i] for i in perm1], dtype=dtype, device=device).permute([perm1.index(i) for i in range(len(sizes))])
|
|
t_jit = torch.jit.script(t)
|
|
jit_o = t_jit(x, y)
|
|
jit_o = t_jit(x, y)
|
|
o = t(x, y)
|
|
self.assertEqual(o.dtype, jit_o.dtype)
|
|
# numerical issues here due to our scheduling.
|
|
# can't use `self.assertEqual(o, jit_o)`
|
|
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
|
|
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_reduction(self):
|
|
for x in ([7, 8, 12], [12, 8, 7, 9, 15], [128, 16, 8, 32]):
|
|
# note that num_dim is exclusive from len(x), so we are not reducing
|
|
# to single element (codegen limitation at this moment)
|
|
for num_reduce_dim in range(1, len(x)):
|
|
for axes in itertools.combinations(range(len(x)), num_reduce_dim):
|
|
perm0 = range(len(x))
|
|
perm1 = range(len(x))
|
|
self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_reduction_permutation(self):
|
|
x = [7, 8, 12]
|
|
# note that num_dim is exclusive from len(x), so we are not reducing
|
|
# to single element (codegen limitation at this moment)
|
|
for num_reduce_dim in range(1, len(x)):
|
|
for axes in itertools.combinations(range(len(x)), num_reduce_dim):
|
|
for perm0 in itertools.permutations(range(len(x))):
|
|
for perm1 in itertools.permutations(range(len(x))):
|
|
self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_reduction_multiple_output(self):
|
|
old_guard = torch._C._jit_set_nvfuser_guard_mode(True)
|
|
torch._C._jit_set_bailout_depth(20)
|
|
|
|
def t(x: torch.Tensor, y: torch.Tensor, scale: float, z: torch.Tensor):
|
|
o = torch.mul(x, y)
|
|
o = torch.mul(o, scale)
|
|
out1 = torch.mul(o, z)
|
|
out2 = torch.sum(out1, dim=[2])
|
|
return out1, out2
|
|
|
|
t_jit = torch.jit.script(t)
|
|
x = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda")
|
|
y = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda")
|
|
z = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda")
|
|
scale = 0.5
|
|
jit_o = t_jit(x, y, scale, z)
|
|
jit_o = t_jit(x, y, scale, z)
|
|
o = t(x, y, scale, z)
|
|
for oo, jit_oo in zip(o, jit_o):
|
|
self.assertEqual(oo.dtype, jit_oo.dtype)
|
|
self.assertEqual(oo, jit_oo)
|
|
self.assertGraphContains(t_jit.graph_for(x, y, scale, z), FUSION_GUARD)
|
|
|
|
x = x.to(memory_format=torch.channels_last)
|
|
y = y.to(memory_format=torch.channels_last)
|
|
z = z.to(memory_format=torch.channels_last)
|
|
jit_o = t_jit(x, y, scale, z)
|
|
jit_o = t_jit(x, y, scale, z)
|
|
o = t(x, y, scale, z)
|
|
for oo, jit_oo in zip(o, jit_o):
|
|
self.assertEqual(oo.dtype, jit_oo.dtype)
|
|
self.assertEqual(oo, jit_oo)
|
|
self.assertGraphContains(t_jit.graph_for(x, y, scale, z), FUSION_GUARD)
|
|
torch._C._jit_set_nvfuser_guard_mode(old_guard)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_reduction_dtype(self):
|
|
def t(x: torch.Tensor):
|
|
o = torch.mul(x, 1.0)
|
|
o = torch.sum(o, dim=[2], dtype=torch.float32)
|
|
return o
|
|
t_jit = torch.jit.script(t)
|
|
|
|
x = torch.randn(8, 4, 16, dtype=torch.float, device="cuda")
|
|
jit_o = t_jit(x)
|
|
jit_o = t_jit(x)
|
|
o = t(x)
|
|
self.assertEqual(o.dtype, jit_o.dtype)
|
|
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
|
|
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_reduction_half(self):
|
|
def t(x: torch.Tensor):
|
|
o = torch.mul(x, 1.0)
|
|
o = torch.sum(o, dim=[2])
|
|
return o
|
|
|
|
t_jit = torch.jit.script(t)
|
|
x = torch.randn(8, 4, 16, dtype=torch.float16, device="cuda")
|
|
jit_o = t_jit(x)
|
|
jit_o = t_jit(x)
|
|
o = t(x)
|
|
self.assertEqual(o.dtype, jit_o.dtype)
|
|
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
|
|
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_pw_single_reduction_partition(self):
|
|
sizes = [8, 8, 8]
|
|
dtype = torch.float
|
|
device = "cuda"
|
|
x = torch.randn(sizes, dtype=dtype, device=device)
|
|
y = torch.randn(sizes, dtype=dtype, device=device)
|
|
z = torch.randn(sizes, dtype=dtype, device=device)
|
|
|
|
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
|
|
o = torch.add(x, y)
|
|
o = torch.sum(o, dim=[0])
|
|
o = torch.add(o, z)
|
|
return o
|
|
t_jit = torch.jit.script(t)
|
|
jit_o = t_jit(x, y, z)
|
|
jit_o = t_jit(x, y, z)
|
|
o = t(x, y, z)
|
|
self.assertEqual(o.dtype, jit_o.dtype)
|
|
self.assertEqual(o, jit_o)
|
|
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_single_reduction_broadcast(self):
|
|
dtype = torch.float
|
|
device = "cuda"
|
|
x = torch.randn([7, 4, 8], dtype=dtype, device=device)
|
|
y = torch.randn([4, 8], dtype=dtype, device=device)
|
|
z = torch.randn([1, 4, 8], dtype=dtype, device=device)
|
|
|
|
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
|
|
o = torch.add(x, y)
|
|
o = torch.add(o, z)
|
|
o = torch.sum(o, dim=[0])
|
|
return o
|
|
t_jit = torch.jit.script(t)
|
|
jit_o = t_jit(x, y, z)
|
|
jit_o = t_jit(x, y, z)
|
|
o = t(x, y, z)
|
|
self.assertEqual(o.dtype, jit_o.dtype)
|
|
self.assertEqual(o, jit_o)
|
|
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_profiling_node(self):
|
|
dtype = torch.float
|
|
device = "cuda"
|
|
x = torch.randn(4, 8, 8, 8, dtype=dtype, device=device)
|
|
|
|
def repro(x: torch.Tensor, alpha: float):
|
|
o = torch.rand_like(x)
|
|
o = torch.add(o, alpha)
|
|
return o
|
|
repro_jit = torch.jit.script(repro)
|
|
self._run_helper(repro_jit, repro, x, 0.6)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_reduction_sizes_op(self):
|
|
dtype = torch.float
|
|
device = "cuda"
|
|
x = torch.randn(2, 3, 4, 5, dtype=dtype, device=device)
|
|
y = torch.randn(2, 3, 4, 5, dtype=dtype, device=device)
|
|
|
|
def t(x: torch.Tensor, y: torch.Tensor):
|
|
o = x + y
|
|
o = torch.relu(o)
|
|
o = o.sum((1, 3))
|
|
return o.size()
|
|
t_jit = torch.jit.script(t)
|
|
jit_o = t_jit(x, y)
|
|
jit_o = t_jit(x, y)
|
|
o = t(x, y)
|
|
self.assertEqual(o, jit_o)
|
|
# since the output value is not used at all, the fusion operator should
|
|
# have been optimized away
|
|
self.assertGraphContainsExactly(t_jit.graph_for(x, y), FUSION_GUARD, 0)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_gelu_fusion(self):
|
|
dtype = torch.float
|
|
device = "cuda"
|
|
x = torch.randn([64, 128, 1024], dtype=dtype, device=device, requires_grad=True)
|
|
grads = torch.randn([64, 128, 1024], dtype=dtype, device=device)
|
|
|
|
def t(x: torch.Tensor):
|
|
o = torch.nn.functional.gelu(x)
|
|
o = o * 1.0
|
|
return o
|
|
|
|
t_jit = torch.jit.script(t)
|
|
|
|
self._run_training_helper(t_jit, t, grads, x)
|
|
|
|
class TestPassManagerCudaFuser(JitTestCase):
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"Requires fusion optimization pass to be effective")
|
|
def test_context_manager_test(self):
|
|
x = torch.randn(4, 8, dtype=torch.float, device="cuda")
|
|
y = torch.randn(4, 8, dtype=torch.float, device="cuda")
|
|
with torch.jit.fuser('fuser2'):
|
|
with torch.jit.fuser('fuser2'):
|
|
|
|
def t1(x, y):
|
|
o = x + y
|
|
o = o + 2.0
|
|
return o
|
|
t_jit = torch.jit.script(t1)
|
|
t_jit(x, y)
|
|
t_jit(x, y)
|
|
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
|
|
|
|
def t2(x, y):
|
|
o = x + y
|
|
o = o + 3.0
|
|
return o
|
|
t_jit_2 = torch.jit.script(t2)
|
|
t_jit_2(x, y)
|
|
t_jit_2(x, y)
|
|
self.assertGraphContains(t_jit_2.graph_for(x, y), FUSION_GUARD)
|
|
|
|
def t3(x, y):
|
|
o = x + y
|
|
o = o + 4.0
|
|
return o
|
|
t_jit_3 = torch.jit.script(t3)
|
|
t_jit_3(x, y)
|
|
t_jit_3(x, y)
|
|
self.assertGraphContainsExactly(t_jit_3.graph_for(x, y), FUSION_GUARD, 0)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
def test_register_fuser(self):
|
|
self.assertFalse(torch._C._jit_set_nvfuser_enabled(True))
|
|
self.assertTrue(torch._C._jit_nvfuser_enabled())
|
|
self.assertTrue(torch._C._jit_set_nvfuser_enabled(True))
|
|
self.assertTrue(torch._C._jit_nvfuser_enabled())
|
|
self.assertTrue(torch._C._jit_set_nvfuser_enabled(False))
|
|
self.assertFalse(torch._C._jit_nvfuser_enabled())
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|