pytorch/test/test_jit_cuda_fuser.py
jjsjann123 df741c589f [NVFuser] Upstream push 0809 (#83067)
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. removes un-necessary sync from redundant thread compute analysis
  2. symmetric API for BestEffortReplay
  3. support merge on trivial reductions
  4. Ampere async copy improvements
- bug fixes:
  1. vectorization bug fixes
  2. type inference patch : fixes upstream #81725
  3. segmenter bug fix with deterministic iteration ordering
- parser update
  1. added leaky_relu
- scheduler
  1. normalization scheduler clean up.
  2. simplifies matmul scheduling with new transform propagator
  3. merge all dimensions in PW scheduler
  4. various gemm related improvements
- debuggability
  1. nsight compute support
  2. debug dump for InlinePropagator
  3. Add `UnaryOpType::Print`

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
dfe02f3faed4c64477e5f5c678f21f33415d0195 Merge remote-tracking branch 'csarofeen/devel' into HEAD
16173732ecfafc4797e93c2449cfb778015a6c7a Add `TensorViewBuilder::shape(std::vector<Val*> shape)` (#1884)
7cfb7796bdcf055eb61d600b7b5c9df292950290 Merge pull request #1887 from csarofeen/upstream_merge_0803
3399f6de62061d30781de50ef1862bbfb1615173 Merge remote-tracking branch 'origin/viable/strict' into HEAD
01208f5bba3bc158d41ccbefa0ee2c5ceea7aedb Add `UnaryOpType::Print` which can be helpful for debugging (#1878)
0646522454aa715ef164c88a73fb8bdddc706805 Remove redundant TORCH_INTERNAL_ASSERT in lower_magic_zero.cpp (#1881)
7bc76aa219293a59e4166e258d76289fe13633ca Fix most inlined propagator for mismatched dims (#1875)
501f4aa270bf4dd47b0d2f4860bc6f23ebc32a38 Nonaffine swizzle formulation ep.2: Loop swizzle variant. (#1826)
d863d690f923047a85b5229a787118708f810741 Ampere async copy ep.2: circular buffering extension to support pipelined matmul operand load (#1827)
e0ae11a61c87cd998e88ddd79a496548171c31e0 Larger sized mma instructions to support full vectorization (#1824)
9bb4cf7a66b098f04c9d95a2d34ab2bceee151b3 fragment iteration to support fully unrolled mma ops (#1823)
a48270a18dc2d3accc2626758d14d5858ae55032 Merge all dims in pointwise scheduler (#1872)
172fb3673fb4aaf4c1e889922a4fc5c06cbd59f7 Make MostInlined and BestEffort inline propagation no longer assert replayed (#1868)
a64462a5ac2fcf57a177bf36b0f26c61a4e252a4 Allow trivial reduction to be merged (#1871)
440102bcda6eb1dcd42d5fa5aeab9d6b049956bc Symmetric API for BestEffortReplay (#1870)
d1caf330c08ea8002f7133ca655bbd5b28c4eb98 Some misc cleanups/refactor split out from #1854 (#1867)
1013eda50be38eac96c00ba781340ac199d5a136 Remove some welford specific logic. (#1864)
51589d36be5a101d06e641fe0400b39028b7cb81 Some cleanups on tests and heuristics params (#1866)
a6b3e70da5dee51dbc246347228ea21384e46ac3 Segmenter bug fix, and deterministic iteration ordering.  (#1865)
1b665b9b5e562d6f0caba5e7319e83e5df64104f Add nullptr checks to IrBuilder (#1861)
1cd9451d7493f631c2837ba07c1ea93a74e83a15 Simplify matmul scheduling with the new transform propagator.  (#1817)
bbc1fb9b8c454f557ab9fcf5b1c3cef9b9e136d0 Add leaky_relu operation (#1852)
e842a9bab5e9f7289b7ce33ee37a682b22373f49 Minor cleanup in pointwise scheduler (#1858)
9ee850ca2f7f51dd5269bffb1255e485f809282d Fix stringstream usage (#1857)
20a36c1e4f28c4ff9837e56784be2686d17435f3 Improve nsight compute support (#1855)
405910308301097297b55c34d560aab6a360e897 Remove debugging `true ||` from getPointwiseHeuristics (#1822)
01117bfe8fdfacdbfdcfba9a624cdf900fe044d4 Misc cleanup (#1853)
5cc64943dc381a568223140bce0f22163c01e29f Apply the magic-zero protection to each indexed domain individually for predicate indexing (#1846)
92e6f0207e3a89fe90fd5cd3ffc575dfd766ba00 Cleanup normalization scheduler (#1845)
db89c6591a2f21130599a93675e0615e55564e41 Type inference patch (#1848)
102fe93a4605ca465cda26ebaee4ba1af2026901 Add debug dump for InlinePropagator (#1847)
b7a4d93d375a6e2ddef483763c93ffddc62ec452 Redundant thread compute analysis to avoid un-necessary sync insertion (#1687)
942be5b256056d0e02877361b814ae6af32ca15f Upstream ci build fixes (#1842)
0b83645915029d67f9345aa4649b8c6f62b0061b Fix vectorization bug introduced in #1831 (#1840)
63630f1ae091180e541932a9d9dc598e0a9902dd Move MaxProducerPosUpdater into InlinePropagator::tearDown (#1825)
9135a963c01d97ba34b1a7d2f106e78a13fd6651 Fix transpose benchmark dtype (#1839)
2c9a6c02312d5bf4f83cde653b847b4f85849432 Add extra configurability to `parallelizeAllLike` (#1831)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38543000](https://our.internmc.facebook.com/intern/diff/D38543000)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83067
Approved by: https://github.com/davidberard98
2022-08-10 21:02:56 +00:00

5160 lines
209 KiB
Python

# Owner(s): ["oncall: jit"]
import contextlib
import unittest
import os
import random
import enum
import copy
from functools import reduce
import operator
import warnings
import torch
from torch.nn import functional
from torch.profiler import profile, ProfilerActivity
from torch.testing._internal.codegen.random_topo_test import runDefaultTestWithSeed
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_device_type import instantiate_device_type_tests, ops, OpDTypes
from torch.testing._internal.common_jit import JitCommonTestCase
from torch.testing._internal.common_methods_invocations import op_db, SampleInput
from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, TEST_WITH_ROCM, slowTest, \
is_iterable_of_tensors, freeze_rng_state
from torch.testing._internal.jit_utils import clone_inputs, get_traced_sample_variant_pairs, JitTestCase, RUN_CUDA
from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn
from torch.testing import FileCheck
from jit.test_fuser_common import TestFuserCommon # noqa: F401
import itertools
import numpy as np
import math
from torch.autograd.gradcheck import gradcheck
from typing import List
RUN_NVFUSER = RUN_CUDA and not TEST_WITH_ROCM
CUDA_MAJOR, CUDA_MINOR = 0, 0
if RUN_NVFUSER and torch.version.cuda is not None:
CUDA_MAJOR, CUDA_MINOR = (int(x) for x in torch.version.cuda.split('.')[:2])
os.environ['PYTORCH_NVFUSER_ENABLE'] = 'linear_decomposition,conv_decomposition'
os.environ['PYTORCH_NVFUSER_DISABLE'] = 'fallback,fma,unroll_with_rng'
os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0'
# TODO: enable complex when we fixes the extremal cases in OpInfo
# see issue https://github.com/csarofeen/pytorch/issues/1730"
# os.environ['PYTORCH_NVFUSER_ENABLE'] = 'complex'
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
FUSION_GROUP = 'prim::CudaFusionGroup'
FUSION_GUARD = 'prim::CudaFusionGuard'
# TODO: revert disabled alias ops
ALIAS_TEST_DISABLED = True
@contextlib.contextmanager
def nvfuser_singleton_fusion(flag):
old_value = torch._C._jit_set_nvfuser_single_node_mode(flag)
try:
yield
finally:
torch._C._jit_set_nvfuser_single_node_mode(old_value)
@contextlib.contextmanager
def nvfuser_horizontal_fusion(flag):
old_value = torch._C._jit_set_nvfuser_horizontal_mode(flag)
try:
yield
finally:
torch._C._jit_set_nvfuser_horizontal_mode(old_value)
def is_pre_volta():
if not RUN_NVFUSER:
return False
prop = torch.cuda.get_device_properties(torch.cuda.current_device())
return prop.major < 7
TEST_BF16 = RUN_NVFUSER and torch.cuda.is_bf16_supported()
TEST_LARGE_TENSOR = RUN_NVFUSER
if RUN_NVFUSER:
torch.ones(1).cuda() # initialize cuda context
TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 12e9
class CudaFuserTestOptions():
def __init__(self):
self.old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
self.old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
self.old_guard = torch._C._jit_set_nvfuser_guard_mode(False)
torch._C._debug_set_autodiff_subgraph_inlining(False)
self.old_value = torch._C._jit_set_autocast_mode(True)
if(RUN_CUDA):
self.old_nvfuser = torch._C._jit_set_nvfuser_enabled(True)
def restore(self):
if(RUN_CUDA):
torch._C._jit_set_nvfuser_enabled(self.old_nvfuser)
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuse)
torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuse)
torch._C._jit_set_nvfuser_guard_mode(self.old_guard)
torch._C._debug_set_autodiff_subgraph_inlining(True)
torch._C._jit_set_autocast_mode(self.old_value)
class TestCudaFuser(JitTestCase):
def assertEqual(self, *args, **kwargs):
kwargs["exact_layout"] = True
super(JitTestCase, self).assertEqual(*args, **kwargs)
def _getSubgraphInFusion(self, graph):
num_node = 0
subgraph = None
def count(block, ret):
for n in block.nodes():
if n.kind() == FUSION_GROUP:
ret[0] = ret[0] + 1
self.assertTrue(n.hasAttribute('Subgraph'))
ret[1] = n.g('Subgraph')
for block in n.blocks():
count(block, ret)
ret = [num_node, subgraph]
count(graph, ret)
self.assertEqual(ret[0], 1)
return ret[1]
def setUp(self):
super(TestCudaFuser, self).setUp()
self.skip_node_list = []
disabled_ops = ("aten::batch_norm",
"aten::_batch_norm_impl_index",
"aten::_batch_norm_impl_index_backward",
"aten::native_batch_norm_backward")
for op in disabled_ops:
disabled_flag = torch._C._jit_set_nvfuser_skip_node_kind(op, False)
if disabled_flag:
torch._C._jit_set_nvfuser_skip_node_kind(op, True)
self.skip_node_list.append(op)
# cpu backup to avoid errors in case this is run on a CPU-only machine
dev = 'cuda' if RUN_NVFUSER else 'cpu'
self.special_values = torch.tensor(
[float("-inf"), -10, -math.pi,
-1, -0.5, 0, 1, 0.5,
math.pi, 10, float("inf"),
float("nan")], dtype=torch.float, device=dev)
self.int_types = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64
]
self.support_tensor_dtypes = [
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bool,
torch.complex64,
torch.complex128,
]
if TEST_BF16:
self.support_tensor_dtypes.append(torch.bfloat16)
if(RUN_NVFUSER):
self.cuda_fuser_options = CudaFuserTestOptions()
def tearDown(self):
# restoring skip node to the configuration before tests
for op in self.skip_node_list:
disabled_flag = torch._C._jit_set_nvfuser_skip_node_kind(op, False)
if not disabled_flag:
torch._C._jit_set_nvfuser_skip_node_kind(op, True)
if(RUN_NVFUSER):
self.cuda_fuser_options.restore()
super(TestCudaFuser, self).tearDown()
def _run_helper(self, jit_op, op, *args, check_stride=False, num_fusion=1, check_runs=1):
seed = 123
torch.cuda.manual_seed_all(seed)
jit_o = jit_op(*args)
for i in range(check_runs):
torch.cuda.manual_seed_all(seed + i)
jit_o = jit_op(*args)
torch.cuda.manual_seed_all(seed + i)
o = op(*args)
if type(jit_o) is torch.Tensor:
jit_o = [jit_o, ]
o = [o, ]
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
if check_stride:
self.assertEqual(oo.stride(), jit_oo.stride())
self.assertGraphContainsExactly(jit_op.graph_for(*args), FUSION_GUARD, num_fusion, consider_subgraphs=True)
def _run_training_helper(self, jit_op, op, grads, *args):
torch.cuda.manual_seed_all(123)
jit_o = jit_op(*args)
jit_g = jit_o.backward(grads)
torch.cuda.manual_seed_all(123)
jit_o = jit_op(*args)
jit_g = jit_o.backward(grads)
torch.cuda.manual_seed_all(123)
jit_o = jit_op(*args)
jit_g = jit_o.backward(grads)
torch.cuda.manual_seed_all(123)
o = op(*args)
g = o.backward(grads)
self.assertEqual(o, jit_o)
self.assertEqual(g, jit_g)
self.assertGraphContainsExactly(jit_op.graph_for(*args), FUSION_GUARD, 1, consider_subgraphs=True)
bwd_graph = list(
list(jit_op.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
self.assertGraphContainsExactly(bwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_half(self):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, alpha: float):
o_16 = torch.add(x, y)
o_32_a = torch.add(y, z, alpha=alpha)
o_32_b = torch.add(o_16, z)
return (o_16, o_32_a, o_32_b)
t_jit = torch.jit.script(t)
alpha = 0.5
# stick to integers, this avoid the numerical difference due to our
# promotion
x = torch.randint(0, 256, (4, 8)).to(dtype=torch.float16, device="cuda")
y = torch.randint(0, 256, (4, 8)).to(dtype=torch.float16, device="cuda")
z = torch.randint(0, 256, (4, 8)).to(dtype=torch.float16, device="cuda")
jit_o = t_jit(x, y, z, alpha)
jit_o = t_jit(x, y, z, alpha)
o = t(x, y, z, alpha)
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
self.assertGraphContains(t_jit.graph_for(x, y, z, alpha), FUSION_GUARD)
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_bfloat(self):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, alpha: float):
o_16 = torch.add(x, y)
o_32_a = torch.add(y, z, alpha=alpha)
o_32_b = torch.add(o_16, z)
return (o_16, o_32_a, o_32_b)
t_jit = torch.jit.script(t)
alpha = 0.5
# stick to integers, this avoid the numerical difference due to our
# promotion
x = torch.randint(0, 256, (4, 8)).to(dtype=torch.bfloat16, device="cuda")
y = torch.randint(0, 256, (4, 8)).to(dtype=torch.bfloat16, device="cuda")
z = torch.randint(0, 256, (4, 8)).to(dtype=torch.bfloat16, device="cuda")
jit_o = t_jit(x, y, z, alpha)
jit_o = t_jit(x, y, z, alpha)
o = t(x, y, z, alpha)
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
self.assertGraphContains(t_jit.graph_for(x, y, z, alpha), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_const(self):
def t(x, y):
o = x + y
o = o + 2.0
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_chunk(self):
def t(x, y, z, q):
o = x + q
x0, x1 = torch.chunk(o, 2)
o = x0 + x1
o = o + y
o = o * z
o = torch.relu(o)
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, dtype=torch.float, device="cuda")
y = torch.randn(2, 8, dtype=torch.float, device="cuda")
z = torch.randn(2, 8, dtype=torch.float, device="cuda")
q = torch.randn(4, 8, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, z, q)
jit_o = t_jit(x, y, z, q)
o = t(x, y, z, q)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z, q), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_reduction_dtypes_axis(self):
for op in [torch.sum, torch.mean, torch.amax, torch.var, torch.std]:
for dtype in [torch.float16, torch.float32, torch.double]:
for axis in [-1, 2, 0]:
def make_func(op):
def func(x: torch.Tensor):
o = torch.mul(x, 2.0)
o = op(o, dim=[axis])
return o
return func
x = torch.randn(8, 4, 16, dtype=dtype, device="cuda")
t = make_func(op)
t_jit = torch.jit.trace(t, x)
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_variance(self):
for op in [torch.var, torch.std]:
for dtype in [torch.float16, torch.float32, torch.double]:
for axis in [-2, -1, 2, 1]:
for unbiased in [False, True]:
def make_func(op):
def func(x: torch.Tensor):
o = torch.mul(x, 2.0)
o = op(o, dim=[axis])
return o
return func
x = torch.randn(8, 4, 16, dtype=dtype, device="cuda")
t = make_func(op)
t_jit = torch.jit.trace(t, x)
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_scalar_input(self):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, 1, 32, dtype=torch.float, device="cuda")
y = y.expand(4, 8, 32, 32)
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_0(self):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0))
self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_1(self):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(1, 32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0))
self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_2(self):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 1, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(8, 32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0))
self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_3(self):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(8, 17, 8, dtype=torch.float, device="cuda")
y = torch.randn(8, 17, 1, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0))
self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
# test_broadcasting_partition_logic_X
# Testing partition logic that is capable to avoid creating unsupported
# broadcasting semantics in CudaFusionGroup
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_partition_logic_0(self):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
x = x + 12.0
o1 = x + y
o2 = x + z
o = o1 + o2
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, 6, 8, dtype=torch.float32, device="cuda")
y = torch.randn(8, 6, 8, dtype=torch.float32, device="cuda")
z = torch.randn(6, 8, dtype=torch.float32, device="cuda")
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, z))
self.assertGraphContainsExactly(subgraph, 'aten::add', 4, consider_subgraphs=False)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_partition_logic_1(self):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
x = x + 12.0
o1 = x + y
o2 = x + z
o = o1 + o2
return o
t_jit = torch.jit.script(t)
x = torch.randn(8, 6, 8, dtype=torch.float32, device="cuda")
y = torch.randn(4, 8, 6, 8, dtype=torch.float32, device="cuda")
z = torch.randn(4, 1, 6, 8, dtype=torch.float32, device="cuda")
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, z))
self.assertGraphContainsExactly(subgraph, 'aten::add', 4, consider_subgraphs=False)
@unittest.skipIf(True, "Broadcast with different output not supported yet")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_multiple_output_shape(self):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = x + 12
o1 = o + y
o2 = o + z
oo = o1.sum() + o2.sum()
return oo
t_jit = torch.jit.script(t)
x = torch.randn(32, 32, dtype=torch.float, device="cuda")
y = torch.randn(2, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o, jit_o)
# Currently cannot fuse this
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
@unittest.skipIf(True, "broadcast on branches can't be resolved yet")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_multiple_output(self):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = x + 12
o1 = o + y
o2 = o + z
oo = o1.sum() + o2.sum()
return oo
t_jit = torch.jit.script(t)
x = torch.randn(32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o, jit_o)
# Currently cannot fuse this
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
def _unary_test_helper(self, operation, dtype, random_data):
gradient_check = (dtype == torch.float64) and random_data
shape = self.special_values.shape
torch.cuda.manual_seed_all(211)
# need additional def of t for boolean ops
def t(x: torch.Tensor, y: torch.Tensor):
o = x * y
o = o + 5e-3
o = operation(o)
return o
y = torch.rand(shape, dtype=torch.float32, device="cuda", requires_grad=gradient_check)
y = y.to(dtype=dtype)
if random_data:
x = torch.rand(shape, dtype=torch.float32, device="cuda", requires_grad=gradient_check)
if dtype in self.int_types:
# prefer a larger variance for integer types
x = x * 5
x = x.to(dtype=dtype)
else:
x = self.special_values.to(dtype=dtype)
try:
ref = t(x, y)
except Exception:
# same way as TE checker, if eager mode throws, ignore this test
return
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
if gradient_check:
if jit_o.dtype != torch.bool:
# bool dtype has no `-`
gradcheck(t_jit, [x, y], nondet_tol=1e-5)
elif dtype in self.support_tensor_dtypes:
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
if dtype == torch.bfloat16:
# compare with the actual ground truth for
# bfloat16 kernels instead of eager mode
# implementation, since mismatch in cast
# adds excessive noise.
o = t(x.to(torch.float64), y.to(torch.float64))
if o.dtype.is_floating_point:
o = o.to(torch.bfloat16)
else:
o = t(x, y)
self.assertTrue(self._compare("failing case {}\n{}\n{}\n{}".format(dtype, operation, x, y), o, jit_o, 1e-2))
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_unary_ops(self):
data_types = [
*self.int_types,
torch.float16,
torch.float32,
torch.float64,
# TODO: revert this
# see issue https://github.com/csarofeen/pytorch/issues/1730"
# torch.cfloat,
# torch.cdouble,
]
if TEST_BF16:
data_types.append(torch.bfloat16)
operations = [torch.neg,
torch.abs,
torch.log,
torch.log10,
torch.log1p,
torch.log2,
torch.lgamma,
torch.exp,
torch.expm1,
torch.erf,
torch.erfc,
torch.cos,
torch.acos,
torch.cosh,
torch.sin,
torch.asin,
torch.sinh,
torch.tan,
torch.atan,
torch.sqrt,
torch.rsqrt,
torch.ceil,
torch.floor,
torch.round,
torch.trunc,
torch.frac,
torch.reciprocal,
torch.isfinite,
torch.isinf,
torch.isnan,
torch.isneginf,
torch.isposinf,
torch.isreal,
torch.nn.functional.softplus,
torch.nn.functional.gelu,
torch.nn.functional.leaky_relu,
torch.nn.functional.silu,
torch.relu,
torch.sigmoid,
torch.bitwise_not,
torch.tan,
torch.tanh]
skip_complex = {torch.rsqrt, torch.reciprocal}
for op, dtype in itertools.product(operations, data_types):
if dtype.is_complex and op in skip_complex:
continue
self._unary_test_helper(op, dtype, False) # test special numbers
self._unary_test_helper(op, dtype, True) # test random data
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_category_rule(self):
def run_tensor(x, z):
def t(x: torch.Tensor, z: torch.Tensor):
o = x + z
o = torch.abs(o)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, z)
jit_o = t_jit(x, z)
o = t(x, z)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, z), FUSION_GUARD)
def run_scalar(x, z):
def t(x: torch.Tensor, z: float):
o = x + z
o = torch.abs(o)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, z)
jit_o = t_jit(x, z)
o = t(x, z)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, z), FUSION_GUARD)
# n-dim with 0-dim (no type-promote)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
z = torch.tensor(2.0, dtype=torch.double, device="cuda")
run_tensor(x, z)
# n-dim with 0-dim (type-promote)
x = torch.randn(4, 8, 32, 32, device="cuda").to(dtype=torch.long)
z = torch.tensor(2.0, dtype=torch.double, device="cuda")
run_tensor(x, z)
# n-dim with n-dim (type-promote)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 8, 32, 32, dtype=torch.double, device="cuda")
run_tensor(x, z)
# n-dim with scalar (no type-promote)
x = torch.randn(4, 8, 32, 32, dtype=torch.float16, device="cuda")
z = torch.tensor(3., dtype=torch.double)
run_scalar(x, z)
if TEST_BF16:
# n-dim with scalar (no type-promote)
x = torch.randn(4, 8, 32, 32, dtype=torch.bfloat16, device="cuda")
z = torch.tensor(3., dtype=torch.double)
run_scalar(x, z)
# n-dim with scalar (type-promote)
x = torch.randn(4, 8, 32, 32, device="cuda").to(dtype=torch.long)
z = torch.tensor(3., dtype=torch.double)
run_scalar(x, z)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_unary_bitwise(self):
def bit_not(x: torch.Tensor):
return ~(x + 1)
jitted = torch.jit.script(bit_not)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(5).to(torch.long)
jit_o = jitted(x)
jit_o = jitted(x)
o = bit_not(x)
self.assertEqual(o, jit_o)
jitted.graph_for(x) # Shows up in second instance, not first
self.assertGraphContains(jitted.graph_for(x), FUSION_GUARD)
def bool_not(x: torch.Tensor, y: torch.Tensor):
return ~(x & y)
jitted = torch.jit.script(bool_not)
x = torch.rand(4, 8, 32, 32, dtype=torch.float, device="cuda").round().to(torch.bool)
y = torch.rand(4, 8, 32, 32, dtype=torch.float, device="cuda").round().to(torch.bool)
jit_o = jitted(x, y)
jit_o = jitted(x, y)
o = bool_not(x, y)
self.assertEqual(o, jit_o)
jitted.graph_for(x, y) # Shows up in second instance, not first
self.assertGraphContains(jitted.graph_for(x, y), FUSION_GUARD)
def _get_scalar_binary_test_fn(self, category_and_type1, category_and_type2, operation):
category1, dtype_arg1 = category_and_type1
category2, dtype_arg2 = category_and_type2
def t_intx_tensory(x: int, y: torch.Tensor):
o = operation(x, y)
o = 2 + o
return o
def t_doublex_tensory(x: float, y: torch.Tensor):
o = operation(x, y)
o = 2 + o
return o
def t_cdoublex_tensory(x: complex, y: torch.Tensor):
o = operation(x, y)
o = 2 + o
return o
# Omit both scalar cases and swap cases
assert category1 == "scalar" and category2 != "scalar"
if dtype_arg1.is_floating_point:
return t_doublex_tensory
if dtype_arg1 == torch.int64 or dtype_arg1 == torch.int32:
return t_intx_tensory
if dtype_arg1.is_complex or dtype_arg1 == torch.int32:
return t_cdoublex_tensory
raise NotImplementedError
def _binary_test_helper(self, operation, dtypes, random_data, categories="ndim"):
if isinstance(dtypes, tuple):
dtype_arg1, dtype_arg2 = dtypes
else:
dtype_arg1 = dtype_arg2 = dtypes
if isinstance(categories, tuple) and random_data:
category1, category2 = categories
elif not random_data:
category1 = category2 = "ndim"
else:
category1 = category2 = categories
def is_cpu_category(x):
return x == "0dimcpu" or x == "scalar"
# skip unsupported cases
if is_cpu_category(category1) and is_cpu_category(category2):
return
# only test cases with first operand as scalar
if category2 == "scalar":
return
# skip ops that doesn't support scalar inputs in eager
if operation in [
torch.atan2,
torch.max,
torch.min,
torch.remainder, # unsupported in nvfuser
]:
if category1 == "scalar" or category2 == "scalar":
return
if operation in [
torch.fmod,
torch.eq,
torch.ne,
torch.ge,
torch.gt,
torch.le,
torch.lt
]:
if category1 == "scalar":
return
# operators that does not support bfloat16
if operation in [torch.fmod]:
if dtype_arg1 == torch.bfloat16 or dtype_arg2 == torch.bfloat16:
return
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = operation(x, y)
o = o + z
return o
shape = (4, 32, 32)
shapex = shape if category1 == "ndim" else ()
shapey = shape if category2 == "ndim" else ()
if random_data:
x = (torch.randn(shapex, dtype=torch.float, device="cuda") * 5).to(dtype_arg1)
y = (torch.randn(shapey, dtype=torch.float, device="cuda") * 5).to(dtype_arg2)
else:
x = self.special_values.to(dtype=dtype_arg1)
y = (torch.rand_like(self.special_values) * 5).to(dtype_arg2)
r"""
Category conversion
"""
has_scalar = False
if category1 == "scalar":
has_scalar = True
x = x.item()
if category1 == "0dimcpu":
x = x.to(device="cpu")
if category2 == "scalar":
has_scalar = True
y = y.item()
if category2 == "0dimcpu":
y = y.to(device="cpu")
z = torch.tensor([2], device="cuda").to(dtype_arg1)
is_dtype_arg1_int = dtype_arg1 == torch.int32 or dtype_arg1 == torch.int64
is_dtype_arg2_int = dtype_arg2 == torch.int32 or dtype_arg2 == torch.int64
if operation in [torch.pow]:
if is_dtype_arg1_int and is_dtype_arg2_int:
if category2 == "scalar":
# RuntimeError: Integers to negative integer powers are not allowed
y = abs(y)
if category2 == "0dimcpu" and y == -1:
# https://github.com/pytorch/pytorch/issues/73196
y = y - 1
if category2 == "0dimcpu" and y == -2:
# avoid pow(0, -2), which gives inconsistent results on integer tensor
y = y - 1
# Avoid division by zero for integer tensors
div_like = [torch.div, torch.fmod, torch.remainder]
if operation in div_like and (dtype_arg2 == torch.int32 or dtype_arg2 == torch.int64):
y[y == 0] = 1
test_value = True
if dtype_arg1 == torch.half or dtype_arg2 == torch.half:
test_value = False
if dtype_arg1 == torch.bfloat16 or dtype_arg2 == torch.bfloat16:
test_value = False
try:
if not has_scalar:
o = t(x, y, z)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
self.assertEqual(o.dtype, jit_o.dtype)
if test_value:
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
elif category2 != "scalar": # only test the case where first is scalar
test_fn = self._get_scalar_binary_test_fn((category1, dtype_arg1), (category2, dtype_arg2), operation)
o = test_fn(x, y)
t_jit = torch.jit.script(test_fn)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
if test_value:
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
except Exception as e:
print("failing test for op: ", operation.__name__)
print("with input\n\tx: ", x)
print("\ty: ", y)
print("\tz: ", z)
raise e
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_binary_ops(self):
data_types = [
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
]
if TEST_BF16:
data_types.append(torch.bfloat16)
operations = [torch.mul,
torch.div,
torch.atan2,
torch.max,
torch.min,
torch.pow,
torch.remainder,
torch.fmod,
torch.eq,
torch.ne,
torch.ge,
torch.gt,
torch.le,
torch.lt]
category_types = [
"scalar",
"0dim",
"0dimcpu",
"ndim"
]
binary_dtype_combinations = list(itertools.combinations(data_types, 2))
category_combinations = list(itertools.combinations(category_types, 2))
for op, dtypes, categories in itertools.product(operations, binary_dtype_combinations, category_combinations):
self._binary_test_helper(op, dtypes, True, categories) # random data
for op, dtypes in itertools.product(operations, binary_dtype_combinations):
self._binary_test_helper(op, dtypes, False) # special numbers
# TODO: revert this
@unittest.skipIf(True, "see issue https://github.com/csarofeen/pytorch/issues/1730")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_binary_ops_complex(self):
data_types = [torch.cfloat, torch.cdouble]
operations = [torch.mul, torch.div, torch.pow, torch.eq, torch.ne]
category_types = [
"scalar",
"0dim",
"0dimcpu",
"ndim"
]
binary_dtype_combinations = list(itertools.combinations(data_types, 2))
category_combinations = list(itertools.combinations(category_types, 2))
for op, dtypes, categories in itertools.product(operations, binary_dtype_combinations, category_combinations):
self._binary_test_helper(op, dtypes, True, categories) # random data
for op, dtypes in itertools.product(operations, binary_dtype_combinations):
self._binary_test_helper(op, dtypes, False) # special numbers
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_binary_bitwise(self):
dtypes = [torch.bool, torch.int32, torch.int64]
for dtype1, dtype2, dtype3 in itertools.product(dtypes, repeat=3):
def jit_and(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return torch.bitwise_and(x, y) & z
def jit_or(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return torch.bitwise_or(x, y) | z
def jit_xor(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return torch.bitwise_xor(x, y) ^ z
def jit_lshift(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return torch.bitwise_left_shift(x, y) << z
def jit_rshift(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return torch.bitwise_right_shift(x, y) >> z
for jit_func in [jit_and, jit_or, jit_xor, jit_lshift, jit_rshift]:
if torch.bool in {dtype1, dtype2, dtype3} and jit_func in {jit_lshift, jit_rshift}:
continue
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(5).to(dtype1)
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(5).to(dtype2)
z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(2).to(dtype3)
jitted = torch.jit.script(jit_func)
jit_o = jitted(x, y, z)
jit_o = jitted(x, y, z)
o = jit_func(x, y, z)
self.assertEqual(o, jit_o)
self.assertGraphContains(jitted.graph_for(x, y, z), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_type_as_op(self):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = torch.lt(x, z)
o = o.type_as(y)
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 0.5)
jit_o = t_jit(x, y, 0.5)
o = t(x, y, 0.5)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, 0.5), FUSION_GUARD)
def _ternary_integer_test_helper(self, dtype_arg1):
shape = (4, 8, 32, 32)
magnitude = 100
if (dtype_arg1 in self.int_types):
x = torch.randint(-magnitude, magnitude, shape, dtype=dtype_arg1, device="cuda")
else:
x = torch.randn(shape, dtype=dtype_arg1, device="cuda") * magnitude
arg2 = int(0)
arg3 = int(magnitude * 0.1)
def clamp0(x: torch.Tensor, f: int):
o = 2. * torch.clamp(x, min=f)
return o
clamp0_jit = torch.jit.script(clamp0)
self._run_helper(clamp0_jit, clamp0, x, arg2)
def clamp1(x: torch.Tensor, f: int, ff: int):
o = 2. * torch.clamp(x, min=f, max=ff)
return o
clamp1_jit = torch.jit.script(clamp1)
self._run_helper(clamp1_jit, clamp1, x, arg2, arg3)
def clamp2(x: torch.Tensor, f: float, ff: int):
o = 2. * torch.clamp(x, min=f, max=ff)
return o
clamp2_jit = torch.jit.script(clamp2)
self._run_helper(clamp2_jit, clamp2, x, float(arg2), arg3)
def clamp3(x: torch.Tensor, f: int, ff: float):
o = 2. * torch.clamp(x, min=f, max=ff)
return o
clamp3_jit = torch.jit.script(clamp3)
self._run_helper(clamp3_jit, clamp3, x, arg2, float(arg3))
def threshold(x: torch.Tensor, th: int, val: int):
o = 2. * torch.threshold(x, th, val)
return o
threshold_jit = torch.jit.script(threshold)
self._run_helper(threshold_jit, threshold, x, arg2, arg3)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_ternary_ops_integer_compatibility(self):
data_types = [
torch.float16,
torch.float32,
torch.float64
]
for dtype in data_types:
self._ternary_integer_test_helper(dtype)
def _ternary_test_helper(self, operation, dtypes, random_data):
if isinstance(dtypes, tuple):
dtype_arg1, dtype_arg2, dtype_arg3 = dtypes
else:
dtype_arg1 = dtype_arg2 = dtype_arg3 = dtypes
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, alpha: torch.Tensor):
o = operation(x, y, z)
o = o + alpha
return o
shape = (4, 32, 32)
if operation is torch.where:
dtype_arg1 = torch.bool
if random_data:
x = torch.randint(0, 2, shape).to(dtype=torch.bool, device="cuda")
y = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg2)
z = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg3)
else:
x = torch.randint(0, 2, self.special_values.size()).to(dtype=torch.bool, device="cuda")
y = self.special_values.to(dtype=dtype_arg2)
z = (torch.rand_like(self.special_values) * 5).to(dtype_arg3)
elif random_data:
x = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg1)
y = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg2)
z = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg3)
else:
x = self.special_values.to(dtype=dtype_arg1)
y = (torch.rand_like(self.special_values) * 5).to(dtype_arg2)
z = (torch.rand_like(self.special_values) * 5).to(dtype_arg3)
alpha = torch.tensor([2], device="cuda").to(dtype_arg1)
o = t(x, y, z, alpha)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y, z, alpha)
jit_o = t_jit(x, y, z, alpha)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_ternary_ops_type_promotion(self):
# TODO: update accuracy tolerance for bf16 / fp16 data types
data_types = [
# torch.float16,
torch.float32,
torch.float64
]
'''
if TEST_BF16:
data_types.append(torch.bfloat16)
'''
# TODO: Add Tensor support for clamp
operations = [torch.clamp]
ternary_dtype_combinations = itertools.combinations(data_types, 3)
for op, dtypes in itertools.product(operations, ternary_dtype_combinations):
self._ternary_test_helper(op, dtypes, True) # random data
self._ternary_test_helper(op, dtypes, False) # special numbers
# We can't test the scalar version of rsub from python
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective")
def test_rsub(self):
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
def rsub(x: torch.Tensor, y: torch.Tensor):
o = torch.rsub(x, y)
o = o * 2.
return o
rsub_jit = torch.jit.script(rsub)
self._run_helper(rsub_jit, rsub, x, y)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
# legacy fuser does not work for rand_like, see issue #34361
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective")
def test_ternary_ops(self):
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
cond = torch.randint(0, 2, (4, 8, 32, 32)).to(dtype=torch.bool, device="cuda")
def add(x: torch.Tensor, other: torch.Tensor, alpha: float):
o = torch.relu(x)
o = torch.add(o, other=other, alpha=alpha)
return o
add_jit = torch.jit.script(add)
self._run_helper(add_jit, add, x, y, 2.0)
def clamp0(x: torch.Tensor, f: float):
o = 2. * torch.clamp(x, min=f)
return o
clamp0_jit = torch.jit.script(clamp0)
self._run_helper(clamp0_jit, clamp0, x, 0.5)
def clamp1(x: torch.Tensor, f: float, ff: float):
o = 2. * torch.clamp(x, min=f, max=ff)
return o
clamp1_jit = torch.jit.script(clamp1)
self._run_helper(clamp1_jit, clamp1, x, -0.2, 0.7)
def threshold(x: torch.Tensor, th: float, val: float):
o = 2. * torch.threshold(x, th, val)
return o
threshold_jit = torch.jit.script(threshold)
self._run_helper(threshold_jit, threshold, x, 0.2, 0.9)
def where(x: torch.Tensor, y: torch.Tensor, cond: torch.Tensor):
o = 2. * torch.where(cond, x, y)
return o
where_jit = torch.jit.script(where)
self._run_helper(where_jit, where, x, y, cond)
def lerp(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = 2. * torch.lerp(x, y, z)
return o
lerp_jit = torch.jit.script(lerp)
self._run_helper(lerp_jit, lerp, x, y, z)
def lerp_scale(x: torch.Tensor, y: torch.Tensor, z: float):
o = 2. * torch.lerp(x, y, z)
return o
lerp_scale_jit = torch.jit.script(lerp_scale)
self._run_helper(lerp_scale_jit, lerp_scale, x, y, 0.5)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
def test_addcmul_ops(self):
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
def addcmul(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, value: float):
o = torch.add(x, 0.5)
o = torch.addcmul(o, y, z, value=value)
return o
addcmul_jit = torch.jit.script(addcmul)
self._run_helper(addcmul_jit, addcmul, x, y, z, 2.0)
def addcmul_no_alpha(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = torch.add(x, 0.5)
o = torch.addcmul(o, y, z)
return o
addcmul_no_alpha_jit = torch.jit.script(addcmul_no_alpha)
self._run_helper(addcmul_no_alpha_jit, addcmul_no_alpha, x, y, z)
def addcmul_const_alpha(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = torch.add(x, 0.5)
o = torch.addcmul(o, y, z, value=0.75)
return o
addcmul_const_alpha_jit = torch.jit.script(addcmul_const_alpha)
self._run_helper(addcmul_const_alpha_jit, addcmul_const_alpha, x, y, z)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_dynamic_size(self):
old_guard = torch._C._jit_set_nvfuser_guard_mode(True)
torch._C._jit_set_bailout_depth(20)
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0))
self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
# this test is not ideal, as we rely on the bailout to test it and we
# don't know a way to verify the bailout graph to validate the proper
# fusion.
x = torch.randn(8, 32, 16, 8, dtype=torch.float, device="cuda")
y = torch.randn(16, 8, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD)
x = torch.randn(8, 17, 8, dtype=torch.float, device="cuda")
y = torch.randn(8, 17, 1, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD)
torch._C._jit_set_nvfuser_guard_mode(old_guard)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_random_topo(self):
os.environ["PYTORCH_NVFUSER_DISABLE_FALLBACK"] = "1"
self.assertTrue(runDefaultTestWithSeed(28449))
def _compare(self, desc, inp1, inp2, error):
a = inp1.clone()
b = inp2.clone()
close = torch.allclose(a, b, rtol=error, atol=error, equal_nan=True)
if not close:
print(desc, close)
z = a - b
index = (torch.abs(z) >= error + error * torch.abs(b)).nonzero()
print("dif : ", z[index])
print("inp1 : ", a[index])
print("inp2 : ", b[index])
print("maximum difference", z[index].max())
return close
# Permutation helper that applies binary operation between two tensors:
# 1. applies separate permutation `perm0` & `perm1` to two inputs
# 2. reduce dimension `broadcast_axis` of operand two to size 1
# The purpose of this test is to ensure permutation works well in
# complicated cases with arbitrary stride order and broadcasting dimensions
def _permutation_helper(self, sizes, broadcast_axis, dtype, device, perm0, perm1):
def t(x: torch.Tensor, y: torch.Tensor):
o = torch.add(x, y)
o = torch.relu(o)
return o
x = torch.randn([sizes[i] for i in perm0], dtype=dtype, device=device).permute(
[perm0.index(i) for i in range(len(sizes))])
if broadcast_axis >= 0:
sizes[broadcast_axis] = 1
y = torch.randn([sizes[i] for i in perm1], dtype=dtype, device=device).permute(
[perm1.index(i) for i in range(len(sizes))])
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertEqual(o.stride(), jit_o.stride())
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
# end-2-end test of permutation & contiguity handling in integration.
# we are testing inputs with all combination of permutation order, just to
# ensure that integration would be able to generate functionally correct
# kernels
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_binary_ops_permutation(self):
# note that num_dim is exclusive from len(x), so we are not reducing
# to single element (codegen limitation at this moment)
x = [7, 8, 12]
b_axes = range(-1, len(x))
for b_axis in b_axes:
for perm0 in itertools.permutations(range(len(x))):
for perm1 in itertools.permutations(range(len(x))):
x = [7, 8, 12]
self._permutation_helper(x, b_axis, torch.float32, "cuda", perm0, perm1)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_binary_ops_channels_last_with_bcast(self):
device = "cuda"
x = torch.randn([4, 3, 2, 5], device=device).to(memory_format=torch.channels_last)
w = torch.randn([2, 5], device=device)
def t(x: torch.Tensor, b: torch.Tensor):
o = x + b
return torch.relu(o)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, w)
jit_o = t_jit(x, w)
jit_o = t_jit(x, w)
o = t(x, w)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
self.assertGraphContains(t_jit.graph_for(x, w), FUSION_GUARD)
def _reduction_helper(self, sizes, reduction_axis, dtype, device, perm0, perm1, keepdim=False):
class MyReduction(torch.nn.Module):
__constants__ = ['reduction_axis', 'keepdim']
def __init__(self):
super(MyReduction, self).__init__()
self.reduction_axis = reduction_axis
self.keepdim = keepdim
def forward(self, x: torch.Tensor, y: torch.Tensor):
o = torch.add(x, y)
o = torch.sum(o, dim=self.reduction_axis, keepdim=self.keepdim)
return o
t = MyReduction()
x = torch.randn([sizes[i] for i in perm0], dtype=dtype, device=device).permute(
[perm0.index(i) for i in range(len(sizes))])
y = torch.randn([sizes[i] for i in perm1], dtype=dtype, device=device).permute(
[perm1.index(i) for i in range(len(sizes))])
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
# numerical issues here due to our scheduling.
# can't use `self.assertEqual(o, jit_o)`
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_reduction(self):
for x in ([7, 8, 12], [12, 8, 7, 9, 15], [128, 16, 8, 32]):
# note that num_dim is exclusive from len(x), so we are not reducing
# to single element (codegen limitation at this moment)
for num_reduce_dim in range(1, len(x)):
for axes in itertools.combinations(range(len(x)), num_reduce_dim):
for keepdim in (True, False):
perm0 = range(len(x))
perm1 = range(len(x))
self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1, keepdim)
def _layer_norm_autodiff_helper(self, model, grad, shapes, args):
jit_model = torch.jit.script(model)
eps = np.random.random() * 1e-4
use_cudnn = bool(np.random.randint(0, 2))
# profile/optimization runs
for i in range(3):
jit_o = jit_model(shapes, *args, eps, use_cudnn)
jit_o.backward(grad)
ref_args = [t.detach().clone().requires_grad_() for t in args]
[t.grad.zero_() for t in args]
jit_o = jit_model(shapes, *args, eps, use_cudnn)
jit_o.backward(grad)
o = model(shapes, *ref_args, eps, use_cudnn)
o.backward(grad)
self.assertEqual(jit_o, o)
for arg, ref_arg in zip(args, ref_args):
self.assertEqual(arg.grad, ref_arg.grad)
# check fusion in fw & bw
g = jit_model.graph_for(shapes, *args, eps, use_cudnn)
for node in g.nodes():
n = node
dbg_state = jit_model.get_debug_state()
for val in dbg_state.execution_plans.values():
v = val
state2 = v.code.grad_executor_states()
for val in state2[0].execution_plans.values():
v2 = val
FileCheck().check(FUSION_GUARD).run(g)
FileCheck().check(FUSION_GUARD).run(v2.graph)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_layer_norm_autodiff(self):
def t_wb(shapes: List[int], x, w, b, eps: float, cudnn: bool):
o = torch.layer_norm(x, shapes, w, b, eps, cudnn)
o = torch.relu(o)
return o
def t_w(shapes: List[int], x, w, eps: float, cudnn: bool):
o = torch.layer_norm(x, shapes, w, None, eps, cudnn)
o = torch.relu(o)
return o
def t_b(shapes: List[int], x, b, eps: float, cudnn: bool):
o = torch.layer_norm(x, shapes, None, b, eps, cudnn)
o = torch.relu(o)
return o
def t(shapes: List[int], x, eps: float, cudnn: bool):
o = torch.layer_norm(x, shapes, None, None, eps, cudnn)
o = torch.relu(o)
return o
model = {3: t_wb, 2: t_w, 1: t_b, 0: t}
for w, b in itertools.product([True, False], repeat=2):
batch = [2]
# note: awkward shape here to avoid vectorized fast kernel, which is
# buggy in aten
shapes = [2, 7, 3]
m = model[w * 2 + b]
grad = torch.randn(batch + shapes, dtype=torch.float32, device="cuda")
args = [torch.randn(batch + shapes, dtype=torch.float32, device="cuda").requires_grad_()]
if w:
args.append(torch.randn(shapes, dtype=torch.float32, device="cuda").requires_grad_())
if b:
args.append(torch.randn(shapes, dtype=torch.float32, device="cuda").requires_grad_())
self._layer_norm_autodiff_helper(m, grad, shapes, args)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_layer_norm_parser(self):
dtype = torch.float32
device = "cuda"
x = torch.randn([4, 4, 2], dtype=dtype, device=device)
w = torch.randn([4, 2], dtype=dtype, device=device)
b = torch.randn([4, 2], dtype=dtype, device=device)
def t(x: torch.Tensor, w: torch.Tensor, b: torch.Tensor):
o = torch.relu(x)
o = torch.layer_norm(o, [4, 2], w, b, 1e-5)
return o
o = t(x, w, b)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, w, b)
jit_o = t_jit(x, w, b)
o = t(x, w, b)
self.assertGraphContains(t_jit.graph_for(x, w, b), FUSION_GUARD)
def _native_layer_norm_helper(self, shape, norm_shape, dtype, device, error, affine=True):
class MyLayerNorm(torch.nn.Module):
__constants__ = ['norm_shape']
def __init__(self, elementwise_affine=True):
super(MyLayerNorm, self).__init__()
self.norm_shape = norm_shape
if elementwise_affine:
self.weight = torch.randn(norm_shape, dtype=dtype, device=device)
self.bias = torch.randn(norm_shape, dtype=dtype, device=device)
with torch.no_grad():
self.weight.fill_(1)
self.bias.fill_(0)
else:
self.weight = None
self.bias = None
def forward(self, x: torch.Tensor):
o = torch.relu(x)
o = torch.native_layer_norm(o, self.norm_shape, self.weight, self.bias, 1e-5)
return o
t = MyLayerNorm(affine)
x = torch.randn(shape, dtype=dtype, device=device)
t_jit = torch.jit.script(t)
jit_o, jit_mean, jit_rstd = t_jit(x)
jit_o, jit_mean, jit_rstd = t_jit(x)
o, mean, rstd = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
# numerical issues here due to our scheduling.
# can't use `self.assertEqual(o, jit_o)`
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
self.assertTrue(self._compare("comparing mean failed", mean, jit_mean, error))
self.assertTrue(self._compare("comparing rstd failed", rstd, jit_rstd, error))
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_native_layer_norm(self):
dims = 4
rnds = 3
for idx in range(rnds):
for offset in range(1, dims):
for affine in (True, False):
input_shape = [random.randint(10, 30) for idx in range(dims)]
norm_shape = [input_shape[idx] for idx in range(dims - offset, dims)]
self._native_layer_norm_helper(input_shape, norm_shape, torch.float32, "cuda", 1e-4, affine)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_native_layer_norm_half(self):
dims = 4
rnds = 3
for idx in range(rnds):
for offset in range(1, dims):
input_shape = [random.randint(10, 30) for idx in range(dims)]
norm_shape = [input_shape[idx] for idx in range(dims - offset, dims)]
self._native_layer_norm_helper(input_shape, norm_shape, torch.float16, "cuda", 5e-3)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_native_layer_norm_bfloat(self):
dims = 4
rnds = 3
for idx in range(rnds):
for offset in range(1, dims):
input_shape = [random.randint(10, 30) for idx in range(dims)]
norm_shape = [input_shape[idx] for idx in range(dims - offset, dims)]
self._native_layer_norm_helper(input_shape, norm_shape, torch.bfloat16, "cuda", 1e-1)
def _norm_helper(self,
shape,
dtype,
device,
error,
is_batch_norm_else_instance_norm,
memory_format=torch.contiguous_format,
*,
layer_dtype=torch.float32):
class MyBatchNorm(torch.nn.Module):
def __init__(self):
super(MyBatchNorm, self).__init__()
def forward(self, x: torch.Tensor, r_mean: torch.Tensor, r_var: torch.Tensor):
o = torch.nn.functional.batch_norm(x, r_mean, r_var, training=True)
o = torch.relu(o)
return o
class MyInstanceNorm(torch.nn.Module):
def __init__(self):
super(MyInstanceNorm, self).__init__()
def forward(self, x: torch.Tensor, r_mean: torch.Tensor, r_var: torch.Tensor):
o = torch.nn.functional.instance_norm(x, r_mean, r_var, use_input_stats=True)
o = torch.relu(o)
return o
t = MyBatchNorm() if is_batch_norm_else_instance_norm else MyInstanceNorm()
x = torch.randn(shape, dtype=dtype, device=device).to(memory_format=memory_format)
running_mean = torch.zeros(shape[1], dtype=layer_dtype, device=device)
running_var = torch.ones(shape[1], dtype=layer_dtype, device=device)
t_jit = torch.jit.script(t)
eager_running_mean = running_mean.clone()
eager_running_var = running_var.clone()
jit_running_mean = running_mean.clone()
jit_running_var = running_var.clone()
jit_o = t_jit(x, running_mean.clone(), running_var.clone())
self.assertTrue(self._compare("prerun comparing running_mean failed", eager_running_mean, jit_running_mean, error))
self.assertTrue(self._compare("prerun comparing running_var failed", eager_running_var, jit_running_var, error))
jit_o = t_jit(x, jit_running_mean, jit_running_var)
o = t(x, eager_running_mean, eager_running_var)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.stride(), jit_o.stride())
# numerical issues here due to our scheduling.
# can't use `self.assertEqual(o, jit_o)`
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
self.assertTrue(self._compare("comparing running_mean failed", eager_running_mean, jit_running_mean, error))
self.assertTrue(self._compare("comparing running_var failed", eager_running_var, jit_running_var, error))
self.assertGraphContains(t_jit.graph_for(x, running_mean, running_var), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_layer_norm_trivial_reduce_dim(self):
def t_wb(shapes: List[int], x, w, b, eps: float, cudnn: bool):
o = torch.layer_norm(x, shapes, w, b, eps, cudnn)
o = torch.relu(o)
return o
batch = [1]
shapes = [2, 7, 3]
grad = torch.randn(batch + shapes, dtype=torch.float32, device="cuda")
args = [torch.randn(batch + shapes, dtype=torch.float32, device="cuda").requires_grad_()]
args.append(torch.randn(shapes, dtype=torch.float32, device="cuda").requires_grad_())
args.append(torch.randn(shapes, dtype=torch.float32, device="cuda").requires_grad_())
self._layer_norm_autodiff_helper(t_wb, grad, shapes, args)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_norm_half_layer(self):
size = [2, 4, 2, 2]
for is_batch_norm_else_instance_norm in [False, True]:
for mf in [torch.channels_last, torch.contiguous_format]:
self._norm_helper(size, torch.float16, "cuda", 1e-3, is_batch_norm_else_instance_norm,
memory_format=mf, layer_dtype=torch.float16)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_norm_channels_last(self):
size = [3, 4, 5, 6]
with torch.backends.cudnn.flags(enabled=False):
for is_batch_norm_else_instance_norm in [False, True]:
for mf in [torch.channels_last, torch.contiguous_format]:
self._norm_helper(size, torch.float32, "cuda", 1e-4, is_batch_norm_else_instance_norm, memory_format=mf)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_norm(self):
output_elements = 10000
channel_sizes = [67, 457, 1024, 4096]
with torch.backends.cudnn.flags(enabled=False):
for is_batch_norm_else_instance_norm in [False, True]:
for dims in range(3, 6):
output_size = int(pow(output_elements, 1. / (dims - 1)))
for C in channel_sizes:
x = [output_size for idx in range(dims)]
x[1] = C
self._norm_helper(x, torch.float32, "cuda", 1e-4, is_batch_norm_else_instance_norm)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_norm_large(self):
output_elements = 262144
channel_sizes = 67, 457, 1024
for is_batch_norm_else_instance_norm in [True, False]:
for dims in range(3, 6):
output_size = int(pow(output_elements, 1. / (dims - 1)))
for C in channel_sizes:
x = [output_size for idx in range(dims)]
x[1] = C
self._norm_helper(x, torch.float32, "cuda", 1e-4, is_batch_norm_else_instance_norm)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_norm_half(self):
output_elements = 10000
channel_sizes = [67, 457, 1024, 4096]
with torch.backends.cudnn.flags(enabled=False):
for is_batch_norm_else_instance_norm in [False, True]:
for dims in range(3, 6):
output_size = int(pow(output_elements, 1. / (dims - 1)))
for C in channel_sizes:
x = [output_size for idx in range(dims)]
x[1] = C
self._norm_helper(x, torch.float16, "cuda", 5e-3, is_batch_norm_else_instance_norm)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_norm_bfloat(self):
output_elements = 10000
channel_sizes = [67, 457, 1024, 4096]
with torch.backends.cudnn.flags(enabled=False):
for is_batch_norm_else_instance_norm in [False, True]:
for dims in range(3, 6):
output_size = int(pow(output_elements, 1. / (dims - 1)))
for C in channel_sizes:
x = [output_size for idx in range(dims)]
x[1] = C
self._norm_helper(x, torch.bfloat16, "cuda", 1e-1, is_batch_norm_else_instance_norm)
def _softmax_helper(self, shape, reduction_axis, is_log_softmax, dtype, device, error):
class MySoftmax(torch.nn.Module):
__constants__ = ['reduction_axis']
def __init__(self):
super(MySoftmax, self).__init__()
self.reduction_axis = reduction_axis
def forward(self, x: torch.Tensor, y: torch.Tensor):
o = torch.add(x, y)
o = torch.nn.functional.softmax(o, dim=self.reduction_axis)
return o
class MyLogSoftmax(torch.nn.Module):
__constants__ = ['reduction_axis']
def __init__(self):
super(MyLogSoftmax, self).__init__()
self.reduction_axis = reduction_axis
def forward(self, x: torch.Tensor, y: torch.Tensor):
o = torch.add(x, y)
o = torch.nn.functional.log_softmax(o, dim=self.reduction_axis)
return o
gradient_check = (dtype == torch.float64)
t = MyLogSoftmax() if is_log_softmax else MySoftmax()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=gradient_check)
y = torch.randn(shape, dtype=dtype, device=device, requires_grad=gradient_check)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
if gradient_check:
gradcheck(t_jit.forward, [x, y], nondet_tol=1e-5)
else:
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
# numerical issues here due to our scheduling.
# can't use `self.assertEqual(o, jit_o)`
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_softmax_dtype(self):
def t(x: torch.Tensor, y: torch.Tensor):
o = torch.mul(x, y)
o = torch.nn.functional.softmax(o, dim=0, dtype=torch.float32)
return o
x = torch.randn([4, 4], dtype=torch.float16, device="cuda").requires_grad_()
y = torch.randn_like(x).requires_grad_()
grad = torch.randn_like(x).float()
ref_x = x.detach().requires_grad_()
ref_y = y.detach().requires_grad_()
o = t(ref_x, ref_y)
o.backward(grad)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o.backward(grad)
jit_o = t_jit(x, y)
jit_o.backward(grad)
jit_o = t_jit(x, y)
jit_o.backward(grad)
x.grad.zero_()
y.grad.zero_()
jit_o = t_jit(x, y)
jit_o.backward(grad)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(ref_x.grad, x.grad)
self.assertEqual(ref_y.grad, y.grad)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-3))
self.assertGraphContainsExactly(t_jit.graph_for(x, y), FUSION_GUARD, 1, consider_subgraphs=True)
bwd_graph = list(
list(t_jit.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
FileCheck().check(FUSION_GUARD).run(bwd_graph)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test__softmax_function(self):
def t(x: torch.Tensor, y: torch.Tensor):
o = torch.mul(x, y)
o = torch._softmax(o, dim=-1, half_to_float=False)
return o
x = torch.randn([4, 4], dtype=torch.float16, device="cuda")
y = torch.randn_like(x)
o = t(x, y)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-3))
self.assertGraphContainsExactly(t_jit.graph_for(x, y), FUSION_GUARD, 1, consider_subgraphs=True)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test__softmax_function_half_to_float(self):
def t(x: torch.Tensor, y: torch.Tensor):
o = torch.mul(x, y)
o = torch._softmax(o, dim=-1, half_to_float=True)
return o
x = torch.randn([4, 4], dtype=torch.float16, device="cuda")
y = torch.randn_like(x)
o = t(x, y)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-3))
self.assertGraphContainsExactly(t_jit.graph_for(x, y), FUSION_GUARD, 1, consider_subgraphs=True)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_softmax(self):
output_size = 10000
dims = 4
output_size = int(pow(output_size, 1. / dims))
reduction_sizes = [67, 256, 1024, 4096]
# gradient check
for reduction_dim in range(dims):
for is_log_softmax in [False, True]:
shape = [output_size for idx in range(dims)]
self._softmax_helper(shape, reduction_dim, is_log_softmax, torch.float64, "cuda", 1e-4)
for reduction_dim in range(dims):
for reduction_size in reduction_sizes:
x = [output_size for idx in range(dims)]
x[reduction_dim] = reduction_size
for is_log_softmax in [False, True]:
self._softmax_helper(x, reduction_dim, is_log_softmax, torch.float32, "cuda", 1e-4)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_softmax_half(self):
output_size = 10000
dims = 4
output_size = int(pow(output_size, 1. / dims))
reduction_sizes = [67, 256, 1024, 4096]
for reduction_dim in range(dims):
for reduction_size in reduction_sizes:
x = [output_size for idx in range(dims)]
x[reduction_dim] = reduction_size
for is_log_softmax in [False, True]:
self._softmax_helper(x, reduction_dim, is_log_softmax, torch.float16, "cuda", 5e-3)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_softmax_bfloat(self):
output_size = 10000
dims = 4
output_size = int(pow(output_size, 1. / dims))
reduction_sizes = [67, 256, 1024, 4096]
for reduction_dim in range(dims):
for reduction_size in reduction_sizes:
x = [output_size for idx in range(dims)]
x[reduction_dim] = reduction_size
for is_log_softmax in [False, True]:
self._softmax_helper(x, reduction_dim, is_log_softmax, torch.bfloat16, "cuda", 1e-1)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_reduction_permutation(self):
x = [7, 8, 12]
# note that num_dim is exclusive from len(x), so we are not reducing
# to single element (codegen limitation at this moment)
for num_reduce_dim in range(1, len(x)):
for axes in itertools.combinations(range(len(x)), num_reduce_dim):
for perm0 in itertools.permutations(range(len(x))):
for perm1 in itertools.permutations(range(len(x))):
self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_reduction_multiple_output(self):
old_guard = torch._C._jit_set_nvfuser_guard_mode(True)
torch._C._jit_set_bailout_depth(20)
def t(x: torch.Tensor, y: torch.Tensor, scale: float, z: torch.Tensor):
o = torch.mul(x, y)
o = torch.mul(o, scale)
out1 = torch.mul(o, z)
out2 = torch.sum(out1, dim=[2])
return out1, out2
t_jit = torch.jit.script(t)
x = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda")
y = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda")
z = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda")
scale = 0.5
jit_o = t_jit(x, y, scale, z)
jit_o = t_jit(x, y, scale, z)
o = t(x, y, scale, z)
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
self.assertGraphContains(t_jit.graph_for(x, y, scale, z), FUSION_GUARD)
x = x.to(memory_format=torch.channels_last)
y = y.to(memory_format=torch.channels_last)
z = z.to(memory_format=torch.channels_last)
jit_o = t_jit(x, y, scale, z)
jit_o = t_jit(x, y, scale, z)
o = t(x, y, scale, z)
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
self.assertGraphContains(t_jit.graph_for(x, y, scale, z), FUSION_GUARD)
torch._C._jit_set_nvfuser_guard_mode(old_guard)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_channels_last_with_broadcast(self):
# setting this true forces a new graph to be generated with a new
# input a different broadcast shape
torch._C._jit_set_nvfuser_guard_mode(True)
def t(x: torch.Tensor, y: torch.Tensor):
o = torch.mul(x, y)
o = o + 2.0
return o
t_jit = torch.jit.script(t)
# Single Channel broadcasts
# Test 1
x = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda")
x = x.to(memory_format=torch.channels_last)
y = torch.randn(8, 4, 10, 1, dtype=torch.float, device="cuda")
y = y.to(memory_format=torch.channels_last)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
jit_o.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(o, jit_o)
# Test 2
y = torch.randn(8, 4, 1, 16, dtype=torch.float, device="cuda")
y = y.to(memory_format=torch.channels_last)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
jit_o.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(o, jit_o)
# Test 3
y = torch.randn(8, 1, 10, 16, dtype=torch.float, device="cuda")
y = y.to(memory_format=torch.channels_last)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
jit_o.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(o, jit_o)
# Test 3
y = torch.randn(1, 4, 10, 16, dtype=torch.float, device="cuda")
y = y.to(memory_format=torch.channels_last)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
jit_o.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(o, jit_o)
'''
Currently, the JIT doesn't have tensor merge logic to handle adding
a broadcast tensor with more than one broadcast into a non-broadcast
tensor. Therefore, either of these tests can fail depending on the
sort implementation. The second test is known to fail.
# Two Channel broadcasts
# Test 1
y = torch.randn(8, 4, 1, 1, dtype=torch.float, device="cuda")
y = y.to(memory_format=torch.channels_last)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
jit_o.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(o, jit_o)
# Test 2
y = torch.randn(8, 4, 1, 1, dtype=torch.float, device="cuda")
y = y.to(memory_format=torch.channels_last).transpose(2,3)
x = x.transpose(2,3)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
jit_o.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(o, jit_o)
'''
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_pw_single_reduction_partition(self):
sizes = [2, 2, 2]
dtype = torch.float
device = "cuda"
x = torch.randn(sizes, dtype=dtype, device=device)
y = torch.randn(sizes, dtype=dtype, device=device)
z = torch.randn(sizes, dtype=dtype, device=device)
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = torch.add(x, y)
o = torch.sum(o, dim=[0])
o = torch.add(o, z)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_permutation_preservation(self):
sizes = [2, 3, 4, 5]
dtype = torch.float
device = "cuda"
with nvfuser_singleton_fusion(True):
def t(x: torch.Tensor):
return torch.relu(x)
t_jit = torch.jit.script(t)
x = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last)
self._run_helper(t_jit, t, x, check_stride=True)
def t(x: torch.Tensor, y: torch.Tensor):
return torch.add(x, y)
t_jit = torch.jit.script(t)
x = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last)
y = torch.randn(sizes[1:], dtype=dtype, device=device)
self._run_helper(t_jit, t, x, y, check_stride=True)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_permutation_preservation_edge_case_0(self):
sizes = [2, 3, 4, 5]
dtype = torch.float
device = "cuda"
x = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last)
# mismatch rank with *note* different permutation recognized by PE
bias = torch.randn(3, dtype=dtype, device=device).unsqueeze(-1).unsqueeze(-1)
def t(x, y):
return x + y
t_jit = torch.jit.script(t)
with nvfuser_singleton_fusion(True):
self._run_helper(t_jit, t, x, bias, check_stride=True)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_permutation_preservation_edge_case_1_broken(self):
sizes = [2, 3, 4, 5]
dtype = torch.float
device = "cuda"
x = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last)
# in-compatible permutation, this will cause format propagation to break
bias = torch.randn(4, 5, dtype=dtype, device=device)
def t(x, y):
return x + y
t_jit = torch.jit.script(t)
with nvfuser_singleton_fusion(True):
for _ in range(5):
jit_o = t_jit(x, bias)
o = t(x, bias)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
try:
# nvfuser does not support in-compatible permutation, this will throw
self.assertEqual(o.stride(), jit_o.stride())
except Exception as e:
warnings.warn(
"permutation propagation is broken, proper support should come after nvfuser permutation scheduler update")
self.assertGraphContains(t_jit.graph_for(x, bias), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_permutation_preservation_edge_case_2(self):
sizes = [2, 3, 4, 5]
dtype = torch.float
device = "cuda"
x = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last)
y = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last)
z = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last)
def t(x, y, w):
tmp = torch.lerp(x, y, w)
tmp = torch.clamp(tmp, -1.0, 0.5)
tmp = torch.nn.functional.softplus(tmp)
return torch.threshold(tmp, -2.0, 0.5)
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, y, z, check_stride=True)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_normalization_partition(self):
sizes = [3, 8, 5]
dtype = torch.float
device = "cuda"
x = torch.randn(sizes, dtype=dtype, device=device)
y = torch.randn(sizes, dtype=dtype, device=device)
z = torch.randn(sizes, dtype=dtype, device=device)
r_m = torch.randn(8, dtype=dtype, device=device)
r_v = torch.randn(8, dtype=dtype, device=device)
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, r_mean: torch.Tensor, r_var: torch.Tensor):
o = torch.add(x, y)
o = torch.nn.functional.softmax(o, dim=0)
o = torch.add(o, z)
o = torch.nn.functional.batch_norm(o, r_mean, r_var, training=True)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y, z, r_m, r_v)
jit_o = t_jit(x, y, z, r_m, r_v)
o = t(x, y, z, r_m, r_v)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z, r_m, r_v), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_sum_to_one(self):
dtype = torch.float
device = "cuda"
x = torch.randn([4, 5, 6], dtype=dtype, device=device)
def t(x: torch.Tensor):
o = torch.add(x, 1)
o = torch.sum(o, dim=[0, 1, 2])
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_single_reduction_broadcast(self):
dtype = torch.float
device = "cuda"
x = torch.randn([7, 4, 8], dtype=dtype, device=device)
y = torch.randn([4, 8], dtype=dtype, device=device)
z = torch.randn([1, 4, 8], dtype=dtype, device=device)
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = torch.add(x, y)
o = torch.add(o, z)
o = torch.sum(o, dim=[0])
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_trivial_reduction(self):
dtype = torch.float
device = "cuda"
x = torch.randn([1, 4, 8], dtype=dtype, device=device)
def t(x: torch.Tensor):
o = torch.add(x, 1)
o = torch.sum(o, dim=[0])
o = torch.sum(o, dim=[0])
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_profiling_node(self):
dtype = torch.float
device = "cuda"
x = torch.randn(4, 8, 8, 8, dtype=dtype, device=device)
def repro(x: torch.Tensor, alpha: float):
o = torch.rand_like(x)
o = torch.add(o, alpha)
return o
repro_jit = torch.jit.script(repro)
self._run_helper(repro_jit, repro, x, 0.6)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_rand_like(self):
dtype = torch.float
device = "cuda"
def t(x: torch.Tensor, alpha: float):
o = torch.rand_like(x)
o = torch.add(o, alpha)
return o
# disabling cache so new inputs would generate new graph
t.__disable_jit_function_caching__ = True
for m_format in [torch.contiguous_format, torch.channels_last]:
x = torch.randn(4, 5, 6, 7, dtype=dtype, device=device).to(memory_format=m_format)
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, 0.6, check_stride=True)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_reduction_sizes_op(self):
dtype = torch.float
device = "cuda"
x = torch.randn(2, 3, 4, 5, dtype=dtype, device=device)
y = torch.randn(2, 3, 4, 5, dtype=dtype, device=device)
def t(x: torch.Tensor, y: torch.Tensor):
o = x + y
o = torch.relu(o)
o = o.sum((1, 3))
return o.size()
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o, jit_o)
# since the output value is not used at all, the fusion operator should
# have been optimized away
self.assertGraphContainsExactly(t_jit.graph_for(x, y), FUSION_GUARD, 0)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_profile_ivalue(self):
dtype = torch.float
device = "cuda"
x = torch.randn([7, 4, 7], dtype=dtype, device=device)
y = torch.randn([7, 4, 7], dtype=dtype, device=device)
def t(x: torch.Tensor, y: torch.Tensor, dim: List[int], keepdim: bool):
o = torch.add(x, y)
o = o.sum(dim, keepdim=keepdim)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y, (0, 1), False)
jit_o = t_jit(x, y, (0, 1), False)
o = t(x, y, (0, 1), False)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, (0, 1), False), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_profile_ivalue_multiple_profiles(self):
dtype = torch.float
device = "cuda"
x = torch.randn([7, 4, 7], dtype=dtype, device=device)
def t(x, num: int):
for i in range(num):
# varying reduction axes should break profile_ivalue
tmp = x.sum(i, keepdim=True)
# inplace add on input/output, can't be functionalized/fused
x += tmp
return x
with nvfuser_singleton_fusion(True):
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, 3, num_fusion=0)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_sum_to_size(self):
dtype = torch.float
device = "cuda"
x = torch.randn([2, 4, 4], dtype=dtype, device=device)
y = torch.randn([2, 4, 4], dtype=dtype, device=device)
def t(x: torch.Tensor, y: torch.Tensor, new_size: List[int]):
o = torch.add(x, y)
o = o.sum_to_size(new_size)
return o
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, y, (4, 1))
# update shape: old kernel should handle dynamic shape well without
# recompilation
x = torch.randn([2, 5, 8], dtype=dtype, device=device)
y = torch.randn([2, 5, 8], dtype=dtype, device=device)
# (TODO) check executed kernel, should extend autograd.profiler to fused
# kernels
self._run_helper(t_jit, t, x, y, (5, 1))
with nvfuser_singleton_fusion(True):
x = torch.randn([2, 5, 8], dtype=dtype, device=device)
def t(x: torch.Tensor):
# no-op reduction
return x.sum_to_size((2, 5, 8))
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_grad_sum_to_size(self):
dtype = torch.float
device = "cuda"
x = torch.randn([2, 4, 4], dtype=dtype, device=device).requires_grad_()
y = torch.randn([4], dtype=dtype, device=device).requires_grad_()
grad = torch.randn([2, 4, 4], dtype=dtype, device=device)
ref_x = x.detach().clone().requires_grad_()
ref_y = y.detach().clone().requires_grad_()
def t(x: torch.Tensor, y: torch.Tensor):
o = torch.add(x, y)
o = torch.relu(o)
return o
# profiling runs for forward & backward
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o.backward(grad)
jit_o = t_jit(x, y)
jit_o.backward(grad)
x.grad = None
y.grad = None
jit_o = t_jit(x, y)
jit_o.backward(grad)
o = t(ref_x, ref_y)
o.backward(grad)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertEqual(x.grad, ref_x.grad)
self.assertEqual(y.grad, ref_y.grad)
bwd_graph = list(
list(t_jit.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
FileCheck().check(FUSION_GUARD).run(bwd_graph)
# update shape: old kernel should handle dynamic shape well without
# recompilation
x = torch.randn([2, 5, 8], dtype=dtype, device=device).requires_grad_()
y = torch.randn([8], dtype=dtype, device=device).requires_grad_()
ref_x = x.detach().clone().requires_grad_()
ref_y = y.detach().clone().requires_grad_()
grad = torch.randn([2, 5, 8], dtype=dtype, device=device)
jit_o = t_jit(x, y)
# (TODO) check executed kernel, should extend autograd.profiler to fused
# kernels
jit_o.backward(grad)
o = t(ref_x, ref_y)
o.backward(grad)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertEqual(x.grad, ref_x.grad)
self.assertEqual(y.grad, ref_y.grad)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_dropout_inference_fusion(self):
dtype = torch.float
device = "cuda"
x = torch.randn([10, 4, 8], dtype=dtype, device=device)
def t(x: torch.Tensor, p: float, train: bool):
o = torch.nn.functional.dropout(x, p, training=train)
o = o + 1.0
return o
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, 0.15, False)
@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_dropout_train_nograd_fusion(self):
dtype = torch.float
device = "cuda"
x = torch.randn([64, 128, 1024], dtype=dtype, device=device)
def t(x: torch.Tensor, p: float, train: bool):
o = torch.nn.functional.dropout(x, p, training=train)
o = o + 1.0
return o
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, 0.0, True, check_runs=20)
self._run_helper(t_jit, t, x, 1.0, True, check_runs=20)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_dropout_train_nograd_prob_check(self):
dtype = torch.float
device = "cuda"
x = torch.randn([1024, 1024], dtype=dtype, device=device)
def t(x: torch.Tensor, p: float, train: bool):
o = torch.nn.functional.dropout(x, p, training=train)
o = o * 2.0
return o
t_jit = torch.jit.script(t)
for prob in [0.0, 0.15, 0.5, 0.85, 1.]:
torch.cuda.manual_seed_all(123)
jit_o = t_jit(x, prob, True)
torch.cuda.manual_seed_all(123)
jit_o = t_jit(x, prob, True)
self.assertTrue(jit_o.detach().isfinite().all().item())
num_elems = x.numel()
num_zeros = num_elems - jit_o.detach().count_nonzero().item()
percent_zeros = num_zeros / num_elems
self.assertTrue((percent_zeros >= (prob - 0.01)) and (percent_zeros <= (prob + 0.01)))
self.assertGraphContainsExactly(t_jit.graph_for(x, prob, True), FUSION_GUARD, 1, consider_subgraphs=True)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_dropout_training_fusion(self):
dtype = torch.float
device = "cuda"
sizes = [2, 3, 4, 5]
def t(x: torch.Tensor, p: float, train: bool):
o = torch.nn.functional.dropout(x, p, training=train)
o = o * 2.0
return o
def t2(x: torch.Tensor, p: float, train: bool):
o = torch.nn.functional.softmax(x, dim=-1)
o = torch.nn.functional.dropout(o, p, training=train)
return o
# disabling cache so new inputs would generate new graph
t.__disable_jit_function_caching__ = True
t2.__disable_jit_function_caching__ = True
for fn in [t, t2]:
for m_format in [torch.contiguous_format, torch.channels_last]:
fn_jit = torch.jit.script(fn)
x = torch.randn(sizes, dtype=dtype, device=device, requires_grad=True).to(memory_format=m_format)
grads = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=m_format)
# The drop probability needs to be set to zero given that the order of picking random
# numbers between eager mode and the jit is different
self._run_training_helper(fn_jit, fn, grads, x, 0.0, True)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_gelu(self):
old_guard = torch._C._jit_set_nvfuser_guard_mode(True)
dtype = torch.float
device = "cuda"
x = torch.randn([1024, 1024], dtype=dtype, device=device, requires_grad=True)
grads = torch.randn([1024, 1024], dtype=dtype, device=device, requires_grad=False)
def t(x: torch.Tensor, mode: str):
o = torch.nn.functional.gelu(x, approximate=mode)
o = o * 2.0
return o
t_jit = torch.jit.script(t)
self._run_training_helper(t_jit, t, grads, x, 'none')
self._run_training_helper(t_jit, t, grads, x, 'tanh')
torch._C._jit_set_nvfuser_guard_mode(old_guard)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_dropout_training_prob_check(self):
dtype = torch.float
device = "cuda"
x = torch.randn([1024, 1024], dtype=dtype, device=device, requires_grad=True)
x_nograd = torch.randn([1024, 1024], dtype=dtype, device=device)
def t(x: torch.Tensor, p: float, train: bool):
o = torch.nn.functional.dropout(x, p, training=train)
o = o * 2.0
return o
t_jit = torch.jit.script(t)
for prob in [0.0, 0.15, 0.5, 0.85, 1.]:
torch.cuda.manual_seed_all(123)
jit_o = t_jit(x, prob, True)
torch.cuda.manual_seed_all(123)
jit_o = t_jit(x, prob, True)
torch.cuda.manual_seed_all(123)
jit_o = t_jit(x, prob, True)
self.assertTrue(jit_o.detach().isfinite().all().item())
num_elems = x.numel()
num_zeros = num_elems - jit_o.detach().count_nonzero().item()
percent_zeros = num_zeros / num_elems
self.assertTrue((percent_zeros >= (prob - 0.01)) and (percent_zeros <= (prob + 0.01)))
self.assertGraphContainsExactly(t_jit.graph_for(x, prob, True), FUSION_GUARD, 1, consider_subgraphs=True)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_linear(self):
in_feature = 2
out_feature = 8
# Changing the input dims to be 3-D to avoid eager mode bias fusion
# The bias fusion causes some precision issues with TF-32
weight = torch.randn(out_feature, in_feature, dtype=torch.float32, device='cuda')
bias = torch.randn(out_feature, dtype=torch.float32, device='cuda')
def t(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor):
o = torch.nn.functional.linear(x, weight, bias)
o = torch.relu(o)
return o
# disabling cache so new inputs would generate new graph
t.__disable_jit_function_caching__ = True
sizes = [in_feature, ]
for i in range(4):
# increase input rank in each iteration
sizes.insert(0, i + 2)
x = torch.randn(*sizes, dtype=torch.float32, device='cuda')
t_jit = torch.jit.script(t)
# fusion only happens for input rank >= 4
has_fusion = 0 if len(sizes) < 4 else 1
self._run_helper(t_jit, t, x, weight, bias, check_stride=True, num_fusion=has_fusion)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_linear_symbolic_shapes(self):
def fn(x: int):
y = torch.zeros((3, 4, x, x + 2)).cuda()
for i in range(2):
inp = torch.rand((3, 4, x, x + i)).cuda()
weight = torch.rand((x + 2, x + i)).cuda()
bias = torch.rand((x, x + 2)).cuda()
y += torch.sin(torch.nn.functional.linear(inp, weight, bias))
return y
fn_s = torch.jit.script(fn)
fn_s(5)
fn_s(5)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_conv2d_symbolic_shapes(self):
def fn(x: int):
responses = []
for i in range(2):
inp = torch.rand((3, 3, 32, 32)).cuda()
weight = torch.rand((x + i, 3, 7, 7)).cuda()
bias = torch.rand((x + i)).cuda()
res = torch.nn.functional.conv2d(inp, weight, bias, padding=3)
responses.append(res)
return responses
fn_s = torch.jit.script(fn)
fn_s(5)
fn_s(5)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_backward_type(self):
# not super useful to check gradient of integer/bool, so skipping here
type_pairs = [
(torch.float, torch.half),
(torch.double, torch.half),
(torch.float, torch.double),
]
if TEST_BF16:
type_pairs += [
(torch.float, torch.bfloat16),
(torch.double, torch.bfloat16),
]
for x_type, y_type in type_pairs:
x = torch.randn(4, 2, dtype=x_type, device='cuda', requires_grad=True)
y = torch.randn(4, 2, dtype=y_type, device='cuda', requires_grad=True)
grad = torch.randn(4, 2, dtype=torch.float, device='cuda')
def test1(x: torch.Tensor, y: torch.Tensor):
o = torch.add(x, y)
o = torch.add(o, y)
o = torch.add(o, y)
o = torch.add(o, y)
o = o + 1.0
return o
test1_jit = torch.jit.script(test1)
for i in range(3):
jit_o = test1_jit(x, y)
jit_o.backward(grad)
bwd_graph = list(
list(test1_jit.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
FileCheck().check(FUSION_GROUP).run(bwd_graph)
self.assertEqual(x.grad.dtype, x.dtype)
self.assertEqual(y.grad.dtype, y.dtype)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_autocast_1(self):
def t(x: torch.Tensor, y: torch.Tensor):
o = x * 2.0
o = torch.softmax(o, dim=-1)
o = o * 3.0
o = torch._C._nn.linear(o, y)
return o
x = torch.randn(8, 4, dtype=torch.half, device='cuda', requires_grad=True)
y = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
grad = torch.randn(8, 4, dtype=torch.half, device='cuda', requires_grad=False)
t_jit = torch.jit.script(t)
for i in range(3):
with torch.cuda.amp.autocast():
jit_o = t_jit(x, y)
if i == 2:
fwd_graph = t_jit.graph_for(x, y)
jit_o.backward(grad)
self.assertGraphContainsExactly(fwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
with torch.cuda.amp.autocast():
bwd_graph = list(
list(t_jit.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
FileCheck().check(FUSION_GROUP).run(bwd_graph)
self.assertEqual(jit_o.dtype, torch.half)
self.assertEqual(x.grad.dtype, x.dtype)
self.assertEqual(y.grad.dtype, y.dtype)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_autocast_2(self):
def t(x: torch.Tensor):
o = x * 2.0
o = torch.softmax(o, dim=-1)
o = o * 3.0
o = torch.softmax(o, dim=-1)
o = o * 4.0
return o
x = torch.randn(8, 4, dtype=torch.half, device='cuda', requires_grad=True)
grad = torch.randn(8, 4, dtype=torch.float, device='cuda', requires_grad=False)
t_jit = torch.jit.script(t)
for i in range(3):
with torch.cuda.amp.autocast():
jit_o = t_jit(x)
if i == 2:
fwd_graph = t_jit.graph_for(x)
jit_o.backward(grad)
self.assertGraphContainsExactly(fwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
with torch.cuda.amp.autocast():
bwd_graph = list(
list(t_jit.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
FileCheck().check(FUSION_GROUP).run(bwd_graph)
self.assertEqual(jit_o.dtype, torch.float)
self.assertEqual(x.grad.dtype, x.dtype)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_autocast_1_bfloat(self):
def t(x: torch.Tensor, y: torch.Tensor):
o = x * 2.0
o = torch.softmax(o, dim=-1)
o = o * 3.0
o = torch._C._nn.linear(o, y)
return o
x = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda', requires_grad=True)
y = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
grad = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda', requires_grad=False)
t_jit = torch.jit.script(t)
for i in range(3):
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
jit_o = t_jit(x, y)
if i == 2:
fwd_graph = t_jit.graph_for(x, y)
jit_o.backward(grad)
self.assertGraphContainsExactly(fwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
bwd_graph = list(
list(t_jit.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
FileCheck().check(FUSION_GROUP).run(bwd_graph)
self.assertEqual(jit_o.dtype, torch.bfloat16)
self.assertEqual(x.grad.dtype, x.dtype)
self.assertEqual(y.grad.dtype, y.dtype)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_autocast_2_bfloat(self):
def t(x: torch.Tensor):
o = x * 2.0
o = torch.softmax(o, dim=-1)
o = o * 3.0
o = torch.softmax(o, dim=-1)
o = o * 4.0
return o
x = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda', requires_grad=True)
grad = torch.randn(8, 4, dtype=torch.float, device='cuda', requires_grad=False)
t_jit = torch.jit.script(t)
for i in range(3):
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
jit_o = t_jit(x)
if i == 2:
fwd_graph = t_jit.graph_for(x)
jit_o.backward(grad)
self.assertGraphContainsExactly(fwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
bwd_graph = list(
list(t_jit.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
FileCheck().check(FUSION_GROUP).run(bwd_graph)
self.assertEqual(jit_o.dtype, torch.float)
self.assertEqual(x.grad.dtype, x.dtype)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_to_dtype_fp32_to_fp16(self):
def t(x: torch.Tensor):
o = x * 2.0
o = o.to(dtype=torch.half)
o = o * 3.0
return o
x = torch.randn(8, 4, dtype=torch.float, device='cuda')
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x)
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
self.assertEqual(jit_o.dtype, torch.half)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_to_dtype_fp16_to_fp32(self):
def t(x: torch.Tensor):
o = x * 2.0
o = o.to(dtype=torch.float)
o = o * 3.0
return o
x = torch.randn(8, 4, dtype=torch.half, device='cuda')
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x)
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
self.assertEqual(jit_o.dtype, torch.float)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_to_dtype_fp16_to_fp16(self):
def t(x: torch.Tensor):
o = x * 2.0
o = o.to(dtype=torch.half)
o = o * 3.0
return o
x = torch.randn(8, 4, dtype=torch.half, device='cuda')
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x)
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
self.assertEqual(jit_o.dtype, torch.half)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_to_dtype_fp32_to_bf16(self):
def t(x: torch.Tensor):
o = x * 2.0
o = o.to(dtype=torch.bfloat16)
o = o * 3.0
return o
x = torch.randn(8, 4, dtype=torch.float, device='cuda')
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x)
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
self.assertEqual(jit_o.dtype, torch.bfloat16)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_to_dtype_bf16_to_fp32(self):
def t(x: torch.Tensor):
o = x * 2.0
o = o.to(dtype=torch.float)
o = o * 3.0
return o
x = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda')
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x)
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
self.assertEqual(jit_o.dtype, torch.float)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_to_dtype_bf16_to_bf16(self):
def t(x: torch.Tensor):
o = x * 2.0
o = o.to(dtype=torch.bfloat16)
o = o * 3.0
return o
x = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda')
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x)
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
self.assertEqual(jit_o.dtype, torch.bfloat16)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(not TEST_MULTIGPU, "requires multiple CUDA device")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_multiple_device_pw(self):
def t(x):
o = x + 1.0
o = torch.relu(o)
return o
x = torch.randn(2, dtype=torch.float32, device="cuda")
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x)
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
torch.cuda.device(1)
x = x.to("cuda:1")
jit_o = t_jit(x)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_graph_for_with_missing_optimized_engine(self):
x = torch.randn(8, 4, 2, dtype=torch.float, device="cuda").requires_grad_()
def t(x: torch.Tensor, flag: bool):
x = x + 1.0
x = torch.relu(x)
if flag:
o = x + 1.0
o = torch.relu(o)
else:
o = x + 2.0
o = torch.relu(o)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, False)
jit_o = t_jit(x, False)
jit_o = t_jit(x, True)
o = t(x, True)
self.assertEqual(o, jit_o)
# since the output value is not used at all, the fusion operator should
# have been optimized away
self.assertGraphContainsExactly(t_jit.graph_for(x, True), FUSION_GUARD, 1, True)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_branches(self):
in_feature = 2
out_feature = 4
x = torch.randn(4, in_feature, dtype=torch.float32, device='cuda')
weight = torch.randn(out_feature, in_feature, dtype=torch.float32, device='cuda')
bias = torch.randn(out_feature, dtype=torch.float32, device='cuda')
def t(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, flag: bool):
if flag:
o = torch.nn.functional.linear(x, weight, bias)
o = o + 1.0
o = torch.relu(o)
else:
o = x.sum()
o = o + 2.0
o = torch.relu(o)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, weight, bias, True)
jit_o = t_jit(x, weight, bias, True)
o = t(x, weight, bias, True)
self.assertEqual(o, jit_o)
# since the output value is not used at all, the fusion operator should
# have been optimized away
self.assertGraphContainsExactly(t_jit.graph_for(x, weight, bias, True), FUSION_GUARD, 1)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_scalar_tensor(self):
x = torch.empty([], device="cuda", dtype=torch.float32)
def t(x: torch.Tensor):
o = x + 1.0
o = torch.nn.functional.relu(o)
return o
# bias set to true.
t_jit = torch.jit.script(t)
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o, jit_o)
# since the output value is not used at all, the fusion operator should
# have been optimized away
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
@unittest.skipIf(os.environ.get('PYTORCH_NO_CUDA_MEMORY_CACHING') is not None,
"skipping graph_rng when caching allocator is disabled")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(CUDA_MAJOR < 11, "requires CUDA11 or above")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_graph_rng(self):
self.assertTrue(torch._C._jit_nvfuser_enabled())
size = 10000
a = torch.randn((size,), device="cuda", dtype=torch.float)
def t(x):
o = x + 1.0
o = torch.nn.functional.dropout(o, p=0.1)
o = o + 1.0
o = torch.nn.functional.dropout(o, p=0.1)
return o
t_jit = torch.jit.script(t)
for _ in range(3):
t_jit(a)
self.assertGraphContainsExactly(t_jit.graph_for(a), FUSION_GUARD, 1)
# Control (jitted, ungraphed)
torch.cuda.manual_seed(5)
eager_out = a.clone()
for _ in range(3):
eager_out = t_jit(eager_out)
graph_in = a.clone()
g = torch.cuda.CUDAGraph()
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
torch.cuda.manual_seed(5)
g.capture_begin()
graph_out = t_jit(graph_in)
g.capture_end()
torch.cuda.current_stream().wait_stream(s)
# g is now a jitted, graphed version of t.
# Runs a (jitted, graphed) -> (jitted, ungraphed) -> (jitted, graphed) sequence.
# The ops in the overall sequence should be the same as Control.
g.replay()
# graph_out is now filled with g's result. Use it as ungraphed input.
out = t_jit(graph_out)
graph_in.copy_(out)
g.replay()
# If replay() updated RNG state correctly, graph_out should now equal eager_out
self.assertEqual(graph_out, eager_out)
def _test_batch_norm_impl_index_helper(self, batch, c, hw, affine=True,
track_running_stats=True, train=True,
dtype=torch.float32):
# enabling inlining to avoid counter increment in BN forward
torch._C._debug_set_autodiff_subgraph_inlining(True)
class MyModule(torch.nn.Module):
def __init__(self, num_features=10, affine=True, track_running_stats=True):
super(MyModule, self).__init__()
self.bn = torch.nn.BatchNorm2d(num_features,
1e-5,
affine=affine,
track_running_stats=track_running_stats).to(dtype=dtype)
def forward(self, x):
o = self.bn(x)
o = o * 2.0
return o
x = torch.randn(batch, c, hw, hw, dtype=torch.float, device="cuda").to(dtype=dtype).requires_grad_()
grad = torch.randint(-20, 20, (batch, c, hw, hw), device="cuda").to(dtype=dtype).div(-10)
my_module = MyModule(c, affine, track_running_stats).cuda()
ref_module = MyModule(c, affine, track_running_stats).cuda()
if not train:
my_module.eval()
ref_module.eval()
t_jit = torch.jit.script(my_module)
ref_module.load_state_dict(my_module.state_dict())
ref_x = x.detach().requires_grad_()
for i in range(0, 3):
jit_o = t_jit(x)
jit_o.backward(grad)
# TODO: remove this run?
o = ref_module(ref_x)
o.backward(grad)
has_affine = ref_module.bn.weight is not None
has_running_stats = ref_module.bn.running_mean is not None
if has_running_stats:
my_module.bn.running_mean.zero_()
my_module.bn.running_var.fill_(1.0)
ref_module.bn.running_mean.zero_()
ref_module.bn.running_var.fill_(1.0)
# Verify that when train is False, we don't have grad for weight/bias.
if has_affine and train:
my_module.bn.weight.grad.zero_()
my_module.bn.bias.grad.zero_()
ref_module.bn.weight.grad.zero_()
ref_module.bn.bias.grad.zero_()
x.grad.zero_()
ref_x.grad.zero_()
# real runs
jit_o = t_jit(x)
jit_o.backward(grad)
o = ref_module(ref_x)
o.backward(grad)
# assert forward graph fusion
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1, consider_subgraphs=True)
# assert backward graph fusion
bwd_graph = list(
list(t_jit.get_debug_state().execution_plans.values())[0].code.grad_executor_states()[0]
.execution_plans.values())[0].graph
self.assertGraphContainsExactly(bwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
e0 = 1e-5 if dtype is not torch.half else 1e-3
e1 = 1e-4 if dtype is not torch.half else 1e-3
e2 = 1e-3 if dtype is not torch.half else 1e-2
self.assertTrue(self._compare("comparing output failed", jit_o, o, e0))
self.assertTrue(self._compare("comparing input grad failed", x.grad, ref_x.grad, e1))
# TODO: switch to welford and reduce this to 1e-5
# The 1e-3 looks bad, but we don't have welford in codegen, so numeric
# is very different between reference and codegen.
if has_affine and train:
self.assertTrue(self._compare("comparing weight grad failed",
my_module.bn.weight.grad,
ref_module.bn.weight.grad,
e2))
self.assertTrue(self._compare("comparing bias grad failed",
my_module.bn.bias.grad,
ref_module.bn.bias.grad,
e1))
if has_running_stats:
self.assertTrue(self._compare("comparing running_mean failed",
my_module.bn.running_mean,
ref_module.bn.running_mean,
e0))
self.assertTrue(self._compare("comparing running_var failed",
my_module.bn.running_var,
ref_module.bn.running_var,
e0))
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_batch_norm_half(self):
with torch.backends.cudnn.flags(enabled=True):
setups = [
[True, True],
[False, False],
[True, False],
[False, True]]
for training_and_track, affine in itertools.product(setups, [True, False]):
training, track_running_stats = training_and_track
self._test_batch_norm_impl_index_helper(4, 8, 5, affine, track_running_stats, training, torch.half)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_batch_norm_impl_index_inner_bcast(self):
# the repro
self._test_batch_norm_impl_index_helper(2, 1, 1, False, True, True)
# running the full set
setups = [
[True, True],
[False, False],
[True, False],
[False, True]]
for training_and_track, affine in itertools.product(setups, [True, False]):
training, track_running_stats = training_and_track
self._test_batch_norm_impl_index_helper(2, 1, 1, affine, track_running_stats, training)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_batch_norm_impl_index_correctness(self):
with torch.backends.cudnn.flags(enabled=True):
batch = [2, 7, 16]
channels = [4, 89, 19, 32]
hw = [1, 8, 17, 32]
# avoid tolerance failure in CI
torch.cuda.manual_seed_all(211)
# failing sizes (2, 1, 1, 1)
# failing sizes (2, 89, 8, 8) training False, track True, affine: False
for b, c, hw in itertools.product(batch, channels, hw):
setups = [
[True, True],
[False, False],
[True, False],
[False, True]]
for training_and_track, affine in itertools.product(setups, [True, False]):
training, track_running_stats = training_and_track
self._test_batch_norm_impl_index_helper(b, c, hw, affine, track_running_stats, training)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_softplus_fuser(self):
def shifted_softplus(x: torch.Tensor, shift: float):
return functional.softplus(x) - shift
jitted = torch.jit.script(shifted_softplus)
inp = torch.randn(4, 2, dtype=torch.float32, device="cuda").requires_grad_()
inp_ref = inp.detach().clone().requires_grad_()
grad = torch.randn(4, 2, dtype=torch.float32, device="cuda")
aten_o = shifted_softplus(inp_ref, 0.693147)
aten_o.backward(grad)
aten_grad = inp_ref.grad
for i in range(3):
jit_o = jitted(inp, 0.693147)
inp.grad = None # avoid accumulation on grad
jit_o.backward(grad)
jit_grad = inp.grad
assert torch.allclose(jit_o, aten_o)
assert torch.allclose(jit_grad, aten_grad)
self.assertGraphContains(jitted.graph_for(inp, 0.693147), FUSION_GROUP, True)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_inplace_removal(self):
def t(x: torch.Tensor):
o = torch.nn.functional.softmax(x, dim=0)
o += x
return o.relu_()
jitted = torch.jit.script(t)
inp = torch.randn(4, 2, dtype=torch.float32, device="cuda")
for i in range(3):
jit_o = jitted(inp)
graph = jitted.graph_for(inp)
self.assertGraphContains(graph, FUSION_GROUP, True)
self.assertGraphContains(graph, 'aten::add', True)
self.assertGraphContains(graph, 'aten::relu', True)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_conv2d_bias(self):
def t(x: torch.Tensor, w: torch.Tensor, bias: torch.Tensor):
o = torch.nn.functional.conv2d(x, w, bias)
return o.relu()
jitted = torch.jit.script(t)
inp = torch.randn(4, 5, 3, 3, dtype=torch.float32, device="cuda")
weight = torch.randn(2, 5, 2, 2, dtype=torch.float32, device="cuda")
bias = torch.randn(2, dtype=torch.float32, device="cuda")
for i in range(3):
jit_o = jitted(inp, weight, bias)
graph = jitted.graph_for(inp)
self.assertGraphContains(graph, FUSION_GROUP, True)
def t_not_fused(x: torch.Tensor, w: torch.Tensor):
o = torch.nn.functional.conv2d(x, w)
return o.relu()
jitted_not_fused = torch.jit.script(t_not_fused)
for i in range(3):
jit_o = jitted_not_fused(inp, weight)
graph = jitted_not_fused.graph_for(inp)
self.assertGraphContainsExactly(graph, FUSION_GROUP, 0)
self.assertGraphContains(graph, 'aten::relu', True)
def t_bias(x: torch.Tensor, w: torch.Tensor, bias: torch.Tensor):
o = torch.nn.functional.conv2d(x, w, bias)
return o.relu()
jitted_bias = torch.jit.script(t_bias)
for i in range(3):
jit_o = jitted_bias(inp, weight, bias)
graph = jitted_bias.graph_for(inp)
self.assertGraphContains(graph, FUSION_GROUP, True)
self.assertGraphContains(graph, 'prim::add_optional', True)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_remove_output_used_only_in_dtype(self):
class MyModule(torch.nn.Module):
def __init__(self, num_features=4):
super(MyModule, self).__init__()
self.bn0 = torch.nn.BatchNorm2d(num_features)
self.bn1 = torch.nn.BatchNorm2d(num_features)
def forward(self, x, y):
o1 = self.bn0(x)
o2 = self.bn1(y)
return torch.relu(o1 + o2)
t = MyModule(4).float().cuda()
jitted = torch.jit.script(t)
x = torch.randn(3, 4, 2, 5, dtype=torch.float32, device="cuda")
y = torch.randn(3, 4, 2, 5, dtype=torch.float32, device="cuda")
with torch.cuda.amp.autocast(True):
for i in range(5):
jit_o = jitted(x, y)
jit_o = jitted(x, y)
o = t(x, y)
self.assertTrue(torch.allclose(jit_o, o))
graph = jitted.graph_for(x, y)
self.assertGraphContains(graph, FUSION_GROUP, True)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_fix_shape_expression_bn(self):
class MyModule(torch.nn.Module):
def __init__(self, num_features=4):
super(MyModule, self).__init__()
self.bn = torch.nn.BatchNorm2d(num_features)
def forward(self, x, y):
out1 = self.bn(x)
out2 = out1 + y
out3 = torch.relu(out2)
return out3
t = MyModule(4).float().cuda()
jitted = torch.jit.script(t)
x = torch.randn(3, 4, 2, 5, dtype=torch.float32, device="cuda")
y = torch.randn(3, 4, 2, 5, dtype=torch.float32, device="cuda")
with torch.cuda.amp.autocast(True):
for i in range(5):
jit_o = jitted(x, y)
jit_o = jitted(x, y)
o = t(x, y)
self.assertTrue(torch.allclose(jit_o, o))
graph = jitted.graph_for(x, y)
self.assertGraphContains(graph, FUSION_GROUP, True)
def _run_fwd_helper(self, func, ops, *args):
jitted = torch.jit.script(func)
for i in range(3):
jit_o = jitted(*args)
jit_o = jitted(*args)
o = func(*args)
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
graph = jitted.graph_for(*args)
self.assertGraphContains(graph, FUSION_GROUP, True)
for op in ops:
self.assertGraphContainsExactly(graph, op, 0)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_sibling_fusion(self):
device = "cuda"
dtype = torch.float
x = torch.randn(2, 5, dtype=dtype, device=device)
y = torch.randn(2, 5, dtype=dtype, device=device)
def t(x: torch.Tensor):
o1 = x + 1.0
o2 = x * 0.5
return o1, o2
self._run_fwd_helper(t, ['aten::add', 'aten::mul'], x)
def t2(x: torch.Tensor, y: torch.Tensor):
o1 = x.sum(0)
o2 = (x * y).sum(0)
return o1, o2
self._run_fwd_helper(t2, ['aten::sum', 'aten::mul'], x, y)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_clean_profile_ivalue(self):
device = "cuda"
dtype = torch.float
x = torch.randn(2, 5, dtype=dtype, device=device, requires_grad=True)
# turn on autodiff subgraph inlining
# this is to verify that we clean up profile_ivalue node out side of
# fusion code path.
torch._C._debug_set_autodiff_subgraph_inlining(True)
def t(x: torch.Tensor, flag: bool):
return torch.dropout(x, 0.5, flag)
jit_t = torch.jit.script(t)
for idx in range(5):
out = jit_t(x, True)
graph = jit_t.graph_for(x, True)
out = jit_t(x, False)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_sibling_fusion_no_scalar_inputs(self):
device = "cuda"
dtype = torch.float
x = torch.randn(2, 5, dtype=dtype, device=device)
y = torch.randn(3, dtype=dtype, device=device)
# no tensor dependency between o1/o2, we shouldn't be fusing them
def t(x: torch.Tensor, y: torch.Tensor):
o1 = x + 1
o2 = y - 1
return o1, o2
jitted = torch.jit.script(t)
for i in range(3):
jit_o = jitted(x, y)
graph = jitted.graph_for(x, y)
self.assertGraphContainsExactly(graph, FUSION_GROUP, 0)
def _bias_view_relu_helper(self, shape, output_shape, dtype, device, error):
class BiasViewRelu(torch.nn.Module):
def __init__(self):
super(BiasViewRelu, self).__init__()
self.bias = torch.nn.Parameter(torch.randn(shape, dtype=dtype, device=device), requires_grad=False)
with torch.no_grad():
self.bias.fill_(10)
def forward(self, inputs: torch.Tensor, view_shape: List[int]):
o = inputs + self.bias
o = o.view(view_shape)
return torch.relu(o)
t = BiasViewRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
# profiling
jit_o = t_jit(x, output_shape)
# optimization
jit_o = t_jit(x, output_shape)
# final
jit_o = t_jit(x, output_shape)
# eager - baseline
o = t(x, output_shape)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x, output_shape)
has_inferred_dimension = any([dim == -1 for dim in output_shape])
if has_inferred_dimension:
# prohibit fusing when view_shape contains an inferred dimension
self.assertGraphContainsExactly(graph, FUSION_GROUP, 0)
self.assertGraphContainsExactly(graph, 'prim::view_copy', 0)
else:
self.assertGraphContains(graph, FUSION_GUARD)
self.assertGraphContains(graph, 'prim::view_copy', True)
def _alias_bias_view_relu_helper(self, shape, output_shape, dtype, device, error):
class BiasViewRelu(torch.nn.Module):
def __init__(self):
super(BiasViewRelu, self).__init__()
self.bias = torch.nn.Parameter(torch.randn(shape, dtype=dtype, device=device), requires_grad=False)
with torch.no_grad():
self.bias.fill_(10)
def forward(self, inputs : torch.Tensor, bias : torch.Tensor, view_shape : List[int]):
o = inputs.view(view_shape)
inputs.add_(bias)
return torch.relu(o)
t = BiasViewRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
# profiling
jit_o = t_jit(x.clone(), bias, output_shape)
# optimization
jit_o = t_jit(x.clone(), bias, output_shape)
# final
jit_o = t_jit(x.clone(), bias, output_shape)
# eager - baseline
o = t(x.clone(), bias, output_shape)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x, bias, output_shape)
self.assertGraphContainsExactly(graph, FUSION_GUARD, 0)
self.assertGraphContainsExactly(graph, 'prim::view_copy', 0)
# generate random view given original view
def _random_view(self, original_view, max_len=8, max_views=10000):
class Moves(enum.Enum):
Merge = 0
Split = 1
Broadcast = 2
ImplicitBroadcast = 3
Keep = 4
def valid(old_view, new_view):
old_view_size = reduce(operator.mul, old_view)
new_view_size = reduce(operator.mul, new_view)
return old_view_size == new_view_size
# given a random starting number, find the nearest divisor
def find_nearest_divisor(N):
if 2 >= (N - 1):
return -1
result = random.randint(2, N - 1)
while (N % result) != 0:
result += 1
return result
complete_views = set([tuple(original_view)])
to_visit = []
# empty new view, curent originaal view, start pos=0, move count = 0, last_move
to_visit.append(([], original_view, 0, [], Moves.Keep))
# depth-first search of view shapes, starting from the original view
while len(to_visit) > 0 and len(complete_views) < max_views:
new_view, old_view, odx, move_list, last_move = to_visit[-1]
to_visit.pop()
# iterate over each move type
for idx in range(len(Moves)):
state = Moves(idx)
new_view_clone = copy.deepcopy(new_view)
old_view_clone = copy.deepcopy(old_view)
new_move_list = move_list + [state]
new_odx = odx
# Update state using Move state
if state == Moves.Keep:
new_size = old_view_clone[odx]
new_view_clone.append(new_size)
new_odx += 1
elif state == Moves.Merge:
if odx + 1 < len(old_view_clone):
new_size = old_view_clone[odx] * old_view_clone[odx + 1]
new_view_clone.append(new_size)
new_odx += 2
else:
continue
elif state == Moves.Broadcast and last_move != Moves.Broadcast:
new_view_clone.append(1)
elif state == Moves.Split:
new_size = find_nearest_divisor(old_view_clone[odx])
if new_size == -1:
continue
new_view_clone.append(new_size)
old_view_clone[odx] = int(old_view[odx] / new_size)
if old_view_clone[odx] == 1:
new_odx += 1
elif state == Moves.ImplicitBroadcast:
old_view_clone.insert(odx + 1, 1)
new_size = old_view[odx] * 1
new_view_clone.append(new_size)
new_odx += 2
if new_odx < len(old_view_clone) and len(new_move_list) < max_len:
to_visit.append((new_view_clone, old_view_clone, new_odx, new_move_list, state))
elif (valid(original_view, new_view_clone)):
final_new_view = tuple(new_view_clone)
complete_views.add(final_new_view)
return list(complete_views)
# ndims - number of dimensions
# test_fn - view test function
def _view_test_generator(self, ndims, test_fn):
# create random tensor
# max value for each dimension
max_size = 10e7
max_value = max(int(pow(max_size, 1. / ndims)), 1)
sizes = [random.randint(1, max_value) for idx in range(ndims)]
x = torch.randn(sizes)
original_sizes = list(x.size())
all_views = self._random_view(original_sizes)
random.shuffle(all_views)
max_samples = 20
max_views = min(len(all_views), max_samples)
total = 0
correct = 0
# test random combinations of compatible views
for idx in range(max_views):
for jdx in range(idx + 1, max_views):
total += 1
test_fn(all_views[idx], all_views[jdx], torch.float, 'cuda', 1e-6)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since view is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_view(self):
torch._C._jit_set_nvfuser_guard_mode(True)
self._bias_view_relu_helper([2, 3, 4, 5], [-1, 4, 5], torch.float, 'cuda', 1e-6)
for ndims in range(1, 5):
self._view_test_generator(ndims, self._bias_view_relu_helper)
self._alias_bias_view_relu_helper([2, 3, 4, 5], [1, 6, 1, 2, 2, 5, 1], torch.float, 'cuda', 1e-6)
def _bias_flatten_relu_helper(self, shape, start_dim, end_dim, dtype, device, error):
class BiasFlattenRelu(torch.nn.Module):
def __init__(self):
super(BiasFlattenRelu, self).__init__()
self.bias = torch.nn.Parameter(torch.randn(shape, dtype=dtype, device=device), requires_grad=False)
with torch.no_grad():
self.bias.fill_(10)
def forward(self, inputs : torch.Tensor, start_dim : int, end_dim : int):
o = inputs + self.bias
o = o.flatten(start_dim, end_dim)
return torch.relu(o)
t = BiasFlattenRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, start_dim, end_dim)
self.assertGraphContains(t_jit.graph_for(x, start_dim, end_dim), 'prim::flatten_copy', True)
def _alias_bias_flatten_relu_helper(self, shape, start_dim, end_dim, dtype, device, error):
class BiasFlattenRelu(torch.nn.Module):
def __init__(self):
super(BiasFlattenRelu, self).__init__()
self.bias = torch.nn.Parameter(torch.randn(shape, dtype=dtype, device=device), requires_grad=False)
with torch.no_grad():
self.bias.fill_(10)
def forward(self, inputs : torch.Tensor, bias : torch.Tensor, start_dim : int, end_dim : int):
o = inputs.flatten(start_dim, end_dim)
inputs.add_(bias)
return torch.relu(o)
t = BiasFlattenRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
# profiling
jit_o = t_jit(x.clone(), bias, start_dim, end_dim)
# optimization
jit_o = t_jit(x.clone(), bias, start_dim, end_dim)
# final
jit_o = t_jit(x.clone(), bias, start_dim, end_dim)
# eager - baseline
o = t(x.clone(), bias, start_dim, end_dim)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x, bias, start_dim, end_dim)
self.assertGraphContainsExactly(graph, FUSION_GUARD, 0)
self.assertGraphContainsExactly(graph, 'prim::flatten_copy', 0)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since flatten is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_flatten(self):
torch._C._jit_set_nvfuser_guard_mode(True)
self._bias_flatten_relu_helper([2, 3, 4, 5], 0, -1, torch.float, 'cuda', 1e-6)
self._bias_flatten_relu_helper([2, 3, 4, 5], 1, -1, torch.float, 'cuda', 1e-6)
self._bias_flatten_relu_helper([2, 3, 4, 5], 2, -1, torch.float, 'cuda', 1e-6)
self._bias_flatten_relu_helper([2, 3, 4, 5], 0, 3, torch.float, 'cuda', 1e-6)
self._bias_flatten_relu_helper([2, 3, 4, 5], 1, 2, torch.float, 'cuda', 1e-6)
self._bias_flatten_relu_helper([2, 3, 4, 5], 2, 2, torch.float, 'cuda', 1e-6)
self._alias_bias_flatten_relu_helper([2, 3, 4, 5], 0, -1, torch.float, 'cuda', 1e-6)
self._alias_bias_flatten_relu_helper([2, 3, 4, 5], 1, -1, torch.float, 'cuda', 1e-6)
self._alias_bias_flatten_relu_helper([2, 3, 4, 5], 2, -1, torch.float, 'cuda', 1e-6)
self._alias_bias_flatten_relu_helper([2, 3, 4, 5], 0, 3, torch.float, 'cuda', 1e-6)
self._alias_bias_flatten_relu_helper([2, 3, 4, 5], 1, 2, torch.float, 'cuda', 1e-6)
self._alias_bias_flatten_relu_helper([2, 3, 4, 5], 2, 2, torch.float, 'cuda', 1e-6)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_strict_fusion(self):
def success(x):
with torch.jit.strict_fusion():
return x + x + x
scripted = self.checkScript(success, (torch.rand([4], device='cuda'),))
g = torch.jit.last_executed_optimized_graph()
FileCheck().check_not("aten::add").check("prim::CudaFusionGroup").run(g)
def failure(x):
with torch.jit.strict_fusion():
return x + torch.mm(x, x) + x
with self.assertRaises(Exception) as error_out:
foo_s = torch.jit.script(failure)
foo_s(torch.rand([4, 4]))
foo_s(torch.rand([4, 4]))
fc = FileCheck().check("Found unfused operators")
fc.check("aten::mm").run(str(error_out.exception))
def _ltc_helper(self, shape, dtype, device, error, approximate=True):
# modeled after LTC linear layer
class LTC(torch.nn.Module):
def __init__(self):
super(LTC, self).__init__()
self.weight = torch.nn.Parameter(torch.randn([1024, 1024], dtype=dtype, device=device), requires_grad=False)
self.bias = torch.nn.Parameter(torch.randn([1, 1024], dtype=dtype, device=device), requires_grad=False)
def forward(self, inputs : torch.Tensor):
o = inputs.view([32768, 1024])
o = torch.mm(o, self.weight)
o = o.view([256, 128, 1024])
o = o + self.bias
o = o.view([32768, 1024])
o = o.view([256, 128, 1024])
return torch.nn.functional.gelu(o)
t = LTC()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
# profile/optimization runs
for i in range(3):
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x)
self.assertGraphContains(graph, FUSION_GUARD)
self.assertGraphContains(graph, 'prim::view_copy', True)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since view is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_nested_view(self):
self._ltc_helper([256, 128, 1024], torch.float, 'cuda', 1e-6)
def _bias_squeeze_relu_helper(self, shape, dtype, device, error):
class BiasSqueezeRelu(torch.nn.Module):
def __init__(self):
super(BiasSqueezeRelu, self).__init__()
def forward(self, inputs: torch.Tensor, bias: torch.Tensor):
o = inputs + bias
o = torch.squeeze(o)
return torch.relu(o)
t = BiasSqueezeRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, bias)
jit_o = t_jit(x, bias)
jit_o = t_jit(x, bias)
o = t(x, bias)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x, bias)
self.assertGraphContains(graph, FUSION_GUARD)
self.assertGraphContains(graph, 'prim::squeeze_copy', True)
def _alias_bias_squeeze_relu_helper(self, shape, dtype, device, error):
class BiasSqueezeRelu(torch.nn.Module):
def __init__(self):
super(BiasSqueezeRelu, self).__init__()
def forward(self, inputs: torch.Tensor, bias: torch.Tensor):
o = torch.squeeze(inputs)
inputs.add_(bias)
return torch.relu(o)
t = BiasSqueezeRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
jit_o = t_jit(x.clone(), bias)
jit_o = t_jit(x.clone(), bias)
jit_o = t_jit(x.clone(), bias)
o = t(x.clone(), bias)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x, bias)
self.assertGraphContainsExactly(graph, FUSION_GUARD, 0)
self.assertGraphContainsExactly(graph, 'prim::squeeze_copy', 0)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since squeeze/unsqueeze is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_squeeze(self):
self._bias_squeeze_relu_helper([1, 6, 1, 2, 2, 5, 1], torch.float, 'cuda', 1e-6)
self._alias_bias_squeeze_relu_helper([1, 6, 1, 2, 2, 5, 1], torch.float, 'cuda', 1e-6)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since squeeze/unsqueeze is disabled now")
# remove this after opinfo tests are enabled
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_squeeze_zero(self):
x = torch.tensor(1.0, dtype=torch.float, device="cuda")
def squeeze_0(x: torch.Tensor):
o = x + 1.
o = torch.squeeze(o, 0)
o = o * 2.
return o
def squeeze_1(x: torch.Tensor):
o = x + 1.
o = torch.squeeze(o, -1)
o = o + .5
return o
squeeze_0_jit = torch.jit.script(squeeze_0)
self._run_helper(squeeze_0_jit, squeeze_0, x)
squeeze_1_jit = torch.jit.script(squeeze_1)
self._run_helper(squeeze_1_jit, squeeze_1, x)
def _bias_unsqueeze_relu_helper(self, shape, dtype, device, error):
class BiasUnsqueezeRelu(torch.nn.Module):
def __init__(self):
super(BiasUnsqueezeRelu, self).__init__()
def forward(self, inputs: torch.Tensor, bias: torch.Tensor):
o = inputs + bias
o = torch.unsqueeze(o, 0)
return torch.relu(o)
t = BiasUnsqueezeRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, bias)
jit_o = t_jit(x, bias)
jit_o = t_jit(x, bias)
o = t(x, bias)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x, bias)
self.assertGraphContains(graph, FUSION_GUARD)
self.assertGraphContains(graph, 'prim::unsqueeze_copy', True)
def _alias_bias_unsqueeze_relu_helper(self, shape, dtype, device, error):
class BiasUnsqueezeRelu(torch.nn.Module):
def __init__(self):
super(BiasUnsqueezeRelu, self).__init__()
def forward(self, inputs : torch.Tensor, bias : torch.Tensor):
o = torch.unsqueeze(inputs, 0)
inputs.add_(bias)
return torch.relu(o)
t = BiasUnsqueezeRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
jit_o = t_jit(x.clone(), bias)
jit_o = t_jit(x.clone(), bias)
jit_o = t_jit(x.clone(), bias)
o = t(x.clone(), bias)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x, bias)
self.assertGraphContainsExactly(graph, FUSION_GUARD, 0)
self.assertGraphContainsExactly(graph, 'prim::unsqueeze_copy', 0)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since squeeze/unsqueeze is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_unsqueeze(self):
self._bias_unsqueeze_relu_helper([2, 3, 4, 5], torch.float, 'cuda', 1e-6)
self._alias_bias_unsqueeze_relu_helper([2, 3, 4, 5], torch.float, 'cuda', 1e-6)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since unsqueeze is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_alias_pass_fix(self):
x = torch.randn(4, 24, 2, 2, dtype=torch.float, device="cuda")
w = torch.randn(24, 24, 1, 1, dtype=torch.float, device="cuda")
b = torch.randn(24, dtype=torch.float, device="cuda")
def t(x, w, b):
b2 = b + 1.0
o = torch.conv2d(x, w, b2)
return o
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, w, b)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since squeeze/unsqueeze is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_squeeze_negative_dim(self):
x = torch.randn(4, 24, 1, 2, dtype=torch.float, device="cuda")
def t(x):
o = x + 1.0
o = o.squeeze(-2)
o = o * 2.0
return o
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_singleton_fusion(self):
x = torch.randn(4, 2, device="cuda")
with nvfuser_singleton_fusion(True):
def t(x):
return x.relu()
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_issue1445_fusion(self):
def f(t0, t1, t2, t3):
masked_input = torch.where(t1, t2, t3)
total = masked_input.sum([0, 1, 2, 3])
sizes : List[int] = []
t10 = torch.reshape(t0, sizes)
t7 = total / t10
t4 = t7.to(dtype=torch.float)
return t4
x = torch.randn(1, 1, 1, 1, device='cuda').to(dtype=torch.long)
y = torch.randn(3, 2, 1, 1, device='cuda').to(dtype=torch.bool).expand([3, 2, 1, 2])
z = torch.randn(3, 2, 1, 2, device='cuda')
w = torch.tensor(1.5, device='cuda')
f_jit = torch.jit.script(f)
for i in range(5):
out_jit = f_jit(x, y, z, w)
out = f(x, y, z, w)
self.assertEqual(out, out_jit)
self.assertGraphContainsExactly(f_jit.graph_for(x, y, z, w), FUSION_GROUP, 1)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_disable_sibling_fuse(self):
x = torch.randn(4, 2, device="cuda")
y = torch.randn(8, device="cuda")
s = torch.tensor(1.5, device="cuda")
with nvfuser_horizontal_fusion(False):
def t(x, y, s):
o1 = x + s
o2 = y + s
return o1, o2
t_jit = torch.jit.script(t)
for i in range(5):
t_jit(x, y, s)
# sibling fusion should be disabled with the flag
self.assertGraphContainsExactly(t_jit.graph_for(x, y, s), FUSION_GUARD, 0)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_build_shape_expression_native_dropout(self):
x = torch.randn(4, 2, device="cuda")
def t(x):
o, mask = torch.native_dropout(x, 0.0, True)
o1 = o.sigmoid()
o2 = mask.float().sigmoid()
return (o1, o2)
t_jit = torch.jit.script(t)
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_scalar_tensor_permuted(self):
x = torch.randn(4, 2, 3, device="cuda").permute([1, 2, 0])
y = torch.tensor(1.0, device="cuda")
with nvfuser_singleton_fusion(True):
def t(x, y):
return x + y
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, y)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_cpu_scalar(self):
x = torch.randn(4, 2, 3, device="cuda")
y = torch.tensor(1.0, device="cpu")
z = torch.tensor(2.0, device="cpu")
with nvfuser_singleton_fusion(True):
# testing cpu scalar tensor promotion
def t(x, y):
return x + y
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, y)
# scalar cpu tensor add should NOT be fused
@torch.jit.script
def t1(y, z):
return y * z
for _ in range(5):
t1(y, z)
self.assertGraphContainsExactly(t1.graph_for(y, z), FUSION_GUARD, 0)
# everything, including scalar cpu tensor add should be fused
@torch.jit.script
def t2(x, y, z):
tmp = y + z
return tmp + x
for _ in range(5):
t2(x, y, z)
self.assertGraphContainsExactly(t2.graph_for(x, y, z), 'aten::add', 0)
self.assertGraphContainsExactly(t2.graph_for(x, y, z), FUSION_GUARD, 1)
# 'cpu_tmp = y + z' shouldn't be fused.
@torch.jit.script
def t3(x, y, z):
cpu_tmp = y + z
out = x + y
return cpu_tmp, out
for _ in range(5):
t3(x, y, z)
self.assertGraphContainsExactly(t3.graph_for(x, y, z), FUSION_GUARD, 1)
self.assertGraphContainsExactly(t3.graph_for(x, y, z), 'aten::add', 1)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since squeeze/unsqueeze is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_shape_expression(self):
x = torch.randn(4, 2, 1, 3, device="cuda")
def t_unsqueeze(x):
t0 = x.relu()
t1 = t0.unsqueeze(1)
t2 = t1 + 1.0
t3 = t1.size()
return t2, t3
def t_squeeze(x):
t0 = x.relu()
t1 = t0.squeeze()
t2 = t1 + 1.0
t3 = t1.size()
return t2, t3
def t_squeeze_dim(x):
t0 = x.relu()
t1 = t0.squeeze(-2)
t2 = t1 + 1.0
t3 = t1.size()
return t2, t3
# squeezing a non-size 1 dimension should be a no op
def t_squeeze_dim_no_op(x):
t0 = x.relu()
t1 = t0.squeeze(1)
t2 = t1 + 1.0
t3 = t1.size()
return t2, t3
def run(fn):
jit_fn = torch.jit.script(fn)
jit_o = jit_fn(x)
jit_o = jit_fn(x)
jit_o = jit_fn(x)
o = fn(x)
# output 0 is a tensor, so we check dtype and value
self.assertEqual(o[0].dtype, jit_o[0].dtype)
self.assertEqual(o[0], jit_o[0])
# output 1 is shape
self.assertEqual(o[1], jit_o[1])
self.assertGraphContainsExactly(jit_fn.graph_for(x), FUSION_GUARD, 1)
for t in [t_unsqueeze, t_squeeze, t_squeeze_dim, t_squeeze_dim_no_op]:
run(t)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_scalar_cuda_tensor(self):
x = torch.tensor(2.0, device="cuda")
with nvfuser_singleton_fusion(True):
def t(x):
return x + 1.0
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@torch.jit.script
def t_jitted(x):
return x.sum(0)
for i in range(5):
t_jitted(x)
self.assertGraphContainsExactly(t_jitted.graph_for(x), FUSION_GUARD, 0)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_overlapped_input(self):
x = torch.randn(8, device="cuda").as_strided((2, 4), (1, 1))
with nvfuser_singleton_fusion(True):
def t(x):
return x + 1.0
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
def test_reduction_empty_axes(self):
x = torch.randn(4, 2, 3, device="cuda").permute([1, 2, 0])
with nvfuser_singleton_fusion(True):
def t(x):
sizes : List[int] = []
return x.sum(sizes)
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
def test_int_tensor_input(self):
x = torch.randn(4, 2, device="cuda").to(dtype=torch.int)
with nvfuser_singleton_fusion(True):
def t(x):
return x.amax(dim=0)
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_to_boolean(self):
x = torch.randn(4, 2, device="cuda")
with nvfuser_singleton_fusion(True):
def t(x):
return x.to(dtype=torch.bool)
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_to_copy(self):
x = torch.randn(4, 2, device="cuda")
with nvfuser_singleton_fusion(True):
def t(x, dtype : torch.dtype):
o = torch.ops.aten._to_copy(x, dtype=dtype)
return o
t.__disable_jit_function_caching__ = True
t_jit = torch.jit.script(t)
for dtype in [torch.float16, torch.bool, torch.float64]:
self._run_helper(t_jit, t, x, dtype)
def t_none(x):
with torch.jit.strict_fusion():
o = torch.ops.aten._to_copy(x, dtype=None)
return o
t_jit_none = torch.jit.script(t_none)
self._run_helper(t_jit_none, t_none, x)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since reshape is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_view_copy_graph_guard(self):
x = torch.randn(4, 2, 3, device="cuda").permute([1, 2, 0])
y = [4, 6]
with nvfuser_singleton_fusion(True):
def t(x, y : List[int]):
t1 = x + 1.0
t2 = t1 * 1.0
out = t2.reshape(y)
return out.relu()
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, y)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since view is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_view_copy_graph_guard_double_fusion(self):
x = torch.randn(2, 2, 5, device="cuda")
w = torch.randn(5, 5, device="cuda")
with nvfuser_singleton_fusion(True):
def t(x, w):
o = x.view([4, x.size()[-1]])
o = torch.matmul(o, w)
o = o.view([2, 2, o.size()[1]])
return o
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x, w)
o = t(x, w)
self.assertEqual(jit_o, o)
self.assertGraphContainsExactly(t_jit.graph_for(x, w), FUSION_GUARD, 2, consider_subgraphs=True)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_input_output_passthrough(self):
def t(t0, t1, t2):
mask = t1.to(dtype=torch.bool)
masked_input = torch.where(t0, mask, t2)
return masked_input, mask
t_jit = torch.jit.script(t)
# stick to integers, this avoid the numerical difference due to our
# promotion
x = torch.randn(4, 4, device='cuda').to(dtype=torch.bool)
y = torch.randn(4, 4, device='cuda').to(dtype=torch.bool)
z = torch.tensor(1.0, device='cuda').to(dtype=torch.bool)
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_pointwise_reference_tensor(self):
def t(input1, input2, scalar):
_unsafe_view = torch.ops.aten._unsafe_view(input1, [2, 4, 16])
add_ = torch.ops.aten.add_(_unsafe_view, input2)
gelu_ = torch.ops.aten.gelu(add_)
view_ = torch.ops.aten.view(gelu_, [8, 16])
mul_ = torch.ops.aten.mul(add_, scalar)
return [view_, mul_]
x = torch.randn(8, 16, device="cuda")
bias = torch.randn(16, device="cuda")
scalar = torch.ones(torch.Size([]), device="cuda")
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x, bias, scalar)
o = t(x, bias, scalar)
self.assertEqual(jit_o, o)
self.assertGraphContains(t_jit.graph_for(x, bias, scalar), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
def test_native_batch_norm_backward(self):
grad_output = torch.randn(4, 2, 3, device="cuda")
input = torch.randn(4, 2, 3, device="cuda")
weight = torch.randn(2, device="cuda")
r_m = torch.randn(2, device="cuda")
r_v = torch.randn(2, device="cuda").abs()
save_mean = torch.randn(2, device="cuda")
save_invstd = torch.randn(2, device="cuda").abs()
with nvfuser_singleton_fusion(True):
def t(grad_out, input, weight, r_m, r_v, save_mean, save_invstd, train: bool, eps: float, mask: List[bool]):
return torch.ops.aten.native_batch_norm_backward(grad_out, input, weight, r_m, r_v, save_mean,
save_invstd, train, eps, mask)
t_jit = torch.jit.script(t)
for i in range(4):
jit_o = t_jit(grad_output, input, weight, r_m.clone(), r_v.clone(),
save_mean, save_invstd, True, 1e-5, [True, True, True])
ref_m = r_m.clone()
ref_v = r_v.clone()
jit_o = t_jit(grad_output, input, weight, r_m, r_v, save_mean, save_invstd, True, 1e-5, [True, True, True])
o = t(grad_output, input, weight, ref_m, ref_v, save_mean, save_invstd, True, 1e-5, [True, True, True])
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
self.assertEqual(ref_m.dtype, r_m.dtype)
self.assertEqual(ref_m, r_m)
self.assertEqual(ref_v.dtype, r_v.dtype)
self.assertEqual(ref_v, r_v)
self.assertGraphContains(t_jit.graph_for(grad_output, input, weight, r_m.clone(), r_v.clone, save_mean,
save_invstd, True, 1e-5, [True, True, True]), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_contiguous_on_broadcasted(self):
x = torch.randn(4, 1, device="cuda")
y = torch.randn(4, 128, device="cuda")
with nvfuser_singleton_fusion(True):
def t(x, y):
t1 = x.expand([4, 128])
t2 = t1 * y
return t2
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, y)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_skip_parser(self):
x = torch.randn(4, 12, device="cuda")
with nvfuser_singleton_fusion(True):
def fn(x):
t1 = x + 1.0
return t1.relu()
fn_jit = torch.jit.script(fn)
self._run_helper(fn_jit, fn, x)
# add node should have been merged into fusion
self.assertGraphContains(fn_jit.graph_for(x), FUSION_GUARD)
self.assertGraphContainsExactly(fn_jit.graph_for(x), 'aten::add', 0)
# flips skip parse for `aten::add`, following fusion should skip the
# add node
self.assertFalse(torch._C._jit_set_nvfuser_skip_node_kind("aten::add", True))
def fn_1(x):
t1 = x + 2.0 # change const value so we'll not reuse plan
return t1.relu()
fn_1_jit = torch.jit.script(fn_1)
self._run_helper(fn_1_jit, fn_1, x)
# add node should have been merged into fusion
self.assertGraphContains(fn_1_jit.graph_for(x), FUSION_GUARD)
self.assertGraphContainsExactly(fn_1_jit.graph_for(x), 'aten::add', 1)
# flips skip parse for `aten::add`, next fusion should fuse add node
self.assertTrue(torch._C._jit_set_nvfuser_skip_node_kind("aten::add", True))
def fn_2(x):
t1 = x + 2.0 # change const value so we'll not reuse plan
return t1.relu()
fn_2_jit = torch.jit.script(fn_2)
self._run_helper(fn_2_jit, fn_2, x)
# add node should have been merged into fusion
self.assertGraphContains(fn_2_jit.graph_for(x), FUSION_GUARD)
self.assertGraphContainsExactly(fn_2_jit.graph_for(x), 'aten::add', 0)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_cuda_fusion_guard(self):
old_guard = torch._C._jit_set_nvfuser_guard_mode(True)
class ConvModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.sin().sigmoid()
mod = ConvModule().to(device="cuda")
inputs = [torch.randn(20, 16, 50, 100, device="cuda", requires_grad=True)]
def reduce_scalar(temp):
return temp.sum()
scripted = torch.jit.script(mod)
with torch.no_grad():
scripted(*inputs)
res = scripted(*inputs)
reduce_scalar(res).backward()
torch._C._jit_set_nvfuser_guard_mode(old_guard)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_nvfuser_comparison_callbacks_with_fallback(self):
try:
fused_result = None
unfused_result = None
graph_ir = None
def callback(fused_outputs, unfused_outputs, graph_str):
nonlocal unfused_result
nonlocal fused_result
nonlocal graph_ir
unfused_result = unfused_outputs[-1]
fused_result = fused_outputs[-1]
graph_ir = graph_str
torch._C._jit_nvfuser_set_comparison_callback(True, callback)
def fn(x, y):
z = torch.add(x, y)
return torch.relu(z)
x = torch.rand((4, 4)).cuda() - 0.5
y = torch.rand((4, 4)).cuda() - 0.5
fn_s = torch.jit.script(fn)
fn_s(x, y)
fn_s(x, y)
fn_s(x, y)
expected = fn(x, y)
self.assertEqual(expected, fused_result)
self.assertEqual(expected, unfused_result)
FileCheck().check("aten::add").run(graph_ir)
finally:
torch._C._jit_nvfuser_clear_comparison_callback()
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_nvfuser_comparison_callbacks_without_fallback(self):
try:
fused_result = None
unfused_result = None
graph_ir = None
def callback(fused_outputs, unfused_outputs, graph_str):
nonlocal unfused_result
nonlocal fused_result
nonlocal graph_ir
if len(unfused_outputs) > 0:
unfused_result = unfused_outputs[-1]
fused_result = fused_outputs[-1]
graph_ir = graph_str
torch._C._jit_nvfuser_set_comparison_callback(False, callback)
def fn(x, y):
z = torch.add(x, y)
return torch.relu(z)
x = torch.rand((4, 4)).cuda() - 0.5
y = torch.rand((4, 4)).cuda() - 0.5
fn_s = torch.jit.script(fn)
fn_s(x, y)
fn_s(x, y)
fn_s(x, y)
expected = fn(x, y)
self.assertEqual(expected, fused_result)
self.assertEqual(None, unfused_result)
FileCheck().check("aten::add").run(graph_ir)
finally:
torch._C._jit_nvfuser_clear_comparison_callback()
@unittest.skipIf(not RUN_NVFUSER, "requires NVFuser")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_cuda_fusion_guard_backward(self):
old_guard = torch._C._jit_set_nvfuser_guard_mode(True)
inp = torch.randn(10, device="cuda", requires_grad=True)
grad = torch.randn(10, device="cuda")
def f(x):
a = x.cos().cos()
return a
scripted = torch.jit.script(f)
with profile(activities=[ProfilerActivity.CPU]) as prof:
for _ in range(5):
inp.grad = None
out = scripted(inp)
out.backward(grad)
# check that we do not have fallback triggered
self.assertEqual(prof.events().table().find("fallback"), -1)
torch._C._jit_set_nvfuser_guard_mode(old_guard)
# TODO: generalize this
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
def test_inf_quick_patch(self):
inputs = [torch.tensor([-float('inf'), float('inf'), 4.0], device="cuda"),
torch.tensor([1.0, float('inf'), 4.0], device="cuda"),
torch.tensor([-float('inf'), -1.5, 4.0], device="cuda"),
torch.tensor([1.0, -3.0, float('nan')], device="cuda"),
torch.tensor([-float('inf'), -float('inf'), -float('inf')], device="cuda"),
torch.tensor([float('inf'), float('inf'), float('inf')], device="cuda"),
torch.tensor([float('nan'), float('nan'), float('nan')], device="cuda")]
def fn_amax(x):
return x.amax(dim=0)
def fn_amin(x):
return x.amin(dim=0)
def fn_add_nan(x):
return x.relu() + float('nan')
def fn_add(x):
return x + 1.0
with nvfuser_singleton_fusion(True):
for t in [fn_amax, fn_amin, fn_add, fn_add_nan]:
for x in inputs:
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_clamp_reversed_bound(self):
x = torch.tensor([1., -float('inf'), 2., float('inf'), float('nan')], device="cuda")
def t(x):
return x.clamp(min=1., max=0.5)
with nvfuser_singleton_fusion(True):
jit_t = torch.jit.script(t)
self._run_helper(jit_t, t, x)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_issue_1785(self):
class Fusion(torch.nn.Module):
def __init__(self):
super(Fusion, self).__init__()
def forward(self, x, a, b):
out = torch.mul(x.unsqueeze(-1), a)
out = out + b
return out
x = torch.randn(1024, 192, 3, device='cuda')
a = torch.randn(3, 128, device='cuda')
b = torch.randn(3, 128, device='cuda')
model = Fusion()
jit_model = torch.jit.script(model)
with torch.jit.fuser('fuser2'):
for _ in range(4):
out_ref = model(x, a, b)
out_jit = jit_model(x, a, b)
out_ref = model(x, a, b)
out_jit = jit_model(x, a, b)
self.assertTrue(self._compare("comparing output failed", out_ref, out_jit, 1e-5))
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_high_rank_fusion(self):
# currently we want to limit fusion to node with input where rank <= 8
rank_limit = 8
shapes = [4 for i in range(rank_limit + 1)]
x = torch.randn(shapes, device="cuda")
with nvfuser_singleton_fusion(True):
def t(x):
return x.relu()
jit_t = torch.jit.script(t)
for i in range(5):
jit_t(x)
self.assertGraphContainsExactly(jit_t.graph_for(x), FUSION_GUARD, 0)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_clamp(self):
x = torch.tensor([1., float('inf'), 2., float('nan'), float('-inf')], device="cuda")
def clamp_max(x):
return x.clamp(max=1.5)
def clamp_min_max(x):
return x.clamp(min=1.5)
def clamp_min(x):
return x.clamp(min=1., max=3.)
with nvfuser_singleton_fusion(True):
for t in [clamp_max, clamp_min, clamp_min_max]:
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_device_constant(self):
x = torch.randn(4, 2, device="cuda")
def t(x):
return torch.rand_like(x, device=torch.device(type='cuda'))
# cpu tensor shouldn't be fused
def t_cpu(x):
return torch.rand_like(x, device=torch.device(type='cpu'))
with nvfuser_singleton_fusion(True):
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
t_cpu_jit = torch.jit.script(t_cpu)
for i in range(5):
t_cpu_jit(x)
self.assertGraphContainsExactly(t_cpu_jit.graph_for(x), FUSION_GUARD, 0)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_expand(self):
device = "cuda"
x = torch.randn(3, 5, device=device)
y = torch.randn(4, 2, 3, 5, device=device)
def t(x, y):
with torch.jit.strict_fusion():
x = x.relu()
o0 = x.expand(2, 3, 5)
o1 = x.expand_as(y)
return o0, o1
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, y, check_stride=True)
def t2(x, y):
o0 = x.expand(2, 3, 5)
o1 = x.expand_as(y)
x.add_(1)
return o0, o1
t2_jit = torch.jit.script(t2)
self._run_helper(t2_jit, t2, x, y, check_stride=True, num_fusion=0)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_scheduler_with_polymorphic_broadcast(self):
device = "cuda"
x0 = torch.randn(10, 128, device=device)
x1 = torch.rand_like(x0)
x2 = torch.randn(10, device=device)
def t(x0, x1, x2):
x3 = x2.unsqueeze(-1)
x4 = x3 + x0
x5 = x3 + x1
x6 = x5.sum(0)
return x4, x6
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x0, x1, x2, check_stride=True)
x2 = torch.randn(128, device=device)
def t2(x0, x1, x2):
x3 = x2.unsqueeze(0)
x4 = x3 + x0
x5 = x3 + x1
x6 = x5.sum(1)
return x4, x6
t2_jit = torch.jit.script(t2)
self._run_helper(t2_jit, t2, x0, x1, x2, check_stride=True)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_type_inference(self):
device = "cuda"
x0 = torch.randn(10, 128, device=device)
x1 = torch.rand_like(x0)
x2 = torch.rand_like(x0)
def t(x0, x1, x2, flag : bool = True):
x3 = 2.0 * x0
x4 = 2.0 * x1
x5 = 2.0 * x2
if flag:
return torch.stack([x3, x4, x5], dim=-1)
# second code path doesn't run through profiling
# hence would utilize type inference with profiling information
return x0 + x1 + x2
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x0, x1, x2, check_stride=True)
class TestEnableDisableCudaFuser(JitTestCase):
def setUp(self):
super().setUp()
if RUN_NVFUSER:
self.is_enabled = torch._C._jit_set_nvfuser_enabled(False)
def tearDown(self):
if RUN_NVFUSER:
torch._C._jit_set_nvfuser_enabled(self.is_enabled)
super().tearDown()
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_context_manager_test(self):
x = torch.randn(4, 8, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, dtype=torch.float, device="cuda")
with torch.jit.fuser('fuser2'):
with torch.jit.fuser('fuser2'):
def t1(x, y):
o = x + y
o = o + 2.0
return o
t_jit = torch.jit.script(t1)
t_jit(x, y)
t_jit(x, y)
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
def t2(x, y):
o = x + y
o = o + 3.0
return o
t_jit_2 = torch.jit.script(t2)
t_jit_2(x, y)
t_jit_2(x, y)
self.assertGraphContains(t_jit_2.graph_for(x, y), FUSION_GUARD)
def t3(x, y):
o = x + y
o = o + 4.0
return o
t_jit_3 = torch.jit.script(t3)
t_jit_3(x, y)
t_jit_3(x, y)
self.assertGraphContainsExactly(t_jit_3.graph_for(x, y), FUSION_GUARD, 0)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
def test_register_fuser(self):
self.assertFalse(torch._C._jit_set_nvfuser_enabled(True))
self.assertTrue(torch._C._jit_nvfuser_enabled())
self.assertTrue(torch._C._jit_set_nvfuser_enabled(True))
self.assertTrue(torch._C._jit_nvfuser_enabled())
self.assertTrue(torch._C._jit_set_nvfuser_enabled(False))
self.assertFalse(torch._C._jit_nvfuser_enabled())
@unittest.skipIf(RUN_CUDA, "Testing on CPU only")
def test_register_fuser_cpu(self):
with self.assertRaises(RuntimeError):
torch._C._jit_set_nvfuser_enabled(True)
torch._C._jit_set_nvfuser_enabled(False)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(not TEST_WITH_ROCM, "ROCM test only")
def test_register_fuser_rocm(self):
with self.assertRaises(RuntimeError):
torch._C._jit_set_nvfuser_enabled(True)
torch._C._jit_set_nvfuser_enabled(False)
def test_can_be_enabled_nvfuser(self):
if TEST_WITH_ROCM:
expected = False
else:
expected = RUN_CUDA
self.assertEqual(expected, torch._C._jit_nvfuser_can_be_enabled())
# See TestNNCOpInfoParent
class TestCudaFuserOpInfoParent(JitCommonTestCase):
pass
class TestCudaFuserOpInfo(TestCudaFuserOpInfoParent):
def setUp(self):
super(TestCudaFuserOpInfoParent, self).setUp()
if RUN_NVFUSER:
self.cuda_fuser_options = CudaFuserTestOptions()
# enables guard mode since tracing could change graph to violate guard.
torch._C._jit_set_nvfuser_guard_mode(True)
self.nvfuser_single_node_mode = torch._C._jit_set_nvfuser_single_node_mode(True)
def tearDown(self):
if RUN_NVFUSER:
self.cuda_fuser_options.restore()
torch._C._jit_set_nvfuser_single_node_mode(self.nvfuser_single_node_mode)
super(TestCudaFuserOpInfoParent, self).tearDown()
@slowTest
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@ops(op_db, dtypes=OpDTypes.supported)
def test_nvfuser_correctness(self, device, dtype, op):
if not op.supports_tracing:
self.skipTest("nvfuser requires tracing support")
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
for variant, sample in variant_sample_pairs:
trace = create_traced_fn(self, variant, cache_traced_fn=True)
ref = variant(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
val = trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
self.assertEqual(ref, val, exact_layout=True)
# Note: Clearing CU after NVFuser tests
# https://github.com/pytorch/pytorch/issues/35600
# each torch.jit.trace adds state to the _python_cu compilation unit
# since this test traces a lot of functions, out-of-memory can occur
# if the CU is not cleared.
torch.jit._state._python_cu.drop_all_functions()
@slowTest
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@ops(op_db, allowed_dtypes=(torch.float16, torch.bfloat16, torch.float32,
torch.float64, torch.complex64, torch.complex128))
def test_nvfuser_extremal_values(self, device, dtype, op):
if not op.supports_tracing:
self.skipTest("nvfuser requires tracing support")
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
def _get_extremal_tensor(x, val, dtype):
if x.dtype != dtype:
return x
return torch.full_like(x, val)
def _get_extremal_input(x, val, dtype):
if isinstance(x, torch.Tensor):
return _get_extremal_tensor(x, val, dtype)
elif is_iterable_of_tensors(x):
return [_get_extremal_tensor(y, val, dtype) for y in x]
return x
def _get_extremal_sample(sample: SampleInput, val, dtype):
extremal_sample = SampleInput(
input=_get_extremal_input(sample.input, val, dtype),
args=tuple(_get_extremal_input(x, val, dtype) for x in sample.args),
kwargs={k: _get_extremal_input(v, val, dtype) for k, v in sample.kwargs.items()},
)
return extremal_sample
def _get_extremal_samples(sample: SampleInput, dtype):
vals = [float('inf'), float('-inf'), float('nan')]
if dtype.is_complex:
complex_vals = itertools.product(vals, vals)
vals = tuple(map(lambda x: complex(*x), complex_vals))
for val in vals:
yield _get_extremal_sample(sample, val, dtype)
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
for variant, sample in variant_sample_pairs:
trace = create_traced_fn(self, variant, cache_traced_fn=True)
trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
for extremal_sample in _get_extremal_samples(sample, dtype):
try:
with freeze_rng_state():
ref = variant(*clone_inputs((extremal_sample.input, *extremal_sample.args)),
**extremal_sample.kwargs)
except (torch._C._LinAlgError, RuntimeError, ValueError):
# if eager errors out, then don't expect NVFuser to pass
continue
with freeze_rng_state():
val = trace(*clone_inputs((extremal_sample.input, *extremal_sample.args)),
**extremal_sample.kwargs)
self.assertEqual(val, ref, equal_nan=True, exact_device=True)
# See [Note: Clearing CU after NVFuser tests]
torch.jit._state._python_cu.drop_all_functions()
instantiate_device_type_tests(TestCudaFuserOpInfo, globals(), only_for=("cuda"))
if __name__ == '__main__':
run_tests()