diff --git a/CMakeLists.txt b/CMakeLists.txt index 94996175f64..bd694c9d3e6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -687,6 +687,8 @@ if(USE_FBGEMM AND ((CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND CMAKE_SIZEOF_VO set(USE_FBGEMM OFF) endif() +set(BUILD_ONEDNN_GRAPH OFF) + include(cmake/Dependencies.cmake) if(USE_CUDA AND (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 10.2) AND (CMAKE_HOST_SYSTEM_NAME MATCHES "Windows")) diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index ec61336f92c..0b307552d16 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -43,6 +43,8 @@ namespace c10 { _(prim, FusionGroup) \ _(prim, CudaFusionGroup) \ _(prim, CudaFusionGuard) \ + _(prim, oneDNNFusionGroup) \ + _(prim, oneDNNFusionGuard) \ _(prim, FunctionalGraph) \ _(prim, add_optional) \ _(prim, view_copy) \ @@ -316,6 +318,7 @@ namespace c10 { _(attr, cache_id) \ _(attr, new_axis) \ _(attr, warn_id) \ + _(attr, output_layouts) \ _(attr, allowzero) \ _(attr, seen_none) \ _(attr, overload_name) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 32acdc6b762..788db5d4999 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -656,6 +656,26 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) set_source_files_properties(${TORCH_SRC_DIR}/csrc/jit/passes/frozen_conv_add_relu_fusion.cpp PROPERTIES COMPILE_FLAGS "-DUSE_CUDA=1") endif() + if(USE_MLCOMPUTE) + include(../mlc/mlc_build.cmake) + endif() + + if(BUILD_ONEDNN_GRAPH) + list(APPEND Caffe2_CPU_SRCS + ${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/graph_fuser.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/graph_rewriter.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/graph_helper.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/register_interface.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/interface.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/kernel.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/defer_size_check.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/layout_propagation.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/prepare_binary.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/guard_shape.cpp + ) + endif() + if(USE_ROCM) list(APPEND Caffe2_HIP_SRCS ${Caffe2_GPU_HIP_JIT_FUSERS_SRCS}) if(USE_NCCL) diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index 4d3febbdfc4..b17bb29a140 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -12,86 +12,116 @@ # MKLDNN_USE_NATIVE_ARCH : Whether native CPU instructions should be used in MKLDNN. This should be turned off for # general packaging to avoid incompatible CPU instructions. Default: OFF. -IF (NOT MKLDNN_FOUND) +IF(NOT MKLDNN_FOUND) + SET(MKLDNN_LIBRARIES) + SET(MKLDNN_INCLUDE_DIR) -SET(MKLDNN_LIBRARIES) -SET(MKLDNN_INCLUDE_DIR) + SET(IDEEP_ROOT "${PROJECT_SOURCE_DIR}/third_party/ideep") + SET(MKLDNN_ROOT "${PROJECT_SOURCE_DIR}/third_party/ideep/mkl-dnn/third_party/oneDNN") + IF(NOT APPLE AND NOT WIN32 AND NOT BUILD_LITE_INTERPRETER) + MESSAGE("-- Will build oneDNN Graph") + SET(LLGA_ROOT "${PROJECT_SOURCE_DIR}/third_party/ideep/mkl-dnn") + SET(BUILD_ONEDNN_GRAPH ON) + ENDIF(NOT APPLE AND NOT WIN32 AND NOT BUILD_LITE_INTERPRETER) -SET(IDEEP_ROOT "${PROJECT_SOURCE_DIR}/third_party/ideep") -SET(MKLDNN_ROOT "${IDEEP_ROOT}/mkl-dnn/third_party/oneDNN") + FIND_PACKAGE(BLAS) + FIND_PATH(IDEEP_INCLUDE_DIR ideep.hpp PATHS ${IDEEP_ROOT} PATH_SUFFIXES include) + FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include) + IF(NOT MKLDNN_INCLUDE_DIR) + EXECUTE_PROCESS(COMMAND git${CMAKE_EXECUTABLE_SUFFIX} submodule update --init --jobs 0 mkl-dnn WORKING_DIRECTORY ${IDEEP_ROOT}) + FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include) + ENDIF(NOT MKLDNN_INCLUDE_DIR) + IF(BUILD_ONEDNN_GRAPH) + FIND_PATH(LLGA_INCLUDE_DIR oneapi/dnnl/dnnl_graph.hpp PATHS ${LLGA_ROOT} PATH_SUFFIXES include) + ENDIF(BUILD_ONEDNN_GRAPH) -FIND_PACKAGE(BLAS) -FIND_PATH(IDEEP_INCLUDE_DIR ideep.hpp PATHS ${IDEEP_ROOT} PATH_SUFFIXES include) -FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include) -IF (NOT MKLDNN_INCLUDE_DIR) - EXECUTE_PROCESS(COMMAND git${CMAKE_EXECUTABLE_SUFFIX} submodule update --init --jobs 0 mkl-dnn WORKING_DIRECTORY ${IDEEP_ROOT}) - FIND_PATH(MKLDNN_INCLUDE_DIR mkldnn.hpp mkldnn.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include) -ENDIF(NOT MKLDNN_INCLUDE_DIR) + IF(NOT IDEEP_INCLUDE_DIR OR NOT MKLDNN_INCLUDE_DIR) + MESSAGE(STATUS "MKLDNN source files not found!") + RETURN() + ENDIF(NOT IDEEP_INCLUDE_DIR OR NOT MKLDNN_INCLUDE_DIR) + LIST(APPEND MKLDNN_INCLUDE_DIR ${IDEEP_INCLUDE_DIR}) + IF(BUILD_ONEDNN_GRAPH) + LIST(APPEND MKLDNN_INCLUDE_DIR ${LLGA_INCLUDE_DIR}) + ENDIF(BUILD_ONEDNN_GRAPH) + IF(MKL_FOUND) + ADD_DEFINITIONS(-DIDEEP_USE_MKL) + # Append to mkldnn dependencies + LIST(APPEND MKLDNN_LIBRARIES ${MKL_LIBRARIES}) + LIST(APPEND MKLDNN_INCLUDE_DIR ${MKL_INCLUDE_DIR}) + ELSE(MKL_FOUND) + SET(MKLDNN_USE_MKL "NONE" CACHE STRING "" FORCE) + ENDIF(MKL_FOUND) -IF (NOT IDEEP_INCLUDE_DIR OR NOT MKLDNN_INCLUDE_DIR) - MESSAGE(STATUS "MKLDNN source files not found!") - RETURN() -ENDIF(NOT IDEEP_INCLUDE_DIR OR NOT MKLDNN_INCLUDE_DIR) -LIST(APPEND MKLDNN_INCLUDE_DIR ${IDEEP_INCLUDE_DIR}) -IF(MKL_FOUND) - ADD_DEFINITIONS(-DIDEEP_USE_MKL) - # Append to mkldnn dependencies - LIST(APPEND MKLDNN_LIBRARIES ${MKL_LIBRARIES}) - LIST(APPEND MKLDNN_INCLUDE_DIR ${MKL_INCLUDE_DIR}) -ELSE(MKL_FOUND) - SET(MKLDNN_USE_MKL "NONE" CACHE STRING "" FORCE) -ENDIF(MKL_FOUND) + SET(MKL_cmake_included TRUE) + IF(NOT MKLDNN_CPU_RUNTIME) + SET(MKLDNN_CPU_RUNTIME "OMP" CACHE STRING "") + ELSEIF(MKLDNN_CPU_RUNTIME STREQUAL "TBB") + IF(USE_TBB) + MESSAGE(STATUS "MKL-DNN is using TBB") -SET(MKL_cmake_included TRUE) -IF (NOT MKLDNN_CPU_RUNTIME) - SET(MKLDNN_CPU_RUNTIME "OMP" CACHE STRING "") -ELSEIF (MKLDNN_CPU_RUNTIME STREQUAL "TBB") - IF (USE_TBB) - MESSAGE(STATUS "MKL-DNN is using TBB") + SET(TBB_cmake_included TRUE) + SET(Threading_cmake_included TRUE) - SET(TBB_cmake_included TRUE) - SET(Threading_cmake_included TRUE) - - SET(DNNL_CPU_THREADING_RUNTIME ${MKLDNN_CPU_RUNTIME}) - INCLUDE_DIRECTORIES(${TBB_INCLUDE_DIR}) - LIST(APPEND EXTRA_SHARED_LIBS TBB::tbb) - ELSE() - MESSAGE(FATAL_ERROR "MKLDNN_CPU_RUNTIME is set to TBB but TBB is not used") - ENDIF() -ENDIF() -MESSAGE(STATUS "MKLDNN_CPU_RUNTIME = ${MKLDNN_CPU_RUNTIME}") - -SET(MKLDNN_CPU_RUNTIME ${MKLDNN_CPU_RUNTIME} CACHE STRING "" FORCE) -SET(DNNL_BUILD_TESTS FALSE CACHE BOOL "" FORCE) -SET(DNNL_BUILD_EXAMPLES FALSE CACHE BOOL "" FORCE) -SET(DNNL_LIBRARY_TYPE STATIC CACHE STRING "" FORCE) -SET(DNNL_ENABLE_PRIMITIVE_CACHE TRUE CACHE BOOL "" FORCE) -IF(MKLDNN_USE_NATIVE_ARCH) # Disable HostOpts in MKLDNN unless MKLDNN_USE_NATIVE_ARCH is set. - SET(DNNL_ARCH_OPT_FLAGS "HostOpts" CACHE STRING "" FORCE) -ELSE() - IF(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") - IF(CPU_INTEL) - SET(DNNL_ARCH_OPT_FLAGS "-msse4" CACHE STRING "" FORCE) + SET(DNNL_CPU_THREADING_RUNTIME ${MKLDNN_CPU_RUNTIME}) + INCLUDE_DIRECTORIES(${TBB_INCLUDE_DIR}) + LIST(APPEND EXTRA_SHARED_LIBS TBB::tbb) + ELSE() + MESSAGE(FATAL_ERROR "MKLDNN_CPU_RUNTIME is set to TBB but TBB is not used") ENDIF() - ELSE() - SET(DNNL_ARCH_OPT_FLAGS "" CACHE STRING "" FORCE) ENDIF() -ENDIF() + MESSAGE(STATUS "MKLDNN_CPU_RUNTIME = ${MKLDNN_CPU_RUNTIME}") -ADD_SUBDIRECTORY(${MKLDNN_ROOT}) -IF(NOT TARGET dnnl) - MESSAGE("Failed to include MKL-DNN target") - RETURN() -ENDIF(NOT TARGET dnnl) + SET(MKLDNN_CPU_RUNTIME ${MKLDNN_CPU_RUNTIME} CACHE STRING "" FORCE) + SET(DNNL_BUILD_TESTS FALSE CACHE BOOL "" FORCE) + SET(DNNL_BUILD_EXAMPLES FALSE CACHE BOOL "" FORCE) + SET(DNNL_LIBRARY_TYPE STATIC CACHE STRING "" FORCE) + SET(DNNL_ENABLE_PRIMITIVE_CACHE TRUE CACHE BOOL "" FORCE) + IF(BUILD_ONEDNN_GRAPH) + SET(DNNL_GRAPH_LIBRARY_TYPE STATIC CACHE STRING "" FORCE) + ENDIF(BUILD_ONEDNN_GRAPH) + IF(MKLDNN_USE_NATIVE_ARCH) # Disable HostOpts in MKLDNN unless MKLDNN_USE_NATIVE_ARCH is set. + SET(DNNL_ARCH_OPT_FLAGS "HostOpts" CACHE STRING "" FORCE) + ELSE() + IF(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + IF(CPU_INTEL) + SET(DNNL_ARCH_OPT_FLAGS "-msse4" CACHE STRING "" FORCE) + ENDIF() + ELSE() + SET(DNNL_ARCH_OPT_FLAGS "" CACHE STRING "" FORCE) + ENDIF() + ENDIF() -IF(NOT APPLE AND CMAKE_COMPILER_IS_GNUCC) - TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-maybe-uninitialized) - TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-strict-overflow) - TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-error=strict-overflow) -ENDIF(NOT APPLE AND CMAKE_COMPILER_IS_GNUCC) -LIST(APPEND MKLDNN_LIBRARIES dnnl) + IF(BUILD_ONEDNN_GRAPH) + ADD_SUBDIRECTORY(${LLGA_ROOT}) + IF(NOT TARGET dnnl_graph) + MESSAGE("Failed to include LLGA target") + RETURN() + ENDIF(NOT TARGET dnnl_graph) -SET(MKLDNN_FOUND TRUE) -MESSAGE(STATUS "Found MKL-DNN: TRUE") + IF(CMAKE_COMPILER_IS_GNUCC) + TARGET_COMPILE_OPTIONS(dnnl_graph PRIVATE -Wno-maybe-uninitialized) + TARGET_COMPILE_OPTIONS(dnnl_graph PRIVATE -Wno-strict-overflow) + TARGET_COMPILE_OPTIONS(dnnl_graph PRIVATE -Wno-error=strict-overflow) + ENDIF(CMAKE_COMPILER_IS_GNUCC) + ENDIF(BUILD_ONEDNN_GRAPH) + + IF(NOT TARGET dnnl) + MESSAGE("Failed to include MKL-DNN target") + RETURN() + ENDIF(NOT TARGET dnnl) + + IF(NOT APPLE AND CMAKE_COMPILER_IS_GNUCC) + TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-maybe-uninitialized) + TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-strict-overflow) + TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-error=strict-overflow) + ENDIF(NOT APPLE AND CMAKE_COMPILER_IS_GNUCC) + LIST(APPEND MKLDNN_LIBRARIES ${MKL_OPENMP_LIBRARY}) + IF(BUILD_ONEDNN_GRAPH) + LIST(APPEND MKLDNN_LIBRARIES "$") + ENDIF(BUILD_ONEDNN_GRAPH) + LIST(APPEND MKLDNN_LIBRARIES dnnl) + + SET(MKLDNN_FOUND TRUE) + MESSAGE(STATUS "Found MKL-DNN: TRUE") ENDIF(NOT MKLDNN_FOUND) diff --git a/cmake/public/mkldnn.cmake b/cmake/public/mkldnn.cmake index 87935625f9b..50404d3b30d 100644 --- a/cmake/public/mkldnn.cmake +++ b/cmake/public/mkldnn.cmake @@ -16,3 +16,15 @@ set_property( set_property( TARGET caffe2::mkldnn PROPERTY INTERFACE_LINK_LIBRARIES ${MKLDNN_LIBRARIES}) +if(BUILD_ONEDNN_GRAPH) + if(NOT TARGET caffe2::dnnl_graph) + add_library(caffe2::dnnl_graph INTERFACE IMPORTED) + endif() + + set_property( + TARGET caffe2::dnnl_graph PROPERTY INTERFACE_INCLUDE_DIRECTORIES + ${MKLDNN_INCLUDE_DIR}) + set_property( + TARGET caffe2::dnnl_graph PROPERTY INTERFACE_LINK_LIBRARIES + ${MKLDNN_LIBRARIES}) +endif() diff --git a/docs/source/jit.rst b/docs/source/jit.rst index dec190fac18..70c5f26c284 100644 --- a/docs/source/jit.rst +++ b/docs/source/jit.rst @@ -61,6 +61,8 @@ Creating TorchScript Code ScriptFunction freeze optimize_for_inference + enable_onednn_fusion + onednn_fusion_enabled set_fusion_strategy strict_fusion save diff --git a/test/test_jit_llga_fuser.py b/test/test_jit_llga_fuser.py new file mode 100644 index 00000000000..7c60cee984f --- /dev/null +++ b/test/test_jit_llga_fuser.py @@ -0,0 +1,489 @@ +# Owner(s): ["module: mkldnn"] +import torch +import unittest +import itertools + +import torch.nn as nn +import torch.nn.functional as F +from torch.testing._internal.jit_utils import JitTestCase +from torch.testing._internal.common_utils import run_tests, TEST_SCIPY, IS_WINDOWS, IS_MACOS + +LLGA_FUSION_GROUP = 'prim::oneDNNFusionGroup' +LLGA_NOT_ENABLED = not torch._C.has_mkldnn or IS_WINDOWS or IS_MACOS + + +def warmup_forward(f, *args, profiling_count=2): + for i in range(profiling_count): + results = f(*args) + + return results + + +class JitLlgaTestCase(JitTestCase): + def setUp(self): + torch.jit.enable_onednn_fusion(True) + + def tearDown(self): + torch.jit.enable_onednn_fusion(False) + + def checkTrace(self, m, x, *args, **kwargs): + if isinstance(m, torch.nn.Module): + m.eval() + with torch.no_grad(), \ + torch._jit_internal._disable_emit_hooks(): + traced = torch.jit.trace(m, x) + if isinstance(m, torch.nn.Module): + traced = torch.jit.freeze(traced) + warmup_forward(traced, *x) + fwd_graph = traced.graph_for(*x) + + ref_o = m(*x) + jit_o = traced(*x) + self.assertEqual(jit_o, ref_o) + return traced, fwd_graph + + def assertFused(self, graph, fused_patterns): + for pat in fused_patterns: + self.assertGraphContainsExactly(graph, pat, 0) + + +try: + import torchvision + HAS_TORCHVISION = True +except ImportError: + HAS_TORCHVISION = False +except RuntimeError: + HAS_TORCHVISION = False +skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, 'no torchvision') + +def get_eltwise_fn(name): + if hasattr(torch, name): + return getattr(torch, name) + elif hasattr(F, name): + return getattr(F, name) + else: + raise NameError('Eltwise function %s not found' % name) + + +@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled") +class TestOp(JitLlgaTestCase): + def test_conv2d(self): + for [spatial, in_channels, out_channels, kernel, padding, stride, dilation, g, bias] in itertools.product( + [7, 8], + [8, 15], + [7, 16], + [3, 4], + [0, 2], + [1, 2], + [1, 2], + [1, 2], + [True, False]): + + m = nn.Conv2d(in_channels=in_channels * g, + out_channels=out_channels * g, + kernel_size=kernel, + padding=padding, + stride=stride, + dilation=dilation, + groups=g, + bias=bias) + + x = torch.rand(1, in_channels * g, spatial, spatial) + _, graph = self.checkTrace(m, [x]) + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) + + def test_bn2d(self): + m = nn.BatchNorm2d(32).eval() + x = torch.rand(1, 32, 28, 28) + _, graph = self.checkTrace(m, [x]) + # single-op partition shouldn't be created for softmax + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0) + + def test_eltwise(self): + class M(nn.Module): + def __init__(self, eltwise_fn): + super(M, self).__init__() + self.eltwise = eltwise_fn + + def forward(self, x): + return self.eltwise(x) + + for eltwise in ['relu', 'gelu']: + eltwise_fn = get_eltwise_fn(eltwise) + m = M(eltwise_fn) + x = torch.rand(1, 32, 28, 28) + _, graph = self.checkTrace(m, [x]) + # single-op partition shouldn't be created. + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0) + + def test_max_pool2d(self): + for [spatial, kernel, padding, stride, dilation, ceil_mode] in itertools.product( + [15, 16, 17, 18, 19], + [4, 5], + [0, 1, 2], + [1, 2], # [1, 2, 4], TODO: fix issue in pad calculation + [1], # [1, 2], TODO: backend support for dilation + [True, False]): + + m = nn.MaxPool2d(kernel_size=kernel, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode) + + x = torch.rand(1, 4, spatial, spatial) + _, graph = self.checkTrace(m, [x]) + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) + + def test_avg_pool2d(self): + for [spatial, kernel, padding, stride, ceil_mode, count_include_pad] in itertools.product( + [15, 16, 17, 18, 19], + [4, 5], + [0, 1, 2], + [1, 2, 4], + [False], # TODO: oneDNN Graph does not fully support ceil_mode=True + [True, False]): + + m = nn.AvgPool2d(kernel_size=kernel, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad) + + x = torch.rand(1, 4, spatial, spatial) + _, graph = self.checkTrace(m, [x]) + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) + + def test_variable_kernel_avg_pool2d(self): + class M(nn.Module): + def __init__(self): + super(M, self).__init__() + + def forward(self, x): + x = F.avg_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=0, count_include_pad=False) + return x + + x = torch.randn(1, 1000, 1, 1) + m = M() + _, graph = self.checkTrace(m, [x]) + # kernel_size is not Constant, shouldn't have any LLGA_FUSION_GROUP + # TODO: with shape specialization, should have 1 LLGA_FUSION_GROUP + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0) + + def test_softmax(self): + for dim in [-4, -3, -2, -1, 0, 1, 2, 3]: + m = nn.Softmax(dim=dim) + x = torch.rand(8, 12, 12, 12) + _, graph = self.checkTrace(m, [x]) + # single-op partition shouldn't be created for softmax + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0) + + def test_linear(self): + for bias in [True, False]: + x = torch.rand(32, 28) + m = torch.nn.Linear(in_features=28, out_features=64, bias=bias) + _, graph = self.checkTrace(m, [x]) + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) + self.assertFused(graph, ['aten::linear']) + + def _gen_binary_inputs(self, gen_permute=True): + for xshape, yshape in [ + [[1, 32, 28, 28], [1, 32, 28, 28]], + [[1, 32, 28, 28], [1, 1, 28, 28]], + [[1, 32, 28, 28], [28]], + [[1, 32, 28, 28], [1]], + + ]: + yield torch.rand(xshape), torch.rand(yshape) + if gen_permute and xshape != yshape: + yield torch.rand(yshape), torch.rand(xshape) + + def test_add(self): + def forward_add(x, y): + return torch.add(x, y, alpha=2) + + for x, y in self._gen_binary_inputs(): + _, graph = self.checkTrace(forward_add, [x, y]) + # single-op partitions shouldn't be created + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0) + + def test_add_scalar(self): + def add_scalar(x): + return 42 + x + 3.14 + + x = torch.rand(32, 32) + _, graph = self.checkTrace(add_scalar, [x]) + # single-op partitions shouldn't be created. + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0) + + def test_addmm(self): + def addmm(x, y, z): + # alpha and beta are 1, by default + return torch.addmm(z, x, y) + + x = torch.rand(64, 32) + y = torch.rand(32, 32) + z = torch.rand(64, 32) + _, graph = self.checkTrace(addmm, [x, y, z]) + # single-op partition should be created for matmul with bias. + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) + + def test_mul(self): + def forward_mul(x, y): + return torch.mul(x, y) * 3 + + for x, y in self._gen_binary_inputs(): + _, graph = self.checkTrace(forward_mul, [x, y]) + # single-op partitions shouldn't be created + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0) + + def test_identity_binary(self): + def forward(x): + return x * 1 + 0.0 + + x = torch.rand(32) + _, graph = self.checkTrace(forward, [x]) + self.assertFused(graph, ['aten::add', 'aten::mul']) + + def test_layer_norm(self): + # TODO: support more normalized_shape + m = torch.nn.LayerNorm(10) + x = torch.randn(2, 5, 10, 10) + _, graph = self.checkTrace(m, [x]) + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) + + def test_cat(self): + def cat_along_dim(d): + def forward_cat(*inputs): + return torch.cat(inputs, d) + return forward_cat + + for xshape in [ + [8, 8, 8, 8], + [64, 8, 32], + [2048, 64], + ]: + for d in range(len(xshape)): + x = torch.rand(xshape) + _, graph = self.checkTrace(cat_along_dim(d), [x, x, x]) + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) + + def test_typecheck(self): + x = torch.rand(32, 28) + m = torch.nn.Linear(in_features=28, out_features=64, bias=True) + traced, graph = self.checkTrace(m, [x]) + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) + self.assertFused(graph, ['aten::linear']) + # change the shape of the input, we should enter fallback graph + x = torch.rand(5, 28) + self.assertEqual(m(x), traced(x)) + + +@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled") +class TestFusionPattern(JitLlgaTestCase): + def test_conv2d_eltwise(self): + class M(nn.Module): + def __init__(self, eltwise_fn): + super(M, self).__init__() + self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True) + self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=False) + self.eltwise = eltwise_fn + + def forward(self, x): + x = self.conv1(x) + x = self.eltwise(x) + x = self.conv2(x) + x = self.eltwise(x) + return x + + # for eltwise in ['relu', 'sigmoid', 'sqrt', 'abs', 'square', 'hardtanh']: + for eltwise in ['relu']: + for inplace in [True, False]: + eltwise_fn_name = eltwise + '_' if inplace else eltwise + eltwise_fn = get_eltwise_fn(eltwise_fn_name) + + m = M(eltwise_fn) + x = torch.rand(1, 32, 28, 28) + _, graph = self.checkTrace(m, [x]) + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2) + # test if relu_ is replace with relu by mutation removal pass + self.assertFused(graph, ['aten::' + eltwise_fn_name]) + # test if relu is fused into the fusion group + self.assertFused(graph, ['aten::' + eltwise]) + + def test_conv2d_bn(self): + class M(nn.Module): + def __init__(self): + super(M, self).__init__() + self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True) + self.bn1 = nn.BatchNorm2d(32) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + return x + + m = M().eval() + x = torch.rand(1, 32, 28, 28) + _, graph = self.checkTrace(m, [x]) + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) + self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm']) + + + def test_conv2d_bn_relu(self): + class M(nn.Module): + def __init__(self): + super(M, self).__init__() + self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True) + self.bn1 = nn.BatchNorm2d(32) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + return x + + m = M().eval() + x = torch.rand(1, 32, 28, 28) + _, graph = self.checkTrace(m, [x]) + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) + self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm', + 'aten::relu']) + + def test_bn2d_eltwise(self): + class M(nn.Module): + def __init__(self, eltwise_fn): + super(M, self).__init__() + self.eltwise = eltwise_fn + self.bn = nn.BatchNorm2d(32) + + def forward(self, x): + x = self.bn(x) + x = self.eltwise(x) + return x + + for eltwise in ['relu']: + eltwise_fn = get_eltwise_fn(eltwise) + m = M(eltwise_fn).eval() + x = torch.rand(1, 32, 28, 28) + _, graph = self.checkTrace(m, [x]) + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) + self.assertFused(graph, ['aten::' + eltwise]) + + def test_linear_eltwise(self): + class M(nn.Module): + def __init__(self, eltwise_fn, bias): + super(M, self).__init__() + self.linear = nn.Linear(28, 64, bias) + self.eltwise = eltwise_fn + + def forward(self, x): + x = self.linear(x) + x = self.eltwise(x) + return x + + for [has_bias, eltwise] in itertools.product( + [True, False], + ['relu', 'gelu', 'sigmoid', 'hardtanh', 'relu6', 'elu']): + + eltwise_fn = get_eltwise_fn(eltwise) + m = M(eltwise_fn, has_bias) + x = torch.rand(32, 28, requires_grad=False) + _, graph = self.checkTrace(m, [x]) + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) + self.assertFused(graph, ['aten::' + eltwise]) + + def test_conv2d_sum(self): + class M(nn.Module): + def __init__(self, bias=False): + super(M, self).__init__() + self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=bias) + self.bn1 = nn.BatchNorm2d(32) + self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=bias) + self.bn2 = nn.BatchNorm2d(32) + self.relu = nn.ReLU() + self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=bias) + self.bn3 = nn.BatchNorm2d(32) + + def forward(self, x, y): + x = self.conv1(x) + x = self.bn1(x) + y = self.conv2(y) + y = self.bn2(y) + z = self.relu(x + y) + z = self.conv3(z) + z = self.bn3(z) + return z + + for bias in [True, False]: + m = M(bias).eval() + x = torch.rand(1, 32, 16, 16, requires_grad=False) + y = torch.rand(1, 32, 16, 16, requires_grad=False) + _, graph = self.checkTrace(m, [x, y]) + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3) + + def test_wildcard(self): + class M(nn.Module): + def __init__(self): + super(M, self).__init__() + self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True) + self.eltwise = nn.ReLU() + + def forward(self, x): + x = self.conv1(x) + y = self.eltwise(x) + return [x, y] + + # The pattern is as the following: + # conv + # | \ + # eltwise \ + # | \ + # ListConstruct + # + # The output of conv is used by a wildcard op: ListConstruct. + # Thus conv-eltwise cannot be selected into the same Partition. + m = M() + x = torch.rand(1, 32, 28, 28) + _, graph = self.checkTrace(m, [x]) + # conv can exist in a single-op oneDNN Graph partition but not relu + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) + self.assertFused(graph, ['aten::_convolution']) + + +@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled") +class TestModel(JitLlgaTestCase): + @skipIfNoTorchVision + def _test_vision(self, model_name): + m = getattr(torchvision.models, model_name)().eval() + x = torch.rand(1, 3, 224, 224) / 10 + _, graph = self.checkTrace(m, [x]) + self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm', + 'aten::relu', 'aten::linear', + 'aten::avg_pool2d', 'aten::max_pool2d']) + + +for model_name, enabled in [ + ['resnet50', True], + ['resnext50_32x4d', True], + ['resnext101_32x8d', True], + ['densenet121', True], + ['googlenet', TEST_SCIPY], + ['mobilenet_v2', True], + ['mnasnet1_0', True], + ['squeezenet1_0', True], + ['vgg16', True], + ['alexnet', True], + ['shufflenet_v2_x1_0', True], + ['wide_resnet50_2', True], +]: + def wrapper(mname): + @unittest.skipIf(not enabled, 'Disabled') + def test(self): + return self._test_vision(mname) + return test + + setattr(TestModel, 'test_vision_%s' % model_name, wrapper(model_name)) + +if __name__ == '__main__': + run_tests() diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index cf28010f2c6..15bad203945 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -14,7 +14,7 @@ if(NOT BUILD_PYTHON) endif() if(USE_TBB) -include_directories(${TBB_INCLUDE_DIR}) + include_directories(${TBB_INCLUDE_DIR}) endif() set(TORCH_SRC_DIR "${CMAKE_CURRENT_SOURCE_DIR}") @@ -423,6 +423,10 @@ target_compile_options(torch_python PRIVATE ${TORCH_PYTHON_COMPILE_OPTIONS}) target_include_directories(torch_python PUBLIC ${TORCH_PYTHON_INCLUDE_DIRECTORIES}) +if(BUILD_ONEDNN_GRAPH) + target_compile_definitions(torch_python PRIVATE "-DBUILD_ONEDNN_GRAPH") + target_compile_definitions(torch_cpu PRIVATE "-DBUILD_ONEDNN_GRAPH") +endif() if(NOT TORCH_PYTHON_LINK_FLAGS STREQUAL "") set_target_properties(torch_python PROPERTIES LINK_FLAGS ${TORCH_PYTHON_LINK_FLAGS}) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 25a98ce3902..09995479ce9 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -228,6 +228,8 @@ def _debug_get_fusion_group_inlining() -> _bool: ... def _debug_set_fusion_group_inlining(enable: _bool): ... def _jit_texpr_fuser_enabled() -> _bool: ... def _jit_nvfuser_enabled() -> _bool: ... +def _jit_llga_enabled() -> _bool: ... +def _jit_set_llga_enabled(enable: _bool): ... def _llvm_enabled() -> _bool: ... def _jit_override_can_fuse_on_cpu(override: _bool): ... def _jit_override_can_fuse_on_gpu(override: _bool): ... diff --git a/torch/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp b/torch/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp new file mode 100644 index 00000000000..4953ba7c75c --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp @@ -0,0 +1,153 @@ +#include + +#if AT_MKLDNN_ENABLED() +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +dnnl::graph::engine& Engine::getEngine() { + static dnnl::graph::engine cpu_engine( + dnnl::graph::engine::kind::cpu, /* device_id = */ 0); + return cpu_engine; +} + +dnnl::graph::stream& Stream::getStream() { + static dnnl::graph::stream cpu_stream{Engine::getEngine(), nullptr}; + return cpu_stream; +} + +LlgaTensorImpl::LlgaTensorImpl( + at::Storage&& storage, + const caffe2::TypeMeta& data_type, + const LlgaTensorDesc& desc) + : at::TensorImpl( + std::move(storage), + c10::DispatchKeySet(c10::DispatchKey::MkldnnCPU), + data_type), + desc_(desc) { + sizes_and_strides_.set_sizes(desc.sizes()); + refresh_numel(); +} + +// The following are publically exposed as methods of Tensor +c10::IntArrayRef LlgaTensorImpl::strides() const { + TORCH_CHECK(false, "Cannot get strides of LlgaTensorImpl"); +} +int64_t LlgaTensorImpl::stride(int64_t d) const { + TORCH_CHECK(false, "Cannot get strides of LlgaTensorImpl"); +} +bool LlgaTensorImpl::is_contiguous(at::MemoryFormat memory_format) const { + TORCH_CHECK(false, "Cannot get query is_contiguous on LlgaTensorImpl"); +} +const at::Storage& LlgaTensorImpl::storage() const { + TORCH_CHECK(false, "Cannot access the storage() of LlgaTensorImpl"); +} +int64_t LlgaTensorImpl::storage_offset() const { + TORCH_CHECK(false, "Cannot access the storage_offset() of LlgaTensorImpl"); +} + +// The following are some internal inherited methods that we do not support. +// They should never get called. +void LlgaTensorImpl::set_size(int64_t dim, int64_t new_size) { + TORCH_INTERNAL_ASSERT(false, "Cannot set_size for LlgaTensorImpl"); +} +void LlgaTensorImpl::set_stride(int64_t dim, int64_t new_stride) { + TORCH_INTERNAL_ASSERT(false, "Cannot set_stride for LlgaTensorImpl"); +} +void LlgaTensorImpl::set_storage_offset(int64_t storage_offset) { + TORCH_INTERNAL_ASSERT(false, "Cannot set_storage_offset for LlgaTensorImpl"); +} +bool LlgaTensorImpl::has_storage() const { + return true; +} + +at::Tensor empty_llga( + const LlgaTensorDesc& desc, + const c10::TensorOptions& options) { + auto nbytes = desc.storage_size(); + + auto allocator = at::GetCPUAllocator(); + auto storage_impl = c10::make_intrusive( + c10::StorageImpl::use_byte_size_t(), + nbytes, + allocator->allocate(nbytes), + allocator, + /*resizable=*/false); + + return at::detail::make_tensor( + std::move(storage_impl), options.dtype(), desc); +} + +const LlgaTensorDesc& get_llga_desc(const at::Tensor& tensor) { + TORCH_INTERNAL_ASSERT( + tensor.is_mkldnn(), "get_llga_desc expects Mkldnn tensor input"); + return static_cast(tensor.unsafeGetTensorImpl())->desc(); +} + +dnnl::graph::tensor llga_from_aten_tensor(const at::Tensor& tensor) { + return { + get_llga_desc(tensor).logical_tensor(), + torch::jit::fuser::onednn::Engine::getEngine(), + tensor.data_ptr()}; +} + +using data_type = dnnl::graph::logical_tensor::data_type; + +data_type getLlgaDataType(at::ScalarType dt) { + switch (dt) { + case at::ScalarType::Float: + return data_type::f32; + case at::ScalarType::BFloat16: + return data_type::bf16; + case at::kInt: + return data_type::s32; + case at::ScalarType::QInt8: + return data_type::s8; + case at::ScalarType::QUInt8: + return data_type::u8; + default: + TORCH_CHECK(false, "Not support data type ", dt); + } +} + +LlgaTensorDesc LlgaTensorDesc::supplementTensorInfo(const at::Tensor& t) const { + if (t.is_mkldnn()) { + // if input tensor is of mkldnn, it's originated from an upstream + // LLGA partition which carries opaque layout info + return get_llga_desc(t).tid(tid_); + } else { + // if input tensor is not an mkldnn tensor, use default layout + auto sizes = t.sizes().vec(); + auto strides = t.strides().vec(); + auto dtype = getLlgaDataType(t.scalar_type()); + return {tid_, sizes, strides, dtype, property_type_}; + } +} + +at::ScalarType LlgaTensorDesc::aten_scalar_type() const { + switch (dtype_) { + case data_type::f32: + return at::ScalarType::Float; + case data_type::bf16: + return at::ScalarType::BFloat16; + case data_type::s32: + return at::kInt; + case data_type::s8: + return at::ScalarType::QInt8; + case data_type::u8: + return at::ScalarType::QUInt8; + default: + TORCH_CHECK(false, "Invalid data type ", static_cast(dtype_)); + } +} + +} // namespace onednn +} // namespace fuser +} // namespace jit +} // namespace torch + +#endif // AT_MKLDNN_ENABLED() diff --git a/torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h b/torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h new file mode 100644 index 00000000000..4fe9888369b --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h @@ -0,0 +1,279 @@ +#pragma once + +#include +#include + +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +// Engine represents a device and its context. From the device kind, the engine +// knows how to generate code for the target device and what kind of device +// object to be expected. The device id ensures that there is a unique engine +// being created for each device. The device handle passed from PyTorch allows +// oneDNN Graph implementation to work on the device specified by PyTorch, which +// is currently CPU, so we only have one engine. +// Ref: https://spec.oneapi.io/onednn-graph/latest/programming_model.html#engine +struct Engine { + // CPU engine singleton + static dnnl::graph::engine& getEngine(); + Engine(const Engine&) = delete; + void operator=(const Engine&) = delete; +}; + +// Stream is the logical abstraction for execution units. It is created on top +// of oneDNN Graph engine. A compiled oneDNN Graph partition is submitted to a +// stream for execution. +struct Stream { + // CPU stream singleton + static dnnl::graph::stream& getStream(); + Stream(const Stream&) = delete; + void operator=(const Stream&) = delete; +}; + +struct LlgaTensorDesc { + using desc = dnnl::graph::logical_tensor; + + LlgaTensorDesc( + size_t tid, + std::vector sizes, + std::vector strides, + desc::data_type dtype, + desc::property_type property_type) + : tid_(tid), + sizes_(sizes), + strides_(strides), + dtype_(dtype), + property_type_(property_type), + layout_type_(desc::layout_type::strided), + layout_id_(-1) {} + + LlgaTensorDesc(const desc& t) + : tid_(t.get_id()), + sizes_(t.get_dims()), + strides_({-1}), + dtype_(t.get_data_type()), + property_type_(t.get_property_type()), + layout_type_(t.get_layout_type()), + layout_id_(-1) { + if (is_opaque()) { + layout_id_ = t.get_layout_id(); + } + if (is_strided()) { + strides_ = t.get_strides(); + } + } + + // TODO: llga need set input/output type constraints while it seems that we + // cannot get the dtype during compile time, hard-coded to fp32 for now to be + // able to add_op + LlgaTensorDesc(const torch::jit::Value* v) + : LlgaTensorDesc( + v->unique(), + {}, + {}, + desc::data_type::f32, + get_property_type(v)) { + if (v->type()->isSubtypeOf(TensorType::get())) { + auto tt = v->type()->cast(); + + auto sizes = tt->sizes(); + if (sizes.sizes()) { + for (auto d : *sizes.sizes()) { + sizes_.push_back(d.value_or(DNNL_GRAPH_UNKNOWN_DIM)); + } + } + + auto strides = tt->strides(); + if (strides.sizes()) { + for (auto d : *strides.sizes()) { + strides_.push_back(d.value_or(DNNL_GRAPH_UNKNOWN_DIM)); + } + } + } + } + + LlgaTensorDesc supplementTensorInfo(const at::Tensor& t) const; + + at::ScalarType aten_scalar_type() const; + + const std::vector& sizes() const { + return sizes_; + } + + const std::vector& strides() const { + TORCH_CHECK(!is_opaque(), "Cannot get strides on opaque layout"); + return strides_; + } + + size_t tid() const { + return tid_; + } + + LlgaTensorDesc tid(uint64_t new_id) const { + auto ret = *this; + ret.tid_ = new_id; + return ret; + } + + desc::data_type dtype() const { + return dtype_; + } + + LlgaTensorDesc dtype(desc::data_type new_dtype) const { + return LlgaTensorDesc(tid_, sizes_, strides_, new_dtype, property_type_); + } + + desc::layout_type layout_type() const { + return layout_type_; + } + + LlgaTensorDesc layout_type(desc::layout_type new_layout_type) { + auto ret = *this; + ret.layout_type_ = new_layout_type; + return ret; + } + + desc::property_type get_property_type(const torch::jit::Value* v) { + switch (v->node()->kind()) { + case prim::Constant: + return desc::property_type::constant; + default: + return desc::property_type::variable; + } + } + + LlgaTensorDesc any() { + return layout_type(desc::layout_type::any); + } + + size_t storage_size() const { + return logical_tensor().get_mem_size(); + } + + desc logical_tensor() const { + if (is_dimensionality_unknown()) { + return desc( + tid_, dtype_, DNNL_GRAPH_UNKNOWN_NDIMS, layout_type_, property_type_); + } else if (is_opaque()) { + return desc(tid_, dtype_, sizes_, layout_id_, property_type_); + } else if (is_any()) { + return desc(tid_, dtype_, sizes_, layout_type_, property_type_); + } else { + return desc(tid_, dtype_, sizes_, strides_, property_type_); + } + } + + bool is_strided() const { + return layout_type_ == desc::layout_type::strided; + } + + bool is_any() const { + return layout_type_ == desc::layout_type::any; + } + + bool is_opaque() const { + return layout_type_ == desc::layout_type::opaque; + } + + bool operator==(const LlgaTensorDesc& desc) const { + return tid_ == desc.tid_ && sizes_ == desc.sizes_ && + dtype_ == desc.dtype_ && layout_type_ == desc.layout_type_ && + ((is_opaque() && layout_id_ == desc.layout_id_) || + strides_ == desc.strides_); + } + + bool operator!=(const LlgaTensorDesc& desc) const { + return (tid_ != desc.tid_) || (sizes_ != desc.sizes_) || + (dtype_ != desc.dtype_) || (layout_type_ != desc.layout_type_) || + !((is_opaque() && (layout_id_ == desc.layout_id_)) || + (strides_ == desc.strides_)); + } + + static size_t hash(const LlgaTensorDesc& desc) { + return c10::get_hash( + desc.tid_, + desc.sizes_, + desc.dtype_, + desc.layout_type_, + desc.layout_id_); + } + + void set_compute_inplace() { + compute_inplace_ = true; + } + + void set_input_tensor_index(size_t index) { + input_tensor_index_ = index; + } + + bool reuses_input_tensor() { + return compute_inplace_; + } + + size_t get_input_tensor_index() { + return input_tensor_index_; + } + + private: + bool is_dimensionality_unknown() const { + return sizes_.size() == 0; + } + + size_t tid_; + std::vector sizes_; + std::vector strides_; + desc::data_type dtype_; + desc::property_type property_type_; + desc::layout_type layout_type_; + size_t layout_id_; + // If this is an output tensor, and querying the compiled partition would + // determine that this tensor would reuse its input tensor, then + // compute_inplace would be true, and input_tensor_index would be the index of + // the corresponding input tensor in inputSpecs_ of the LlgaKernel object. + bool compute_inplace_ = false; + size_t input_tensor_index_; +}; + +struct TORCH_API LlgaTensorImpl : public c10::TensorImpl { + LlgaTensorImpl( + at::Storage&& storage, + const caffe2::TypeMeta& data_type, + const LlgaTensorDesc& desc); + + const LlgaTensorDesc& desc() const { + return desc_; + } + + // Override a bunch of methods inherited from TensorImpl to return error + // messages. + bool is_contiguous( + at::MemoryFormat memory_format = + at::MemoryFormat::Contiguous) const override; + c10::IntArrayRef strides() const override; + int64_t stride(int64_t d) const override; + void set_size(int64_t dim, int64_t new_size) override; + void set_stride(int64_t dim, int64_t new_stride) override; + void set_storage_offset(int64_t storage_offset) override; + bool has_storage() const override; + const at::Storage& storage() const override; + int64_t storage_offset() const override; + + private: + LlgaTensorDesc desc_; +}; + +at::Tensor empty_llga( + const LlgaTensorDesc& desc, + const c10::TensorOptions& options); + +dnnl::graph::tensor llga_from_aten_tensor(const at::Tensor& tensor); + +} // namespace onednn +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/onednn/README.md b/torch/csrc/jit/codegen/onednn/README.md new file mode 100644 index 00000000000..3da25117a07 --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/README.md @@ -0,0 +1,108 @@ +# Pytorch - oneDNN Graph API Bridge +This integration will add the infrastructure of a new PyTorch JIT graph fuser based on [oneDNN Graph API](https://spec.oneapi.io/onednn-graph/latest/programming_model.html), which provides a flexible API for aggressive fusion. The current preview4 version supports fusion for FP32 inference. Currently, the speedup is achieved for static shapes, +although we'd soon add dynamic-shape support. When oneDNN Graph is enabled, weights are cached, as they're constant during inference. + +## Graph Optimization +We have registered optimization passes in the custom pre-passes set of PyTorch: + +1. Alias and mutation reduction + + The operators of oneDNN graph are pure functional while PyTorch has operators in in-place forms or create views for buffer sharing. + Due to the semantic gaps between the backend operators and the PyTorch operators, we have a pass to reduce mutation with best effort at the beginning. + +2. Graph passing + + With a PyTorch TorchScript graph, the integration maps PyTorch operators on the graph to the corresponding oneDNN Graph operators to form a backend graph. + +3. Partitioning + + The backend selects regions to be fused in the graph and returns a list of partitions. Each partition corresponds to a set of fused operators. + +4. Graph rewriting + + The original PyTorch JIT graph will be re-written based on the partitions returned from the backend. The operators in one partition will be grouped together to form a JIT operator, referred to as a oneDNN Graph fusion group. + +5. Layout propagation + + This pass is to eliminate unnecessary layout conversions at partition boundaries. We set different formats to the output of a partition so that the backend could perform layout conversion internally. When `ANY` is set, the layout at boundaries will be fully decided by the backend. Otherwise, the backend should follow the layout set by PyTorch. Currently, we set `ANY` layout for a tensor that's an output of a oneDNN Graph partition, and an input to another. + +## Graph Executor +During runtime execution of a (re-written) PyTorch JIT graph, oneDNN graph partitions will be dispatched to the oneDNN graph JIT variadic Operator. +Inside the oneDNN graph JIT Op, input PyTorch tensors of each partition will be mapped to oneDNN graph tensors. The partition will then be [compiled](https://spec.oneapi.io/onednn-graph/latest/programming_model.html#partition) and [executed](https://spec.oneapi.io/onednn-graph/latest/programming_model.html#compiled-partition). The output oneDNN graph tensor will be mapped back to PyTorch tensors to be fed to the next operator on the PyTorch JIT graph. + + +## Tests + +```bash +pytest test/test_jit_llga_fuser.py +``` + +## Quick Start + +A simple cascaded Conv-Relu example is provided in test. Please consider enabling log outputs to familiarize yourself with the whole pipeline: + +**Mutation Removal -> Prepare Binary -> Defer Size Check -> Graph Fuser -> Layout Propagation -> Type Guard -> Kernel Execution** + +oneDNN Graph was formerly known as LLGA (Low Level Graph API), +and thus LLGA in the codebase corresponds to oneDNN Graph. + +```bash +DNNL_VERBOSE=1 PYTORCH_JIT_LOG_LEVEL=">>graph_helper:>>graph_fuser:>>kernel:>>interface" python -u test/test_jit_llga_fuser.py -k test_conv2d_eltwise +``` + +## Codebase structure + +Most of the source code is placed in + +```bash +torch/csrc/jit/codegen/onednn/* +``` + +Tensor related code is located at + +```bash +torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h +torch/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp +``` + +CMake files where bridge code is included: + +```bash +caffe2/CMakeLists.txt +``` + +CMake files where oneDNN Graph submodule are included: + +```bash +third_party/ideep/mkl-dnn +cmake/public/mkldnn.cmake +cmake/Modules/FindMKLDNN.cmake +cmake/Dependencies.cmake +``` + +To map another op to oneDNN Graph, you should add an entry for it in in createOperator in torch/csrc/jit/codegen/onednn/graph_helper.cpp. +If it has an inplace variant, you should add it in the lambda being passed to RemoveTensorMutation in +torch/csrc/jit/codegen/onednn/interface.cpp. You might also want to add it to canFuseNode in torch/csrc/jit/codegen/onednn/register_interface.cpp. + +## How to use + + +```python +# enable oneDNN graph fusion globally +torch.jit.enable_onednn_fusion(True) + +# define the model +def MyModel(torch.nn.Module): + ... + +# construct the model +model = MyModel(…) +with torch.no_grad(): + model.eval() + model = torch.jit.trace(model, torch.rand(args.batch_size, 3, 224, 224)) + +# run the model +with torch.no_grad(): + # oneDNN graph fusion will be trigerred during runtime + output = model(images) +``` diff --git a/torch/csrc/jit/codegen/onednn/defer_size_check.cpp b/torch/csrc/jit/codegen/onednn/defer_size_check.cpp new file mode 100644 index 00000000000..28266a80859 --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/defer_size_check.cpp @@ -0,0 +1,87 @@ +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +class SizeCheckMover { + private: + Block* block_; + std::shared_ptr graph_; + + public: + SizeCheckMover(Block* block, std::shared_ptr graph) + : block_(block), graph_(std::move(graph)) {} + + bool analyzeNode(Node* node, AliasDb& aliasDb) { + // + // %b = addmm(%a) + // %sz = aten::size(%b) + // %c = relu(%b) + // => + // %b = addmm(%a) + // %c = relu(%b) + // %sz = aten::size(%c) + // ^-- move size check after relu as it preserves input shape + // + if (!node->matches("aten::size(Tensor self) -> int[]")) + return false; + + auto* input = node->input(0); + auto& uses = input->uses(); + bool onlyUsedByShapePreserveOp = + uses.size() > 1 && std::all_of(uses.begin(), uses.end(), [&](auto& u) { + if (u.user == node) { + return true; + } + // match with shape-preserving unary ops in + // tensorexpr_elementwise_set that's defined in + // torch/csrc/jit/runtime/symbolic_shape_registry_util.cpp + OperatorMap schemaMap = get_tensorexpr_elementwise_set(); + c10::optional mapping = + schemaMap.find(u.user->getOperator()); + return mapping == "unary"; + }); + + if (!onlyUsedByShapePreserveOp) + return false; + + for (const auto& use : uses) { + if (use.user == node) + continue; + auto shapePreserveOp = use.user; + if (aliasDb.moveAfterTopologicallyValid(node, shapePreserveOp)) { + node->replaceInputWith(input, shapePreserveOp->output(0)); + return true; + } + } + + return false; + } + + void run() { + bool changed = true; + while (changed) { + changed = false; + AliasDb aliasDb(graph_); + for (Node* node : block_->nodes()) { + changed |= analyzeNode(node, aliasDb); + } + } + + for (Node* node : block_->nodes()) + for (Block* subBlock : node->blocks()) + SizeCheckMover(subBlock, graph_).run(); + } +}; + +void DeferSizeCheck(std::shared_ptr& graph) { + SizeCheckMover(graph->block(), graph).run(); +} + +} // namespace onednn +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/onednn/defer_size_check.h b/torch/csrc/jit/codegen/onednn/defer_size_check.h new file mode 100644 index 00000000000..6e31cf202d3 --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/defer_size_check.h @@ -0,0 +1,15 @@ +#pragma once + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +void DeferSizeCheck(std::shared_ptr& graph); + +} // namespace onednn +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/onednn/graph_fuser.cpp b/torch/csrc/jit/codegen/onednn/graph_fuser.cpp new file mode 100644 index 00000000000..2a956362688 --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/graph_fuser.cpp @@ -0,0 +1,31 @@ +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +void CreateLlgaSubgraphs(std::shared_ptr& graph) { + AliasDb db(graph); + GraphRewriter graphRewriter(graph->block(), graph, db); + // We maintain alias db correctness in-place while building up the LLGA + // subgraphs, however it is difficult to preserve correctness when + // un-inlining autodiff subgraphs. We first recursively construct all + // subgraphs and then recursively cleanup & unmerge the small subgraphs + graphRewriter.buildupSubgraphs(); + graphRewriter.cleanupSubgraphs(); + // Run CSE globally onceto eliminate duplicates that may have occurred + // while inlining subgraphs. + EliminateCommonSubexpression(graph); + EliminateDeadCode(graph); +} + +} // namespace onednn +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/onednn/graph_fuser.h b/torch/csrc/jit/codegen/onednn/graph_fuser.h new file mode 100644 index 00000000000..ee83edc68fc --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/graph_fuser.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +struct WorkBlock : public std::pair { + using pair::pair; + + Node* begin() { + return this->first; + } + Node* end() { + return this->second; + } +}; + +class GraphRewriter { + public: + GraphRewriter(Block* block, std::shared_ptr graph, AliasDb& aliasDb) + : block_(block), + graph_(std::move(graph)), + aliasDb_(aliasDb), + llgaHelper_(graph_) {} + + void cleanupSubgraphs(); + void buildupSubgraphs(); + + private: + Block* block_; + std::shared_ptr graph_; + AliasDb& aliasDb_; + LlgaGraphHelper llgaHelper_; + std::vector buildWorkBlocks(); + std::pair scanNode( + Node* consumer, + graph_node_list::iterator workblock_begin); + c10::optional tryMerge(Node* consumer, Node* producer); +}; + +// This pass creates the subgraphs for oneDNN Graph Fusion Nodes. +// Its code-structure has been vastly inspired from +// torch/csrc/jit/passes/create_autodiff_subgraphs.cpp +void CreateLlgaSubgraphs(std::shared_ptr& graph); + +} // namespace onednn +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/onednn/graph_helper.cpp b/torch/csrc/jit/codegen/onednn/graph_helper.cpp new file mode 100644 index 00000000000..4b37054fc97 --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/graph_helper.cpp @@ -0,0 +1,557 @@ +#include +#include + +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +using opkind = dnnl::graph::op::kind; + +void fixConvOptionalBias(Node* node) { + if (node->namedInput("bias")->mustNotBeNone() == false) { + // Replace non-existent optional bias with const None + auto g = node->owningGraph(); + auto n = g->createNone(); + auto v = n->insertBefore(node)->output(); + node->replaceInput(2, v); + } +} + +c10::optional getDimensions(Value* v) { + if (v->type()->isSubtypeOf(TensorType::get())) { + return v->type()->cast()->sizes().size(); + } else { + return c10::nullopt; + } +} + +// PyTorch ops that can't otherwise be mapped to oneDNN Graph ops are mapped as +// Wildcards instead. They make the integration code with PyTorch simpler by +// passing every op to the oneDNN Graph library in the add_op call - +// no need to check beforehand whether the op is supported by oneDNN Graph or +// not oneDNN Graph ops separated by wildcards don't end up in the same +// partition. +Operator makeWildcardOp(Node* node) { + auto o = Operator(node, opkind::Wildcard); + // wildcard op contains only topology info + for (size_t i = 0; i < node->inputs().size(); i++) { + o.setInput(i); + } + for (size_t i = 0; i < node->outputs().size(); i++) { + o.setOutput(i); + } + return o; +} + +// If we don't meet a certain condition to map a PyTorch op to a oneDNN Graph +// op, then we create a wildcard op corresponding to that PyTorch op instead. +#define REQUIRE(cond) \ + if (!(cond)) { \ + GRAPH_DEBUG("Unsupported condition " #cond "\n"); \ + return makeWildcardOp(node); \ + } + +Operator makeEltwiseOp(Node* node, opkind kind) { + return Operator(node, kind).setInput(0).setOutput(0); +} + +Operator makeBinaryOp(Node* node, opkind kind) { + REQUIRE( + node->input(0)->type()->isSubtypeOf(TensorType::get()) && + node->input(1)->type()->isSubtypeOf(TensorType::get())) + return Operator(node, kind).setInput(0, 1).setOutput(0); +} + +// Map a PyTorch op to its corresponding oneDNN Graph op. +// If mapping isn't possible, then create a wildcard op instead. +// The mapping is done as per oneDNN Graph op schema defined in +// third_party/ideep/mkl-dnn/src/interface/op_def.hpp. +Operator createOperator(Node* node) { + switch (node->kind()) { + case aten::conv2d: { + fixConvOptionalBias(node); + return Operator(node, opkind::Convolution) + .setInput(0, 1, 2) + .setOutput(0) + .setAttr("strides", Operator::Ints, 3) + .setAttr("pads_begin", Operator::Ints, 4) + .setAttr("pads_end", Operator::Ints, 4) + .setAttr("dilations", Operator::Ints, 5) + .setAttr("groups", Operator::Int, 6) + .setAttr("filter_format", std::string("OIX")); + } + + case aten::_convolution: { + bool transposed = toIValue(node->namedInput("transposed"))->toBool(); + REQUIRE(!transposed); + + return Operator(node, opkind::Convolution) + .setInput(0, 1, 2) + .setOutput(0) + .setAttr("strides", Operator::Ints, 3) + .setAttr("pads_begin", Operator::Ints, 4) + .setAttr("pads_end", Operator::Ints, 4) + .setAttr("dilations", Operator::Ints, 5) + .setAttr("groups", Operator::Int, 8) + .setAttr("filter_format", std::string("OIX")); + } + + case aten::batch_norm: { + auto training = toIValue(node->namedInput("training")); + REQUIRE( + training.has_value()); // cannot get training status in script mode + REQUIRE(!training->toBool()); // TODO: support bn training + return Operator(node, opkind::BatchNormInference) + .setInput(0, 1, 2, 3, 4) + .setOutput(0) + .setAttr("epsilon", Operator::Float, 7); + } + + case aten::layer_norm: { + auto normalized_shape = toIValue(node->namedInput("normalized_shape")); + REQUIRE(normalized_shape->toIntList().size() == 1); + return Operator(node, opkind::LayerNorm) + .setInput(0, 2, 3) + .setOutput(0) + .setAttr("epsilon", Operator::Float, 4) + .setAttr("keep_stats", false); + } + + case aten::addmm: { + auto alpha = toIValue(node->namedInput("alpha")); + auto beta = toIValue(node->namedInput("beta")); + REQUIRE( + alpha.has_value() && beta.has_value() && (alpha->toDouble() == 1.0) && + (beta->toDouble() == 1.0)); + return Operator(node, opkind::MatMul).setInput(1, 2, 0).setOutput(0); + } + + case aten::add: + return makeBinaryOp(node, opkind::Add); + + case aten::mul: + return makeBinaryOp(node, opkind::Multiply); + + case aten::tanh: + return makeEltwiseOp(node, opkind::Tanh); + + case aten::relu: + return makeEltwiseOp(node, opkind::ReLU); + + case aten::elu: + return makeEltwiseOp(node, opkind::Elu) + .setAttr("alpha", Operator::Float, 1); + + case aten::sigmoid: + return makeEltwiseOp(node, opkind::Sigmoid); + case aten::gelu: + return makeEltwiseOp(node, opkind::GELU); + + case aten::sqrt: + return makeEltwiseOp(node, opkind::Sqrt); + + case aten::abs: + return makeEltwiseOp(node, opkind::Abs); + + case aten::square: + return makeEltwiseOp(node, opkind::Square); + + case aten::hardtanh: + return makeEltwiseOp(node, opkind::HardTanh) + .setAttr("min", Operator::Float, 1) + .setAttr("max", Operator::Float, 2); + + case aten::relu6: + return makeEltwiseOp(node, opkind::HardTanh) + .setAttr("min", 0.f) + .setAttr("max", 6.f); + + case aten::softmax: { + auto axis = toIValue(node->namedInput("dim"))->toInt(); + return Operator(node, opkind::SoftMax) + .setInput(0) + .setOutput(0) + .setAttr("axis", axis); + } + + case aten::cat: { + auto o = Operator(node, opkind::Concat); + REQUIRE( + node->namedInput("tensors")->node()->kind() == prim::ListConstruct); + REQUIRE(node->namedInput("tensors")->uses().size() == 1); + REQUIRE(node->namedInput("dim")->node()->kind() == prim::Constant); + // aten::cat needs a special handling since it takes a Tensor[] as input. + // We set the inputs of ListConstruct as the inputs of cat. + // + // Pytorch IR: LLGA sees: + // %a %b %c %dim %a %b %c + // \ | / | \ | / + // prim::ListConstruct prim::Constant llga::Concat[axis=%dim] + // \ / + // aten::cat + auto listConstruct = node->input(0)->node(); + for (auto input : listConstruct->inputs()) + o.setInputValue(input); + return o.setOutput(0).setAttr("axis", Operator::Int, 1); + } + + case aten::max_pool2d: { + REQUIRE( + node->namedInput("kernel_size")->node()->kind() == prim::Constant); + + auto rounding_type = + toIValue(node->namedInput("ceil_mode"))->toBool() ? "ceil" : "floor"; + return Operator(node, opkind::MaxPool) + .setInput(0) + .setOutput(0) + .setAttr("kernel", Operator::Ints, 1) + .setAttr("strides", Operator::Ints, 2) + .setAttr("pads_begin", Operator::Ints, 3) + .setAttr("pads_end", Operator::Ints, 3) + .setAttr("dilations", Operator::Ints, 4) + .setAttr("rounding_type", std::string(rounding_type)); + } + + case aten::avg_pool2d: { + // TODO: do we need add checks for all Constants? + REQUIRE( + node->namedInput("kernel_size")->node()->kind() == prim::Constant); + auto rounding_type = + toIValue(node->namedInput("ceil_mode"))->toBool() ? "ceil" : "floor"; + auto divisor_override = toIValue(node->namedInput("divisor_override")); + REQUIRE(divisor_override->isNone()); + return Operator(node, opkind::AvgPool) + .setInput(0) + .setOutput(0) + .setAttr("kernel", Operator::Ints, 1) + .setAttr("strides", Operator::Ints, 2) + .setAttr("pads_begin", Operator::Ints, 3) + .setAttr("pads_end", Operator::Ints, 3) + .setAttr("exclude_pad", !Operator::Bool(node, 5)) + .setAttr("rounding_type", std::string(rounding_type)); + } + + case aten::matmul: { + auto dim0 = getDimensions(node->namedInput("self")).value_or(-1); + auto dim1 = getDimensions(node->namedInput("other")).value_or(-1); + // TODO: support all shape combinations + REQUIRE( + (dim0 == 2 && dim1 == 2) || (dim0 == 4 && dim1 == 4) || + (dim0 == 3 && dim1 == 2)); + } // fall through + case aten::mm: { + return Operator(node, opkind::MatMul).setInput(0, 1).setOutput(0); + } + + case aten::linear: { + return Operator(node, opkind::MatMul) + .setInput(0, 1, 2) + .setOutput(0) + .setAttr("transpose_b", true); + } + + default: + return makeWildcardOp(node); + } +} + +dnnl::graph::op createLlgaOp(Node* node) { + return createOperator(node).llgaOp(); +} + +bool isSupported(Node* node) { + return createOperator(node).kind() != opkind::Wildcard; +}; + +DeviceType inferDeviceFromValue(Value* v) { + auto tt = v->type()->cast(); + if (!tt) { + return at::kCPU; + } + auto device = tt->device(); + if (!device) { + return at::kCPU; + } + return device->type(); +} + +DeviceType inferDevice(const std::shared_ptr& graph) { + auto dt = inferDeviceFromValue(graph->inputs()[0]); + TORCH_CHECK( + std::all_of( + graph->inputs().begin(), + graph->inputs().end(), + [dt](Value* v) { return inferDeviceFromValue(v) == dt; }), + "All inputs must have the same deive type"); + return dt; +} + +dnnl::graph::engine::kind getLlgaEngineKind(DeviceType type) { + switch (type) { + case DeviceType::CPU: + return dnnl::graph::engine::kind::cpu; + default: + TORCH_CHECK(false, "Not support device type ", type); + } +} + +void mayAddListConstructIntoConcatPartition( + Node* n, + OpPartitionMap& opToOwningPartition) { + // Since prim::ListConstruct is not visible to the LLGA, + // it will not be in any partition returned from partfuseritioning results. + // We need rewrite opToOwningPartition to make the prim::ListConstruct to be + // 'virtually' in the same partition with the aten::cat, so that + // prim::ListConstruct can be fused into the fusion group by graph fuser. + // We emphasize on 'virtually' because get_num_ops() for cat's partition + // would still return 1. + if (n->kind() == aten::cat && opToOwningPartition.has(n)) { + auto listConstrcut = n->namedInput("tensors")->node(); + auto partitionId = opToOwningPartition.get(n); + opToOwningPartition.add(listConstrcut, partitionId); + } +} + +// Verify that input tensors are compatible with oneDNN Graph. +// Scalars would be converted to 1-D tensors later anyway, +// but they shouldn't be complex-double +// If this check fails, convert op to wildcard +bool checkInputCompatibility(Node* node) { + auto allInputs = node->inputs(); + for (auto input : allInputs) { + c10::IValue inputIValue = toIValue(input); + if (inputIValue.isTensor()) { + const at::Tensor& tensor = inputIValue.toTensor(); + if (tensor.device() != at::kCPU) { + return false; + } + auto dtype = tensor.scalar_type(); + if ((dtype != at::ScalarType::Float) && (dtype != at::ScalarType::Long)) { + return false; + } + } else if (inputIValue.isScalar()) { + if (inputIValue.isComplexDouble()) { + return false; + } + } + } + return true; +} + +LlgaGraphHelper::LlgaGraphHelper( + const std::shared_ptr& graph, + dnnl::graph::partition::policy policy) { + auto deviceType = inferDevice(graph); + auto engineKind = getLlgaEngineKind(deviceType); + dnnl::graph::graph g{engineKind}; + + GRAPH_DEBUG("Constructing LLGA graph"); + // TODO: select nodes in top-level block for now + for (auto* node : graph->block()->nodes()) { + auto op = createLlgaOp(node); + auto kindOfNode = node->kind(); + if (checkInputCompatibility(node)) { + g.add_op(op); + GRAPH_DEBUG(" Added node ", kindOfNode.toQualString()); + } else { + GRAPH_DEBUG("The backend failed to add node ", kindOfNode.toQualString()); + g.add_op(makeWildcardOp(node).llgaOp()); + } + + for (Value* input : node->inputs()) { + tensorIdToValue_.emplace(input->unique(), input); + } + } + + GRAPH_DEBUG("Get Partitions"); + std::vector partitions = g.get_partitions(policy); + // excluded unsupported Wildcard partitions + for (auto& partition : partitions) { + if (partition.is_supported()) { + partitions_.push_back(partition); + } + } + + GRAPH_DEBUG(" Got #partitions: ", partitions_.size()); + for (size_t partId = 0; partId < partitions_.size(); partId++) { + for (auto opId : partitions_[partId].get_ops()) { + opToOwningPartition_.add(opId, partId); + } + } + + // Scanning the graph again for post processing + for (auto* node : graph->block()->nodes()) { + mayAddListConstructIntoConcatPartition(node, opToOwningPartition_); + } +} + +bool LlgaGraphHelper::isLlgaSubgraph(const Node* node) { + return node->hasAttribute(attr::Subgraph) && + node->kind() == prim::oneDNNFusionGroup; +} + +bool LlgaGraphHelper::shouldMerge(Node* toMerge, Node* subgraph) { + TORCH_CHECK( + isLlgaSubgraph(subgraph), + "The consumer node does not contain a subgraph"); + if (!shouldConsiderForMerge(toMerge)) { + return false; + } + return opToOwningPartition_.get(toMerge) == + opToOwningPartition_.get(subgraph); +} + +// Except for conv & GEMMs, which should always be handled by oneDNN Graph, +// only use single-op partitions for ops unsupported by NNC, or ops +// that oneDNN executes faster. prim::ListConstruct is an exception, since +// we simply want to fuse it with cat. +bool isBetterSuitedForLLGA(NodeKind kindOfOp) { + return ( + (kindOfOp == aten::layer_norm) || (kindOfOp == aten::avg_pool2d) || + (kindOfOp == aten::matmul) || (kindOfOp == aten::max_pool2d) || + (kindOfOp == aten::conv2d) || (kindOfOp == aten::_convolution) || + (kindOfOp == aten::mm) || (kindOfOp == aten::linear) || + (kindOfOp == aten::cat) || (kindOfOp == prim::ListConstruct)); +} + +bool LlgaGraphHelper::checkForSingleOpPartition(Node* node) { + if (opToOwningPartition_.has(node)) { + auto partitionId = opToOwningPartition_.get(node); + if (partitions_[partitionId].get_ops_num() == 1) { + auto kindOfNode = node->kind(); + return isBetterSuitedForLLGA(kindOfNode); + } else { + // multi-op partition + return true; + } + } else { + // this op isn't present in any partition + return false; + } +} + +bool LlgaGraphHelper::shouldConsiderForMerge(Node* node) { + // if we're already in the process of merging + if (isLlgaSubgraph(node)) { + return true; + } + return checkForSingleOpPartition(node); +} + +Node* LlgaGraphHelper::createSingletonSubgraph(Node* n, AliasDb& aliasDb) { + auto partitionId = opToOwningPartition_.get(n); + GRAPH_DEBUG( + "Creating FusionGroup_", partitionId, " for ", n->kind().toQualString()); + auto group = SubgraphUtils::createSingletonSubgraphAndUpdateAliasing( + n, prim::oneDNNFusionGroup, aliasDb); + opToOwningPartition_.add(group, partitionId); + LlgaNodeWrapper(group).initOutputLayouts(); + return group; +} + +void LlgaGraphHelper::mergeNodeIntoSubgraph( + Node* toMerge, + Node* subgraphNode, + AliasDb& aliasDb) { + if (isLlgaSubgraph(toMerge)) { + GRAPH_DEBUG( + "Merging ", + toMerge->kind().toQualString(), + "_", + opToOwningPartition_.get(toMerge), + " into ", + subgraphNode->kind().toQualString(), + "_", + opToOwningPartition_.get(subgraphNode)); + } else { + GRAPH_DEBUG( + "Merging ", + toMerge->kind().toQualString(), + " into ", + subgraphNode->kind().toQualString(), + "_", + opToOwningPartition_.get(subgraphNode)); + } + + SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing( + toMerge, subgraphNode, aliasDb); +} + +void LlgaGraphHelper::unmergeIfAnyNodeIsMissing(Node* subgraphNode) { + TORCH_CHECK(isLlgaSubgraph(subgraphNode), "Cannot unmerge a non-LLGA node"); + + auto partitionId = opToOwningPartition_.get(subgraphNode); + auto expectOpNum = partitions_[partitionId].get_ops_num(); + auto actualOpNum = countSupportedOps(subgraphNode->g(attr::Subgraph)); + + if (expectOpNum != actualOpNum) { + GRAPH_DEBUG( + "Unmerging FusionGroup_", + partitionId, + ". Expected ", + expectOpNum, + " ops, but got ", + actualOpNum, + " ops."); + SubgraphUtils::unmergeSubgraph(subgraphNode); + } +} + +size_t LlgaGraphHelper::countSupportedOps( + const std::shared_ptr& graph) const { + // TODO: count nodes in top-level block for now + size_t cnt = 0; + for (auto* node : graph->block()->nodes()) { + auto nodeKind = node->kind(); + if ((nodeKind != prim::Constant) && (nodeKind != prim::ListConstruct)) { + cnt++; + } + } + return cnt; +} + +std::vector LlgaGraphHelper::getPartitions() const { + return partitions_; +} + +std::map LlgaGraphHelper::getTensorIdToValue() const { + return tensorIdToValue_; +} + +LlgaNodeWrapper::LlgaNodeWrapper(const Node* node) + : n(const_cast(node)) { // NOLINT + TORCH_CHECK( + LlgaGraphHelper::isLlgaSubgraph(n), "Cannot wrap a non-LLGA fusion node"); +} + +void LlgaNodeWrapper::setOpaqueLayout(size_t offset) { + TORCH_CHECK(offset < n->outputs().size(), "Invalid output offset ", offset); + auto& layouts = + const_cast&>(n->is(attr::output_layouts)); // NOLINT + layouts.at(offset) = 1; +} + +bool LlgaNodeWrapper::useOpaqueLayout(size_t offset) const { + TORCH_CHECK(offset < n->outputs().size(), "Invalid output offset ", offset); + return n->is(attr::output_layouts)[offset] == 1; +} + +void LlgaNodeWrapper::initOutputLayouts() { + if (n->hasAttribute(attr::output_layouts)) { + return; + } + + // Init all output layouts as undef + std::vector layouts(n->outputs().size(), 0); + n->is_(attr::output_layouts, layouts); +} + +} // namespace onednn +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/onednn/graph_helper.h b/torch/csrc/jit/codegen/onednn/graph_helper.h new file mode 100644 index 00000000000..969f3cdc0ef --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/graph_helper.h @@ -0,0 +1,95 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +struct OpPartitionMap { + void add(uint64_t opId, uint64_t partitionId) { + opmap_[opId] = partitionId; + } + void add(Node* n, uint64_t partitionId) { + add(Operator::getId(n), partitionId); + } + bool has(uint64_t opId) { + return opmap_.count(opId) > 0; + } + bool has(Node* n) { + return has(Operator::getId(n)); + } + uint64_t get(uint64_t opId) { + return opmap_[opId]; + } + uint64_t get(Node* n) { + auto opId = Operator::getId(n); + TORCH_CHECK( + has(opId), + "Node ", + n->kind().toQualString(), + " does not belong to any LLGA partition"); + return get(opId); + } + + private: + std::unordered_map opmap_; +}; + +class LlgaGraphHelper { + public: + LlgaGraphHelper( + const std::shared_ptr& graph, + dnnl::graph::partition::policy policy = + dnnl::graph::partition::policy::fusion); + + bool shouldMerge(Node* toMerge, Node* subgraph); + + bool shouldConsiderForMerge(Node* node); + + bool checkForSingleOpPartition(Node* node); + + Node* createSingletonSubgraph(Node* n, AliasDb& db); + + void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode, AliasDb& db); + + void unmergeIfAnyNodeIsMissing(Node* subgraphNode); + + static bool isLlgaSubgraph(const Node* node); + + std::vector getPartitions() const; + + std::map getTensorIdToValue() const; + + private: + size_t countSupportedOps(const std::shared_ptr& graph) const; + + OpPartitionMap opToOwningPartition_; + std::vector partitions_; + std::map + tensorIdToValue_; // map from tensorId to torch::jit::Value +}; + +class LlgaNodeWrapper { + public: + LlgaNodeWrapper(const Node* node); + + void setOpaqueLayout(size_t offset); + + bool useOpaqueLayout(size_t offset) const; + + friend class LlgaGraphHelper; + + private: + void initOutputLayouts(); + + Node* n; +}; + +} // namespace onednn +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/onednn/graph_rewriter.cpp b/torch/csrc/jit/codegen/onednn/graph_rewriter.cpp new file mode 100644 index 00000000000..c91ff9b3917 --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/graph_rewriter.cpp @@ -0,0 +1,144 @@ +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +void GraphRewriter::cleanupSubgraphs() { + auto curNode = *block_->nodes().rbegin(); + while (curNode != *block_->nodes().rend()) { + // Save the previous node, since we might delete `curNode` in next block + auto prevNode = curNode->prev(); + if (llgaHelper_.isLlgaSubgraph(curNode)) { + // Unmerge subgraph if we don't get every nodes of a partition + // into the subgraph due to failed alias check + llgaHelper_.unmergeIfAnyNodeIsMissing(curNode); + } + curNode = prevNode; + } + for (Node* n : block_->nodes()) { + for (Block* b : n->blocks()) { + GraphRewriter(b, graph_, aliasDb_).cleanupSubgraphs(); + } + } +} + +void GraphRewriter::buildupSubgraphs() { + // We need to run the rewriter multiple times in order to get all merge + // opportunities. This is because moveBeforeTopologicalValid may reorder + // nodes to be AFTER the current iteration point. In order to properly + // consider those nodes for merging, we need run the pass until no changes + // have been made. + // + // Example: + // c = f(a, b) + // d = f(c) + // e = f(d) <- iter is here, moving upward + // After c.moveBeforeTopologicallyValid(e), we have: + // c = f(a, b) + // e = f(d) <- iter still here + // d = f(c) <- this was node moved on the other side. + // see [workblocks] + auto workblocks = buildWorkBlocks(); + for (auto& workblock : workblocks) { + bool any_changed = true; + while (any_changed) { + any_changed = false; + auto workblock_end = workblock.end()->reverseIterator(); + auto workblock_begin = workblock.begin()->reverseIterator(); + for (auto it = workblock_end; it != workblock_begin;) { + bool changed = false; + std::tie(it, changed) = scanNode(*it, workblock_begin); + any_changed |= changed; + } + } + } + + // Construct Subgraphs Recursively + for (Node* n : block_->nodes()) { + for (auto subBlock : n->blocks()) { + GraphRewriter(subBlock, graph_, aliasDb_).buildupSubgraphs(); + } + } +} + +std::vector GraphRewriter::buildWorkBlocks() { + // [workblocks] + // the IR has many nodes which can never be reordered around, such as a + // prim::Bailout. if a node N is surrounded by two nodes which cannot be + // reordered, A and B, then a fusion group that is created from N + // can only contain nodes from (A, B) The nodes from A to B represent one + // work block for the subgraph rewriter to work on. By creating these up + // front, we avoid retraversing the whole graph block any time scanNode + // returns + Node* end_bound_node = block_->return_node(); + Node* curr = end_bound_node->prev(); + std::vector worklist; + while (curr != block_->param_node()) { + // cannot reorder around side effectful nodes + if (curr->hasSideEffects()) { + worklist.emplace_back(curr, end_bound_node); + end_bound_node = curr; + } + curr = curr->prev(); + } + worklist.emplace_back(curr, end_bound_node); + return worklist; +} + +std::pair GraphRewriter::scanNode( + Node* consumer, + graph_node_list::iterator workblock_begin) { + GRAPH_DEBUG("Scanning ", consumer->kind().toQualString()); + if (llgaHelper_.shouldConsiderForMerge(consumer)) { + if (!llgaHelper_.isLlgaSubgraph(consumer)) { + consumer = llgaHelper_.createSingletonSubgraph(consumer, aliasDb_); + } + // Iterate through the workblock to merge nodes of the + // same partition determined by LLGA graph helper. + // Nodes like B and C do not share a common input but belong to a + // same partition, and thus we cannot only scan the input nodes + // to find merging opportunities. Instead, we have to scan through + // the whole workblock, which might lead to O^2 accesses in worst case + // A + // + - - / - \ - - + + // | B C | + // | | | | + // | D E | + // + - - \ - / - - + + // F + auto prev = ++consumer->reverseIterator(); + for (auto it = prev; it != workblock_begin; it++) { + if (auto group = tryMerge(consumer, *it)) { + // we successfully merged, so the new group's `inputs` may have + // changed. So rescan the new group for more merging opportunities. + return std::make_pair(group.value()->reverseIterator(), true); + } + } + } + return std::make_pair(++consumer->reverseIterator(), false); +} + +// Try to merge `producer` into `consumer`. If successful, this destroys +// `producer` and returns the `consumer` group. +c10::optional GraphRewriter::tryMerge(Node* consumer, Node* producer) { + AT_ASSERT(llgaHelper_.isLlgaSubgraph(consumer)); + bool canMerge = llgaHelper_.shouldMerge(producer, consumer) && + aliasDb_.moveBeforeTopologicallyValid(producer, consumer); + if (!canMerge) { + return c10::nullopt; + } + llgaHelper_.mergeNodeIntoSubgraph(producer, consumer, aliasDb_); + return consumer; +} + +} // namespace onednn +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/onednn/guard_shape.cpp b/torch/csrc/jit/codegen/onednn/guard_shape.cpp new file mode 100644 index 00000000000..ee595b5c8d7 --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/guard_shape.cpp @@ -0,0 +1,45 @@ +#include + +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +//! [ Note -- prepareFusionGroupAndGuardOutputs implementation ] +//! shamelessly copying code from NNC (tensorexpr_fuser) with very little +//! modification, original code at: +//! `torch/csrc/jit/passes/tensorexpr_fuser.cpp:prepareFusionGroupAndGuardOutputs` +//! +//! We have the assumption that LLGA does not have operators +//! depending on the content of the tensor. +void prepareFusionGroupAndGuardOutputs(Block* block) { + std::vector fusion_groups; + for (Node* n : block->nodes()) { + for (Block* b : n->blocks()) { + prepareFusionGroupAndGuardOutputs(b); + } + if (n->kind() == prim::oneDNNFusionGroup) { + fusion_groups.push_back(n); + } + } + for (Node* fusion_group : fusion_groups) { + // TODO: add further optimization pass to removeOutputsUsedOnlyInSize, + // refer to + // `torch/csrc/jit/passes/tensorexpr_fuser.cpp:removeOutputsUsedOnlyInSize` + // removeOutputsUsedOnlyInSize(fusion_group); + insertTypeGuard( + fusion_group, + [](const TensorTypePtr& t) { return t; }, + prim::oneDNNFusionGuard); + } +} + +} // namespace onednn +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/onednn/guard_shape.h b/torch/csrc/jit/codegen/onednn/guard_shape.h new file mode 100644 index 00000000000..46f8a396a16 --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/guard_shape.h @@ -0,0 +1,15 @@ +#pragma once + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +void prepareFusionGroupAndGuardOutputs(Block* block); + +} // namespace onednn +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/onednn/interface.cpp b/torch/csrc/jit/codegen/onednn/interface.cpp new file mode 100644 index 00000000000..ef525f99e2c --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/interface.cpp @@ -0,0 +1,172 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +void fuseGraph(std::shared_ptr& g) { + // Follow the process of the tensorexpr_fuser in profiling mode: + // Remove prim::profile nodes and embed the profile info directly in the + // IR in value types to avoid breaking the fusion patterns. + // Will add shape guard after LLGA optimization passes and + // wipe the tensor type information from the IR, so that it's not + // accidentally used by any other pass. + + // We rely on the shape specialization and shape guard to ensure the validity + // of the cached compilation in the kernel, thus only support profiling mode. + // TODO: add check on oneDNNFusionGroup to ensure allShapesAreKnown on nodes + // to fuse: torch/csrc/jit/passes/tensorexpr_fuser.cpp: allShapesAreKnown + if (getProfilingMode()) { + GRAPH_DUMP( + "Before RemoveProfileNodesAndSpecializeTypes. Beginning of LLGA " + "optimization pass", + g); + RemoveProfileNodesAndSpecializeTypes(g); + GRAPH_DUMP( + "After RemoveProfileNodesAndSpecializeTypes. Before mutation removal", + g); + + RemoveTensorMutation(g, [](Node* nodeToFunctionalize) { + static std::unordered_set supportedOps = { + aten::add_, + aten::mul_, + aten::tanh_, + aten::elu_, + aten::relu_, + aten::relu6_, + aten::gelu_, + aten::sqrt_, + aten::sigmoid_, + aten::hardtanh_, + aten::abs_, + aten::square_, + }; + return supportedOps.count(nodeToFunctionalize->kind()) != 0; + }); + RemoveListMutation(g); + GRAPH_DUMP("After mutation removal. Before PrepareBinaryForLLGA", g); + PrepareBinaryForLLGA(g); + GRAPH_DUMP("After PrepareBinaryForLLGA. Before DeferSizeCheck", g); + DeferSizeCheck(g); + GRAPH_DUMP("After DeferSizeCheck. Before CreateLlgaSubgraphs", g); + CreateLlgaSubgraphs(g); + GRAPH_DUMP("After CreateLlgaSubgraphs. Before PropagateLayout", g); + PropagateLayout(g); + GRAPH_DUMP( + "After PropagateLayout. Before prepareFusionGroupAndGuardOutputs", g); + + // Add shape guard for profiling mode and wipe the tensor type information + // from the IR + prepareFusionGroupAndGuardOutputs(g->block()); + GRAPH_DUMP( + "After prepareFusionGroupAndGuardOutputs. Before " + "RemoveTensorTypeSpecializations", + g); + RemoveTensorTypeSpecializations(g); + GRAPH_DUMP( + "After RemoveTensorTypeSpecializations. End of LLGA optimization pass", + g); + } +} + +} // namespace onednn +} // namespace fuser + +Operation createLlgaKernel(const Node* node) { + auto kernel = std::make_shared(node); + return [kernel](Stack* stack) { + RECORD_FUNCTION(kernel->debugName(), std::vector()); + kernel->run(*stack); + return 0; + }; +} + +RegisterOperators oneDNNFusionGroupOp({ + torch::jit::Operator( + prim::oneDNNFusionGroup, + createLlgaKernel, + AliasAnalysisKind::INTERNAL_SPECIAL_CASE), +}); + +// Currently, we convert some scalar inputs, such as the second argument of +// binary ops to a 1D tensor. Other scalar inputs are prim::Constant nodes. +// But if we have any scalar inputs to guard in the future, some logic here +// would have to be changed. +Operation createLlgaGuardKernel(const Node* node) { + return [node](Stack* stack) { +#ifdef GRAPH_DEBUG_ENABLED + GRAPH_DEBUG("Guarding node: ", node->kind().toQualString()); +#endif + std::vector types = node->tys(attr::types); + const auto num_inputs = types.size(); +#ifdef GRAPH_DEBUG_ENABLED + GRAPH_DEBUG("num_inputs to guard: ", num_inputs); +#endif + for (size_t i = 0; i < num_inputs; i++) { +#ifdef GRAPH_DEBUG_ENABLED + GRAPH_DEBUG("checking input ", i); +#endif + auto& input = peek(stack, i, num_inputs); + const c10::TensorTypePtr& guard_tensor_type = + types[i]->cast(); + + if (!input.isTensor()) { +#ifdef GRAPH_DEBUG_ENABLED + GRAPH_DEBUG("input ", i, " is not a tensor, return false"); +#endif + push(stack, IValue(false)); + return; + } + const at::Tensor& tensor = input.toTensor(); + + // If input tensor is of mkldnn, it's originated from an upstream + // LLGA partition that has passed the check on input shapes. + // It is valid to continue here as long as the output shapes from + // oneDNN graph partitions are determined by the input shapes. + if (tensor.is_mkldnn()) { +#ifdef GRAPH_DEBUG_ENABLED + GRAPH_DEBUG("input ", i, " is_mkldnn, continue"); +#endif + continue; + } + + if (!guard_tensor_type->matchTensor(tensor)) { +#ifdef GRAPH_DEBUG_ENABLED + GRAPH_DEBUG("input ", i, " check failed, return false"); +#endif + push(stack, IValue(false)); + return; + } + } +#ifdef GRAPH_DEBUG_ENABLED + GRAPH_DEBUG("all check done, return true"); +#endif + push(stack, IValue(true)); + return; + }; +} + +RegisterOperators oneDNNGuardOp({ + torch::jit::Operator( + prim::oneDNNFusionGuard, + createLlgaGuardKernel, + AliasAnalysisKind::FROM_SCHEMA), +}); +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/onednn/interface.h b/torch/csrc/jit/codegen/onednn/interface.h new file mode 100644 index 00000000000..e591c1c3b58 --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/interface.h @@ -0,0 +1,62 @@ +#pragma once +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +static std::atomic onednn_enabled{false}; + +std::atomic& getLlgaEnabled() { + return onednn_enabled; +} + +C10_EXPORT void fuseGraph(std::shared_ptr& g); + +} // namespace onednn +} // namespace fuser + +struct C10_EXPORT RegisterLlgaFuseGraph + : public PassManager { + static bool setEnabled(bool enabled) { + TORCH_CHECK( + AT_MKLDNN_ENABLED(), + "Running oneDNN Graph fuser is only supported with MKLDNN builds."); + bool oldState = fuser::onednn::getLlgaEnabled(); + fuser::onednn::getLlgaEnabled() = enabled; + if (enabled) { + registerPass(fuser::onednn::fuseGraph); + } else { + clearPass(); + } + return oldState; + } + + static bool isEnabled() { + return fuser::onednn::getLlgaEnabled(); + } + + // override PassManager::registerPass to register pre-pass + static bool registerPass(GraphPass p) { + if (!isRegistered()) { + passID(registerPrePass(std::move(p)), true); + isRegistered(true); + return false; + } + return true; + } + + // override PassManager::clearPass to clear pre-pass + static void clearPass() { + if (isRegistered()) { + clearPrePass(passID()); + isRegistered(true); + } + } +}; + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/onednn/kernel.cpp b/torch/csrc/jit/codegen/onednn/kernel.cpp new file mode 100644 index 00000000000..a740b8e3a6d --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/kernel.cpp @@ -0,0 +1,257 @@ +#include +#include + +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +using namespace dnnl::graph; + +LlgaKernel::LlgaKernel(const Node* fusionNode) + : fusionNode_(fusionNode), + graph_(fusionNode->g(attr::Subgraph)), + nGraphInputs_(graph_->inputs().size()), + nOutputs_(graph_->outputs().size()), + debugName_(genDebugName()) { + // TODO: This is a workaround to recreate the partitions here. + // The ideal way is to use the partition serialization API (not available from + // LLGA now) to carry a serialized string representation from graph rewrite + // and deserialize it here. + auto llgaGraphHelper = LlgaGraphHelper(graph_); + auto partitions = llgaGraphHelper.getPartitions(); + tensorIdToValue_ = llgaGraphHelper.getTensorIdToValue(); + TORCH_CHECK( + partitions.size() == 1, + "LLGA subgraph should contain only one partition"); + partition_ = partitions[0]; + nPartitionInputs_ = partition_.get_in_ports().size(); +#ifdef GRAPH_DEBUG_ENABLED + GRAPH_DEBUG("Initialized ", debugName(), "\n", graph_->toString()); +#endif +} + +bool LlgaKernel::useOpaqueLayout(size_t offset) const { + return LlgaNodeWrapper(fusionNode_).useOpaqueLayout(offset); +} + +void LlgaKernel::initializeConstantInputs() { + for (auto& lt : partition_.get_in_ports()) { + auto inputId = lt.get_id(); + if (initializedInputIds_.find(inputId) == initializedInputIds_.end()) { + TORCH_CHECK( + tensorIdToValue_.count(inputId) > 0, + "inputs with inputId ", + inputId, + " is missing"); + auto* value = tensorIdToValue_[inputId]; + + TORCH_CHECK( + value->node()->kind() == prim::Constant && + value->type()->cast(), + "inputs with inputId ", + inputId, + " should be a Constant tensor"); + constantValues_.emplace_back(value); + + auto const_tensor = toIValue(value)->toTensor(); + constantInputs_.emplace_back(const_tensor); + } + } +} + +std::map LlgaKernel::initializeTensorIdToOccurence() const { + std::map tensorIdToOccurence; + for (auto& lt : partition_.get_in_ports()) { + auto inputId = lt.get_id(); + std::map::iterator it(tensorIdToOccurence.find(inputId)); + if (it != tensorIdToOccurence.end()) { + it->second++; + } else { + tensorIdToOccurence[inputId] = 1; + } + } + return tensorIdToOccurence; +} + +ArgSpecs LlgaKernel::initializeInputSpecs(const TensorArgs& inputs) { + ArgSpecs inputSpecs; + inputSpecs.reserve(nPartitionInputs_); + GRAPH_DEBUG("Initializing graph input logical tensors"); + std::map tensorIdToOccurence = + initializeTensorIdToOccurence(); + for (size_t i = 0; i < nGraphInputs_; i++) { + auto spec = ArgSpec(graph_->inputs()[i]).supplementTensorInfo(inputs[i]); + initializedInputIds_.insert(spec.tid()); + int64_t occurence = tensorIdToOccurence[spec.tid()]; + inputSpecs.insert(inputSpecs.end(), occurence, spec); + runArgsIdx_.insert(runArgsIdx_.end(), occurence, i); + } + GRAPH_DEBUG("Initializing constant input tensors"); + initializeConstantInputs(); + + TORCH_CHECK( + inputSpecs.size() + constantValues_.size() == nPartitionInputs_, + "Partition inputs are missing"); + GRAPH_DEBUG( + "Concatenating constant input logical tensors to graph input " + "logical tensors"); + for (Value* constant_value : constantValues_) { + ArgSpec constantInputSpec(constant_value); + inputSpecs.emplace_back(constantInputSpec); + constantLogicalTensors_.emplace_back(constantInputSpec.logical_tensor()); + } + return inputSpecs; +} + +ArgSpecs LlgaKernel::initializeOutputSpecs() const { + ArgSpecs outputSpecs; + outputSpecs.reserve(nOutputs_); + for (size_t i = 0; i < nOutputs_; i++) { + auto spec = ArgSpec(graph_->outputs()[i]); + if (useOpaqueLayout(i)) { + spec = spec.any(); + } + outputSpecs.emplace_back(spec); + } + return outputSpecs; +} + +std::tuple LlgaKernel::prepareRunArgs( + const TensorArgs& inputs, + TensorArgs& outputs) const { + RunArgs runInputs, runOutputs; + auto numInputs = runArgsIdx_.size(); + for (size_t i = 0; i < numInputs; i++) { + auto spec = inputSpecs_[i]; + auto input = inputs[runArgsIdx_[i]]; + runInputs.push_back( + {spec.logical_tensor(), Engine::getEngine(), input.data_ptr()}); + } + auto numConstantInputs = constantInputs_.size(); + for (size_t i = 0; i < numConstantInputs; i++) { + // constantInputSpecs are placed after graphInputSpecs + auto constantInputSpecIdx = nGraphInputs_ + i; + auto constantInputSpec = inputSpecs_[constantInputSpecIdx]; + runInputs.push_back( + {constantLogicalTensors_[i], + Engine::getEngine(), + constantInputs_[i].data_ptr()}); + } + + for (size_t i = 0; i < nOutputs_; i++) { + auto spec = outputSpecs_[i]; + auto opt = c10::TensorOptions(spec.aten_scalar_type()).device(device_); + + if (spec.reuses_input_tensor()) { + auto inputTensor = inputs[spec.get_input_tensor_index()]; + outputs.push_back(inputTensor); + runOutputs.push_back( + {spec.logical_tensor(), Engine::getEngine(), inputTensor.data_ptr()}); + } else if (spec.is_opaque()) { + auto tensor = empty_llga(spec, opt); + outputs.push_back(tensor); + runOutputs.push_back(llga_from_aten_tensor(tensor)); + } else { + auto tensor = at::empty_strided(spec.sizes(), spec.strides(), opt); + outputs.push_back(tensor); + runOutputs.push_back( + {spec.logical_tensor(), Engine::getEngine(), tensor.data_ptr()}); + } + } + + return std::make_tuple(runInputs, runOutputs); +} + +compiled_partition LlgaKernel::compile(const partition& partition) { + auto inputs = fmap(inputSpecs_, toLogicalTensor); + auto outputs = fmap(outputSpecs_, toLogicalTensor); + auto compilation = partition.compile(inputs, outputs, Engine::getEngine()); + + // Since layouts of opaque outputs would be known after compilation, + // we need to query them out from compilation and update outputSpecs + for (size_t i = 0; i < nOutputs_; i++) { + auto tid = outputSpecs_[i].tid(); + outputSpecs_[i] = compilation.query_logical_tensor(tid); + } + + // Build static mapping from output id to input offset + // in accordance with available inplace options + for (auto&& option : compilation.get_inplace_ports()) { + size_t inputId = option.first; + size_t outputId = option.second; + auto inputSpecIter = + std::find_if(inputSpecs_.begin(), inputSpecs_.end(), [&](auto& spec) { + return spec.tid() == inputId; + }); + TORCH_CHECK(inputSpecIter != inputSpecs_.end(), "In-place input not found"); + auto inputOffset = inputSpecIter - inputSpecs_.begin(); + auto outputSpecIter = + std::find_if(outputSpecs_.begin(), outputSpecs_.end(), [&](auto& spec) { + return spec.tid() == outputId; + }); + auto outputOffset = outputSpecIter - outputSpecs_.begin(); + outputSpecs_[outputOffset].set_compute_inplace(); + outputSpecs_[outputOffset].set_input_tensor_index(inputOffset); + } + + return compilation; +} + +void LlgaKernel::run(Stack& stack) { +#ifdef GRAPH_DEBUG_ENABLED + GRAPH_DEBUG("In ", debugName(), "\n"); +#endif + + // Grab input values from stack + auto stackInputs = last(stack, nGraphInputs_); + auto inputs = fmap(stackInputs, [&](const IValue& v) { + TORCH_CHECK( + v.isTensor(), "Stack values for LLGA partition must be Tensor type"); + return v.toTensor(); + }); + + // Even in case of concurrent threads, the kernel would be initialized once. + // TODO: Try not using an atomic lock + std::call_once( + initialized_flag, + [&](const TensorArgs& inputs) { + GRAPH_DEBUG("Initializing input logical tensors"); + inputSpecs_ = initializeInputSpecs(inputs); + GRAPH_DEBUG("Initializing output logical tensors"); + outputSpecs_ = initializeOutputSpecs(); + GRAPH_DEBUG("Compiling partition"); + compilation_ = compile(partition_); + is_initialized_ = true; + }, + inputs); +#ifdef GRAPH_DEBUG_ENABLED + GRAPH_DEBUG("Preparing runtime tensors"); +#endif + TensorArgs outputs; + RunArgs runInputs, runOutputs; + std::tie(runInputs, runOutputs) = prepareRunArgs(inputs, outputs); +#ifdef GRAPH_DEBUG_ENABLED + GRAPH_DEBUG("Executing partition"); +#endif + compilation_.execute(Stream::getStream(), runInputs, runOutputs); +#ifdef GRAPH_DEBUG_ENABLED + GRAPH_DEBUG("Partition executed"); +#endif + + // Update the stack. + drop(stack, nGraphInputs_); + for (auto& o : outputs) + push_one(stack, std::move(o)); +#ifdef GRAPH_DEBUG_ENABLED + GRAPH_DEBUG("Stack updated"); +#endif +} + +} // namespace onednn +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/onednn/kernel.h b/torch/csrc/jit/codegen/onednn/kernel.h new file mode 100644 index 00000000000..a9c7b24ad8c --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/kernel.h @@ -0,0 +1,93 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +using ArgSpec = LlgaTensorDesc; +using ArgSpecs = std::vector; +using RunArg = dnnl::graph::tensor; +using RunArgs = std::vector; +using TensorArgs = std::vector; + +class LlgaKernel { + public: + explicit LlgaKernel(const Node* fusionNode); + + void run(Stack& stack); + + void initialize(const TensorArgs& inputs); + + const std::string& debugName() const { + return debugName_; + } + + private: + bool useOpaqueLayout(size_t offset) const; + + // PyTorch copy constants inside the subgraph instead of referencing them. + // Constants inputs to the partition are no longer in the graph->inputs(). + // Need use the tid retrieved from the partition to find the missing + // constant inputs. + void initializeConstantInputs(); + + ArgSpecs initializeInputSpecs(const TensorArgs& inputs); + + ArgSpecs initializeOutputSpecs() const; + + dnnl::graph::compiled_partition compile( + const dnnl::graph::partition& partition); + + std::map initializeTensorIdToOccurence() const; + + std::tuple prepareRunArgs( + const TensorArgs& inputs, + TensorArgs& outputs) const; + + static std::string genDebugName() { + static size_t debugId = 0; + return "LlgaPartition_" + std::to_string(debugId++); + } + + static dnnl::graph::logical_tensor toLogicalTensor(const ArgSpec& s) { + return s.logical_tensor(); + } + + at::Device device_ = at::kCPU; + const Node* fusionNode_; + std::shared_ptr graph_; + int64_t nGraphInputs_ = 0; // number of inputs to graph_ on the IR + int64_t nOutputs_ = 0; + std::map tensorIdToValue_; + std::vector runArgsIdx_; + dnnl::graph::partition partition_; + // nPartitionInputs_ is the actual number of inputs to partition_ of graph_ + // needed by the backend. + // nPartitionInputs_ = nGraphInputs_ + constantInputs_.size() since Constant + // inputs are copied to the inside of the subgraph + int64_t nPartitionInputs_; + dnnl::graph::compiled_partition compilation_; + std::set initializedInputIds_; + std::vector constantValues_; + TensorArgs constantInputs_; + ArgSpecs inputSpecs_; + ArgSpecs outputSpecs_; + std::vector constantLogicalTensors_; + std::string debugName_; + std::once_flag initialized_flag; + bool is_initialized_ = false; +}; + +} // namespace onednn +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/onednn/layout_propagation.cpp b/torch/csrc/jit/codegen/onednn/layout_propagation.cpp new file mode 100644 index 00000000000..448e1cf8588 --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/layout_propagation.cpp @@ -0,0 +1,44 @@ +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +void LayoutPropagation(Node* n) { + if (!LlgaGraphHelper::isLlgaSubgraph(n)) + return; + + for (auto input : n->inputs()) { + auto prev = input->node(); + auto offset = input->offset(); + if (LlgaGraphHelper::isLlgaSubgraph(prev)) { + bool useOpaqueLayout = true; + for (auto& use : input->uses()) { + if (!LlgaGraphHelper::isLlgaSubgraph(use.user)) { + useOpaqueLayout = false; + break; + } + } + if (useOpaqueLayout) { + LlgaNodeWrapper(prev).setOpaqueLayout(offset); + } + } + } +} + +void LayoutPropagation(at::ArrayRef blocks) { + for (Block* block : blocks) + for (Node* node : block->nodes()) + LayoutPropagation(node); +} + +void PropagateLayout(const std::shared_ptr& graph) { + LayoutPropagation(graph->block()); +} + +} // namespace onednn +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/onednn/layout_propagation.h b/torch/csrc/jit/codegen/onednn/layout_propagation.h new file mode 100644 index 00000000000..5e48a097cd4 --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/layout_propagation.h @@ -0,0 +1,15 @@ +#pragma once + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +void PropagateLayout(const std::shared_ptr& graph); + +} // namespace onednn +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/onednn/operator.h b/torch/csrc/jit/codegen/onednn/operator.h new file mode 100644 index 00000000000..2affa9398ff --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/operator.h @@ -0,0 +1,103 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +class Operator { + public: + Operator(const Node* node, dnnl::graph::op::kind kind) + : n(node), o(getId(node), kind, node->kind().toQualString()), k(kind) { + setAttr("data_format", std::string("NCX")); + } + + Operator& setInputValue(Value* v) { + if (v->mustNotBeNone()) + o.add_input(createLogicalTensor(v)); + return *this; + } + + Operator& setInput(size_t offset) { + return setInputValue(n->input(offset)); + } + + template + Operator& setInput(size_t offset, Ts... other) { + setInput(offset); + return setInput(other...); + } + + Operator& setOutputValue(Value* v) { + if (v->mustNotBeNone()) + o.add_output(createLogicalTensor(v)); + return *this; + } + + Operator& setOutput(size_t offset) { + return setOutputValue(n->output(offset)); + } + + template + Operator& setOutput(size_t offset, Ts... other) { + setOutput(offset); + return setOutput(other...); + } + + template + Operator& setAttr(std::string name, Attr&& attr) { + o.set_attr(name, std::forward(attr)); + return *this; + } + + template + Operator& setAttr(std::string name, const F& fn, size_t offset) { + return setAttr(name, fn(n, offset)); + } + + static std::vector Ints(const Node* node, size_t offset) { + return toIValue(node->input(offset))->toIntVector(); + } + + static int64_t Int(const Node* node, size_t offset) { + return toIValue(node->input(offset))->toInt(); + } + + static float Float(const Node* node, size_t offset) { + return static_cast(toIValue(node->input(offset))->toDouble()); + } + + static bool Bool(const Node* node, size_t offset) { + return toIValue(node->input(offset))->toBool(); + } + + static uint64_t getId(const Node* node) { + return reinterpret_cast(node); // cast node address as op id + } + + dnnl::graph::op::kind kind() const { + return k; + } + + dnnl::graph::op llgaOp() const { + return o; + } + + private: + dnnl::graph::logical_tensor createLogicalTensor(Value* value) const { + return LlgaTensorDesc(value).logical_tensor(); + } + + const Node* n; + dnnl::graph::op o; + dnnl::graph::op::kind k; +}; + +} // namespace onednn +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/onednn/prepare_binary.cpp b/torch/csrc/jit/codegen/onednn/prepare_binary.cpp new file mode 100644 index 00000000000..5704f598eb4 --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/prepare_binary.cpp @@ -0,0 +1,106 @@ +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +bool compareConstValue(Value* v, double d) { + auto ival = toIValue(v); + return ival.has_value() && + ((ival->isInt() && static_cast(ival->toInt()) == d) || + (ival->isDouble() && ival->toDouble() == d)); +} + +void mayConvertScalarInputToTensor(Node* node) { + // We do not handle binary ops with two scalar inputs, + // and we assume scalar is always at the second place. + if (node->input(0)->type()->isSubtypeOf(TensorType::get()) && + (node->input(1)->type()->isSubtypeOf(FloatType::get()) || + node->input(1)->type()->isSubtypeOf(IntType::get()))) { + auto scalar = node->input(1); + WithInsertPoint guard(node); + auto g = node->owningGraph(); + // 42 : Scalar --> tensor(42.0) : Float([]) + auto t = g->insert( + aten::as_tensor, {scalar}, {{"dtype", at::ScalarType::Float}}); + // add dim & stride info to IR + c10::optional t_dim = 1; + auto target_type = TensorTypePtr( + TensorType::create(at::ScalarType::Float, at::kCPU, t_dim, false)); + target_type = target_type->withSizes({1}); + t->setType(target_type); + + // tensor(42.0) : Float([]) --> tensor([42.0]) : Float([1]) + auto unsqueezed = g->insert(aten::unsqueeze, {t, 0}); + unsqueezed->setType(target_type); + node->replaceInput(1, unsqueezed); + } +} + +static void ConvertScalarToTensor(Block* block) { + for (auto node : block->nodes()) { + for (auto sub : node->blocks()) { + ConvertScalarToTensor(sub); + } + + if (node->kind() == aten::add || node->kind() == aten::mul) { + mayConvertScalarInputToTensor(node); + } + } +} + +void mayDecomposeAdd(Node* node) { + if (toIValue(node->namedInput("alpha")).has_value()) { + auto alphaEqualsOne = compareConstValue(node->namedInput("alpha"), 1.0); + if (!alphaEqualsOne) { + WithInsertPoint guard(node); + auto g = node->owningGraph(); + auto mul = g->insert( + aten::mul, {node->namedInput("other"), node->namedInput("alpha")}); + node->replaceInput(1, mul); + auto one = g->insertConstant(1.0); + node->replaceInput(2, one); + } + } +} + +static void DecomposeFusedAdd(Block* block) { + for (auto node : block->nodes()) { + for (auto sub : node->blocks()) { + DecomposeFusedAdd(sub); + } + + if (node->kind() == aten::add) { + mayDecomposeAdd(node); + } + } +} + +static void EliminateIdentityMulAdd(Block* block) { + for (auto node : block->nodes()) { + for (auto sub : node->blocks()) { + EliminateIdentityMulAdd(sub); + } + + if ((node->kind() == aten::add && compareConstValue(node->input(1), 0.0)) || + (node->kind() == aten::mul && compareConstValue(node->input(1), 1.0))) { + node->output()->replaceAllUsesWith(node->namedInput("self")); + } + } +} + +void PrepareBinaryForLLGA(const std::shared_ptr& graph) { + DecomposeFusedAdd(graph->block()); + EliminateIdentityMulAdd(graph->block()); + EliminateDeadCode(graph); + // ConvertScalarToTensor must be placed after EliminateIdentityMulAdd + ConvertScalarToTensor(graph->block()); +} + +} // namespace onednn +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/onednn/prepare_binary.h b/torch/csrc/jit/codegen/onednn/prepare_binary.h new file mode 100644 index 00000000000..d7f90002e8f --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/prepare_binary.h @@ -0,0 +1,26 @@ +#pragma once + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +// Prepare binary ops for LLGA +// +// The pass does the following: +// +// - Convert scalar input of aten::add and aten::mul into Float tensor with +// dimension [1] +// +// - Decompose fused add into aten::mul + aten::add when alpha != 1.0 +// +// - Eliminate identity add/mul, i.e., tensor + 0, tensor * 1 +// +void PrepareBinaryForLLGA(const std::shared_ptr& graph); + +} // namespace onednn +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/onednn/register_interface.cpp b/torch/csrc/jit/codegen/onednn/register_interface.cpp new file mode 100644 index 00000000000..62400944cc9 --- /dev/null +++ b/torch/csrc/jit/codegen/onednn/register_interface.cpp @@ -0,0 +1,54 @@ +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +bool canFuseNode(const Node* node) { + switch (node->kind()) { + case aten::conv2d: + case aten::_convolution: + case aten::batch_norm: + case aten::layer_norm: + case aten::add: + case aten::mul: + case aten::tanh: + case aten::relu: + case aten::elu: + case aten::sigmoid: + case aten::gelu: + case aten::sqrt: + case aten::abs: + case aten::square: + case aten::hardtanh: + case aten::relu6: + case aten::softmax: + case aten::max_pool2d: + case aten::avg_pool2d: + case aten::matmul: + case aten::mm: + case aten::linear: + case aten::addmm: + return true; + + default: + return false; + } +} + +namespace { +class RegisterInterface { + public: + RegisterInterface() { + RegisterProfilingNode(canFuseNode); + } +}; + +static RegisterInterface register_interface_; +} // namespace + +} // namespace onednn +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index ba26ff157c4..b88fa36d645 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -623,6 +623,7 @@ void AliasDb::analyzeImpl(Node* node) { return analyzeLoop(node); case prim::FusionGroup: case prim::CudaFusionGroup: + case prim::oneDNNFusionGroup: case prim::FunctionalGraph: case prim::DifferentiableGraph: case prim::FallbackGraph: diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index 19c7a0c7745..ee51e5c29e7 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -508,6 +508,7 @@ void Node::lint() const { break; case prim::FusionGroup: case prim::CudaFusionGroup: + case prim::oneDNNFusionGroup: checkSameDevice(this); // TODO: Typecheck the parameters g(attr::Subgraph)->lint(); diff --git a/torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp index c5d91391f43..f8d63e87f07 100644 --- a/torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp +++ b/torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp @@ -21,7 +21,8 @@ bool canRunWithAutograd(Node* node) { } return kind != prim::FusionGroup && kind != prim::CudaFusionGroup && kind != prim::TypeCheck && kind != prim::TensorExprGroup && - kind != prim::CudaFusionGuard && (kind.is_aten() || kind.is_prim()); + kind != prim::CudaFusionGuard && kind != prim::oneDNNFusionGroup && + kind != prim::oneDNNFusionGuard && (kind.is_aten() || kind.is_prim()); } namespace { diff --git a/torch/csrc/jit/passes/onednn_graph_fuser.h b/torch/csrc/jit/passes/onednn_graph_fuser.h new file mode 100644 index 00000000000..aeb79470b01 --- /dev/null +++ b/torch/csrc/jit/passes/onednn_graph_fuser.h @@ -0,0 +1,64 @@ +#pragma once + +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace onednn { + +static std::atomic onednn_enabled{true}; + +static std::atomic& getLlgaEnabled() { + return onednn_enabled; +} + +TORCH_API void fuseGraph(std::shared_ptr& g); + +} // namespace onednn +} // namespace fuser + +struct C10_EXPORT RegisterLlgaFuseGraph + : public PassManager { + static bool setEnabled(bool enabled) { + TORCH_CHECK( + AT_MKLDNN_ENABLED(), + "Running oneDNN Graph fuser is only supported with MKLDNN builds."); + bool oldState = fuser::onednn::getLlgaEnabled(); + fuser::onednn::getLlgaEnabled() = enabled; + if (enabled) { + registerPass(fuser::onednn::fuseGraph); + } else { + clearPass(); + } + return oldState; + } + + static bool isEnabled() { + return fuser::onednn::getLlgaEnabled(); + } + + // override PassManager::registerPass to register pre-pass + static bool registerPass(GraphPass p) { + if (!isRegistered()) { + passID(registerPrePass(std::move(p)), true); + isRegistered(true); + return false; + } + return true; + } + + // override PassManager::clearPass to clear pre-pass + static void clearPass() { + if (isRegistered()) { + clearPrePass(passID()); + isRegistered(true); + } + } +}; + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 29188f33765..d8a7d806172 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -8,6 +8,9 @@ #include #include #include +#if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH)) +#include +#endif #include #include #include @@ -663,6 +666,10 @@ void initJITBindings(PyObject* module) { return oldState; }) .def("_jit_nvfuser_enabled", &RegisterCudaFuseGraph::isRegistered) +#if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH)) + .def("_jit_set_llga_enabled", &RegisterLlgaFuseGraph::setEnabled) + .def("_jit_llga_enabled", &RegisterLlgaFuseGraph::isEnabled) +#endif .def( "_jit_nvfuser_set_comparison_callback", [](bool run_fallback, py::function fn) { diff --git a/torch/csrc/jit/runtime/operator.cpp b/torch/csrc/jit/runtime/operator.cpp index 373196c112d..d005d1b100b 100644 --- a/torch/csrc/jit/runtime/operator.cpp +++ b/torch/csrc/jit/runtime/operator.cpp @@ -246,6 +246,8 @@ bool printerHasSpecialCaseFor(Symbol sym) { prim::StaticSubgraph, // optimization pass adds it prim::ConstantMKLDNNTensor, // optimization pass adds it prim::BroadcastMKLDNNTensors, // optimization pass adds it + prim::oneDNNFusionGroup, // optimization pass adds it + prim::oneDNNFusionGuard, // optimization pass adds it prim::StaticRuntimeCopyOuts, // used in SR only prim::Load, // used in interpreter only prim::MMTreeReduce, // used as an optimization @@ -282,6 +284,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) { prim::Loop, prim::FusionGroup, prim::CudaFusionGroup, + prim::oneDNNFusionGroup, prim::DifferentiableGraph, prim::TensorExprGroup, prim::TensorExprDynamicGroup, diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index f3a66bd9d1d..5f3ab73324e 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -229,7 +229,19 @@ def _hide_source_ranges() -> Iterator[None]: finally: torch._C.Graph.set_global_print_source_ranges(old_enable_source_ranges) # type: ignore[attr-defined] -# dont expose Any, TODO: define `__all__` +def enable_onednn_fusion(enabled: bool): + """ + Enables or disables onednn JIT fusion based on the parameter `enabled`. + """ + + torch._C._jit_set_llga_enabled(enabled) + +def onednn_fusion_enabled(): + """ + Returns whether onednn JIT fusion is enabled + """ + return torch._C._jit_llga_enabled() + del Any if not torch._C._jit_init():