pytorch/torch/csrc/jit/codegen/cuda/lower2device.cpp
Christian Sarofeen b9b4f05abf [nvFuser] Working towards reductions, codegen improvements (#40864)
Summary:
Have basic reduction fusion working, and have improved code generator to approach performance of eager mode reductions. Coming soon will be pointwise-reduction fusions in a way that should prevent the possibility of hitting regressions. Also working on performant softmax kernels in the code generator which may be our next fusion target.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/40864

Reviewed By: ngimel

Differential Revision: D22392877

Pulled By: soumith

fbshipit-source-id: 457448a807d628b1035f6d90bc0abe8a87bf8447
2020-07-06 14:52:49 -07:00

50 lines
1.5 KiB
C++

#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/lower_index.h>
#include <torch/csrc/jit/codegen/cuda/lower_loops.h>
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
#include <torch/csrc/jit/codegen/cuda/lower_unroll.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower_validation.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
namespace torch {
namespace jit {
namespace fuser {
// Traverse through the fusion and print CUDA code associated with it
std::vector<Expr*> GPULower::getLoweredExprs() {
FusionGuard fg(fusion_);
// Validate and make some minor modifications in preparation to generate code.
PrepareForLowering(fusion_);
auto preds = ThreadPredicates::compute(fusion_);
// Run our passes keeping the lowered expressions and forwarding them.
auto loop_nests = LoopNestGenerator::getLoopNest(
fusion_, fusion_->exprs(true, false, true), preds);
auto unrolled_loops = UnrollPass::runPass(fusion_, loop_nests, preds);
auto indexed_loops = IndexLowering::getIndexedExprs(fusion_, unrolled_loops);
return indexed_loops;
}
std::ostream& GPULower::printKernel(
std::ostream& os,
const std::string& kernel_name) {
FusionGuard fg(fusion_);
auto exprs = getLoweredExprs();
IRPrinter irp(os);
irp.printKernel(exprs, kernel_name);
return os;
}
} // namespace fuser
} // namespace jit
} // namespace torch