The Cacherator (#71350)

Summary:
This PR adds a persistent filesystem cache for jitted kernels. The cache is disabled on Windows because it relies on POSIX headers.

The cache writes, by default, to `~/.cache/torch/kernels`, but the location can be controlled by setting the `PYTORCH_KERNEL_CACHE_PATH`. A separate environment variable, `USE_PYTORCH_KERNEL_CACHE`, will disable all caching logic when set to zero.

The use of a persistent fileystem cache dramatically lowers the "first call time" for an operator AFTER its has been compiled, because it skips (most of) the jit compilation process. On systems where we're compiling only to ptx that ptx still has to be just-in-time compiled by the driver API, so an additional latency of around 10 milliseconds is expected at first call time. On systems which compile to SASS the additional first call time latency is about one millisecond. This compares with times of 150 milliseconds+ for just-in-time kernel compilation.

Files in the cache use a mostly human readable string that includes an SHA1 hash of the CUDA C string used to generate them. Note that this is not an SHA1 hash of the file's contents, because the contents are the compiled ptx or SASS. No verification is done when the file is loaded to ensure the kernel is what's expected, but it's far more likely you'll be struck by a meteor than observe two file names conflict. Using SHA1 hashes to generate unique ids this way is a common practice (GitHub does it, too).

This cache design could be reused by other fusion systems and should allow us to jiterate more operations without fear of regressing the "incremental development" scenario where users are tweaking or extending programs slightly, rerunning then, and then repeating that process again and again. Without a cache each run of the program would have to recompile every jitted kernel, but with this cache we expect a negligible impact to the user experience.

cc kshitij12345, xwang233

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71350

Reviewed By: ngimel

Differential Revision: D33626671

Pulled By: mruberry

fbshipit-source-id: d55df53416fbe46348623846f699f9b998e6c318
This commit is contained in:
Mike Ruberry 2022-01-17 23:50:46 -08:00 committed by Facebook GitHub Bot
parent 7b9fff90d2
commit d17f340a2e
7 changed files with 497 additions and 95 deletions

View File

@ -0,0 +1,17 @@
#pragma once
// USE_JITERATOR, controls whether we jit some elementwise kernels
// Currently unsupported on ROCm GPUs
#ifndef USE_ROCM
#define USE_JITERATOR true
#define jiterator_stringify(...) std::string(#__VA_ARGS__);
#else
// TODO: update this to become a static assertion
#define jiterator_stringify(...) std::string("USE_JITERATOR is undefined");
#endif // USE_ROCM
// BUILD_JITERATOR_WITH_CACHE, controls whether jitted kernels can be cached
// Currently unsupported on Windows
#ifndef _WIN32
#define BUILD_JITERATOR_WITH_CACHE true
#endif // _WIN32

View File

@ -33,6 +33,7 @@
#include <iostream>
#include <mutex>
#include <ATen/jit_macros.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/core/Array.h>
#include <ATen/detail/FunctionTraits.h>
@ -123,6 +124,8 @@ static inline void launch_vectorized_kernel(int64_t N, const func_t& f, array_t
}
}
// Jiterator functions are guarded behind this macro
#ifdef USE_JITERATOR
template<char const *name,
typename result_type,
typename f_inputs_type,
@ -214,7 +217,9 @@ at::opmath_type<f_inputs_type> scalar_val) {
if (!fn_ptr->function) {
const std::lock_guard<std::mutex> lock{_jiterator_mutex};
if (!fn_ptr->function) {
if (!fn_ptr->function) { // cache miss!
// Generates program
constexpr int nTensors = array_t::size();
std::string string_name{name};
std::string f_inputs_type_str = at::cuda::jit::typeName<f_inputs_type>();
@ -226,6 +231,8 @@ at::opmath_type<f_inputs_type> scalar_val) {
scalar_pos,
vectorized, vec_size);
std::string kernel_name = vectorized ? string_name + "_vectorized" + std::to_string(vec_size) : string_name;
// Acquires the program
*fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name);
}
}
@ -262,18 +269,6 @@ at::opmath_type<f_inputs_type> scalar_val) {
at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, num_threads());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
template<typename func_t, typename array_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t>
static inline void launch_unrolled_kernel(int64_t N, const func_t& f, array_t data,
inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s)
{
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
int64_t grid = (N + block_work_size() - 1) / block_work_size();
auto stream = at::cuda::getCurrentCUDAStream();
unrolled_elementwise_kernel<func_t, array_t><<<grid, num_threads(), 0, stream>>>(N, f, data, ic, oc, l, s);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template <char const *name, typename result_type, typename compute_type, int arity,
@ -349,6 +344,18 @@ void jitted_gpu_kernel_impl(TensorIteratorBase& iter, const std::string& f, cons
iter.device().index(), numel, f, data, input_offset_calculator,
output_offset_calculator, loader, storer, contiguous, scalar_val);
}
#endif // USE_JITERATOR
template<typename func_t, typename array_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t>
static inline void launch_unrolled_kernel(int64_t N, const func_t& f, array_t data,
inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s)
{
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
int64_t grid = (N + block_work_size() - 1) / block_work_size();
auto stream = at::cuda::getCurrentCUDAStream();
unrolled_elementwise_kernel<func_t, array_t><<<grid, num_threads(), 0, stream>>>(N, f, data, ic, oc, l, s);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template <typename func_t>
void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {

View File

@ -1,6 +1,7 @@
#pragma once
#include <ATen/jit_macros.h>
#include <ATen/detail/FunctionTraits.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/TensorIteratorDynamicCasting.h>
@ -79,23 +80,27 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
// memory access introduce regression on ROCm.
#if !defined(USE_ROCM)
#include <ATen/native/cuda/CUDALoops.cuh>
#define USE_JITERATOR
#include <ATen/native/cuda/CUDALoops.cuh>
#else
#include <ATen/native/cuda/ROCmLoops.cuh>
#include <ATen/native/cuda/ROCmLoops.cuh>
#endif
namespace at { namespace native {
#ifdef USE_JITERATOR
#ifdef USE_JITERATOR
/* Note [Jiterator]
The "jiterator" simply just-in-time compiles the same kernels that
Loops.cuh (and CUDALoops.cuh) usually build. This reduces build time,
build size, and CUDA context size.
build size, and initial CUDA context size.
By default on non-Windows systems, it also caches compiled kernels in ~/.cache/torch/kernels.
This behavior is controlled with two environment variables:
- USE_PYTORCH_KERNEL_CACHE, if set to zero then this will disable all cache use
- PYTORCH_KERNEL_CACHE_PATH, if set specifies the folder to use for cached kernels
The jiterator currently has some limitations, however. It cannot:
- handle float16, bfloat16, or complex datatypes
- handle scalar inputs
- handle math on complex datatypes
- handle kernels with scalar parameters
These improvements will likely come soon.
@ -207,7 +212,7 @@ void opmath_jitted_gpu_kernel_with_scalars(TensorIteratorBase& iter, const std::
jitted_gpu_kernel<name, return_type, f_inputs_type, 2>(iter, f);
}
}
#endif //USE_JITERATOR
#endif // USE_JITERATOR
template <typename func_t>
void gpu_kernel(TensorIteratorBase& iter, const func_t& f) {

View File

@ -1,6 +1,7 @@
#pragma once
#include <ATen/AccumulateType.h>
#include <ATen/jit_macros.h>
#include <c10/macros/Macros.h>
#include <ATen/native/cuda/jit_utils.h>
@ -110,9 +111,7 @@ static inline C10_HOST_DEVICE scalar_t calc_i0(scalar_t _x) {
// See note [Jiterator]
// TODO: elaborate in this comment on the structure of math.cuh
// The jiterator is not currently supported on ROCm
// TODO: update this to USE_JITERATOR in the future (requires refactoring the macro)
#ifndef USE_ROCM
#ifdef USE_JITERATOR
const auto ndtri_string = jiterator_stringify(
/*
@ -1381,7 +1380,7 @@ const auto erfcx_string = jiterator_stringify(
}
); // erfcx_string
#else // USE_ROCM
#else // !USE_JITERATOR -- kernels must be precompiled
template <typename scalar_t>
static inline C10_HOST_DEVICE scalar_t calc_gcd(scalar_t a_in, scalar_t b_in) {
@ -1647,7 +1646,7 @@ static inline C10_HOST_DEVICE scalar_t calc_i1e(scalar_t _x) {
return (_x < scalar_t{0.0}) ? -out : out;
}
#endif // USE_ROCM (false/true)
#endif // USE_JITERATOR (this closes the "else" branch of a if/else preprocessor directive)
} // namespace native
} // namespace at

View File

@ -1,13 +1,28 @@
#include <sstream>
#include <c10/core/ScalarType.h>
#include <c10/util/irange.h>
#include <c10/util/hash.h>
#include <c10/util/Optional.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/code_template.h>
#include <ATen/native/cuda/jit_utils.h>
#include <sstream>
#include <fstream>
#include <cstdio>
#include <iterator> // istreambuf_iterator
#include <cstdlib>
#include <string>
#if BUILD_JITERATOR_WITH_CACHE
// Uses POSIX headers, which is why these are guarded behind BUILD_JITERATOR_WITH_CACHE
// TODO: C++17 has the fileystem header, which may replace these
#include <sys/types.h>
#include <sys/stat.h> // mkdir
#include <unistd.h>
#endif // BUILD_JITERATOR_WITH_CACHE
namespace at { namespace cuda { namespace jit {
@ -510,52 +525,58 @@ const at::cuda::NVRTC& nvrtc() {
// TODO refactor so this function is usable both from jit and from aten
void codegenOutputQuery(
const cudaDeviceProp* const prop,
int& major,
int& minor,
int& cuda_major,
int& cuda_minor,
int& nvrtc_major,
int& nvrtc_minor,
bool& compile_to_sass) {
using CudaVersion = std::pair<int, int>;
CudaVersion nvrtc_version;
AT_CUDA_NVRTC_CHECK(
nvrtc().nvrtcVersion(&nvrtc_version.first, &nvrtc_version.second));
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcVersion(&nvrtc_major, &nvrtc_minor));
TORCH_CHECK(
nvrtc_version.first >= 6,
"NVRTC versions less than 6 are not supported. Is: ",
nvrtc_version.first);
nvrtc_major >= 6, "NVRTC versions less than 6 are not supported. Is: ", nvrtc_major);
// Version supported by device
// Usually any lower version works too but is less efficient
const CudaVersion dev_version = CudaVersion(prop->major, prop->minor);
using CUDAVersion = std::pair<int, int>;
const CUDAVersion nvrtc_version{nvrtc_major, nvrtc_minor};
const CUDAVersion dev_version{prop->major, prop->minor};
// Maximum version supported by the driver, cap dev_version to this
CudaVersion max_dev_version;
if (nvrtc_version.first <= 7) { // 7 supports 2-5.x
max_dev_version = CudaVersion(5, 0);
} else if (nvrtc_version.first <= 8) { // 8 supports 2-6.x
max_dev_version = CudaVersion(6, 0);
} else if (nvrtc_version.first <= 9) { // 9 supports 3-7.2
max_dev_version = CudaVersion(7, 2);
} else if (nvrtc_version.first <= 10) { // 10 supports 3-7.5
max_dev_version = CudaVersion(7, 5);
} else if (nvrtc_version == CudaVersion(11, 0)) { // 11.0 supports 3-8.0
max_dev_version = CudaVersion(8, 0);
CUDAVersion max_dev_version;
if (nvrtc_major <= 7) { // 7 supports 2-5.x
max_dev_version = CUDAVersion(5, 0);
} else if (nvrtc_major <= 8) { // 8 supports 2-6.x
max_dev_version = CUDAVersion(6, 0);
} else if (nvrtc_major <= 9) { // 9 supports 3-7.2
max_dev_version = CUDAVersion(7, 2);
} else if (nvrtc_major <= 10) { // 10 supports 3-7.5
max_dev_version = CUDAVersion(7, 5);
} else if (nvrtc_version == CUDAVersion(11, 0)) { // 11.0 supports 3-8.0
max_dev_version = CUDAVersion(8, 0);
} else {
// If the driver version is unknown (i.e. newer than this code)
// assume the driver supports this device
max_dev_version = dev_version;
}
if (dev_version > max_dev_version) {
major = max_dev_version.first;
minor = max_dev_version.second;
cuda_major = max_dev_version.first;
cuda_minor = max_dev_version.second;
// if we are clamping major/minor, sass is not compatible
compile_to_sass = false;
} else {
major = dev_version.first;
minor = dev_version.second;
cuda_major = dev_version.first;
cuda_minor = dev_version.second;
compile_to_sass = true;
}
#if defined(CUDA_VERSION) && CUDA_VERSION < 11010
// compile to sass is not allowed prior to CUDA 11.1
compile_to_sass = false;
#endif
}
//TODO another copy paste from jit, refactor so it's usable from both
// TODO: another copy paste from jit, refactor so it's usable from both
// TODO: try making the CUcontext thread local to see if that improves performance - why is this slow?
void __inline__ initializeCudaContext() {
// lazily construct context if non-existing yet;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
@ -644,11 +665,13 @@ std::string generate_code(
"loader", std::string("LoadWithCast<" + std::to_string(nInputs) + ">"));
env.s("storer", "StoreWithCast");
}
if (contiguous) {
env.s("offset_calculator", "TrivialOffsetCalculator");
} else {
env.s("offset_calculator", "OffsetCalculator");
}
std::stringstream load_inputs;
for (int i = 0; i < nInputs; i++) {
auto i_string = std::to_string(i);
@ -663,8 +686,10 @@ std::string generate_code(
store_outputs << "s.store<" << result_type
<< ">(out[j], data[0], output_offsets[0]);\n";
env.s("store_outputs", store_outputs.str());
static auto cuda_template = at::jit::CodeTemplate(jit_common_types + jit_code_template);
return cuda_template.format(env);
const auto code = cuda_template.format(env);
return code;
}
// vectorized case
@ -697,28 +722,148 @@ std::string generate_code(
env.s("load_unrolled_inputs", load_unrolled_inputs.str());
static auto cuda_template = at::jit::CodeTemplate(jit_common_types + jit_vectorized_code_template);
return cuda_template.format(env);
const auto code = cuda_template.format(env);
return code;
}
// Compiles the kernel
#ifdef BUILD_JITERATOR_WITH_CACHE
// Acquires (possibly creating) the kernel cache directory
c10::optional<std::string> get_cache_dir() {
// If the environment variable USE_TORCH_KERNEL_CACHE is set to "0" then no persistent cache is used
const char* uptkc = std::getenv("USE_PYTORCH_KERNEL_CACHE");
const bool use_kernel_cache = (uptkc == nullptr) ? true : std::strcmp(uptkc, "0");
if (!use_kernel_cache) {
return {};
}
// Cache path comes from PYTORCH_KERNEL_CACHE_PATH, then XDG_CACHE_HOME, then HOME environment variables
std::string cache_dir;
char* ptkcp = std::getenv("PYTORCH_KERNEL_CACHE_PATH");
if (ptkcp != nullptr) {
cache_dir = std::string(ptkcp);
} else {
// USES XDG_CACHE_HOME if it's set
ptkcp = std::getenv("XDG_CACHE_HOME");
if (ptkcp != nullptr) {
cache_dir = std::string(ptkcp) + "/torch/kernels";
} else {
// Falls back to HOME/.cache
ptkcp = std::getenv("HOME");
if (ptkcp == nullptr) {
TORCH_WARN_ONCE("No PYTORCH_KERNEL_CACHE_PATH or HOME environment variable set!",
" This disables kernel caching.");
return {};
} else {
cache_dir = std::string(ptkcp) + "/.cache/torch/kernels";
}
}
}
// Creates the cache directory if it does not exist
const char* p_cache_dir = cache_dir.c_str();
const bool cache_dir_exists = (access(p_cache_dir, F_OK) == 0);
if (!cache_dir_exists) {
if (mkdir(p_cache_dir, S_IRWXU | S_IRWXG | S_IRWXO) != 0) {
TORCH_WARN_ONCE("Specified kernel cache directory could not be created! This disables kernel caching.",
" Specified directory is ", cache_dir, ".",
" This warning will appear only once per process.");
return {};
}
}
// Checks that the cache directory is readable and writable
const bool cache_dir_readable = (access(p_cache_dir, R_OK) == 0);
if (!cache_dir_readable) {
TORCH_WARN_ONCE("Specified kernel cache directory is not readable! This disables kernel caching.",
" Specified directory is ", cache_dir, ".",
" This warning will appear only once per process.");
return {};
}
const bool cache_dir_writable = (access(p_cache_dir, W_OK) == 0);
if (!cache_dir_writable) {
TORCH_WARN_ONCE("Specified kernel cache directory is not writable! This disables kernel caching.",
" Specified directory is ", cache_dir, ".",
" This warning will appear only once per process.");
return {};
}
return cache_dir;
}
#endif // BUILD_JITERATOR_WITH_CACHE
// Compiles the kernel, or acquires if from the cache if caching
NvrtcFunction jit_pwise_function(
const std::string& code,
const std::string& kernel_name) {
// Acquires device and NVRTC properties (for compile arch and occupancy calculations)
initializeCudaContext();
// Acquires CUDA and nvrtc versions and whether we're compiling to ptx or SASS
const cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
int major = 0, minor = 0;
int cuda_major = 0, cuda_minor = 0, nvrtc_major = 0, nvrtc_minor = 0;
bool compile_to_sass = false;
codegenOutputQuery(prop, major, minor, compile_to_sass);
at::cuda::jit::codegenOutputQuery(
prop, cuda_major, cuda_minor, nvrtc_major, nvrtc_minor, compile_to_sass);
// Objects used whether loading from the cache or jit compiling
const auto& nvrtc = at::globalContext().getNVRTC();
NvrtcFunction compiled_kernel_;
std::string name = kernel_name + "_kernel";
#ifdef BUILD_JITERATOR_WITH_CACHE
static const c10::optional<std::string> cache_dir = get_cache_dir();
std::string file_path;
if (cache_dir.has_value()) {
// Attemps to read from the cache.
// Cubin name is <kernel name>_arch<major>.<minor>_nvrtc<major>.<minor>_<ptx or sass>_<program length>_<string hash>
// Note that the SHA1 hash used in the file name is NOT the SHA1 hash of the file's contents,
// because we hash on the CUDA code, but we save the compiled ptx or sass
// Acquires SHA1 hash
c10::sha1 sha1_hash{code};
const auto hash_code = sha1_hash.str();
// Constructs file path by appending constructed cubin name to cache path
std::stringstream ss;
ss << *cache_dir << "/";
ss << kernel_name;
ss << "_arch" << cuda_major << "." << cuda_minor;
ss << "_nvrtc" << nvrtc_major << "." << nvrtc_minor;
ss << (compile_to_sass ? "_sass" : "_ptx");
ss << "_" << code.length();
ss << "_" << hash_code;
file_path = ss.str();
std::ifstream readin{file_path, std::ios::in | std::ifstream::binary};
if (readin.fail()) {
// NOTE: this does not warn because the file might not exist
// TODO: consider if this should explicilty check for the file's existence or not to throw
// an informative warning
readin.close();
} else {
// TODO: try passing the "mapped" file directly to cuModuleLoadCall instead of using an intermediate buffer
std::vector<char> buffer(std::istreambuf_iterator<char>(readin), {});
AT_CUDA_DRIVER_CHECK(nvrtc.cuModuleLoadData(&(compiled_kernel_.module), buffer.data()));
AT_CUDA_DRIVER_CHECK(
nvrtc.cuModuleGetFunction(&(compiled_kernel_.function), compiled_kernel_.module, name.c_str()));
readin.close();
return compiled_kernel_;
}
}
#endif // BUILD_JITERATOR_WITH_CACHE
// Just-in-time compiles the program
// Creates the NVRTC program
nvrtcProgram program;
const auto& nvrtc = at::globalContext().getNVRTC();
AT_CUDA_NVRTC_CHECK(nvrtc.nvrtcCreateProgram(
&program, code.c_str(), nullptr, 0, nullptr, nullptr));
// constructs nvrtc build arguments
#if defined(CUDA_VERSION) && CUDA_VERSION < 11010
// compile to sass is not allowed prior to CUDA 11.1
compile_to_sass = false;
#endif
// Constructs nvrtc build arguments
// CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_)
// which gives better backwards compatibility to work on older driver,
// (since older driver doesn't necessrily recognize PTX emitted by new
@ -727,23 +872,24 @@ NvrtcFunction jit_pwise_function(
// `unsupported_arch==True`), since SASS are not necessarily compatible,
// we fallback to PTX instead.
const std::string compute = std::string("--gpu-architecture=") +
(compile_to_sass ? "sm_" : "compute_") + std::to_string(major) +
std::to_string(minor);
(compile_to_sass ? "sm_" : "compute_") + std::to_string(cuda_major) +
std::to_string(cuda_minor);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<const char*> args = {
"--std=c++14", compute.c_str(), "-default-device"};
#ifndef NDEBUG
// Add line info to generated kernels
args.push_back("-lineinfo");
#else
// Avoid excessive register usage from assertion
args.push_back("-DNDEBUG");
#endif
// compiles and validates result
initializeCudaContext();
#ifndef NDEBUG
// Add line info to generated kernels
args.push_back("-lineinfo");
#else
// Avoid excessive register usage from assertion
args.push_back("-DNDEBUG");
#endif
const auto compilation_result =
nvrtc.nvrtcCompileProgram(program, args.size(), args.data());
// Throws an error on compilation failure
if (compilation_result != NVRTC_SUCCESS) {
size_t logsize;
AT_CUDA_NVRTC_CHECK(nvrtc.nvrtcGetProgramLogSize(program, &logsize));
@ -753,9 +899,10 @@ NvrtcFunction jit_pwise_function(
cu << log.data();
throw std::runtime_error(cu.str() + code);
}
size_t ptx_size = 0;
std::vector<char> ptx;
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
// compile_to_sass determines whether we are generating SASS or PTX, hence
// the different API.
const auto getSize = compile_to_sass
@ -764,23 +911,54 @@ NvrtcFunction jit_pwise_function(
const auto getFunc = compile_to_sass
? at::globalContext().getNVRTC().nvrtcGetCUBIN
: at::globalContext().getNVRTC().nvrtcGetPTX;
#else
#else
const auto getSize = at::globalContext().getNVRTC().nvrtcGetPTXSize;
const auto getFunc = at::globalContext().getNVRTC().nvrtcGetPTX;
#endif
AT_CUDA_NVRTC_CHECK(getSize(program, &ptx_size));
ptx.resize(ptx_size);
AT_CUDA_NVRTC_CHECK(getFunc(program, ptx.data()));
#endif
NvrtcFunction compiled_kernel_;
AT_CUDA_DRIVER_CHECK(nvrtc.cuModuleLoadData(&(compiled_kernel_.module), ptx.data()));
std::string name = kernel_name + "_kernel";
AT_CUDA_DRIVER_CHECK(
nvrtc.cuModuleGetFunction(&(compiled_kernel_.function), compiled_kernel_.module, name.c_str()));
// TODO: use guards to avoid leaking
AT_CUDA_NVRTC_CHECK(nvrtc.nvrtcDestroyProgram(&program));
AT_CUDA_NVRTC_CHECK(getSize(program, &ptx_size));
ptx.resize(ptx_size);
AT_CUDA_NVRTC_CHECK(getFunc(program, ptx.data()));
return compiled_kernel_;
AT_CUDA_DRIVER_CHECK(nvrtc.cuModuleLoadData(&(compiled_kernel_.module), ptx.data()));
AT_CUDA_DRIVER_CHECK(
nvrtc.cuModuleGetFunction(&(compiled_kernel_.function), compiled_kernel_.module, name.c_str()));
// TODO: use guards to avoid leaking
AT_CUDA_NVRTC_CHECK(nvrtc.nvrtcDestroyProgram(&program));
#ifdef BUILD_JITERATOR_WITH_CACHE
if (cache_dir.has_value()) {
// Writes the program to the cache if caching
// NOTE: Actually writes to a per-process temporary file to avoid multi-process contention.
// The temporary file is then renamed to the actual file.
// If the actual file already exists then the rename may fail or replace the actual file,
// the behavior is implementation-specific.
// Files replaced through this process should remain extant if they are being read because
// of UNIX filesystem properties, but this behavior is unverified and may require
// additional review in the future.
// TODO: In C++17 we should be able to use the filesystem header.
const auto pid = getpid();
std::stringstream tmp_file_path_ss;
tmp_file_path_ss << file_path << "_tmp_" << pid;
const std::string tmp_file_path = tmp_file_path_ss.str();
std::ofstream cubin(tmp_file_path, std::ios::out | std::ofstream::binary);
if (cubin.fail()) {
TORCH_WARN_ONCE("Failed to write temporarily kernel cache file!",
" File path was ", tmp_file_path, ".",
" This warning will only appear once per process.");
} else {
std::copy(ptx.begin(), ptx.end(), std::ostreambuf_iterator<char>(cubin));
if (std::rename(tmp_file_path.c_str(), file_path.c_str()) != 0) {
// Removes tmp file if the rename failed
std::remove(tmp_file_path.c_str());
}
}
cubin.close();
}
#endif // BUILD_JITERATOR_WITH_CACHE
return compiled_kernel_;
}
// TODO: may need/want to initialize CUDA context here (refactor into nvrtc call)

View File

@ -1,12 +1,16 @@
#pragma once
#include <iomanip>
#include <functional>
#include <vector>
#include <sstream>
#include <c10/util/ArrayRef.h>
#include <c10/util/complex.h>
#include <functional>
#include <vector>
namespace c10 {
// NOTE: hash_combine is based on implementation from Boost
// NOTE: hash_combine and SHA1 hashing is based on implementation from Boost
//
// Boost Software License - Version 1.0 - August 17th, 2003
//
@ -36,6 +40,198 @@ inline size_t hash_combine(size_t seed, size_t value) {
return seed ^ (value + 0x9e3779b9 + (seed << 6u) + (seed >> 2u));
}
// Creates the SHA1 hash of a string. A 160-bit hash.
// Based on the implementation in Boost (see notice above).
// Note that SHA1 hashes are no longer considered cryptographically
// secure, but are the standard hash for generating unique ids.
// Usage:
// // Let 'code' be a std::string
// c10::sha1 sha1_hash{code};
// const auto hash_code = sha1_hash.str();
// TODO: Compare vs OpenSSL and/or CryptoPP implementations
struct sha1 {
typedef unsigned int(digest_type)[5];
sha1(const std::string &s = "") {
if (!s.empty()) {
reset();
process_bytes(s.c_str(), s.size());
}
}
void reset() {
h_[0] = 0x67452301;
h_[1] = 0xEFCDAB89;
h_[2] = 0x98BADCFE;
h_[3] = 0x10325476;
h_[4] = 0xC3D2E1F0;
block_byte_index_ = 0;
bit_count_low = 0;
bit_count_high = 0;
}
std::string str() {
unsigned int digest[5];
get_digest(digest);
std::ostringstream buf;
for (int i = 0; i < 5; ++i) {
buf << std::hex << std::setfill('0') << std::setw(8) << digest[i];
}
return buf.str();
}
private:
unsigned int left_rotate(unsigned int x, std::size_t n) {
return (x << n) ^ (x >> (32 - n));
}
void process_block_impl() {
unsigned int w[80];
for (std::size_t i = 0; i < 16; ++i) {
w[i] = (block_[i*4 + 0] << 24);
w[i] |= (block_[i*4 + 1] << 16);
w[i] |= (block_[i*4 + 2] << 8);
w[i] |= (block_[i*4 + 3]);
}
for (std::size_t i = 16; i < 80; ++i) {
w[i] = left_rotate((w[i-3] ^ w[i-8] ^ w[i-14] ^ w[i-16]), 1);
}
unsigned int a = h_[0];
unsigned int b = h_[1];
unsigned int c = h_[2];
unsigned int d = h_[3];
unsigned int e = h_[4];
for (std::size_t i = 0; i < 80; ++i) {
unsigned int f;
unsigned int k;
if (i<20) {
f = (b & c) | (~b & d);
k = 0x5A827999;
} else if (i<40) {
f = b ^ c ^ d;
k = 0x6ED9EBA1;
} else if (i<60) {
f = (b & c) | (b & d) | (c & d);
k = 0x8F1BBCDC;
} else {
f = b ^ c ^ d;
k = 0xCA62C1D6;
}
unsigned temp = left_rotate(a, 5) + f + e + k + w[i];
e = d;
d = c;
c = left_rotate(b, 30);
b = a;
a = temp;
}
h_[0] += a;
h_[1] += b;
h_[2] += c;
h_[3] += d;
h_[4] += e;
}
void process_byte_impl(unsigned char byte) {
block_[block_byte_index_++] = byte;
if (block_byte_index_ == 64) {
block_byte_index_ = 0;
process_block_impl();
}
}
void process_byte(unsigned char byte) {
process_byte_impl(byte);
// size_t max value = 0xFFFFFFFF
//if (bit_count_low + 8 >= 0x100000000) { // would overflow
//if (bit_count_low >= 0x100000000-8) {
if (bit_count_low < 0xFFFFFFF8) {
bit_count_low += 8;
} else {
bit_count_low = 0;
if (bit_count_high <= 0xFFFFFFFE) {
++bit_count_high;
} else {
TORCH_CHECK(false, "sha1 too many bytes");
}
}
}
void process_block(void const* bytes_begin, void const* bytes_end) {
unsigned char const* begin = static_cast<unsigned char const*>(bytes_begin);
unsigned char const* end = static_cast<unsigned char const*>(bytes_end);
for(; begin != end; ++begin) {
process_byte(*begin);
}
}
void process_bytes(void const* buffer, std::size_t byte_count) {
unsigned char const* b = static_cast<unsigned char const*>(buffer);
process_block(b, b + byte_count);
}
void get_digest(digest_type& digest) {
// append the bit '1' to the message
process_byte_impl(0x80);
// append k bits '0', where k is the minimum number >= 0
// such that the resulting message length is congruent to 56 (mod 64)
// check if there is enough space for padding and bit_count
if (block_byte_index_ > 56) {
// finish this block
while (block_byte_index_ != 0) {
process_byte_impl(0);
}
// one more block
while (block_byte_index_ < 56) {
process_byte_impl(0);
}
} else {
while (block_byte_index_ < 56) {
process_byte_impl(0);
}
}
// append length of message (before pre-processing)
// as a 64-bit big-endian integer
process_byte_impl( static_cast<unsigned char>((bit_count_high>>24) & 0xFF) );
process_byte_impl( static_cast<unsigned char>((bit_count_high>>16) & 0xFF) );
process_byte_impl( static_cast<unsigned char>((bit_count_high>>8 ) & 0xFF) );
process_byte_impl( static_cast<unsigned char>((bit_count_high) & 0xFF) );
process_byte_impl( static_cast<unsigned char>((bit_count_low>>24) & 0xFF) );
process_byte_impl( static_cast<unsigned char>((bit_count_low>>16) & 0xFF) );
process_byte_impl( static_cast<unsigned char>((bit_count_low>>8 ) & 0xFF) );
process_byte_impl( static_cast<unsigned char>((bit_count_low) & 0xFF) );
// get final digest
digest[0] = h_[0];
digest[1] = h_[1];
digest[2] = h_[2];
digest[3] = h_[3];
digest[4] = h_[4];
}
unsigned int h_[5];
unsigned char block_[64];
std::size_t block_byte_index_;
std::size_t bit_count_low;
std::size_t bit_count_high;
};
////////////////////////////////////////////////////////////////////////////////
// c10::hash implementation
////////////////////////////////////////////////////////////////////////////////

View File

@ -177,7 +177,7 @@ endif()
if(BUILD_SPLIT_CUDA)
# Splitting the source files that'll be in torch_cuda between torch_cuda_cu and torch_cuda_cpp
foreach(tmp ${Caffe2_GPU_SRCS})
if("${tmp}" MATCHES "(.*aten.*\\.cu|.*(b|B)las.*|.*((s|S)olver|Register.*CUDA|Legacy|THC|TensorShapeCUDA|BatchLinearAlgebra|ReduceOps|Equal|Activation|ScanKernels|Sort|TensorTopK|TensorModeKernel|IndexKernel).*\\.cpp)" AND NOT "${tmp}" MATCHES ".*(THC((CachingHost)?Allocator|General)).*")
if("${tmp}" MATCHES "(.*aten.*\\.cu|.*(b|B)las.*|.*((s|S)olver|Register.*CUDA|Legacy|THC|TensorShapeCUDA|BatchLinearAlgebra|ReduceOps|Equal|Activation|ScanKernels|Sort|TensorTopK|TensorModeKernel|IndexKernel|jit_utils).*\\.cpp)" AND NOT "${tmp}" MATCHES ".*(THC((CachingHost)?Allocator|General)).*")
# Currently, torch_cuda_cu will have all the .cu files in aten, as well as some others that depend on those files
list(APPEND Caffe2_GPU_SRCS_CU ${tmp})
else()