mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Fix another simplification edge case, a Cond statement when one branch is nullptr and the other is a zero stmt block. This happens mostly with an if with no else branch where all statements inside the if are removed (eg via inlining or simplification). Common case is SplitWithMask -> ComputeInline. Pull Request resolved: https://github.com/pytorch/pytorch/pull/39754 Differential Revision: D21962987 Pulled By: nickgg fbshipit-source-id: 2461415466fbbab88d2329061f90fcfdfa85e243
345 lines
15 KiB
C++
345 lines
15 KiB
C++
#pragma once
|
|
|
|
/**
|
|
* See README.md for instructions on how to add a new test.
|
|
*/
|
|
#include <c10/macros/Export.h>
|
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
#define TH_FORALL_TENSOREXPR_TESTS(_) \
|
|
_(ExprBasicValueTest) \
|
|
_(ExprBasicValueTest02) \
|
|
_(ExprLetTest01) \
|
|
_(ExprLetStmtTest01) \
|
|
_(ExprLetTest02) \
|
|
_(ExprIntTest) \
|
|
_(ExprFloatTest) \
|
|
_(ExprByteTest) \
|
|
_(ExprCharTest) \
|
|
_(ExprShortTest) \
|
|
_(ExprLongTest) \
|
|
_(ExprHalfTest) \
|
|
_(ExprDoubleTest) \
|
|
_(ExprVectorAdd01) \
|
|
_(ExprCompareSelectEQ) \
|
|
_(ExprSubstitute01) \
|
|
_(ExprMath01) \
|
|
_(ExprUnaryMath01) \
|
|
_(ExprBinaryMath01) \
|
|
_(ExprDynamicShapeAdd) \
|
|
_(ExprBitwiseOps) \
|
|
_(IRPrinterBasicValueTest) \
|
|
_(IRPrinterBasicValueTest02) \
|
|
_(IRPrinterCastTest) \
|
|
_(IRPrinterFunctionName) \
|
|
_(ExprSimple01) \
|
|
_(ExprLower01) \
|
|
_(ExprSimple02) \
|
|
_(ExprSplitWithTail) \
|
|
_(ExprSplitWithTailNone) \
|
|
_(ExprSplitWithMask01) \
|
|
_(ScheduleBroadcastAddBuffer) \
|
|
_(ScheduleFunctionCall01) \
|
|
_(ScheduleInlineFunc01) \
|
|
_(ScheduleFuserStyle) \
|
|
_(ScheduleFuserThreeArg) \
|
|
_(ScheduleDynamicShape2D) \
|
|
_(ReduceSum1D) \
|
|
_(ReduceSum2D) \
|
|
_(ReduceSum3D) \
|
|
_(ReduceSum10D) \
|
|
_(ReduceProduct) \
|
|
_(ReduceMax) \
|
|
_(ReduceMinCustomInitializer) \
|
|
_(ReduceAnyAll) \
|
|
_(ReduceMatmul2D) \
|
|
_(ReduceRfactorLike) \
|
|
_(ReduceRfactor) \
|
|
_(Reduce3DRfactor) \
|
|
_(Reduce3DRfactor2) \
|
|
_(Reduce3DRfactor3) \
|
|
_(Reduce3DRfactorWithOuter) \
|
|
_(Reduce3DRfactorRepeated) \
|
|
_(ReduceRfactorInsertionPoint) \
|
|
_(Reduce3DRfactorInsertionPoint) \
|
|
_(ReduceRepeatedInternalRfactor) \
|
|
_(ReduceSplitTail) \
|
|
_(ReduceSplitNoTail) \
|
|
_(ReduceOverSplitTail) \
|
|
_(ReduceSplitMask) \
|
|
_(ReduceSplitNoMask) \
|
|
_(ReduceOverSplitMask) \
|
|
_(ReduceSplitRfactor) \
|
|
_(ReduceOverSplitRfactor) \
|
|
_(SplitReduceAxis) \
|
|
_(SplitNonReduceAxis) \
|
|
_(TypeTest01) \
|
|
_(TypePropagation) \
|
|
_(Cond01) \
|
|
_(IfThenElse01) \
|
|
_(IfThenElse02) \
|
|
_(ATen_cast_Float) \
|
|
_(ATennegInt) \
|
|
_(ATennegFloat) \
|
|
_(ATenaddInt) \
|
|
_(ATenaddFloat) \
|
|
_(ATensubInt) \
|
|
_(ATensubFloat) \
|
|
_(ATenlerp) \
|
|
_(ATenaddcmulInt) \
|
|
_(ATenaddcmulFloat) \
|
|
_(ATenmulInt) \
|
|
_(ATenmulFloat) \
|
|
_(ATendivInt) \
|
|
_(ATendivFloat) \
|
|
_(ATenmaxInt) \
|
|
_(ATenmaxFloat) \
|
|
_(ATenminInt) \
|
|
_(ATenminFloat) \
|
|
_(ATen_sigmoid_backward) \
|
|
_(ATen_tanh_backward) \
|
|
_(ATenreciprocal) \
|
|
_(ATenreluInt) \
|
|
_(ATenreluFloat) \
|
|
_(ATenlogFloat) \
|
|
_(ATenlog10Float) \
|
|
_(ATenlog2Float) \
|
|
_(ATenexpFloat) \
|
|
_(ATenerfFloat) \
|
|
_(ATencosFloat) \
|
|
_(ATeneqInt) \
|
|
_(ATengeInt) \
|
|
_(ATengtInt) \
|
|
_(ATenleInt) \
|
|
_(ATenltInt) \
|
|
_(ConstantFoldSimple) \
|
|
_(ConstantFoldTwoLayer) \
|
|
_(ConstantFoldShifts) \
|
|
_(ConstantFoldBitwise) \
|
|
_(ConstantFoldMultiOp) \
|
|
_(ConstantFoldMinMax) \
|
|
_(ConstantFoldIntrinsics) \
|
|
_(ConstantFoldWithVar) \
|
|
_(UnFoldableExpr) \
|
|
_(HashSimple) \
|
|
_(HashEquivalence) \
|
|
_(HashEquivalenceAfterFolding) \
|
|
_(HashDifferenceTypes) \
|
|
_(HashLargeExpression) \
|
|
_(HashForLoopOptions) \
|
|
_(SimplifyAdd) \
|
|
_(SimplifySub) \
|
|
_(SimplifyMultiLayer) \
|
|
_(SimplifyMultiTerm) \
|
|
_(SimplifyCasts) \
|
|
_(SimplifyEliminatesNoOps) \
|
|
_(SimplifyMultiVar) \
|
|
_(SimplifyEliminatesVar) \
|
|
_(SimplifyAdds) \
|
|
_(SimplifyMuls) \
|
|
_(SimplifySubs) \
|
|
_(SimplifyMultiOp) \
|
|
_(SimplifyManyOps) \
|
|
_(SimplifyFactorization) \
|
|
_(SimplifyFactorizeUneven) \
|
|
_(SimplifyDeeperTerms) \
|
|
_(SimplifyDeeperDifference) \
|
|
_(SimplifyFoldComplexDifference) \
|
|
_(SimplifyIfComponents) \
|
|
_(SimplifyOpaqueTerms) \
|
|
_(SimplifyWontReorderFloat) \
|
|
_(SimplifyRoundModPattern) \
|
|
_(SimplifyRoundModPatternFactorization) \
|
|
_(SimplifyRoundModPatternMultivar) \
|
|
_(SimplifyDivisionScalarFactorization) \
|
|
_(SimplifyConstantBranches) \
|
|
_(SimplifyConstantCond) \
|
|
_(SimplifyEliminateEmptyCond) \
|
|
_(SimplifyEliminateZeroLengthFor) \
|
|
_(SimplifyOneLoopFor) \
|
|
_(SimplifyForWontLoseLoopOptions) \
|
|
_(SimplifyMultilevelFor) \
|
|
_(SimplifyForCleansUp) \
|
|
_(SimplifyEliminateEmptyFor) \
|
|
_(SimplifyFlattenBlock) \
|
|
_(SimplifyEliminateZeroLengthAlloc) \
|
|
_(StmtClone) \
|
|
_(BoundsInference_1) \
|
|
_(BoundsInference_2) \
|
|
_(BoundsInference_3) \
|
|
_(BoundsInference_4) \
|
|
_(BoundsInference_5) \
|
|
_(BoundsInference_6) \
|
|
_(LoopNestComputeAt_1) \
|
|
_(LoopNestComputeAt_2) \
|
|
_(LoopNestComputeAt_3) \
|
|
_(LoopNestComputeAt_4) \
|
|
_(LoopNestReorderAxis1) \
|
|
_(LoopNestReorderPartialAxes) \
|
|
_(LoopNestReorderInternalAxis) \
|
|
_(LoopNestReorderEnclosingAxis) \
|
|
_(LoopNestReorderSameAxis) \
|
|
_(LoopNestReorderExtraStatements) \
|
|
_(LoopNestReorderLongStringOfPreOrphans) \
|
|
_(LoopNestReorderLongStringOfPostOrphans) \
|
|
_(LoopNestReorderLongStringFull) \
|
|
_(LoopNestReorderInternalLoopNest) \
|
|
_(OuterLoopVectorization) \
|
|
_(Kernel_1) \
|
|
_(Kernel_2) \
|
|
_(Kernel_3) \
|
|
_(FuserPass_1) \
|
|
_(FuserPass_2)
|
|
|
|
#define TH_FORALL_TENSOREXPR_TESTS_LLVM(_) \
|
|
_(LLVMByteImmTest) \
|
|
_(LLVMCharImmTest) \
|
|
_(LLVMShortImmTest) \
|
|
_(LLVMIntImmTest) \
|
|
_(LLVMLongImmTest) \
|
|
_(LLVMFloatImmTest) \
|
|
_(LLVMDoubleImmTest) \
|
|
_(LLVMHalfImmTest) \
|
|
_(LLVMByteAddTest) \
|
|
_(LLVMCharAddTest) \
|
|
_(LLVMShortAddTest) \
|
|
_(LLVMIntAddTest) \
|
|
_(LLVMLongAddTest) \
|
|
_(LLVMFloatAddTest) \
|
|
_(LLVMDoubleAddTest) \
|
|
_(LLVMHalfAddTest) \
|
|
_(LLVMByteSubTest) \
|
|
_(LLVMCharSubTest) \
|
|
_(LLVMShortSubTest) \
|
|
_(LLVMIntSubTest) \
|
|
_(LLVMLongSubTest) \
|
|
_(LLVMFloatSubTest) \
|
|
_(LLVMDoubleSubTest) \
|
|
_(LLVMHalfSubTest) \
|
|
_(LLVMByteMulTest) \
|
|
_(LLVMCharMulTest) \
|
|
_(LLVMShortMulTest) \
|
|
_(LLVMIntMulTest) \
|
|
_(LLVMLongMulTest) \
|
|
_(LLVMFloatMulTest) \
|
|
_(LLVMDoubleMulTest) \
|
|
_(LLVMHalfMulTest) \
|
|
_(LLVMByteDivTest) \
|
|
_(LLVMCharDivTest) \
|
|
_(LLVMShortDivTest) \
|
|
_(LLVMIntDivTest) \
|
|
_(LLVMLongDivTest) \
|
|
_(LLVMFloatDivTest) \
|
|
_(LLVMDoubleDivTest) \
|
|
_(LLVMHalfDivTest) \
|
|
_(LLVMIntToFloatCastTest) \
|
|
_(LLVMFloatToIntCastTest) \
|
|
_(LLVMIntToLongCastTest) \
|
|
_(LLVMByteToCharCastTest) \
|
|
_(LLVMHalfToLongCastTest) \
|
|
_(LLVMByteToDoubleCastTest) \
|
|
_(LLVMLetTest01) \
|
|
_(LLVMLetTest02) \
|
|
_(LLVMLetTestMultitype) \
|
|
_(LLVMBufferTest) \
|
|
_(LLVMBlockTest) \
|
|
_(LLVMLoadStoreTest) \
|
|
_(LLVMVecLoadStoreTest) \
|
|
_(LLVMVecFloat_acosLane4Test) \
|
|
_(LLVMVecFloat_asinLane4Test) \
|
|
_(LLVMVecFloat_atanLane4Test) \
|
|
_(LLVMVecFloat_coshLane4Test) \
|
|
_(LLVMVecFloat_sinhLane4Test) \
|
|
_(LLVMVecFloat_tanhLane4Test) \
|
|
_(LLVMVecFloat_erfLane4Test) \
|
|
_(LLVMVecFloat_erfcLane4Test) \
|
|
_(LLVMVecFloat_expm1Lane4Test) \
|
|
_(LLVMVecFloat_lgammaLane4Test) \
|
|
_(LLVMVecFloat_acosLane8Test) \
|
|
_(LLVMVecFloat_asinLane8Test) \
|
|
_(LLVMVecFloat_atanLane8Test) \
|
|
_(LLVMVecFloat_coshLane8Test) \
|
|
_(LLVMVecFloat_sinhLane8Test) \
|
|
_(LLVMVecFloat_tanhLane8Test) \
|
|
_(LLVMVecFloat_erfLane8Test) \
|
|
_(LLVMVecFloat_erfcLane8Test) \
|
|
_(LLVMVecFloat_expm1Lane8Test) \
|
|
_(LLVMVecFloat_lgammaLane8Test) \
|
|
_(LLVMVecDouble_acosLane2Test) \
|
|
_(LLVMVecDouble_asinLane2Test) \
|
|
_(LLVMVecDouble_atanLane2Test) \
|
|
_(LLVMVecDouble_coshLane2Test) \
|
|
_(LLVMVecDouble_sinhLane2Test) \
|
|
_(LLVMVecDouble_tanhLane2Test) \
|
|
_(LLVMVecDouble_erfLane2Test) \
|
|
_(LLVMVecDouble_erfcLane2Test) \
|
|
_(LLVMVecDouble_expm1Lane2Test) \
|
|
_(LLVMVecDouble_lgammaLane2Test) \
|
|
_(LLVMVecDouble_acosLane4Test) \
|
|
_(LLVMVecDouble_asinLane4Test) \
|
|
_(LLVMVecDouble_atanLane4Test) \
|
|
_(LLVMVecDouble_coshLane4Test) \
|
|
_(LLVMVecDouble_sinhLane4Test) \
|
|
_(LLVMVecDouble_tanhLane4Test) \
|
|
_(LLVMVecDouble_erfLane4Test) \
|
|
_(LLVMVecDouble_erfcLane4Test) \
|
|
_(LLVMVecDouble_expm1Lane4Test) \
|
|
_(LLVMVecDouble_lgammaLane4Test) \
|
|
_(LLVMMemcpyTest) \
|
|
_(LLVMBzeroTest) \
|
|
_(LLVMElemwiseAdd) \
|
|
_(LLVMElemwiseAddFloat) \
|
|
_(LLVMElemwiseLog10Float) \
|
|
_(LLVMElemwiseMaxInt) \
|
|
_(LLVMElemwiseMinInt) \
|
|
_(LLVMElemwiseMaxNumFloat) \
|
|
_(LLVMElemwiseMaxNumNaNFloat) \
|
|
_(LLVMElemwiseMinNumFloat) \
|
|
_(LLVMElemwiseMinNumNaNFloat) \
|
|
_(LLVMCompareSelectIntEQ) \
|
|
_(LLVMCompareSelectFloatEQ) \
|
|
_(LLVMStoreFloat) \
|
|
_(LLVMSimpleMath01) \
|
|
_(LLVMComputeMul) \
|
|
_(LLVMBroadcastAdd) \
|
|
_(LLVMBitwiseOps) \
|
|
_(LLVMDynamicShapeAdd) \
|
|
_(LLVMBindDynamicShapeAdd) \
|
|
_(LLVMTensorDynamicShapeAdd) \
|
|
_(LLVMDynamicShape2D) \
|
|
_(LLVMEmptyStmt) \
|
|
_(LLVMEliminatedStmt) \
|
|
_(LLVMIfThenElseTest) \
|
|
_(LLVMVectorizerLoadStoreTest) \
|
|
_(LLVMSimpleReduction) \
|
|
_(LLVMRFactorReduction) \
|
|
_(LLVMRFactorVectorizedReduction)
|
|
|
|
#define TH_FORALL_TENSOREXPR_TESTS_CUDA(_) \
|
|
_(CudaTestVectorAdd01) \
|
|
_(CudaTestVectorAdd02) \
|
|
_(CudaDynamicShape2D) \
|
|
_(CudaDynamicShapeSplit) \
|
|
_(CudaOneBlockOneThreadGlobalReduce1) \
|
|
_(CudaOneBlockMultiThreadGlobalReduce1) \
|
|
_(CudaNoThreadIdxWrite_1) \
|
|
_(CudaSharedMemReduce_1) \
|
|
_(CudaLocalMemReduce_1) \
|
|
_(CudaTestRand01)
|
|
|
|
#define DECLARE_TENSOREXPR_TEST(name) void test##name();
|
|
TH_FORALL_TENSOREXPR_TESTS(DECLARE_TENSOREXPR_TEST)
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
TH_FORALL_TENSOREXPR_TESTS_LLVM(DECLARE_TENSOREXPR_TEST)
|
|
#endif
|
|
#ifdef USE_CUDA
|
|
TH_FORALL_TENSOREXPR_TESTS_CUDA(DECLARE_TENSOREXPR_TEST)
|
|
#endif
|
|
#undef DECLARE_TENSOREXPR_TEST
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|