pytorch/test/test_jit_cuda_fuser.py
jiej 2d110d514f Nvfuser code bump 2_1_2022 (#72127)
Summary:
Things changed in this PR that requires review:
1. aten/src/ATen/core/interned_strings.h
2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation
3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry
4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported

nvfuser code update:
1. codegen improvements and performance tuning
2. integration bug fixes for shape expression logic
3. kernel segmentation update to address perf regression from horizontal fusion
4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor

Things reverted from local changes:
aten::gelu with approximation (tracked in PR: https://github.com/pytorch/pytorch/pull/61439)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72127

Reviewed By: HamidShojanazeri

Differential Revision: D34113233

Pulled By: jbschlosser

fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74
(cherry picked from commit e009bc5c4e)
2022-02-15 00:43:16 +00:00

3713 lines
150 KiB
Python

# Owner(s): ["oncall: jit"]
import unittest
import os
import random
import enum
import copy
from functools import reduce
import operator
import torch
from torch.nn import functional
from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR # TEST_WITH_ROCM
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.codegen.random_topo_test import runDefaultTestWithSeed
from torch.testing import FileCheck
from test_jit import JitTestCase, RUN_CUDA
from jit.test_fuser_common import TestFuserCommon # noqa: F401
import itertools
import numpy as np
import math
from torch.autograd.gradcheck import gradcheck
from typing import List
CUDA_MAJOR, CUDA_MINOR = (int(x) for x in torch.version.cuda.split('.'))
os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = '1'
os.environ['PYTORCH_NVFUSER_DISABLE_FMA'] = '1'
os.environ['PYTORCH_NVFUSER_DISABLE_FASTMATH'] = '1'
os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0'
os.environ['PYTORCH_NVFUSER_DISABLE_RNG_UNROLL'] = '1'
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
FUSION_GROUP = 'prim::CudaFusionGroup'
FUSION_GUARD = 'prim::CudaFusionGuard'
import contextlib
@contextlib.contextmanager
def nvfuser_singleton_fusion(flag):
old_value = torch._C._jit_set_nvfuser_single_node_mode(flag)
try:
yield
finally:
torch._C._jit_set_nvfuser_single_node_mode(old_value)
@contextlib.contextmanager
def nvfuser_horizontal_fusion(flag):
old_value = torch._C._jit_set_nvfuser_horizontal_mode(flag)
try:
yield
finally:
torch._C._jit_set_nvfuser_horizontal_mode(old_value)
def is_pre_volta():
prop = torch.cuda.get_device_properties(torch.cuda.current_device())
return prop.major < 7
TEST_BF16 = torch.cuda.is_bf16_supported()
class TestCudaFuser(JitTestCase):
special_values = torch.tensor(
[float("-inf"), -10, -math.pi,
-1, -0.5, 0, 1, 0.5,
math.pi, 10, float("inf"),
float("nan")], dtype=torch.float, device='cuda')
int_types = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64
]
support_tensor_dtypes = [
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bool
]
if TEST_BF16:
support_tensor_dtypes.append(torch.bfloat16)
def _getSubgraphInFusion(self, graph):
num_node = 0
subgraph = None
def count(block, ret):
for n in block.nodes():
if n.kind() == FUSION_GROUP:
ret[0] = ret[0] + 1
self.assertTrue(n.hasAttribute('Subgraph'))
ret[1] = n.g('Subgraph')
for block in n.blocks():
count(block, ret)
ret = [num_node, subgraph]
count(graph, ret)
self.assertEqual(ret[0], 1)
return ret[1]
def setUp(self):
super(TestCudaFuser, self).setUp()
self.old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
self.old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
self.old_guard = torch._C._jit_set_nvfuser_guard_mode(False)
torch._C._debug_set_autodiff_subgraph_inlining(False)
self.old_value = torch._C._jit_set_autocast_mode(True)
if(RUN_CUDA):
self.old_nvfuser = torch._C._jit_set_nvfuser_enabled(True)
def tearDown(self):
if(RUN_CUDA):
torch._C._jit_set_nvfuser_enabled(self.old_nvfuser)
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuse)
torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuse)
torch._C._jit_set_nvfuser_guard_mode(self.old_guard)
torch._C._debug_set_autodiff_subgraph_inlining(True)
torch._C._jit_set_autocast_mode(self.old_value)
super(TestCudaFuser, self).tearDown()
def _run_helper(self, jit_op, op, *args):
torch.cuda.manual_seed_all(123)
jit_o = jit_op(*args)
torch.cuda.manual_seed_all(123)
jit_o = jit_op(*args)
torch.cuda.manual_seed_all(123)
o = op(*args)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContainsExactly(jit_op.graph_for(*args), FUSION_GUARD, 1, consider_subgraphs=True)
def _run_training_helper(self, jit_op, op, grads, *args):
torch.cuda.manual_seed_all(123)
jit_o = jit_op(*args)
jit_g = jit_o.backward(grads)
torch.cuda.manual_seed_all(123)
jit_o = jit_op(*args)
jit_g = jit_o.backward(grads)
torch.cuda.manual_seed_all(123)
jit_o = jit_op(*args)
jit_g = jit_o.backward(grads)
torch.cuda.manual_seed_all(123)
o = op(*args)
g = o.backward(grads)
self.assertEqual(o, jit_o)
self.assertEqual(g, jit_g)
self.assertGraphContainsExactly(jit_op.graph_for(*args), FUSION_GUARD, 1, consider_subgraphs=True)
bwd_graph = list(
list(jit_op.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
self.assertGraphContainsExactly(bwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_half(self):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, alpha: float):
o_16 = torch.add(x, y)
o_32_a = torch.add(y, z, alpha=alpha)
o_32_b = torch.add(o_16, z)
return (o_16, o_32_a, o_32_b)
t_jit = torch.jit.script(t)
alpha = 0.5
# stick to integers, this avoid the numerical difference due to our
# promotion
x = torch.randint(0, 256, (4, 8)).to(dtype=torch.float16, device="cuda")
y = torch.randint(0, 256, (4, 8)).to(dtype=torch.float16, device="cuda")
z = torch.randint(0, 256, (4, 8)).to(dtype=torch.float16, device="cuda")
jit_o = t_jit(x, y, z, alpha)
jit_o = t_jit(x, y, z, alpha)
o = t(x, y, z, alpha)
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
self.assertGraphContains(t_jit.graph_for(x, y, z, alpha), FUSION_GUARD)
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_bfloat(self):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, alpha: float):
o_16 = torch.add(x, y)
o_32_a = torch.add(y, z, alpha=alpha)
o_32_b = torch.add(o_16, z)
return (o_16, o_32_a, o_32_b)
t_jit = torch.jit.script(t)
alpha = 0.5
# stick to integers, this avoid the numerical difference due to our
# promotion
x = torch.randint(0, 256, (4, 8)).to(dtype=torch.bfloat16, device="cuda")
y = torch.randint(0, 256, (4, 8)).to(dtype=torch.bfloat16, device="cuda")
z = torch.randint(0, 256, (4, 8)).to(dtype=torch.bfloat16, device="cuda")
jit_o = t_jit(x, y, z, alpha)
jit_o = t_jit(x, y, z, alpha)
o = t(x, y, z, alpha)
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
self.assertGraphContains(t_jit.graph_for(x, y, z, alpha), FUSION_GUARD)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_const(self):
def t(x, y):
o = x + y
o = o + 2.0
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_chunk(self):
def t(x, y, z, q):
o = x + q
x0, x1 = torch.chunk(o, 2)
o = x0 + x1
o = o + y
o = o * z
o = torch.relu(o)
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, dtype=torch.float, device="cuda")
y = torch.randn(2, 8, dtype=torch.float, device="cuda")
z = torch.randn(2, 8, dtype=torch.float, device="cuda")
q = torch.randn(4, 8, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, z, q)
jit_o = t_jit(x, y, z, q)
o = t(x, y, z, q)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z, q), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_reduction_dtypes_axis(self):
for op in [torch.sum, torch.mean, torch.amax]:
for dtype in [torch.float16, torch.float32, torch.double]:
for axis in [-1, 2]:
def make_func(op):
def func(x: torch.Tensor):
o = torch.mul(x, 2.0)
o = op(o, dim=[axis])
return o
return func
x = torch.randn(8, 4, 16, dtype=dtype, device="cuda")
t = make_func(op)
t_jit = torch.jit.trace(t, x)
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_scalar_input(self):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, 1, 32, dtype=torch.float, device="cuda")
y = y.expand(4, 8, 32, 32)
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_0(self):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0))
self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_1(self):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(1, 32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0))
self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_2(self):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 1, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(8, 32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0))
self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_3(self):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(8, 17, 8, dtype=torch.float, device="cuda")
y = torch.randn(8, 17, 1, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0))
self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
# test_broadcasting_partition_logic_X
# Testing partition logic that is capable to avoid creating unsupported
# broadcasting semantics in CudaFusionGroup
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_partition_logic_0(self):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
x = x + 12.0
o1 = x + y
o2 = x + z
o = o1 + o2
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, 6, 8, dtype=torch.float32, device="cuda")
y = torch.randn(8, 6, 8, dtype=torch.float32, device="cuda")
z = torch.randn(6, 8, dtype=torch.float32, device="cuda")
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, z))
self.assertGraphContainsExactly(subgraph, 'aten::add', 4, consider_subgraphs=False)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_partition_logic_1(self):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
x = x + 12.0
o1 = x + y
o2 = x + z
o = o1 + o2
return o
t_jit = torch.jit.script(t)
x = torch.randn(8, 6, 8, dtype=torch.float32, device="cuda")
y = torch.randn(4, 8, 6, 8, dtype=torch.float32, device="cuda")
z = torch.randn(4, 1, 6, 8, dtype=torch.float32, device="cuda")
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, z))
self.assertGraphContainsExactly(subgraph, 'aten::add', 4, consider_subgraphs=False)
@unittest.skipIf(True, "Broadcast with different output not supported yet")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_multiple_output_shape(self):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = x + 12
o1 = o + y
o2 = o + z
oo = o1.sum() + o2.sum()
return oo
t_jit = torch.jit.script(t)
x = torch.randn(32, 32, dtype=torch.float, device="cuda")
y = torch.randn(2, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o, jit_o)
# Currently cannot fuse this
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
@unittest.skipIf(True, "broadcast on branches can't be resolved yet")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_multiple_output(self):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = x + 12
o1 = o + y
o2 = o + z
oo = o1.sum() + o2.sum()
return oo
t_jit = torch.jit.script(t)
x = torch.randn(32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o, jit_o)
# Currently cannot fuse this
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
def _unary_test_helper(self, operation, dtype, random_data):
gradient_check = (dtype == torch.float64) and random_data
shape = (8, 7)
torch.cuda.manual_seed_all(211)
# need additional def of t for boolean ops
def t(x: torch.Tensor, y: torch.Tensor):
o = x * y
o = o + 5e-3
o = operation(o)
return o
y = torch.rand(shape, dtype=torch.float32, device="cuda", requires_grad=gradient_check)
y = y.to(dtype=dtype)
if random_data:
x = torch.rand(shape, dtype=torch.float32, device="cuda", requires_grad=gradient_check)
if dtype in self.int_types:
# prefer a larger variance for integer types
x = x * 5
x = x.to(dtype=dtype)
else:
x = self.special_values.to(dtype=dtype)
try:
ref = t(x, y)
except Exception:
# same way as TE checker, if eager mode throws, ignore this test
return
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
if gradient_check:
gradcheck(t_jit, [x, y], nondet_tol=1e-5)
elif dtype in self.support_tensor_dtypes:
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("failing case {}\n{}\n{}\n{}".format(dtype, operation, x, y), o, jit_o, 1e-2))
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_unary_ops(self):
data_types = [
*self.int_types,
torch.float16,
torch.float32,
torch.float64
]
if TEST_BF16:
data_types.append(torch.bfloat16)
operations = [torch.neg,
torch.abs,
torch.log,
torch.log10,
torch.log1p,
torch.log2,
torch.lgamma,
torch.exp,
torch.expm1,
torch.erf,
torch.erfc,
torch.cos,
torch.acos,
torch.cosh,
torch.sin,
torch.asin,
torch.sinh,
torch.tan,
torch.atan,
torch.sqrt,
torch.rsqrt,
torch.ceil,
torch.floor,
torch.round,
torch.trunc,
torch.frac,
torch.reciprocal,
torch.nn.functional.softplus,
torch.nn.functional.gelu,
torch.relu,
torch.sigmoid,
torch.bitwise_not,
torch.tan,
torch.tanh,
torch.nn.functional.silu]
for op, dtype in itertools.product(operations, data_types):
self._unary_test_helper(op, dtype, False) # test special numbers
self._unary_test_helper(op, dtype, True) # test random data
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_category_rule(self):
def run_tensor(x, z):
def t(x: torch.Tensor, z: torch.Tensor):
o = x + z
o = torch.abs(o)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, z)
jit_o = t_jit(x, z)
o = t(x, z)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, z), FUSION_GUARD)
def run_scalar(x, z):
def t(x: torch.Tensor, z: float):
o = x + z
o = torch.abs(o)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, z)
jit_o = t_jit(x, z)
o = t(x, z)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, z), FUSION_GUARD)
# n-dim with 0-dim (no type-promote)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
z = torch.tensor(2.0, dtype=torch.double, device="cuda")
run_tensor(x, z)
# n-dim with 0-dim (type-promote)
x = torch.randn(4, 8, 32, 32, device="cuda").to(dtype=torch.long)
z = torch.tensor(2.0, dtype=torch.double, device="cuda")
run_tensor(x, z)
# n-dim with n-dim (type-promote)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 8, 32, 32, dtype=torch.double, device="cuda")
run_tensor(x, z)
# n-dim with scalar (no type-promote)
x = torch.randn(4, 8, 32, 32, dtype=torch.float16, device="cuda")
z = torch.tensor(3., dtype=torch.double)
run_scalar(x, z)
if TEST_BF16:
# n-dim with scalar (no type-promote)
x = torch.randn(4, 8, 32, 32, dtype=torch.bfloat16, device="cuda")
z = torch.tensor(3., dtype=torch.double)
run_scalar(x, z)
# n-dim with scalar (type-promote)
x = torch.randn(4, 8, 32, 32, device="cuda").to(dtype=torch.long)
z = torch.tensor(3., dtype=torch.double)
run_scalar(x, z)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_unary_bitwise(self):
def bit_not(x: torch.Tensor):
return ~(x + 1)
jitted = torch.jit.script(bit_not)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(5).to(torch.long)
jit_o = bit_not(x)
jit_o = bit_not(x)
o = bit_not(x)
self.assertEqual(o, jit_o)
jitted.graph_for(x) # Shows up in second instance, not first
self.assertGraphContains(jitted.graph_for(x), FUSION_GUARD)
def bool_not(x: torch.Tensor, y: torch.Tensor):
return ~(x & y)
jitted = torch.jit.script(bool_not)
x = torch.rand(4, 8, 32, 32, dtype=torch.float, device="cuda").round().to(torch.bool)
y = torch.rand(4, 8, 32, 32, dtype=torch.float, device="cuda").round().to(torch.bool)
jit_o = bool_not(x, y)
jit_o = bool_not(x, y)
o = bool_not(x, y)
self.assertEqual(o, jit_o)
jitted.graph_for(x, y) # Shows up in second instance, not first
self.assertGraphContains(jitted.graph_for(x, y), FUSION_GUARD)
def _binary_test_helper(self, operation, dtypes, random_data):
if isinstance(dtypes, tuple):
dtype_arg1, dtype_arg2 = dtypes
else:
dtype_arg1 = dtype_arg2 = dtypes
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = operation(x, y)
o = o + z
return o
def t_int(x: torch.Tensor, y: torch.Tensor):
o = operation(x, y)
o = 2 + o
return o
def t_float(x: torch.Tensor, y: torch.Tensor):
o = operation(x, y)
o = 2. + o
return o
shape = (4, 32, 32)
if random_data:
x = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg1)
y = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg2)
else:
x = self.special_values.to(dtype=dtype_arg1)
y = (torch.rand_like(self.special_values) * 5).to(dtype_arg2)
z = torch.tensor([2], device="cuda").to(dtype_arg1)
# Avoid division by zero for integer tensors
div_like = [torch.div, torch.fmod, torch.remainder]
if operation in div_like and (dtype_arg2 == torch.int32 or dtype_arg2 == torch.int64):
y[y == 0] = 1
for test_fn in [t, t_int, t_float]:
o = t(x, y, z)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_binary_ops(self):
# disabled bf16 / fp16 data types because of accuracy tolerance
data_types = [
torch.int32,
torch.int64,
# torch.float16,
torch.float32,
torch.float64
]
'''
if TEST_BF16:
data_types.append(torch.bfloat16)
'''
operations = [torch.mul,
torch.div,
torch.atan2,
torch.max,
torch.min,
torch.pow,
torch.remainder,
torch.fmod,
torch.eq,
torch.ne,
torch.ge,
torch.gt,
torch.le,
torch.lt]
binary_dtype_combinations = itertools.combinations(data_types, 2)
for op, dtypes in itertools.product(operations, binary_dtype_combinations):
self._binary_test_helper(op, dtypes, True) # random data
self._binary_test_helper(op, dtypes, False) # special numbers
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_binary_bitwise(self):
def jit_or(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return (x & y) | z
def jit_xor(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return (x & y) ^ z
def jit_lshift(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return (x & y) << z
def jit_rshift(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return (x & y) >> z
for jit_func in [jit_or, jit_xor, jit_lshift, jit_rshift]:
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(5).to(torch.long)
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(5).to(torch.long)
z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(2).to(torch.long)
jitted = torch.jit.script(jit_func)
jit_o = jitted(x, y, z)
jit_o = jitted(x, y, z)
o = jit_func(x, y, z)
self.assertEqual(o, jit_o)
self.assertGraphContains(jitted.graph_for(x, y, z), FUSION_GUARD)
# We shouldn't need this redefinition of the function, but otherwise it won't recompile for a new type
def jit_or(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return (x & y) | z
def jit_xor(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return (x & y) ^ z
for jit_func in [jit_or, jit_xor]:
x = torch.rand(4, 2, dtype=torch.float, device="cuda").round().to(torch.bool)
y = torch.rand(4, 2, dtype=torch.float, device="cuda").round().to(torch.bool)
z = torch.rand(4, 2, dtype=torch.float, device="cuda").round().to(torch.bool)
jitted = torch.jit.script(jit_func)
jit_o = jitted(x, y, z)
jit_o = jitted(x, y, z)
o = jit_func(x, y, z)
self.assertEqual(o, jit_o)
self.assertGraphContains(jitted.graph_for(x, y, z), FUSION_GUARD)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_type_as_op(self):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = torch.lt(x, z)
o = o.type_as(y)
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 0.5)
jit_o = t_jit(x, y, 0.5)
o = t(x, y, 0.5)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, 0.5), FUSION_GUARD)
def _ternary_integer_test_helper(self, dtype_arg1):
shape = (4, 8, 32, 32)
magnitude = 100
if (dtype_arg1 in self.int_types):
x = torch.randint(-magnitude, magnitude, shape, dtype=dtype_arg1, device="cuda")
else:
x = torch.randn(shape, dtype=dtype_arg1, device="cuda") * magnitude
arg2 = int(0)
arg3 = int(magnitude * 0.1)
def clamp0(x: torch.Tensor, f: int):
o = 2. * torch.clamp(x, min=f)
return o
clamp0_jit = torch.jit.script(clamp0)
self._run_helper(clamp0_jit, clamp0, x, arg2)
def clamp1(x: torch.Tensor, f: int, ff: int):
o = 2. * torch.clamp(x, min=f, max=ff)
return o
clamp1_jit = torch.jit.script(clamp1)
self._run_helper(clamp1_jit, clamp1, x, arg2, arg3)
def clamp2(x: torch.Tensor, f: float, ff: int):
o = 2. * torch.clamp(x, min=f, max=ff)
return o
clamp2_jit = torch.jit.script(clamp2)
self._run_helper(clamp2_jit, clamp2, x, float(arg2), arg3)
def clamp3(x: torch.Tensor, f: int, ff: float):
o = 2. * torch.clamp(x, min=f, max=ff)
return o
clamp3_jit = torch.jit.script(clamp3)
self._run_helper(clamp3_jit, clamp3, x, arg2, float(arg3))
def threshold(x: torch.Tensor, th: int, val: int):
o = 2. * torch.threshold(x, th, val)
return o
threshold_jit = torch.jit.script(threshold)
self._run_helper(threshold_jit, threshold, x, arg2, arg3)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_ternary_ops_integer_compatibility(self):
data_types = [
torch.float16,
torch.float32,
torch.float64
]
for dtype in data_types:
self._ternary_integer_test_helper(dtype)
def _ternary_test_helper(self, operation, dtypes, random_data):
if isinstance(dtypes, tuple):
dtype_arg1, dtype_arg2, dtype_arg3 = dtypes
else:
dtype_arg1 = dtype_arg2 = dtype_arg3 = dtypes
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, alpha: torch.Tensor):
o = operation(x, y, z)
o = o + alpha
return o
shape = (4, 32, 32)
if operation is torch.where:
dtype_arg1 = torch.bool
if random_data:
x = torch.randint(0, 2, shape).to(dtype=torch.bool, device="cuda")
y = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg2)
z = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg3)
else:
x = torch.randint(0, 2, self.special_values.size()).to(dtype=torch.bool, device="cuda")
y = self.special_values.to(dtype=dtype_arg2)
z = (torch.rand_like(self.special_values) * 5).to(dtype_arg3)
elif random_data:
x = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg1)
y = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg2)
z = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg3)
else:
x = self.special_values.to(dtype=dtype_arg1)
y = (torch.rand_like(self.special_values) * 5).to(dtype_arg2)
z = (torch.rand_like(self.special_values) * 5).to(dtype_arg3)
alpha = torch.tensor([2], device="cuda").to(dtype_arg1)
o = t(x, y, z, alpha)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y, z, alpha)
jit_o = t_jit(x, y, z, alpha)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_ternary_ops_type_promotion(self):
# TODO: update accuracy tolerance for bf16 / fp16 data types
data_types = [
# torch.float16,
torch.float32,
torch.float64
]
'''
if TEST_BF16:
data_types.append(torch.bfloat16)
'''
# TODO: Add Tensor support for clamp
operations = [torch.clamp]
ternary_dtype_combinations = itertools.combinations(data_types, 3)
for op, dtypes in itertools.product(operations, ternary_dtype_combinations):
self._ternary_test_helper(op, dtypes, True) # random data
self._ternary_test_helper(op, dtypes, False) # special numbers
# We can't test the scalar version of rsub from python
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective")
def test_rsub(self):
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
def rsub(x: torch.Tensor, y: torch.Tensor):
o = torch.rsub(x, y)
o = o * 2.
return o
rsub_jit = torch.jit.script(rsub)
self._run_helper(rsub_jit, rsub, x, y)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
# legacy fuser does not work for rand_like, see issue #34361
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective")
def test_ternary_ops(self):
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
cond = torch.randint(0, 2, (4, 8, 32, 32)).to(dtype=torch.bool, device="cuda")
def add(x: torch.Tensor, other: torch.Tensor, alpha: float):
o = torch.relu(x)
o = torch.add(o, other=other, alpha=alpha)
return o
add_jit = torch.jit.script(add)
self._run_helper(add_jit, add, x, y, 2.0)
def clamp0(x: torch.Tensor, f: float):
o = 2. * torch.clamp(x, min=f)
return o
clamp0_jit = torch.jit.script(clamp0)
self._run_helper(clamp0_jit, clamp0, x, 0.5)
def clamp1(x: torch.Tensor, f: float, ff: float):
o = 2. * torch.clamp(x, min=f, max=ff)
return o
clamp1_jit = torch.jit.script(clamp1)
self._run_helper(clamp1_jit, clamp1, x, -0.2, 0.7)
def threshold(x: torch.Tensor, th: float, val: float):
o = 2. * torch.threshold(x, th, val)
return o
threshold_jit = torch.jit.script(threshold)
self._run_helper(threshold_jit, threshold, x, 0.2, 0.9)
def where(x: torch.Tensor, y: torch.Tensor, cond: torch.Tensor):
o = 2. * torch.where(cond, x, y)
return o
where_jit = torch.jit.script(where)
self._run_helper(where_jit, where, x, y, cond)
def lerp(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = 2. * torch.lerp(x, y, z)
return o
lerp_jit = torch.jit.script(lerp)
self._run_helper(lerp_jit, lerp, x, y, z)
def lerp_scale(x: torch.Tensor, y: torch.Tensor, z: float):
o = 2. * torch.lerp(x, y, z)
return o
lerp_scale_jit = torch.jit.script(lerp_scale)
self._run_helper(lerp_scale_jit, lerp_scale, x, y, 0.5)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
def test_addcmul_ops(self):
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
def addcmul(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, value: float):
o = torch.add(x, 0.5)
o = torch.addcmul(o, y, z, value=value)
return o
addcmul_jit = torch.jit.script(addcmul)
self._run_helper(addcmul_jit, addcmul, x, y, z, 2.0)
def addcmul_no_alpha(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = torch.add(x, 0.5)
o = torch.addcmul(o, y, z)
return o
addcmul_no_alpha_jit = torch.jit.script(addcmul_no_alpha)
self._run_helper(addcmul_no_alpha_jit, addcmul_no_alpha, x, y, z)
def addcmul_const_alpha(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = torch.add(x, 0.5)
o = torch.addcmul(o, y, z, value=0.75)
return o
addcmul_const_alpha_jit = torch.jit.script(addcmul_const_alpha)
self._run_helper(addcmul_const_alpha_jit, addcmul_const_alpha, x, y, z)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_dynamic_size(self):
old_guard = torch._C._jit_set_nvfuser_guard_mode(True)
torch._C._jit_set_bailout_depth(20)
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0))
self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
# this test is not ideal, as we rely on the bailout to test it and we
# don't know a way to verify the bailout graph to validate the proper
# fusion.
x = torch.randn(8, 32, 16, 8, dtype=torch.float, device="cuda")
y = torch.randn(16, 8, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD)
x = torch.randn(8, 17, 8, dtype=torch.float, device="cuda")
y = torch.randn(8, 17, 1, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD)
torch._C._jit_set_nvfuser_guard_mode(old_guard)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_random_topo(self):
os.environ["PYTORCH_NVFUSER_DISABLE_FALLBACK"] = "1"
self.assertTrue(runDefaultTestWithSeed(28449))
def _compare(self, desc, inp1, inp2, error):
a = inp1.clone()
b = inp2.clone()
close = torch.allclose(a, b, rtol=error, atol=error)
if not close:
print(desc, close)
z = a - b
index = (torch.abs(z) >= error + error * torch.abs(b)).nonzero()
print("dif : ", z[index])
print("inp1 : ", a[index])
print("inp2 : ", b[index])
print("maximum difference", z[index].max())
return close
# Permutation helper that applies binary operation between two tensors:
# 1. applies separate permutation `perm0` & `perm1` to two inputs
# 2. reduce dimension `broadcast_axis` of operand two to size 1
# The purpose of this test is to ensure permutation works well in
# complicated cases with arbitrary stride order and broadcasting dimensions
def _permutation_helper(self, sizes, broadcast_axis, dtype, device, perm0, perm1):
def t(x: torch.Tensor, y: torch.Tensor):
o = torch.add(x, y)
o = torch.relu(o)
return o
x = torch.randn([sizes[i] for i in perm0], dtype=dtype, device=device).permute(
[perm0.index(i) for i in range(len(sizes))])
if broadcast_axis >= 0:
sizes[broadcast_axis] = 1
y = torch.randn([sizes[i] for i in perm1], dtype=dtype, device=device).permute(
[perm1.index(i) for i in range(len(sizes))])
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertEqual(o.stride(), jit_o.stride())
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
# end-2-end test of permutation & contiguity handling in integration.
# we are testing inputs with all combination of permutation order, just to
# ensure that integration would be able to generate functionally correct
# kernels
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_binary_ops_permutation(self):
# note that num_dim is exclusive from len(x), so we are not reducing
# to single element (codegen limitation at this moment)
x = [7, 8, 12]
b_axes = range(-1, len(x))
for b_axis in b_axes:
for perm0 in itertools.permutations(range(len(x))):
for perm1 in itertools.permutations(range(len(x))):
x = [7, 8, 12]
self._permutation_helper(x, b_axis, torch.float32, "cuda", perm0, perm1)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_binary_ops_channels_last_with_bcast(self):
device = "cuda"
x = torch.randn([4, 3, 2, 5], device=device).to(memory_format=torch.channels_last)
w = torch.randn([2, 5], device=device)
def t(x: torch.Tensor, b: torch.Tensor):
o = x + b
return torch.relu(o)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, w)
jit_o = t_jit(x, w)
jit_o = t_jit(x, w)
o = t(x, w)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
self.assertGraphContains(t_jit.graph_for(x, w), FUSION_GUARD)
def _reduction_helper(self, sizes, reduction_axis, dtype, device, perm0, perm1, keepdim=False):
class MyReduction(torch.nn.Module):
__constants__ = ['reduction_axis', 'keepdim']
def __init__(self):
super(MyReduction, self).__init__()
self.reduction_axis = reduction_axis
self.keepdim = keepdim
def forward(self, x: torch.Tensor, y: torch.Tensor):
o = torch.add(x, y)
o = torch.sum(o, dim=self.reduction_axis, keepdim=self.keepdim)
return o
t = MyReduction()
x = torch.randn([sizes[i] for i in perm0], dtype=dtype, device=device).permute(
[perm0.index(i) for i in range(len(sizes))])
y = torch.randn([sizes[i] for i in perm1], dtype=dtype, device=device).permute(
[perm1.index(i) for i in range(len(sizes))])
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
# numerical issues here due to our scheduling.
# can't use `self.assertEqual(o, jit_o)`
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_reduction(self):
for x in ([7, 8, 12], [12, 8, 7, 9, 15], [128, 16, 8, 32]):
# note that num_dim is exclusive from len(x), so we are not reducing
# to single element (codegen limitation at this moment)
for num_reduce_dim in range(1, len(x)):
for axes in itertools.combinations(range(len(x)), num_reduce_dim):
for keepdim in (True, False):
perm0 = range(len(x))
perm1 = range(len(x))
self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1, keepdim)
def _layer_norm_autodiff_helper(self, model, grad, shapes, args):
jit_model = torch.jit.script(model)
eps = np.random.random() * 1e-4
use_cudnn = bool(np.random.randint(0, 2))
# profile/optimization runs
for i in range(3):
jit_o = jit_model(shapes, *args, eps, use_cudnn)
jit_o.backward(grad)
ref_args = [t.detach().clone().requires_grad_() for t in args]
[t.grad.zero_() for t in args]
jit_o = jit_model(shapes, *args, eps, use_cudnn)
jit_o.backward(grad)
o = model(shapes, *ref_args, eps, use_cudnn)
o.backward(grad)
self.assertEqual(jit_o, o)
for arg, ref_arg in zip(args, ref_args):
self.assertEqual(arg.grad, ref_arg.grad)
# check fusion in fw & bw
g = jit_model.graph_for(shapes, *args, eps, use_cudnn)
for node in g.nodes():
n = node
dbg_state = jit_model.get_debug_state()
for val in dbg_state.execution_plans.values():
v = val
state2 = v.code.grad_executor_states()
for val in state2[0].execution_plans.values():
v2 = val
FileCheck().check(FUSION_GUARD).run(g)
FileCheck().check(FUSION_GUARD).run(v2.graph)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_layer_norm_autodiff(self):
def t_wb(shapes: List[int], x, w, b, eps: float, cudnn: bool):
o = torch.layer_norm(x, shapes, w, b, eps, cudnn)
o = torch.relu(o)
return o
def t_w(shapes: List[int], x, w, eps: float, cudnn: bool):
o = torch.layer_norm(x, shapes, w, None, eps, cudnn)
o = torch.relu(o)
return o
def t_b(shapes: List[int], x, b, eps: float, cudnn: bool):
o = torch.layer_norm(x, shapes, None, b, eps, cudnn)
o = torch.relu(o)
return o
def t(shapes: List[int], x, eps: float, cudnn: bool):
o = torch.layer_norm(x, shapes, None, None, eps, cudnn)
o = torch.relu(o)
return o
model = {3: t_wb, 2: t_w, 1: t_b, 0: t}
for w, b in itertools.product([True, False], repeat=2):
batch = [2]
# note: awkward shape here to avoid vectorized fast kernel, which is
# buggy in aten
shapes = [2, 7, 3]
m = model[w * 2 + b]
grad = torch.randn(batch + shapes, dtype=torch.float32, device="cuda")
args = [torch.randn(batch + shapes, dtype=torch.float32, device="cuda").requires_grad_()]
if w:
args.append(torch.randn(shapes, dtype=torch.float32, device="cuda").requires_grad_())
if b:
args.append(torch.randn(shapes, dtype=torch.float32, device="cuda").requires_grad_())
self._layer_norm_autodiff_helper(m, grad, shapes, args)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_layer_norm_parser(self):
dtype = torch.float32
device = "cuda"
x = torch.randn([4, 4, 2], dtype=dtype, device=device)
w = torch.randn([4, 2], dtype=dtype, device=device)
b = torch.randn([4, 2], dtype=dtype, device=device)
def t(x: torch.Tensor, w: torch.Tensor, b: torch.Tensor):
o = torch.relu(x)
o = torch.layer_norm(o, [4, 2], w, b, 1e-5)
return o
o = t(x, w, b)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, w, b)
jit_o = t_jit(x, w, b)
o = t(x, w, b)
self.assertGraphContains(t_jit.graph_for(x, w, b), FUSION_GUARD)
def _native_layer_norm_helper(self, shape, norm_shape, dtype, device, error, affine=True):
class MyLayerNorm(torch.nn.Module):
__constants__ = ['norm_shape']
def __init__(self, elementwise_affine=True):
super(MyLayerNorm, self).__init__()
self.norm_shape = norm_shape
if elementwise_affine:
self.weight = torch.randn(norm_shape, dtype=dtype, device=device)
self.bias = torch.randn(norm_shape, dtype=dtype, device=device)
with torch.no_grad():
self.weight.fill_(1)
self.bias.fill_(0)
else:
self.weight = None
self.bias = None
def forward(self, x: torch.Tensor):
o = torch.relu(x)
o = torch.native_layer_norm(o, self.norm_shape, self.weight, self.bias, 1e-5)
return o
t = MyLayerNorm(affine)
x = torch.randn(shape, dtype=dtype, device=device)
t_jit = torch.jit.script(t)
jit_o, jit_mean, jit_rstd = t_jit(x)
jit_o, jit_mean, jit_rstd = t_jit(x)
o, mean, rstd = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
# numerical issues here due to our scheduling.
# can't use `self.assertEqual(o, jit_o)`
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
self.assertTrue(self._compare("comparing mean failed", mean, jit_mean, error))
self.assertTrue(self._compare("comparing rstd failed", rstd, jit_rstd, error))
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_native_layer_norm(self):
dims = 4
rnds = 3
for idx in range(rnds):
for offset in range(1, dims):
for affine in (True, False):
input_shape = [random.randint(10, 30) for idx in range(dims)]
norm_shape = [input_shape[idx] for idx in range(dims - offset, dims)]
self._native_layer_norm_helper(input_shape, norm_shape, torch.float32, "cuda", 1e-4, affine)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_native_layer_norm_half(self):
dims = 4
rnds = 3
for idx in range(rnds):
for offset in range(1, dims):
input_shape = [random.randint(10, 30) for idx in range(dims)]
norm_shape = [input_shape[idx] for idx in range(dims - offset, dims)]
self._native_layer_norm_helper(input_shape, norm_shape, torch.float16, "cuda", 5e-3)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_native_layer_norm_bfloat(self):
dims = 4
rnds = 3
for idx in range(rnds):
for offset in range(1, dims):
input_shape = [random.randint(10, 30) for idx in range(dims)]
norm_shape = [input_shape[idx] for idx in range(dims - offset, dims)]
self._native_layer_norm_helper(input_shape, norm_shape, torch.bfloat16, "cuda", 1e-1)
def _norm_helper(self, shape, dtype, device, error, is_batch_norm_else_instance_norm, memory_format=torch.contiguous_format):
class MyBatchNorm(torch.nn.Module):
def __init__(self):
super(MyBatchNorm, self).__init__()
def forward(self, x: torch.Tensor, r_mean: torch.Tensor, r_var: torch.Tensor):
o = torch.nn.functional.batch_norm(x, r_mean, r_var, training=True)
o = torch.relu(o)
return o
class MyInstanceNorm(torch.nn.Module):
def __init__(self):
super(MyInstanceNorm, self).__init__()
def forward(self, x: torch.Tensor, r_mean: torch.Tensor, r_var: torch.Tensor):
o = torch.nn.functional.instance_norm(x, r_mean, r_var, use_input_stats=True)
o = torch.relu(o)
return o
t = MyBatchNorm() if is_batch_norm_else_instance_norm else MyInstanceNorm()
x = torch.randn(shape, dtype=dtype, device=device).to(memory_format=memory_format)
running_mean = torch.zeros(shape[1], dtype=torch.float32, device=device)
running_var = torch.ones(shape[1], dtype=torch.float32, device=device)
t_jit = torch.jit.script(t)
eager_running_mean = running_mean.clone()
eager_running_var = running_var.clone()
jit_running_mean = running_mean.clone()
jit_running_var = running_var.clone()
jit_o = t_jit(x, running_mean.clone(), running_var.clone())
self.assertTrue(self._compare("prerun comparing running_mean failed", eager_running_mean, jit_running_mean, error))
self.assertTrue(self._compare("prerun comparing running_var failed", eager_running_var, jit_running_var, error))
jit_o = t_jit(x, jit_running_mean, jit_running_var)
o = t(x, eager_running_mean, eager_running_var)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.stride(), jit_o.stride())
# numerical issues here due to our scheduling.
# can't use `self.assertEqual(o, jit_o)`
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
self.assertTrue(self._compare("comparing running_mean failed", eager_running_mean, jit_running_mean, error))
self.assertTrue(self._compare("comparing running_var failed", eager_running_var, jit_running_var, error))
self.assertGraphContains(t_jit.graph_for(x, running_mean, running_var), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_norm_channels_last(self):
size = [3, 4, 5, 6]
with torch.backends.cudnn.flags(enabled=False):
for is_batch_norm_else_instance_norm in [False, True]:
for mf in [torch.channels_last, torch.contiguous_format]:
self._norm_helper(size, torch.float32, "cuda", 1e-4, is_batch_norm_else_instance_norm, memory_format=mf)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_norm(self):
output_elements = 10000
channel_sizes = [67, 457, 1024, 4096]
with torch.backends.cudnn.flags(enabled=False):
for is_batch_norm_else_instance_norm in [False, True]:
for dims in range(3, 6):
output_size = int(pow(output_elements, 1. / (dims - 1)))
for C in channel_sizes:
x = [output_size for idx in range(dims)]
x[1] = C
self._norm_helper(x, torch.float32, "cuda", 1e-4, is_batch_norm_else_instance_norm)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_norm_large(self):
output_elements = 262144
channel_sizes = 67, 457, 1024
for is_batch_norm_else_instance_norm in [True, False]:
for dims in range(3, 6):
output_size = int(pow(output_elements, 1. / (dims - 1)))
for C in channel_sizes:
x = [output_size for idx in range(dims)]
x[1] = C
self._norm_helper(x, torch.float32, "cuda", 1e-4, is_batch_norm_else_instance_norm)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_norm_half(self):
output_elements = 10000
channel_sizes = [67, 457, 1024, 4096]
with torch.backends.cudnn.flags(enabled=False):
for is_batch_norm_else_instance_norm in [False, True]:
for dims in range(3, 6):
output_size = int(pow(output_elements, 1. / (dims - 1)))
for C in channel_sizes:
x = [output_size for idx in range(dims)]
x[1] = C
self._norm_helper(x, torch.float16, "cuda", 5e-3, is_batch_norm_else_instance_norm)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_norm_bfloat(self):
output_elements = 10000
channel_sizes = [67, 457, 1024, 4096]
with torch.backends.cudnn.flags(enabled=False):
for is_batch_norm_else_instance_norm in [False, True]:
for dims in range(3, 6):
output_size = int(pow(output_elements, 1. / (dims - 1)))
for C in channel_sizes:
x = [output_size for idx in range(dims)]
x[1] = C
self._norm_helper(x, torch.bfloat16, "cuda", 1e-1, is_batch_norm_else_instance_norm)
def _softmax_helper(self, shape, reduction_axis, dtype, device, error):
class MySoftmax(torch.nn.Module):
__constants__ = ['reduction_axis']
def __init__(self):
super(MySoftmax, self).__init__()
self.reduction_axis = reduction_axis
def forward(self, x: torch.Tensor, y: torch.Tensor):
o = torch.add(x, y)
o = torch.nn.functional.softmax(o, dim=self.reduction_axis)
return o
t = MySoftmax()
x = torch.randn(shape, dtype=dtype, device=device)
y = torch.randn(shape, dtype=dtype, device=device)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
# numerical issues here due to our scheduling.
# can't use `self.assertEqual(o, jit_o)`
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_softmax_dtype(self):
def t(x: torch.Tensor, y: torch.Tensor):
o = torch.mul(x, y)
o = torch.nn.functional.softmax(o, dim=0, dtype=torch.float32)
return o
x = torch.randn([4, 4], dtype=torch.float16, device="cuda").requires_grad_()
y = torch.randn_like(x).requires_grad_()
grad = torch.randn_like(x).float()
ref_x = x.detach().requires_grad_()
ref_y = y.detach().requires_grad_()
o = t(ref_x, ref_y)
o.backward(grad)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o.backward(grad)
jit_o = t_jit(x, y)
jit_o.backward(grad)
jit_o = t_jit(x, y)
jit_o.backward(grad)
x.grad.zero_()
y.grad.zero_()
jit_o = t_jit(x, y)
jit_o.backward(grad)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(ref_x.grad, x.grad)
self.assertEqual(ref_y.grad, y.grad)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-3))
self.assertGraphContainsExactly(t_jit.graph_for(x, y), FUSION_GUARD, 1, consider_subgraphs=True)
bwd_graph = list(
list(t_jit.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
FileCheck().check(FUSION_GUARD).run(bwd_graph)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test__softmax_function(self):
def t(x: torch.Tensor, y: torch.Tensor):
o = torch.mul(x, y)
o = torch._softmax(o, dim=-1, half_to_float=False)
return o
x = torch.randn([4, 4], dtype=torch.float16, device="cuda")
y = torch.randn_like(x)
o = t(x, y)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-3))
self.assertGraphContainsExactly(t_jit.graph_for(x, y), FUSION_GUARD, 1, consider_subgraphs=True)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test__softmax_function_half_to_float(self):
def t(x: torch.Tensor, y: torch.Tensor):
o = torch.mul(x, y)
o = torch._softmax(o, dim=-1, half_to_float=True)
return o
x = torch.randn([4, 4], dtype=torch.float16, device="cuda")
y = torch.randn_like(x)
o = t(x, y)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-3))
self.assertGraphContainsExactly(t_jit.graph_for(x, y), FUSION_GUARD, 1, consider_subgraphs=True)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_softmax(self):
output_size = 10000
dims = 4
output_size = int(pow(output_size, 1. / dims))
reduction_sizes = [67, 256, 1024, 4096]
for reduction_dim in range(dims):
for reduction_size in reduction_sizes:
x = [output_size for idx in range(dims)]
x[reduction_dim] = reduction_size
self._softmax_helper(x, reduction_dim, torch.float32, "cuda", 1e-4)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_softmax_half(self):
output_size = 10000
dims = 4
output_size = int(pow(output_size, 1. / dims))
reduction_sizes = [67, 256, 1024, 4096]
for reduction_dim in range(dims):
for reduction_size in reduction_sizes:
x = [output_size for idx in range(dims)]
x[reduction_dim] = reduction_size
self._softmax_helper(x, reduction_dim, torch.float16, "cuda", 5e-3)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_softmax_bfloat(self):
output_size = 10000
dims = 4
output_size = int(pow(output_size, 1. / dims))
reduction_sizes = [67, 256, 1024, 4096]
for reduction_dim in range(dims):
for reduction_size in reduction_sizes:
x = [output_size for idx in range(dims)]
x[reduction_dim] = reduction_size
self._softmax_helper(x, reduction_dim, torch.bfloat16, "cuda", 1e-1)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_reduction_permutation(self):
x = [7, 8, 12]
# note that num_dim is exclusive from len(x), so we are not reducing
# to single element (codegen limitation at this moment)
for num_reduce_dim in range(1, len(x)):
for axes in itertools.combinations(range(len(x)), num_reduce_dim):
for perm0 in itertools.permutations(range(len(x))):
for perm1 in itertools.permutations(range(len(x))):
self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_reduction_multiple_output(self):
old_guard = torch._C._jit_set_nvfuser_guard_mode(True)
torch._C._jit_set_bailout_depth(20)
def t(x: torch.Tensor, y: torch.Tensor, scale: float, z: torch.Tensor):
o = torch.mul(x, y)
o = torch.mul(o, scale)
out1 = torch.mul(o, z)
out2 = torch.sum(out1, dim=[2])
return out1, out2
t_jit = torch.jit.script(t)
x = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda")
y = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda")
z = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda")
scale = 0.5
jit_o = t_jit(x, y, scale, z)
jit_o = t_jit(x, y, scale, z)
o = t(x, y, scale, z)
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
self.assertGraphContains(t_jit.graph_for(x, y, scale, z), FUSION_GUARD)
x = x.to(memory_format=torch.channels_last)
y = y.to(memory_format=torch.channels_last)
z = z.to(memory_format=torch.channels_last)
jit_o = t_jit(x, y, scale, z)
jit_o = t_jit(x, y, scale, z)
o = t(x, y, scale, z)
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
self.assertGraphContains(t_jit.graph_for(x, y, scale, z), FUSION_GUARD)
torch._C._jit_set_nvfuser_guard_mode(old_guard)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_channels_last_with_broadcast(self):
# setting this true forces a new graph to be generated with a new
# input a different broadcast shape
torch._C._jit_set_nvfuser_guard_mode(True)
def t(x: torch.Tensor, y: torch.Tensor):
o = torch.mul(x, y)
o = o + 2.0
return o
t_jit = torch.jit.script(t)
# Single Channel broadcasts
# Test 1
x = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda")
x = x.to(memory_format=torch.channels_last)
y = torch.randn(8, 4, 10, 1, dtype=torch.float, device="cuda")
y = y.to(memory_format=torch.channels_last)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
jit_o.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(o, jit_o)
# Test 2
y = torch.randn(8, 4, 1, 16, dtype=torch.float, device="cuda")
y = y.to(memory_format=torch.channels_last)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
jit_o.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(o, jit_o)
# Test 3
y = torch.randn(8, 1, 10, 16, dtype=torch.float, device="cuda")
y = y.to(memory_format=torch.channels_last)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
jit_o.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(o, jit_o)
# Test 3
y = torch.randn(1, 4, 10, 16, dtype=torch.float, device="cuda")
y = y.to(memory_format=torch.channels_last)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
jit_o.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(o, jit_o)
'''
Currently, the JIT doesn't have tensor merge logic to handle adding
a broadcast tensor with more than one broadcast into a non-broadcast
tensor. Therefore, either of these tests can fail depending on the
sort implementation. The second test is known to fail.
# Two Channel broadcasts
# Test 1
y = torch.randn(8, 4, 1, 1, dtype=torch.float, device="cuda")
y = y.to(memory_format=torch.channels_last)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
jit_o.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(o, jit_o)
# Test 2
y = torch.randn(8, 4, 1, 1, dtype=torch.float, device="cuda")
y = y.to(memory_format=torch.channels_last).transpose(2,3)
x = x.transpose(2,3)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
jit_o.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(o, jit_o)
'''
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_pw_single_reduction_partition(self):
sizes = [2, 2, 2]
dtype = torch.float
device = "cuda"
x = torch.randn(sizes, dtype=dtype, device=device)
y = torch.randn(sizes, dtype=dtype, device=device)
z = torch.randn(sizes, dtype=dtype, device=device)
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = torch.add(x, y)
o = torch.sum(o, dim=[0])
o = torch.add(o, z)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_permutation_preservation(self):
sizes = [2, 3, 4, 5]
dtype = torch.float
device = "cuda"
x = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last)
def t(x: torch.Tensor):
o = torch.relu(x)
o = torch.sum(o, dim=[0])
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
# TODO: we could preserve permutation to inputs
self.assertEqual(o.stride(), jit_o.stride())
def t(x: torch.Tensor):
o = torch.relu(x)
o = torch.add(o, 1.0)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
self.assertTrue(jit_o.is_contiguous(memory_format=torch.channels_last))
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_normalization_partition(self):
sizes = [3, 8, 5]
dtype = torch.float
device = "cuda"
x = torch.randn(sizes, dtype=dtype, device=device)
y = torch.randn(sizes, dtype=dtype, device=device)
z = torch.randn(sizes, dtype=dtype, device=device)
r_m = torch.randn(8, dtype=dtype, device=device)
r_v = torch.randn(8, dtype=dtype, device=device)
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, r_mean: torch.Tensor, r_var: torch.Tensor):
o = torch.add(x, y)
o = torch.nn.functional.softmax(o, dim=0)
o = torch.add(o, z)
o = torch.nn.functional.batch_norm(o, r_mean, r_var, training=True)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y, z, r_m, r_v)
jit_o = t_jit(x, y, z, r_m, r_v)
o = t(x, y, z, r_m, r_v)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z, r_m, r_v), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_sum_to_one(self):
dtype = torch.float
device = "cuda"
x = torch.randn([4, 5, 6], dtype=dtype, device=device)
def t(x: torch.Tensor):
o = torch.add(x, 1)
o = torch.sum(o, dim=[0, 1, 2])
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_single_reduction_broadcast(self):
dtype = torch.float
device = "cuda"
x = torch.randn([7, 4, 8], dtype=dtype, device=device)
y = torch.randn([4, 8], dtype=dtype, device=device)
z = torch.randn([1, 4, 8], dtype=dtype, device=device)
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = torch.add(x, y)
o = torch.add(o, z)
o = torch.sum(o, dim=[0])
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_trivial_reduction(self):
dtype = torch.float
device = "cuda"
x = torch.randn([1, 4, 8], dtype=dtype, device=device)
def t(x: torch.Tensor):
o = torch.add(x, 1)
o = torch.sum(o, dim=[0])
o = torch.sum(o, dim=[0])
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_profiling_node(self):
dtype = torch.float
device = "cuda"
x = torch.randn(4, 8, 8, 8, dtype=dtype, device=device)
def repro(x: torch.Tensor, alpha: float):
o = torch.rand_like(x)
o = torch.add(o, alpha)
return o
repro_jit = torch.jit.script(repro)
self._run_helper(repro_jit, repro, x, 0.6)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_reduction_sizes_op(self):
dtype = torch.float
device = "cuda"
x = torch.randn(2, 3, 4, 5, dtype=dtype, device=device)
y = torch.randn(2, 3, 4, 5, dtype=dtype, device=device)
def t(x: torch.Tensor, y: torch.Tensor):
o = x + y
o = torch.relu(o)
o = o.sum((1, 3))
return o.size()
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o, jit_o)
# since the output value is not used at all, the fusion operator should
# have been optimized away
self.assertGraphContainsExactly(t_jit.graph_for(x, y), FUSION_GUARD, 0)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_profile_ivalue(self):
dtype = torch.float
device = "cuda"
x = torch.randn([7, 4, 7], dtype=dtype, device=device)
y = torch.randn([7, 4, 7], dtype=dtype, device=device)
def t(x: torch.Tensor, y: torch.Tensor, dim: List[int], keepdim: bool):
o = torch.add(x, y)
o = o.sum(dim, keepdim=keepdim)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y, (0, 1), False)
jit_o = t_jit(x, y, (0, 1), False)
o = t(x, y, (0, 1), False)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, (0, 1), False), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_sum_to_size(self):
dtype = torch.float
device = "cuda"
x = torch.randn([2, 4, 4], dtype=dtype, device=device)
y = torch.randn([2, 4, 4], dtype=dtype, device=device)
def t(x: torch.Tensor, y: torch.Tensor, new_size: List[int]):
o = torch.add(x, y)
o = o.sum_to_size(new_size)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y, (4, 1))
jit_o = t_jit(x, y, (4, 1))
o = t(x, y, (4, 1))
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, (4, 1)), FUSION_GUARD)
# update shape: old kernel should handle dynamic shape well without
# recompilation
x = torch.randn([2, 5, 8], dtype=dtype, device=device)
y = torch.randn([2, 5, 8], dtype=dtype, device=device)
# (TODO) check executed kernel, should extend autograd.profiler to fused
# kernels
jit_o = t_jit(x, y, (5, 1))
o = t(x, y, (5, 1))
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_grad_sum_to_size(self):
dtype = torch.float
device = "cuda"
x = torch.randn([2, 4, 4], dtype=dtype, device=device).requires_grad_()
y = torch.randn([4], dtype=dtype, device=device).requires_grad_()
grad = torch.randn([2, 4, 4], dtype=dtype, device=device)
ref_x = x.detach().clone().requires_grad_()
ref_y = y.detach().clone().requires_grad_()
def t(x: torch.Tensor, y: torch.Tensor):
o = torch.add(x, y)
o = torch.relu(o)
return o
# profiling runs for forward & backward
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o.backward(grad)
jit_o = t_jit(x, y)
jit_o.backward(grad)
x.grad = None
y.grad = None
jit_o = t_jit(x, y)
jit_o.backward(grad)
o = t(ref_x, ref_y)
o.backward(grad)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertEqual(x.grad, ref_x.grad)
self.assertEqual(y.grad, ref_y.grad)
bwd_graph = list(
list(t_jit.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
FileCheck().check(FUSION_GUARD).run(bwd_graph)
# update shape: old kernel should handle dynamic shape well without
# recompilation
x = torch.randn([2, 5, 8], dtype=dtype, device=device).requires_grad_()
y = torch.randn([8], dtype=dtype, device=device).requires_grad_()
ref_x = x.detach().clone().requires_grad_()
ref_y = y.detach().clone().requires_grad_()
grad = torch.randn([2, 5, 8], dtype=dtype, device=device)
jit_o = t_jit(x, y)
# (TODO) check executed kernel, should extend autograd.profiler to fused
# kernels
jit_o.backward(grad)
o = t(ref_x, ref_y)
o.backward(grad)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertEqual(x.grad, ref_x.grad)
self.assertEqual(y.grad, ref_y.grad)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_dropout_inference_fusion(self):
dtype = torch.float
device = "cuda"
x = torch.randn([10, 4, 8], dtype=dtype, device=device)
def t(x: torch.Tensor, p: float, train: bool):
o = torch.nn.functional.dropout(x, p, training=train)
o = o + 1.0
return o
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, 0.15, False)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_dropout_train_nograd_fusion(self):
dtype = torch.float
device = "cuda"
x = torch.randn([10, 4, 8], dtype=dtype, device=device)
def t(x: torch.Tensor, p: float, train: bool):
o = torch.nn.functional.dropout(x, p, training=train)
o = o + 1.0
return o
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, 0.0, True)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_dropout_train_nograd_prob_check(self):
dtype = torch.float
device = "cuda"
x = torch.randn([1024, 1024], dtype=dtype, device=device)
def t(x: torch.Tensor, p: float, train: bool):
o = torch.nn.functional.dropout(x, p, training=train)
o = o * 2.0
return o
t_jit = torch.jit.script(t)
for prob in [0.0, 0.15, 0.5, 0.85, 1.]:
torch.cuda.manual_seed_all(123)
jit_o = t_jit(x, prob, True)
torch.cuda.manual_seed_all(123)
jit_o = t_jit(x, prob, True)
self.assertTrue(jit_o.detach().isfinite().all().item())
num_elems = x.numel()
num_zeros = num_elems - jit_o.detach().count_nonzero().item()
percent_zeros = num_zeros / num_elems
self.assertTrue((percent_zeros >= (prob - 0.01)) and (percent_zeros <= (prob + 0.01)))
self.assertGraphContainsExactly(t_jit.graph_for(x, prob, True), FUSION_GUARD, 1, consider_subgraphs=True)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_dropout_training_fusion(self):
dtype = torch.float
device = "cuda"
x = torch.randn([10, 4, 8], dtype=dtype, device=device, requires_grad=True)
grads = torch.randn([10, 4, 8], dtype=dtype, device=device)
def t(x: torch.Tensor, p: float, train: bool):
o = torch.nn.functional.dropout(x, p, training=train)
o = o * 2.0
return o
t_jit = torch.jit.script(t)
# The drop probability needs to be set to zero given that the order of picking random
# numbers between eager mode and the jit is different
self._run_training_helper(t_jit, t, grads, x, 0.0, True)
def t2(x: torch.Tensor, p: float, train: bool):
o = torch.nn.functional.softmax(x, dim=-1)
o = torch.nn.functional.dropout(o, p, training=train)
return o
t2_jit = torch.jit.script(t2)
# The drop probability needs to be set to zero given that the order of picking random
# numbers between eager mode and the jit is different
self._run_training_helper(t2_jit, t2, grads, x, 0.0, True)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_gelu(self):
old_guard = torch._C._jit_set_nvfuser_guard_mode(True)
dtype = torch.float
device = "cuda"
x = torch.randn([1024, 1024], dtype=dtype, device=device, requires_grad=True)
grads = torch.randn([1024, 1024], dtype=dtype, device=device, requires_grad=False)
def t(x: torch.Tensor, mode : str):
o = torch.nn.functional.gelu(x, approximate=mode)
o = o * 2.0
return o
t_jit = torch.jit.script(t)
self._run_training_helper(t_jit, t, grads, x, 'none')
self._run_training_helper(t_jit, t, grads, x, 'tanh')
torch._C._jit_set_nvfuser_guard_mode(old_guard)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_dropout_training_prob_check(self):
dtype = torch.float
device = "cuda"
x = torch.randn([1024, 1024], dtype=dtype, device=device, requires_grad=True)
x_nograd = torch.randn([1024, 1024], dtype=dtype, device=device)
def t(x: torch.Tensor, p: float, train: bool):
o = torch.nn.functional.dropout(x, p, training=train)
o = o * 2.0
return o
t_jit = torch.jit.script(t)
for prob in [0.0, 0.15, 0.5, 0.85, 1.]:
torch.cuda.manual_seed_all(123)
jit_o = t_jit(x, prob, True)
torch.cuda.manual_seed_all(123)
jit_o = t_jit(x, prob, True)
torch.cuda.manual_seed_all(123)
jit_o = t_jit(x, prob, True)
self.assertTrue(jit_o.detach().isfinite().all().item())
num_elems = x.numel()
num_zeros = num_elems - jit_o.detach().count_nonzero().item()
percent_zeros = num_zeros / num_elems
self.assertTrue((percent_zeros >= (prob - 0.01)) and (percent_zeros <= (prob + 0.01)))
self.assertGraphContainsExactly(t_jit.graph_for(x, prob, True), FUSION_GUARD, 1, consider_subgraphs=True)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_linear(self):
in_feature = 2
out_feature = 8
x = torch.randn(4, in_feature, dtype=torch.float32, device='cuda')
weight = torch.randn(out_feature, in_feature, dtype=torch.float32, device='cuda')
bias = torch.randn(out_feature, dtype=torch.float32, device='cuda')
def t(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor):
o = torch.nn.functional.linear(x, weight, bias)
o = torch.relu(o)
return o
# bias set to true.
t_jit = torch.jit.script(t)
jit_o = t_jit(x, weight, bias)
jit_o = t_jit(x, weight, bias)
o = t(x, weight, bias)
self.assertEqual(o, jit_o)
# since the output value is not used at all, the fusion operator should
# have been optimized away
self.assertGraphContainsExactly(t_jit.graph_for(x, weight, bias), FUSION_GUARD, 1)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_backward_type(self):
# not super useful to check gradient of integer/bool, so skipping here
type_pairs = [
(torch.float, torch.half),
(torch.double, torch.half),
(torch.float, torch.double),
]
if TEST_BF16:
type_pairs += [
(torch.float, torch.bfloat16),
(torch.double, torch.bfloat16),
]
for x_type, y_type in type_pairs:
x = torch.randn(4, 2, dtype=x_type, device='cuda', requires_grad=True)
y = torch.randn(4, 2, dtype=y_type, device='cuda', requires_grad=True)
grad = torch.randn(4, 2, dtype=torch.float, device='cuda')
def test1(x: torch.Tensor, y: torch.Tensor):
o = torch.add(x, y)
o = torch.add(o, y)
o = torch.add(o, y)
o = torch.add(o, y)
o = o + 1.0
return o
test1_jit = torch.jit.script(test1)
for i in range(3):
jit_o = test1_jit(x, y)
jit_o.backward(grad)
bwd_graph = list(
list(test1_jit.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
FileCheck().check(FUSION_GROUP).run(bwd_graph)
self.assertEqual(x.grad.dtype, x.dtype)
self.assertEqual(y.grad.dtype, y.dtype)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_autocast_1(self):
def t(x: torch.Tensor, y: torch.Tensor):
o = x * 2.0
o = torch.softmax(o, dim=-1)
o = o * 3.0
o = torch._C._nn.linear(o, y)
return o
x = torch.randn(8, 4, dtype=torch.half, device='cuda', requires_grad=True)
y = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
grad = torch.randn(8, 4, dtype=torch.half, device='cuda', requires_grad=False)
t_jit = torch.jit.script(t)
for i in range(3):
with torch.cuda.amp.autocast():
jit_o = t_jit(x, y)
if i == 2 :
fwd_graph = t_jit.graph_for(x, y)
jit_o.backward(grad)
self.assertGraphContainsExactly(fwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
with torch.cuda.amp.autocast():
bwd_graph = list(
list(t_jit.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
FileCheck().check(FUSION_GROUP).run(bwd_graph)
self.assertEqual(jit_o.dtype, torch.half)
self.assertEqual(x.grad.dtype, x.dtype)
self.assertEqual(y.grad.dtype, y.dtype)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_autocast_2(self):
def t(x: torch.Tensor):
o = x * 2.0
o = torch.softmax(o, dim=-1)
o = o * 3.0
o = torch.softmax(o, dim=-1)
o = o * 4.0
return o
x = torch.randn(8, 4, dtype=torch.half, device='cuda', requires_grad=True)
grad = torch.randn(8, 4, dtype=torch.float, device='cuda', requires_grad=False)
t_jit = torch.jit.script(t)
for i in range(3):
with torch.cuda.amp.autocast() :
jit_o = t_jit(x)
if i == 2 :
fwd_graph = t_jit.graph_for(x)
jit_o.backward(grad)
self.assertGraphContainsExactly(fwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
with torch.cuda.amp.autocast():
bwd_graph = list(
list(t_jit.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
FileCheck().check(FUSION_GROUP).run(bwd_graph)
self.assertEqual(jit_o.dtype, torch.float)
self.assertEqual(x.grad.dtype, x.dtype)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_autocast_1_bfloat(self):
def t(x: torch.Tensor, y: torch.Tensor):
o = x * 2.0
o = torch.softmax(o, dim=-1)
o = o * 3.0
o = torch._C._nn.linear(o, y)
return o
x = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda', requires_grad=True)
y = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
grad = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda', requires_grad=False)
t_jit = torch.jit.script(t)
for i in range(3):
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
jit_o = t_jit(x, y)
if i == 2 :
fwd_graph = t_jit.graph_for(x, y)
jit_o.backward(grad)
self.assertGraphContainsExactly(fwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
bwd_graph = list(
list(t_jit.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
FileCheck().check(FUSION_GROUP).run(bwd_graph)
self.assertEqual(jit_o.dtype, torch.bfloat16)
self.assertEqual(x.grad.dtype, x.dtype)
self.assertEqual(y.grad.dtype, y.dtype)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_autocast_2_bfloat(self):
def t(x: torch.Tensor):
o = x * 2.0
o = torch.softmax(o, dim=-1)
o = o * 3.0
o = torch.softmax(o, dim=-1)
o = o * 4.0
return o
x = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda', requires_grad=True)
grad = torch.randn(8, 4, dtype=torch.float, device='cuda', requires_grad=False)
t_jit = torch.jit.script(t)
for i in range(3):
with torch.cuda.amp.autocast(dtype=torch.bfloat16) :
jit_o = t_jit(x)
if i == 2 :
fwd_graph = t_jit.graph_for(x)
jit_o.backward(grad)
self.assertGraphContainsExactly(fwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
bwd_graph = list(
list(t_jit.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
FileCheck().check(FUSION_GROUP).run(bwd_graph)
self.assertEqual(jit_o.dtype, torch.float)
self.assertEqual(x.grad.dtype, x.dtype)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_to_dtype_fp32_to_fp16(self):
def t(x: torch.Tensor):
o = x * 2.0
o = o.to(dtype=torch.half)
o = o * 3.0
return o
x = torch.randn(8, 4, dtype=torch.float, device='cuda')
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x)
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
self.assertEqual(jit_o.dtype, torch.half)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_to_dtype_fp16_to_fp32(self):
def t(x: torch.Tensor):
o = x * 2.0
o = o.to(dtype=torch.float)
o = o * 3.0
return o
x = torch.randn(8, 4, dtype=torch.half, device='cuda')
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x)
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
self.assertEqual(jit_o.dtype, torch.float)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_to_dtype_fp16_to_fp16(self):
def t(x: torch.Tensor):
o = x * 2.0
o = o.to(dtype=torch.half)
o = o * 3.0
return o
x = torch.randn(8, 4, dtype=torch.half, device='cuda')
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x)
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
self.assertEqual(jit_o.dtype, torch.half)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_to_dtype_fp32_to_bf16(self):
def t(x: torch.Tensor):
o = x * 2.0
o = o.to(dtype=torch.bfloat16)
o = o * 3.0
return o
x = torch.randn(8, 4, dtype=torch.float, device='cuda')
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x)
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
self.assertEqual(jit_o.dtype, torch.bfloat16)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_to_dtype_bf16_to_fp32(self):
def t(x: torch.Tensor):
o = x * 2.0
o = o.to(dtype=torch.float)
o = o * 3.0
return o
x = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda')
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x)
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
self.assertEqual(jit_o.dtype, torch.float)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_to_dtype_bf16_to_bf16(self):
def t(x: torch.Tensor):
o = x * 2.0
o = o.to(dtype=torch.bfloat16)
o = o * 3.0
return o
x = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda')
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x)
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
self.assertEqual(jit_o.dtype, torch.bfloat16)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(not TEST_MULTIGPU, "requires multiple CUDA device")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_multiple_device_pw(self):
def t(x):
o = x + 1.0
o = torch.relu(o)
return o
x = torch.randn(2, dtype=torch.float32, device="cuda")
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x)
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
torch.cuda.device(1)
x = x.to("cuda:1")
jit_o = t_jit(x)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_graph_for_with_missing_optimized_engine(self):
x = torch.randn(8, 4, 2, dtype=torch.float, device="cuda").requires_grad_()
def t(x: torch.Tensor, flag: bool):
x = x + 1.0
x = torch.relu(x)
if flag:
o = x + 1.0
o = torch.relu(o)
else:
o = x + 2.0
o = torch.relu(o)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, False)
jit_o = t_jit(x, False)
jit_o = t_jit(x, True)
o = t(x, True)
self.assertEqual(o, jit_o)
# since the output value is not used at all, the fusion operator should
# have been optimized away
self.assertGraphContainsExactly(t_jit.graph_for(x, True), FUSION_GUARD, 1, True)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_branches(self):
in_feature = 2
out_feature = 4
x = torch.randn(4, in_feature, dtype=torch.float32, device='cuda')
weight = torch.randn(out_feature, in_feature, dtype=torch.float32, device='cuda')
bias = torch.randn(out_feature, dtype=torch.float32, device='cuda')
def t(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, flag: bool):
if flag:
o = torch.nn.functional.linear(x, weight, bias)
o = o + 1.0
o = torch.relu(o)
else:
o = x.sum()
o = o + 2.0
o = torch.relu(o)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, weight, bias, True)
jit_o = t_jit(x, weight, bias, True)
o = t(x, weight, bias, True)
self.assertEqual(o, jit_o)
# since the output value is not used at all, the fusion operator should
# have been optimized away
self.assertGraphContainsExactly(t_jit.graph_for(x, weight, bias, True), FUSION_GUARD, 1)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_scalar_tensor(self):
x = torch.empty([], device="cuda", dtype=torch.float32)
def t(x: torch.Tensor):
o = x + 1.0
o = torch.nn.functional.relu(o)
return o
# bias set to true.
t_jit = torch.jit.script(t)
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o, jit_o)
# since the output value is not used at all, the fusion operator should
# have been optimized away
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
@unittest.skipIf(os.environ.get('PYTORCH_NO_CUDA_MEMORY_CACHING') is not None,
"skipping graph_rng when caching allocator is disabled")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(CUDA_MAJOR < 11, "requires CUDA11 or above")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_graph_rng(self):
self.assertTrue(torch._C._jit_nvfuser_enabled())
size = 10000
a = torch.randn((size,), device="cuda", dtype=torch.float)
def t(x):
o = x + 1.0
o = torch.nn.functional.dropout(o, p=0.1)
o = o + 1.0
o = torch.nn.functional.dropout(o, p=0.1)
return o
t_jit = torch.jit.script(t)
for _ in range(3):
t_jit(a)
self.assertGraphContainsExactly(t_jit.graph_for(a), FUSION_GUARD, 1)
# Control (jitted, ungraphed)
torch.cuda.manual_seed(5)
eager_out = a.clone()
for _ in range(3):
eager_out = t_jit(eager_out)
graph_in = a.clone()
g = torch.cuda.CUDAGraph()
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
torch.cuda.manual_seed(5)
g.capture_begin()
graph_out = t_jit(graph_in)
g.capture_end()
torch.cuda.current_stream().wait_stream(s)
# g is now a jitted, graphed version of t.
# Runs a (jitted, graphed) -> (jitted, ungraphed) -> (jitted, graphed) sequence.
# The ops in the overall sequence should be the same as Control.
g.replay()
# graph_out is now filled with g's result. Use it as ungraphed input.
out = t_jit(graph_out)
graph_in.copy_(out)
g.replay()
# If replay() updated RNG state correctly, graph_out should now equal eager_out
self.assertEqual(graph_out, eager_out)
def _test_batch_norm_impl_index_helper(self, batch, c, hw, affine=True,
track_running_stats=True, train=True,
dtype=torch.float32):
# enabling inlining to avoid counter increment in BN forward
torch._C._debug_set_autodiff_subgraph_inlining(True)
class MyModule(torch.nn.Module):
def __init__(self, num_features=10, affine=True, track_running_stats=True):
super(MyModule, self).__init__()
self.bn = torch.nn.BatchNorm2d(num_features,
1e-5,
affine=affine,
track_running_stats=track_running_stats).to(dtype=dtype)
def forward(self, x):
o = self.bn(x)
o = o * 2.0
return o
x = torch.randn(batch, c, hw, hw, dtype=torch.float, device="cuda").to(dtype=dtype).requires_grad_()
grad = torch.randint(-20, 20, (batch, c, hw, hw), device="cuda").to(dtype=dtype).div(-10)
my_module = MyModule(c, affine, track_running_stats).cuda()
ref_module = MyModule(c, affine, track_running_stats).cuda()
if not train:
my_module.eval()
ref_module.eval()
t_jit = torch.jit.script(my_module)
ref_module.load_state_dict(my_module.state_dict())
ref_x = x.detach().requires_grad_()
for i in range(0, 3):
jit_o = t_jit(x)
jit_o.backward(grad)
# TODO: remove this run?
o = ref_module(ref_x)
o.backward(grad)
has_affine = ref_module.bn.weight is not None
has_running_stats = ref_module.bn.running_mean is not None
if has_running_stats:
my_module.bn.running_mean.zero_()
my_module.bn.running_var.fill_(1.0)
ref_module.bn.running_mean.zero_()
ref_module.bn.running_var.fill_(1.0)
# Verify that when train is False, we don't have grad for weight/bias.
if has_affine and train:
my_module.bn.weight.grad.zero_()
my_module.bn.bias.grad.zero_()
ref_module.bn.weight.grad.zero_()
ref_module.bn.bias.grad.zero_()
x.grad.zero_()
ref_x.grad.zero_()
# real runs
jit_o = t_jit(x)
jit_o.backward(grad)
o = ref_module(ref_x)
o.backward(grad)
# assert forward graph fusion
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1, consider_subgraphs=True)
# assert backward graph fusion
bwd_graph = list(
list(t_jit.get_debug_state().execution_plans.values())[0].code.grad_executor_states()[0]
.execution_plans.values())[0].graph
self.assertGraphContainsExactly(bwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
e0 = 1e-5 if dtype is not torch.half else 1e-3
e1 = 1e-4 if dtype is not torch.half else 1e-3
e2 = 1e-3 if dtype is not torch.half else 1e-2
self.assertTrue(self._compare("comparing output failed", jit_o, o, e0))
self.assertTrue(self._compare("comparing input grad failed", x.grad, ref_x.grad, e1))
# TODO: switch to welford and reduce this to 1e-5
# The 1e-3 looks bad, but we don't have welford in codegen, so numeric
# is very different between reference and codegen.
if has_affine and train:
self.assertTrue(self._compare("comparing weight grad failed",
my_module.bn.weight.grad,
ref_module.bn.weight.grad,
e2))
self.assertTrue(self._compare("comparing bias grad failed",
my_module.bn.bias.grad,
ref_module.bn.bias.grad,
e1))
if has_running_stats:
self.assertTrue(self._compare("comparing running_mean failed",
my_module.bn.running_mean,
ref_module.bn.running_mean,
e0))
self.assertTrue(self._compare("comparing running_var failed",
my_module.bn.running_var,
ref_module.bn.running_var,
e0))
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_batch_norm_half(self):
with torch.backends.cudnn.flags(enabled=True):
setups = [
[True, True],
[False, False],
[True, False],
[False, True]]
for training_and_track, affine in itertools.product(setups, [True, False]):
training, track_running_stats = training_and_track
self._test_batch_norm_impl_index_helper(4, 8, 5, affine, track_running_stats, training, torch.half)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_batch_norm_impl_index_correctness(self):
with torch.backends.cudnn.flags(enabled=True):
batch = [2, 7, 16]
channels = [4, 89, 19, 32]
hw = [1, 8, 17, 32]
# avoid tolerance failure in CI
torch.cuda.manual_seed_all(211)
# failing sizes (2, 1, 1, 1)
# failing sizes (2, 89, 8, 8) training False, track True, affine: False
for b, c, hw in itertools.product(batch, channels, hw):
setups = [
[True, True],
[False, False],
[True, False],
[False, True]]
for training_and_track, affine in itertools.product(setups, [True, False]):
training, track_running_stats = training_and_track
self._test_batch_norm_impl_index_helper(b, c, hw, affine, track_running_stats, training)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_softplus_fuser(self):
def shifted_softplus(x: torch.Tensor, shift: float):
return functional.softplus(x) - shift
jitted = torch.jit.script(shifted_softplus)
inp = torch.randn(4, 2, dtype=torch.float32, device="cuda").requires_grad_()
inp_ref = inp.detach().clone().requires_grad_()
grad = torch.randn(4, 2, dtype=torch.float32, device="cuda")
aten_o = shifted_softplus(inp_ref, 0.693147)
aten_o.backward(grad)
aten_grad = inp_ref.grad
for i in range(3):
jit_o = jitted(inp, 0.693147)
inp.grad = None # avoid accumulation on grad
jit_o.backward(grad)
jit_grad = inp.grad
assert torch.allclose(jit_o, aten_o)
assert torch.allclose(jit_grad, aten_grad)
self.assertGraphContains(jitted.graph_for(inp, 0.693147), FUSION_GROUP, True)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_inplace_removal(self):
def t(x: torch.Tensor):
o = torch.nn.functional.softmax(x, dim=0)
o += x
return o.relu_()
jitted = torch.jit.script(t)
inp = torch.randn(4, 2, dtype=torch.float32, device="cuda")
for i in range(3):
jit_o = jitted(inp)
graph = jitted.graph_for(inp)
self.assertGraphContains(graph, FUSION_GROUP, True)
self.assertGraphContains(graph, 'aten::add', True)
self.assertGraphContains(graph, 'aten::relu', True)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_conv2d_bias(self):
def t(x: torch.Tensor, w: torch.Tensor, bias: torch.Tensor):
o = torch.nn.functional.conv2d(x, w, bias)
return o.relu()
jitted = torch.jit.script(t)
inp = torch.randn(4, 5, 3, 3, dtype=torch.float32, device="cuda")
weight = torch.randn(2, 5, 2, 2, dtype=torch.float32, device="cuda")
bias = torch.randn(2, dtype=torch.float32, device="cuda")
for i in range(3):
jit_o = jitted(inp, weight, bias)
graph = jitted.graph_for(inp)
self.assertGraphContains(graph, FUSION_GROUP, True)
def t_not_fused(x: torch.Tensor, w: torch.Tensor):
o = torch.nn.functional.conv2d(x, w)
return o.relu()
jitted_not_fused = torch.jit.script(t_not_fused)
for i in range(3):
jit_o = jitted_not_fused(inp, weight)
graph = jitted_not_fused.graph_for(inp)
self.assertGraphContainsExactly(graph, FUSION_GROUP, 0)
self.assertGraphContains(graph, 'aten::relu', True)
def t_bias(x: torch.Tensor, w: torch.Tensor, bias: torch.Tensor):
return torch.nn.functional.conv2d(x, w, bias)
jitted_bias = torch.jit.script(t_bias)
for i in range(3):
jit_o = jitted_bias(inp, weight, bias)
graph = jitted_bias.graph_for(inp)
self.assertGraphContainsExactly(graph, FUSION_GROUP, 0)
self.assertGraphContains(graph, 'prim::add_optional', True)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_remove_output_used_only_in_dtype(self):
class MyModule(torch.nn.Module):
def __init__(self, num_features=4):
super(MyModule, self).__init__()
self.bn0 = torch.nn.BatchNorm2d(num_features)
self.bn1 = torch.nn.BatchNorm2d(num_features)
def forward(self, x, y):
o1 = self.bn0(x)
o2 = self.bn1(y)
return torch.relu(o1 + o2)
t = MyModule(4).float().cuda()
jitted = torch.jit.script(t)
x = torch.randn(3, 4, 2, 5, dtype=torch.float32, device="cuda")
y = torch.randn(3, 4, 2, 5, dtype=torch.float32, device="cuda")
with torch.cuda.amp.autocast(True):
for i in range(5):
jit_o = jitted(x, y)
jit_o = jitted(x, y)
o = t(x, y)
self.assertTrue(torch.allclose(jit_o, o))
graph = jitted.graph_for(x, y)
self.assertGraphContains(graph, FUSION_GROUP, True)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_fix_shape_expression_bn(self):
class MyModule(torch.nn.Module):
def __init__(self, num_features=4):
super(MyModule, self).__init__()
self.bn = torch.nn.BatchNorm2d(num_features)
def forward(self, x, y):
out1 = self.bn(x)
out2 = out1 + y
out3 = torch.relu(out2)
return out3
t = MyModule(4).float().cuda()
jitted = torch.jit.script(t)
x = torch.randn(3, 4, 2, 5, dtype=torch.float32, device="cuda")
y = torch.randn(3, 4, 2, 5, dtype=torch.float32, device="cuda")
with torch.cuda.amp.autocast(True):
for i in range(5):
jit_o = jitted(x, y)
jit_o = jitted(x, y)
o = t(x, y)
self.assertTrue(torch.allclose(jit_o, o))
graph = jitted.graph_for(x, y)
self.assertGraphContains(graph, FUSION_GROUP, True)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_linear_1d_weight_mismatch_bias_dtype(self):
def t(x: torch.Tensor, w: torch.Tensor, b: torch.Tensor):
o = torch.nn.functional.linear(x, w, b)
return o.relu()
device = "cuda"
jitted = torch.jit.script(t)
x = torch.randn(2, 5, 5, dtype=torch.half, device=device)
w = torch.randn(5, dtype=torch.half, device=device)
b = torch.randn(5, dtype=torch.float32, device=device)
for i in range(3):
jit_o = jitted(x, w, b)
jit_o = jitted(x, w, b)
o = t(x, w, b)
self.assertEqual(o, jit_o)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.size(), jit_o.size())
graph = jitted.graph_for(x, w, b)
self.assertGraphContains(graph, FUSION_GROUP, True)
self.assertGraphContains(graph, 'aten::matmul', True)
def _run_fwd_helper(self, func, ops, *args):
jitted = torch.jit.script(func)
for i in range(3):
jit_o = jitted(*args)
jit_o = jitted(*args)
o = func(*args)
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
graph = jitted.graph_for(*args)
self.assertGraphContains(graph, FUSION_GROUP, True)
for op in ops:
self.assertGraphContainsExactly(graph, op, 0)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_sibling_fusion(self):
device = "cuda"
dtype = torch.float
x = torch.randn(2, 5, dtype=dtype, device=device)
y = torch.randn(2, 5, dtype=dtype, device=device)
def t(x: torch.Tensor):
o1 = x + 1.0
o2 = x * 0.5
return o1, o2
self._run_fwd_helper(t, ['aten::add', 'aten::mul'], x)
def t2(x: torch.Tensor, y: torch.Tensor):
o1 = x.sum(0)
o2 = (x * y).sum(0)
return o1, o2
self._run_fwd_helper(t2, ['aten::sum', 'aten::mul'], x, y)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_clean_profile_ivalue(self):
device = "cuda"
dtype = torch.float
x = torch.randn(2, 5, dtype=dtype, device=device, requires_grad=True)
# turn on autodiff subgraph inlining
# this is to verify that we clean up profile_ivalue node out side of
# fusion code path.
torch._C._debug_set_autodiff_subgraph_inlining(True)
def t(x: torch.Tensor, flag: bool):
return torch.dropout(x, 0.5, flag)
jit_t = torch.jit.script(t)
for idx in range(5) :
out = jit_t(x, True)
graph = jit_t.graph_for(x, True)
out = jit_t(x, False)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_sibling_fusion_no_scalar_inputs(self):
device = "cuda"
dtype = torch.float
x = torch.randn(2, 5, dtype=dtype, device=device)
y = torch.randn(3, dtype=dtype, device=device)
# no tensor dependency between o1/o2, we shouldn't be fusing them
def t(x: torch.Tensor, y: torch.Tensor):
o1 = x + 1
o2 = y - 1
return o1, o2
jitted = torch.jit.script(t)
for i in range(3):
jit_o = jitted(x, y)
graph = jitted.graph_for(x, y)
self.assertGraphContainsExactly(graph, FUSION_GROUP, 0)
def _bias_view_relu_helper(self, shape, output_shape, dtype, device, error):
class BiasViewRelu(torch.nn.Module):
def __init__(self):
super(BiasViewRelu, self).__init__()
self.bias = torch.nn.Parameter(torch.randn(shape, dtype=dtype, device=device), requires_grad=False)
with torch.no_grad():
self.bias.fill_(10)
def forward(self, inputs : torch.Tensor, view_shape : List[int]):
o = inputs + self.bias
o = o.view(view_shape)
return torch.relu(o)
t = BiasViewRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
# profiling
jit_o = t_jit(x, output_shape)
# optimization
jit_o = t_jit(x, output_shape)
# final
jit_o = t_jit(x, output_shape)
# eager - baseline
o = t(x, output_shape)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x, output_shape)
has_inferred_dimension = any([dim == -1 for dim in output_shape])
if has_inferred_dimension:
# prohibit fusing when view_shape contains an inferred dimension
self.assertGraphContainsExactly(graph, FUSION_GROUP, 0)
self.assertGraphContainsExactly(graph, 'prim::view_copy', 0)
else:
self.assertGraphContains(graph, FUSION_GUARD)
self.assertGraphContains(graph, 'prim::view_copy', True)
def _alias_bias_view_relu_helper(self, shape, output_shape, dtype, device, error):
class BiasViewRelu(torch.nn.Module):
def __init__(self):
super(BiasViewRelu, self).__init__()
self.bias = torch.nn.Parameter(torch.randn(shape, dtype=dtype, device=device), requires_grad=False)
with torch.no_grad():
self.bias.fill_(10)
def forward(self, inputs : torch.Tensor, view_shape : List[int]):
o = inputs.view(view_shape)
inputs = inputs * self.bias
return torch.relu(o)
t = BiasViewRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
# profiling
jit_o = t_jit(x, output_shape)
# optimization
jit_o = t_jit(x, output_shape)
# final
jit_o = t_jit(x, output_shape)
# eager - baseline
o = t(x, output_shape)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x, output_shape)
self.assertGraphContainsExactly(graph, FUSION_GUARD, 0)
self.assertGraphContainsExactly(graph, 'prim::view_copy', 0)
# generate random view given original view
def _random_view(self, original_view, max_len=8, max_views=10000):
class Moves(enum.Enum):
Merge = 0
Split = 1
Broadcast = 2
ImplicitBroadcast = 3
Keep = 4
def valid(old_view, new_view):
old_view_size = reduce(operator.mul, old_view)
new_view_size = reduce(operator.mul, new_view)
return old_view_size == new_view_size
# given a random starting number, find the nearest divisor
def find_nearest_divisor(N):
if 2 >= (N - 1):
return -1
result = random.randint(2, N - 1)
while (N % result) != 0:
result += 1
return result
complete_views = set([tuple(original_view)])
to_visit = []
# empty new view, curent originaal view, start pos=0, move count = 0, last_move
to_visit.append(([], original_view, 0, [], Moves.Keep))
# depth-first search of view shapes, starting from the original view
while len(to_visit) > 0 and len(complete_views) < max_views:
new_view, old_view, odx, move_list, last_move = to_visit[-1]
to_visit.pop()
# iterate over each move type
for idx in range(len(Moves)):
state = Moves(idx)
new_view_clone = copy.deepcopy(new_view)
old_view_clone = copy.deepcopy(old_view)
new_move_list = move_list + [state]
new_odx = odx
# Update state using Move state
if state == Moves.Keep:
new_size = old_view_clone[odx]
new_view_clone.append(new_size)
new_odx += 1
elif state == Moves.Merge:
if odx + 1 < len(old_view_clone):
new_size = old_view_clone[odx] * old_view_clone[odx + 1]
new_view_clone.append(new_size)
new_odx += 2
else:
continue
elif state == Moves.Broadcast and last_move != Moves.Broadcast:
new_view_clone.append(1)
elif state == Moves.Split:
new_size = find_nearest_divisor(old_view_clone[odx])
if new_size == -1:
continue
new_view_clone.append(new_size)
old_view_clone[odx] = int(old_view[odx] / new_size)
if old_view_clone[odx] == 1:
new_odx += 1
elif state == Moves.ImplicitBroadcast:
old_view_clone.insert(odx + 1, 1)
new_size = old_view[odx] * 1
new_view_clone.append(new_size)
new_odx += 2
if new_odx < len(old_view_clone) and len(new_move_list) < max_len:
to_visit.append((new_view_clone, old_view_clone, new_odx, new_move_list, state))
elif (valid(original_view, new_view_clone)):
final_new_view = tuple(new_view_clone)
complete_views.add(final_new_view)
return list(complete_views)
# ndims - number of dimensions
# test_fn - view test function
def _view_test_generator(self, ndims, test_fn):
# create random tensor
# max value for each dimension
max_size = 10e7
max_value = max(int(pow(max_size, 1. / ndims)), 1)
sizes = [random.randint(1, max_value) for idx in range(ndims)]
x = torch.randn(sizes)
original_sizes = list(x.size())
all_views = self._random_view(original_sizes)
random.shuffle(all_views)
max_samples = 20
max_views = min(len(all_views), max_samples)
total = 0
correct = 0
# test random combinations of compatible views
for idx in range(max_views):
for jdx in range(idx + 1, max_views):
total += 1
test_fn(all_views[idx], all_views[jdx], torch.float, 'cuda', 1e-6)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_view(self):
torch._C._jit_set_nvfuser_guard_mode(True)
self._bias_view_relu_helper([2, 3, 4, 5], [-1, 4, 5], torch.float, 'cuda', 1e-6)
for ndims in range(1, 5):
self._view_test_generator(ndims, self._bias_view_relu_helper)
self._alias_bias_view_relu_helper([2, 3, 4, 5], [1, 6, 1, 2, 2, 5, 1], torch.float, 'cuda', 1e-6)
def _bias_squeeze_relu_helper(self, shape, dtype, device, error):
class BiasSqueezeRelu(torch.nn.Module):
def __init__(self):
super(BiasSqueezeRelu, self).__init__()
def forward(self, inputs : torch.Tensor, bias : torch.Tensor):
o = inputs + bias
o = torch.squeeze(o)
return torch.relu(o)
t = BiasSqueezeRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, bias)
jit_o = t_jit(x, bias)
jit_o = t_jit(x, bias)
o = t(x, bias)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x)
self.assertGraphContains(graph, FUSION_GUARD)
self.assertGraphContains(graph, 'prim::squeeze_copy', True)
def _alias_bias_squeeze_relu_helper(self, shape, dtype, device, error):
class BiasSqueezeRelu(torch.nn.Module):
def __init__(self):
super(BiasSqueezeRelu, self).__init__()
def forward(self, inputs : torch.Tensor, bias : torch.Tensor):
o = torch.squeeze(inputs)
inputs = inputs * bias
return torch.relu(o)
t = BiasSqueezeRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, bias)
jit_o = t_jit(x, bias)
jit_o = t_jit(x, bias)
o = t(x, bias)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x, bias)
self.assertGraphContainsExactly(graph, FUSION_GUARD, 0)
self.assertGraphContainsExactly(graph, 'prim::squeeze_copy', 0)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_squeeze(self):
self._bias_squeeze_relu_helper([1, 6, 1, 2, 2, 5, 1], torch.float, 'cuda', 1e-6)
self._alias_bias_squeeze_relu_helper([1, 6, 1, 2, 2, 5, 1], torch.float, 'cuda', 1e-6)
def _bias_unsqueeze_relu_helper(self, shape, dtype, device, error):
class BiasUnsqueezeRelu(torch.nn.Module):
def __init__(self):
super(BiasUnsqueezeRelu, self).__init__()
def forward(self, inputs : torch.Tensor, bias : torch.Tensor):
o = inputs + bias
o = torch.unsqueeze(o, 0)
return torch.relu(o)
t = BiasUnsqueezeRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, bias)
jit_o = t_jit(x, bias)
jit_o = t_jit(x, bias)
o = t(x, bias)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x)
self.assertGraphContains(graph, FUSION_GUARD)
self.assertGraphContains(graph, 'prim::unsqueeze_copy', True)
def _alias_bias_unsqueeze_relu_helper(self, shape, dtype, device, error):
class BiasUnsqueezeRelu(torch.nn.Module):
def __init__(self):
super(BiasUnsqueezeRelu, self).__init__()
def forward(self, inputs : torch.Tensor, bias : torch.Tensor):
o = torch.squeeze(inputs)
o = torch.unsqueeze(inputs, 0)
inputs = inputs * bias
return torch.relu(o)
t = BiasUnsqueezeRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, bias)
jit_o = t_jit(x, bias)
jit_o = t_jit(x, bias)
o = t(x, bias)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x)
self.assertGraphContainsExactly(graph, FUSION_GUARD, 0)
self.assertGraphContainsExactly(graph, 'prim::unsqueeze_copy', 0)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_unsqueeze(self):
self._bias_unsqueeze_relu_helper([2, 3, 4, 5], torch.float, 'cuda', 1e-6)
self._alias_bias_unsqueeze_relu_helper([2, 3, 4, 5], torch.float, 'cuda', 1e-6)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_alias_pass_fix(self):
x = torch.randn(4, 24, 2, 2, dtype=torch.float, device="cuda")
w = torch.randn(24, 24, 1, 1, dtype=torch.float, device="cuda")
b = torch.randn(24, dtype=torch.float, device="cuda")
def t(x, w, b):
b2 = b + 1.0
o = torch.conv2d(x, w, b2)
return o
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, w, b)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_squeeze_negative_dim(self):
x = torch.randn(4, 24, 1, 2, dtype=torch.float, device="cuda")
def t(x):
o = x + 1.0
o = o.squeeze(-2)
o = o * 2.0
return o
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_singleton_fusion(self):
x = torch.randn(4, 2, device="cuda")
with nvfuser_singleton_fusion(True):
def t(x):
return x.relu()
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_disable_sibling_fuse(self):
x = torch.randn(4, 2, device="cuda")
y = torch.randn(8, device="cuda")
s = torch.tensor(1.5, device="cuda")
with nvfuser_horizontal_fusion(False):
def t(x, y, s):
o1 = x + s
o2 = y + s
return o1, o2
t_jit = torch.jit.script(t)
for i in range(5):
t_jit(x, y, s)
# sibling fusion should be disabled with the flag
self.assertGraphContainsExactly(t_jit.graph_for(x, y, s), FUSION_GUARD, 0)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_build_shape_expression_native_dropout(self):
x = torch.randn(4, 2, device="cuda")
def t(x):
o, mask = torch.native_dropout(x, 0.0, True)
o1 = o.sigmoid()
o2 = mask.float().sigmoid()
return (o1, o2)
t_jit = torch.jit.script(t)
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_scalar_tensor_permuted(self):
x = torch.randn(4, 2, 3, device="cuda").permute([1, 2, 0])
y = torch.tensor(1.0, device="cuda")
with nvfuser_singleton_fusion(True):
def t(x, y):
return x + y
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, y)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_cpu_scalar(self):
x = torch.randn(4, 2, 3, device="cuda")
y = torch.tensor(1.0, device="cpu")
z = torch.tensor(2.0, device="cpu")
with nvfuser_singleton_fusion(True):
# testing cpu scalar tensor promotion
def t(x, y):
return x + y
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, y)
# scalar cpu tensor add should NOT be fused
@torch.jit.script
def t1(y, z):
return y * z
for _ in range(5):
t1(y, z)
self.assertGraphContainsExactly(t1.graph_for(y, z), FUSION_GUARD, 0)
# everything, including scalar cpu tensor add should be fused
@torch.jit.script
def t2(x, y, z):
tmp = y + z
return tmp + x
for _ in range(5):
t2(x, y, z)
self.assertGraphContainsExactly(t2.graph_for(x, y, z), 'aten::add', 0)
self.assertGraphContainsExactly(t2.graph_for(x, y, z), FUSION_GUARD, 1)
# 'cpu_tmp = y + z' shouldn't be fused.
@torch.jit.script
def t3(x, y, z):
cpu_tmp = y + z
out = x + y
return cpu_tmp, out
for _ in range(5):
t3(x, y, z)
self.assertGraphContainsExactly(t3.graph_for(x, y, z), FUSION_GUARD, 1)
self.assertGraphContainsExactly(t3.graph_for(x, y, z), 'aten::add', 1)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_shape_expression(self):
x = torch.randn(4, 2, 1, 3, device="cuda")
def t_unsqueeze(x):
t0 = x.relu()
t1 = t0.unsqueeze(1)
t2 = t1 + 1.0
t3 = t1.size()
return t2, t3
def t_squeeze(x):
t0 = x.relu()
t1 = t0.squeeze()
t2 = t1 + 1.0
t3 = t1.size()
return t2, t3
def t_squeeze_dim(x):
t0 = x.relu()
t1 = t0.squeeze(-2)
t2 = t1 + 1.0
t3 = t1.size()
return t2, t3
# squeezing a non-size 1 dimension should be a no op
def t_squeeze_dim_no_op(x):
t0 = x.relu()
t1 = t0.squeeze(1)
t2 = t1 + 1.0
t3 = t1.size()
return t2, t3
def run(fn):
jit_fn = torch.jit.script(fn)
jit_o = jit_fn(x)
jit_o = jit_fn(x)
jit_o = jit_fn(x)
o = fn(x)
# output 0 is a tensor, so we check dtype and value
self.assertEqual(o[0].dtype, jit_o[0].dtype)
self.assertEqual(o[0], jit_o[0])
# output 1 is shape
self.assertEqual(o[1], jit_o[1])
self.assertGraphContainsExactly(jit_fn.graph_for(x), FUSION_GUARD, 1)
for t in [t_unsqueeze, t_squeeze, t_squeeze_dim, t_squeeze_dim_no_op]:
run(t)
class TestPassManagerCudaFuser(JitTestCase):
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_context_manager_test(self):
x = torch.randn(4, 8, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, dtype=torch.float, device="cuda")
with torch.jit.fuser('fuser2'):
with torch.jit.fuser('fuser2'):
def t1(x, y):
o = x + y
o = o + 2.0
return o
t_jit = torch.jit.script(t1)
t_jit(x, y)
t_jit(x, y)
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
def t2(x, y):
o = x + y
o = o + 3.0
return o
t_jit_2 = torch.jit.script(t2)
t_jit_2(x, y)
t_jit_2(x, y)
self.assertGraphContains(t_jit_2.graph_for(x, y), FUSION_GUARD)
def t3(x, y):
o = x + y
o = o + 4.0
return o
t_jit_3 = torch.jit.script(t3)
t_jit_3(x, y)
t_jit_3(x, y)
self.assertGraphContainsExactly(t_jit_3.graph_for(x, y), FUSION_GUARD, 0)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
def test_register_fuser(self):
self.assertFalse(torch._C._jit_set_nvfuser_enabled(True))
self.assertTrue(torch._C._jit_nvfuser_enabled())
self.assertTrue(torch._C._jit_set_nvfuser_enabled(True))
self.assertTrue(torch._C._jit_nvfuser_enabled())
self.assertTrue(torch._C._jit_set_nvfuser_enabled(False))
self.assertFalse(torch._C._jit_nvfuser_enabled())
if __name__ == '__main__':
run_tests()