mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: We've got quite a few things going on, preparing a push back to upstream so we don't get too desynced. - Major refactor of transform replay. It is now far more robust and fixes bugs discovered in reductions. Preparing for extension to explicit broadcast ops which will be the last major memory pattern for op coverage. Broadcast ops will allow us to express up to and potentially beyond norms and gemms. - Initial runtime expression evaluator. This allows us to evaluate expressions at runtime. Will be useful for determining our grid/block layout at runtime, so we don't have to manually compute them according to the code we're trying to generate. - Moving to int64 and double for scalar representations to match PyTorch JIT. - Improvements in codegen interface where we return Tensor like object instead of parent class Val. - Add `addcmul` and `lerp` ops - General updates, fixes, test additions, test inprovements. Pull Request resolved: https://github.com/pytorch/pytorch/pull/39579 Differential Revision: D21974001 Pulled By: soumith fbshipit-source-id: 7f7ccc91593466e948f3ce90f8f9b7fbc5c28de2
104 lines
3.0 KiB
C++
104 lines
3.0 KiB
C++
#pragma once
|
|
|
|
#include <ATen/core/ivalue.h>
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
|
|
|
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
|
|
|
/*
|
|
* The exposed APIs in this file is used by manager.h/cpp
|
|
*
|
|
* code here handles CUDA code generation and execution from Fusion IR.
|
|
* NVRTC is used for kernel compilation. CUDA Driver API is used to load and
|
|
* execute compiled kernel.
|
|
*
|
|
* A stringify trick is used to unify the IO data structure for kernel
|
|
* execution. We stringify the data structure and assert it direclty in the
|
|
* generated CUDA source to avoid runtime search of header files.
|
|
* The header file is included twice: one time as a c++ code to allow host code
|
|
* to prepare IO data; the other time for stringify.
|
|
*/
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
|
|
// TODO: given that KernelArgsReq is becoming complicated and not really
|
|
// hashable, I should throw this inside CudaKernel.
|
|
// Interfacing object allows kernel to return whether a given input
|
|
// configuration could/should be handled.
|
|
struct KernelArgsReq {
|
|
virtual bool matchKernelSize(const at::ArrayRef<c10::IValue> inputs) = 0;
|
|
virtual ~KernelArgsReq() = default;
|
|
};
|
|
|
|
// naive P-wise kernel only requires same dimensionality for input tensors.
|
|
struct NaivePWKernelArgsReq : KernelArgsReq {
|
|
bool matchKernelSize(const at::ArrayRef<c10::IValue> inputs) override;
|
|
std::vector<int> dims_;
|
|
};
|
|
|
|
struct CudaKernel {
|
|
public:
|
|
CudaKernel() {
|
|
fusion_ = std::make_unique<Fusion>();
|
|
}
|
|
|
|
CUmodule& getModule() {
|
|
return module_;
|
|
}
|
|
|
|
CUfunction& getFunction() {
|
|
return function_;
|
|
}
|
|
|
|
int16_t device_;
|
|
CUmodule module_;
|
|
CUfunction function_;
|
|
int max_blocks_;
|
|
int unroll_factor_ = 1;
|
|
|
|
// WARNING:
|
|
// Block and Grid dimension setting is here for testing purposes only
|
|
// These are not here for general use and only for use with
|
|
// the runTestKernel() function.
|
|
void block(unsigned int x = 1, unsigned int y = 1, unsigned int z = 1) {
|
|
block_ = dim3(x, y, z);
|
|
}
|
|
void grid(unsigned int x = 1, unsigned int y = 1, unsigned int z = 1) {
|
|
grid_ = dim3(x, y, z);
|
|
}
|
|
|
|
dim3 block_;
|
|
dim3 grid_;
|
|
bool has_random_;
|
|
|
|
std::unique_ptr<Fusion> fusion_;
|
|
};
|
|
|
|
// compile Fusion to CUDA functions:
|
|
// 1. JIT compilation via nvrtc to generate CUDA c++ kernel code;
|
|
// 2. CUDA Drive API to load CUDA c++ kernel code as function_;
|
|
TORCH_CUDA_API void compileKernel(CudaKernel* entry);
|
|
|
|
// run loaded kernel through Function.
|
|
// inputs/outputs is given in the sense of a PyTorch JIT ir node. This function
|
|
// wraps IO data structure for tensors on host.
|
|
TORCH_CUDA_API void runKernel(
|
|
CudaKernel* entry,
|
|
const at::ArrayRef<c10::IValue> inputs,
|
|
std::vector<at::Tensor> outputs);
|
|
|
|
// Facility API to run kernel in tests.
|
|
TORCH_CUDA_API void runTestKernel(
|
|
CudaKernel* entry,
|
|
const at::ArrayRef<c10::IValue> inputs,
|
|
std::vector<at::Tensor> outputs);
|
|
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|