mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Have basic reduction fusion working, and have improved code generator to approach performance of eager mode reductions. Coming soon will be pointwise-reduction fusions in a way that should prevent the possibility of hitting regressions. Also working on performant softmax kernels in the code generator which may be our next fusion target. Pull Request resolved: https://github.com/pytorch/pytorch/pull/40864 Reviewed By: ngimel Differential Revision: D22392877 Pulled By: soumith fbshipit-source-id: 457448a807d628b1035f6d90bc0abe8a87bf8447
50 lines
1.5 KiB
C++
50 lines
1.5 KiB
C++
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_index.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_loops.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_unroll.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_validation.h>
|
|
|
|
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
|
|
// Traverse through the fusion and print CUDA code associated with it
|
|
std::vector<Expr*> GPULower::getLoweredExprs() {
|
|
FusionGuard fg(fusion_);
|
|
|
|
// Validate and make some minor modifications in preparation to generate code.
|
|
PrepareForLowering(fusion_);
|
|
|
|
auto preds = ThreadPredicates::compute(fusion_);
|
|
|
|
// Run our passes keeping the lowered expressions and forwarding them.
|
|
auto loop_nests = LoopNestGenerator::getLoopNest(
|
|
fusion_, fusion_->exprs(true, false, true), preds);
|
|
|
|
auto unrolled_loops = UnrollPass::runPass(fusion_, loop_nests, preds);
|
|
|
|
auto indexed_loops = IndexLowering::getIndexedExprs(fusion_, unrolled_loops);
|
|
|
|
return indexed_loops;
|
|
}
|
|
|
|
std::ostream& GPULower::printKernel(
|
|
std::ostream& os,
|
|
const std::string& kernel_name) {
|
|
FusionGuard fg(fusion_);
|
|
auto exprs = getLoweredExprs();
|
|
|
|
IRPrinter irp(os);
|
|
irp.printKernel(exprs, kernel_name);
|
|
return os;
|
|
}
|
|
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|