mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Python Interface for Jiterator
This PR allows user to author a CUDA kernel in python.
```
from torch.cuda.jiterator import create_jit_fn
code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x * y + x - y + alpha; }"
jitted_fn = create_jit_fn(code_string, alpha=0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
result = jitted_fn(a, b, alpha=1.0)
```
Limitations:
- Only supports elementwise kernel
- 1~8 tensor inputs (empty input, e.g. factory methods, is not supported)
- inputs tensors must live in cuda device
- cpu Scalar is not supported
- kwargs must be pre-declared when calling create_jit_fn
- kwargs must be convertible to at::Scalar, one of float64, int64_t, bool. (complex not support for now)
TODOs:
- [x] consolidate union and c10::variant implementation
- [x] plug into existing op testing framework
- [ ] rename files, place files in the right folder
- [ ] place util functions in the right file
- [x] enforce assumptions in python interface e.g <8 inputs, kwargs types
- [x] Add user-facing documentation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76394
Approved by: https://github.com/mruberry
This commit is contained in:
parent
6c615a21a0
commit
8b6a78f39f
345
aten/src/ATen/cuda/jiterator.cu
Normal file
345
aten/src/ATen/cuda/jiterator.cu
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
#include <ATen/jit_macros.h>
|
||||
|
||||
#if AT_USE_JITERATOR()
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/jiterator.h>
|
||||
#include <ATen/cuda/jiterator_impl.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
#include <chrono>
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
static inline void launch_jitted_vectorized_kernel_dynamic(
|
||||
const std::string& name, TensorIteratorBase& iter,
|
||||
DeviceIndex dev_idx, int64_t N, const std::string& f, void* data_ptr,
|
||||
const std::vector<at::Scalar>& extra_args) {
|
||||
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
|
||||
// N is still int64_t for the computation, but it's always safe to cast result to int
|
||||
const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
|
||||
|
||||
const int vec_size = jitted_can_vectorize_up_to(iter);
|
||||
bool vectorized = vec_size > 1;
|
||||
|
||||
// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
|
||||
// fn_ptr is set to the appropriate function based on the vec size and GPU used
|
||||
// TODO: Memory use can probably be optimized by re-using kernels across GPUs with
|
||||
// the same compute capability
|
||||
|
||||
int nTensors = iter.ntensors();
|
||||
const at::ScalarType common_dtype = iter.common_dtype();
|
||||
std::string f_inputs_type_str = at::cuda::jit::typeName(common_dtype);
|
||||
std::string compute_type_str = at::cuda::jit::typeName(toOpMathType(common_dtype));
|
||||
std::string result_type_str = at::cuda::jit::typeName(common_dtype);
|
||||
c10::SmallVector<std::string> extra_args_types = get_extra_args_typenames(extra_args);
|
||||
|
||||
// The cache key includes all the parameters to generate_code + vec_size + dev_idx
|
||||
std::stringstream ss;
|
||||
ss << nTensors << f;
|
||||
ss << f_inputs_type_str << compute_type_str << result_type_str;
|
||||
ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar);
|
||||
ss << extra_args_types;
|
||||
ss << vec_size;
|
||||
// DeviceIndex, e.g. int8_t, is not treated as a number by the stream, cast to int as a workaround
|
||||
ss << static_cast<int>(dev_idx);
|
||||
const std::string cache_key = ss.str();
|
||||
|
||||
static std::mutex _jiterator_mutex;
|
||||
static std::unordered_map<std::string, at::cuda::jit::NvrtcFunction> fns;
|
||||
at::cuda::jit::NvrtcFunction* fn_ptr = &fns[cache_key];
|
||||
|
||||
if (!fn_ptr->function) {
|
||||
const std::lock_guard<std::mutex> lock{_jiterator_mutex};
|
||||
if (!fn_ptr->function) { // cache miss!
|
||||
// Generates program
|
||||
auto code = at::cuda::jit::generate_code(nTensors, f, name,
|
||||
f_inputs_type_str, compute_type_str, result_type_str,
|
||||
/*contiguous=*/true, /*dynamic_casting=*/false,
|
||||
at::cuda::jit::BinaryFuncVariant::NoScalar,
|
||||
extra_args_types,
|
||||
vectorized, vec_size);
|
||||
std::string kernel_name = vectorized ? name + "_vectorized" + std::to_string(vec_size) : name;
|
||||
// Acquires the program
|
||||
*fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name);
|
||||
}
|
||||
}
|
||||
|
||||
// size of `extra_args` is unknown at compile-time
|
||||
auto extra_args_size = extra_args.size();
|
||||
|
||||
float scalar_val = 0;
|
||||
|
||||
if (vectorized) {
|
||||
// pack args for kernel launch
|
||||
constexpr int kernel_args = 3;
|
||||
auto args = std::make_unique<void*[]>(kernel_args + extra_args_size);
|
||||
args[0] = static_cast<void*>(&N);
|
||||
args[1] = data_ptr;
|
||||
args[2] = static_cast<void*>(&scalar_val);
|
||||
|
||||
for (const auto i : c10::irange(extra_args_size)) {
|
||||
// since 3 slots are already filled in `args`
|
||||
args[i + 3] = const_cast<void*>(extra_args[i].data_ptr());
|
||||
}
|
||||
at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args.get(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
|
||||
} else {
|
||||
TrivialOffsetCalculatorVariant input_offset_calculator(iter);
|
||||
void* ic_ptr = input_offset_calculator.data_ptr();
|
||||
auto oc = TrivialOffsetCalculator<1>();
|
||||
auto l = memory::LoadWithoutCast();
|
||||
auto s = memory::StoreWithoutCast();
|
||||
|
||||
// pack args for kernel launch
|
||||
constexpr int kernel_args = 7;
|
||||
auto args = std::make_unique<void*[]>(kernel_args + extra_args_size);
|
||||
args[0] = static_cast<void*>(&N);
|
||||
args[1] = data_ptr;
|
||||
args[2] = ic_ptr;
|
||||
args[3] = static_cast<void*>(&oc);
|
||||
args[4] = static_cast<void*>(&l);
|
||||
args[5] = static_cast<void*>(&s);
|
||||
args[6] = static_cast<void*>(&scalar_val);
|
||||
|
||||
for (const auto i : c10::irange(extra_args_size)) {
|
||||
// since 7 slots are already filled in `args`
|
||||
args[i + 7] = const_cast<void*>(extra_args[i].data_ptr());
|
||||
}
|
||||
|
||||
at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args.get(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
|
||||
}
|
||||
}
|
||||
|
||||
static inline void launch_jitted_unrolled_kernel_dynamic(
|
||||
const std::string& name, TensorIteratorBase& iter,
|
||||
DeviceIndex dev_idx, int64_t N, const std::string& f, void* data_ptr,
|
||||
void* ic_ptr, void* oc_ptr, void* l_ptr, void* s_ptr, bool contiguous, bool dynamic_casting,
|
||||
const std::vector<at::Scalar>& extra_args) {
|
||||
|
||||
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
|
||||
//casting result to int is always safe, intermediate is int64 and won't overflow
|
||||
const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
|
||||
|
||||
int nTensors = iter.ntensors();
|
||||
const at::ScalarType common_dtype = iter.common_dtype();
|
||||
std::string f_inputs_type_str = at::cuda::jit::typeName(common_dtype);
|
||||
std::string compute_type_str = at::cuda::jit::typeName(toOpMathType(common_dtype));
|
||||
std::string result_type_str = at::cuda::jit::typeName(common_dtype);
|
||||
c10::SmallVector<std::string> extra_args_types = get_extra_args_typenames(extra_args);
|
||||
|
||||
// The cache key includes all the parameters to generate_code + dev_idx
|
||||
std::stringstream ss;
|
||||
ss << nTensors << f;
|
||||
ss << f_inputs_type_str << compute_type_str << result_type_str;
|
||||
ss << contiguous << dynamic_casting;
|
||||
ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar);
|
||||
ss << extra_args_types;
|
||||
ss << dev_idx;
|
||||
const std::string cache_key = ss.str();
|
||||
|
||||
static std::mutex _jiterator_mutex;
|
||||
static std::unordered_map<std::string, at::cuda::jit::NvrtcFunction> fns;
|
||||
|
||||
at::cuda::jit::NvrtcFunction* fn_ptr = &fns[cache_key];
|
||||
if (!fn_ptr->function) {
|
||||
const std::lock_guard<std::mutex> lock{_jiterator_mutex};
|
||||
if (!fn_ptr->function) {
|
||||
auto code = at::cuda::jit::generate_code(nTensors, f, name,
|
||||
f_inputs_type_str, compute_type_str, result_type_str,
|
||||
contiguous, dynamic_casting,
|
||||
at::cuda::jit::BinaryFuncVariant::NoScalar,
|
||||
extra_args_types);
|
||||
*fn_ptr = at::cuda::jit::jit_pwise_function(code, name);
|
||||
}
|
||||
}
|
||||
|
||||
float scalar_val = 0;
|
||||
|
||||
// pack args for kernel launch
|
||||
constexpr int kernel_args = 7;
|
||||
auto extra_args_size = extra_args.size();
|
||||
auto args = std::make_unique<void*[]>(kernel_args + extra_args_size);
|
||||
args[0] = static_cast<void*>(&N);
|
||||
args[1] = data_ptr;
|
||||
args[2] = ic_ptr;
|
||||
args[3] = oc_ptr;
|
||||
args[4] = l_ptr;
|
||||
args[5] = s_ptr;
|
||||
args[6] = static_cast<void*>(&scalar_val);
|
||||
|
||||
for (const auto i : c10::irange(extra_args_size)) {
|
||||
// since 7 slots are already filled in `args`
|
||||
args[i + 7] = const_cast<void*>(extra_args[i].data_ptr());
|
||||
}
|
||||
|
||||
at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args.get(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
|
||||
}
|
||||
|
||||
void jitted_gpu_kernel_dynamic_impl(
|
||||
const std::string& kernel_name,
|
||||
TensorIteratorBase& iter,
|
||||
const std::string& f,
|
||||
const bool dynamic_casting,
|
||||
const std::vector<at::Scalar>& extra_args) {
|
||||
|
||||
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
|
||||
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
|
||||
TORCH_INTERNAL_ASSERT(iter.ninputs() <= 8);
|
||||
|
||||
ArrayVariant data(iter);
|
||||
void* data_ptr = data.data_ptr();
|
||||
|
||||
int64_t numel = iter.numel();
|
||||
bool contiguous = iter.is_contiguous();
|
||||
|
||||
// Decides which of 4 kernel types to launch
|
||||
// Variations are:
|
||||
// - Case 1: no dynamic casting and contiguous
|
||||
// - Case 2: no dynamic casting and noncontiguous
|
||||
// - Case 3: dynamic casting and contiguous
|
||||
// - Case 4: dynamic casting and noncontiguous
|
||||
// These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl
|
||||
|
||||
if (!dynamic_casting) {
|
||||
if (contiguous) {
|
||||
// Case 1: no dynamic casting and contiguous
|
||||
launch_jitted_vectorized_kernel_dynamic(kernel_name, iter,
|
||||
iter.device().index(), numel, f, data_ptr, extra_args);
|
||||
return;
|
||||
}
|
||||
|
||||
// Case 2: no dynamic casting and noncontiguous
|
||||
OffsetCalculatorVariant input_offset_calculator(iter);
|
||||
void* ic_ptr = input_offset_calculator.data_ptr();
|
||||
auto output_offset_calculator = make_output_offset_calculator(iter);
|
||||
void* oc_ptr = static_cast<void*>(&output_offset_calculator);
|
||||
|
||||
auto loader = memory::LoadWithoutCast();
|
||||
auto storer = memory::StoreWithoutCast();
|
||||
void* l_ptr = static_cast<void*>(&loader);
|
||||
void* s_ptr = static_cast<void*>(&storer);
|
||||
|
||||
launch_jitted_unrolled_kernel_dynamic(
|
||||
kernel_name, iter, iter.device().index(), numel, f, data_ptr,
|
||||
ic_ptr, oc_ptr, l_ptr, s_ptr, contiguous, dynamic_casting, extra_args);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Cases 3 and 4 are handled below
|
||||
// Both require construction of a storer (this asserts 1 output) and one or more loaders
|
||||
|
||||
// Creates load casts from inputs (note offset indexing into the iterators 1...n tensors)
|
||||
LoadWithCastVariant loader(iter);
|
||||
void* l_ptr = loader.data_ptr();
|
||||
|
||||
// Creates store cast to output (the zeroth tensor in TensorIterator)
|
||||
auto storer = memory::StoreWithCast(iter.dtype(0));
|
||||
void* s_ptr = static_cast<void*>(&storer);
|
||||
|
||||
if (contiguous) {
|
||||
// Case 3: dynamic casting and contiguous
|
||||
TrivialOffsetCalculatorVariant input_offset_calculator(iter);
|
||||
void* ic_ptr = input_offset_calculator.data_ptr();
|
||||
|
||||
auto output_offset_calculator = TrivialOffsetCalculator<1>();
|
||||
void* oc_ptr = static_cast<void*>(&output_offset_calculator);
|
||||
|
||||
launch_jitted_unrolled_kernel_dynamic(
|
||||
kernel_name, iter, iter.device().index(), numel, f, data_ptr,
|
||||
ic_ptr, oc_ptr, l_ptr, s_ptr, contiguous, dynamic_casting, extra_args);
|
||||
return;
|
||||
}
|
||||
|
||||
// Case 4: dynamic casting and noncontiguous
|
||||
OffsetCalculatorVariant input_offset_calculator(iter);
|
||||
void* ic_ptr = input_offset_calculator.data_ptr();
|
||||
|
||||
auto output_offset_calculator = make_output_offset_calculator(iter);
|
||||
void* oc_ptr = static_cast<void*>(&output_offset_calculator);
|
||||
|
||||
launch_jitted_unrolled_kernel_dynamic(
|
||||
kernel_name, iter, iter.device().index(), numel, f, data_ptr,
|
||||
ic_ptr, oc_ptr, l_ptr, s_ptr, contiguous, dynamic_casting, extra_args);
|
||||
}
|
||||
|
||||
// Entrypoint for dynamic version of jitted GPU kernels, which accepts dynamic number of inputs
|
||||
// and arbitrary types of input and extra args. This dynamic version is needed for jiterator with python interface,
|
||||
// since the kernel definition is unknown at the compilation time.
|
||||
// Similarly, launch_jitted_vectorized_kernel_dynamic and launch_jitted_unrolled_kernel_dynamic are created
|
||||
// to handle arbitrary functions defined in python user code.
|
||||
// For templated version, see note [Jiterator] in JitLoops.cuh for more details
|
||||
void jitted_gpu_kernel_dynamic(
|
||||
const std::string& kernel_name,
|
||||
TensorIteratorBase& iter,
|
||||
const std::string& f,
|
||||
const std::vector<at::Scalar>& extra_args) {
|
||||
|
||||
// TODO: much of preamble is common to both jitted_gpu_kernel and gpu_kernel
|
||||
// Maybe it could be refactored?
|
||||
for (int arg = 0; arg < iter.ntensors(); arg++) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
iter.device(arg).is_cuda(),
|
||||
"argument ", arg, ": expected a CUDA device but found ", iter.device(arg));
|
||||
}
|
||||
|
||||
if (iter.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!iter.can_use_32bit_indexing()) {
|
||||
for (auto& sub_iter : iter.with_32bit_indexing()) {
|
||||
jitted_gpu_kernel_dynamic(kernel_name, sub_iter, f, extra_args);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Computes if dynamic casting is needed
|
||||
// Dynamic casting is needed if an input's or output's dtype differs from the common dtype
|
||||
bool needs_dynamic_casting = false;
|
||||
const at::ScalarType common_dtype = iter.common_dtype();
|
||||
for (auto i = 0; i < iter.ntensors(); ++i) {
|
||||
if (iter.dtype(i) != common_dtype) {
|
||||
needs_dynamic_casting = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
jitted_gpu_kernel_dynamic_impl(kernel_name, iter, f, needs_dynamic_casting, extra_args);
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
|
||||
namespace cuda {
|
||||
|
||||
at::Tensor CompileAndLaunchKernel(
|
||||
const std::string& code_string,
|
||||
const std::string& kernel_name,
|
||||
const std::vector<at::Tensor>& tensors,
|
||||
const std::vector<at::Scalar>& extra_args) {
|
||||
|
||||
Tensor output;
|
||||
TensorIteratorConfig config;
|
||||
config
|
||||
.set_check_mem_overlap(true)
|
||||
.allow_cpu_scalars(false)
|
||||
.promote_inputs_to_common_dtype(true)
|
||||
.cast_common_dtype_to_outputs(true)
|
||||
.enforce_safe_casting_to_output(true)
|
||||
.check_all_same_device(true)
|
||||
.add_owned_output(output);
|
||||
for (const auto& t: tensors){
|
||||
config.add_input(t);
|
||||
}
|
||||
TensorIterator iter = config.build();
|
||||
|
||||
CUDAGuard guard(iter.device());
|
||||
at::native::jitted_gpu_kernel_dynamic(kernel_name, iter, code_string, extra_args);
|
||||
|
||||
return iter.output();
|
||||
}
|
||||
|
||||
}} // namespace at::cuda
|
||||
|
||||
#endif // AT_USE_JITERATOR()
|
||||
35
aten/src/ATen/cuda/jiterator.h
Normal file
35
aten/src/ATen/cuda/jiterator.h
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
#pragma once
|
||||
#include <ATen/jit_macros.h>
|
||||
|
||||
#if AT_USE_JITERATOR()
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace at {
|
||||
namespace cuda {
|
||||
|
||||
TORCH_CUDA_CPP_API at::Tensor CompileAndLaunchKernel(
|
||||
const std::string& code_string,
|
||||
const std::string& kernel_name,
|
||||
const std::vector<at::Tensor>& tensors,
|
||||
const std::vector<at::Scalar>& extra_args);
|
||||
|
||||
}} // namespace at::cuda
|
||||
|
||||
#else
|
||||
|
||||
namespace at { namespace cuda {
|
||||
TORCH_CUDA_CPP_API at::Tensor CompileAndLaunchKernel(
|
||||
const std::string& code_string,
|
||||
const std::string& kernel_name,
|
||||
const std::vector<at::Tensor>& tensors,
|
||||
const std::vector<at::Scalar>& extra_args) {
|
||||
TORCH_CHECK(false, "Jiterator is not supported on ROCm");
|
||||
}
|
||||
}} // namespace at::cuda
|
||||
|
||||
#endif // AT_USE_JITERATOR()
|
||||
208
aten/src/ATen/cuda/jiterator_impl.h
Normal file
208
aten/src/ATen/cuda/jiterator_impl.h
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
#pragma once
|
||||
#include <ATen/jit_macros.h>
|
||||
|
||||
#if AT_USE_JITERATOR()
|
||||
|
||||
#include <c10/util/variant.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
||||
#include <ATen/native/cuda/jit_utils.h>
|
||||
#include <ATen/native/cuda/MemoryAccess.cuh>
|
||||
#include <ATen/native/cuda/JitLoops.cuh>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
constexpr int NUM_INPUTS = 8;
|
||||
|
||||
#define AT_FOR_8_INPUTS(_) \
|
||||
_(1) \
|
||||
_(2) \
|
||||
_(3) \
|
||||
_(4) \
|
||||
_(5) \
|
||||
_(6) \
|
||||
_(7) \
|
||||
_(8)
|
||||
|
||||
c10::SmallVector<std::string> get_extra_args_typenames(const std::vector<at::Scalar>& extra_args) {
|
||||
c10::SmallVector<std::string> args_typenames(extra_args.size());
|
||||
for (auto i = 0; i < extra_args.size(); ++i) {
|
||||
args_typenames[i] = at::cuda::jit::typeName(extra_args[i].type());
|
||||
}
|
||||
return args_typenames;
|
||||
}
|
||||
|
||||
int can_vectorize_up_to(at::ScalarType type, char* pointer) {
|
||||
switch(type) {
|
||||
#define DEFINE_CASE(ctype, scalartype) \
|
||||
case ScalarType::scalartype : return memory::can_vectorize_up_to<ctype>(pointer);
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
|
||||
#undef DEFINE_CASE
|
||||
|
||||
default: TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type);
|
||||
}
|
||||
}
|
||||
|
||||
// jitted version of the above
|
||||
// See Note [Jiterator], this relies on the assumptions enumerated there
|
||||
int jitted_can_vectorize_up_to(const TensorIteratorBase& iter) {
|
||||
const at::ScalarType common_dtype = iter.common_dtype();
|
||||
const at::ScalarType result_dtype = common_dtype;
|
||||
|
||||
// Deals with output
|
||||
int result = can_vectorize_up_to(result_dtype, static_cast<char*>(iter.data_ptr(0)));
|
||||
|
||||
// Incorporates input(s)
|
||||
for (auto i = 1; i < iter.ntensors(); ++i) {
|
||||
result = std::min<int>(result, can_vectorize_up_to(common_dtype, static_cast<char*>(iter.data_ptr(i))));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template<int N>
|
||||
static std::unique_ptr<OffsetCalculator<N>> make_unique_input_offset_calculator(const TensorIteratorBase& iter) {
|
||||
// array size can not be 0, this happens when N == 0
|
||||
constexpr int array_size = std::max<int>(N, 1);
|
||||
TORCH_INTERNAL_ASSERT(N == iter.ntensors() - iter.noutputs());
|
||||
std::array<const int64_t*, array_size> strides;
|
||||
int64_t element_sizes[array_size];
|
||||
for (int i = 0; i < N; i++) {
|
||||
strides[i] = iter.strides(i + iter.noutputs()).data();
|
||||
element_sizes[i] = iter.element_size(i + iter.noutputs());
|
||||
}
|
||||
return std::make_unique<OffsetCalculator<N>>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
|
||||
}
|
||||
|
||||
struct OffsetCalculatorVariant {
|
||||
#define DEFINE_CASE(index) std::unique_ptr<OffsetCalculator<index>>,
|
||||
using OffsetCalculatorTypes = c10::variant<
|
||||
AT_FOR_8_INPUTS(DEFINE_CASE)
|
||||
>;
|
||||
#undef DEFINE_CASE
|
||||
|
||||
OffsetCalculatorVariant(const TensorIteratorBase& iter) {
|
||||
int arity = iter.ninputs();
|
||||
switch(arity) {
|
||||
#define DEFINE_CASE(index) \
|
||||
case index : v = make_unique_input_offset_calculator<index>(iter); break;
|
||||
|
||||
AT_FOR_8_INPUTS(DEFINE_CASE)
|
||||
#undef DEFINE_CASE
|
||||
default:
|
||||
TORCH_CHECK(false, "OffsetCalculatorVariant is not implemented for ninputs = ", arity);
|
||||
}
|
||||
}
|
||||
|
||||
void* data_ptr() {
|
||||
return c10::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
|
||||
}
|
||||
|
||||
private:
|
||||
OffsetCalculatorTypes v;
|
||||
};
|
||||
|
||||
struct ArrayVariant {
|
||||
// notice: This would produce c10::variant<at::detail::Array<char*, 2...9>>
|
||||
#define DEFINE_CASE(index) at::detail::Array<char*, index + 1>,
|
||||
using ArrayTypes = c10::variant<
|
||||
AT_FOR_8_INPUTS(DEFINE_CASE)
|
||||
>;
|
||||
#undef DEFINE_CASE
|
||||
|
||||
ArrayVariant(const TensorIteratorBase& iter) {
|
||||
int arity = iter.ninputs();
|
||||
// This assumes that jiterator kernels only have 1 output
|
||||
switch(arity) {
|
||||
#define DEFINE_CASE(index) \
|
||||
case index: array = at::detail::Array<char*, index + 1>{}; break;
|
||||
|
||||
AT_FOR_8_INPUTS(DEFINE_CASE)
|
||||
#undef DEFINE_CASE
|
||||
|
||||
default:
|
||||
TORCH_CHECK(false, "ArrayVariant is not implemented for ninputs = ", arity);
|
||||
}
|
||||
|
||||
c10::visit([&](auto& a) {
|
||||
for (auto i = 0; i < arity + 1; ++i) {
|
||||
a[i] = (char*)iter.data_ptr(i);
|
||||
}
|
||||
}, array);
|
||||
}
|
||||
|
||||
void* data_ptr() {
|
||||
return c10::visit([](auto & a){ return static_cast<void*>(&a); }, array);
|
||||
}
|
||||
|
||||
private:
|
||||
ArrayTypes array;
|
||||
};
|
||||
|
||||
struct TrivialOffsetCalculatorVariant {
|
||||
#define DEFINE_CASE(index) TrivialOffsetCalculator<index>,
|
||||
using TrivialOffsetCalculatorTypes = c10::variant<
|
||||
AT_FOR_8_INPUTS(DEFINE_CASE)
|
||||
>;
|
||||
#undef DEFINE_CASE
|
||||
|
||||
TrivialOffsetCalculatorVariant(const TensorIteratorBase& iter) {
|
||||
int arity = iter.ninputs();
|
||||
switch(arity) {
|
||||
#define DEFINE_CASE(index) \
|
||||
case index: v = TrivialOffsetCalculator<index>(); break;
|
||||
|
||||
AT_FOR_8_INPUTS(DEFINE_CASE)
|
||||
#undef DEFINE_CASE
|
||||
|
||||
default:
|
||||
TORCH_CHECK(false, "TrivialOffsetCalculatorVariant is not implemented for ninputs = ", arity);
|
||||
}
|
||||
}
|
||||
|
||||
void* data_ptr() {
|
||||
return c10::visit([](auto & v){ return static_cast<void*>(&v); }, v);
|
||||
}
|
||||
|
||||
private:
|
||||
TrivialOffsetCalculatorTypes v;
|
||||
};
|
||||
|
||||
struct LoadWithCastVariant {
|
||||
#define DEFINE_CASE(index) std::unique_ptr<memory::LoadWithCast<index>>,
|
||||
using LoadWithCastPtr = c10::variant<
|
||||
AT_FOR_8_INPUTS(DEFINE_CASE)
|
||||
>;
|
||||
#undef DEFINE_CASE
|
||||
|
||||
LoadWithCastVariant(const TensorIteratorBase& iter) {
|
||||
int arity = iter.ninputs();
|
||||
switch(arity) {
|
||||
#define DEFINE_CASE(index) \
|
||||
case index: v = std::make_unique<memory::LoadWithCast<index>>(iter); break;
|
||||
|
||||
AT_FOR_8_INPUTS(DEFINE_CASE)
|
||||
#undef DEFINE_CASE
|
||||
|
||||
default:
|
||||
TORCH_CHECK(false, "LoadWithCastVariant is not implemented for ninputs = ", arity);
|
||||
}
|
||||
}
|
||||
|
||||
void* data_ptr() {
|
||||
return c10::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
|
||||
}
|
||||
|
||||
private:
|
||||
LoadWithCastPtr v;
|
||||
};
|
||||
|
||||
}} // namespace at::native
|
||||
|
||||
|
||||
#endif // AT_USE_JITERATOR()
|
||||
|
|
@ -116,6 +116,15 @@ struct LoadWithCast {
|
|||
}
|
||||
}
|
||||
|
||||
LoadWithCast(const TensorIteratorBase& iter) {
|
||||
assert(iter.ninputs() == N);
|
||||
#pragma unroll
|
||||
for (auto i = 0; i < N; ++i) {
|
||||
this->dtypes[i] = iter.dtype(i + 1);
|
||||
element_sizes[i] = c10::elementSize(iter.dtype(i + 1));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
__device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) {
|
||||
void *ptr = base_ptr + element_sizes[arg] * offset;
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
#include <ATen/jit_macros.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
||||
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
||||
#include <ATen/code_template.h>
|
||||
#include <ATen/native/cuda/jit_utils.h>
|
||||
#include <ATen/cuda/llvm_jit_strings.h>
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include <c10/util/irange.h>
|
||||
#include <ATen/jit_macros.h>
|
||||
#include <ATen/cuda/detail/LazyNVRTC.h>
|
||||
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
||||
|
||||
namespace at { namespace cuda { namespace jit {
|
||||
|
||||
|
|
@ -85,6 +86,9 @@ AT_FORALL_SCALAR_TYPES(TYPE_NAME_FN)
|
|||
// JIT uses std::complex directly, because nvRTC compile programs
|
||||
// with -default-device, so there is no such issue like:
|
||||
// "std::sin(complex) is __host__ only"
|
||||
template <> inline std::string typeName<bool>(){
|
||||
return "bool";
|
||||
}
|
||||
template <> inline std::string typeName<c10::complex<at::Half>>(){
|
||||
return "std::complex<at::Half>";
|
||||
}
|
||||
|
|
@ -101,4 +105,20 @@ template <> inline std::string typeName<at::BFloat16>(){
|
|||
return "at::BFloat16";
|
||||
}
|
||||
|
||||
#define TYPE_NAME_CASE(ctype, scalartype) \
|
||||
case ScalarType::scalartype: return std::string(#ctype);
|
||||
inline std::string typeName(ScalarType t) {
|
||||
switch (t) {
|
||||
AT_FORALL_SCALAR_TYPES(TYPE_NAME_CASE)
|
||||
case ScalarType::Bool : return "bool";
|
||||
case ScalarType::Half : return "at::Half";
|
||||
case ScalarType::BFloat16 : return "at::BFloat16";
|
||||
case ScalarType::ComplexFloat : return "std::complex<float>";
|
||||
case ScalarType::ComplexDouble : return "std::complex<double>";
|
||||
default:
|
||||
TORCH_CHECK(false, "invalid type for jiterator");
|
||||
}
|
||||
}
|
||||
#undef TYPE_NAME_CASE
|
||||
|
||||
}}} // namespace at::cuda::jit
|
||||
|
|
|
|||
|
|
@ -74,6 +74,10 @@ class C10_API Scalar {
|
|||
template <typename T>
|
||||
T to() const = delete;
|
||||
|
||||
const void* data_ptr() const {
|
||||
return static_cast<const void*>(&v);
|
||||
}
|
||||
|
||||
#undef DEFINE_ACCESSOR
|
||||
bool isFloatingPoint() const {
|
||||
return Tag::HAS_d == tag;
|
||||
|
|
|
|||
|
|
@ -280,7 +280,21 @@ namespace std {
|
|||
#define C10_MPARK_BUILTIN_UNREACHABLE
|
||||
#endif
|
||||
|
||||
#if __has_builtin(__type_pack_element)
|
||||
// NOTE [nvcc bug workaround]
|
||||
//
|
||||
// The original line `typename Front = lib::type_pack_element_t<0, Ts...>,`
|
||||
// throws the following compiler error on nvcc:
|
||||
// ```
|
||||
// c10/util/variant.h(2367): error: parameter pack "Ts" was referenced but not
|
||||
// expanded
|
||||
// ```
|
||||
// As a workaround, we skip defining C10_MPARK_TYPE_PACK_ELEMENT for nvcc
|
||||
// compiler
|
||||
//
|
||||
// See the following issues for more context:
|
||||
// https://github.com/pytorch/extension-cpp/issues/58
|
||||
// https://github.com/mpark/variant/issues/77
|
||||
#if __has_builtin(__type_pack_element) && !defined(__CUDACC__)
|
||||
#define C10_MPARK_TYPE_PACK_ELEMENT
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -123,3 +123,11 @@ NVIDIA Tools Extension (NVTX)
|
|||
nvtx.mark
|
||||
nvtx.range_push
|
||||
nvtx.range_pop
|
||||
|
||||
Jiterator (beta)
|
||||
-----------------------------
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
jiterator._create_jit_fn
|
||||
|
|
|
|||
132
test/test_jiterator.py
Normal file
132
test/test_jiterator.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
# Owner(s): ["module: cuda"]
|
||||
|
||||
import torch
|
||||
from torch.cuda.jiterator import _create_jit_fn as create_jit_fn
|
||||
import sys
|
||||
from itertools import product
|
||||
from torch.testing._internal.common_utils import TestCase, parametrize, run_tests, TEST_CUDA
|
||||
from torch.testing._internal.common_dtype import all_types_and_complex_and
|
||||
from torch.testing._internal.common_device_type import (
|
||||
skipCUDAIfRocm, skipCUDAIf, instantiate_device_type_tests, dtypes, toleranceOverride, tol)
|
||||
from torch.testing._internal.common_cuda import _get_torch_cuda_version
|
||||
|
||||
if not TEST_CUDA:
|
||||
print('CUDA not available, skipping tests', file=sys.stderr)
|
||||
TestCase = object # noqa: F811
|
||||
|
||||
|
||||
code_string = "template <typename T> T my_fused_kernel(T x, T y, T alpha, T beta) { return alpha * x + beta * y; }"
|
||||
jitted_fn = create_jit_fn(code_string, alpha=1, beta=1)
|
||||
|
||||
def ref_fn(x, y, alpha=1, beta=1):
|
||||
return alpha * x + beta * y
|
||||
|
||||
class TestPythonJiterator(TestCase):
|
||||
@skipCUDAIfRocm
|
||||
@parametrize("shape_strides", [
|
||||
(([3, 3], [3, 1]), ([3, 3], [3, 1])), # contiguous
|
||||
])
|
||||
@dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16),
|
||||
all_types_and_complex_and(torch.half, torch.bfloat16)))
|
||||
def test_all_dtype_contiguous(self, device, dtypes, shape_strides):
|
||||
a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0])
|
||||
b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1])
|
||||
|
||||
a = a_buffer.as_strided(*shape_strides[0])
|
||||
b = b_buffer.as_strided(*shape_strides[1])
|
||||
|
||||
expected = ref_fn(a, b)
|
||||
result = jitted_fn(a, b)
|
||||
|
||||
self.assertEqual(expected, result)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
# See https://github.com/pytorch/pytorch/pull/76394#issuecomment-1118018287 for details
|
||||
@skipCUDAIf(_get_torch_cuda_version() < (11, 6), "On cuda 11.3, nvrtcCompileProgram is taking too long to "
|
||||
"compile jiterator generated kernels for non-contiguous input that requires dynamic-casting.")
|
||||
@parametrize("shape_strides", [
|
||||
(([3, 3], [1, 3]), ([3, 1], [1, 3])), # non-contiguous
|
||||
])
|
||||
@dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16),
|
||||
all_types_and_complex_and(torch.half, torch.bfloat16)))
|
||||
def test_all_dtype_noncontiguous(self, device, dtypes, shape_strides):
|
||||
a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0])
|
||||
b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1])
|
||||
|
||||
a = a_buffer.as_strided(*shape_strides[0])
|
||||
b = b_buffer.as_strided(*shape_strides[1])
|
||||
|
||||
expected = ref_fn(a, b)
|
||||
result = jitted_fn(a, b)
|
||||
|
||||
self.assertEqual(expected, result)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float, torch.double, torch.float16, torch.bfloat16)
|
||||
@parametrize("alpha", [-1, 2.0, None])
|
||||
@parametrize("beta", [3, -4.2, None])
|
||||
@toleranceOverride({torch.float16 : tol(atol=1e-2, rtol=1e-3)})
|
||||
def test_extra_args(self, device, dtype, alpha, beta):
|
||||
a = torch.rand(3, device=device).mul(10).type(dtype)
|
||||
b = torch.rand(3, device=device).mul(10).type(dtype)
|
||||
|
||||
extra_args = {}
|
||||
if alpha is not None:
|
||||
extra_args["alpha"] = alpha
|
||||
if beta is not None:
|
||||
extra_args["beta"] = beta
|
||||
|
||||
expected = ref_fn(a, b, **extra_args)
|
||||
result = jitted_fn(a, b, **extra_args)
|
||||
|
||||
self.assertEqual(expected, result)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
def test_bool_extra_args(self, device):
|
||||
code_string = "template <typename T> T conditional(T x, T mask, bool is_train) { return is_train ? x * mask : x; }"
|
||||
jitted_fn = create_jit_fn(code_string, is_train=False)
|
||||
|
||||
def ref_fn(x, mask, is_train):
|
||||
return x * mask if is_train else x
|
||||
|
||||
a = torch.rand(3, device=device)
|
||||
b = torch.rand(3, device=device)
|
||||
|
||||
expected = ref_fn(a, b, is_train=True)
|
||||
result = jitted_fn(a, b, is_train=True)
|
||||
self.assertEqual(expected, result)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@parametrize("num_inputs", list(range(1, 9)))
|
||||
def test_various_num_inputs(self, num_inputs):
|
||||
inputs = []
|
||||
for i in range(num_inputs):
|
||||
inputs.append(torch.rand(3, device='cuda').mul(10))
|
||||
|
||||
input_string = ",".join([f"T i{i}" for i in range(num_inputs)])
|
||||
function_body = "+".join([f"i{i}" for i in range(num_inputs)])
|
||||
code_string = f"template <typename T> T my_kernel({input_string}) {{ return {function_body}; }}"
|
||||
jitted_fn = create_jit_fn(code_string)
|
||||
|
||||
def ref_fn(*inputs):
|
||||
return torch.sum(torch.stack(inputs), dim=0)
|
||||
|
||||
expected = ref_fn(*inputs)
|
||||
result = jitted_fn(*inputs)
|
||||
|
||||
self.assertEqual(expected, result)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@parametrize("code_string", [
|
||||
"template <typename T> T my _kernel(T x) { return x; }",
|
||||
"template <typename T> Tmy_kernel(T x) { return x; }",
|
||||
])
|
||||
def test_invalid_function_name(self, code_string):
|
||||
with self.assertRaises(Exception):
|
||||
jitted_fn = create_jit_fn(code_string)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestPythonJiterator, globals(), only_for="cuda")
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
@ -67,7 +67,7 @@ def check_file(filename: str) -> Optional[LintMessage]:
|
|||
name="testestTrailing newline",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description="Trailing newline found. Run `lintunner --take NEWLINE -a` to apply changes.",
|
||||
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
|
||||
)
|
||||
|
||||
else:
|
||||
|
|
@ -103,7 +103,7 @@ def check_file(filename: str) -> Optional[LintMessage]:
|
|||
name="Trailing newline",
|
||||
original=original,
|
||||
replacement=original.rstrip("\n") + "\n",
|
||||
description="Trailing newline found. Run `lintunner --take NEWLINE -a` to apply changes.",
|
||||
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -924,6 +924,10 @@ def _cuda_memorySnapshot() -> List[Dict[str, Any]]: ...
|
|||
def _cuda_lock_mutex() -> None: ...
|
||||
def _cuda_unlock_mutex() -> None: ...
|
||||
def _cuda_canDeviceAccessPeer(device: _int, peer_device: _int) -> _bool: ...
|
||||
def _cuda_jiterator_compile_and_launch_kernel(code_string: str,
|
||||
kernel_name: str,
|
||||
tensors: Tuple,
|
||||
kwargs: Dict[str, Union[_int, _float, _bool]]) -> Tensor: ...
|
||||
def _nccl_version() -> _int: ...
|
||||
def _nccl_unique_id() -> bytes: ...
|
||||
def _nccl_init_rank(nranks: _int, comm_id: bytes, rank: _int) -> object: ...
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include <ATen/cuda/CachingHostAllocator.h>
|
||||
#include <ATen/cuda/Sleep.h>
|
||||
#include <ATen/cuda/detail/CUDAHooks.h>
|
||||
#include <ATen/cuda/jiterator.h>
|
||||
#ifdef USE_NCCL
|
||||
#include <torch/csrc/cuda/python_nccl.h>
|
||||
#endif
|
||||
|
|
@ -217,6 +218,71 @@ PyObject * THCPModule_cudaCachingAllocator_raw_alloc(PyObject *_unused, PyObject
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// Unpack a PyObject to at::Scalar, throw an exception if it fails
|
||||
at::Scalar as_scalar(PyObject* arg) {
|
||||
// Zero-dim tensors are converted to Scalars as-is. Note this doesn't currently
|
||||
// handle most NumPy scalar types except np.float64.
|
||||
if (THPVariable_Check(arg)) {
|
||||
return THPVariable_Unpack(arg).item();
|
||||
}
|
||||
|
||||
if (THPUtils_checkLong(arg)) {
|
||||
return at::Scalar(static_cast<int64_t>(THPUtils_unpackLong(arg)));
|
||||
}
|
||||
|
||||
if (PyBool_Check(arg)) {
|
||||
return at::Scalar(THPUtils_unpackBool(arg));
|
||||
}
|
||||
|
||||
if (PyComplex_Check(arg)) {
|
||||
return at::Scalar(THPUtils_unpackComplexDouble(arg));
|
||||
}
|
||||
return at::Scalar(THPUtils_unpackDouble(arg));
|
||||
}
|
||||
|
||||
// Entrypoint for the callable created by torch.cuda.jiterator
|
||||
// See jiterator.py for more details
|
||||
PyObject * THCPModule_cudaJiteratorCompileAndLaunchKernel(PyObject *_unused, PyObject *args){
|
||||
HANDLE_TH_ERRORS
|
||||
|
||||
PyObject* code_string_o = nullptr;
|
||||
PyObject* kernel_name_o = nullptr;
|
||||
PyObject* tensors_o = nullptr;
|
||||
PyObject* kwargs_o = nullptr;
|
||||
if(!PyArg_ParseTuple(args, "OOO|O", &code_string_o, &kernel_name_o, &tensors_o, &kwargs_o)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::string code_string = THPUtils_unpackString(code_string_o);
|
||||
std::string kernel_name = THPUtils_unpackString(kernel_name_o);
|
||||
|
||||
THPUtils_assert(PyTuple_Check(tensors_o), "tensors argument is expected to "
|
||||
"be a tuple, but got %s", THPUtils_typename(tensors_o));
|
||||
Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors_o);
|
||||
|
||||
std::vector<at::Tensor> tensors;
|
||||
for(const auto i : c10::irange(num_tensors)) {
|
||||
PyObject *_tensor = PyTuple_GET_ITEM(tensors_o, i);
|
||||
THPUtils_assert(THPVariable_Check(_tensor), "element %d of tensors "
|
||||
"tuple is not a Tensor", i);
|
||||
|
||||
tensors.emplace_back(THPVariable_Unpack(_tensor));
|
||||
}
|
||||
|
||||
std::vector<at::Scalar> extra_args;
|
||||
PyObject *key = nullptr;
|
||||
PyObject *value = nullptr;
|
||||
Py_ssize_t pos = 0;
|
||||
while (PyDict_Next(kwargs_o, &pos, &key, &value)) {
|
||||
extra_args.emplace_back(as_scalar(value));
|
||||
}
|
||||
|
||||
at::Tensor output = at::cuda::CompileAndLaunchKernel(code_string, kernel_name, tensors, extra_args);
|
||||
|
||||
return THPVariable_Wrap(output);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject * THCPModule_cudaCachingAllocator_raw_delete(PyObject *_unused, PyObject *obj){
|
||||
HANDLE_TH_ERRORS
|
||||
void* mem_ptr = PyLong_AsVoidPtr(obj);
|
||||
|
|
@ -597,6 +663,7 @@ static struct PyMethodDef _THCPModule_methods[] = {
|
|||
{"_cuda_unlock_mutex", THCPModule_cudaUnlockMutex, METH_NOARGS, nullptr},
|
||||
{"_cuda_set_sync_debug_mode", THCPModule_cudaSetSyncDebugMode, METH_O, nullptr},
|
||||
{"_cuda_get_sync_debug_mode", THCPModule_cudaGetSyncDebugMode, METH_NOARGS, nullptr},
|
||||
{"_cuda_jiterator_compile_and_launch_kernel", THCPModule_cudaJiteratorCompileAndLaunchKernel, METH_VARARGS, nullptr},
|
||||
#ifdef USE_NCCL
|
||||
{"_nccl_version", THCPModule_nccl_version, METH_NOARGS, nullptr},
|
||||
{"_nccl_unique_id", THCPModule_nccl_unique_id, METH_NOARGS, nullptr},
|
||||
|
|
|
|||
|
|
@ -767,3 +767,4 @@ from . import sparse
|
|||
from . import profiler
|
||||
from . import nvtx
|
||||
from . import amp
|
||||
from . import jiterator
|
||||
|
|
|
|||
104
torch/cuda/jiterator.py
Normal file
104
torch/cuda/jiterator.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from typing import Callable, List
|
||||
|
||||
import re
|
||||
|
||||
__all__ : List[str] = []
|
||||
|
||||
class _CodeParser:
|
||||
def __init__(self, code_string: str):
|
||||
optional_ws = r"\s*"
|
||||
required_ws = r"\s+"
|
||||
template_params = r"(?P<template_params>\<.+\>)"
|
||||
return_type = r"(?P<return_type>\w+)"
|
||||
function_name = r"(?P<function_name>\w+)"
|
||||
function_params = r"(?P<function_params>\(.+\))"
|
||||
function_body = r"(?P<function_body>\{.+\})"
|
||||
|
||||
pattern = \
|
||||
optional_ws \
|
||||
+ "template" \
|
||||
+ optional_ws + template_params \
|
||||
+ optional_ws + return_type \
|
||||
+ required_ws + function_name \
|
||||
+ optional_ws + function_params \
|
||||
+ optional_ws + function_body \
|
||||
+ optional_ws
|
||||
|
||||
result = re.match(pattern, code_string, re.DOTALL) # DOTALL for matching multiline
|
||||
|
||||
if result is None:
|
||||
raise Exception(f"Couldn't parse code, please check correctness:\n {code_string}")
|
||||
|
||||
self.template_params = result["template_params"]
|
||||
self.return_type = result["return_type"]
|
||||
self.function_name = result["function_name"]
|
||||
self.function_params = result["function_params"]
|
||||
self.function_body = result["function_body"]
|
||||
|
||||
|
||||
def _create_jit_fn(code_string: str, **kwargs) -> Callable:
|
||||
"""
|
||||
Create a jiterator-generated cuda kernel for an elementwise op.
|
||||
|
||||
The code string has to be a valid CUDA function that describes the computation for a single element. The code
|
||||
string has to follow the c++ template pattern, as shown in the example below. This function will be inlined
|
||||
into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as
|
||||
local temp dir.
|
||||
|
||||
Jiterator-generated kernels accepts noncontiguous tensors, and supports boardcasting and type promotion.
|
||||
|
||||
Args:
|
||||
code_string (string): CUDA code string to be compiled by jiterator.
|
||||
kwargs (Dict, optional): Keyword arguments for generated function
|
||||
|
||||
Examples:
|
||||
>>> code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"
|
||||
>>> jitted_fn = create_jit_fn(code_string, alpha=1.0)
|
||||
>>> a = torch.rand(3, device='cuda')
|
||||
>>> b = torch.rand(3, device='cuda')
|
||||
>>> # invoke jitted function like a regular python function
|
||||
>>> result = jitted_fn(a, b, alpha=3.14)
|
||||
|
||||
.. warning::
|
||||
This API is in beta and may change in future releases.
|
||||
|
||||
.. warning::
|
||||
Jiterator only supports up to 8 tensor inputs
|
||||
|
||||
.. warning::
|
||||
All input tensors must live in CUDA device
|
||||
|
||||
"""
|
||||
class JittedFunction:
|
||||
def __init__(self, code_string: str, **kwargs):
|
||||
self.code_string = code_string
|
||||
|
||||
parsed_code = _CodeParser(code_string)
|
||||
self.kernel_name = parsed_code.function_name
|
||||
|
||||
self.kwargs_dict = kwargs
|
||||
self.is_cuda_available = torch.cuda.is_available()
|
||||
|
||||
def __call__(self, *tensors: Tensor, **kwargs):
|
||||
# Jiterator follow torch.cuda's lazy initialization behavior
|
||||
# Defer checking cuda's availability at the function invocation time
|
||||
assert self.is_cuda_available, "Jiterator is only supported on CUDA GPUs, no CUDA GPUs are available."
|
||||
|
||||
assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs."
|
||||
|
||||
expanded_kwargs = self.kwargs_dict.copy()
|
||||
for key, value in kwargs.items():
|
||||
if key in self.kwargs_dict:
|
||||
expanded_kwargs[key] = value
|
||||
else:
|
||||
raise KeyError(f"{key} is not declared in function definition")
|
||||
|
||||
return torch._C._cuda_jiterator_compile_and_launch_kernel(
|
||||
self.code_string,
|
||||
self.kernel_name,
|
||||
tensors,
|
||||
expanded_kwargs)
|
||||
|
||||
return JittedFunction(code_string, **kwargs)
|
||||
|
|
@ -2299,6 +2299,34 @@ def sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs)
|
|||
)
|
||||
|
||||
|
||||
def sample_inputs_jiterator(op, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
shapes = (
|
||||
((), ()),
|
||||
((S,), ()),
|
||||
((S, 1), (S,)),
|
||||
((M, S), ()),
|
||||
((S, M, S), (M, S)),
|
||||
((S, M, S), (S, M, S)),
|
||||
((M, 1, S), (M, S)),
|
||||
((M, 1, S), (1, M, S)),
|
||||
((0, 1, 3), (0, 10, 3))
|
||||
)
|
||||
|
||||
num_inputs = kwargs.get('num_inputs')
|
||||
sample_kwargs = kwargs.get('sample_kwargs', {})
|
||||
|
||||
for shape_lhs, shape_rhs in shapes:
|
||||
lhs = make_arg(shape_lhs)
|
||||
|
||||
args = []
|
||||
for i in range(num_inputs - 1):
|
||||
args.append(make_arg(shape_rhs))
|
||||
broadcasts_input = (shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs))
|
||||
|
||||
yield SampleInput(lhs, args=tuple(args), kwargs=sample_kwargs, broadcasts_input=broadcasts_input)
|
||||
|
||||
# The base reference input generation for elementwise binary operations
|
||||
def _reference_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs):
|
||||
yield from op.sample_inputs_func(op, device, dtype, requires_grad, **kwargs)
|
||||
|
|
@ -16371,6 +16399,95 @@ op_db: List[OpInfo] = [
|
|||
# Can't find schemas for this operator for some reason
|
||||
DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
|
||||
)),
|
||||
# Following tests are for jiterator's python interface
|
||||
# Jiterator can be used to author elementwise CUDA kernel
|
||||
# jiterator._create_jit_fn returns a callable that behaves like a regular pytorch op
|
||||
# See create_jit_fn in jiterator.py for more information
|
||||
UnaryUfuncInfo(
|
||||
'jiterator_unary',
|
||||
op=torch.cuda.jiterator._create_jit_fn("template <typename T> T unary(T x) { return x * x + x; }"),
|
||||
ref=lambda x: x * x + x,
|
||||
dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
|
||||
supports_out=False,
|
||||
supports_autograd=False, # jiterator ops doesn't have backward defined
|
||||
decorators=[
|
||||
onlyCUDA,
|
||||
skipCUDAIfRocm,
|
||||
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
|
||||
'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
|
||||
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
|
||||
'TestUnaryUfuncs', 'test_reference_numerics_hard'),
|
||||
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
|
||||
'TestUnaryUfuncs', 'test_reference_numerics_normal'),
|
||||
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
|
||||
'TestUnaryUfuncs', 'test_reference_numerics_small'),
|
||||
],
|
||||
skips=(
|
||||
# Jiterator ops doesn't support neg or conj view
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
|
||||
# Jiterator ops doesn't suport CompositeCompliantTensor
|
||||
# Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped
|
||||
DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'),
|
||||
# Skip reference_numerics tests for bool type, as the defined function doesn't work for bool
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
|
||||
dtypes=[torch.bool]),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard',
|
||||
dtypes=[torch.bool]),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
|
||||
dtypes=[torch.bool]),
|
||||
# Expected failure: torch.jiterator_unary is not a valid op
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
|
||||
)
|
||||
),
|
||||
BinaryUfuncInfo(
|
||||
'jiterator_binary',
|
||||
op=torch.cuda.jiterator._create_jit_fn(
|
||||
"template <typename T> T binary(T x, T y, T alpha) { return x + alpha * y; }", alpha=1),
|
||||
ref=lambda input, other, *, alpha=1: np.add(input, other) if alpha == 1 \
|
||||
else np.add(input, np.multiply(alpha, other)),
|
||||
dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
|
||||
sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2, alpha=-3.14),
|
||||
supports_out=False,
|
||||
supports_autograd=False, # jiterator ops doesn't have backward defined
|
||||
supports_rhs_python_scalar=False,
|
||||
decorators=[onlyCUDA, skipCUDAIfRocm],
|
||||
skips=(
|
||||
# Jiterator ops doesn't support neg or conj view
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
|
||||
# Jiterator ops doesn't suport CompositeCompliantTensor
|
||||
# Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped
|
||||
DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'),
|
||||
# Expected failure: torch.jiterator_binary is not a valid op
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
|
||||
)
|
||||
),
|
||||
OpInfo(
|
||||
'jiterator_4inputs_with_extra_args',
|
||||
op=torch.cuda.jiterator._create_jit_fn(
|
||||
"template <typename T> T binary(T i0, T i1, T i2, T i3, T alpha, T beta) { return alpha * i0 + beta * i1 + i2 + i3; }",
|
||||
alpha=1, beta=1),
|
||||
ref=lambda i0, i1, i2, i3, *, alpha=1, beta=1: alpha * i0 + beta * i1 + i2 + i3,
|
||||
dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
|
||||
sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=4, alpha=3.14, beta=-4.20),
|
||||
supports_out=False,
|
||||
supports_autograd=False, # jiterator ops doesn't have backward defined
|
||||
decorators=[onlyCUDA, skipCUDAIfRocm],
|
||||
skips=(
|
||||
# Jiterator ops doesn't support neg or conj view
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
|
||||
# Jiterator ops doesn't suport CompositeCompliantTensor
|
||||
# Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped
|
||||
DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'),
|
||||
# Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
|
||||
)
|
||||
),
|
||||
# `torch.norm` has multiple code paths depending on the value of `p`.
|
||||
# These paths have different dtype support. Also JIT supports,
|
||||
# most variants but not all of them. So we split the OpInfo entries,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user