mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: **Summary:** This PR contains the infrastructure of a new CUDA fuser. This CUDA fuser is based on many of the same principles of TensorExpressions and Halide, however the implementation is ground up. The fusion pass itself is similar to the default CUDA fuser, however, it has undergone some refactoring and is using the new code generation infrastructure. For those who are interested in how the code generation in this PR works, I would recommend reviewing _test/cpp/jit/test_gpu_fusion.cpp_ as well as the long comment section at the beginning of _torch/csrc/jit/codegen/cuda/transform_replay.h_ One of the largest differences between our approach and that of TVM/Halide, is the concept of "TensorView". TensorView from a high level should be thought of similarly to how we think of working with Tensors in PyTorch. It's an N-D object which can undergo transformations that change its dimensionality. Dimensionality changes are done through the operations split/merge/reorder/computeAt. These transformations are similar to split/fuse/reorder/compute_at of TVM, they modify how a tensor is iterated over to generate GPU code. Interestingly, in our scheme these transformations are applied to tensors and only impact how that tensor is generated. **Warning:** This PR is purposefully not feature complete with the current fuser. We wanted to separate out the infrastructure from the fusion capabilities. Once in, smaller incremental PRs will be submitted to expand capabilities of the fuser. **Short term goals:** Parity with current CUDA fuser (including performance): - Dynamic shapes (no recompilation) - Implicit handling of braodcast (broadcasted tensors are treated as tensors of the braodcasted size in the generated code) - Dropout **Mid-term goals:** - Transposes fused with pointwise operations where transpose involves only 2 axes (across the fused operation). - 1-D reductions fused with pointwise operations Pull Request resolved: https://github.com/pytorch/pytorch/pull/34785 Reviewed By: ZolotukhinM Differential Revision: D20650977 Pulled By: soumith fbshipit-source-id: ee39c95a880e1b9822e874ed4cc180971572bf63
248 lines
6.1 KiB
C++
248 lines
6.1 KiB
C++
#include <torch/csrc/jit/codegen/cuda/utils.h>
|
|
|
|
#include <c10/util/Exception.h>
|
|
|
|
#include <torch/csrc/jit/codegen/cuda/tensor_meta.h>
|
|
|
|
#include <algorithm>
|
|
#include <ostream>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
|
|
/*
|
|
* Functions for printing ATen IR
|
|
*/
|
|
|
|
void printScalar(std::ostream& stream, const Value* const value) {
|
|
if (value->node()->kind() == prim::Constant) {
|
|
stream << "Const Scalar: ";
|
|
} else {
|
|
stream << "Scalar: ";
|
|
}
|
|
|
|
if (value->type() == FloatType::get()) {
|
|
stream << "float ";
|
|
const float val = value->node()->f(attr::value);
|
|
stream << val;
|
|
} else if (value->type() == IntType::get()) {
|
|
stream << "int ";
|
|
const int val = value->node()->i(attr::value);
|
|
stream << val;
|
|
} else {
|
|
stream << "unknown";
|
|
}
|
|
stream << std::endl;
|
|
}
|
|
|
|
// Note: innermost dimension is at nDims - 1 (when nDims > 0)
|
|
void printStrides(std::ostream& stream, const c10::VaryingStrides& strides) {
|
|
stream << "Strides=(";
|
|
for (size_t i = 0; i < *(strides.size()); ++i) {
|
|
stream << *(strides[i]);
|
|
if (i != *(strides.size()) - 1) {
|
|
stream << ", ";
|
|
} else {
|
|
stream << ")";
|
|
}
|
|
}
|
|
}
|
|
|
|
void printSizes(std::ostream& stream, const c10::VaryingShape& sizes) {
|
|
stream << "Sizes=(";
|
|
for (size_t i = 0; i < *(sizes.size()); ++i) {
|
|
stream << *(sizes[i]);
|
|
if (i != *(sizes.size()) - 1) {
|
|
stream << ", ";
|
|
} else {
|
|
stream << ")";
|
|
}
|
|
}
|
|
}
|
|
|
|
void printCompleteTensor(
|
|
std::ostream& stream,
|
|
const std::shared_ptr<c10::TensorType>& tensor) {
|
|
stream << "Complete Tensor: ";
|
|
stream << *(tensor->device()) << " ";
|
|
stream << *(tensor->scalarType()) << " ";
|
|
stream << "nDims: " << *(tensor->dim()) << " ";
|
|
stream << std::endl;
|
|
printSizes(stream, tensor->sizes());
|
|
stream << ", ";
|
|
printStrides(stream, tensor->strides());
|
|
stream << std::endl;
|
|
}
|
|
|
|
void printValue(std::ostream& stream, const Value* const value) {
|
|
if (value->isCompleteTensor()) {
|
|
printCompleteTensor(stream, value->type()->expect<TensorType>());
|
|
} else if (value->type()->isSubtypeOf(NumberType::get())) {
|
|
printScalar(stream, value);
|
|
} else {
|
|
stream << "Request to print unknown value" << std::endl;
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Functions for acquiring devices and device types from ATen IR nodes
|
|
*/
|
|
|
|
c10::Device getFusionDevice(const Node* const fusion) {
|
|
const std::shared_ptr<c10::TensorType> out_tensor =
|
|
fusion->outputs()[0]->type()->expect<TensorType>();
|
|
return *(out_tensor->device());
|
|
}
|
|
|
|
c10::DeviceType getFusionDeviceType(const Node* const node) {
|
|
return getFusionDevice(node).type();
|
|
}
|
|
|
|
/*
|
|
* Functions for obtaining parts of complete tensors
|
|
*/
|
|
|
|
std::vector<int64_t> extractStrides(
|
|
const std::shared_ptr<c10::TensorType>& tensor) {
|
|
const c10::VaryingStrides& strides = tensor->strides();
|
|
const auto size = *(strides.size());
|
|
std::vector<int64_t> extracted_strides;
|
|
|
|
for (auto i = decltype(size){0}; i < size; ++i) {
|
|
extracted_strides.push_back(*(strides[i]));
|
|
}
|
|
|
|
return extracted_strides;
|
|
}
|
|
|
|
std::vector<int64_t> extractSizes(
|
|
const std::shared_ptr<c10::TensorType>& tensor) {
|
|
const c10::VaryingStrides& sizes = tensor->sizes();
|
|
const auto size = *(sizes.size());
|
|
std::vector<int64_t> extracted_sizes;
|
|
|
|
for (auto i = decltype(size){0}; i < size; ++i) {
|
|
extracted_sizes.push_back(*(sizes[i]));
|
|
}
|
|
|
|
return extracted_sizes;
|
|
}
|
|
|
|
c10::DeviceType getDeviceType(const std::shared_ptr<c10::TensorType>& tensor) {
|
|
return (*(tensor->device())).type();
|
|
}
|
|
|
|
size_t getRank(const std::shared_ptr<c10::TensorType>& tensor) {
|
|
return *(tensor->dim());
|
|
}
|
|
|
|
size_t getNumel(const std::shared_ptr<c10::TensorType>& tensor) {
|
|
return *(tensor->numel());
|
|
}
|
|
|
|
/*
|
|
* Functions for working with scalar Values
|
|
*/
|
|
|
|
bool isScalar(const Value* const value) {
|
|
return value->type()->isSubtypeOf(NumberType::get());
|
|
}
|
|
|
|
c10::optional<float> getFloat(const Value* const value) {
|
|
if (value->type() == FloatType::get()) {
|
|
return value->node()->f(attr::value);
|
|
}
|
|
|
|
return c10::nullopt;
|
|
}
|
|
|
|
c10::optional<int> getInt(const Value* const value) {
|
|
if (value->type() == IntType::get()) {
|
|
return value->node()->i(attr::value);
|
|
}
|
|
|
|
return c10::nullopt;
|
|
}
|
|
|
|
float getAsFloat(const Value* const value) {
|
|
if (value->type() == FloatType::get()) {
|
|
return value->node()->f(attr::value);
|
|
}
|
|
if (value->type() == IntType::get()) {
|
|
return static_cast<float>(value->node()->i(attr::value));
|
|
}
|
|
|
|
TORCH_CHECK(false, "getAsFloat() found unknown scalar type!");
|
|
}
|
|
|
|
/*
|
|
* Functions for comparing complete tensors
|
|
*/
|
|
|
|
bool haveSameDevice(
|
|
const std::shared_ptr<c10::TensorType>& lhs,
|
|
const std::shared_ptr<c10::TensorType>& rhs) {
|
|
const auto lhs_device = *(lhs->device());
|
|
const auto rhs_device = *(rhs->device());
|
|
return (lhs_device == rhs_device);
|
|
}
|
|
|
|
bool haveSameScalarType(
|
|
const std::shared_ptr<c10::TensorType>& lhs,
|
|
const std::shared_ptr<c10::TensorType>& rhs) {
|
|
const auto lhs_scalar_type = *(lhs->scalarType());
|
|
const auto rhs_scalar_type = *(rhs->scalarType());
|
|
return (lhs_scalar_type == rhs_scalar_type);
|
|
}
|
|
|
|
bool haveSameSizes(
|
|
const std::shared_ptr<c10::TensorType>& lhs,
|
|
const std::shared_ptr<c10::TensorType>& rhs) {
|
|
const auto& lhs_sizes = lhs->sizes();
|
|
const auto& rhs_sizes = rhs->sizes();
|
|
|
|
if (*(lhs_sizes.size()) != *(rhs_sizes.size())) {
|
|
return false;
|
|
}
|
|
|
|
for (size_t i = 0; i < *(lhs_sizes.size()); ++i) {
|
|
if (*(lhs_sizes[i]) != *(rhs_sizes[i])) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool haveSameStrides(
|
|
const std::shared_ptr<c10::TensorType>& lhs,
|
|
const std::shared_ptr<c10::TensorType>& rhs) {
|
|
const auto& lhs_strides = lhs->strides();
|
|
const auto& strides = rhs->strides();
|
|
|
|
if (*(lhs_strides.size()) != *(strides.size())) {
|
|
return false;
|
|
}
|
|
|
|
for (size_t i = 0; i < *(lhs_strides.size()); ++i) {
|
|
if (*(lhs_strides[i]) != *(strides[i])) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool haveSameShape(
|
|
const std::shared_ptr<c10::TensorType>& lhs,
|
|
const std::shared_ptr<c10::TensorType>& rhs) {
|
|
return (
|
|
haveSameDevice(lhs, rhs) && haveSameScalarType(lhs, rhs) &&
|
|
haveSameSizes(lhs, rhs) && haveSameStrides(lhs, rhs));
|
|
}
|
|
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|