diff --git a/BUILD.bazel b/BUILD.bazel index d9780aa23c3..ba509759adc 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -6,7 +6,7 @@ load("//third_party:substitution.bzl", "header_template_rule") load("//:tools/build_variables.bzl", "jit_core_sources", "libtorch_core_sources", "libtorch_cuda_sources", "libtorch_distributed_sources", "libtorch_extra_sources", "libtorch_nvfuser_generated_headers", "libtorch_nvfuser_runtime_sources", "libtorch_python_core_sources", "torch_cpp_srcs") load("//tools/rules:cu.bzl", "cu_library") load("//tools/config:defs.bzl", "if_cuda") -load("//:aten.bzl", "intern_build_aten_ops", "generate_aten") +load("//:aten.bzl", "intern_build_aten_ops", "generate_aten", "aten_ufunc_generated_cpu_sources", "aten_ufunc_generated_cpu_kernel_sources", "aten_ufunc_generated_cuda_sources") COMMON_COPTS = [ "-DHAVE_MALLOC_USABLE_SIZE=1", @@ -94,9 +94,14 @@ generated_cuda_cpp = [ generate_aten( name = "generated_aten_cpp", srcs = aten_generation_srcs, - outs = generated_cpu_cpp + generated_cuda_cpp + [ - "aten/src/ATen/Declarations.yaml", - ], + outs = ( + generated_cpu_cpp + + generated_cuda_cpp + + aten_ufunc_generated_cpu_sources("aten/src/ATen/{}") + + aten_ufunc_generated_cpu_kernel_sources("aten/src/ATen/{}") + + aten_ufunc_generated_cuda_sources("aten/src/ATen/{}") + + ["aten/src/ATen/Declarations.yaml"] + ), generator=":gen", ) @@ -301,7 +306,9 @@ filegroup( "aten/src/ATen/native/cuda/*.cu", "aten/src/ATen/native/quantized/cuda/*.cu", "aten/src/ATen/native/sparse/cuda/*.cu", - ]), + ]) + aten_ufunc_generated_cuda_sources("aten/src/ATen/{}"), + # It's a bit puzzling to me why it's not necessary to declare the + # target that generates these sources... ) header_template_rule( @@ -383,6 +390,7 @@ intern_build_aten_ops( "@fbgemm", "@mkl", ], + extra_impls = aten_ufunc_generated_cpu_kernel_sources("aten/src/ATen/{}"), ) cc_library( @@ -400,7 +408,7 @@ cc_library( ":aten_native_sparse_cpp", ":aten_native_xnnpack", ":aten_src_ATen_config", - ] + generated_cpu_cpp, + ] + generated_cpu_cpp + aten_ufunc_generated_cpu_sources("aten/src/ATen/{}"), copts = ATEN_COPTS, data = if_cuda( [":libcaffe2_nvrtc.so"], diff --git a/aten.bzl b/aten.bzl index eccdb4b4d0c..c97f22284f1 100644 --- a/aten.bzl +++ b/aten.bzl @@ -1,5 +1,6 @@ load("@bazel_skylib//lib:paths.bzl", "paths") load("@rules_cc//cc:defs.bzl", "cc_library") +load("//:tools/build_variables.bzl", "aten_ufunc_headers") CPU_CAPABILITY_NAMES = ["DEFAULT", "AVX2"] CAPABILITY_COMPILER_FLAGS = { @@ -8,8 +9,9 @@ CAPABILITY_COMPILER_FLAGS = { } PREFIX = "aten/src/ATen/native/" +EXTRA_PREFIX = "aten/src/ATen/" -def intern_build_aten_ops(copts, deps): +def intern_build_aten_ops(copts, deps, extra_impls): for cpu_capability in CPU_CAPABILITY_NAMES: srcs = [] for impl in native.glob( @@ -28,6 +30,17 @@ def intern_build_aten_ops(copts, deps): ) srcs.append(out) + for impl in extra_impls: + name = impl.replace(EXTRA_PREFIX, "") + out = EXTRA_PREFIX + name + "." + cpu_capability + ".cpp" + native.genrule( + name = name + "_" + cpu_capability + "_cp", + srcs = [impl], + outs = [out], + cmd = "cp $< $@", + ) + srcs.append(out) + cc_library( name = "ATen_CPU_" + cpu_capability, srcs = srcs, @@ -81,3 +94,32 @@ generate_aten = rule( "srcs": attr.label_list(allow_files = True), }, ) + +# copy pasted from ufunc_defs.bzl, as ufuncs_defs.bzl cannot be included +# from BUILD.bazel because it has a directory relative load, and Bazel +# always load from workspace root. The "correct" fix would be to move +# build_variables.bzl to the top level but I don't have time to do this at +# the moment. + +aten_ufunc_names = [ + paths.split_extension(paths.basename(h))[0] + for h in aten_ufunc_headers +] + +def aten_ufunc_generated_cpu_sources(gencode_pattern = "{}"): + return [gencode_pattern.format(name) for name in [ + "UfuncCPU_{}.cpp".format(n) + for n in aten_ufunc_names + ]] + +def aten_ufunc_generated_cpu_kernel_sources(gencode_pattern = "{}"): + return [gencode_pattern.format(name) for name in [ + "UfuncCPUKernel_{}.cpp".format(n) + for n in aten_ufunc_names + ]] + +def aten_ufunc_generated_cuda_sources(gencode_pattern = "{}"): + return [gencode_pattern.format(name) for name in [ + "UfuncCUDA_{}.cu".format(n) + for n in aten_ufunc_names + ]] diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index bdd6c87403e..437835d7a86 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -232,10 +232,9 @@ CREATE_COMPARISON_SCALAR_TENSOR_META_FUNC(ge); namespace native { -DEFINE_DISPATCH(add_stub); DEFINE_DISPATCH(add_clamp_stub); -DEFINE_DISPATCH(sub_stub); DEFINE_DISPATCH(mul_stub); +DEFINE_DISPATCH(sub_stub); DEFINE_DISPATCH(div_true_stub); DEFINE_DISPATCH(div_floor_stub); DEFINE_DISPATCH(div_trunc_stub); @@ -277,17 +276,10 @@ DEFINE_DISPATCH(xlogy_stub); DEFINE_DISPATCH(xlog1py_stub); DEFINE_DISPATCH(zeta_stub); -TORCH_IMPL_FUNC(add_out) ( - const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& result -) { - add_stub(device_type(), *this, alpha); - TORCH_INTERNAL_ASSERT(result.scalar_type() == output().dtype()); -} - TORCH_IMPL_FUNC(sub_out) ( const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& result ) { - sub_stub(device_type(), *this, alpha); + add_stub(device_type(), *this, -alpha); TORCH_INTERNAL_ASSERT(result.scalar_type() == output().dtype()); } diff --git a/aten/src/ATen/native/BinaryOps.h b/aten/src/ATen/native/BinaryOps.h index 4bdf587f0bd..f34f210c4e4 100644 --- a/aten/src/ATen/native/BinaryOps.h +++ b/aten/src/ATen/native/BinaryOps.h @@ -50,7 +50,9 @@ using binary_fn = void(*)(TensorIterator&); using binary_clamp_fn_alpha = void(*)(TensorIterator&, const Scalar& alpha, const Scalar& min_val, const Scalar& max_val); +// NB: codegenned DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub); + DECLARE_DISPATCH(binary_clamp_fn_alpha, add_clamp_stub); DECLARE_DISPATCH(structured_binary_fn_alpha, sub_stub); DECLARE_DISPATCH(structured_binary_fn, mul_stub); diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index d383849e290..0e5db26b069 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -21,27 +21,6 @@ namespace { using namespace vec; -// Note: Undefined behavior when performing addition is intentionally -// ignored. -void add_kernel(TensorIteratorBase& iter, const Scalar& alpha_scalar) { - if (iter.dtype() == ScalarType::Bool) { - using scalar_t = bool; - auto alpha = alpha_scalar.to(); - cpu_kernel(iter, - [=](scalar_t a, scalar_t b) __ubsan_ignore_undefined__ -> scalar_t { return a + alpha * b; }); - } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "add_cpu/sub_cpu", [&]() { - auto alpha = alpha_scalar.to(); - auto alpha_vec = Vectorized(alpha); - cpu_kernel_vec(iter, - [=](scalar_t a, scalar_t b) __ubsan_ignore_undefined__ -> scalar_t { return a + alpha * b; }, - [=](Vectorized a, Vectorized b) __ubsan_ignore_undefined__ { - return vec::fmadd(b, alpha_vec, a); - }); - }); - } -} - void add_clamp_kernel(TensorIterator& iter, const Scalar& alpha_scalar, const Scalar& min_val, const Scalar& max_val) { AT_DISPATCH_ALL_TYPES(iter.dtype(), "add_clamp_cpu", [&]() { auto alpha = alpha_scalar.to(); @@ -74,12 +53,6 @@ void atan2_kernel(TensorIteratorBase& iter) { }); } -// Note: Undefined behavior when performing subtraction is intentionally -// ignored. -void sub_kernel(TensorIteratorBase& iter, const Scalar& alpha_scalar) __ubsan_ignore_undefined__ { - add_kernel(iter, -alpha_scalar); -} - void mul_kernel(TensorIteratorBase& iter) { if (iter.dtype() == ScalarType::Bool) { cpu_kernel(iter, [=](bool a, bool b) -> bool { return a && b; }); @@ -1133,9 +1106,7 @@ void zeta_kernel(TensorIteratorBase& iter) { } // namespace -REGISTER_DISPATCH(add_stub, &add_kernel); REGISTER_DISPATCH(add_clamp_stub, &add_clamp_kernel); -REGISTER_DISPATCH(sub_stub, &sub_kernel); REGISTER_DISPATCH(mul_stub, &mul_kernel); REGISTER_DISPATCH(div_true_stub, &div_true_kernel); REGISTER_DISPATCH(div_trunc_stub, &div_trunc_kernel); diff --git a/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu b/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu deleted file mode 100644 index 56d6b0acd72..00000000000 --- a/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu +++ /dev/null @@ -1,37 +0,0 @@ -#define TORCH_ASSERT_NO_OPERATORS -#include -#include -#include -#include -#include -#include - -// NOTE: CUDA on Windows requires that the enclosing function -// of a __device__ lambda not have internal linkage. - -namespace at { namespace native { - -template -struct AddFunctor { - AddFunctor(T alpha) : alpha_(alpha) {} - T alpha_; - __device__ __forceinline__ T operator()(T a, T b) const __ubsan_ignore_undefined__ { - return a + b * alpha_; - } -}; - -void add_kernel_cuda(TensorIteratorBase& iter, const Scalar& alpha_scalar) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, iter.common_dtype(), "add_cuda/sub_cuda", [&]() { - using opmath_t = at::opmath_type; - opmath_gpu_kernel_with_scalars(iter, AddFunctor(alpha_scalar.to())); - }); -} - -static void sub_kernel_cuda(TensorIteratorBase& iter, const Scalar& alpha_scalar) { - add_kernel_cuda(iter, -alpha_scalar); -} - -REGISTER_DISPATCH(add_stub, &add_kernel_cuda); -REGISTER_DISPATCH(sub_stub, &sub_kernel_cuda); - -}} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 0b6467775e7..2fce3ebaa11 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -462,8 +462,10 @@ device_check: NoCheck # TensorIterator structured: True structured_inherits: TensorIteratorBase + ufunc_inner_loop: + Generic: add (AllAndComplex, BFloat16, Half) + ScalarOnly: add (Bool) dispatch: - CPU, CUDA: add_out SparseCPU: add_out_sparse_cpu SparseCUDA: add_out_sparse_cuda SparseCsrCPU: add_out_sparse_csr_cpu diff --git a/aten/src/ATen/native/ufunc/add.h b/aten/src/ATen/native/ufunc/add.h new file mode 100644 index 00000000000..94a776728ea --- /dev/null +++ b/aten/src/ATen/native/ufunc/add.h @@ -0,0 +1,27 @@ +#pragma once + +#include + +#if !defined(__CUDACC__) && !defined(__HIPCC__) +#include +#include +#endif + +namespace at { +namespace native { +namespace ufunc { + +template +C10_HOST_DEVICE C10_ALWAYS_INLINE T add(T self, T other, T alpha) __ubsan_ignore_undefined__ { + return self + alpha * other; +} + +#if !defined(__CUDACC__) && !defined(__HIPCC__) +using vec::Vectorized; +template +C10_ALWAYS_INLINE Vectorized add(Vectorized self, Vectorized other, Vectorized alpha) __ubsan_ignore_undefined__ { + return vec::fmadd(other, alpha, self); +} +#endif + +}}} // namespace at::native::ufunc diff --git a/aten/src/ATen/templates/UfuncCPU.cpp b/aten/src/ATen/templates/UfuncCPU.cpp new file mode 100644 index 00000000000..6b363a50890 --- /dev/null +++ b/aten/src/ATen/templates/UfuncCPU.cpp @@ -0,0 +1,19 @@ +#define TORCH_ASSERT_NO_OPERATORS + +#include +#include +#include + +namespace at { + +// NB: this is explicitly copied here (via codegen) rather than +// included via NativeFunctions.h to avoid recompiling this file when +// NativeFunctions.h changes +namespace meta { +${meta_declaration} +} + +namespace native { +${native_declaration} +${native_definitions} +}} // namespace at::native diff --git a/aten/src/ATen/templates/UfuncCPUKernel.cpp b/aten/src/ATen/templates/UfuncCPUKernel.cpp new file mode 100644 index 00000000000..0cac55664d6 --- /dev/null +++ b/aten/src/ATen/templates/UfuncCPUKernel.cpp @@ -0,0 +1,14 @@ +#define TORCH_ASSERT_NO_OPERATORS + +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +${native_definitions} +}} // namespace at::native diff --git a/aten/src/ATen/templates/UfuncCUDA.cu b/aten/src/ATen/templates/UfuncCUDA.cu new file mode 100644 index 00000000000..e75d82d9cc8 --- /dev/null +++ b/aten/src/ATen/templates/UfuncCUDA.cu @@ -0,0 +1,21 @@ +#define TORCH_ASSERT_NO_OPERATORS + +#include +#include +#include +#include +${cuda_headers} + +namespace at { + +// NB: this is explicitly copied here (via codegen) rather than +// included via NativeFunctions.h to avoid recompiling this file when +// NativeFunctions.h changes +namespace meta { +${meta_declaration} +} + +namespace native { +${native_declaration} +${native_definitions} +}} // namespace at::native diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index bb573fc35cc..d4db507a98a 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -150,6 +150,7 @@ if(INTERN_BUILD_ATEN_OPS) include("${CMAKE_BINARY_DIR}/aten/src/ATen/generated_${gen_type}.cmake") include("${CMAKE_BINARY_DIR}/aten/src/ATen/core_generated_${gen_type}.cmake") + include("${CMAKE_BINARY_DIR}/aten/src/ATen/cpu_vec_generated_${gen_type}.cmake") include("${CMAKE_BINARY_DIR}/aten/src/ATen/cuda_generated_${gen_type}.cmake") include("${CMAKE_BINARY_DIR}/aten/src/ATen/ops_generated_${gen_type}.cmake") @@ -161,10 +162,12 @@ if(INTERN_BUILD_ATEN_OPS) ${generated_${gen_type}} ${cuda_generated_${gen_type}} ${core_generated_${gen_type}} + ${cpu_vec_generated_${gen_type}} ${ops_generated_${gen_type}} ${CMAKE_BINARY_DIR}/aten/src/ATen/generated_${gen_type}.cmake ${CMAKE_BINARY_DIR}/aten/src/ATen/ops_generated_${gen_type}.cmake ${CMAKE_BINARY_DIR}/aten/src/ATen/core_generated_${gen_type}.cmake + ${CMAKE_BINARY_DIR}/aten/src/ATen/cpu_vec_generated_${gen_type}.cmake ${CMAKE_BINARY_DIR}/aten/src/ATen/cuda_generated_${gen_type}.cmake COMMAND ${GEN_COMMAND_${gen_type}} DEPENDS ${all_python} ${${gen_type}_templates} @@ -177,8 +180,8 @@ if(INTERN_BUILD_ATEN_OPS) # not tracked correctly in CMake. We make the libATen.so depend explicitly # on building the generated ATen files to workaround. add_custom_target(ATEN_CPU_FILES_GEN_TARGET DEPENDS - ${generated_headers} ${core_generated_headers} ${ops_generated_headers} - ${generated_sources} ${core_generated_sources} ${ops_generated_sources} + ${generated_headers} ${core_generated_headers} ${cpu_vec_generated_headers} ${ops_generated_headers} + ${generated_sources} ${core_generated_sources} ${cpu_vec_generated_sources} ${ops_generated_sources} ${generated_declarations_yaml}) add_custom_target(ATEN_CUDA_FILES_GEN_TARGET DEPENDS ${cuda_generated_headers} ${cuda_generated_sources}) @@ -260,12 +263,11 @@ if(INTERN_BUILD_ATEN_OPS) # The sources list might get reordered later based on the capabilites. # See NOTE [ Linking AVX and non-AVX files ] foreach(i RANGE ${NUM_CPU_CAPABILITY_NAMES}) - foreach(IMPL ${cpu_kernel_cpp_in}) - file(RELATIVE_PATH NAME "${PROJECT_SOURCE_DIR}/aten/src/ATen/" "${IMPL}") + function(process_vec NAME) list(GET CPU_CAPABILITY_NAMES ${i} CPU_CAPABILITY) set(NEW_IMPL ${CMAKE_BINARY_DIR}/aten/src/ATen/${NAME}.${CPU_CAPABILITY}.cpp) configure_file("${PROJECT_SOURCE_DIR}/cmake/IncludeSource.cpp.in" ${NEW_IMPL}) - set(cpu_kernel_cpp ${NEW_IMPL} ${cpu_kernel_cpp}) # Create list of copies + set(cpu_kernel_cpp ${NEW_IMPL} ${cpu_kernel_cpp} PARENT_SCOPE) # Create list of copies list(GET CPU_CAPABILITY_FLAGS ${i} FLAGS) if(MSVC) set(EXTRA_FLAGS "/DCPU_CAPABILITY=${CPU_CAPABILITY} /DCPU_CAPABILITY_${CPU_CAPABILITY}") @@ -284,6 +286,14 @@ if(INTERN_BUILD_ATEN_OPS) endif() endif() set_source_files_properties(${NEW_IMPL} PROPERTIES COMPILE_FLAGS "${FLAGS} ${EXTRA_FLAGS}") + endfunction() + foreach(IMPL ${cpu_kernel_cpp_in}) + file(RELATIVE_PATH NAME "${PROJECT_SOURCE_DIR}/aten/src/ATen/" "${IMPL}") + process_vec("${NAME}") + endforeach() + foreach(IMPL ${cpu_vec_generated_sources}) + file(RELATIVE_PATH NAME "${CMAKE_BINARY_DIR}/aten/src/ATen/" "${IMPL}") + process_vec("${NAME}") endforeach() endforeach() list(APPEND ATen_CPU_SRCS ${cpu_kernel_cpp}) diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 8813544588a..b1cae2b40f0 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -1,5 +1,16 @@ +# WARNING: the contents of this file must BOTH be valid Starlark (for Buck and + +# Bazel) as well as valid Python (for our cmake build). This means that +# load() directives are not allowed (as they are not recognized by Python). +# If you want to fix this, figure out how run this file from cmake with a proper +# Starlark interpreter as part of the default OSS build process. If you need +# some nontrivial Starlark features, make a separate bzl file (remember that + +# bzl files are not exported via ShipIt by default, so you may also need to +# update PyTorch's ShipIt config) + # In both open-source and fbcode builds, these are generated into -# torch/csrc/{autgrad,jit}/generated.i +# torch/csrc/{autograd,jit}/generated.i GENERATED_CPP = [ "autograd/generated/Functions.cpp", "autograd/generated/VariableType_0.cpp", @@ -1065,6 +1076,10 @@ aten_cpu_source_codegen_list = [ "aten/src/ATen/native/cpu/AdaptiveMaxPoolKernel.cpp", ] +aten_ufunc_headers = [ + "aten/src/ATen/native/ufunc/add.h", +] + # When building lite interpreter in OSS, "aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp" will go through # codegen process. The codegen version of this file, like Activation.cpp.DEFAULT.cpp, will be included # in ${cpu_kernel_cpp} in aten/src/ATen/CMakeLists.txt. As a result, in aten/src/ATen/CMakeLists.txt, diff --git a/tools/codegen/api/translate.py b/tools/codegen/api/translate.py index 591b8d75e3b..8342e80a536 100644 --- a/tools/codegen/api/translate.py +++ b/tools/codegen/api/translate.py @@ -5,7 +5,8 @@ from tools.codegen.api.types import (BaseCType, Binding, ConstRefCType, memoryFormatT, tensorOptionsT, scalarTypeT, boolT, deviceT, layoutT, optionalTensorRefT, scalarT, optionalScalarRefT, - VectorCType, longT, intArrayRefT) + VectorCType, longT, intArrayRefT, + scalar_t, opmath_t) # This file implements a small program synthesis engine that implements # conversions between one API to another. @@ -92,9 +93,34 @@ def translate( # While we're at it, do some simple forward inference, looking through # constructors. + # + # NB: When should you do forward inference versus backward inference? + # The general idea: + # + # - Backward inference WHEN the goal gets smaller + # - Forward inference WHEN the hypothesis gets smaller + # + # This helps ensure termination: backward inference starts with a goal + # and tries to make it simpler and simpler until it's trivial; if the + # goal can grow in size, we blow up to a really huge goal size. + # Similarly, with forward inference we take hypotheses and decompose + # them into simpler hypotheses; if hypotheses could expand in size, + # we also have potential nontermination. (In the code below, forward + # inference is only ever carried out at a single step, but you could + # imagine repeated application of forward inference being profitable.) + # + # A good starting point in the literature for exploring more about proof + # search are these lecture notes + # https://www.cs.cmu.edu/~fp/courses/oregon-m10/04-focusing.pdf + # # TODO: My kingdom for a pattern matcher # https://www.python.org/dev/peps/pep-0634/ - # TODO: This could get us in recomputation trouble if b.expr is nontrivial + # + # TODO: This could get us in recomputation trouble if b.expr is nontrivial. + # Fix this by implementing some sort of sharing so that if multiple + # goals share the same expression, we only compute it once. This seems + # to matter in practice as compiler is often unwilling to CSE nontrivial + # expressions like scalar.to() t = b.type if isinstance(t, ConstRefCType) and isinstance(t.elem, OptionalCType) and \ isinstance(t.elem.elem, BaseCType) and str(t.elem.elem.type) == 'at::Tensor': @@ -105,10 +131,16 @@ def translate( ctx[NamedCType(t.name, BaseCType(optionalTensorRefT))] = \ f'(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())' + if t.type == ConstRefCType(BaseCType(scalarT)): + ctx[NamedCType(t.name, BaseCType(opmath_t))] = f'({b.expr}).to()' + if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))): ctx[NamedCType(t.name, BaseCType(optionalScalarRefT))] = \ f'({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())' + if t.type == BaseCType(scalar_t): + ctx[NamedCType(t.name, BaseCType(opmath_t))] = f'static_cast({b.expr})' + # Add implicit bindings if the generated code is inside a Tensor method if method: ctx[NamedCType("self", MutRefCType(BaseCType(tensorT)))] = "const_cast(*this)" @@ -129,7 +161,8 @@ Check this module for more information. ''') # A shitty backtracking search implementation. It's shitty because it - # doesn't actually do backtracing or search. In particular, if + # does backtracking via stack (bad idea!) and for the most part tries to + # avoid backtracking. In particular, if # direct=True, we won't try to do any fancy synthesis, just trivial # conversions (e.g., "T a" is OK for "const T& a"). So all of the # existing rules in this function simply try to solve immediately, diff --git a/tools/codegen/api/types.py b/tools/codegen/api/types.py index d269f2c7a3f..8a01b49bfb4 100644 --- a/tools/codegen/api/types.py +++ b/tools/codegen/api/types.py @@ -1,6 +1,6 @@ from tools.codegen.model import (Argument, FunctionSchema, NativeFunction, - BackendIndex, - SelfArgument, TensorOptionsArguments, BaseTy) + BackendIndex, NativeFunctionsGroup, + SelfArgument, TensorOptionsArguments, BaseTy, ScalarType) from dataclasses import dataclass from typing import Optional, Union, Sequence, TypeVar, List, Set, Dict from enum import Enum @@ -68,6 +68,27 @@ tensorOptionsT = BaseCppType('at', 'TensorOptions') typeAndSizeT = BaseCppType('torch::autograd::generated', 'TypeAndSize') tensorGeometryT = BaseCppType('at', 'TensorGeometry') +# Types representing template parameters. Technically, we probably shouldn't +# represent them this way in codegen, but it was pretty convenient. +scalar_t = BaseCppType('', 'scalar_t') +opmath_t = BaseCppType('', 'opmath_t') + +ScalarTypeToCppMapping: Dict[ScalarType, BaseCppType] = { + ScalarType.Byte: byteT, + ScalarType.Char: charT, + ScalarType.Short: shortT, + ScalarType.Int: int32T, + ScalarType.Long: longT, + ScalarType.Half: halfT, + ScalarType.Float: floatT, + ScalarType.Double: doubleT, + ScalarType.ComplexHalf: complexHalfT, + ScalarType.ComplexFloat: complexFloatT, + ScalarType.ComplexDouble: complexDoubleT, + ScalarType.Bool: boolT, + ScalarType.BFloat16: bfloat16T, +} + BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = { BaseTy.int: longT, BaseTy.float: doubleT, @@ -218,6 +239,23 @@ class TupleCType: def remove_const_ref(self) -> 'CType': return TupleCType([e.remove_const_ref() for e in self.elems]) +@dataclass(frozen=True) +class VectorizedCType: + # This template is explicitly specialized, so the only valid + # elems are those we have specializations for (e.g., float, double, ...) + # scalar_t is also a common argument here (when we are codegen in + # a templated context) + elem: BaseCType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + return f'at::vec::Vectorized<{self.elem.cpp_type()}>' + + def cpp_type_registration_declarations(self) -> str: + raise NotImplementedError + + def remove_const_ref(self) -> 'CType': + return self + CType = Union[ BaseCType, OptionalCType, @@ -227,7 +265,8 @@ CType = Union[ ArrayRefCType, ArrayCType, VectorCType, - TupleCType + TupleCType, + VectorizedCType ] # A NamedCType is short for Named C++ semantic type. A NamedCType represents a C++ type, plus @@ -270,6 +309,14 @@ class Binding: # TODO: maybe don't represent default here default: Optional[str] = None + def rename(self, name: str) -> 'Binding': + return Binding( + name=name, + nctype=self.nctype, + argument=self.argument, + default=self.default, + ) + @property def type(self) -> str: return self.nctype.cpp_type() @@ -596,6 +643,19 @@ class FunctionalizationLambda: return FunctionalizationLambda(f, functional_op, is_reverse) +@dataclass(frozen=True) +class StructuredImplSignature: + g: NativeFunctionsGroup + name: str + + def defn(self, name: Optional[str] = None) -> str: + args_str = ', '.join(a.defn() for a in self.arguments()) + return f"TORCH_IMPL_FUNC({self.name})({args_str})" + + def arguments(self) -> List[Binding]: + return structured.impl_arguments(self.g) + + # Helper functions def kernel_signature( @@ -615,4 +675,4 @@ def kernel_signature( return NativeSignature(f.func, prefix) # Functions only, no types -from tools.codegen.api import cpp, dispatcher, native, translate, functionalization +from tools.codegen.api import cpp, dispatcher, native, translate, functionalization, structured diff --git a/tools/codegen/api/ufunc.py b/tools/codegen/api/ufunc.py new file mode 100644 index 00000000000..e6609e0b888 --- /dev/null +++ b/tools/codegen/api/ufunc.py @@ -0,0 +1,176 @@ +from tools.codegen.model import (Argument, BaseTy, BaseType, FunctionSchema, + NativeFunctionsGroup, Type, DispatchKey) + +import tools.codegen.api.types as api_types +from tools.codegen.api.types import (ArgName, BaseCType, Binding, + ConstRefCType, NamedCType, + scalarT, CType, BaseCppType) + +from tools.codegen.api import cpp, structured + +from dataclasses import dataclass +from typing import List, Optional + +def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str: + assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas" + return f"ufunc_{func.name.name}_{dispatch_key}" + +def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str: + return schema_kernel_name(g.out.func, dispatch_key) + +# Tensors are omitted (as they are stored in TensorIterator), everything else is +# passed along (technically, we can pass tensors along too, it just wastes +# argument registers) +# +# NB: used for CPU only +def dispatchstub_type(t: Type, *, binds: ArgName) -> Optional[NamedCType]: + r = cpp.valuetype_type(t, binds=binds) + if r is not None: + return r + + if t == BaseType(BaseTy.Scalar): + return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) + elif t == BaseType(BaseTy.Tensor): + return None + else: + raise AssertionError(f"unrecognized type {repr(t)}") + +def opmath_type(scalar_t: BaseCppType) -> BaseCppType: + if scalar_t == api_types.scalar_t: + return api_types.opmath_t + raise NotImplementedError + +# NB: Tensors in constructor are stored in opmath_t, not scalar_t +# because Tensor in constructor = its a scalar tensor partially applied = +# it can be higher precision and we want to compute in that higher precision +# +# NB: CUDA only +def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType: + r = cpp.valuetype_type(t, binds=binds) + if r is not None: + return r + + if t == BaseType(BaseTy.Scalar): + return NamedCType(binds, BaseCType(opmath_type(scalar_t))) + elif t == BaseType(BaseTy.Tensor): + return NamedCType(binds, BaseCType(opmath_type(scalar_t))) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + +# Only Tensors ever get passed directly to operator() +# +# NB: CUDA only +# (Actually, this works for CPU too) +def ufunctor_apply_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType: + if t == BaseType(BaseTy.Tensor): + return NamedCType(binds, BaseCType(scalar_t)) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + +# The actual ufunc template function the user writes. Everything here +# is done in the computation type. compute_t is opmath_t in CUDA and scalar_t +# in CPU +def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType: + r = cpp.valuetype_type(t, binds=binds) + if r is not None: + return r + + if t == BaseType(BaseTy.Scalar): + return NamedCType(binds, compute_t) + elif t == BaseType(BaseTy.Tensor): + return NamedCType(binds, compute_t) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + +def ufunctor_ctor_argument(a: Argument, scalar_t: BaseCppType) -> Binding: + return Binding( + nctype=ufunctor_ctor_type(a.type, binds=a.name, scalar_t=scalar_t), + name=a.name, + default=None, + argument=a, + ) + +def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding: + return Binding( + nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t), + name=a.name, + default=None, + argument=a, + ) + +def ufunc_argument(a: Argument, compute_t: CType) -> Binding: + return Binding( + nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t), + name=a.name, + default=None, + argument=a, + ) + +@dataclass(frozen=True) +class UfunctorBindings: + ctor: List[Binding] + apply: List[Binding] + +# ufunctors are a CUDA-only concept representing functors that take some of +# their arguments on a host-side constructor, and the rest in the device-side +# apply. E.g., +# +# template +# struct CUDAFunctorOnSelf_add { +# using opmath_t = at::opmath_type; +# opmath_t other_; +# opmath_t alpha_; +# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) : other_(other), alpha_(alpha) {} +# __device__ scalar_t operator()(scalar_t self) { +# return ufunc::add(static_cast(self), other_, alpha_); +# } +# }; +# +# The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers +# to the operator() definition +def ufunctor_arguments( + g: NativeFunctionsGroup, *, scalar_tensor_idx: Optional[int], scalar_t: BaseCppType +) -> UfunctorBindings: + ctor = [] + apply = [] + for a in g.functional.func.arguments.flat_non_out: + if a.type.is_tensor_like(): + if scalar_tensor_idx == 0: + # put it in the ctor anyway + ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t)) + scalar_tensor_idx = None + else: + if scalar_tensor_idx is not None: + scalar_tensor_idx -= 1 + apply.append(ufunctor_apply_argument(a, scalar_t=scalar_t)) + else: + ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t)) + assert scalar_tensor_idx is None + return UfunctorBindings(ctor=ctor, apply=apply) + +# ufuncs are the inner loop template functions that you wrote in ufunc/add.h +# which do the actual computation in question. E.g., +# +# template +# C10_HOST_DEVICE T add(T self, T other, T alpha) __ubsan_ignore_undefined__ { +# return self + alpha * other; +# } +# +# In this file, we refer to T as compute_t which is bound by caller +def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> List[Binding]: + return [ufunc_argument(a, compute_t=compute_t) for a in g.functional.func.arguments.flat_non_out] + +# Stubs are the DispatchStub trampolines that CPU kernels use to get to their +# vectorized versions. E.g., +# +# using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha); +# DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub); +def stub_arguments(g: NativeFunctionsGroup) -> List[Binding]: + # stubs drop all tensor arguments (they are implicit in the TensorIterator + # argument and keep everything else) + return [ + r + for a in g.out.func.arguments.flat_non_out + if not a.type.is_tensor_like() + for r in structured.argument(a) + ] diff --git a/tools/codegen/dest/__init__.py b/tools/codegen/dest/__init__.py index ce9265adf96..d191b8361ba 100644 --- a/tools/codegen/dest/__init__.py +++ b/tools/codegen/dest/__init__.py @@ -7,3 +7,8 @@ from .register_dispatch_key import ( gen_registration_headers as gen_registration_headers, ) from .native_functions import compute_native_function_declaration as compute_native_function_declaration +from .ufunc import ( + compute_ufunc_cuda as compute_ufunc_cuda, + compute_ufunc_cpu as compute_ufunc_cpu, + compute_ufunc_cpu_kernel as compute_ufunc_cpu_kernel +) diff --git a/tools/codegen/dest/ufunc.py b/tools/codegen/dest/ufunc.py new file mode 100644 index 00000000000..c8b92bd538e --- /dev/null +++ b/tools/codegen/dest/ufunc.py @@ -0,0 +1,477 @@ +from dataclasses import dataclass +from typing import Union, Optional, List, Tuple, Dict, Sequence +from tools.codegen.api.translate import translate +from tools.codegen.model import NativeFunctionsGroup, ScalarType, UfuncKey, DispatchKey, BaseType, BaseTy, Argument +import tools.codegen.api.ufunc as ufunc +from tools.codegen.api.ufunc import UfunctorBindings +from tools.codegen.api.types import ( + StructuredImplSignature, scalar_t, opmath_t, Binding, CType, + BaseCType, Expr, NamedCType, ScalarTypeToCppMapping, VectorizedCType +) +from tools.codegen.context import with_native_function + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# CUDA STUFF +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + +# NB: not bothering to generate dispatch stub forward declaration in header, +# we can just paste it whereever necessary + +# TODO: use BackendIndex +# dispatch_key: DispatchKey # only CPU/CUDA right now + + +# Represents functors for implementing CUDA ufuncs. +# Functors are templated by scalar_t because when USERS instantiate functors +# they are templated. A functor looks something like this: +# +# template +# struct CUDAFunctorOnSelf_add { +# using opmath_t = at::opmath_type; +# opmath_t other_; +# opmath_t alpha_; +# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) +# : other_(other), alpha_(alpha) {} +# __device__ scalar_t operator()(scalar_t self) { +# return ufunc::add(static_cast(self), other_, alpha_); +# } +# }; +# +@dataclass(frozen=True) +class UfunctorSignature: + g: NativeFunctionsGroup + scalar_tensor_idx: Optional[int] + name: str + + def arguments(self) -> UfunctorBindings: + return ufunc.ufunctor_arguments(self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t) + + def fields(self) -> List[Binding]: + # fields are renamed to have a trailing underscore, as is conventional + return [b.rename(f"{b.name}_") for b in self.arguments().ctor] + + def returns_type(self) -> CType: + # TODO: don't hardcode; return type will be inferred based on tags on + # the native function + return BaseCType(scalar_t) + + def decl_fields(self) -> str: + return "\n".join(f"{f.type} {f.name};" for f in self.fields()) + + def inline_defn_ctor(self) -> str: + args_str = ', '.join(a.decl() for a in self.arguments().ctor) + # NB: hypothetically could do this with translate but the + # transition here is very regular + init_str = ', '.join(f"{a.name}_({a.name})" for a in self.arguments().ctor) + return f"{self.name}({args_str}) : {init_str} {{}}" + + def decl_apply(self) -> str: + args_str = ', '.join(a.decl() for a in self.arguments().apply) + return f"{self.returns_type().cpp_type()} operator()({args_str}) const" + + +@dataclass(frozen=True) +class UfuncSignature: + g: NativeFunctionsGroup + name: str + compute_t: CType + + def arguments(self) -> List[Binding]: + return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t) + + def call(self, ctx: Sequence[Union[Binding, Expr]]) -> str: + return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})" + + +# steps: +# 1. take the functional signature +# 2. use api.ufunc to convert it to template signature. this establishes +# the type of the template function +# 3. use api.ufunc (II) to generate a split struct / operator() signature. +# this establish context in which we call the template signature +# +# StructuredImplSignature context +# ~> functor constructor sig +# +# Functor constructor context +# ~> functor fields sig +# +# Functor apply context (functor fields + functor apply sig) +# ~> template sig +# + +def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool: + num_tensors = sum(1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like()) + return num_tensors == 2 + +def compute_ufunc_cuda_functors(g: NativeFunctionsGroup) -> Tuple[Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]], str]: + # First, build the functors. + ufunctor_sigs: Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]] = {} + ufunctors: List[str] = [] + loops = g.out.ufunc_inner_loop + scalar_tensor_idx_lookup = { + UfuncKey.CUDAFunctorOnSelf: 1, + UfuncKey.CUDAFunctorOnOther: 0, + UfuncKey.CUDAFunctor: None + } + if eligible_for_binary_scalar_specialization(g): + keys = [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther, UfuncKey.CUDAFunctor] + else: + keys = [UfuncKey.CUDAFunctor] + for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]: + assert k not in loops, f"cannot use {k} on non-binary function" + for k in keys: + # If the key was directly defined, skip functor codegen; we assume the + # user already done it for us + if k in loops: + ufunctor_sig = UfunctorSignature(g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name) + for dtype in loops[k].supported_dtypes: + ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig + continue + + # Note [ScalarOnly and Generic must match names for CUDA] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Otherwise, look in ANY of the generic entries. For simplicity of + # codegen, both ScalarOnly and Generic are defined, the ufunc name + # must match (if they didn't match, we'd have to generate distinct + # functors per dtype, which is awful, so we're not going to do it unless + # someone really forces us to) + ufunc_name = None + supported_dtypes = set() + for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]: + if lk not in loops: + continue + if ufunc_name is None: + ufunc_name = loops[lk].name + else: + # See Note [ScalarOnly and Generic must match names for CUDA] + assert ufunc_name == loops[lk].name, "ScalarOnly and Generic must have same ufunc name" + supported_dtypes |= loops[lk].supported_dtypes + assert ufunc_name is not None + + name = f"{k}_{ufunc_name}" + ufunctor_sig = UfunctorSignature(g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name) + for dtype in supported_dtypes: + ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig + + ufunc_sig = UfuncSignature(g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t)) + apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply + ufunctors.append(f""" +template +struct {ufunctor_sig.name} {{ + using opmath_t = at::opmath_type; + {ufunctor_sig.decl_fields()} + {ufunctor_sig.inline_defn_ctor()} + __device__ {ufunctor_sig.decl_apply()} {{ + return {ufunc_sig.call(apply_ctx)}; + }} +}}; +""") + + return ufunctor_sigs, "\n".join(ufunctors) + +@dataclass(frozen=True) +class BinaryScalarSpecializationConfig: + scalar_idx: int + ctor_tensor: str + ufunc_key: UfuncKey + +BinaryScalarSpecializationConfigs = [ + BinaryScalarSpecializationConfig( + scalar_idx=0, + ctor_tensor='self', + ufunc_key=UfuncKey.CUDAFunctorOnOther, + ), + BinaryScalarSpecializationConfig( + scalar_idx=1, + ctor_tensor='other', + ufunc_key=UfuncKey.CUDAFunctorOnSelf, + ), +] + +def compute_ufunc_cuda_dtype_body( + g: NativeFunctionsGroup, dtype: ScalarType, + inner_loops: Dict[UfuncKey, UfunctorSignature], parent_ctx: Sequence[Binding] +) -> str: + body = "using opmath_t = at::opmath_type;" + body += "if (false) {}\n" # for ease of codegen + for config in BinaryScalarSpecializationConfigs: + if config.ufunc_key not in inner_loops: + continue + ufunctor_sig = inner_loops[config.ufunc_key] + scalar_idx = config.scalar_idx + 1 + # Make a copy and at the same time widen the type (not permissible + # without copy; we don't want to mutate the input argument anyway) + ctx: List[Union[Expr, Binding]] = list(parent_ctx) + ctx.append(Expr( + expr=f"iter.scalar_value({scalar_idx})", + type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)), + )) + ufunctor_ctor_exprs_str = ', '.join(a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor)) + + # NB: ufunctor must be allocated before iter.remove_operand is called, + # as it relies on iter + body += f"""\ +else if (iter.is_cpu_scalar({scalar_idx})) {{ + {ufunctor_sig.name} ufunctor({ufunctor_ctor_exprs_str}); + iter.remove_operand({scalar_idx}); + gpu_kernel(iter, ufunctor); +}}""" + + ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor] + ufunctor_ctor_exprs_str = ', '.join(a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor)) + body += f""" +else {{ + gpu_kernel(iter, {ufunctor_sig.name}({ufunctor_ctor_exprs_str})); +}} + """ + return body + +@with_native_function +def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str: + # First, build the functors, indexing them by dtype + ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g) + + # Next, build the conditionals + sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA)) + dtype_cases = [] + for dtype, inner_ufunctor_sigs in ufunctor_sigs.items(): + dtype_cases.append(f""" +AT_PRIVATE_CASE_TYPE("{sig.name}", at::ScalarType::{dtype}, {ScalarTypeToCppMapping[dtype]}, + [&]() {{ + {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunctor_sigs, sig.arguments())} + }} +) +""") + + dtype_cases_str = "\n".join(dtype_cases) + + stub_sig = StubSignature(g) + + return f""" +{ufunctors} + +{stub_sig.type_defn()}; +{stub_sig.dispatch_decl()}; + +{stub_sig.kernel_defn()} {{ + at::ScalarType st = iter.common_dtype(); + RECORD_KERNEL_FUNCTION_DTYPE("{sig.name}", st); + switch (st) {{ + {dtype_cases_str} + default: + TORCH_CHECK(false, "{sig.name}", " not implemented for '", toString(st), "'"); + }} +}} +REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name}); + +{sig.defn()} {{ + {stub_sig.direct_call(sig.arguments())}; +}} +""" + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# CPU STUFF +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + +@dataclass(frozen=True) +class StubSignature: + g: NativeFunctionsGroup + + @property + def name(self) -> str: + return f"{str(self.g.functional.func.name.name)}_stub" + + @property + def kernel_name(self) -> str: + return f"{str(self.g.functional.func.name.name)}_kernel" + + @property + def type_name(self) -> str: + return f"{str(self.g.functional.func.name.name)}_fn" + + def arguments(self) -> List[Binding]: + return ufunc.stub_arguments(self.g) + + def type(self) -> str: + cpp_args = self.arguments() + return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})" + + def dispatch_decl(self) -> str: + return f"DECLARE_DISPATCH({self.type_name}, {self.name})" + + def dispatch_defn(self) -> str: + return f"DEFINE_DISPATCH({self.name})" + + def kernel_defn(self) -> str: + return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})" + + def type_defn(self) -> str: + return f"using {self.type_name} = {self.type()}" + + # must be called from context where this is TensorIteratorBase* + def call(self, ctx: Sequence[Binding]) -> str: + return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})" + + # used in CUDA to skip the unnecessary dynamic dispatch + def direct_call(self, ctx: Sequence[Binding]) -> str: + return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})" + +@with_native_function +def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str: + stub_sig = StubSignature(g) + sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU)) + + return f""" +{stub_sig.type_defn()}; +{stub_sig.dispatch_decl()}; +{stub_sig.dispatch_defn()}; + +{sig.defn()} {{ + {stub_sig.call(sig.arguments())}; +}} +""" + +def compute_ufunc_cpu_dtype_body( + g: NativeFunctionsGroup, dtype: ScalarType, inner_loops: Dict[UfuncKey, UfuncSignature], + parent_ctx: Sequence[Binding] +) -> str: + assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}" + assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector} + scalar_loop = inner_loops[UfuncKey.CPUScalar] + vec_loop = None + if UfuncKey.CPUVector in inner_loops: + vec_loop = inner_loops[UfuncKey.CPUVector] + + # NB: We DON'T use translate here, because translate is + # incapable of CSE'ing the scalar accesses in case it is also + # used by Vectorized; also, the unpacking here is very simple + # and only affects Scalar; everything else is implicitly captured + # by the lambda + + # Setup scalar in scope + body = [] + ctx = [] + for b in parent_ctx: + if isinstance(b.argument, Argument) and b.argument.type != BaseType(BaseTy.Scalar): + continue + body.append(f"auto _s_{b.name} = {b.name}.to();") + ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t)))) + if vec_loop is not None: + for b in parent_ctx: + if isinstance(b.argument, Argument) and b.argument.type != BaseType(BaseTy.Scalar): + continue + body.append(f"auto _v_{b.name} = at::vec::Vectorized(_s_{b.name});") + ctx.append(Expr(f"_v_{b.name}", NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))))) + + # Setup lambda signature + # NB: simplified version of ufunctor_arguments + scalar_bindings = [] + vec_bindings = [] + for a in g.functional.func.arguments.flat_non_out: + if not a.type.is_tensor_like(): + continue + assert a.type == BaseType(BaseTy.Tensor) + scalar_bindings.append(Binding( + name=a.name, + nctype=NamedCType(a.name, BaseCType(scalar_t)), + argument=a, + )) + if vec_loop is not None: + vec_bindings.append(Binding( + name=a.name, + nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))), + argument=a, + )) + + def with_ctx(b: Sequence[Binding]) -> List[Union[Expr, Binding]]: + r: List[Union[Expr, Binding]] = [] + r.extend(ctx) + r.extend(b) + return r + + body_str = '\n'.join(body) + if vec_loop is not None: + return f""" +{body_str} +cpu_kernel_vec(iter, + [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}, + [=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }} +); +""" + else: + return f""" +{body_str} +cpu_kernel(iter, + [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }} +); +""" + +@with_native_function +def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str: + stub_sig = StubSignature(g) + + # Reindex the ufunc by dtypes; processing generic/scalaronly as well + loops = g.out.ufunc_inner_loop + ufunc_sigs: Dict[ScalarType, Dict[UfuncKey, UfuncSignature]] = {} + for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]: + lks = [] + # ORDER MATTERS: this specifies overriding precedence + if k in loops: # should happen rarely + lks.append(k) + if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar: + lks.append(UfuncKey.ScalarOnly) + if UfuncKey.Generic in loops: + lks.append(UfuncKey.Generic) + # TODO: don't hardcode ufunc:: namespace here, should be centralized smh + for lk in lks: + for dtype in loops[lk].supported_dtypes: + compute_t: CType + if k is UfuncKey.CPUScalar: + compute_t = BaseCType(scalar_t) + elif k is UfuncKey.CPUVector: + compute_t = VectorizedCType(BaseCType(scalar_t)) + else: + raise AssertionError() + inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {}) + if k not in inner_ufunc_sigs: + inner_ufunc_sigs[k] = UfuncSignature( + g, name=f"ufunc::{loops[lk].name}", + compute_t=compute_t + ) + + # Build the conditionals + dtype_cases = [] + for dtype, inner_ufunc_sigs in ufunc_sigs.items(): + dtype_cases.append(f""" +AT_PRIVATE_CASE_TYPE("{stub_sig.name}", at::ScalarType::{dtype}, {ScalarTypeToCppMapping[dtype]}, + [&]() {{ + {compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())} + }} +) +""") + + dtype_cases_str = "\n".join(dtype_cases) + return f""" +namespace {{ + +{stub_sig.kernel_defn()} {{ + at::ScalarType st = iter.common_dtype(); + RECORD_KERNEL_FUNCTION_DTYPE("{stub_sig.name}", st); + switch (st) {{ + {dtype_cases_str} + default: + TORCH_CHECK(false, "{stub_sig.name}", " not implemented for '", toString(st), "'"); + }} +}} + +}} // anonymous namespace + +{stub_sig.type_defn()}; +{stub_sig.dispatch_decl()}; +REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name}); +""" diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index 846c02f1382..51ff8340095 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -16,6 +16,7 @@ from tools.codegen.model import (Argument, DispatchKey, FunctionSchema, TensorOptionsArguments, Type, Variant, is_cuda_dispatch_key, is_generic_dispatch_key, + is_ufunc_dispatch_key, Tag, BaseOperatorName) from tools.codegen.api.types import (Binding, CppSignature, CppSignatureGroup, DispatcherSignature, NativeSignature) @@ -111,40 +112,44 @@ _GLOBAL_PARSE_NATIVE_YAML_CACHE = {} # Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices. ParsedYaml = namedtuple('ParsedYaml', ['native_functions', 'backend_indices']) + +def parse_native_yaml_struct(es: object, path: str = "") -> ParsedYaml: + assert isinstance(es, list) + rs: List[NativeFunction] = [] + bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict) + for e in es: + assert isinstance(e.get('__line__'), int), e + loc = Location(path, e['__line__']) + funcs = e.get('func') + with context(lambda: f'in {loc}:\n {funcs}'): + func, m = NativeFunction.from_yaml(e, loc) + rs.append(func) + BackendIndex.grow_index(bs, m) + error_check_native_functions(rs) + # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet. + indices: Dict[DispatchKey, BackendIndex] = defaultdict(lambda: BackendIndex( + dispatch_key=DispatchKey.Undefined, + use_out_as_primary=True, + external=False, + device_guard=False, + index={})) + for k, v in bs.items(): + # All structured in-tree operators are implemented in terms of their out operator. + indices[k] = BackendIndex( + dispatch_key=k, + use_out_as_primary=True, + external=False, + # Only cuda-like devices in tree require device guards + device_guard=is_cuda_dispatch_key(k), + index=v) + return ParsedYaml(rs, indices) + def parse_native_yaml(path: str) -> ParsedYaml: global _GLOBAL_PARSE_NATIVE_YAML_CACHE if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE: with open(path, 'r') as f: es = yaml.load(f, Loader=LineLoader) - assert isinstance(es, list) - rs: List[NativeFunction] = [] - bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict) - for e in es: - assert isinstance(e.get('__line__'), int), e - loc = Location(path, e['__line__']) - funcs = e.get('func') - with context(lambda: f'in {loc}:\n {funcs}'): - func, m = NativeFunction.from_yaml(e, loc) - rs.append(func) - BackendIndex.grow_index(bs, m) - error_check_native_functions(rs) - # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet. - indices: Dict[DispatchKey, BackendIndex] = defaultdict(lambda: BackendIndex( - dispatch_key=DispatchKey.Undefined, - use_out_as_primary=True, - external=False, - device_guard=False, - index={})) - for k, v in bs.items(): - # All structured in-tree operators are implemented in terms of their out operator. - indices[k] = BackendIndex( - dispatch_key=k, - use_out_as_primary=True, - external=False, - # Only cuda-like devices in tree require device guards - device_guard=is_cuda_dispatch_key(k), - index=v) - _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = ParsedYaml(rs, indices) + _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = parse_native_yaml_struct(es, path=path) return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] @@ -1012,6 +1017,7 @@ def gen_aggregated_headers( *, native_functions: Sequence[NativeFunction], grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], + structured_native_functions: Sequence[NativeFunctionsGroup], static_dispatch_idx: Optional[BackendIndex], selector: SelectiveBuilder, backend_indices: Dict[DispatchKey, BackendIndex], @@ -1023,8 +1029,6 @@ def gen_aggregated_headers( ) -> None: # Buck doesn't support dynamic output files, so we aggregate all operator # headers into a single file - structured_native_functions = [g for g in grouped_native_functions - if isinstance(g, NativeFunctionsGroup)] cpu_fm.write('NativeMetaFunctions.h', lambda: { 'NativeMetaFunctions_includes': [], 'NativeMetaFunctions_declarations': list( @@ -1242,6 +1246,7 @@ def gen_headers( *, native_functions: Sequence[NativeFunction], grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], + structured_native_functions: Sequence[NativeFunctionsGroup], static_dispatch_idx: Optional[BackendIndex], selector: SelectiveBuilder, backend_indices: Dict[DispatchKey, BackendIndex], @@ -1272,6 +1277,7 @@ def gen_headers( gen_aggregated_headers( native_functions=native_functions, grouped_native_functions=grouped_native_functions, + structured_native_functions=structured_native_functions, static_dispatch_idx=static_dispatch_idx, selector=selector, backend_indices=backend_indices, @@ -1343,11 +1349,13 @@ def gen_source_files( *, native_functions: Sequence[NativeFunction], grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], + structured_native_functions: Sequence[NativeFunctionsGroup], static_dispatch_idx: Optional[BackendIndex], selector: SelectiveBuilder, backend_indices: Dict[DispatchKey, BackendIndex], core_fm: FileManager, cpu_fm: FileManager, + cpu_vec_fm: FileManager, cuda_fm: FileManager, dispatch_keys: Sequence[DispatchKey], functions_keys: Set[DispatchKey], @@ -1373,19 +1381,30 @@ def gen_source_files( if per_operator_headers: def operator_headers() -> List[str]: headers = [] - for fn in native_functions: - is_registered = backend_index.has_kernel(fn) or ( - fn.structured and dispatch_key in - (DispatchKey.Meta, DispatchKey.CompositeExplicitAutograd)) + for g in grouped_native_functions: + is_registered = False + if backend_index.has_kernel(g): + is_registered = True + # The above has_kernel test on a group will only test for + # the existence of out dispatch, because that's how + # structured kernels work. But sometimes functions can be + # grouped but not be structured, and then you need to check + # each individual piece, as they may have manual dispatch + # entries. + elif isinstance(g, NativeFunctionsGroup) and any(backend_index.has_kernel(fn) for fn in g.functions()): + is_registered = True + # TODO: this condition is a bit questionable + elif g.structured and dispatch_key in (DispatchKey.Meta, DispatchKey.CompositeExplicitAutograd): + is_registered = True if not is_registered: continue - headers.append(f"#include ") + headers.append(f"#include ") if dispatch_key == DispatchKey.CompositeExplicitAutograd: - headers.append(f"#include ") + headers.append(f"#include ") if dispatch_key in functions_keys: headers.append( - f"#include ") + f"#include ") return sorted(set(headers)) else: @@ -1439,6 +1458,39 @@ def gen_source_files( )), }) + for g in structured_native_functions: + if not g.out.ufunc_inner_loop or not is_ufunc_dispatch_key(dispatch_key): + continue + name = g.functional.func.name.name + if dispatch_key is DispatchKey.CPU: + assert fm is cpu_fm + fm.write_with_template(f'UfuncCPU_{name}.cpp', 'UfuncCPU.cpp', lambda: { + 'meta_declaration': compute_meta_function_declaration(g), + 'native_declaration': + dest.compute_native_function_declaration(g, backend_indices[dispatch_key]), + 'native_definitions': dest.compute_ufunc_cpu(g), + }) + cpu_vec_fm.write_with_template(f'UfuncCPUKernel_{name}.cpp', 'UfuncCPUKernel.cpp', lambda: { + 'name': name, + 'native_definitions': dest.compute_ufunc_cpu_kernel(g), + }) + elif dispatch_key is DispatchKey.CUDA: + cuda_headers = "#include " + if rocm: + cuda_headers = "#include " + fm.write_with_template(f'UfuncCUDA_{name}.cu', 'UfuncCUDA.cu', lambda: { + 'name': name, + 'cuda_headers': cuda_headers, + 'meta_declaration': compute_meta_function_declaration(g), + 'native_declaration': + dest.compute_native_function_declaration(g, backend_indices[dispatch_key]), + 'native_definitions': dest.compute_ufunc_cuda(g), + }) + else: + raise AssertionError(f'unrecognized {dispatch_key} for ufunc') + + del fm + # BackendSelect is generated specially def gen_backend_select() -> Dict[str, List[str]]: relevant_fns = [fn for fn in native_functions if needs_backend_select(fn, selector)] @@ -1601,6 +1653,8 @@ def main() -> None: parsed_yaml = parse_native_yaml(native_yaml_path) native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices grouped_native_functions = get_grouped_native_functions(native_functions) + structured_native_functions = [g for g in grouped_native_functions + if isinstance(g, NativeFunctionsGroup)] template_dir = os.path.join(options.source_path, "templates") @@ -1620,10 +1674,15 @@ def main() -> None: pathlib.Path(ops_install_dir).mkdir(parents=True, exist_ok=True) def make_file_manager(install_dir: str) -> FileManager: - return FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run) + return FileManager( + install_dir=install_dir, + template_dir=template_dir, + dry_run=options.dry_run + ) core_fm = make_file_manager(core_install_dir) cpu_fm = make_file_manager(options.install_dir) + cpu_vec_fm = make_file_manager(options.install_dir) cuda_fm = make_file_manager(options.install_dir) ops_fm = make_file_manager(ops_install_dir) @@ -1661,11 +1720,13 @@ def main() -> None: gen_source_files( native_functions=native_functions, grouped_native_functions=grouped_native_functions, + structured_native_functions=structured_native_functions, static_dispatch_idx=static_dispatch_idx, selector=selector, backend_indices=backend_indices, core_fm=core_fm, cpu_fm=cpu_fm, + cpu_vec_fm=cpu_vec_fm, cuda_fm=cuda_fm, dispatch_keys=dispatch_keys, functions_keys=functions_keys, @@ -1678,6 +1739,7 @@ def main() -> None: gen_headers( native_functions=native_functions, grouped_native_functions=grouped_native_functions, + structured_native_functions=structured_native_functions, static_dispatch_idx=static_dispatch_idx, selector=selector, backend_indices=backend_indices, @@ -1703,6 +1765,7 @@ def main() -> None: for fm, prefix in [ (cpu_fm, ""), + (cpu_vec_fm, "cpu_vec_"), (core_fm, "core_"), (cuda_fm, "cuda_"), (ops_fm, "ops_"), diff --git a/tools/codegen/model.py b/tools/codegen/model.py index fab6ba3affc..3d92d92dfe0 100644 --- a/tools/codegen/model.py +++ b/tools/codegen/model.py @@ -166,6 +166,97 @@ def is_cuda_dispatch_key(dk: DispatchKey) -> bool: def is_structured_dispatch_key(dk: DispatchKey) -> bool: return dk in STRUCTURED_DISPATCH_KEYS +def is_ufunc_dispatch_key(dk: DispatchKey) -> bool: + # For now, ufunc dispatch keys coincide with structured keys + return dk in STRUCTURED_DISPATCH_KEYS + +# This is oddly named ScalarType and not DType for symmetry with C++ +class ScalarType(Enum): + Byte = auto() + Char = auto() + Short = auto() + Int = auto() + Long = auto() + Half = auto() + Float = auto() + Double = auto() + ComplexHalf = auto() + ComplexFloat = auto() + ComplexDouble = auto() + Bool = auto() + BFloat16 = auto() + + def __str__(self) -> str: + return self.name + + @staticmethod + def maybe_parse(value: str) -> Optional['ScalarType']: + for k, v in ScalarType.__members__.items(): + if k == value: + return v + return None + + @staticmethod + def parse(value: str) -> 'ScalarType': + mb_r = ScalarType.maybe_parse(value) + assert mb_r is not None, f'unknown dtype {value}' + return mb_r + + @staticmethod + def parse_set(values: str) -> Set['ScalarType']: + dtypes: Set[ScalarType] = set() + for value in values.split(', '): + if value in DTYPE_CLASSES: + dtypes.update(DTYPE_CLASSES[value]) + else: + dtypes.add(ScalarType.parse(value)) + return dtypes + + +DTYPE_CLASSES: Dict[str, Set[ScalarType]] = {} +# NB: Integral doesn't include boolean +DTYPE_CLASSES["Integral"] = { + ScalarType.Byte, ScalarType.Char, ScalarType.Int, ScalarType.Long, + ScalarType.Short +} +# NB: Floating doesn't include low precision types +DTYPE_CLASSES["Floating"] = {ScalarType.Float, ScalarType.Double} +DTYPE_CLASSES["Complex"] = {ScalarType.ComplexFloat, ScalarType.ComplexDouble} +DTYPE_CLASSES["All"] = DTYPE_CLASSES["Integral"] | DTYPE_CLASSES["Floating"] +DTYPE_CLASSES["AllAndComplex"] = DTYPE_CLASSES["All"] | DTYPE_CLASSES["Complex"] +DTYPE_CLASSES["FloatingAndComplex"] = DTYPE_CLASSES["Floating"] | DTYPE_CLASSES["Complex"] + + +# Represents the valid entries for ufunc_inner_loop in native_functions.yaml. +# NB: if you add a new UfuncKey, you will teach tools.codegen.dest.ufunc how +# to process it. Most logic will ignore keys they don't understand, so your +# new key will get silently ignored until you hook in logic to deal with it. +class UfuncKey(Enum): + # These are low level keys that represent exactly one particular + # instantiation of the kernel produced by codegen + CUDAFunctor = auto() + CUDAFunctorOnOther = auto() + CUDAFunctorOnSelf = auto() + + CPUScalar = auto() + CPUVector = auto() + + # These are the ones users will usually specify, and + # implicitly "fill in" the low level keys + ScalarOnly = auto() # CUDA*, CPUScalar + Generic = auto() # CUDA*, CPU* + + def __str__(self) -> str: + return self.name + + @staticmethod + def parse(value: str) -> 'UfuncKey': + for k, v in UfuncKey.__members__.items(): + if k == value: + return v + raise AssertionError(f'unknown ufunc key {value}') + + class DeviceCheckType(Enum): NoCheck = 0 ExactSame = 1 @@ -239,6 +330,10 @@ class NativeFunction: # defined. This is for conveniently reporting error messages! loc: 'Location' + # If non-empty, this kernel is subject to ufunc codegen. + # Sorted by ufunc_key + ufunc_inner_loop: Dict[UfuncKey, 'UfuncInnerLoop'] + # Whether or not this out functions is a "structured kernel". Structured # kernels are defined a little differently from normal kernels; in # particular, their shape checking logic is defined separately from @@ -413,6 +508,31 @@ class NativeFunction: "strictly subsumes the other. If you wanted to provide an explicit autograd " \ "implementation, specify CompositeExplicitAutograd; otherwise specify CompositeImplicitAutograd only" + raw_ufunc_inner_loop = e.pop('ufunc_inner_loop', {}) + ufunc_inner_loop = {} + if isinstance(raw_ufunc_inner_loop, str): + ufunc_inner_loop[UfuncKey.Generic] = UfuncInnerLoop.parse(raw_ufunc_inner_loop, UfuncKey.Generic) + elif isinstance(raw_ufunc_inner_loop, dict): + for k, vo in raw_ufunc_inner_loop.items(): + if k == '__line__': + continue + assert isinstance(k, str), f'ufunc_inner_loop key is not a str: {k}' + assert isinstance(vo, str), f'ufunc_inner_loop value is not a str: {v}' + ufunc_key = UfuncKey.parse(k) + ufunc_inner_loop[ufunc_key] = UfuncInnerLoop.parse(vo, ufunc_key) + else: + raise AssertionError(f'ufunc_inner_loop not str or dict: {raw_ufunc_inner_loop}') + # Program the BackendIndex for the implicit dispatch entry from ufunc + if ufunc_inner_loop: + assert structured, "ufunc must be structured" + for dispatch_key in STRUCTURED_DISPATCH_KEYS: + assert dispatch_key not in dispatch, \ + f"ufunc should not have explicit dispatch entry for {dispatch_key}" + dispatch[dispatch_key] = BackendMetadata( + kernel=ufunc.schema_kernel_name(func, dispatch_key), + structured=True + ) + if structured_delegate: # Structured functions MUST have a dispatch table is_abstract = True @@ -448,6 +568,7 @@ class NativeFunction: structured_delegate=structured_delegate, structured_inherits=structured_inherits, precomputed=precomputed, + ufunc_inner_loop=ufunc_inner_loop, manual_kernel_registration=manual_kernel_registration, manual_cpp_binding=manual_cpp_binding, python_module=python_module, @@ -666,7 +787,24 @@ class BackendMetadata: # in native_functions.yaml. # However, external backends like XLA can indendently toggle which ops are structured. structured: bool - # + +@dataclass(frozen=True) +class UfuncInnerLoop: + name: str + supported_dtypes: Set[ScalarType] + # key is stored here because it affects the semantics of name, + # so its helpful to have them together for further processing + ufunc_key: UfuncKey + + @staticmethod + def parse(value: str, ufunc_key: UfuncKey) -> 'UfuncInnerLoop': + name, supported_dtypes_str = value.split(' ', 1) + assert supported_dtypes_str[0] == '(' + assert supported_dtypes_str[-1] == ')' + supported_dtypes = set() + for k in supported_dtypes_str[1:-1].split(', '): + supported_dtypes |= ScalarType.parse_set(k) + return UfuncInnerLoop(name=name, supported_dtypes=supported_dtypes, ufunc_key=ufunc_key) # BackendIndex represents a backend. @@ -1664,3 +1802,5 @@ class Precompute: replace_list.append(f'{kernel_param} -> {replacements}') return replace_list + +import tools.codegen.api.ufunc as ufunc diff --git a/tools/test/test_codegen_model.py b/tools/test/test_codegen_model.py new file mode 100644 index 00000000000..50ea59575be --- /dev/null +++ b/tools/test/test_codegen_model.py @@ -0,0 +1,124 @@ +# Owner(s): ["module: codegen"] + +import expecttest +import unittest +import yaml +import textwrap + +from tools.codegen.model import NativeFunctionsGroup, DispatchKey +import tools.codegen.dest as dest +import tools.codegen.gen as gen +from tools.codegen.gen import LineLoader, parse_native_yaml_struct + +class TestCodegenModel(expecttest.TestCase): + def assertParseErrorInline(self, yaml_str: str, expect: str) -> None: + es = yaml.load(yaml_str, Loader=LineLoader) + try: + parse_native_yaml_struct(es) + except AssertionError as e: + # hack to strip out the context + msg, _ = str(e).split(' in ', 2) + self.assertExpectedInline('\n'.join(textwrap.wrap(msg)), expect, skip=1) + return + self.fail(msg="Did not raise when expected to") + + def assertUfuncErrorInline(self, yaml_str: str, expect: str) -> None: + # parse a single structured group out of the yaml to g + es = yaml.load(yaml_str, Loader=LineLoader) + parsed_yaml = parse_native_yaml_struct(es) + native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices + grouped_native_functions = gen.get_grouped_native_functions(native_functions) + assert len(grouped_native_functions) == 1 + g = grouped_native_functions[0] + assert isinstance(g, NativeFunctionsGroup) + assert g.out.ufunc_inner_loop + # this is not ufunc codegen per se, but it does some basic sanity tests for + # ufunc generation + gen.compute_meta_function_declaration(g) + dest.compute_native_function_declaration(g, backend_indices[DispatchKey.CPU]) + dest.compute_native_function_declaration(g, backend_indices[DispatchKey.CUDA]) + try: + # the real kahuna + dest.compute_ufunc_cpu(g) + dest.compute_ufunc_cpu_kernel(g) + dest.compute_ufunc_cuda(g) + except AssertionError as e: + # hack to strip out the context + msg, _ = str(e).split(' in ', 2) + self.assertExpectedInline('\n'.join(textwrap.wrap(msg)), expect, skip=1) + return + self.fail(msg="Did not raise when expected to") + + # NB: indent is hardcoded to be two here, so format your yaml accordingly + binop_out = 'func: binop.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)' + ti_binop_out = f'''{binop_out} + structured: True + structured_inherits: TensorIteratorBase''' + ti_binop = '''func: binop(Tensor self, Tensor other) -> Tensor + structured_delegate: binop.out +''' + + ti_unop_out = '''func: unop.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase''' + ti_unop = '''func: unop(Tensor self) -> Tensor + structured_delegate: unop.out +''' + + def test_nonstructured_ufunc(self) -> None: + yaml_str = f'''\ +- {self.binop_out} + ufunc_inner_loop: + Generic: binop (Bool) +''' + self.assertParseErrorInline(yaml_str, '''\ +ufunc must be structured''') + + def test_overlapping_ufunc_and_dispatch(self) -> None: + yaml_str = f'''\ +- {self.ti_binop_out} + ufunc_inner_loop: + Generic: binop (Bool) + dispatch: + CPU: binop_cpu +''' + self.assertParseErrorInline(yaml_str, '''\ +ufunc should not have explicit dispatch entry for CPU''') + + # See https://github.com/pytorch/pytorch/pull/65851#discussion_r810238456 + @unittest.expectedFailure + def test_scalaronly_shadowed(self) -> None: + yaml_str = f'''\ +- {self.ti_binop_out} + ufunc_inner_loop: + Generic: binop (Bool) + ScalarOnly: binop (Bool) +''' + self.assertParseErrorInline(yaml_str, '''\ +''') + + def test_conflicting_ufunc(self) -> None: + yaml_str = f'''\ +- {self.ti_binop_out} + ufunc_inner_loop: + Generic: binop (Bool) + ScalarOnly: binop_scalar (Bool) +- {self.ti_binop} +''' + self.assertUfuncErrorInline(yaml_str, '''\ +ScalarOnly and Generic must have same ufunc name''') + + def test_invalid_cudafunctoronself_for_binary_op(self) -> None: + yaml_str = f'''\ +- {self.ti_unop_out} + ufunc_inner_loop: + Generic: unop (All) + CUDAFunctorOnSelf: unop_self_cuda (All) +- {self.ti_unop} +''' + self.assertUfuncErrorInline(yaml_str, '''\ +cannot use CUDAFunctorOnSelf on non-binary function''') + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/ufunc_defs.bzl b/tools/ufunc_defs.bzl new file mode 100644 index 00000000000..4490f05be01 --- /dev/null +++ b/tools/ufunc_defs.bzl @@ -0,0 +1,25 @@ +load("@bazel_skylib//lib:paths.bzl", "paths") +load(":build_variables.bzl", "aten_ufunc_headers") + +aten_ufunc_names = [ + paths.split_extension(paths.basename(h))[0] + for h in aten_ufunc_headers +] + +def aten_ufunc_generated_cpu_sources(gencode_pattern = "{}"): + return [gencode_pattern.format(name) for name in [ + "UfuncCPU_{}.cpp".format(n) + for n in aten_ufunc_names + ]] + +def aten_ufunc_generated_cpu_kernel_sources(gencode_pattern = "{}"): + return [gencode_pattern.format(name) for name in [ + "UfuncCPUKernel_{}.cpp".format(n) + for n in aten_ufunc_names + ]] + +def aten_ufunc_generated_cuda_sources(gencode_pattern = "{}"): + return [gencode_pattern.format(name) for name in [ + "UfuncCUDA_{}.cu".format(n) + for n in aten_ufunc_names + ]]