Python Interface for Jiterator

This PR allows user to author a CUDA kernel in python.

```
from torch.cuda.jiterator import create_jit_fn

code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return  -x * y + x - y + alpha; }"
jitted_fn = create_jit_fn(code_string, alpha=0)

a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
result = jitted_fn(a, b, alpha=1.0)
```

Limitations:
- Only supports elementwise kernel
- 1~8 tensor inputs (empty input, e.g. factory methods, is not supported)
- inputs tensors must live in cuda device
- cpu Scalar is not supported
- kwargs must be pre-declared when calling create_jit_fn
- kwargs must be convertible to at::Scalar, one of float64, int64_t, bool. (complex not support for now)

TODOs:
- [x] consolidate union and c10::variant implementation
- [x] plug into existing op testing framework
- [ ] rename files, place files in the right folder
- [ ] place util functions in the right file
- [x] enforce assumptions in python interface e.g <8 inputs, kwargs types
- [x] Add user-facing documentation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76394
Approved by: https://github.com/mruberry
This commit is contained in:
Sherlockk Huang 2022-05-06 18:44:28 +00:00 committed by PyTorch MergeBot
parent 6c615a21a0
commit 8b6a78f39f
16 changed files with 1071 additions and 4 deletions

View File

@ -0,0 +1,345 @@
#include <ATen/jit_macros.h>
#if AT_USE_JITERATOR()
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/jiterator.h>
#include <ATen/cuda/jiterator_impl.h>
#include <iostream>
#include <utility>
#include <chrono>
namespace at {
namespace native {
static inline void launch_jitted_vectorized_kernel_dynamic(
const std::string& name, TensorIteratorBase& iter,
DeviceIndex dev_idx, int64_t N, const std::string& f, void* data_ptr,
const std::vector<at::Scalar>& extra_args) {
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
// N is still int64_t for the computation, but it's always safe to cast result to int
const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
const int vec_size = jitted_can_vectorize_up_to(iter);
bool vectorized = vec_size > 1;
// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
// fn_ptr is set to the appropriate function based on the vec size and GPU used
// TODO: Memory use can probably be optimized by re-using kernels across GPUs with
// the same compute capability
int nTensors = iter.ntensors();
const at::ScalarType common_dtype = iter.common_dtype();
std::string f_inputs_type_str = at::cuda::jit::typeName(common_dtype);
std::string compute_type_str = at::cuda::jit::typeName(toOpMathType(common_dtype));
std::string result_type_str = at::cuda::jit::typeName(common_dtype);
c10::SmallVector<std::string> extra_args_types = get_extra_args_typenames(extra_args);
// The cache key includes all the parameters to generate_code + vec_size + dev_idx
std::stringstream ss;
ss << nTensors << f;
ss << f_inputs_type_str << compute_type_str << result_type_str;
ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar);
ss << extra_args_types;
ss << vec_size;
// DeviceIndex, e.g. int8_t, is not treated as a number by the stream, cast to int as a workaround
ss << static_cast<int>(dev_idx);
const std::string cache_key = ss.str();
static std::mutex _jiterator_mutex;
static std::unordered_map<std::string, at::cuda::jit::NvrtcFunction> fns;
at::cuda::jit::NvrtcFunction* fn_ptr = &fns[cache_key];
if (!fn_ptr->function) {
const std::lock_guard<std::mutex> lock{_jiterator_mutex};
if (!fn_ptr->function) { // cache miss!
// Generates program
auto code = at::cuda::jit::generate_code(nTensors, f, name,
f_inputs_type_str, compute_type_str, result_type_str,
/*contiguous=*/true, /*dynamic_casting=*/false,
at::cuda::jit::BinaryFuncVariant::NoScalar,
extra_args_types,
vectorized, vec_size);
std::string kernel_name = vectorized ? name + "_vectorized" + std::to_string(vec_size) : name;
// Acquires the program
*fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name);
}
}
// size of `extra_args` is unknown at compile-time
auto extra_args_size = extra_args.size();
float scalar_val = 0;
if (vectorized) {
// pack args for kernel launch
constexpr int kernel_args = 3;
auto args = std::make_unique<void*[]>(kernel_args + extra_args_size);
args[0] = static_cast<void*>(&N);
args[1] = data_ptr;
args[2] = static_cast<void*>(&scalar_val);
for (const auto i : c10::irange(extra_args_size)) {
// since 3 slots are already filled in `args`
args[i + 3] = const_cast<void*>(extra_args[i].data_ptr());
}
at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args.get(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
} else {
TrivialOffsetCalculatorVariant input_offset_calculator(iter);
void* ic_ptr = input_offset_calculator.data_ptr();
auto oc = TrivialOffsetCalculator<1>();
auto l = memory::LoadWithoutCast();
auto s = memory::StoreWithoutCast();
// pack args for kernel launch
constexpr int kernel_args = 7;
auto args = std::make_unique<void*[]>(kernel_args + extra_args_size);
args[0] = static_cast<void*>(&N);
args[1] = data_ptr;
args[2] = ic_ptr;
args[3] = static_cast<void*>(&oc);
args[4] = static_cast<void*>(&l);
args[5] = static_cast<void*>(&s);
args[6] = static_cast<void*>(&scalar_val);
for (const auto i : c10::irange(extra_args_size)) {
// since 7 slots are already filled in `args`
args[i + 7] = const_cast<void*>(extra_args[i].data_ptr());
}
at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args.get(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
}
}
static inline void launch_jitted_unrolled_kernel_dynamic(
const std::string& name, TensorIteratorBase& iter,
DeviceIndex dev_idx, int64_t N, const std::string& f, void* data_ptr,
void* ic_ptr, void* oc_ptr, void* l_ptr, void* s_ptr, bool contiguous, bool dynamic_casting,
const std::vector<at::Scalar>& extra_args) {
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
//casting result to int is always safe, intermediate is int64 and won't overflow
const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
int nTensors = iter.ntensors();
const at::ScalarType common_dtype = iter.common_dtype();
std::string f_inputs_type_str = at::cuda::jit::typeName(common_dtype);
std::string compute_type_str = at::cuda::jit::typeName(toOpMathType(common_dtype));
std::string result_type_str = at::cuda::jit::typeName(common_dtype);
c10::SmallVector<std::string> extra_args_types = get_extra_args_typenames(extra_args);
// The cache key includes all the parameters to generate_code + dev_idx
std::stringstream ss;
ss << nTensors << f;
ss << f_inputs_type_str << compute_type_str << result_type_str;
ss << contiguous << dynamic_casting;
ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar);
ss << extra_args_types;
ss << dev_idx;
const std::string cache_key = ss.str();
static std::mutex _jiterator_mutex;
static std::unordered_map<std::string, at::cuda::jit::NvrtcFunction> fns;
at::cuda::jit::NvrtcFunction* fn_ptr = &fns[cache_key];
if (!fn_ptr->function) {
const std::lock_guard<std::mutex> lock{_jiterator_mutex};
if (!fn_ptr->function) {
auto code = at::cuda::jit::generate_code(nTensors, f, name,
f_inputs_type_str, compute_type_str, result_type_str,
contiguous, dynamic_casting,
at::cuda::jit::BinaryFuncVariant::NoScalar,
extra_args_types);
*fn_ptr = at::cuda::jit::jit_pwise_function(code, name);
}
}
float scalar_val = 0;
// pack args for kernel launch
constexpr int kernel_args = 7;
auto extra_args_size = extra_args.size();
auto args = std::make_unique<void*[]>(kernel_args + extra_args_size);
args[0] = static_cast<void*>(&N);
args[1] = data_ptr;
args[2] = ic_ptr;
args[3] = oc_ptr;
args[4] = l_ptr;
args[5] = s_ptr;
args[6] = static_cast<void*>(&scalar_val);
for (const auto i : c10::irange(extra_args_size)) {
// since 7 slots are already filled in `args`
args[i + 7] = const_cast<void*>(extra_args[i].data_ptr());
}
at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args.get(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
}
void jitted_gpu_kernel_dynamic_impl(
const std::string& kernel_name,
TensorIteratorBase& iter,
const std::string& f,
const bool dynamic_casting,
const std::vector<at::Scalar>& extra_args) {
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
TORCH_INTERNAL_ASSERT(iter.ninputs() <= 8);
ArrayVariant data(iter);
void* data_ptr = data.data_ptr();
int64_t numel = iter.numel();
bool contiguous = iter.is_contiguous();
// Decides which of 4 kernel types to launch
// Variations are:
// - Case 1: no dynamic casting and contiguous
// - Case 2: no dynamic casting and noncontiguous
// - Case 3: dynamic casting and contiguous
// - Case 4: dynamic casting and noncontiguous
// These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl
if (!dynamic_casting) {
if (contiguous) {
// Case 1: no dynamic casting and contiguous
launch_jitted_vectorized_kernel_dynamic(kernel_name, iter,
iter.device().index(), numel, f, data_ptr, extra_args);
return;
}
// Case 2: no dynamic casting and noncontiguous
OffsetCalculatorVariant input_offset_calculator(iter);
void* ic_ptr = input_offset_calculator.data_ptr();
auto output_offset_calculator = make_output_offset_calculator(iter);
void* oc_ptr = static_cast<void*>(&output_offset_calculator);
auto loader = memory::LoadWithoutCast();
auto storer = memory::StoreWithoutCast();
void* l_ptr = static_cast<void*>(&loader);
void* s_ptr = static_cast<void*>(&storer);
launch_jitted_unrolled_kernel_dynamic(
kernel_name, iter, iter.device().index(), numel, f, data_ptr,
ic_ptr, oc_ptr, l_ptr, s_ptr, contiguous, dynamic_casting, extra_args);
return;
}
// Cases 3 and 4 are handled below
// Both require construction of a storer (this asserts 1 output) and one or more loaders
// Creates load casts from inputs (note offset indexing into the iterators 1...n tensors)
LoadWithCastVariant loader(iter);
void* l_ptr = loader.data_ptr();
// Creates store cast to output (the zeroth tensor in TensorIterator)
auto storer = memory::StoreWithCast(iter.dtype(0));
void* s_ptr = static_cast<void*>(&storer);
if (contiguous) {
// Case 3: dynamic casting and contiguous
TrivialOffsetCalculatorVariant input_offset_calculator(iter);
void* ic_ptr = input_offset_calculator.data_ptr();
auto output_offset_calculator = TrivialOffsetCalculator<1>();
void* oc_ptr = static_cast<void*>(&output_offset_calculator);
launch_jitted_unrolled_kernel_dynamic(
kernel_name, iter, iter.device().index(), numel, f, data_ptr,
ic_ptr, oc_ptr, l_ptr, s_ptr, contiguous, dynamic_casting, extra_args);
return;
}
// Case 4: dynamic casting and noncontiguous
OffsetCalculatorVariant input_offset_calculator(iter);
void* ic_ptr = input_offset_calculator.data_ptr();
auto output_offset_calculator = make_output_offset_calculator(iter);
void* oc_ptr = static_cast<void*>(&output_offset_calculator);
launch_jitted_unrolled_kernel_dynamic(
kernel_name, iter, iter.device().index(), numel, f, data_ptr,
ic_ptr, oc_ptr, l_ptr, s_ptr, contiguous, dynamic_casting, extra_args);
}
// Entrypoint for dynamic version of jitted GPU kernels, which accepts dynamic number of inputs
// and arbitrary types of input and extra args. This dynamic version is needed for jiterator with python interface,
// since the kernel definition is unknown at the compilation time.
// Similarly, launch_jitted_vectorized_kernel_dynamic and launch_jitted_unrolled_kernel_dynamic are created
// to handle arbitrary functions defined in python user code.
// For templated version, see note [Jiterator] in JitLoops.cuh for more details
void jitted_gpu_kernel_dynamic(
const std::string& kernel_name,
TensorIteratorBase& iter,
const std::string& f,
const std::vector<at::Scalar>& extra_args) {
// TODO: much of preamble is common to both jitted_gpu_kernel and gpu_kernel
// Maybe it could be refactored?
for (int arg = 0; arg < iter.ntensors(); arg++) {
TORCH_INTERNAL_ASSERT(
iter.device(arg).is_cuda(),
"argument ", arg, ": expected a CUDA device but found ", iter.device(arg));
}
if (iter.numel() == 0) {
return;
}
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
jitted_gpu_kernel_dynamic(kernel_name, sub_iter, f, extra_args);
}
return;
}
// Computes if dynamic casting is needed
// Dynamic casting is needed if an input's or output's dtype differs from the common dtype
bool needs_dynamic_casting = false;
const at::ScalarType common_dtype = iter.common_dtype();
for (auto i = 0; i < iter.ntensors(); ++i) {
if (iter.dtype(i) != common_dtype) {
needs_dynamic_casting = true;
break;
}
}
jitted_gpu_kernel_dynamic_impl(kernel_name, iter, f, needs_dynamic_casting, extra_args);
}
} // namespace native
namespace cuda {
at::Tensor CompileAndLaunchKernel(
const std::string& code_string,
const std::string& kernel_name,
const std::vector<at::Tensor>& tensors,
const std::vector<at::Scalar>& extra_args) {
Tensor output;
TensorIteratorConfig config;
config
.set_check_mem_overlap(true)
.allow_cpu_scalars(false)
.promote_inputs_to_common_dtype(true)
.cast_common_dtype_to_outputs(true)
.enforce_safe_casting_to_output(true)
.check_all_same_device(true)
.add_owned_output(output);
for (const auto& t: tensors){
config.add_input(t);
}
TensorIterator iter = config.build();
CUDAGuard guard(iter.device());
at::native::jitted_gpu_kernel_dynamic(kernel_name, iter, code_string, extra_args);
return iter.output();
}
}} // namespace at::cuda
#endif // AT_USE_JITERATOR()

View File

@ -0,0 +1,35 @@
#pragma once
#include <ATen/jit_macros.h>
#if AT_USE_JITERATOR()
#include <c10/macros/Export.h>
#include <ATen/core/Tensor.h>
#include <string>
#include <vector>
namespace at {
namespace cuda {
TORCH_CUDA_CPP_API at::Tensor CompileAndLaunchKernel(
const std::string& code_string,
const std::string& kernel_name,
const std::vector<at::Tensor>& tensors,
const std::vector<at::Scalar>& extra_args);
}} // namespace at::cuda
#else
namespace at { namespace cuda {
TORCH_CUDA_CPP_API at::Tensor CompileAndLaunchKernel(
const std::string& code_string,
const std::string& kernel_name,
const std::vector<at::Tensor>& tensors,
const std::vector<at::Scalar>& extra_args) {
TORCH_CHECK(false, "Jiterator is not supported on ROCm");
}
}} // namespace at::cuda
#endif // AT_USE_JITERATOR()

View File

@ -0,0 +1,208 @@
#pragma once
#include <ATen/jit_macros.h>
#if AT_USE_JITERATOR()
#include <c10/util/variant.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/native/cuda/jit_utils.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <ATen/native/cuda/JitLoops.cuh>
#include <string>
#include <vector>
namespace at {
namespace native {
constexpr int NUM_INPUTS = 8;
#define AT_FOR_8_INPUTS(_) \
_(1) \
_(2) \
_(3) \
_(4) \
_(5) \
_(6) \
_(7) \
_(8)
c10::SmallVector<std::string> get_extra_args_typenames(const std::vector<at::Scalar>& extra_args) {
c10::SmallVector<std::string> args_typenames(extra_args.size());
for (auto i = 0; i < extra_args.size(); ++i) {
args_typenames[i] = at::cuda::jit::typeName(extra_args[i].type());
}
return args_typenames;
}
int can_vectorize_up_to(at::ScalarType type, char* pointer) {
switch(type) {
#define DEFINE_CASE(ctype, scalartype) \
case ScalarType::scalartype : return memory::can_vectorize_up_to<ctype>(pointer);
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
#undef DEFINE_CASE
default: TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type);
}
}
// jitted version of the above
// See Note [Jiterator], this relies on the assumptions enumerated there
int jitted_can_vectorize_up_to(const TensorIteratorBase& iter) {
const at::ScalarType common_dtype = iter.common_dtype();
const at::ScalarType result_dtype = common_dtype;
// Deals with output
int result = can_vectorize_up_to(result_dtype, static_cast<char*>(iter.data_ptr(0)));
// Incorporates input(s)
for (auto i = 1; i < iter.ntensors(); ++i) {
result = std::min<int>(result, can_vectorize_up_to(common_dtype, static_cast<char*>(iter.data_ptr(i))));
}
return result;
}
template<int N>
static std::unique_ptr<OffsetCalculator<N>> make_unique_input_offset_calculator(const TensorIteratorBase& iter) {
// array size can not be 0, this happens when N == 0
constexpr int array_size = std::max<int>(N, 1);
TORCH_INTERNAL_ASSERT(N == iter.ntensors() - iter.noutputs());
std::array<const int64_t*, array_size> strides;
int64_t element_sizes[array_size];
for (int i = 0; i < N; i++) {
strides[i] = iter.strides(i + iter.noutputs()).data();
element_sizes[i] = iter.element_size(i + iter.noutputs());
}
return std::make_unique<OffsetCalculator<N>>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
}
struct OffsetCalculatorVariant {
#define DEFINE_CASE(index) std::unique_ptr<OffsetCalculator<index>>,
using OffsetCalculatorTypes = c10::variant<
AT_FOR_8_INPUTS(DEFINE_CASE)
>;
#undef DEFINE_CASE
OffsetCalculatorVariant(const TensorIteratorBase& iter) {
int arity = iter.ninputs();
switch(arity) {
#define DEFINE_CASE(index) \
case index : v = make_unique_input_offset_calculator<index>(iter); break;
AT_FOR_8_INPUTS(DEFINE_CASE)
#undef DEFINE_CASE
default:
TORCH_CHECK(false, "OffsetCalculatorVariant is not implemented for ninputs = ", arity);
}
}
void* data_ptr() {
return c10::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
}
private:
OffsetCalculatorTypes v;
};
struct ArrayVariant {
// notice: This would produce c10::variant<at::detail::Array<char*, 2...9>>
#define DEFINE_CASE(index) at::detail::Array<char*, index + 1>,
using ArrayTypes = c10::variant<
AT_FOR_8_INPUTS(DEFINE_CASE)
>;
#undef DEFINE_CASE
ArrayVariant(const TensorIteratorBase& iter) {
int arity = iter.ninputs();
// This assumes that jiterator kernels only have 1 output
switch(arity) {
#define DEFINE_CASE(index) \
case index: array = at::detail::Array<char*, index + 1>{}; break;
AT_FOR_8_INPUTS(DEFINE_CASE)
#undef DEFINE_CASE
default:
TORCH_CHECK(false, "ArrayVariant is not implemented for ninputs = ", arity);
}
c10::visit([&](auto& a) {
for (auto i = 0; i < arity + 1; ++i) {
a[i] = (char*)iter.data_ptr(i);
}
}, array);
}
void* data_ptr() {
return c10::visit([](auto & a){ return static_cast<void*>(&a); }, array);
}
private:
ArrayTypes array;
};
struct TrivialOffsetCalculatorVariant {
#define DEFINE_CASE(index) TrivialOffsetCalculator<index>,
using TrivialOffsetCalculatorTypes = c10::variant<
AT_FOR_8_INPUTS(DEFINE_CASE)
>;
#undef DEFINE_CASE
TrivialOffsetCalculatorVariant(const TensorIteratorBase& iter) {
int arity = iter.ninputs();
switch(arity) {
#define DEFINE_CASE(index) \
case index: v = TrivialOffsetCalculator<index>(); break;
AT_FOR_8_INPUTS(DEFINE_CASE)
#undef DEFINE_CASE
default:
TORCH_CHECK(false, "TrivialOffsetCalculatorVariant is not implemented for ninputs = ", arity);
}
}
void* data_ptr() {
return c10::visit([](auto & v){ return static_cast<void*>(&v); }, v);
}
private:
TrivialOffsetCalculatorTypes v;
};
struct LoadWithCastVariant {
#define DEFINE_CASE(index) std::unique_ptr<memory::LoadWithCast<index>>,
using LoadWithCastPtr = c10::variant<
AT_FOR_8_INPUTS(DEFINE_CASE)
>;
#undef DEFINE_CASE
LoadWithCastVariant(const TensorIteratorBase& iter) {
int arity = iter.ninputs();
switch(arity) {
#define DEFINE_CASE(index) \
case index: v = std::make_unique<memory::LoadWithCast<index>>(iter); break;
AT_FOR_8_INPUTS(DEFINE_CASE)
#undef DEFINE_CASE
default:
TORCH_CHECK(false, "LoadWithCastVariant is not implemented for ninputs = ", arity);
}
}
void* data_ptr() {
return c10::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
}
private:
LoadWithCastPtr v;
};
}} // namespace at::native
#endif // AT_USE_JITERATOR()

View File

@ -116,6 +116,15 @@ struct LoadWithCast {
}
}
LoadWithCast(const TensorIteratorBase& iter) {
assert(iter.ninputs() == N);
#pragma unroll
for (auto i = 0; i < N; ++i) {
this->dtypes[i] = iter.dtype(i + 1);
element_sizes[i] = c10::elementSize(iter.dtype(i + 1));
}
}
template<typename scalar_t>
__device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) {
void *ptr = base_ptr + element_sizes[arg] * offset;

View File

@ -7,7 +7,6 @@
#include <ATen/jit_macros.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
#include <ATen/code_template.h>
#include <ATen/native/cuda/jit_utils.h>
#include <ATen/cuda/llvm_jit_strings.h>

View File

@ -8,6 +8,7 @@
#include <c10/util/irange.h>
#include <ATen/jit_macros.h>
#include <ATen/cuda/detail/LazyNVRTC.h>
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
namespace at { namespace cuda { namespace jit {
@ -85,6 +86,9 @@ AT_FORALL_SCALAR_TYPES(TYPE_NAME_FN)
// JIT uses std::complex directly, because nvRTC compile programs
// with -default-device, so there is no such issue like:
// "std::sin(complex) is __host__ only"
template <> inline std::string typeName<bool>(){
return "bool";
}
template <> inline std::string typeName<c10::complex<at::Half>>(){
return "std::complex<at::Half>";
}
@ -101,4 +105,20 @@ template <> inline std::string typeName<at::BFloat16>(){
return "at::BFloat16";
}
#define TYPE_NAME_CASE(ctype, scalartype) \
case ScalarType::scalartype: return std::string(#ctype);
inline std::string typeName(ScalarType t) {
switch (t) {
AT_FORALL_SCALAR_TYPES(TYPE_NAME_CASE)
case ScalarType::Bool : return "bool";
case ScalarType::Half : return "at::Half";
case ScalarType::BFloat16 : return "at::BFloat16";
case ScalarType::ComplexFloat : return "std::complex<float>";
case ScalarType::ComplexDouble : return "std::complex<double>";
default:
TORCH_CHECK(false, "invalid type for jiterator");
}
}
#undef TYPE_NAME_CASE
}}} // namespace at::cuda::jit

View File

@ -74,6 +74,10 @@ class C10_API Scalar {
template <typename T>
T to() const = delete;
const void* data_ptr() const {
return static_cast<const void*>(&v);
}
#undef DEFINE_ACCESSOR
bool isFloatingPoint() const {
return Tag::HAS_d == tag;

View File

@ -280,7 +280,21 @@ namespace std {
#define C10_MPARK_BUILTIN_UNREACHABLE
#endif
#if __has_builtin(__type_pack_element)
// NOTE [nvcc bug workaround]
//
// The original line `typename Front = lib::type_pack_element_t<0, Ts...>,`
// throws the following compiler error on nvcc:
// ```
// c10/util/variant.h(2367): error: parameter pack "Ts" was referenced but not
// expanded
// ```
// As a workaround, we skip defining C10_MPARK_TYPE_PACK_ELEMENT for nvcc
// compiler
//
// See the following issues for more context:
// https://github.com/pytorch/extension-cpp/issues/58
// https://github.com/mpark/variant/issues/77
#if __has_builtin(__type_pack_element) && !defined(__CUDACC__)
#define C10_MPARK_TYPE_PACK_ELEMENT
#endif

View File

@ -123,3 +123,11 @@ NVIDIA Tools Extension (NVTX)
nvtx.mark
nvtx.range_push
nvtx.range_pop
Jiterator (beta)
-----------------------------
.. autosummary::
:toctree: generated
:nosignatures:
jiterator._create_jit_fn

132
test/test_jiterator.py Normal file
View File

@ -0,0 +1,132 @@
# Owner(s): ["module: cuda"]
import torch
from torch.cuda.jiterator import _create_jit_fn as create_jit_fn
import sys
from itertools import product
from torch.testing._internal.common_utils import TestCase, parametrize, run_tests, TEST_CUDA
from torch.testing._internal.common_dtype import all_types_and_complex_and
from torch.testing._internal.common_device_type import (
skipCUDAIfRocm, skipCUDAIf, instantiate_device_type_tests, dtypes, toleranceOverride, tol)
from torch.testing._internal.common_cuda import _get_torch_cuda_version
if not TEST_CUDA:
print('CUDA not available, skipping tests', file=sys.stderr)
TestCase = object # noqa: F811
code_string = "template <typename T> T my_fused_kernel(T x, T y, T alpha, T beta) { return alpha * x + beta * y; }"
jitted_fn = create_jit_fn(code_string, alpha=1, beta=1)
def ref_fn(x, y, alpha=1, beta=1):
return alpha * x + beta * y
class TestPythonJiterator(TestCase):
@skipCUDAIfRocm
@parametrize("shape_strides", [
(([3, 3], [3, 1]), ([3, 3], [3, 1])), # contiguous
])
@dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16),
all_types_and_complex_and(torch.half, torch.bfloat16)))
def test_all_dtype_contiguous(self, device, dtypes, shape_strides):
a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0])
b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1])
a = a_buffer.as_strided(*shape_strides[0])
b = b_buffer.as_strided(*shape_strides[1])
expected = ref_fn(a, b)
result = jitted_fn(a, b)
self.assertEqual(expected, result)
@skipCUDAIfRocm
# See https://github.com/pytorch/pytorch/pull/76394#issuecomment-1118018287 for details
@skipCUDAIf(_get_torch_cuda_version() < (11, 6), "On cuda 11.3, nvrtcCompileProgram is taking too long to "
"compile jiterator generated kernels for non-contiguous input that requires dynamic-casting.")
@parametrize("shape_strides", [
(([3, 3], [1, 3]), ([3, 1], [1, 3])), # non-contiguous
])
@dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16),
all_types_and_complex_and(torch.half, torch.bfloat16)))
def test_all_dtype_noncontiguous(self, device, dtypes, shape_strides):
a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0])
b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1])
a = a_buffer.as_strided(*shape_strides[0])
b = b_buffer.as_strided(*shape_strides[1])
expected = ref_fn(a, b)
result = jitted_fn(a, b)
self.assertEqual(expected, result)
@skipCUDAIfRocm
@dtypes(torch.float, torch.double, torch.float16, torch.bfloat16)
@parametrize("alpha", [-1, 2.0, None])
@parametrize("beta", [3, -4.2, None])
@toleranceOverride({torch.float16 : tol(atol=1e-2, rtol=1e-3)})
def test_extra_args(self, device, dtype, alpha, beta):
a = torch.rand(3, device=device).mul(10).type(dtype)
b = torch.rand(3, device=device).mul(10).type(dtype)
extra_args = {}
if alpha is not None:
extra_args["alpha"] = alpha
if beta is not None:
extra_args["beta"] = beta
expected = ref_fn(a, b, **extra_args)
result = jitted_fn(a, b, **extra_args)
self.assertEqual(expected, result)
@skipCUDAIfRocm
def test_bool_extra_args(self, device):
code_string = "template <typename T> T conditional(T x, T mask, bool is_train) { return is_train ? x * mask : x; }"
jitted_fn = create_jit_fn(code_string, is_train=False)
def ref_fn(x, mask, is_train):
return x * mask if is_train else x
a = torch.rand(3, device=device)
b = torch.rand(3, device=device)
expected = ref_fn(a, b, is_train=True)
result = jitted_fn(a, b, is_train=True)
self.assertEqual(expected, result)
@skipCUDAIfRocm
@parametrize("num_inputs", list(range(1, 9)))
def test_various_num_inputs(self, num_inputs):
inputs = []
for i in range(num_inputs):
inputs.append(torch.rand(3, device='cuda').mul(10))
input_string = ",".join([f"T i{i}" for i in range(num_inputs)])
function_body = "+".join([f"i{i}" for i in range(num_inputs)])
code_string = f"template <typename T> T my_kernel({input_string}) {{ return {function_body}; }}"
jitted_fn = create_jit_fn(code_string)
def ref_fn(*inputs):
return torch.sum(torch.stack(inputs), dim=0)
expected = ref_fn(*inputs)
result = jitted_fn(*inputs)
self.assertEqual(expected, result)
@skipCUDAIfRocm
@parametrize("code_string", [
"template <typename T> T my _kernel(T x) { return x; }",
"template <typename T> Tmy_kernel(T x) { return x; }",
])
def test_invalid_function_name(self, code_string):
with self.assertRaises(Exception):
jitted_fn = create_jit_fn(code_string)
instantiate_device_type_tests(TestPythonJiterator, globals(), only_for="cuda")
if __name__ == '__main__':
run_tests()

View File

@ -67,7 +67,7 @@ def check_file(filename: str) -> Optional[LintMessage]:
name="testestTrailing newline",
original=None,
replacement=None,
description="Trailing newline found. Run `lintunner --take NEWLINE -a` to apply changes.",
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
)
else:
@ -103,7 +103,7 @@ def check_file(filename: str) -> Optional[LintMessage]:
name="Trailing newline",
original=original,
replacement=original.rstrip("\n") + "\n",
description="Trailing newline found. Run `lintunner --take NEWLINE -a` to apply changes.",
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
)

View File

@ -924,6 +924,10 @@ def _cuda_memorySnapshot() -> List[Dict[str, Any]]: ...
def _cuda_lock_mutex() -> None: ...
def _cuda_unlock_mutex() -> None: ...
def _cuda_canDeviceAccessPeer(device: _int, peer_device: _int) -> _bool: ...
def _cuda_jiterator_compile_and_launch_kernel(code_string: str,
kernel_name: str,
tensors: Tuple,
kwargs: Dict[str, Union[_int, _float, _bool]]) -> Tensor: ...
def _nccl_version() -> _int: ...
def _nccl_unique_id() -> bytes: ...
def _nccl_init_rank(nranks: _int, comm_id: bytes, rank: _int) -> object: ...

View File

@ -6,6 +6,7 @@
#include <ATen/cuda/CachingHostAllocator.h>
#include <ATen/cuda/Sleep.h>
#include <ATen/cuda/detail/CUDAHooks.h>
#include <ATen/cuda/jiterator.h>
#ifdef USE_NCCL
#include <torch/csrc/cuda/python_nccl.h>
#endif
@ -217,6 +218,71 @@ PyObject * THCPModule_cudaCachingAllocator_raw_alloc(PyObject *_unused, PyObject
END_HANDLE_TH_ERRORS
}
// Unpack a PyObject to at::Scalar, throw an exception if it fails
at::Scalar as_scalar(PyObject* arg) {
// Zero-dim tensors are converted to Scalars as-is. Note this doesn't currently
// handle most NumPy scalar types except np.float64.
if (THPVariable_Check(arg)) {
return THPVariable_Unpack(arg).item();
}
if (THPUtils_checkLong(arg)) {
return at::Scalar(static_cast<int64_t>(THPUtils_unpackLong(arg)));
}
if (PyBool_Check(arg)) {
return at::Scalar(THPUtils_unpackBool(arg));
}
if (PyComplex_Check(arg)) {
return at::Scalar(THPUtils_unpackComplexDouble(arg));
}
return at::Scalar(THPUtils_unpackDouble(arg));
}
// Entrypoint for the callable created by torch.cuda.jiterator
// See jiterator.py for more details
PyObject * THCPModule_cudaJiteratorCompileAndLaunchKernel(PyObject *_unused, PyObject *args){
HANDLE_TH_ERRORS
PyObject* code_string_o = nullptr;
PyObject* kernel_name_o = nullptr;
PyObject* tensors_o = nullptr;
PyObject* kwargs_o = nullptr;
if(!PyArg_ParseTuple(args, "OOO|O", &code_string_o, &kernel_name_o, &tensors_o, &kwargs_o)) {
return nullptr;
}
std::string code_string = THPUtils_unpackString(code_string_o);
std::string kernel_name = THPUtils_unpackString(kernel_name_o);
THPUtils_assert(PyTuple_Check(tensors_o), "tensors argument is expected to "
"be a tuple, but got %s", THPUtils_typename(tensors_o));
Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors_o);
std::vector<at::Tensor> tensors;
for(const auto i : c10::irange(num_tensors)) {
PyObject *_tensor = PyTuple_GET_ITEM(tensors_o, i);
THPUtils_assert(THPVariable_Check(_tensor), "element %d of tensors "
"tuple is not a Tensor", i);
tensors.emplace_back(THPVariable_Unpack(_tensor));
}
std::vector<at::Scalar> extra_args;
PyObject *key = nullptr;
PyObject *value = nullptr;
Py_ssize_t pos = 0;
while (PyDict_Next(kwargs_o, &pos, &key, &value)) {
extra_args.emplace_back(as_scalar(value));
}
at::Tensor output = at::cuda::CompileAndLaunchKernel(code_string, kernel_name, tensors, extra_args);
return THPVariable_Wrap(output);
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_cudaCachingAllocator_raw_delete(PyObject *_unused, PyObject *obj){
HANDLE_TH_ERRORS
void* mem_ptr = PyLong_AsVoidPtr(obj);
@ -597,6 +663,7 @@ static struct PyMethodDef _THCPModule_methods[] = {
{"_cuda_unlock_mutex", THCPModule_cudaUnlockMutex, METH_NOARGS, nullptr},
{"_cuda_set_sync_debug_mode", THCPModule_cudaSetSyncDebugMode, METH_O, nullptr},
{"_cuda_get_sync_debug_mode", THCPModule_cudaGetSyncDebugMode, METH_NOARGS, nullptr},
{"_cuda_jiterator_compile_and_launch_kernel", THCPModule_cudaJiteratorCompileAndLaunchKernel, METH_VARARGS, nullptr},
#ifdef USE_NCCL
{"_nccl_version", THCPModule_nccl_version, METH_NOARGS, nullptr},
{"_nccl_unique_id", THCPModule_nccl_unique_id, METH_NOARGS, nullptr},

View File

@ -767,3 +767,4 @@ from . import sparse
from . import profiler
from . import nvtx
from . import amp
from . import jiterator

104
torch/cuda/jiterator.py Normal file
View File

@ -0,0 +1,104 @@
import torch
from torch import Tensor
from typing import Callable, List
import re
__all__ : List[str] = []
class _CodeParser:
def __init__(self, code_string: str):
optional_ws = r"\s*"
required_ws = r"\s+"
template_params = r"(?P<template_params>\<.+\>)"
return_type = r"(?P<return_type>\w+)"
function_name = r"(?P<function_name>\w+)"
function_params = r"(?P<function_params>\(.+\))"
function_body = r"(?P<function_body>\{.+\})"
pattern = \
optional_ws \
+ "template" \
+ optional_ws + template_params \
+ optional_ws + return_type \
+ required_ws + function_name \
+ optional_ws + function_params \
+ optional_ws + function_body \
+ optional_ws
result = re.match(pattern, code_string, re.DOTALL) # DOTALL for matching multiline
if result is None:
raise Exception(f"Couldn't parse code, please check correctness:\n {code_string}")
self.template_params = result["template_params"]
self.return_type = result["return_type"]
self.function_name = result["function_name"]
self.function_params = result["function_params"]
self.function_body = result["function_body"]
def _create_jit_fn(code_string: str, **kwargs) -> Callable:
"""
Create a jiterator-generated cuda kernel for an elementwise op.
The code string has to be a valid CUDA function that describes the computation for a single element. The code
string has to follow the c++ template pattern, as shown in the example below. This function will be inlined
into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as
local temp dir.
Jiterator-generated kernels accepts noncontiguous tensors, and supports boardcasting and type promotion.
Args:
code_string (string): CUDA code string to be compiled by jiterator.
kwargs (Dict, optional): Keyword arguments for generated function
Examples:
>>> code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"
>>> jitted_fn = create_jit_fn(code_string, alpha=1.0)
>>> a = torch.rand(3, device='cuda')
>>> b = torch.rand(3, device='cuda')
>>> # invoke jitted function like a regular python function
>>> result = jitted_fn(a, b, alpha=3.14)
.. warning::
This API is in beta and may change in future releases.
.. warning::
Jiterator only supports up to 8 tensor inputs
.. warning::
All input tensors must live in CUDA device
"""
class JittedFunction:
def __init__(self, code_string: str, **kwargs):
self.code_string = code_string
parsed_code = _CodeParser(code_string)
self.kernel_name = parsed_code.function_name
self.kwargs_dict = kwargs
self.is_cuda_available = torch.cuda.is_available()
def __call__(self, *tensors: Tensor, **kwargs):
# Jiterator follow torch.cuda's lazy initialization behavior
# Defer checking cuda's availability at the function invocation time
assert self.is_cuda_available, "Jiterator is only supported on CUDA GPUs, no CUDA GPUs are available."
assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs."
expanded_kwargs = self.kwargs_dict.copy()
for key, value in kwargs.items():
if key in self.kwargs_dict:
expanded_kwargs[key] = value
else:
raise KeyError(f"{key} is not declared in function definition")
return torch._C._cuda_jiterator_compile_and_launch_kernel(
self.code_string,
self.kernel_name,
tensors,
expanded_kwargs)
return JittedFunction(code_string, **kwargs)

View File

@ -2299,6 +2299,34 @@ def sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs)
)
def sample_inputs_jiterator(op, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
shapes = (
((), ()),
((S,), ()),
((S, 1), (S,)),
((M, S), ()),
((S, M, S), (M, S)),
((S, M, S), (S, M, S)),
((M, 1, S), (M, S)),
((M, 1, S), (1, M, S)),
((0, 1, 3), (0, 10, 3))
)
num_inputs = kwargs.get('num_inputs')
sample_kwargs = kwargs.get('sample_kwargs', {})
for shape_lhs, shape_rhs in shapes:
lhs = make_arg(shape_lhs)
args = []
for i in range(num_inputs - 1):
args.append(make_arg(shape_rhs))
broadcasts_input = (shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs))
yield SampleInput(lhs, args=tuple(args), kwargs=sample_kwargs, broadcasts_input=broadcasts_input)
# The base reference input generation for elementwise binary operations
def _reference_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs):
yield from op.sample_inputs_func(op, device, dtype, requires_grad, **kwargs)
@ -16371,6 +16399,95 @@ op_db: List[OpInfo] = [
# Can't find schemas for this operator for some reason
DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
)),
# Following tests are for jiterator's python interface
# Jiterator can be used to author elementwise CUDA kernel
# jiterator._create_jit_fn returns a callable that behaves like a regular pytorch op
# See create_jit_fn in jiterator.py for more information
UnaryUfuncInfo(
'jiterator_unary',
op=torch.cuda.jiterator._create_jit_fn("template <typename T> T unary(T x) { return x * x + x; }"),
ref=lambda x: x * x + x,
dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
supports_out=False,
supports_autograd=False, # jiterator ops doesn't have backward defined
decorators=[
onlyCUDA,
skipCUDAIfRocm,
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
'TestUnaryUfuncs', 'test_reference_numerics_hard'),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
'TestUnaryUfuncs', 'test_reference_numerics_normal'),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
'TestUnaryUfuncs', 'test_reference_numerics_small'),
],
skips=(
# Jiterator ops doesn't support neg or conj view
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
# Jiterator ops doesn't suport CompositeCompliantTensor
# Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped
DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'),
# Skip reference_numerics tests for bool type, as the defined function doesn't work for bool
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
dtypes=[torch.bool]),
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard',
dtypes=[torch.bool]),
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
dtypes=[torch.bool]),
# Expected failure: torch.jiterator_unary is not a valid op
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
)
),
BinaryUfuncInfo(
'jiterator_binary',
op=torch.cuda.jiterator._create_jit_fn(
"template <typename T> T binary(T x, T y, T alpha) { return x + alpha * y; }", alpha=1),
ref=lambda input, other, *, alpha=1: np.add(input, other) if alpha == 1 \
else np.add(input, np.multiply(alpha, other)),
dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2, alpha=-3.14),
supports_out=False,
supports_autograd=False, # jiterator ops doesn't have backward defined
supports_rhs_python_scalar=False,
decorators=[onlyCUDA, skipCUDAIfRocm],
skips=(
# Jiterator ops doesn't support neg or conj view
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
# Jiterator ops doesn't suport CompositeCompliantTensor
# Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped
DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'),
# Expected failure: torch.jiterator_binary is not a valid op
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
)
),
OpInfo(
'jiterator_4inputs_with_extra_args',
op=torch.cuda.jiterator._create_jit_fn(
"template <typename T> T binary(T i0, T i1, T i2, T i3, T alpha, T beta) { return alpha * i0 + beta * i1 + i2 + i3; }",
alpha=1, beta=1),
ref=lambda i0, i1, i2, i3, *, alpha=1, beta=1: alpha * i0 + beta * i1 + i2 + i3,
dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=4, alpha=3.14, beta=-4.20),
supports_out=False,
supports_autograd=False, # jiterator ops doesn't have backward defined
decorators=[onlyCUDA, skipCUDAIfRocm],
skips=(
# Jiterator ops doesn't support neg or conj view
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
# Jiterator ops doesn't suport CompositeCompliantTensor
# Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped
DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'),
# Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
)
),
# `torch.norm` has multiple code paths depending on the value of `p`.
# These paths have different dtype support. Also JIT supports,
# most variants but not all of them. So we split the OpInfo entries,