# Owner(s): ["oncall: jit"] import contextlib import unittest import os import random import enum import copy from functools import reduce import operator import warnings import torch from torch.nn import functional from torch.profiler import profile, ProfilerActivity from torch.testing._internal.codegen.random_topo_test import runDefaultTestWithSeed from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_device_type import instantiate_device_type_tests, ops, OpDTypes from torch.testing._internal.common_jit import JitCommonTestCase from torch.testing._internal.common_methods_invocations import op_db, SampleInput from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, TEST_WITH_ROCM, slowTest, \ is_iterable_of_tensors, freeze_rng_state from torch.testing._internal.jit_utils import clone_inputs, get_traced_sample_variant_pairs, JitTestCase, RUN_CUDA from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn from torch.testing import FileCheck from jit.test_fuser_common import TestFuserCommon # noqa: F401 import itertools import numpy as np import math from torch.autograd.gradcheck import gradcheck from typing import List RUN_NVFUSER = RUN_CUDA and not TEST_WITH_ROCM CUDA_MAJOR, CUDA_MINOR = 0, 0 if RUN_NVFUSER and torch.version.cuda is not None: CUDA_MAJOR, CUDA_MINOR = (int(x) for x in torch.version.cuda.split('.')[:2]) os.environ['PYTORCH_NVFUSER_DISABLE'] = 'fallback,fma,unroll_with_rng' os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0' # TODO: enable complex when we fixes the extremal cases in OpInfo # see issue https://github.com/csarofeen/pytorch/issues/1730" # os.environ['PYTORCH_NVFUSER_ENABLE'] = 'complex' if GRAPH_EXECUTOR == ProfilingMode.PROFILING: torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_mode(True) FUSION_GROUP = 'prim::CudaFusionGroup' FUSION_GUARD = 'prim::CudaFusionGuard' # TODO: revert disabled alias ops ALIAS_TEST_DISABLED = True @contextlib.contextmanager def nvfuser_singleton_fusion(flag): old_value = torch._C._jit_set_nvfuser_single_node_mode(flag) try: yield finally: torch._C._jit_set_nvfuser_single_node_mode(old_value) @contextlib.contextmanager def nvfuser_horizontal_fusion(flag): old_value = torch._C._jit_set_nvfuser_horizontal_mode(flag) try: yield finally: torch._C._jit_set_nvfuser_horizontal_mode(old_value) def is_pre_volta(): if not RUN_NVFUSER: return False prop = torch.cuda.get_device_properties(torch.cuda.current_device()) return prop.major < 7 TEST_BF16 = RUN_NVFUSER and torch.cuda.is_bf16_supported() class CudaFuserTestOptions(): def __init__(self): self.old_cpu_fuse = torch._C._jit_can_fuse_on_cpu() self.old_gpu_fuse = torch._C._jit_can_fuse_on_gpu() torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_gpu(False) self.old_guard = torch._C._jit_set_nvfuser_guard_mode(False) torch._C._debug_set_autodiff_subgraph_inlining(False) self.old_value = torch._C._jit_set_autocast_mode(True) if(RUN_CUDA): self.old_nvfuser = torch._C._jit_set_nvfuser_enabled(True) def restore(self): if(RUN_CUDA): torch._C._jit_set_nvfuser_enabled(self.old_nvfuser) torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuse) torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuse) torch._C._jit_set_nvfuser_guard_mode(self.old_guard) torch._C._debug_set_autodiff_subgraph_inlining(True) torch._C._jit_set_autocast_mode(self.old_value) class TestCudaFuser(JitTestCase): def assertEqual(self, *args, **kwargs): kwargs["exact_layout"] = True super(JitTestCase, self).assertEqual(*args, **kwargs) def _getSubgraphInFusion(self, graph): num_node = 0 subgraph = None def count(block, ret): for n in block.nodes(): if n.kind() == FUSION_GROUP: ret[0] = ret[0] + 1 self.assertTrue(n.hasAttribute('Subgraph')) ret[1] = n.g('Subgraph') for block in n.blocks(): count(block, ret) ret = [num_node, subgraph] count(graph, ret) self.assertEqual(ret[0], 1) return ret[1] def setUp(self): super(TestCudaFuser, self).setUp() self.skip_node_list = [] disabled_ops = ("aten::batch_norm", "aten::_batch_norm_impl_index", "aten::_batch_norm_impl_index_backward", "aten::native_batch_norm_backward") for op in disabled_ops: disabled_flag = torch._C._jit_set_nvfuser_skip_node_kind(op, False) if disabled_flag: torch._C._jit_set_nvfuser_skip_node_kind(op, True) self.skip_node_list.append(op) # cpu backup to avoid errors in case this is run on a CPU-only machine dev = 'cuda' if RUN_NVFUSER else 'cpu' self.special_values = torch.tensor( [float("-inf"), -10, -math.pi, -1, -0.5, 0, 1, 0.5, math.pi, 10, float("inf"), float("nan")], dtype=torch.float, device=dev) self.int_types = [ torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64 ] self.support_tensor_dtypes = [ torch.int32, torch.int64, torch.float16, torch.float32, torch.float64, torch.bool, torch.complex64, torch.complex128, ] if TEST_BF16: self.support_tensor_dtypes.append(torch.bfloat16) if(RUN_NVFUSER): self.cuda_fuser_options = CudaFuserTestOptions() def tearDown(self): # restoring skip node to the configuration before tests for op in self.skip_node_list: disabled_flag = torch._C._jit_set_nvfuser_skip_node_kind(op, False) if not disabled_flag: torch._C._jit_set_nvfuser_skip_node_kind(op, True) if(RUN_NVFUSER): self.cuda_fuser_options.restore() super(TestCudaFuser, self).tearDown() def _run_helper(self, jit_op, op, *args, check_stride=False, num_fusion=1): torch.cuda.manual_seed_all(123) jit_o = jit_op(*args) torch.cuda.manual_seed_all(123) jit_o = jit_op(*args) torch.cuda.manual_seed_all(123) o = op(*args) if type(jit_o) is torch.Tensor: jit_o = [jit_o, ] o = [o, ] for oo, jit_oo in zip(o, jit_o): self.assertEqual(oo.dtype, jit_oo.dtype) self.assertEqual(oo, jit_oo) if check_stride: self.assertEqual(oo.stride(), jit_oo.stride()) self.assertGraphContainsExactly(jit_op.graph_for(*args), FUSION_GUARD, num_fusion, consider_subgraphs=True) def _run_training_helper(self, jit_op, op, grads, *args): torch.cuda.manual_seed_all(123) jit_o = jit_op(*args) jit_g = jit_o.backward(grads) torch.cuda.manual_seed_all(123) jit_o = jit_op(*args) jit_g = jit_o.backward(grads) torch.cuda.manual_seed_all(123) jit_o = jit_op(*args) jit_g = jit_o.backward(grads) torch.cuda.manual_seed_all(123) o = op(*args) g = o.backward(grads) self.assertEqual(o, jit_o) self.assertEqual(g, jit_g) self.assertGraphContainsExactly(jit_op.graph_for(*args), FUSION_GUARD, 1, consider_subgraphs=True) bwd_graph = list( list(jit_op.get_debug_state().execution_plans.values())[ 0].code.grad_executor_states()[0].execution_plans.values() )[0].graph self.assertGraphContainsExactly(bwd_graph, FUSION_GUARD, 1, consider_subgraphs=True) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_half(self): def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, alpha: float): o_16 = torch.add(x, y) o_32_a = torch.add(y, z, alpha=alpha) o_32_b = torch.add(o_16, z) return (o_16, o_32_a, o_32_b) t_jit = torch.jit.script(t) alpha = 0.5 # stick to integers, this avoid the numerical difference due to our # promotion x = torch.randint(0, 256, (4, 8)).to(dtype=torch.float16, device="cuda") y = torch.randint(0, 256, (4, 8)).to(dtype=torch.float16, device="cuda") z = torch.randint(0, 256, (4, 8)).to(dtype=torch.float16, device="cuda") jit_o = t_jit(x, y, z, alpha) jit_o = t_jit(x, y, z, alpha) o = t(x, y, z, alpha) for oo, jit_oo in zip(o, jit_o): self.assertEqual(oo.dtype, jit_oo.dtype) self.assertEqual(oo, jit_oo) self.assertGraphContains(t_jit.graph_for(x, y, z, alpha), FUSION_GUARD) @unittest.skipIf(not TEST_BF16, "device does not support BFloat16") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_bfloat(self): def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, alpha: float): o_16 = torch.add(x, y) o_32_a = torch.add(y, z, alpha=alpha) o_32_b = torch.add(o_16, z) return (o_16, o_32_a, o_32_b) t_jit = torch.jit.script(t) alpha = 0.5 # stick to integers, this avoid the numerical difference due to our # promotion x = torch.randint(0, 256, (4, 8)).to(dtype=torch.bfloat16, device="cuda") y = torch.randint(0, 256, (4, 8)).to(dtype=torch.bfloat16, device="cuda") z = torch.randint(0, 256, (4, 8)).to(dtype=torch.bfloat16, device="cuda") jit_o = t_jit(x, y, z, alpha) jit_o = t_jit(x, y, z, alpha) o = t(x, y, z, alpha) for oo, jit_oo in zip(o, jit_o): self.assertEqual(oo.dtype, jit_oo.dtype) self.assertEqual(oo, jit_oo) self.assertGraphContains(t_jit.graph_for(x, y, z, alpha), FUSION_GUARD) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_const(self): def t(x, y): o = x + y o = o + 2.0 return o t_jit = torch.jit.script(t) x = torch.randn(4, 8, dtype=torch.float, device="cuda") y = torch.randn(4, 8, dtype=torch.float, device="cuda") jit_o = t_jit(x, y) jit_o = t_jit(x, y) o = t(x, y) self.assertEqual(o, jit_o) self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_chunk(self): def t(x, y, z, q): o = x + q x0, x1 = torch.chunk(o, 2) o = x0 + x1 o = o + y o = o * z o = torch.relu(o) return o t_jit = torch.jit.script(t) x = torch.randn(4, 8, dtype=torch.float, device="cuda") y = torch.randn(2, 8, dtype=torch.float, device="cuda") z = torch.randn(2, 8, dtype=torch.float, device="cuda") q = torch.randn(4, 8, dtype=torch.float, device="cuda") jit_o = t_jit(x, y, z, q) jit_o = t_jit(x, y, z, q) o = t(x, y, z, q) self.assertEqual(o, jit_o) self.assertGraphContains(t_jit.graph_for(x, y, z, q), FUSION_GUARD) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_reduction_dtypes_axis(self): for op in [torch.sum, torch.mean, torch.amax, torch.var, torch.std]: for dtype in [torch.float16, torch.float32, torch.double]: for axis in [-1, 2, 0]: def make_func(op): def func(x: torch.Tensor): o = torch.mul(x, 2.0) o = op(o, dim=[axis]) return o return func x = torch.randn(8, 4, 16, dtype=dtype, device="cuda") t = make_func(op) t_jit = torch.jit.trace(t, x) jit_o = t_jit(x) jit_o = t_jit(x) o = t(x) self.assertEqual(o.dtype, jit_o.dtype) self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4)) self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_variance(self): for op in [torch.var, torch.std]: for dtype in [torch.float16, torch.float32, torch.double]: for axis in [-2, -1, 2, 1]: for unbiased in [False, True]: def make_func(op): def func(x: torch.Tensor): o = torch.mul(x, 2.0) o = op(o, dim=[axis]) return o return func x = torch.randn(8, 4, 16, dtype=dtype, device="cuda") t = make_func(op) t_jit = torch.jit.trace(t, x) jit_o = t_jit(x) jit_o = t_jit(x) o = t(x) self.assertEqual(o.dtype, jit_o.dtype) self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4)) self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_scalar_input(self): def t(x: torch.Tensor, y: torch.Tensor, z: float): o = x + y o = o + z return o t_jit = torch.jit.script(t) x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") y = torch.randn(4, 8, 1, 32, dtype=torch.float, device="cuda") y = y.expand(4, 8, 32, 32) jit_o = t_jit(x, y, 2.0) jit_o = t_jit(x, y, 2.0) o = t(x, y, 2.0) self.assertEqual(o, jit_o) self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_broadcasting_0(self): def t(x: torch.Tensor, y: torch.Tensor, z: float): o = x + y o = o + z return o t_jit = torch.jit.script(t) x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") y = torch.randn(32, 32, dtype=torch.float, device="cuda") jit_o = t_jit(x, y, 2.0) jit_o = t_jit(x, y, 2.0) o = t(x, y, 2.0) self.assertEqual(o, jit_o) subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0)) self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_broadcasting_1(self): def t(x: torch.Tensor, y: torch.Tensor, z: float): o = x + y o = o + z return o t_jit = torch.jit.script(t) x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") y = torch.randn(1, 32, 32, dtype=torch.float, device="cuda") jit_o = t_jit(x, y, 2.0) jit_o = t_jit(x, y, 2.0) o = t(x, y, 2.0) self.assertEqual(o, jit_o) subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0)) self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_broadcasting_2(self): def t(x: torch.Tensor, y: torch.Tensor, z: float): o = x + y o = o + z return o t_jit = torch.jit.script(t) x = torch.randn(4, 1, 32, 32, dtype=torch.float, device="cuda") y = torch.randn(8, 32, 32, dtype=torch.float, device="cuda") jit_o = t_jit(x, y, 2.0) jit_o = t_jit(x, y, 2.0) o = t(x, y, 2.0) self.assertEqual(o, jit_o) subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0)) self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_broadcasting_3(self): def t(x: torch.Tensor, y: torch.Tensor, z: float): o = x + y o = o + z return o t_jit = torch.jit.script(t) x = torch.randn(8, 17, 8, dtype=torch.float, device="cuda") y = torch.randn(8, 17, 1, dtype=torch.float, device="cuda") jit_o = t_jit(x, y, 2.0) jit_o = t_jit(x, y, 2.0) o = t(x, y, 2.0) self.assertEqual(o, jit_o) subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0)) self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False) # test_broadcasting_partition_logic_X # Testing partition logic that is capable to avoid creating unsupported # broadcasting semantics in CudaFusionGroup @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_broadcasting_partition_logic_0(self): def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): x = x + 12.0 o1 = x + y o2 = x + z o = o1 + o2 return o t_jit = torch.jit.script(t) x = torch.randn(4, 8, 6, 8, dtype=torch.float32, device="cuda") y = torch.randn(8, 6, 8, dtype=torch.float32, device="cuda") z = torch.randn(6, 8, dtype=torch.float32, device="cuda") jit_o = t_jit(x, y, z) jit_o = t_jit(x, y, z) o = t(x, y, z) self.assertEqual(o, jit_o) subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, z)) self.assertGraphContainsExactly(subgraph, 'aten::add', 4, consider_subgraphs=False) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_broadcasting_partition_logic_1(self): def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): x = x + 12.0 o1 = x + y o2 = x + z o = o1 + o2 return o t_jit = torch.jit.script(t) x = torch.randn(8, 6, 8, dtype=torch.float32, device="cuda") y = torch.randn(4, 8, 6, 8, dtype=torch.float32, device="cuda") z = torch.randn(4, 1, 6, 8, dtype=torch.float32, device="cuda") jit_o = t_jit(x, y, z) jit_o = t_jit(x, y, z) o = t(x, y, z) self.assertEqual(o, jit_o) subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, z)) self.assertGraphContainsExactly(subgraph, 'aten::add', 4, consider_subgraphs=False) @unittest.skipIf(True, "Broadcast with different output not supported yet") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_broadcasting_multiple_output_shape(self): def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): o = x + 12 o1 = o + y o2 = o + z oo = o1.sum() + o2.sum() return oo t_jit = torch.jit.script(t) x = torch.randn(32, 32, dtype=torch.float, device="cuda") y = torch.randn(2, 32, 32, dtype=torch.float, device="cuda") z = torch.randn(4, 32, 32, dtype=torch.float, device="cuda") jit_o = t_jit(x, y, z) jit_o = t_jit(x, y, z) o = t(x, y, z) self.assertEqual(o, jit_o) # Currently cannot fuse this self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD) @unittest.skipIf(True, "broadcast on branches can't be resolved yet") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_broadcasting_multiple_output(self): def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): o = x + 12 o1 = o + y o2 = o + z oo = o1.sum() + o2.sum() return oo t_jit = torch.jit.script(t) x = torch.randn(32, 32, dtype=torch.float, device="cuda") y = torch.randn(4, 32, 32, dtype=torch.float, device="cuda") z = torch.randn(4, 32, 32, dtype=torch.float, device="cuda") jit_o = t_jit(x, y, z) jit_o = t_jit(x, y, z) o = t(x, y, z) self.assertEqual(o, jit_o) # Currently cannot fuse this self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD) def _unary_test_helper(self, operation, dtype, random_data): gradient_check = (dtype == torch.float64) and random_data shape = self.special_values.shape torch.cuda.manual_seed_all(211) # need additional def of t for boolean ops def t(x: torch.Tensor, y: torch.Tensor): o = x * y o = o + 5e-3 o = operation(o) return o y = torch.rand(shape, dtype=torch.float32, device="cuda", requires_grad=gradient_check) y = y.to(dtype=dtype) if random_data: x = torch.rand(shape, dtype=torch.float32, device="cuda", requires_grad=gradient_check) if dtype in self.int_types: # prefer a larger variance for integer types x = x * 5 x = x.to(dtype=dtype) else: x = self.special_values.to(dtype=dtype) try: ref = t(x, y) except Exception: # same way as TE checker, if eager mode throws, ignore this test return t_jit = torch.jit.script(t) jit_o = t_jit(x, y) jit_o = t_jit(x, y) jit_o = t_jit(x, y) if gradient_check: if jit_o.dtype != torch.bool: # bool dtype has no `-` gradcheck(t_jit, [x, y], nondet_tol=1e-5) elif dtype in self.support_tensor_dtypes: self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD) o = t(x, y) self.assertEqual(o.dtype, jit_o.dtype) if dtype == torch.bfloat16: # compare with the actual ground truth for # bfloat16 kernels instead of eager mode # implementation, since mismatch in cast # adds excessive noise. o = t(x.to(torch.float64), y.to(torch.float64)) if o.dtype.is_floating_point: o = o.to(torch.bfloat16) else: o = t(x, y) self.assertTrue(self._compare("failing case {}\n{}\n{}\n{}".format(dtype, operation, x, y), o, jit_o, 1e-2)) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_unary_ops(self): data_types = [ *self.int_types, torch.float16, torch.float32, torch.float64, # TODO: revert this # see issue https://github.com/csarofeen/pytorch/issues/1730" # torch.cfloat, # torch.cdouble, ] if TEST_BF16: data_types.append(torch.bfloat16) operations = [torch.neg, torch.abs, torch.log, torch.log10, torch.log1p, torch.log2, torch.lgamma, torch.exp, torch.expm1, torch.erf, torch.erfc, torch.cos, torch.acos, torch.cosh, torch.sin, torch.asin, torch.sinh, torch.tan, torch.atan, torch.sqrt, torch.rsqrt, torch.ceil, torch.floor, torch.round, torch.trunc, torch.frac, torch.reciprocal, torch.isfinite, torch.isinf, torch.isnan, torch.isneginf, torch.isposinf, torch.isreal, torch.nn.functional.softplus, torch.nn.functional.gelu, torch.relu, torch.sigmoid, torch.bitwise_not, torch.tan, torch.tanh, torch.nn.functional.silu] skip_complex = {torch.rsqrt, torch.reciprocal} for op, dtype in itertools.product(operations, data_types): if dtype.is_complex and op in skip_complex: continue self._unary_test_helper(op, dtype, False) # test special numbers self._unary_test_helper(op, dtype, True) # test random data @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_category_rule(self): def run_tensor(x, z): def t(x: torch.Tensor, z: torch.Tensor): o = x + z o = torch.abs(o) return o t_jit = torch.jit.script(t) jit_o = t_jit(x, z) jit_o = t_jit(x, z) o = t(x, z) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o, jit_o) self.assertGraphContains(t_jit.graph_for(x, z), FUSION_GUARD) def run_scalar(x, z): def t(x: torch.Tensor, z: float): o = x + z o = torch.abs(o) return o t_jit = torch.jit.script(t) jit_o = t_jit(x, z) jit_o = t_jit(x, z) o = t(x, z) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o, jit_o) self.assertGraphContains(t_jit.graph_for(x, z), FUSION_GUARD) # n-dim with 0-dim (no type-promote) x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") z = torch.tensor(2.0, dtype=torch.double, device="cuda") run_tensor(x, z) # n-dim with 0-dim (type-promote) x = torch.randn(4, 8, 32, 32, device="cuda").to(dtype=torch.long) z = torch.tensor(2.0, dtype=torch.double, device="cuda") run_tensor(x, z) # n-dim with n-dim (type-promote) x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") z = torch.randn(4, 8, 32, 32, dtype=torch.double, device="cuda") run_tensor(x, z) # n-dim with scalar (no type-promote) x = torch.randn(4, 8, 32, 32, dtype=torch.float16, device="cuda") z = torch.tensor(3., dtype=torch.double) run_scalar(x, z) if TEST_BF16: # n-dim with scalar (no type-promote) x = torch.randn(4, 8, 32, 32, dtype=torch.bfloat16, device="cuda") z = torch.tensor(3., dtype=torch.double) run_scalar(x, z) # n-dim with scalar (type-promote) x = torch.randn(4, 8, 32, 32, device="cuda").to(dtype=torch.long) z = torch.tensor(3., dtype=torch.double) run_scalar(x, z) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_unary_bitwise(self): def bit_not(x: torch.Tensor): return ~(x + 1) jitted = torch.jit.script(bit_not) x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(5).to(torch.long) jit_o = jitted(x) jit_o = jitted(x) o = bit_not(x) self.assertEqual(o, jit_o) jitted.graph_for(x) # Shows up in second instance, not first self.assertGraphContains(jitted.graph_for(x), FUSION_GUARD) def bool_not(x: torch.Tensor, y: torch.Tensor): return ~(x & y) jitted = torch.jit.script(bool_not) x = torch.rand(4, 8, 32, 32, dtype=torch.float, device="cuda").round().to(torch.bool) y = torch.rand(4, 8, 32, 32, dtype=torch.float, device="cuda").round().to(torch.bool) jit_o = jitted(x, y) jit_o = jitted(x, y) o = bool_not(x, y) self.assertEqual(o, jit_o) jitted.graph_for(x, y) # Shows up in second instance, not first self.assertGraphContains(jitted.graph_for(x, y), FUSION_GUARD) def _get_scalar_binary_test_fn(self, category_and_type1, category_and_type2, operation): category1, dtype_arg1 = category_and_type1 category2, dtype_arg2 = category_and_type2 def t_intx_tensory(x: int, y: torch.Tensor): o = operation(x, y) o = 2 + o return o def t_doublex_tensory(x: float, y: torch.Tensor): o = operation(x, y) o = 2 + o return o def t_cdoublex_tensory(x: complex, y: torch.Tensor): o = operation(x, y) o = 2 + o return o # Omit both scalar cases and swap cases assert category1 == "scalar" and category2 != "scalar" if dtype_arg1.is_floating_point: return t_doublex_tensory if dtype_arg1 == torch.int64 or dtype_arg1 == torch.int32: return t_intx_tensory if dtype_arg1.is_complex or dtype_arg1 == torch.int32: return t_cdoublex_tensory raise NotImplementedError def _binary_test_helper(self, operation, dtypes, random_data, categories="ndim"): if isinstance(dtypes, tuple): dtype_arg1, dtype_arg2 = dtypes else: dtype_arg1 = dtype_arg2 = dtypes if isinstance(categories, tuple) and random_data: category1, category2 = categories elif not random_data: category1 = category2 = "ndim" else: category1 = category2 = categories def is_cpu_category(x): return x == "0dimcpu" or x == "scalar" # skip unsupported cases if is_cpu_category(category1) and is_cpu_category(category2): return # only test cases with first operand as scalar if category2 == "scalar": return # skip ops that doesn't support scalar inputs in eager if operation in [ torch.atan2, torch.max, torch.min, torch.remainder, # unsupported in nvfuser ]: if category1 == "scalar" or category2 == "scalar": return if operation in [ torch.fmod, torch.eq, torch.ne, torch.ge, torch.gt, torch.le, torch.lt ]: if category1 == "scalar": return # operators that does not support bfloat16 if operation in [torch.fmod]: if dtype_arg1 == torch.bfloat16 or dtype_arg2 == torch.bfloat16: return def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): o = operation(x, y) o = o + z return o shape = (4, 32, 32) shapex = shape if category1 == "ndim" else () shapey = shape if category2 == "ndim" else () if random_data: x = (torch.randn(shapex, dtype=torch.float, device="cuda") * 5).to(dtype_arg1) y = (torch.randn(shapey, dtype=torch.float, device="cuda") * 5).to(dtype_arg2) else: x = self.special_values.to(dtype=dtype_arg1) y = (torch.rand_like(self.special_values) * 5).to(dtype_arg2) r""" Category conversion """ has_scalar = False if category1 == "scalar": has_scalar = True x = x.item() if category1 == "0dimcpu": x = x.to(device="cpu") if category2 == "scalar": has_scalar = True y = y.item() if category2 == "0dimcpu": y = y.to(device="cpu") z = torch.tensor([2], device="cuda").to(dtype_arg1) is_dtype_arg1_int = dtype_arg1 == torch.int32 or dtype_arg1 == torch.int64 is_dtype_arg2_int = dtype_arg2 == torch.int32 or dtype_arg2 == torch.int64 if operation in [torch.pow]: if is_dtype_arg1_int and is_dtype_arg2_int: if category2 == "scalar": # RuntimeError: Integers to negative integer powers are not allowed y = abs(y) if category2 == "0dimcpu" and y == -1: # https://github.com/pytorch/pytorch/issues/73196 y = y - 1 if category2 == "0dimcpu" and y == -2: # avoid pow(0, -2), which gives inconsistent results on integer tensor y = y - 1 # Avoid division by zero for integer tensors div_like = [torch.div, torch.fmod, torch.remainder] if operation in div_like and (dtype_arg2 == torch.int32 or dtype_arg2 == torch.int64): y[y == 0] = 1 test_value = True if dtype_arg1 == torch.half or dtype_arg2 == torch.half: test_value = False if dtype_arg1 == torch.bfloat16 or dtype_arg2 == torch.bfloat16: test_value = False try: if not has_scalar: o = t(x, y, z) t_jit = torch.jit.script(t) jit_o = t_jit(x, y, z) jit_o = t_jit(x, y, z) jit_o = t_jit(x, y, z) self.assertEqual(o.dtype, jit_o.dtype) if test_value: self.assertEqual(o, jit_o) self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD) elif category2 != "scalar": # only test the case where first is scalar test_fn = self._get_scalar_binary_test_fn((category1, dtype_arg1), (category2, dtype_arg2), operation) o = test_fn(x, y) t_jit = torch.jit.script(test_fn) jit_o = t_jit(x, y) jit_o = t_jit(x, y) jit_o = t_jit(x, y) self.assertEqual(o.dtype, jit_o.dtype) if test_value: self.assertEqual(o, jit_o) self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD) except Exception as e: print("failing test for op: ", operation.__name__) print("with input\n\tx: ", x) print("\ty: ", y) print("\tz: ", z) raise e @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_binary_ops(self): data_types = [ torch.int32, torch.int64, torch.float16, torch.float32, torch.float64, ] if TEST_BF16: data_types.append(torch.bfloat16) operations = [torch.mul, torch.div, torch.atan2, torch.max, torch.min, torch.pow, torch.remainder, torch.fmod, torch.eq, torch.ne, torch.ge, torch.gt, torch.le, torch.lt] category_types = [ "scalar", "0dim", "0dimcpu", "ndim" ] binary_dtype_combinations = list(itertools.combinations(data_types, 2)) category_combinations = list(itertools.combinations(category_types, 2)) for op, dtypes, categories in itertools.product(operations, binary_dtype_combinations, category_combinations): self._binary_test_helper(op, dtypes, True, categories) # random data for op, dtypes in itertools.product(operations, binary_dtype_combinations): self._binary_test_helper(op, dtypes, False) # special numbers # TODO: revert this @unittest.skipIf(True, "see issue https://github.com/csarofeen/pytorch/issues/1730") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_binary_ops_complex(self): data_types = [torch.cfloat, torch.cdouble] operations = [torch.mul, torch.div, torch.pow, torch.eq, torch.ne] category_types = [ "scalar", "0dim", "0dimcpu", "ndim" ] binary_dtype_combinations = list(itertools.combinations(data_types, 2)) category_combinations = list(itertools.combinations(category_types, 2)) for op, dtypes, categories in itertools.product(operations, binary_dtype_combinations, category_combinations): self._binary_test_helper(op, dtypes, True, categories) # random data for op, dtypes in itertools.product(operations, binary_dtype_combinations): self._binary_test_helper(op, dtypes, False) # special numbers @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_binary_bitwise(self): dtypes = [torch.bool, torch.int32, torch.int64] for dtype1, dtype2, dtype3 in itertools.product(dtypes, repeat=3): def jit_and(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): return torch.bitwise_and(x, y) & z def jit_or(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): return torch.bitwise_or(x, y) | z def jit_xor(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): return torch.bitwise_xor(x, y) ^ z def jit_lshift(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): return torch.bitwise_left_shift(x, y) << z def jit_rshift(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): return torch.bitwise_right_shift(x, y) >> z for jit_func in [jit_and, jit_or, jit_xor, jit_lshift, jit_rshift]: if torch.bool in {dtype1, dtype2, dtype3} and jit_func in {jit_lshift, jit_rshift}: continue x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(5).to(dtype1) y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(5).to(dtype2) z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(2).to(dtype3) jitted = torch.jit.script(jit_func) jit_o = jitted(x, y, z) jit_o = jitted(x, y, z) o = jit_func(x, y, z) self.assertEqual(o, jit_o) self.assertGraphContains(jitted.graph_for(x, y, z), FUSION_GUARD) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_type_as_op(self): def t(x: torch.Tensor, y: torch.Tensor, z: float): o = torch.lt(x, z) o = o.type_as(y) return o t_jit = torch.jit.script(t) x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") jit_o = t_jit(x, y, 0.5) jit_o = t_jit(x, y, 0.5) o = t(x, y, 0.5) self.assertEqual(o, jit_o) self.assertGraphContains(t_jit.graph_for(x, y, 0.5), FUSION_GUARD) def _ternary_integer_test_helper(self, dtype_arg1): shape = (4, 8, 32, 32) magnitude = 100 if (dtype_arg1 in self.int_types): x = torch.randint(-magnitude, magnitude, shape, dtype=dtype_arg1, device="cuda") else: x = torch.randn(shape, dtype=dtype_arg1, device="cuda") * magnitude arg2 = int(0) arg3 = int(magnitude * 0.1) def clamp0(x: torch.Tensor, f: int): o = 2. * torch.clamp(x, min=f) return o clamp0_jit = torch.jit.script(clamp0) self._run_helper(clamp0_jit, clamp0, x, arg2) def clamp1(x: torch.Tensor, f: int, ff: int): o = 2. * torch.clamp(x, min=f, max=ff) return o clamp1_jit = torch.jit.script(clamp1) self._run_helper(clamp1_jit, clamp1, x, arg2, arg3) def clamp2(x: torch.Tensor, f: float, ff: int): o = 2. * torch.clamp(x, min=f, max=ff) return o clamp2_jit = torch.jit.script(clamp2) self._run_helper(clamp2_jit, clamp2, x, float(arg2), arg3) def clamp3(x: torch.Tensor, f: int, ff: float): o = 2. * torch.clamp(x, min=f, max=ff) return o clamp3_jit = torch.jit.script(clamp3) self._run_helper(clamp3_jit, clamp3, x, arg2, float(arg3)) def threshold(x: torch.Tensor, th: int, val: int): o = 2. * torch.threshold(x, th, val) return o threshold_jit = torch.jit.script(threshold) self._run_helper(threshold_jit, threshold, x, arg2, arg3) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_ternary_ops_integer_compatibility(self): data_types = [ torch.float16, torch.float32, torch.float64 ] for dtype in data_types: self._ternary_integer_test_helper(dtype) def _ternary_test_helper(self, operation, dtypes, random_data): if isinstance(dtypes, tuple): dtype_arg1, dtype_arg2, dtype_arg3 = dtypes else: dtype_arg1 = dtype_arg2 = dtype_arg3 = dtypes def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, alpha: torch.Tensor): o = operation(x, y, z) o = o + alpha return o shape = (4, 32, 32) if operation is torch.where: dtype_arg1 = torch.bool if random_data: x = torch.randint(0, 2, shape).to(dtype=torch.bool, device="cuda") y = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg2) z = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg3) else: x = torch.randint(0, 2, self.special_values.size()).to(dtype=torch.bool, device="cuda") y = self.special_values.to(dtype=dtype_arg2) z = (torch.rand_like(self.special_values) * 5).to(dtype_arg3) elif random_data: x = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg1) y = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg2) z = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg3) else: x = self.special_values.to(dtype=dtype_arg1) y = (torch.rand_like(self.special_values) * 5).to(dtype_arg2) z = (torch.rand_like(self.special_values) * 5).to(dtype_arg3) alpha = torch.tensor([2], device="cuda").to(dtype_arg1) o = t(x, y, z, alpha) t_jit = torch.jit.script(t) jit_o = t_jit(x, y, z, alpha) jit_o = t_jit(x, y, z, alpha) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o, jit_o) self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_ternary_ops_type_promotion(self): # TODO: update accuracy tolerance for bf16 / fp16 data types data_types = [ # torch.float16, torch.float32, torch.float64 ] ''' if TEST_BF16: data_types.append(torch.bfloat16) ''' # TODO: Add Tensor support for clamp operations = [torch.clamp] ternary_dtype_combinations = itertools.combinations(data_types, 3) for op, dtypes in itertools.product(operations, ternary_dtype_combinations): self._ternary_test_helper(op, dtypes, True) # random data self._ternary_test_helper(op, dtypes, False) # special numbers # We can't test the scalar version of rsub from python @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_rsub(self): x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") def rsub(x: torch.Tensor, y: torch.Tensor): o = torch.rsub(x, y) o = o * 2. return o rsub_jit = torch.jit.script(rsub) self._run_helper(rsub_jit, rsub, x, y) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") # legacy fuser does not work for rand_like, see issue #34361 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_ternary_ops(self): x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") cond = torch.randint(0, 2, (4, 8, 32, 32)).to(dtype=torch.bool, device="cuda") def add(x: torch.Tensor, other: torch.Tensor, alpha: float): o = torch.relu(x) o = torch.add(o, other=other, alpha=alpha) return o add_jit = torch.jit.script(add) self._run_helper(add_jit, add, x, y, 2.0) def clamp0(x: torch.Tensor, f: float): o = 2. * torch.clamp(x, min=f) return o clamp0_jit = torch.jit.script(clamp0) self._run_helper(clamp0_jit, clamp0, x, 0.5) def clamp1(x: torch.Tensor, f: float, ff: float): o = 2. * torch.clamp(x, min=f, max=ff) return o clamp1_jit = torch.jit.script(clamp1) self._run_helper(clamp1_jit, clamp1, x, -0.2, 0.7) def threshold(x: torch.Tensor, th: float, val: float): o = 2. * torch.threshold(x, th, val) return o threshold_jit = torch.jit.script(threshold) self._run_helper(threshold_jit, threshold, x, 0.2, 0.9) def where(x: torch.Tensor, y: torch.Tensor, cond: torch.Tensor): o = 2. * torch.where(cond, x, y) return o where_jit = torch.jit.script(where) self._run_helper(where_jit, where, x, y, cond) def lerp(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): o = 2. * torch.lerp(x, y, z) return o lerp_jit = torch.jit.script(lerp) self._run_helper(lerp_jit, lerp, x, y, z) def lerp_scale(x: torch.Tensor, y: torch.Tensor, z: float): o = 2. * torch.lerp(x, y, z) return o lerp_scale_jit = torch.jit.script(lerp_scale) self._run_helper(lerp_scale_jit, lerp_scale, x, y, 0.5) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") def test_addcmul_ops(self): x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") def addcmul(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, value: float): o = torch.add(x, 0.5) o = torch.addcmul(o, y, z, value=value) return o addcmul_jit = torch.jit.script(addcmul) self._run_helper(addcmul_jit, addcmul, x, y, z, 2.0) def addcmul_no_alpha(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): o = torch.add(x, 0.5) o = torch.addcmul(o, y, z) return o addcmul_no_alpha_jit = torch.jit.script(addcmul_no_alpha) self._run_helper(addcmul_no_alpha_jit, addcmul_no_alpha, x, y, z) def addcmul_const_alpha(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): o = torch.add(x, 0.5) o = torch.addcmul(o, y, z, value=0.75) return o addcmul_const_alpha_jit = torch.jit.script(addcmul_const_alpha) self._run_helper(addcmul_const_alpha_jit, addcmul_const_alpha, x, y, z) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_dynamic_size(self): old_guard = torch._C._jit_set_nvfuser_guard_mode(True) torch._C._jit_set_bailout_depth(20) def t(x: torch.Tensor, y: torch.Tensor, z: float): o = x + y o = o + z return o t_jit = torch.jit.script(t) x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") y = torch.randn(32, 32, dtype=torch.float, device="cuda") jit_o = t_jit(x, y, 2.0) jit_o = t_jit(x, y, 2.0) o = t(x, y, 2.0) self.assertEqual(o, jit_o) subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0)) self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False) # this test is not ideal, as we rely on the bailout to test it and we # don't know a way to verify the bailout graph to validate the proper # fusion. x = torch.randn(8, 32, 16, 8, dtype=torch.float, device="cuda") y = torch.randn(16, 8, dtype=torch.float, device="cuda") jit_o = t_jit(x, y, 2.0) jit_o = t_jit(x, y, 2.0) o = t(x, y, 2.0) self.assertEqual(o, jit_o) self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD) x = torch.randn(8, 17, 8, dtype=torch.float, device="cuda") y = torch.randn(8, 17, 1, dtype=torch.float, device="cuda") jit_o = t_jit(x, y, 2.0) jit_o = t_jit(x, y, 2.0) o = t(x, y, 2.0) self.assertEqual(o, jit_o) self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD) torch._C._jit_set_nvfuser_guard_mode(old_guard) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_random_topo(self): os.environ["PYTORCH_NVFUSER_DISABLE_FALLBACK"] = "1" self.assertTrue(runDefaultTestWithSeed(28449)) def _compare(self, desc, inp1, inp2, error): a = inp1.clone() b = inp2.clone() close = torch.allclose(a, b, rtol=error, atol=error, equal_nan=True) if not close: print(desc, close) z = a - b index = (torch.abs(z) >= error + error * torch.abs(b)).nonzero() print("dif : ", z[index]) print("inp1 : ", a[index]) print("inp2 : ", b[index]) print("maximum difference", z[index].max()) return close # Permutation helper that applies binary operation between two tensors: # 1. applies separate permutation `perm0` & `perm1` to two inputs # 2. reduce dimension `broadcast_axis` of operand two to size 1 # The purpose of this test is to ensure permutation works well in # complicated cases with arbitrary stride order and broadcasting dimensions def _permutation_helper(self, sizes, broadcast_axis, dtype, device, perm0, perm1): def t(x: torch.Tensor, y: torch.Tensor): o = torch.add(x, y) o = torch.relu(o) return o x = torch.randn([sizes[i] for i in perm0], dtype=dtype, device=device).permute( [perm0.index(i) for i in range(len(sizes))]) if broadcast_axis >= 0: sizes[broadcast_axis] = 1 y = torch.randn([sizes[i] for i in perm1], dtype=dtype, device=device).permute( [perm1.index(i) for i in range(len(sizes))]) t_jit = torch.jit.script(t) jit_o = t_jit(x, y) jit_o = t_jit(x, y) o = t(x, y) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o, jit_o) self.assertEqual(o.stride(), jit_o.stride()) self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD) # end-2-end test of permutation & contiguity handling in integration. # we are testing inputs with all combination of permutation order, just to # ensure that integration would be able to generate functionally correct # kernels @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_binary_ops_permutation(self): # note that num_dim is exclusive from len(x), so we are not reducing # to single element (codegen limitation at this moment) x = [7, 8, 12] b_axes = range(-1, len(x)) for b_axis in b_axes: for perm0 in itertools.permutations(range(len(x))): for perm1 in itertools.permutations(range(len(x))): x = [7, 8, 12] self._permutation_helper(x, b_axis, torch.float32, "cuda", perm0, perm1) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_binary_ops_channels_last_with_bcast(self): device = "cuda" x = torch.randn([4, 3, 2, 5], device=device).to(memory_format=torch.channels_last) w = torch.randn([2, 5], device=device) def t(x: torch.Tensor, b: torch.Tensor): o = x + b return torch.relu(o) t_jit = torch.jit.script(t) jit_o = t_jit(x, w) jit_o = t_jit(x, w) jit_o = t_jit(x, w) o = t(x, w) self.assertEqual(o.dtype, jit_o.dtype) self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4)) self.assertGraphContains(t_jit.graph_for(x, w), FUSION_GUARD) def _reduction_helper(self, sizes, reduction_axis, dtype, device, perm0, perm1, keepdim=False): class MyReduction(torch.nn.Module): __constants__ = ['reduction_axis', 'keepdim'] def __init__(self): super(MyReduction, self).__init__() self.reduction_axis = reduction_axis self.keepdim = keepdim def forward(self, x: torch.Tensor, y: torch.Tensor): o = torch.add(x, y) o = torch.sum(o, dim=self.reduction_axis, keepdim=self.keepdim) return o t = MyReduction() x = torch.randn([sizes[i] for i in perm0], dtype=dtype, device=device).permute( [perm0.index(i) for i in range(len(sizes))]) y = torch.randn([sizes[i] for i in perm1], dtype=dtype, device=device).permute( [perm1.index(i) for i in range(len(sizes))]) t_jit = torch.jit.script(t) jit_o = t_jit(x, y) jit_o = t_jit(x, y) o = t(x, y) self.assertEqual(o.dtype, jit_o.dtype) # numerical issues here due to our scheduling. # can't use `self.assertEqual(o, jit_o)` self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4)) self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_reduction(self): for x in ([7, 8, 12], [12, 8, 7, 9, 15], [128, 16, 8, 32]): # note that num_dim is exclusive from len(x), so we are not reducing # to single element (codegen limitation at this moment) for num_reduce_dim in range(1, len(x)): for axes in itertools.combinations(range(len(x)), num_reduce_dim): for keepdim in (True, False): perm0 = range(len(x)) perm1 = range(len(x)) self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1, keepdim) def _layer_norm_autodiff_helper(self, model, grad, shapes, args): jit_model = torch.jit.script(model) eps = np.random.random() * 1e-4 use_cudnn = bool(np.random.randint(0, 2)) # profile/optimization runs for i in range(3): jit_o = jit_model(shapes, *args, eps, use_cudnn) jit_o.backward(grad) ref_args = [t.detach().clone().requires_grad_() for t in args] [t.grad.zero_() for t in args] jit_o = jit_model(shapes, *args, eps, use_cudnn) jit_o.backward(grad) o = model(shapes, *ref_args, eps, use_cudnn) o.backward(grad) self.assertEqual(jit_o, o) for arg, ref_arg in zip(args, ref_args): self.assertEqual(arg.grad, ref_arg.grad) # check fusion in fw & bw g = jit_model.graph_for(shapes, *args, eps, use_cudnn) for node in g.nodes(): n = node dbg_state = jit_model.get_debug_state() for val in dbg_state.execution_plans.values(): v = val state2 = v.code.grad_executor_states() for val in state2[0].execution_plans.values(): v2 = val FileCheck().check(FUSION_GUARD).run(g) FileCheck().check(FUSION_GUARD).run(v2.graph) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_layer_norm_autodiff(self): def t_wb(shapes: List[int], x, w, b, eps: float, cudnn: bool): o = torch.layer_norm(x, shapes, w, b, eps, cudnn) o = torch.relu(o) return o def t_w(shapes: List[int], x, w, eps: float, cudnn: bool): o = torch.layer_norm(x, shapes, w, None, eps, cudnn) o = torch.relu(o) return o def t_b(shapes: List[int], x, b, eps: float, cudnn: bool): o = torch.layer_norm(x, shapes, None, b, eps, cudnn) o = torch.relu(o) return o def t(shapes: List[int], x, eps: float, cudnn: bool): o = torch.layer_norm(x, shapes, None, None, eps, cudnn) o = torch.relu(o) return o model = {3: t_wb, 2: t_w, 1: t_b, 0: t} for w, b in itertools.product([True, False], repeat=2): batch = [2] # note: awkward shape here to avoid vectorized fast kernel, which is # buggy in aten shapes = [2, 7, 3] m = model[w * 2 + b] grad = torch.randn(batch + shapes, dtype=torch.float32, device="cuda") args = [torch.randn(batch + shapes, dtype=torch.float32, device="cuda").requires_grad_()] if w: args.append(torch.randn(shapes, dtype=torch.float32, device="cuda").requires_grad_()) if b: args.append(torch.randn(shapes, dtype=torch.float32, device="cuda").requires_grad_()) self._layer_norm_autodiff_helper(m, grad, shapes, args) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_layer_norm_parser(self): dtype = torch.float32 device = "cuda" x = torch.randn([4, 4, 2], dtype=dtype, device=device) w = torch.randn([4, 2], dtype=dtype, device=device) b = torch.randn([4, 2], dtype=dtype, device=device) def t(x: torch.Tensor, w: torch.Tensor, b: torch.Tensor): o = torch.relu(x) o = torch.layer_norm(o, [4, 2], w, b, 1e-5) return o o = t(x, w, b) t_jit = torch.jit.script(t) jit_o = t_jit(x, w, b) jit_o = t_jit(x, w, b) o = t(x, w, b) self.assertGraphContains(t_jit.graph_for(x, w, b), FUSION_GUARD) def _native_layer_norm_helper(self, shape, norm_shape, dtype, device, error, affine=True): class MyLayerNorm(torch.nn.Module): __constants__ = ['norm_shape'] def __init__(self, elementwise_affine=True): super(MyLayerNorm, self).__init__() self.norm_shape = norm_shape if elementwise_affine: self.weight = torch.randn(norm_shape, dtype=dtype, device=device) self.bias = torch.randn(norm_shape, dtype=dtype, device=device) with torch.no_grad(): self.weight.fill_(1) self.bias.fill_(0) else: self.weight = None self.bias = None def forward(self, x: torch.Tensor): o = torch.relu(x) o = torch.native_layer_norm(o, self.norm_shape, self.weight, self.bias, 1e-5) return o t = MyLayerNorm(affine) x = torch.randn(shape, dtype=dtype, device=device) t_jit = torch.jit.script(t) jit_o, jit_mean, jit_rstd = t_jit(x) jit_o, jit_mean, jit_rstd = t_jit(x) o, mean, rstd = t(x) self.assertEqual(o.dtype, jit_o.dtype) # numerical issues here due to our scheduling. # can't use `self.assertEqual(o, jit_o)` self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) self.assertTrue(self._compare("comparing mean failed", mean, jit_mean, error)) self.assertTrue(self._compare("comparing rstd failed", rstd, jit_rstd, error)) self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_native_layer_norm(self): dims = 4 rnds = 3 for idx in range(rnds): for offset in range(1, dims): for affine in (True, False): input_shape = [random.randint(10, 30) for idx in range(dims)] norm_shape = [input_shape[idx] for idx in range(dims - offset, dims)] self._native_layer_norm_helper(input_shape, norm_shape, torch.float32, "cuda", 1e-4, affine) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_native_layer_norm_half(self): dims = 4 rnds = 3 for idx in range(rnds): for offset in range(1, dims): input_shape = [random.randint(10, 30) for idx in range(dims)] norm_shape = [input_shape[idx] for idx in range(dims - offset, dims)] self._native_layer_norm_helper(input_shape, norm_shape, torch.float16, "cuda", 5e-3) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @unittest.skipIf(not TEST_BF16, "device does not support BFloat16") def test_native_layer_norm_bfloat(self): dims = 4 rnds = 3 for idx in range(rnds): for offset in range(1, dims): input_shape = [random.randint(10, 30) for idx in range(dims)] norm_shape = [input_shape[idx] for idx in range(dims - offset, dims)] self._native_layer_norm_helper(input_shape, norm_shape, torch.bfloat16, "cuda", 1e-1) def _norm_helper(self, shape, dtype, device, error, is_batch_norm_else_instance_norm, memory_format=torch.contiguous_format, *, layer_dtype=torch.float32): class MyBatchNorm(torch.nn.Module): def __init__(self): super(MyBatchNorm, self).__init__() def forward(self, x: torch.Tensor, r_mean: torch.Tensor, r_var: torch.Tensor): o = torch.nn.functional.batch_norm(x, r_mean, r_var, training=True) o = torch.relu(o) return o class MyInstanceNorm(torch.nn.Module): def __init__(self): super(MyInstanceNorm, self).__init__() def forward(self, x: torch.Tensor, r_mean: torch.Tensor, r_var: torch.Tensor): o = torch.nn.functional.instance_norm(x, r_mean, r_var, use_input_stats=True) o = torch.relu(o) return o t = MyBatchNorm() if is_batch_norm_else_instance_norm else MyInstanceNorm() x = torch.randn(shape, dtype=dtype, device=device).to(memory_format=memory_format) running_mean = torch.zeros(shape[1], dtype=layer_dtype, device=device) running_var = torch.ones(shape[1], dtype=layer_dtype, device=device) t_jit = torch.jit.script(t) eager_running_mean = running_mean.clone() eager_running_var = running_var.clone() jit_running_mean = running_mean.clone() jit_running_var = running_var.clone() jit_o = t_jit(x, running_mean.clone(), running_var.clone()) self.assertTrue(self._compare("prerun comparing running_mean failed", eager_running_mean, jit_running_mean, error)) self.assertTrue(self._compare("prerun comparing running_var failed", eager_running_var, jit_running_var, error)) jit_o = t_jit(x, jit_running_mean, jit_running_var) o = t(x, eager_running_mean, eager_running_var) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o.stride(), jit_o.stride()) # numerical issues here due to our scheduling. # can't use `self.assertEqual(o, jit_o)` self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) self.assertTrue(self._compare("comparing running_mean failed", eager_running_mean, jit_running_mean, error)) self.assertTrue(self._compare("comparing running_var failed", eager_running_var, jit_running_var, error)) self.assertGraphContains(t_jit.graph_for(x, running_mean, running_var), FUSION_GUARD) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_layer_norm_trivial_reduce_dim(self): def t_wb(shapes: List[int], x, w, b, eps: float, cudnn: bool): o = torch.layer_norm(x, shapes, w, b, eps, cudnn) o = torch.relu(o) return o batch = [1] shapes = [2, 7, 3] grad = torch.randn(batch + shapes, dtype=torch.float32, device="cuda") args = [torch.randn(batch + shapes, dtype=torch.float32, device="cuda").requires_grad_()] args.append(torch.randn(shapes, dtype=torch.float32, device="cuda").requires_grad_()) args.append(torch.randn(shapes, dtype=torch.float32, device="cuda").requires_grad_()) self._layer_norm_autodiff_helper(t_wb, grad, shapes, args) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_norm_half_layer(self): size = [2, 4, 2, 2] for is_batch_norm_else_instance_norm in [False, True]: for mf in [torch.channels_last, torch.contiguous_format]: self._norm_helper(size, torch.float16, "cuda", 1e-3, is_batch_norm_else_instance_norm, memory_format=mf, layer_dtype=torch.float16) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_norm_channels_last(self): size = [3, 4, 5, 6] with torch.backends.cudnn.flags(enabled=False): for is_batch_norm_else_instance_norm in [False, True]: for mf in [torch.channels_last, torch.contiguous_format]: self._norm_helper(size, torch.float32, "cuda", 1e-4, is_batch_norm_else_instance_norm, memory_format=mf) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_norm(self): output_elements = 10000 channel_sizes = [67, 457, 1024, 4096] with torch.backends.cudnn.flags(enabled=False): for is_batch_norm_else_instance_norm in [False, True]: for dims in range(3, 6): output_size = int(pow(output_elements, 1. / (dims - 1))) for C in channel_sizes: x = [output_size for idx in range(dims)] x[1] = C self._norm_helper(x, torch.float32, "cuda", 1e-4, is_batch_norm_else_instance_norm) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_norm_large(self): output_elements = 262144 channel_sizes = 67, 457, 1024 for is_batch_norm_else_instance_norm in [True, False]: for dims in range(3, 6): output_size = int(pow(output_elements, 1. / (dims - 1))) for C in channel_sizes: x = [output_size for idx in range(dims)] x[1] = C self._norm_helper(x, torch.float32, "cuda", 1e-4, is_batch_norm_else_instance_norm) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_norm_half(self): output_elements = 10000 channel_sizes = [67, 457, 1024, 4096] with torch.backends.cudnn.flags(enabled=False): for is_batch_norm_else_instance_norm in [False, True]: for dims in range(3, 6): output_size = int(pow(output_elements, 1. / (dims - 1))) for C in channel_sizes: x = [output_size for idx in range(dims)] x[1] = C self._norm_helper(x, torch.float16, "cuda", 5e-3, is_batch_norm_else_instance_norm) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @unittest.skipIf(not TEST_BF16, "device does not support BFloat16") def test_norm_bfloat(self): output_elements = 10000 channel_sizes = [67, 457, 1024, 4096] with torch.backends.cudnn.flags(enabled=False): for is_batch_norm_else_instance_norm in [False, True]: for dims in range(3, 6): output_size = int(pow(output_elements, 1. / (dims - 1))) for C in channel_sizes: x = [output_size for idx in range(dims)] x[1] = C self._norm_helper(x, torch.bfloat16, "cuda", 1e-1, is_batch_norm_else_instance_norm) def _softmax_helper(self, shape, reduction_axis, is_log_softmax, dtype, device, error): class MySoftmax(torch.nn.Module): __constants__ = ['reduction_axis'] def __init__(self): super(MySoftmax, self).__init__() self.reduction_axis = reduction_axis def forward(self, x: torch.Tensor, y: torch.Tensor): o = torch.add(x, y) o = torch.nn.functional.softmax(o, dim=self.reduction_axis) return o class MyLogSoftmax(torch.nn.Module): __constants__ = ['reduction_axis'] def __init__(self): super(MyLogSoftmax, self).__init__() self.reduction_axis = reduction_axis def forward(self, x: torch.Tensor, y: torch.Tensor): o = torch.add(x, y) o = torch.nn.functional.log_softmax(o, dim=self.reduction_axis) return o gradient_check = (dtype == torch.float64) t = MyLogSoftmax() if is_log_softmax else MySoftmax() x = torch.randn(shape, dtype=dtype, device=device, requires_grad=gradient_check) y = torch.randn(shape, dtype=dtype, device=device, requires_grad=gradient_check) t_jit = torch.jit.script(t) jit_o = t_jit(x, y) jit_o = t_jit(x, y) jit_o = t_jit(x, y) if gradient_check: gradcheck(t_jit.forward, [x, y], nondet_tol=1e-5) else: o = t(x, y) self.assertEqual(o.dtype, jit_o.dtype) # numerical issues here due to our scheduling. # can't use `self.assertEqual(o, jit_o)` self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_softmax_dtype(self): def t(x: torch.Tensor, y: torch.Tensor): o = torch.mul(x, y) o = torch.nn.functional.softmax(o, dim=0, dtype=torch.float32) return o x = torch.randn([4, 4], dtype=torch.float16, device="cuda").requires_grad_() y = torch.randn_like(x).requires_grad_() grad = torch.randn_like(x).float() ref_x = x.detach().requires_grad_() ref_y = y.detach().requires_grad_() o = t(ref_x, ref_y) o.backward(grad) t_jit = torch.jit.script(t) jit_o = t_jit(x, y) jit_o.backward(grad) jit_o = t_jit(x, y) jit_o.backward(grad) jit_o = t_jit(x, y) jit_o.backward(grad) x.grad.zero_() y.grad.zero_() jit_o = t_jit(x, y) jit_o.backward(grad) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(ref_x.grad, x.grad) self.assertEqual(ref_y.grad, y.grad) self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-3)) self.assertGraphContainsExactly(t_jit.graph_for(x, y), FUSION_GUARD, 1, consider_subgraphs=True) bwd_graph = list( list(t_jit.get_debug_state().execution_plans.values())[ 0].code.grad_executor_states()[0].execution_plans.values() )[0].graph FileCheck().check(FUSION_GUARD).run(bwd_graph) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test__softmax_function(self): def t(x: torch.Tensor, y: torch.Tensor): o = torch.mul(x, y) o = torch._softmax(o, dim=-1, half_to_float=False) return o x = torch.randn([4, 4], dtype=torch.float16, device="cuda") y = torch.randn_like(x) o = t(x, y) t_jit = torch.jit.script(t) jit_o = t_jit(x, y) jit_o = t_jit(x, y) jit_o = t_jit(x, y) self.assertEqual(o.dtype, jit_o.dtype) self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-3)) self.assertGraphContainsExactly(t_jit.graph_for(x, y), FUSION_GUARD, 1, consider_subgraphs=True) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test__softmax_function_half_to_float(self): def t(x: torch.Tensor, y: torch.Tensor): o = torch.mul(x, y) o = torch._softmax(o, dim=-1, half_to_float=True) return o x = torch.randn([4, 4], dtype=torch.float16, device="cuda") y = torch.randn_like(x) o = t(x, y) t_jit = torch.jit.script(t) jit_o = t_jit(x, y) jit_o = t_jit(x, y) jit_o = t_jit(x, y) self.assertEqual(o.dtype, jit_o.dtype) self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-3)) self.assertGraphContainsExactly(t_jit.graph_for(x, y), FUSION_GUARD, 1, consider_subgraphs=True) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_softmax(self): output_size = 10000 dims = 4 output_size = int(pow(output_size, 1. / dims)) reduction_sizes = [67, 256, 1024, 4096] # gradient check for reduction_dim in range(dims): for is_log_softmax in [False, True]: shape = [output_size for idx in range(dims)] self._softmax_helper(shape, reduction_dim, is_log_softmax, torch.float64, "cuda", 1e-4) for reduction_dim in range(dims): for reduction_size in reduction_sizes: x = [output_size for idx in range(dims)] x[reduction_dim] = reduction_size for is_log_softmax in [False, True]: self._softmax_helper(x, reduction_dim, is_log_softmax, torch.float32, "cuda", 1e-4) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_softmax_half(self): output_size = 10000 dims = 4 output_size = int(pow(output_size, 1. / dims)) reduction_sizes = [67, 256, 1024, 4096] for reduction_dim in range(dims): for reduction_size in reduction_sizes: x = [output_size for idx in range(dims)] x[reduction_dim] = reduction_size for is_log_softmax in [False, True]: self._softmax_helper(x, reduction_dim, is_log_softmax, torch.float16, "cuda", 5e-3) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @unittest.skipIf(not TEST_BF16, "device does not support BFloat16") def test_softmax_bfloat(self): output_size = 10000 dims = 4 output_size = int(pow(output_size, 1. / dims)) reduction_sizes = [67, 256, 1024, 4096] for reduction_dim in range(dims): for reduction_size in reduction_sizes: x = [output_size for idx in range(dims)] x[reduction_dim] = reduction_size for is_log_softmax in [False, True]: self._softmax_helper(x, reduction_dim, is_log_softmax, torch.bfloat16, "cuda", 1e-1) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_reduction_permutation(self): x = [7, 8, 12] # note that num_dim is exclusive from len(x), so we are not reducing # to single element (codegen limitation at this moment) for num_reduce_dim in range(1, len(x)): for axes in itertools.combinations(range(len(x)), num_reduce_dim): for perm0 in itertools.permutations(range(len(x))): for perm1 in itertools.permutations(range(len(x))): self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_reduction_multiple_output(self): old_guard = torch._C._jit_set_nvfuser_guard_mode(True) torch._C._jit_set_bailout_depth(20) def t(x: torch.Tensor, y: torch.Tensor, scale: float, z: torch.Tensor): o = torch.mul(x, y) o = torch.mul(o, scale) out1 = torch.mul(o, z) out2 = torch.sum(out1, dim=[2]) return out1, out2 t_jit = torch.jit.script(t) x = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda") y = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda") z = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda") scale = 0.5 jit_o = t_jit(x, y, scale, z) jit_o = t_jit(x, y, scale, z) o = t(x, y, scale, z) for oo, jit_oo in zip(o, jit_o): self.assertEqual(oo.dtype, jit_oo.dtype) self.assertEqual(oo, jit_oo) self.assertGraphContains(t_jit.graph_for(x, y, scale, z), FUSION_GUARD) x = x.to(memory_format=torch.channels_last) y = y.to(memory_format=torch.channels_last) z = z.to(memory_format=torch.channels_last) jit_o = t_jit(x, y, scale, z) jit_o = t_jit(x, y, scale, z) o = t(x, y, scale, z) for oo, jit_oo in zip(o, jit_o): self.assertEqual(oo.dtype, jit_oo.dtype) self.assertEqual(oo, jit_oo) self.assertGraphContains(t_jit.graph_for(x, y, scale, z), FUSION_GUARD) torch._C._jit_set_nvfuser_guard_mode(old_guard) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_channels_last_with_broadcast(self): # setting this true forces a new graph to be generated with a new # input a different broadcast shape torch._C._jit_set_nvfuser_guard_mode(True) def t(x: torch.Tensor, y: torch.Tensor): o = torch.mul(x, y) o = o + 2.0 return o t_jit = torch.jit.script(t) # Single Channel broadcasts # Test 1 x = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda") x = x.to(memory_format=torch.channels_last) y = torch.randn(8, 4, 10, 1, dtype=torch.float, device="cuda") y = y.to(memory_format=torch.channels_last) jit_o = t_jit(x, y) jit_o = t_jit(x, y) o = t(x, y) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o.is_contiguous(memory_format=torch.channels_last), jit_o.is_contiguous(memory_format=torch.channels_last)) self.assertEqual(o, jit_o) # Test 2 y = torch.randn(8, 4, 1, 16, dtype=torch.float, device="cuda") y = y.to(memory_format=torch.channels_last) jit_o = t_jit(x, y) jit_o = t_jit(x, y) o = t(x, y) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o.is_contiguous(memory_format=torch.channels_last), jit_o.is_contiguous(memory_format=torch.channels_last)) self.assertEqual(o, jit_o) # Test 3 y = torch.randn(8, 1, 10, 16, dtype=torch.float, device="cuda") y = y.to(memory_format=torch.channels_last) jit_o = t_jit(x, y) jit_o = t_jit(x, y) o = t(x, y) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o.is_contiguous(memory_format=torch.channels_last), jit_o.is_contiguous(memory_format=torch.channels_last)) self.assertEqual(o, jit_o) # Test 3 y = torch.randn(1, 4, 10, 16, dtype=torch.float, device="cuda") y = y.to(memory_format=torch.channels_last) jit_o = t_jit(x, y) jit_o = t_jit(x, y) o = t(x, y) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o.is_contiguous(memory_format=torch.channels_last), jit_o.is_contiguous(memory_format=torch.channels_last)) self.assertEqual(o, jit_o) ''' Currently, the JIT doesn't have tensor merge logic to handle adding a broadcast tensor with more than one broadcast into a non-broadcast tensor. Therefore, either of these tests can fail depending on the sort implementation. The second test is known to fail. # Two Channel broadcasts # Test 1 y = torch.randn(8, 4, 1, 1, dtype=torch.float, device="cuda") y = y.to(memory_format=torch.channels_last) jit_o = t_jit(x, y) jit_o = t_jit(x, y) o = t(x, y) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o.is_contiguous(memory_format=torch.channels_last), jit_o.is_contiguous(memory_format=torch.channels_last)) self.assertEqual(o, jit_o) # Test 2 y = torch.randn(8, 4, 1, 1, dtype=torch.float, device="cuda") y = y.to(memory_format=torch.channels_last).transpose(2,3) x = x.transpose(2,3) jit_o = t_jit(x, y) jit_o = t_jit(x, y) o = t(x, y) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o.is_contiguous(memory_format=torch.channels_last), jit_o.is_contiguous(memory_format=torch.channels_last)) self.assertEqual(o, jit_o) ''' @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_pw_single_reduction_partition(self): sizes = [2, 2, 2] dtype = torch.float device = "cuda" x = torch.randn(sizes, dtype=dtype, device=device) y = torch.randn(sizes, dtype=dtype, device=device) z = torch.randn(sizes, dtype=dtype, device=device) def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): o = torch.add(x, y) o = torch.sum(o, dim=[0]) o = torch.add(o, z) return o t_jit = torch.jit.script(t) jit_o = t_jit(x, y, z) jit_o = t_jit(x, y, z) o = t(x, y, z) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o, jit_o) self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_permutation_preservation(self): sizes = [2, 3, 4, 5] dtype = torch.float device = "cuda" with nvfuser_singleton_fusion(True): def t(x: torch.Tensor): return torch.relu(x) t_jit = torch.jit.script(t) x = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last) self._run_helper(t_jit, t, x, check_stride=True) def t(x: torch.Tensor, y: torch.Tensor): return torch.add(x, y) t_jit = torch.jit.script(t) x = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last) y = torch.randn(sizes[1:], dtype=dtype, device=device) self._run_helper(t_jit, t, x, y, check_stride=True) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_permutation_preservation_edge_case_0(self): sizes = [2, 3, 4, 5] dtype = torch.float device = "cuda" x = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last) # mismatch rank with *note* different permutation recognized by PE bias = torch.randn(3, dtype=dtype, device=device).unsqueeze(-1).unsqueeze(-1) def t(x, y): return x + y t_jit = torch.jit.script(t) with nvfuser_singleton_fusion(True): self._run_helper(t_jit, t, x, bias, check_stride=True) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_permutation_preservation_edge_case_1_broken(self): sizes = [2, 3, 4, 5] dtype = torch.float device = "cuda" x = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last) # in-compatible permutation, this will cause format propagation to break bias = torch.randn(4, 5, dtype=dtype, device=device) def t(x, y): return x + y t_jit = torch.jit.script(t) with nvfuser_singleton_fusion(True): for _ in range(5): jit_o = t_jit(x, bias) o = t(x, bias) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o, jit_o) try: # nvfuser does not support in-compatible permutation, this will throw self.assertEqual(o.stride(), jit_o.stride()) except Exception as e: warnings.warn( "permutation propagatoin is broken, proper support should come after nvfuser permutation scheduler update") self.assertGraphContains(t_jit.graph_for(x, bias), FUSION_GUARD) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_permutation_preservation_edge_case_2(self): sizes = [2, 3, 4, 5] dtype = torch.float device = "cuda" x = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last) y = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last) z = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last) def t(x, y, w): tmp = torch.lerp(x, y, w) tmp = torch.clamp(tmp, -1.0, 0.5) tmp = torch.nn.functional.softplus(tmp) return torch.threshold(tmp, -2.0, 0.5) t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x, y, z, check_stride=True) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_normalization_partition(self): sizes = [3, 8, 5] dtype = torch.float device = "cuda" x = torch.randn(sizes, dtype=dtype, device=device) y = torch.randn(sizes, dtype=dtype, device=device) z = torch.randn(sizes, dtype=dtype, device=device) r_m = torch.randn(8, dtype=dtype, device=device) r_v = torch.randn(8, dtype=dtype, device=device) def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, r_mean: torch.Tensor, r_var: torch.Tensor): o = torch.add(x, y) o = torch.nn.functional.softmax(o, dim=0) o = torch.add(o, z) o = torch.nn.functional.batch_norm(o, r_mean, r_var, training=True) return o t_jit = torch.jit.script(t) jit_o = t_jit(x, y, z, r_m, r_v) jit_o = t_jit(x, y, z, r_m, r_v) o = t(x, y, z, r_m, r_v) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o, jit_o) self.assertGraphContains(t_jit.graph_for(x, y, z, r_m, r_v), FUSION_GUARD) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_sum_to_one(self): dtype = torch.float device = "cuda" x = torch.randn([4, 5, 6], dtype=dtype, device=device) def t(x: torch.Tensor): o = torch.add(x, 1) o = torch.sum(o, dim=[0, 1, 2]) return o t_jit = torch.jit.script(t) jit_o = t_jit(x) jit_o = t_jit(x) o = t(x) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o, jit_o) self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_single_reduction_broadcast(self): dtype = torch.float device = "cuda" x = torch.randn([7, 4, 8], dtype=dtype, device=device) y = torch.randn([4, 8], dtype=dtype, device=device) z = torch.randn([1, 4, 8], dtype=dtype, device=device) def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): o = torch.add(x, y) o = torch.add(o, z) o = torch.sum(o, dim=[0]) return o t_jit = torch.jit.script(t) jit_o = t_jit(x, y, z) jit_o = t_jit(x, y, z) o = t(x, y, z) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o, jit_o) self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_trivial_reduction(self): dtype = torch.float device = "cuda" x = torch.randn([1, 4, 8], dtype=dtype, device=device) def t(x: torch.Tensor): o = torch.add(x, 1) o = torch.sum(o, dim=[0]) o = torch.sum(o, dim=[0]) return o t_jit = torch.jit.script(t) jit_o = t_jit(x) jit_o = t_jit(x) o = t(x) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o, jit_o) self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_profiling_node(self): dtype = torch.float device = "cuda" x = torch.randn(4, 8, 8, 8, dtype=dtype, device=device) def repro(x: torch.Tensor, alpha: float): o = torch.rand_like(x) o = torch.add(o, alpha) return o repro_jit = torch.jit.script(repro) self._run_helper(repro_jit, repro, x, 0.6) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_rand_like(self): dtype = torch.float device = "cuda" def t(x: torch.Tensor, alpha: float): o = torch.rand_like(x) o = torch.add(o, alpha) return o # disabling cache so new inputs would generate new graph t.__disable_jit_function_caching__ = True for m_format in [torch.contiguous_format, torch.channels_last]: x = torch.randn(4, 5, 6, 7, dtype=dtype, device=device).to(memory_format=m_format) t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x, 0.6, check_stride=True) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_reduction_sizes_op(self): dtype = torch.float device = "cuda" x = torch.randn(2, 3, 4, 5, dtype=dtype, device=device) y = torch.randn(2, 3, 4, 5, dtype=dtype, device=device) def t(x: torch.Tensor, y: torch.Tensor): o = x + y o = torch.relu(o) o = o.sum((1, 3)) return o.size() t_jit = torch.jit.script(t) jit_o = t_jit(x, y) jit_o = t_jit(x, y) o = t(x, y) self.assertEqual(o, jit_o) # since the output value is not used at all, the fusion operator should # have been optimized away self.assertGraphContainsExactly(t_jit.graph_for(x, y), FUSION_GUARD, 0) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_profile_ivalue(self): dtype = torch.float device = "cuda" x = torch.randn([7, 4, 7], dtype=dtype, device=device) y = torch.randn([7, 4, 7], dtype=dtype, device=device) def t(x: torch.Tensor, y: torch.Tensor, dim: List[int], keepdim: bool): o = torch.add(x, y) o = o.sum(dim, keepdim=keepdim) return o t_jit = torch.jit.script(t) jit_o = t_jit(x, y, (0, 1), False) jit_o = t_jit(x, y, (0, 1), False) o = t(x, y, (0, 1), False) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o, jit_o) self.assertGraphContains(t_jit.graph_for(x, y, (0, 1), False), FUSION_GUARD) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_profile_ivalue_multiple_profiles(self): dtype = torch.float device = "cuda" x = torch.randn([7, 4, 7], dtype=dtype, device=device) def t(x, num: int): for i in range(num): # varying reduction axes should break profile_ivalue tmp = x.sum(i, keepdim=True) # inplace add on input/output, can't be functionalized/fused x += tmp return x with nvfuser_singleton_fusion(True): t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x, 3, num_fusion=0) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_sum_to_size(self): dtype = torch.float device = "cuda" x = torch.randn([2, 4, 4], dtype=dtype, device=device) y = torch.randn([2, 4, 4], dtype=dtype, device=device) def t(x: torch.Tensor, y: torch.Tensor, new_size: List[int]): o = torch.add(x, y) o = o.sum_to_size(new_size) return o t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x, y, (4, 1)) # update shape: old kernel should handle dynamic shape well without # recompilation x = torch.randn([2, 5, 8], dtype=dtype, device=device) y = torch.randn([2, 5, 8], dtype=dtype, device=device) # (TODO) check executed kernel, should extend autograd.profiler to fused # kernels self._run_helper(t_jit, t, x, y, (5, 1)) with nvfuser_singleton_fusion(True): x = torch.randn([2, 5, 8], dtype=dtype, device=device) def t(x: torch.Tensor): # no-op reduction return x.sum_to_size((2, 5, 8)) t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_grad_sum_to_size(self): dtype = torch.float device = "cuda" x = torch.randn([2, 4, 4], dtype=dtype, device=device).requires_grad_() y = torch.randn([4], dtype=dtype, device=device).requires_grad_() grad = torch.randn([2, 4, 4], dtype=dtype, device=device) ref_x = x.detach().clone().requires_grad_() ref_y = y.detach().clone().requires_grad_() def t(x: torch.Tensor, y: torch.Tensor): o = torch.add(x, y) o = torch.relu(o) return o # profiling runs for forward & backward t_jit = torch.jit.script(t) jit_o = t_jit(x, y) jit_o.backward(grad) jit_o = t_jit(x, y) jit_o.backward(grad) x.grad = None y.grad = None jit_o = t_jit(x, y) jit_o.backward(grad) o = t(ref_x, ref_y) o.backward(grad) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o, jit_o) self.assertEqual(x.grad, ref_x.grad) self.assertEqual(y.grad, ref_y.grad) bwd_graph = list( list(t_jit.get_debug_state().execution_plans.values())[ 0].code.grad_executor_states()[0].execution_plans.values() )[0].graph FileCheck().check(FUSION_GUARD).run(bwd_graph) # update shape: old kernel should handle dynamic shape well without # recompilation x = torch.randn([2, 5, 8], dtype=dtype, device=device).requires_grad_() y = torch.randn([8], dtype=dtype, device=device).requires_grad_() ref_x = x.detach().clone().requires_grad_() ref_y = y.detach().clone().requires_grad_() grad = torch.randn([2, 5, 8], dtype=dtype, device=device) jit_o = t_jit(x, y) # (TODO) check executed kernel, should extend autograd.profiler to fused # kernels jit_o.backward(grad) o = t(ref_x, ref_y) o.backward(grad) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o, jit_o) self.assertEqual(x.grad, ref_x.grad) self.assertEqual(y.grad, ref_y.grad) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_dropout_inference_fusion(self): dtype = torch.float device = "cuda" x = torch.randn([10, 4, 8], dtype=dtype, device=device) def t(x: torch.Tensor, p: float, train: bool): o = torch.nn.functional.dropout(x, p, training=train) o = o + 1.0 return o t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x, 0.15, False) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_dropout_train_nograd_fusion(self): dtype = torch.float device = "cuda" x = torch.randn([10, 4, 8], dtype=dtype, device=device) def t(x: torch.Tensor, p: float, train: bool): o = torch.nn.functional.dropout(x, p, training=train) o = o + 1.0 return o t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x, 0.0, True) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_dropout_train_nograd_prob_check(self): dtype = torch.float device = "cuda" x = torch.randn([1024, 1024], dtype=dtype, device=device) def t(x: torch.Tensor, p: float, train: bool): o = torch.nn.functional.dropout(x, p, training=train) o = o * 2.0 return o t_jit = torch.jit.script(t) for prob in [0.0, 0.15, 0.5, 0.85, 1.]: torch.cuda.manual_seed_all(123) jit_o = t_jit(x, prob, True) torch.cuda.manual_seed_all(123) jit_o = t_jit(x, prob, True) self.assertTrue(jit_o.detach().isfinite().all().item()) num_elems = x.numel() num_zeros = num_elems - jit_o.detach().count_nonzero().item() percent_zeros = num_zeros / num_elems self.assertTrue((percent_zeros >= (prob - 0.01)) and (percent_zeros <= (prob + 0.01))) self.assertGraphContainsExactly(t_jit.graph_for(x, prob, True), FUSION_GUARD, 1, consider_subgraphs=True) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_dropout_training_fusion(self): dtype = torch.float device = "cuda" sizes = [2, 3, 4, 5] def t(x: torch.Tensor, p: float, train: bool): o = torch.nn.functional.dropout(x, p, training=train) o = o * 2.0 return o def t2(x: torch.Tensor, p: float, train: bool): o = torch.nn.functional.softmax(x, dim=-1) o = torch.nn.functional.dropout(o, p, training=train) return o # disabling cache so new inputs would generate new graph t.__disable_jit_function_caching__ = True t2.__disable_jit_function_caching__ = True for fn in [t, t2]: for m_format in [torch.contiguous_format, torch.channels_last]: fn_jit = torch.jit.script(fn) x = torch.randn(sizes, dtype=dtype, device=device, requires_grad=True).to(memory_format=m_format) grads = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=m_format) # The drop probability needs to be set to zero given that the order of picking random # numbers between eager mode and the jit is different self._run_training_helper(fn_jit, fn, grads, x, 0.0, True) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_gelu(self): old_guard = torch._C._jit_set_nvfuser_guard_mode(True) dtype = torch.float device = "cuda" x = torch.randn([1024, 1024], dtype=dtype, device=device, requires_grad=True) grads = torch.randn([1024, 1024], dtype=dtype, device=device, requires_grad=False) def t(x: torch.Tensor, mode: str): o = torch.nn.functional.gelu(x, approximate=mode) o = o * 2.0 return o t_jit = torch.jit.script(t) self._run_training_helper(t_jit, t, grads, x, 'none') self._run_training_helper(t_jit, t, grads, x, 'tanh') torch._C._jit_set_nvfuser_guard_mode(old_guard) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_dropout_training_prob_check(self): dtype = torch.float device = "cuda" x = torch.randn([1024, 1024], dtype=dtype, device=device, requires_grad=True) x_nograd = torch.randn([1024, 1024], dtype=dtype, device=device) def t(x: torch.Tensor, p: float, train: bool): o = torch.nn.functional.dropout(x, p, training=train) o = o * 2.0 return o t_jit = torch.jit.script(t) for prob in [0.0, 0.15, 0.5, 0.85, 1.]: torch.cuda.manual_seed_all(123) jit_o = t_jit(x, prob, True) torch.cuda.manual_seed_all(123) jit_o = t_jit(x, prob, True) torch.cuda.manual_seed_all(123) jit_o = t_jit(x, prob, True) self.assertTrue(jit_o.detach().isfinite().all().item()) num_elems = x.numel() num_zeros = num_elems - jit_o.detach().count_nonzero().item() percent_zeros = num_zeros / num_elems self.assertTrue((percent_zeros >= (prob - 0.01)) and (percent_zeros <= (prob + 0.01))) self.assertGraphContainsExactly(t_jit.graph_for(x, prob, True), FUSION_GUARD, 1, consider_subgraphs=True) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_linear(self): in_feature = 2 out_feature = 8 # Changing the input dims to be 3-D to avoid eager mode bias fusion # The bias fusion causes some precision issues with TF-32 weight = torch.randn(out_feature, in_feature, dtype=torch.float32, device='cuda') bias = torch.randn(out_feature, dtype=torch.float32, device='cuda') def t(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor): o = torch.nn.functional.linear(x, weight, bias) o = torch.relu(o) return o # disabling cache so new inputs would generate new graph t.__disable_jit_function_caching__ = True sizes = [in_feature, ] for i in range(4): # increase input rank in each iteration sizes.insert(0, i + 2) x = torch.randn(*sizes, dtype=torch.float32, device='cuda') t_jit = torch.jit.script(t) # fusion only happens for input rank >= 4 has_fusion = 0 if len(sizes) < 4 else 1 self._run_helper(t_jit, t, x, weight, bias, check_stride=True, num_fusion=has_fusion) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_linear_symbolic_shapes(self): def fn(x: int): y = torch.zeros((3, 4, x, x + 2)).cuda() for i in range(2): inp = torch.rand((3, 4, x, x + i)).cuda() weight = torch.rand((x + 2, x + i)).cuda() bias = torch.rand((x, x + 2)).cuda() y += torch.sin(torch.nn.functional.linear(inp, weight, bias)) return y fn_s = torch.jit.script(fn) fn_s(5) fn_s(5) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_conv2d_symbolic_shapes(self): def fn(x: int): responses = [] for i in range(2): inp = torch.rand((3, 3, 32, 32)).cuda() weight = torch.rand((x + i, 3, 7, 7)).cuda() bias = torch.rand((x + i)).cuda() res = torch.nn.functional.conv2d(inp, weight, bias, padding=3) responses.append(res) return responses fn_s = torch.jit.script(fn) fn_s(5) fn_s(5) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_backward_type(self): # not super useful to check gradient of integer/bool, so skipping here type_pairs = [ (torch.float, torch.half), (torch.double, torch.half), (torch.float, torch.double), ] if TEST_BF16: type_pairs += [ (torch.float, torch.bfloat16), (torch.double, torch.bfloat16), ] for x_type, y_type in type_pairs: x = torch.randn(4, 2, dtype=x_type, device='cuda', requires_grad=True) y = torch.randn(4, 2, dtype=y_type, device='cuda', requires_grad=True) grad = torch.randn(4, 2, dtype=torch.float, device='cuda') def test1(x: torch.Tensor, y: torch.Tensor): o = torch.add(x, y) o = torch.add(o, y) o = torch.add(o, y) o = torch.add(o, y) o = o + 1.0 return o test1_jit = torch.jit.script(test1) for i in range(3): jit_o = test1_jit(x, y) jit_o.backward(grad) bwd_graph = list( list(test1_jit.get_debug_state().execution_plans.values())[ 0].code.grad_executor_states()[0].execution_plans.values() )[0].graph FileCheck().check(FUSION_GROUP).run(bwd_graph) self.assertEqual(x.grad.dtype, x.dtype) self.assertEqual(y.grad.dtype, y.dtype) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_autocast_1(self): def t(x: torch.Tensor, y: torch.Tensor): o = x * 2.0 o = torch.softmax(o, dim=-1) o = o * 3.0 o = torch._C._nn.linear(o, y) return o x = torch.randn(8, 4, dtype=torch.half, device='cuda', requires_grad=True) y = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True) grad = torch.randn(8, 4, dtype=torch.half, device='cuda', requires_grad=False) t_jit = torch.jit.script(t) for i in range(3): with torch.cuda.amp.autocast(): jit_o = t_jit(x, y) if i == 2: fwd_graph = t_jit.graph_for(x, y) jit_o.backward(grad) self.assertGraphContainsExactly(fwd_graph, FUSION_GUARD, 1, consider_subgraphs=True) with torch.cuda.amp.autocast(): bwd_graph = list( list(t_jit.get_debug_state().execution_plans.values())[ 0].code.grad_executor_states()[0].execution_plans.values() )[0].graph FileCheck().check(FUSION_GROUP).run(bwd_graph) self.assertEqual(jit_o.dtype, torch.half) self.assertEqual(x.grad.dtype, x.dtype) self.assertEqual(y.grad.dtype, y.dtype) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_autocast_2(self): def t(x: torch.Tensor): o = x * 2.0 o = torch.softmax(o, dim=-1) o = o * 3.0 o = torch.softmax(o, dim=-1) o = o * 4.0 return o x = torch.randn(8, 4, dtype=torch.half, device='cuda', requires_grad=True) grad = torch.randn(8, 4, dtype=torch.float, device='cuda', requires_grad=False) t_jit = torch.jit.script(t) for i in range(3): with torch.cuda.amp.autocast(): jit_o = t_jit(x) if i == 2: fwd_graph = t_jit.graph_for(x) jit_o.backward(grad) self.assertGraphContainsExactly(fwd_graph, FUSION_GUARD, 1, consider_subgraphs=True) with torch.cuda.amp.autocast(): bwd_graph = list( list(t_jit.get_debug_state().execution_plans.values())[ 0].code.grad_executor_states()[0].execution_plans.values() )[0].graph FileCheck().check(FUSION_GROUP).run(bwd_graph) self.assertEqual(jit_o.dtype, torch.float) self.assertEqual(x.grad.dtype, x.dtype) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @unittest.skipIf(not TEST_BF16, "device does not support BFloat16") def test_autocast_1_bfloat(self): def t(x: torch.Tensor, y: torch.Tensor): o = x * 2.0 o = torch.softmax(o, dim=-1) o = o * 3.0 o = torch._C._nn.linear(o, y) return o x = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda', requires_grad=True) y = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True) grad = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda', requires_grad=False) t_jit = torch.jit.script(t) for i in range(3): with torch.cuda.amp.autocast(dtype=torch.bfloat16): jit_o = t_jit(x, y) if i == 2: fwd_graph = t_jit.graph_for(x, y) jit_o.backward(grad) self.assertGraphContainsExactly(fwd_graph, FUSION_GUARD, 1, consider_subgraphs=True) with torch.cuda.amp.autocast(dtype=torch.bfloat16): bwd_graph = list( list(t_jit.get_debug_state().execution_plans.values())[ 0].code.grad_executor_states()[0].execution_plans.values() )[0].graph FileCheck().check(FUSION_GROUP).run(bwd_graph) self.assertEqual(jit_o.dtype, torch.bfloat16) self.assertEqual(x.grad.dtype, x.dtype) self.assertEqual(y.grad.dtype, y.dtype) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @unittest.skipIf(not TEST_BF16, "device does not support BFloat16") def test_autocast_2_bfloat(self): def t(x: torch.Tensor): o = x * 2.0 o = torch.softmax(o, dim=-1) o = o * 3.0 o = torch.softmax(o, dim=-1) o = o * 4.0 return o x = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda', requires_grad=True) grad = torch.randn(8, 4, dtype=torch.float, device='cuda', requires_grad=False) t_jit = torch.jit.script(t) for i in range(3): with torch.cuda.amp.autocast(dtype=torch.bfloat16): jit_o = t_jit(x) if i == 2: fwd_graph = t_jit.graph_for(x) jit_o.backward(grad) self.assertGraphContainsExactly(fwd_graph, FUSION_GUARD, 1, consider_subgraphs=True) with torch.cuda.amp.autocast(dtype=torch.bfloat16): bwd_graph = list( list(t_jit.get_debug_state().execution_plans.values())[ 0].code.grad_executor_states()[0].execution_plans.values() )[0].graph FileCheck().check(FUSION_GROUP).run(bwd_graph) self.assertEqual(jit_o.dtype, torch.float) self.assertEqual(x.grad.dtype, x.dtype) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_to_dtype_fp32_to_fp16(self): def t(x: torch.Tensor): o = x * 2.0 o = o.to(dtype=torch.half) o = o * 3.0 return o x = torch.randn(8, 4, dtype=torch.float, device='cuda') t_jit = torch.jit.script(t) for i in range(3): jit_o = t_jit(x) self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1) self.assertEqual(jit_o.dtype, torch.half) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_to_dtype_fp16_to_fp32(self): def t(x: torch.Tensor): o = x * 2.0 o = o.to(dtype=torch.float) o = o * 3.0 return o x = torch.randn(8, 4, dtype=torch.half, device='cuda') t_jit = torch.jit.script(t) for i in range(3): jit_o = t_jit(x) self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1) self.assertEqual(jit_o.dtype, torch.float) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_to_dtype_fp16_to_fp16(self): def t(x: torch.Tensor): o = x * 2.0 o = o.to(dtype=torch.half) o = o * 3.0 return o x = torch.randn(8, 4, dtype=torch.half, device='cuda') t_jit = torch.jit.script(t) for i in range(3): jit_o = t_jit(x) self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1) self.assertEqual(jit_o.dtype, torch.half) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @unittest.skipIf(not TEST_BF16, "device does not support BFloat16") def test_to_dtype_fp32_to_bf16(self): def t(x: torch.Tensor): o = x * 2.0 o = o.to(dtype=torch.bfloat16) o = o * 3.0 return o x = torch.randn(8, 4, dtype=torch.float, device='cuda') t_jit = torch.jit.script(t) for i in range(3): jit_o = t_jit(x) self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1) self.assertEqual(jit_o.dtype, torch.bfloat16) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @unittest.skipIf(not TEST_BF16, "device does not support BFloat16") def test_to_dtype_bf16_to_fp32(self): def t(x: torch.Tensor): o = x * 2.0 o = o.to(dtype=torch.float) o = o * 3.0 return o x = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda') t_jit = torch.jit.script(t) for i in range(3): jit_o = t_jit(x) self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1) self.assertEqual(jit_o.dtype, torch.float) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @unittest.skipIf(not TEST_BF16, "device does not support BFloat16") def test_to_dtype_bf16_to_bf16(self): def t(x: torch.Tensor): o = x * 2.0 o = o.to(dtype=torch.bfloat16) o = o * 3.0 return o x = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda') t_jit = torch.jit.script(t) for i in range(3): jit_o = t_jit(x) self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1) self.assertEqual(jit_o.dtype, torch.bfloat16) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(not TEST_MULTIGPU, "requires multiple CUDA device") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_multiple_device_pw(self): def t(x): o = x + 1.0 o = torch.relu(o) return o x = torch.randn(2, dtype=torch.float32, device="cuda") t_jit = torch.jit.script(t) for i in range(3): jit_o = t_jit(x) self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1) torch.cuda.device(1) x = x.to("cuda:1") jit_o = t_jit(x) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_graph_for_with_missing_optimized_engine(self): x = torch.randn(8, 4, 2, dtype=torch.float, device="cuda").requires_grad_() def t(x: torch.Tensor, flag: bool): x = x + 1.0 x = torch.relu(x) if flag: o = x + 1.0 o = torch.relu(o) else: o = x + 2.0 o = torch.relu(o) return o t_jit = torch.jit.script(t) jit_o = t_jit(x, False) jit_o = t_jit(x, False) jit_o = t_jit(x, True) o = t(x, True) self.assertEqual(o, jit_o) # since the output value is not used at all, the fusion operator should # have been optimized away self.assertGraphContainsExactly(t_jit.graph_for(x, True), FUSION_GUARD, 1, True) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_branches(self): in_feature = 2 out_feature = 4 x = torch.randn(4, in_feature, dtype=torch.float32, device='cuda') weight = torch.randn(out_feature, in_feature, dtype=torch.float32, device='cuda') bias = torch.randn(out_feature, dtype=torch.float32, device='cuda') def t(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, flag: bool): if flag: o = torch.nn.functional.linear(x, weight, bias) o = o + 1.0 o = torch.relu(o) else: o = x.sum() o = o + 2.0 o = torch.relu(o) return o t_jit = torch.jit.script(t) jit_o = t_jit(x, weight, bias, True) jit_o = t_jit(x, weight, bias, True) o = t(x, weight, bias, True) self.assertEqual(o, jit_o) # since the output value is not used at all, the fusion operator should # have been optimized away self.assertGraphContainsExactly(t_jit.graph_for(x, weight, bias, True), FUSION_GUARD, 1) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_scalar_tensor(self): x = torch.empty([], device="cuda", dtype=torch.float32) def t(x: torch.Tensor): o = x + 1.0 o = torch.nn.functional.relu(o) return o # bias set to true. t_jit = torch.jit.script(t) jit_o = t_jit(x) jit_o = t_jit(x) o = t(x) self.assertEqual(o, jit_o) # since the output value is not used at all, the fusion operator should # have been optimized away self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1) @unittest.skipIf(os.environ.get('PYTORCH_NO_CUDA_MEMORY_CACHING') is not None, "skipping graph_rng when caching allocator is disabled") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(CUDA_MAJOR < 11, "requires CUDA11 or above") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_graph_rng(self): self.assertTrue(torch._C._jit_nvfuser_enabled()) size = 10000 a = torch.randn((size,), device="cuda", dtype=torch.float) def t(x): o = x + 1.0 o = torch.nn.functional.dropout(o, p=0.1) o = o + 1.0 o = torch.nn.functional.dropout(o, p=0.1) return o t_jit = torch.jit.script(t) for _ in range(3): t_jit(a) self.assertGraphContainsExactly(t_jit.graph_for(a), FUSION_GUARD, 1) # Control (jitted, ungraphed) torch.cuda.manual_seed(5) eager_out = a.clone() for _ in range(3): eager_out = t_jit(eager_out) graph_in = a.clone() g = torch.cuda.CUDAGraph() s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): torch.cuda.manual_seed(5) g.capture_begin() graph_out = t_jit(graph_in) g.capture_end() torch.cuda.current_stream().wait_stream(s) # g is now a jitted, graphed version of t. # Runs a (jitted, graphed) -> (jitted, ungraphed) -> (jitted, graphed) sequence. # The ops in the overall sequence should be the same as Control. g.replay() # graph_out is now filled with g's result. Use it as ungraphed input. out = t_jit(graph_out) graph_in.copy_(out) g.replay() # If replay() updated RNG state correctly, graph_out should now equal eager_out self.assertEqual(graph_out, eager_out) def _test_batch_norm_impl_index_helper(self, batch, c, hw, affine=True, track_running_stats=True, train=True, dtype=torch.float32): # enabling inlining to avoid counter increment in BN forward torch._C._debug_set_autodiff_subgraph_inlining(True) class MyModule(torch.nn.Module): def __init__(self, num_features=10, affine=True, track_running_stats=True): super(MyModule, self).__init__() self.bn = torch.nn.BatchNorm2d(num_features, 1e-5, affine=affine, track_running_stats=track_running_stats).to(dtype=dtype) def forward(self, x): o = self.bn(x) o = o * 2.0 return o x = torch.randn(batch, c, hw, hw, dtype=torch.float, device="cuda").to(dtype=dtype).requires_grad_() grad = torch.randint(-20, 20, (batch, c, hw, hw), device="cuda").to(dtype=dtype).div(-10) my_module = MyModule(c, affine, track_running_stats).cuda() ref_module = MyModule(c, affine, track_running_stats).cuda() if not train: my_module.eval() ref_module.eval() t_jit = torch.jit.script(my_module) ref_module.load_state_dict(my_module.state_dict()) ref_x = x.detach().requires_grad_() for i in range(0, 3): jit_o = t_jit(x) jit_o.backward(grad) # TODO: remove this run? o = ref_module(ref_x) o.backward(grad) has_affine = ref_module.bn.weight is not None has_running_stats = ref_module.bn.running_mean is not None if has_running_stats: my_module.bn.running_mean.zero_() my_module.bn.running_var.fill_(1.0) ref_module.bn.running_mean.zero_() ref_module.bn.running_var.fill_(1.0) # Verify that when train is False, we don't have grad for weight/bias. if has_affine and train: my_module.bn.weight.grad.zero_() my_module.bn.bias.grad.zero_() ref_module.bn.weight.grad.zero_() ref_module.bn.bias.grad.zero_() x.grad.zero_() ref_x.grad.zero_() # real runs jit_o = t_jit(x) jit_o.backward(grad) o = ref_module(ref_x) o.backward(grad) # assert forward graph fusion self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1, consider_subgraphs=True) # assert backward graph fusion bwd_graph = list( list(t_jit.get_debug_state().execution_plans.values())[0].code.grad_executor_states()[0] .execution_plans.values())[0].graph self.assertGraphContainsExactly(bwd_graph, FUSION_GUARD, 1, consider_subgraphs=True) e0 = 1e-5 if dtype is not torch.half else 1e-3 e1 = 1e-4 if dtype is not torch.half else 1e-3 e2 = 1e-3 if dtype is not torch.half else 1e-2 self.assertTrue(self._compare("comparing output failed", jit_o, o, e0)) self.assertTrue(self._compare("comparing input grad failed", x.grad, ref_x.grad, e1)) # TODO: switch to welford and reduce this to 1e-5 # The 1e-3 looks bad, but we don't have welford in codegen, so numeric # is very different between reference and codegen. if has_affine and train: self.assertTrue(self._compare("comparing weight grad failed", my_module.bn.weight.grad, ref_module.bn.weight.grad, e2)) self.assertTrue(self._compare("comparing bias grad failed", my_module.bn.bias.grad, ref_module.bn.bias.grad, e1)) if has_running_stats: self.assertTrue(self._compare("comparing running_mean failed", my_module.bn.running_mean, ref_module.bn.running_mean, e0)) self.assertTrue(self._compare("comparing running_var failed", my_module.bn.running_var, ref_module.bn.running_var, e0)) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_batch_norm_half(self): with torch.backends.cudnn.flags(enabled=True): setups = [ [True, True], [False, False], [True, False], [False, True]] for training_and_track, affine in itertools.product(setups, [True, False]): training, track_running_stats = training_and_track self._test_batch_norm_impl_index_helper(4, 8, 5, affine, track_running_stats, training, torch.half) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_batch_norm_impl_index_inner_bcast(self): # the repro self._test_batch_norm_impl_index_helper(2, 1, 1, False, True, True) # running the full set setups = [ [True, True], [False, False], [True, False], [False, True]] for training_and_track, affine in itertools.product(setups, [True, False]): training, track_running_stats = training_and_track self._test_batch_norm_impl_index_helper(2, 1, 1, affine, track_running_stats, training) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_batch_norm_impl_index_correctness(self): with torch.backends.cudnn.flags(enabled=True): batch = [2, 7, 16] channels = [4, 89, 19, 32] hw = [1, 8, 17, 32] # avoid tolerance failure in CI torch.cuda.manual_seed_all(211) # failing sizes (2, 1, 1, 1) # failing sizes (2, 89, 8, 8) training False, track True, affine: False for b, c, hw in itertools.product(batch, channels, hw): setups = [ [True, True], [False, False], [True, False], [False, True]] for training_and_track, affine in itertools.product(setups, [True, False]): training, track_running_stats = training_and_track self._test_batch_norm_impl_index_helper(b, c, hw, affine, track_running_stats, training) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_softplus_fuser(self): def shifted_softplus(x: torch.Tensor, shift: float): return functional.softplus(x) - shift jitted = torch.jit.script(shifted_softplus) inp = torch.randn(4, 2, dtype=torch.float32, device="cuda").requires_grad_() inp_ref = inp.detach().clone().requires_grad_() grad = torch.randn(4, 2, dtype=torch.float32, device="cuda") aten_o = shifted_softplus(inp_ref, 0.693147) aten_o.backward(grad) aten_grad = inp_ref.grad for i in range(3): jit_o = jitted(inp, 0.693147) inp.grad = None # avoid accumulation on grad jit_o.backward(grad) jit_grad = inp.grad assert torch.allclose(jit_o, aten_o) assert torch.allclose(jit_grad, aten_grad) self.assertGraphContains(jitted.graph_for(inp, 0.693147), FUSION_GROUP, True) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_inplace_removal(self): def t(x: torch.Tensor): o = torch.nn.functional.softmax(x, dim=0) o += x return o.relu_() jitted = torch.jit.script(t) inp = torch.randn(4, 2, dtype=torch.float32, device="cuda") for i in range(3): jit_o = jitted(inp) graph = jitted.graph_for(inp) self.assertGraphContains(graph, FUSION_GROUP, True) self.assertGraphContains(graph, 'aten::add', True) self.assertGraphContains(graph, 'aten::relu', True) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_conv2d_bias(self): def t(x: torch.Tensor, w: torch.Tensor, bias: torch.Tensor): o = torch.nn.functional.conv2d(x, w, bias) return o.relu() jitted = torch.jit.script(t) inp = torch.randn(4, 5, 3, 3, dtype=torch.float32, device="cuda") weight = torch.randn(2, 5, 2, 2, dtype=torch.float32, device="cuda") bias = torch.randn(2, dtype=torch.float32, device="cuda") for i in range(3): jit_o = jitted(inp, weight, bias) graph = jitted.graph_for(inp) self.assertGraphContains(graph, FUSION_GROUP, True) def t_not_fused(x: torch.Tensor, w: torch.Tensor): o = torch.nn.functional.conv2d(x, w) return o.relu() jitted_not_fused = torch.jit.script(t_not_fused) for i in range(3): jit_o = jitted_not_fused(inp, weight) graph = jitted_not_fused.graph_for(inp) self.assertGraphContainsExactly(graph, FUSION_GROUP, 0) self.assertGraphContains(graph, 'aten::relu', True) def t_bias(x: torch.Tensor, w: torch.Tensor, bias: torch.Tensor): o = torch.nn.functional.conv2d(x, w, bias) return o.relu() jitted_bias = torch.jit.script(t_bias) for i in range(3): jit_o = jitted_bias(inp, weight, bias) graph = jitted_bias.graph_for(inp) self.assertGraphContains(graph, FUSION_GROUP, True) self.assertGraphContains(graph, 'prim::add_optional', True) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_remove_output_used_only_in_dtype(self): class MyModule(torch.nn.Module): def __init__(self, num_features=4): super(MyModule, self).__init__() self.bn0 = torch.nn.BatchNorm2d(num_features) self.bn1 = torch.nn.BatchNorm2d(num_features) def forward(self, x, y): o1 = self.bn0(x) o2 = self.bn1(y) return torch.relu(o1 + o2) t = MyModule(4).float().cuda() jitted = torch.jit.script(t) x = torch.randn(3, 4, 2, 5, dtype=torch.float32, device="cuda") y = torch.randn(3, 4, 2, 5, dtype=torch.float32, device="cuda") with torch.cuda.amp.autocast(True): for i in range(5): jit_o = jitted(x, y) jit_o = jitted(x, y) o = t(x, y) self.assertTrue(torch.allclose(jit_o, o)) graph = jitted.graph_for(x, y) self.assertGraphContains(graph, FUSION_GROUP, True) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_fix_shape_expression_bn(self): class MyModule(torch.nn.Module): def __init__(self, num_features=4): super(MyModule, self).__init__() self.bn = torch.nn.BatchNorm2d(num_features) def forward(self, x, y): out1 = self.bn(x) out2 = out1 + y out3 = torch.relu(out2) return out3 t = MyModule(4).float().cuda() jitted = torch.jit.script(t) x = torch.randn(3, 4, 2, 5, dtype=torch.float32, device="cuda") y = torch.randn(3, 4, 2, 5, dtype=torch.float32, device="cuda") with torch.cuda.amp.autocast(True): for i in range(5): jit_o = jitted(x, y) jit_o = jitted(x, y) o = t(x, y) self.assertTrue(torch.allclose(jit_o, o)) graph = jitted.graph_for(x, y) self.assertGraphContains(graph, FUSION_GROUP, True) def _run_fwd_helper(self, func, ops, *args): jitted = torch.jit.script(func) for i in range(3): jit_o = jitted(*args) jit_o = jitted(*args) o = func(*args) for oo, jit_oo in zip(o, jit_o): self.assertEqual(oo.dtype, jit_oo.dtype) self.assertEqual(oo, jit_oo) graph = jitted.graph_for(*args) self.assertGraphContains(graph, FUSION_GROUP, True) for op in ops: self.assertGraphContainsExactly(graph, op, 0) @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_sibling_fusion(self): device = "cuda" dtype = torch.float x = torch.randn(2, 5, dtype=dtype, device=device) y = torch.randn(2, 5, dtype=dtype, device=device) def t(x: torch.Tensor): o1 = x + 1.0 o2 = x * 0.5 return o1, o2 self._run_fwd_helper(t, ['aten::add', 'aten::mul'], x) def t2(x: torch.Tensor, y: torch.Tensor): o1 = x.sum(0) o2 = (x * y).sum(0) return o1, o2 self._run_fwd_helper(t2, ['aten::sum', 'aten::mul'], x, y) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_clean_profile_ivalue(self): device = "cuda" dtype = torch.float x = torch.randn(2, 5, dtype=dtype, device=device, requires_grad=True) # turn on autodiff subgraph inlining # this is to verify that we clean up profile_ivalue node out side of # fusion code path. torch._C._debug_set_autodiff_subgraph_inlining(True) def t(x: torch.Tensor, flag: bool): return torch.dropout(x, 0.5, flag) jit_t = torch.jit.script(t) for idx in range(5): out = jit_t(x, True) graph = jit_t.graph_for(x, True) out = jit_t(x, False) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_sibling_fusion_no_scalar_inputs(self): device = "cuda" dtype = torch.float x = torch.randn(2, 5, dtype=dtype, device=device) y = torch.randn(3, dtype=dtype, device=device) # no tensor dependency between o1/o2, we shouldn't be fusing them def t(x: torch.Tensor, y: torch.Tensor): o1 = x + 1 o2 = y - 1 return o1, o2 jitted = torch.jit.script(t) for i in range(3): jit_o = jitted(x, y) graph = jitted.graph_for(x, y) self.assertGraphContainsExactly(graph, FUSION_GROUP, 0) def _bias_view_relu_helper(self, shape, output_shape, dtype, device, error): class BiasViewRelu(torch.nn.Module): def __init__(self): super(BiasViewRelu, self).__init__() self.bias = torch.nn.Parameter(torch.randn(shape, dtype=dtype, device=device), requires_grad=False) with torch.no_grad(): self.bias.fill_(10) def forward(self, inputs: torch.Tensor, view_shape: List[int]): o = inputs + self.bias o = o.view(view_shape) return torch.relu(o) t = BiasViewRelu() x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) t_jit = torch.jit.script(t) # profiling jit_o = t_jit(x, output_shape) # optimization jit_o = t_jit(x, output_shape) # final jit_o = t_jit(x, output_shape) # eager - baseline o = t(x, output_shape) self.assertEqual(o.dtype, jit_o.dtype) self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) graph = t_jit.graph_for(x, output_shape) has_inferred_dimension = any([dim == -1 for dim in output_shape]) if has_inferred_dimension: # prohibit fusing when view_shape contains an inferred dimension self.assertGraphContainsExactly(graph, FUSION_GROUP, 0) self.assertGraphContainsExactly(graph, 'prim::view_copy', 0) else: self.assertGraphContains(graph, FUSION_GUARD) self.assertGraphContains(graph, 'prim::view_copy', True) def _alias_bias_view_relu_helper(self, shape, output_shape, dtype, device, error): class BiasViewRelu(torch.nn.Module): def __init__(self): super(BiasViewRelu, self).__init__() self.bias = torch.nn.Parameter(torch.randn(shape, dtype=dtype, device=device), requires_grad=False) with torch.no_grad(): self.bias.fill_(10) def forward(self, inputs : torch.Tensor, bias : torch.Tensor, view_shape : List[int]): o = inputs.view(view_shape) inputs.add_(bias) return torch.relu(o) t = BiasViewRelu() x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) t_jit = torch.jit.script(t) # profiling jit_o = t_jit(x.clone(), bias, output_shape) # optimization jit_o = t_jit(x.clone(), bias, output_shape) # final jit_o = t_jit(x.clone(), bias, output_shape) # eager - baseline o = t(x.clone(), bias, output_shape) self.assertEqual(o.dtype, jit_o.dtype) self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) graph = t_jit.graph_for(x, bias, output_shape) self.assertGraphContainsExactly(graph, FUSION_GUARD, 0) self.assertGraphContainsExactly(graph, 'prim::view_copy', 0) # generate random view given original view def _random_view(self, original_view, max_len=8, max_views=10000): class Moves(enum.Enum): Merge = 0 Split = 1 Broadcast = 2 ImplicitBroadcast = 3 Keep = 4 def valid(old_view, new_view): old_view_size = reduce(operator.mul, old_view) new_view_size = reduce(operator.mul, new_view) return old_view_size == new_view_size # given a random starting number, find the nearest divisor def find_nearest_divisor(N): if 2 >= (N - 1): return -1 result = random.randint(2, N - 1) while (N % result) != 0: result += 1 return result complete_views = set([tuple(original_view)]) to_visit = [] # empty new view, curent originaal view, start pos=0, move count = 0, last_move to_visit.append(([], original_view, 0, [], Moves.Keep)) # depth-first search of view shapes, starting from the original view while len(to_visit) > 0 and len(complete_views) < max_views: new_view, old_view, odx, move_list, last_move = to_visit[-1] to_visit.pop() # iterate over each move type for idx in range(len(Moves)): state = Moves(idx) new_view_clone = copy.deepcopy(new_view) old_view_clone = copy.deepcopy(old_view) new_move_list = move_list + [state] new_odx = odx # Update state using Move state if state == Moves.Keep: new_size = old_view_clone[odx] new_view_clone.append(new_size) new_odx += 1 elif state == Moves.Merge: if odx + 1 < len(old_view_clone): new_size = old_view_clone[odx] * old_view_clone[odx + 1] new_view_clone.append(new_size) new_odx += 2 else: continue elif state == Moves.Broadcast and last_move != Moves.Broadcast: new_view_clone.append(1) elif state == Moves.Split: new_size = find_nearest_divisor(old_view_clone[odx]) if new_size == -1: continue new_view_clone.append(new_size) old_view_clone[odx] = int(old_view[odx] / new_size) if old_view_clone[odx] == 1: new_odx += 1 elif state == Moves.ImplicitBroadcast: old_view_clone.insert(odx + 1, 1) new_size = old_view[odx] * 1 new_view_clone.append(new_size) new_odx += 2 if new_odx < len(old_view_clone) and len(new_move_list) < max_len: to_visit.append((new_view_clone, old_view_clone, new_odx, new_move_list, state)) elif (valid(original_view, new_view_clone)): final_new_view = tuple(new_view_clone) complete_views.add(final_new_view) return list(complete_views) # ndims - number of dimensions # test_fn - view test function def _view_test_generator(self, ndims, test_fn): # create random tensor # max value for each dimension max_size = 10e7 max_value = max(int(pow(max_size, 1. / ndims)), 1) sizes = [random.randint(1, max_value) for idx in range(ndims)] x = torch.randn(sizes) original_sizes = list(x.size()) all_views = self._random_view(original_sizes) random.shuffle(all_views) max_samples = 20 max_views = min(len(all_views), max_samples) total = 0 correct = 0 # test random combinations of compatible views for idx in range(max_views): for jdx in range(idx + 1, max_views): total += 1 test_fn(all_views[idx], all_views[jdx], torch.float, 'cuda', 1e-6) @unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since view is disabled now") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_view(self): torch._C._jit_set_nvfuser_guard_mode(True) self._bias_view_relu_helper([2, 3, 4, 5], [-1, 4, 5], torch.float, 'cuda', 1e-6) for ndims in range(1, 5): self._view_test_generator(ndims, self._bias_view_relu_helper) self._alias_bias_view_relu_helper([2, 3, 4, 5], [1, 6, 1, 2, 2, 5, 1], torch.float, 'cuda', 1e-6) def _bias_flatten_relu_helper(self, shape, start_dim, end_dim, dtype, device, error): class BiasFlattenRelu(torch.nn.Module): def __init__(self): super(BiasFlattenRelu, self).__init__() self.bias = torch.nn.Parameter(torch.randn(shape, dtype=dtype, device=device), requires_grad=False) with torch.no_grad(): self.bias.fill_(10) def forward(self, inputs : torch.Tensor, start_dim : int, end_dim : int): o = inputs + self.bias o = o.flatten(start_dim, end_dim) return torch.relu(o) t = BiasFlattenRelu() x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x, start_dim, end_dim) self.assertGraphContains(t_jit.graph_for(x, start_dim, end_dim), 'prim::flatten_copy', True) def _alias_bias_flatten_relu_helper(self, shape, start_dim, end_dim, dtype, device, error): class BiasFlattenRelu(torch.nn.Module): def __init__(self): super(BiasFlattenRelu, self).__init__() self.bias = torch.nn.Parameter(torch.randn(shape, dtype=dtype, device=device), requires_grad=False) with torch.no_grad(): self.bias.fill_(10) def forward(self, inputs : torch.Tensor, bias : torch.Tensor, start_dim : int, end_dim : int): o = inputs.flatten(start_dim, end_dim) inputs.add_(bias) return torch.relu(o) t = BiasFlattenRelu() x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) t_jit = torch.jit.script(t) # profiling jit_o = t_jit(x.clone(), bias, start_dim, end_dim) # optimization jit_o = t_jit(x.clone(), bias, start_dim, end_dim) # final jit_o = t_jit(x.clone(), bias, start_dim, end_dim) # eager - baseline o = t(x.clone(), bias, start_dim, end_dim) self.assertEqual(o.dtype, jit_o.dtype) self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) graph = t_jit.graph_for(x, bias, start_dim, end_dim) self.assertGraphContainsExactly(graph, FUSION_GUARD, 0) self.assertGraphContainsExactly(graph, 'prim::flatten_copy', 0) @unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since flatten is disabled now") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_flatten(self): torch._C._jit_set_nvfuser_guard_mode(True) self._bias_flatten_relu_helper([2, 3, 4, 5], 0, -1, torch.float, 'cuda', 1e-6) self._bias_flatten_relu_helper([2, 3, 4, 5], 1, -1, torch.float, 'cuda', 1e-6) self._bias_flatten_relu_helper([2, 3, 4, 5], 2, -1, torch.float, 'cuda', 1e-6) self._bias_flatten_relu_helper([2, 3, 4, 5], 0, 3, torch.float, 'cuda', 1e-6) self._bias_flatten_relu_helper([2, 3, 4, 5], 1, 2, torch.float, 'cuda', 1e-6) self._bias_flatten_relu_helper([2, 3, 4, 5], 2, 2, torch.float, 'cuda', 1e-6) self._alias_bias_flatten_relu_helper([2, 3, 4, 5], 0, -1, torch.float, 'cuda', 1e-6) self._alias_bias_flatten_relu_helper([2, 3, 4, 5], 1, -1, torch.float, 'cuda', 1e-6) self._alias_bias_flatten_relu_helper([2, 3, 4, 5], 2, -1, torch.float, 'cuda', 1e-6) self._alias_bias_flatten_relu_helper([2, 3, 4, 5], 0, 3, torch.float, 'cuda', 1e-6) self._alias_bias_flatten_relu_helper([2, 3, 4, 5], 1, 2, torch.float, 'cuda', 1e-6) self._alias_bias_flatten_relu_helper([2, 3, 4, 5], 2, 2, torch.float, 'cuda', 1e-6) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_strict_fusion(self): def success(x): with torch.jit.strict_fusion(): return x + x + x scripted = self.checkScript(success, (torch.rand([4], device='cuda'),)) g = torch.jit.last_executed_optimized_graph() FileCheck().check_not("aten::add").check("prim::CudaFusionGroup").run(g) def failure(x): with torch.jit.strict_fusion(): return x + torch.mm(x, x) + x with self.assertRaises(Exception) as error_out: foo_s = torch.jit.script(failure) foo_s(torch.rand([4, 4])) foo_s(torch.rand([4, 4])) fc = FileCheck().check("Found unfused operators") fc.check("aten::mm").run(str(error_out.exception)) def _ltc_helper(self, shape, dtype, device, error, approximate=True): # modeled after LTC linear layer class LTC(torch.nn.Module): def __init__(self): super(LTC, self).__init__() self.weight = torch.nn.Parameter(torch.randn([1024, 1024], dtype=dtype, device=device), requires_grad=False) self.bias = torch.nn.Parameter(torch.randn([1, 1024], dtype=dtype, device=device), requires_grad=False) def forward(self, inputs : torch.Tensor): o = inputs.view([32768, 1024]) o = torch.mm(o, self.weight) o = o.view([256, 128, 1024]) o = o + self.bias o = o.view([32768, 1024]) o = o.view([256, 128, 1024]) return torch.nn.functional.gelu(o) t = LTC() x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) t_jit = torch.jit.script(t) # profile/optimization runs for i in range(3): jit_o = t_jit(x) o = t(x) self.assertEqual(o.dtype, jit_o.dtype) self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) graph = t_jit.graph_for(x) self.assertGraphContains(graph, FUSION_GUARD) self.assertGraphContains(graph, 'prim::view_copy', True) @unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since view is disabled now") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_nested_view(self): self._ltc_helper([256, 128, 1024], torch.float, 'cuda', 1e-6) def _bias_squeeze_relu_helper(self, shape, dtype, device, error): class BiasSqueezeRelu(torch.nn.Module): def __init__(self): super(BiasSqueezeRelu, self).__init__() def forward(self, inputs: torch.Tensor, bias: torch.Tensor): o = inputs + bias o = torch.squeeze(o) return torch.relu(o) t = BiasSqueezeRelu() x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) t_jit = torch.jit.script(t) jit_o = t_jit(x, bias) jit_o = t_jit(x, bias) jit_o = t_jit(x, bias) o = t(x, bias) self.assertEqual(o.dtype, jit_o.dtype) self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) graph = t_jit.graph_for(x, bias) self.assertGraphContains(graph, FUSION_GUARD) self.assertGraphContains(graph, 'prim::squeeze_copy', True) def _alias_bias_squeeze_relu_helper(self, shape, dtype, device, error): class BiasSqueezeRelu(torch.nn.Module): def __init__(self): super(BiasSqueezeRelu, self).__init__() def forward(self, inputs: torch.Tensor, bias: torch.Tensor): o = torch.squeeze(inputs) inputs.add_(bias) return torch.relu(o) t = BiasSqueezeRelu() x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) t_jit = torch.jit.script(t) jit_o = t_jit(x.clone(), bias) jit_o = t_jit(x.clone(), bias) jit_o = t_jit(x.clone(), bias) o = t(x.clone(), bias) self.assertEqual(o.dtype, jit_o.dtype) self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) graph = t_jit.graph_for(x, bias) self.assertGraphContainsExactly(graph, FUSION_GUARD, 0) self.assertGraphContainsExactly(graph, 'prim::squeeze_copy', 0) @unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since squeeze/unsqueeze is disabled now") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_squeeze(self): self._bias_squeeze_relu_helper([1, 6, 1, 2, 2, 5, 1], torch.float, 'cuda', 1e-6) self._alias_bias_squeeze_relu_helper([1, 6, 1, 2, 2, 5, 1], torch.float, 'cuda', 1e-6) @unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since squeeze/unsqueeze is disabled now") # remove this after opinfo tests are enabled @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_squeeze_zero(self): x = torch.tensor(1.0, dtype=torch.float, device="cuda") def squeeze_0(x: torch.Tensor): o = x + 1. o = torch.squeeze(o, 0) o = o * 2. return o def squeeze_1(x: torch.Tensor): o = x + 1. o = torch.squeeze(o, -1) o = o + .5 return o squeeze_0_jit = torch.jit.script(squeeze_0) self._run_helper(squeeze_0_jit, squeeze_0, x) squeeze_1_jit = torch.jit.script(squeeze_1) self._run_helper(squeeze_1_jit, squeeze_1, x) def _bias_unsqueeze_relu_helper(self, shape, dtype, device, error): class BiasUnsqueezeRelu(torch.nn.Module): def __init__(self): super(BiasUnsqueezeRelu, self).__init__() def forward(self, inputs: torch.Tensor, bias: torch.Tensor): o = inputs + bias o = torch.unsqueeze(o, 0) return torch.relu(o) t = BiasUnsqueezeRelu() x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) t_jit = torch.jit.script(t) jit_o = t_jit(x, bias) jit_o = t_jit(x, bias) jit_o = t_jit(x, bias) o = t(x, bias) self.assertEqual(o.dtype, jit_o.dtype) self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) graph = t_jit.graph_for(x, bias) self.assertGraphContains(graph, FUSION_GUARD) self.assertGraphContains(graph, 'prim::unsqueeze_copy', True) def _alias_bias_unsqueeze_relu_helper(self, shape, dtype, device, error): class BiasUnsqueezeRelu(torch.nn.Module): def __init__(self): super(BiasUnsqueezeRelu, self).__init__() def forward(self, inputs : torch.Tensor, bias : torch.Tensor): o = torch.unsqueeze(inputs, 0) inputs.add_(bias) return torch.relu(o) t = BiasUnsqueezeRelu() x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) t_jit = torch.jit.script(t) jit_o = t_jit(x.clone(), bias) jit_o = t_jit(x.clone(), bias) jit_o = t_jit(x.clone(), bias) o = t(x.clone(), bias) self.assertEqual(o.dtype, jit_o.dtype) self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) graph = t_jit.graph_for(x, bias) self.assertGraphContainsExactly(graph, FUSION_GUARD, 0) self.assertGraphContainsExactly(graph, 'prim::unsqueeze_copy', 0) @unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since squeeze/unsqueeze is disabled now") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_unsqueeze(self): self._bias_unsqueeze_relu_helper([2, 3, 4, 5], torch.float, 'cuda', 1e-6) self._alias_bias_unsqueeze_relu_helper([2, 3, 4, 5], torch.float, 'cuda', 1e-6) @unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since unsqueeze is disabled now") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_alias_pass_fix(self): x = torch.randn(4, 24, 2, 2, dtype=torch.float, device="cuda") w = torch.randn(24, 24, 1, 1, dtype=torch.float, device="cuda") b = torch.randn(24, dtype=torch.float, device="cuda") def t(x, w, b): b2 = b + 1.0 o = torch.conv2d(x, w, b2) return o t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x, w, b) @unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since squeeze/unsqueeze is disabled now") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_squeeze_negative_dim(self): x = torch.randn(4, 24, 1, 2, dtype=torch.float, device="cuda") def t(x): o = x + 1.0 o = o.squeeze(-2) o = o * 2.0 return o t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_singleton_fusion(self): x = torch.randn(4, 2, device="cuda") with nvfuser_singleton_fusion(True): def t(x): return x.relu() t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_issue1445_fusion(self): def f(t0, t1, t2, t3): masked_input = torch.where(t1, t2, t3) total = masked_input.sum([0, 1, 2, 3]) sizes : List[int] = [] t10 = torch.reshape(t0, sizes) t7 = total / t10 t4 = t7.to(dtype=torch.float) return t4 x = torch.randn(1, 1, 1, 1, device='cuda').to(dtype=torch.long) y = torch.randn(3, 2, 1, 1, device='cuda').to(dtype=torch.bool).expand([3, 2, 1, 2]) z = torch.randn(3, 2, 1, 2, device='cuda') w = torch.tensor(1.5, device='cuda') f_jit = torch.jit.script(f) for i in range(5): out_jit = f_jit(x, y, z, w) out = f(x, y, z, w) self.assertEqual(out, out_jit) self.assertGraphContainsExactly(f_jit.graph_for(x, y, z, w), FUSION_GROUP, 1) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_disable_sibling_fuse(self): x = torch.randn(4, 2, device="cuda") y = torch.randn(8, device="cuda") s = torch.tensor(1.5, device="cuda") with nvfuser_horizontal_fusion(False): def t(x, y, s): o1 = x + s o2 = y + s return o1, o2 t_jit = torch.jit.script(t) for i in range(5): t_jit(x, y, s) # sibling fusion should be disabled with the flag self.assertGraphContainsExactly(t_jit.graph_for(x, y, s), FUSION_GUARD, 0) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_build_shape_expression_native_dropout(self): x = torch.randn(4, 2, device="cuda") def t(x): o, mask = torch.native_dropout(x, 0.0, True) o1 = o.sigmoid() o2 = mask.float().sigmoid() return (o1, o2) t_jit = torch.jit.script(t) jit_o = t_jit(x) jit_o = t_jit(x) o = t(x) for oo, jit_oo in zip(o, jit_o): self.assertEqual(oo.dtype, jit_oo.dtype) self.assertEqual(oo, jit_oo) self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_scalar_tensor_permuted(self): x = torch.randn(4, 2, 3, device="cuda").permute([1, 2, 0]) y = torch.tensor(1.0, device="cuda") with nvfuser_singleton_fusion(True): def t(x, y): return x + y t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x, y) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_cpu_scalar(self): x = torch.randn(4, 2, 3, device="cuda") y = torch.tensor(1.0, device="cpu") z = torch.tensor(2.0, device="cpu") with nvfuser_singleton_fusion(True): # testing cpu scalar tensor promotion def t(x, y): return x + y t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x, y) # scalar cpu tensor add should NOT be fused @torch.jit.script def t1(y, z): return y * z for _ in range(5): t1(y, z) self.assertGraphContainsExactly(t1.graph_for(y, z), FUSION_GUARD, 0) # everything, including scalar cpu tensor add should be fused @torch.jit.script def t2(x, y, z): tmp = y + z return tmp + x for _ in range(5): t2(x, y, z) self.assertGraphContainsExactly(t2.graph_for(x, y, z), 'aten::add', 0) self.assertGraphContainsExactly(t2.graph_for(x, y, z), FUSION_GUARD, 1) # 'cpu_tmp = y + z' shouldn't be fused. @torch.jit.script def t3(x, y, z): cpu_tmp = y + z out = x + y return cpu_tmp, out for _ in range(5): t3(x, y, z) self.assertGraphContainsExactly(t3.graph_for(x, y, z), FUSION_GUARD, 1) self.assertGraphContainsExactly(t3.graph_for(x, y, z), 'aten::add', 1) @unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since squeeze/unsqueeze is disabled now") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_shape_expression(self): x = torch.randn(4, 2, 1, 3, device="cuda") def t_unsqueeze(x): t0 = x.relu() t1 = t0.unsqueeze(1) t2 = t1 + 1.0 t3 = t1.size() return t2, t3 def t_squeeze(x): t0 = x.relu() t1 = t0.squeeze() t2 = t1 + 1.0 t3 = t1.size() return t2, t3 def t_squeeze_dim(x): t0 = x.relu() t1 = t0.squeeze(-2) t2 = t1 + 1.0 t3 = t1.size() return t2, t3 # squeezing a non-size 1 dimension should be a no op def t_squeeze_dim_no_op(x): t0 = x.relu() t1 = t0.squeeze(1) t2 = t1 + 1.0 t3 = t1.size() return t2, t3 def run(fn): jit_fn = torch.jit.script(fn) jit_o = jit_fn(x) jit_o = jit_fn(x) jit_o = jit_fn(x) o = fn(x) # output 0 is a tensor, so we check dtype and value self.assertEqual(o[0].dtype, jit_o[0].dtype) self.assertEqual(o[0], jit_o[0]) # output 1 is shape self.assertEqual(o[1], jit_o[1]) self.assertGraphContainsExactly(jit_fn.graph_for(x), FUSION_GUARD, 1) for t in [t_unsqueeze, t_squeeze, t_squeeze_dim, t_squeeze_dim_no_op]: run(t) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_scalar_cuda_tensor(self): x = torch.tensor(2.0, device="cuda") with nvfuser_singleton_fusion(True): def t(x): return x + 1.0 t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x) @torch.jit.script def t_jitted(x): return x.sum(0) for i in range(5): t_jitted(x) self.assertGraphContainsExactly(t_jitted.graph_for(x), FUSION_GUARD, 0) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_overlapped_input(self): x = torch.randn(8, device="cuda").as_strided((2, 4), (1, 1)) with nvfuser_singleton_fusion(True): def t(x): return x + 1.0 t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") def test_reduction_empty_axes(self): x = torch.randn(4, 2, 3, device="cuda").permute([1, 2, 0]) with nvfuser_singleton_fusion(True): def t(x): sizes : List[int] = [] return x.sum(sizes) t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") def test_int_tensor_input(self): x = torch.randn(4, 2, device="cuda").to(dtype=torch.int) with nvfuser_singleton_fusion(True): def t(x): return x.amax(dim=0) t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_to_boolean(self): x = torch.randn(4, 2, device="cuda") with nvfuser_singleton_fusion(True): def t(x): return x.to(dtype=torch.bool) t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x) @unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since reshape is disabled now") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_view_copy_graph_guard(self): x = torch.randn(4, 2, 3, device="cuda").permute([1, 2, 0]) y = [4, 6] with nvfuser_singleton_fusion(True): def t(x, y : List[int]): t1 = x + 1.0 t2 = t1 * 1.0 out = t2.reshape(y) return out.relu() t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x, y) @unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since view is disabled now") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_view_copy_graph_guard_double_fusion(self): x = torch.randn(2, 2, 5, device="cuda") w = torch.randn(5, 5, device="cuda") with nvfuser_singleton_fusion(True): def t(x, w): o = x.view([4, x.size()[-1]]) o = torch.matmul(o, w) o = o.view([2, 2, o.size()[1]]) return o t_jit = torch.jit.script(t) for i in range(3): jit_o = t_jit(x, w) o = t(x, w) self.assertEqual(jit_o, o) self.assertGraphContainsExactly(t_jit.graph_for(x, w), FUSION_GUARD, 2, consider_subgraphs=True) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_input_output_passthrough(self): def t(t0, t1, t2): mask = t1.to(dtype=torch.bool) masked_input = torch.where(t0, mask, t2) return masked_input, mask t_jit = torch.jit.script(t) # stick to integers, this avoid the numerical difference due to our # promotion x = torch.randn(4, 4, device='cuda').to(dtype=torch.bool) y = torch.randn(4, 4, device='cuda').to(dtype=torch.bool) z = torch.tensor(1.0, device='cuda').to(dtype=torch.bool) jit_o = t_jit(x, y, z) jit_o = t_jit(x, y, z) o = t(x, y, z) for oo, jit_oo in zip(o, jit_o): self.assertEqual(oo.dtype, jit_oo.dtype) self.assertEqual(oo, jit_oo) self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_pointwise_reference_tensor(self): def t(input1, input2, scalar): _unsafe_view = torch.ops.aten._unsafe_view(input1, [2, 4, 16]) add_ = torch.ops.aten.add_(_unsafe_view, input2) gelu_ = torch.ops.aten.gelu(add_) view_ = torch.ops.aten.view(gelu_, [8, 16]) mul_ = torch.ops.aten.mul(add_, scalar) return [view_, mul_] x = torch.randn(8, 16, device="cuda") bias = torch.randn(16, device="cuda") scalar = torch.ones(torch.Size([]), device="cuda") t_jit = torch.jit.script(t) for i in range(3): jit_o = t_jit(x, bias, scalar) o = t(x, bias, scalar) self.assertEqual(jit_o, o) self.assertGraphContains(t_jit.graph_for(x, bias, scalar), FUSION_GUARD) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") def test_native_batch_norm_backward(self): grad_output = torch.randn(4, 2, 3, device="cuda") input = torch.randn(4, 2, 3, device="cuda") weight = torch.randn(2, device="cuda") r_m = torch.randn(2, device="cuda") r_v = torch.randn(2, device="cuda").abs() save_mean = torch.randn(2, device="cuda") save_invstd = torch.randn(2, device="cuda").abs() with nvfuser_singleton_fusion(True): def t(grad_out, input, weight, r_m, r_v, save_mean, save_invstd, train: bool, eps: float, mask: List[bool]): return torch.ops.aten.native_batch_norm_backward(grad_out, input, weight, r_m, r_v, save_mean, save_invstd, train, eps, mask) t_jit = torch.jit.script(t) for i in range(4): jit_o = t_jit(grad_output, input, weight, r_m.clone(), r_v.clone(), save_mean, save_invstd, True, 1e-5, [True, True, True]) ref_m = r_m.clone() ref_v = r_v.clone() jit_o = t_jit(grad_output, input, weight, r_m, r_v, save_mean, save_invstd, True, 1e-5, [True, True, True]) o = t(grad_output, input, weight, ref_m, ref_v, save_mean, save_invstd, True, 1e-5, [True, True, True]) for oo, jit_oo in zip(o, jit_o): self.assertEqual(oo.dtype, jit_oo.dtype) self.assertEqual(oo, jit_oo) self.assertEqual(ref_m.dtype, r_m.dtype) self.assertEqual(ref_m, r_m) self.assertEqual(ref_v.dtype, r_v.dtype) self.assertEqual(ref_v, r_v) self.assertGraphContains(t_jit.graph_for(grad_output, input, weight, r_m.clone(), r_v.clone, save_mean, save_invstd, True, 1e-5, [True, True, True]), FUSION_GUARD) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_contiguous_on_broadcasted(self): x = torch.randn(4, 1, device="cuda") y = torch.randn(4, 128, device="cuda") with nvfuser_singleton_fusion(True): def t(x, y): t1 = x.expand([4, 128]) t2 = t1 * y return t2 t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x, y) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_skip_parser(self): x = torch.randn(4, 12, device="cuda") with nvfuser_singleton_fusion(True): def fn(x): t1 = x + 1.0 return t1.relu() fn_jit = torch.jit.script(fn) self._run_helper(fn_jit, fn, x) # add node should have been merged into fusion self.assertGraphContains(fn_jit.graph_for(x), FUSION_GUARD) self.assertGraphContainsExactly(fn_jit.graph_for(x), 'aten::add', 0) # flips skip parse for `aten::add`, following fusion should skip the # add node self.assertFalse(torch._C._jit_set_nvfuser_skip_node_kind("aten::add", True)) def fn_1(x): t1 = x + 2.0 # change const value so we'll not reuse plan return t1.relu() fn_1_jit = torch.jit.script(fn_1) self._run_helper(fn_1_jit, fn_1, x) # add node should have been merged into fusion self.assertGraphContains(fn_1_jit.graph_for(x), FUSION_GUARD) self.assertGraphContainsExactly(fn_1_jit.graph_for(x), 'aten::add', 1) # flips skip parse for `aten::add`, next fusion should fuse add node self.assertTrue(torch._C._jit_set_nvfuser_skip_node_kind("aten::add", True)) def fn_2(x): t1 = x + 2.0 # change const value so we'll not reuse plan return t1.relu() fn_2_jit = torch.jit.script(fn_2) self._run_helper(fn_2_jit, fn_2, x) # add node should have been merged into fusion self.assertGraphContains(fn_2_jit.graph_for(x), FUSION_GUARD) self.assertGraphContainsExactly(fn_2_jit.graph_for(x), 'aten::add', 0) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_cuda_fusion_guard(self): old_guard = torch._C._jit_set_nvfuser_guard_mode(True) class ConvModule(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return x.sin().sigmoid() mod = ConvModule().to(device="cuda") inputs = [torch.randn(20, 16, 50, 100, device="cuda", requires_grad=True)] def reduce_scalar(temp): return temp.sum() scripted = torch.jit.script(mod) with torch.no_grad(): scripted(*inputs) res = scripted(*inputs) reduce_scalar(res).backward() torch._C._jit_set_nvfuser_guard_mode(old_guard) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_nvfuser_comparison_callbacks_with_fallback(self): try: fused_result = None unfused_result = None graph_ir = None def callback(fused_outputs, unfused_outputs, graph_str): nonlocal unfused_result nonlocal fused_result nonlocal graph_ir unfused_result = unfused_outputs[-1] fused_result = fused_outputs[-1] graph_ir = graph_str torch._C._jit_nvfuser_set_comparison_callback(True, callback) def fn(x, y): z = torch.add(x, y) return torch.relu(z) x = torch.rand((4, 4)).cuda() - 0.5 y = torch.rand((4, 4)).cuda() - 0.5 fn_s = torch.jit.script(fn) fn_s(x, y) fn_s(x, y) fn_s(x, y) expected = fn(x, y) self.assertEqual(expected, fused_result) self.assertEqual(expected, unfused_result) FileCheck().check("aten::add").run(graph_ir) finally: torch._C._jit_nvfuser_clear_comparison_callback() @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_nvfuser_comparison_callbacks_without_fallback(self): try: fused_result = None unfused_result = None graph_ir = None def callback(fused_outputs, unfused_outputs, graph_str): nonlocal unfused_result nonlocal fused_result nonlocal graph_ir if len(unfused_outputs) > 0: unfused_result = unfused_outputs[-1] fused_result = fused_outputs[-1] graph_ir = graph_str torch._C._jit_nvfuser_set_comparison_callback(False, callback) def fn(x, y): z = torch.add(x, y) return torch.relu(z) x = torch.rand((4, 4)).cuda() - 0.5 y = torch.rand((4, 4)).cuda() - 0.5 fn_s = torch.jit.script(fn) fn_s(x, y) fn_s(x, y) fn_s(x, y) expected = fn(x, y) self.assertEqual(expected, fused_result) self.assertEqual(None, unfused_result) FileCheck().check("aten::add").run(graph_ir) finally: torch._C._jit_nvfuser_clear_comparison_callback() @unittest.skipIf(not RUN_NVFUSER, "requires NVFuser") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_cuda_fusion_guard_backward(self): old_guard = torch._C._jit_set_nvfuser_guard_mode(True) inp = torch.randn(10, device="cuda", requires_grad=True) grad = torch.randn(10, device="cuda") def f(x): a = x.cos().cos() return a scripted = torch.jit.script(f) with profile(activities=[ProfilerActivity.CPU]) as prof: for _ in range(5): inp.grad = None out = scripted(inp) out.backward(grad) # check that we do not have fallback triggered self.assertEqual(prof.events().table().find("fallback"), -1) torch._C._jit_set_nvfuser_guard_mode(old_guard) # TODO: generalize this @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") def test_inf_quick_patch(self): inputs = [torch.tensor([-float('inf'), float('inf'), 4.0], device="cuda"), torch.tensor([1.0, float('inf'), 4.0], device="cuda"), torch.tensor([-float('inf'), -1.5, 4.0], device="cuda"), torch.tensor([1.0, -3.0, float('nan')], device="cuda"), torch.tensor([-float('inf'), -float('inf'), -float('inf')], device="cuda"), torch.tensor([float('inf'), float('inf'), float('inf')], device="cuda"), torch.tensor([float('nan'), float('nan'), float('nan')], device="cuda")] def fn_amax(x): return x.amax(dim=0) def fn_amin(x): return x.amin(dim=0) def fn_add_nan(x): return x.relu() + float('nan') def fn_add(x): return x + 1.0 with nvfuser_singleton_fusion(True): for t in [fn_amax, fn_amin, fn_add, fn_add_nan]: for x in inputs: t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_clamp_reversed_bound(self): x = torch.tensor([1., -float('inf'), 2., float('inf'), float('nan')], device="cuda") def t(x): return x.clamp(min=1., max=0.5) with nvfuser_singleton_fusion(True): jit_t = torch.jit.script(t) self._run_helper(jit_t, t, x) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_high_rank_fusion(self): # currently we want to limit fusion to node with input where rank <= 8 rank_limit = 8 shapes = [4 for i in range(rank_limit + 1)] x = torch.randn(shapes, device="cuda") with nvfuser_singleton_fusion(True): def t(x): return x.relu() jit_t = torch.jit.script(t) for i in range(5): jit_t(x) self.assertGraphContainsExactly(jit_t.graph_for(x), FUSION_GUARD, 0) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_clamp(self): x = torch.tensor([1., float('inf'), 2., float('nan'), float('-inf')], device="cuda") def clamp_max(x): return x.clamp(max=1.5) def clamp_min_max(x): return x.clamp(min=1.5) def clamp_min(x): return x.clamp(min=1., max=3.) with nvfuser_singleton_fusion(True): for t in [clamp_max, clamp_min, clamp_min_max]: t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_device_constant(self): x = torch.randn(4, 2, device="cuda") def t(x): return torch.rand_like(x, device=torch.device(type='cuda')) # cpu tensor shouldn't be fused def t_cpu(x): return torch.rand_like(x, device=torch.device(type='cpu')) with nvfuser_singleton_fusion(True): t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x) t_cpu_jit = torch.jit.script(t_cpu) for i in range(5): t_cpu_jit(x) self.assertGraphContainsExactly(t_cpu_jit.graph_for(x), FUSION_GUARD, 0) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_scheduler_with_polymorphic_broadcast(self): device = "cuda" x0 = torch.randn(10, 128, device=device) x1 = torch.rand_like(x0) x2 = torch.randn(10, device=device) def t(x0, x1, x2): x3 = x2.unsqueeze(-1) x4 = x3 + x0 x5 = x3 + x1 x6 = x5.sum(0) return x4, x6 t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x0, x1, x2, check_stride=True) x2 = torch.randn(128, device=device) def t2(x0, x1, x2): x3 = x2.unsqueeze(0) x4 = x3 + x0 x5 = x3 + x1 x6 = x5.sum(1) return x4, x6 t2_jit = torch.jit.script(t2) self._run_helper(t2_jit, t2, x0, x1, x2, check_stride=True) class TestPassManagerCudaFuser(JitTestCase): def setUp(self): super().setUp() if RUN_NVFUSER: self.is_enabled = torch._C._jit_set_nvfuser_enabled(False) def tearDown(self): if RUN_NVFUSER: torch._C._jit_set_nvfuser_enabled(self.is_enabled) super().tearDown() @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_context_manager_test(self): x = torch.randn(4, 8, dtype=torch.float, device="cuda") y = torch.randn(4, 8, dtype=torch.float, device="cuda") with torch.jit.fuser('fuser2'): with torch.jit.fuser('fuser2'): def t1(x, y): o = x + y o = o + 2.0 return o t_jit = torch.jit.script(t1) t_jit(x, y) t_jit(x, y) self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD) def t2(x, y): o = x + y o = o + 3.0 return o t_jit_2 = torch.jit.script(t2) t_jit_2(x, y) t_jit_2(x, y) self.assertGraphContains(t_jit_2.graph_for(x, y), FUSION_GUARD) def t3(x, y): o = x + y o = o + 4.0 return o t_jit_3 = torch.jit.script(t3) t_jit_3(x, y) t_jit_3(x, y) self.assertGraphContainsExactly(t_jit_3.graph_for(x, y), FUSION_GUARD, 0) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") def test_register_fuser(self): self.assertFalse(torch._C._jit_set_nvfuser_enabled(True)) self.assertTrue(torch._C._jit_nvfuser_enabled()) self.assertTrue(torch._C._jit_set_nvfuser_enabled(True)) self.assertTrue(torch._C._jit_nvfuser_enabled()) self.assertTrue(torch._C._jit_set_nvfuser_enabled(False)) self.assertFalse(torch._C._jit_nvfuser_enabled()) @unittest.skipIf(RUN_CUDA, "Testing on CPU only") def test_register_fuser_cpu(self): with self.assertRaises(RuntimeError): torch._C._jit_set_nvfuser_enabled(True) torch._C._jit_set_nvfuser_enabled(False) @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(not TEST_WITH_ROCM, "ROCM test only") def test_register_fuser_rocm(self): with self.assertRaises(RuntimeError): torch._C._jit_set_nvfuser_enabled(True) torch._C._jit_set_nvfuser_enabled(False) # See TestNNCOpInfoParent class TestCudaFuserOpInfoParent(JitCommonTestCase): pass class TestCudaFuserOpInfo(TestCudaFuserOpInfoParent): def setUp(self): super(TestCudaFuserOpInfoParent, self).setUp() if RUN_NVFUSER: self.cuda_fuser_options = CudaFuserTestOptions() # enables guard mode since tracing could change graph to violate guard. torch._C._jit_set_nvfuser_guard_mode(True) self.nvfuser_single_node_mode = torch._C._jit_set_nvfuser_single_node_mode(True) def tearDown(self): if RUN_NVFUSER: self.cuda_fuser_options.restore() torch._C._jit_set_nvfuser_single_node_mode(self.nvfuser_single_node_mode) super(TestCudaFuserOpInfoParent, self).tearDown() @slowTest @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @ops(op_db, dtypes=OpDTypes.supported) def test_nvfuser_correctness(self, device, dtype, op): variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op) for variant, sample in variant_sample_pairs: trace = create_traced_fn(self, variant, cache_traced_fn=True) ref = variant(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) val = trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) self.assertEqual(ref, val, exact_layout=True) # Note: Clearing CU after NVFuser tests # https://github.com/pytorch/pytorch/issues/35600 # each torch.jit.trace adds state to the _python_cu compilation unit # since this test traces a lot of functions, out-of-memory can occur # if the CU is not cleared. torch.jit._state._python_cu.drop_all_functions() @slowTest @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @ops(op_db, allowed_dtypes=(torch.float16, torch.bfloat16, torch.float32, torch.float64, torch.complex64, torch.complex128)) def test_nvfuser_extremal_values(self, device, dtype, op): variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op) def _get_extremal_tensor(x, val, dtype): if x.dtype != dtype: return x return torch.full_like(x, val) def _get_extremal_input(x, val, dtype): if isinstance(x, torch.Tensor): return _get_extremal_tensor(x, val, dtype) elif is_iterable_of_tensors(x): return [_get_extremal_tensor(y, val, dtype) for y in x] return x def _get_extremal_sample(sample: SampleInput, val, dtype): extremal_sample = SampleInput( input=_get_extremal_input(sample.input, val, dtype), args=[_get_extremal_input(x, val, dtype) for x in sample.args], kwargs={k: _get_extremal_input(v, val, dtype) for k, v in sample.kwargs.items()}, ) return extremal_sample def _get_extremal_samples(sample: SampleInput, dtype): vals = [float('inf'), float('-inf'), float('nan')] if dtype.is_complex: complex_vals = itertools.product(vals, vals) vals = list(map(lambda x: complex(*x), complex_vals)) for val in vals: yield _get_extremal_sample(sample, val, dtype) variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op) for variant, sample in variant_sample_pairs: trace = create_traced_fn(self, variant, cache_traced_fn=True) trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) for extremal_sample in _get_extremal_samples(sample, dtype): try: with freeze_rng_state(): ref = variant(*clone_inputs((extremal_sample.input, *extremal_sample.args)), **extremal_sample.kwargs) except (torch._C._LinAlgError, RuntimeError, ValueError): # if eager errors out, then don't expect NVFuser to pass continue with freeze_rng_state(): val = trace(*clone_inputs((extremal_sample.input, *extremal_sample.args)), **extremal_sample.kwargs) self.assertEqual(val, ref, equal_nan=True, exact_device=True) # See [Note: Clearing CU after NVFuser tests] torch.jit._state._python_cu.drop_all_functions() instantiate_device_type_tests(TestCudaFuserOpInfo, globals(), only_for=("cuda")) if __name__ == '__main__': run_tests()