pytorch/torch/csrc/jit/codegen/cuda/scheduler.h
Jane Xu 533cb9530e Introducing TORCH_CUDA_CPP_API and TORCH_CUDA_CU_API to the code (#50627)
Summary:
Sub-step of my attempt to split up the torch_cuda library, as it is huge. Please look at https://github.com/pytorch/pytorch/issues/49050 for details on the split and which files are in which target.

This PR introduces two new macros for Windows DLL purposes, TORCH_CUDA_CPP_API and TORCH_CUDA_CU_API. Both are defined as TORCH_CUDA_API for the time being.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50627

Reviewed By: mruberry

Differential Revision: D25955441

Pulled By: janeyx99

fbshipit-source-id: ff226026833b8fb2fb7c77df6f2d6c824f006869
2021-01-21 19:09:11 -08:00

74 lines
2.3 KiB
C++

#pragma once
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
// return true or false on whether given fusion could be scheduled;
TORCH_CUDA_CU_API bool scheduleFusion(
Fusion* fusion,
const at::ArrayRef<c10::IValue> inputs);
// Parameters the Reduction Heuristic Generates to describe the optimial
// schedule. Warning: equal operator is intended for use in caching the kernel
// associated with these reduction parameteres. It does not check if the launch
// parameters are equivelent!
struct ReductionParams {
// Reducing inner most dimension?
bool fastest_dim = true;
// Reduce across the block?
bool cross_block = false;
// Reduce across the grid?
bool cross_grid = false;
// Perform multiple reductions per block?
bool mul_reds_per_blk = false;
// Unrolling factor
int loop_unroll = 4;
LaunchParams lparams;
// Warning: Does not check launch parameters!
bool operator==(const ReductionParams& other) const {
bool attr_equal = other.fastest_dim == fastest_dim &&
other.cross_block == cross_block && other.cross_grid == cross_grid &&
other.mul_reds_per_blk == mul_reds_per_blk &&
other.loop_unroll == loop_unroll;
return attr_equal;
}
};
// Warning: Hash is not based on launch parameters!
class ReductionParamsHash {
public:
size_t operator()(const ReductionParams& rp) const {
constexpr size_t bits = sizeof(std::size_t) * 8;
size_t attr_hash = static_cast<size_t>(rp.fastest_dim) << (bits - 1) |
static_cast<size_t>(rp.cross_block) << (bits - 2) |
static_cast<size_t>(rp.cross_grid) << (bits - 3) |
static_cast<size_t>(rp.mul_reds_per_blk) << (bits - 4);
return attr_hash;
}
};
TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
Fusion* fusion,
const at::ArrayRef<c10::IValue>& fusion_inputs,
TensorView* red_tv);
TORCH_CUDA_CU_API void scheduleReduction(
Fusion* fusion,
const ReductionParams& rparams,
TensorView* red_tv,
std::vector<TensorView*> outs_of_red);
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch