#include #include #include #include #include #include #include #include #include #include namespace torch { namespace jit { namespace fuser { namespace cuda { namespace { // Will return a new value of type val with the DataType dtype. Val* newScalar(ValType vtype, DataType dtype) { switch (vtype) { case (ValType::NamedScalar): case (ValType::Scalar): switch (dtype) { case DataType::Bool: return IrBuilder::create(); case DataType::Double: case DataType::Float: case DataType::Half: case DataType::BFloat16: return IrBuilder::create(); case DataType::Int32: case DataType::Int: return IrBuilder::create(); case DataType::ComplexFloat: case DataType::ComplexDouble: return IrBuilder::create(); default: break; } default: break; } TORCH_CHECK( false, "Cannot handle ValType: ", vtype, " with DataType:", dtype, " in newScalar."); } TensorView* newOutputTV(const std::vector& vals, DataType dtype) { std::vector tvs; for (auto val : vals) if (val->getValType() == ValType::TensorView) tvs.push_back(val->as()); TORCH_CHECK( !tvs.empty(), "Tried to create new output TensorView but received empty list."); std::vector out_domain( TensorDomain::noReductions(tvs[0]->getMaybeRFactorDomain()).size(), nullptr); // For the start and stop offsets, take the maximum of input axes. // For now, the offsets of both start and stop are always integer // constant, so we can statically compute them. It is unclear // whether we would need to support dynamic offsetting, e.g., // shifting by a dynamic offset. std::vector start_offsets(out_domain.size(), 0); std::vector stop_offsets(out_domain.size(), 0); std::vector extent_vals(out_domain.size(), nullptr); std::vector iter_types(out_domain.size(), IterType::Iteration); for (auto tv : tvs) { auto dom = TensorDomain::noReductions(tv->getMaybeRFactorDomain()); TORCH_INTERNAL_ASSERT( dom.size() == out_domain.size(), "Invalid tensor view found while producing and output, it has ", dom.size(), " dimensions but expected ", out_domain.size()); for (const auto i : c10::irange(dom.size())) { if (dom[i]->isBroadcast()) { continue; } if (extent_vals[i] == nullptr) { extent_vals[i] = dom[i]->extent(); iter_types[i] = dom[i]->getIterType(); } auto start_offset = dom[i]->start()->as(); auto stop_offset = dom[i]->stopOffset()->as(); // Currently, start is always constant TORCH_INTERNAL_ASSERT( start_offset->isConst(), "Invalid IterDomain start: ", start_offset); TORCH_INTERNAL_ASSERT( stop_offset->isConst(), "Invalid IterDomain stop offset: ", stop_offset); start_offsets[i] = std::max(start_offsets[i], start_offset->value().value()); stop_offsets[i] = std::max(stop_offsets[i], stop_offset->value().value()); } } for (const auto dim_i : c10::irange(out_domain.size())) { if (extent_vals[dim_i] != nullptr) { out_domain[dim_i] = IrBuilder::create( IrBuilder::create(start_offsets[dim_i]), extent_vals[dim_i], IrBuilder::create(stop_offsets[dim_i]), ParallelType::Serial, iter_types[dim_i]); } else { IterType itype = IterType::BroadcastWithoutStride; for (const auto tv : tvs) { auto dim = TensorDomain::noReductions(tv->getMaybeRFactorDomain())[dim_i]; // If there's an unresolved bcast dim and it came from a strided dim, // assume output of it should be strided too if (dim->getIterType() == IterType::BroadcastWithStride) { itype = IterType::BroadcastWithStride; break; } } out_domain[dim_i] = IrBuilder::create( FusionGuard::getCurFusion()->zeroVal(), FusionGuard::getCurFusion()->oneVal(), ParallelType::Serial, itype); } } return IrBuilder::create( IrBuilder::create( out_domain, std::vector(out_domain.size(), true)), dtype); } std::vector maybeBroadcast(const std::vector& vals) { std::vector out_vals(vals.size(), nullptr); size_t n_dims = 0; for (auto val : vals) { if (val->getValType().value() == ValType::TensorView) { n_dims = std::max( n_dims, TensorDomain::noReductions( val->as()->getMaybeRFactorDomain()) .size()); } } for (const auto i : c10::irange(vals.size())) { if (vals[i]->getValType().value() == ValType::TensorView) { auto tv = vals[i]->as(); size_t tv_dims = TensorDomain::noReductions(tv->getMaybeRFactorDomain()).size(); if (tv_dims < n_dims) { std::vector bcast_flags(n_dims, false); for (const auto j : c10::irange(n_dims - tv_dims)) { bcast_flags[j] = true; } out_vals[i] = broadcast(tv, bcast_flags); } else { out_vals[i] = vals[i]; } } else { out_vals[i] = vals[i]; } } return out_vals; } Val* newValLike(Val* val, DataType dtype) { TORCH_CHECK( dtype != DataType::Null, "Invalid datatype provided for new value."); const ValType vtype = val->getValType().value(); if (vtype == ValType::TensorView) return newOutputTV({val}, dtype); return newScalar(vtype, dtype); } } // namespace Val* castOp(DataType dtype, Val* v1) { if (v1->getDataType().value() == dtype) { return set(v1); } if (cast_func_str(std::make_pair(v1->getDataType().value(), dtype)) == c10::nullopt) { TORCH_CHECK( false, "Illegal Cast value from DataType: ", v1->getDataType().value(), " to DataType: ", dtype); } Val* out = newValLike(v1, dtype); IrBuilder::create(UnaryOpType::Cast, out, v1); return out; } TensorView* castOp(DataType dtype, TensorView* v1) { return castOp(dtype, v1->as())->as(); } Val* unaryOp(UnaryOpType type, Val* v1) { TORCH_INTERNAL_ASSERT( type != UnaryOpType::Address, "The reference operator & is not accessible in the Fusion IR"); // TODO: We should add the following, but we need to go through schedulers // and make sure all calls to "fusion->inputs" includes the output of RandLike // // If rand like, there isn't a real dependency on the input value, so map it // to a dummy scalar. if // // (type == UnaryOpType::RandLike) { // v1 = new NamedScalar("__rnd", v1->getDataType().value()); // } Val* out = newValLike(v1, v1->getDataType().value()); IrBuilder::create(type, out, v1); return out; } TensorView* unaryOp(UnaryOpType type, TensorView* v1) { return unaryOp(type, v1->as())->as(); } Val* unaryOp(UnaryOpType type, Val* v1, const TypePromotionConfig& config) { auto casted_v1 = promoteValues(config, {v1}).front(); return unaryOp(type, casted_v1); } TensorView* unaryOp( UnaryOpType type, TensorView* v1, const TypePromotionConfig& config) { auto casted_v1 = promoteValues(config, {v1}).front(); return unaryOp(type, casted_v1)->as(); } // UNARY OPERATIONS #define NVFUSER_DEFINE_UNARY_OP(op_name, op_type) \ Val* op_name(Val* v) { \ return unaryOp(UnaryOpType::op_type, v); \ } \ TensorView* op_name(TensorView* tv) { \ return unaryOp(UnaryOpType::op_type, tv); \ } NVFUSER_DEFINE_UNARY_OP(set, Set) NVFUSER_DEFINE_UNARY_OP(randlike, RandLike) NVFUSER_DEFINE_UNARY_OP(notOp, Not) NVFUSER_DEFINE_UNARY_OP(ceil, Ceil) NVFUSER_DEFINE_UNARY_OP(floor, Floor) NVFUSER_DEFINE_UNARY_OP(frac, Frac) NVFUSER_DEFINE_UNARY_OP(neg, Neg) NVFUSER_DEFINE_UNARY_OP(relu, Relu) NVFUSER_DEFINE_UNARY_OP(round, Round) NVFUSER_DEFINE_UNARY_OP(silu, Silu) NVFUSER_DEFINE_UNARY_OP(trunc, Trunc) #undef NVFUSER_DEFINE_UNARY_OP // The output of abs(complex_tensor) are real numbers Val* abs(Val* v) { if (v->getDataType() == DataType::ComplexDouble) { Val* out = newValLike(v, DataType::Double); IrBuilder::create(UnaryOpType::Abs, out, v); return out; } if (v->getDataType() == DataType::ComplexFloat) { Val* out = newValLike(v, DataType::Float); IrBuilder::create(UnaryOpType::Abs, out, v); return out; } return unaryOp(UnaryOpType::Abs, v); } TensorView* abs(TensorView* tv) { return abs(tv->as())->as(); } // UNARY FLOAT CAST OPERATIONS #define NVFUSER_DEFINE_UNARY_FLOAT_OP(op_name, op_type) \ Val* op_name(Val* v) { \ return unaryOp(UnaryOpType::op_type, v, TypePromotion::float_op_config); \ } \ TensorView* op_name(TensorView* tv) { \ return unaryOp(UnaryOpType::op_type, tv, TypePromotion::float_op_config); \ } NVFUSER_DEFINE_UNARY_FLOAT_OP(acos, Acos) NVFUSER_DEFINE_UNARY_FLOAT_OP(asin, Asin) NVFUSER_DEFINE_UNARY_FLOAT_OP(atan, Atan) NVFUSER_DEFINE_UNARY_FLOAT_OP(atanh, Atanh) NVFUSER_DEFINE_UNARY_FLOAT_OP(cos, Cos) NVFUSER_DEFINE_UNARY_FLOAT_OP(cosh, Cosh) NVFUSER_DEFINE_UNARY_FLOAT_OP(exp, Exp) NVFUSER_DEFINE_UNARY_FLOAT_OP(expm1, Expm1) NVFUSER_DEFINE_UNARY_FLOAT_OP(erf, Erf) NVFUSER_DEFINE_UNARY_FLOAT_OP(erfc, Erfc) NVFUSER_DEFINE_UNARY_FLOAT_OP(lgamma, Lgamma) NVFUSER_DEFINE_UNARY_FLOAT_OP(log, Log) NVFUSER_DEFINE_UNARY_FLOAT_OP(log10, Log10) NVFUSER_DEFINE_UNARY_FLOAT_OP(log1p, Log1p) NVFUSER_DEFINE_UNARY_FLOAT_OP(log2, Log2) NVFUSER_DEFINE_UNARY_FLOAT_OP(reciprocal, Reciprocal) NVFUSER_DEFINE_UNARY_FLOAT_OP(rsqrt, Rsqrt) NVFUSER_DEFINE_UNARY_FLOAT_OP(sigmoid, Sigmoid) NVFUSER_DEFINE_UNARY_FLOAT_OP(sin, Sin) NVFUSER_DEFINE_UNARY_FLOAT_OP(sinh, Sinh) NVFUSER_DEFINE_UNARY_FLOAT_OP(sqrt, Sqrt) NVFUSER_DEFINE_UNARY_FLOAT_OP(tan, Tan) NVFUSER_DEFINE_UNARY_FLOAT_OP(tanh, Tanh) #undef NVFUSER_DEFINE_UNARY_FLOAT_OP // BINARY OPERATIONS namespace { // Helper function to reduce repetitive code template TensorView* arithOpOverloads(Val* (*func)(Val*, Val*), T1* v1, T2* v2) { return func(v1->template as(), v2->template as()) ->template as(); } template TensorView* arithOpOverloads( BinaryOpType type, T1* v1, T2* v2, DataType common_dtype) { return binaryOp( type, v1->template as(), v2->template as(), common_dtype) ->template as(); } template TensorView* arithOpOverloads( Val* (*func)(Val*, Val*, Val*), T1* v1, T2* v2, T3* v3) { auto vals = maybeBroadcast({v1, v2, v3}); return func( vals[0]->template as(), vals[1]->template as(), vals[2]->template as()) ->template as(); } template TensorView* arithOpOverloads( Val* (*func)(Val*, Val*, Val*, Val*), T1* v1, T2* v2, T3* v3, T4* v4) { auto vals = maybeBroadcast({v1, v2, v3, v4}); return func( vals[0]->template as(), vals[1]->template as(), vals[2]->template as(), vals[3]->template as()) ->template as(); } // Output type promotion logic for binary operators DataType getOutputType( BinaryOpType op_type, Val* v1, Val* v2, DataType common_dtype) { if (isLogicalOp(op_type)) { return DataType::Bool; } else if (common_dtype == DataType::Null) { return promote_type(v1->getDataType().value(), v2->getDataType().value()); } else { return common_dtype; } } } // namespace Val* binaryOp(BinaryOpType type, Val* v1, Val* v2, DataType common_dtype) { const auto out_dtype = getOutputType(type, v1, v2, common_dtype); const auto out_vtype = promote_type(v1->getValType().value(), v2->getValType().value()); auto vals = maybeBroadcast({v1, v2}); Val* out = nullptr; if (out_vtype == ValType::TensorView) { out = newOutputTV(vals, out_dtype); } else { out = newScalar(out_vtype, out_dtype); } IrBuilder::create(type, out, vals[0], vals[1]); return out; } TensorView* binaryOp( BinaryOpType type, TensorView* v1, Val* v2, DataType common_dtype) { return arithOpOverloads(type, v1, v2, common_dtype); } TensorView* binaryOp( BinaryOpType type, Val* v1, TensorView* v2, DataType common_dtype) { return arithOpOverloads(type, v1, v2, common_dtype); } TensorView* binaryOp( BinaryOpType type, TensorView* v1, TensorView* v2, DataType common_dtype) { return arithOpOverloads(type, v1, v2, common_dtype); } Val* binaryOp( BinaryOpType type, Val* v1, Val* v2, const TypePromotionConfig& config) { std::vector operands = {v1, v2}; auto common_dtype = computeTypes(config, operands); auto casted_values = promoteValues(operands, common_dtype); return binaryOp( type, casted_values.front(), casted_values.back(), common_dtype); } TensorView* binaryOp( BinaryOpType type, TensorView* v1, Val* v2, const TypePromotionConfig& config) { std::vector operands = {v1, v2}; auto common_dtype = computeTypes(config, operands); auto casted_values = promoteValues(operands, common_dtype); return binaryOp( type, casted_values.front()->as(), casted_values.back(), common_dtype); } TensorView* binaryOp( BinaryOpType type, Val* v1, TensorView* v2, const TypePromotionConfig& config) { std::vector operands = {v1, v2}; auto common_dtype = computeTypes(config, operands); auto casted_values = promoteValues(operands, common_dtype); return binaryOp( type, casted_values.front(), casted_values.back()->as(), common_dtype); } TensorView* binaryOp( BinaryOpType type, TensorView* v1, TensorView* v2, const TypePromotionConfig& config) { std::vector operands = {v1, v2}; auto common_dtype = computeTypes(config, operands); auto casted_values = promoteValues(operands, common_dtype); return binaryOp( type, casted_values.front()->as(), casted_values.back()->as(), common_dtype); } #define NVFUSER_DEFINE_BINARY_FLOAT_OP(op_name, op_type) \ Val* op_name(Val* v1, Val* v2) { \ return binaryOp( \ BinaryOpType::op_type, v1, v2, TypePromotion::float_op_config); \ } \ TensorView* op_name(TensorView* v1, Val* v2) { \ return binaryOp( \ BinaryOpType::op_type, v1, v2, TypePromotion::float_op_config); \ } \ TensorView* op_name(Val* v1, TensorView* v2) { \ return binaryOp( \ BinaryOpType::op_type, v2, v2, TypePromotion::float_op_config); \ } \ TensorView* op_name(TensorView* v1, TensorView* v2) { \ return binaryOp( \ BinaryOpType::op_type, v1, v2, TypePromotion::float_op_config); \ } NVFUSER_DEFINE_BINARY_FLOAT_OP(div, Div) NVFUSER_DEFINE_BINARY_FLOAT_OP(atan2, Atan2) #undef NVFUSER_DEFINE_BINARY_FLOAT_OP #define NVFUSER_DEFINE_BINARY_CAST_OP(op_name, op_type) \ Val* op_name(Val* v1, Val* v2) { \ return binaryOp( \ BinaryOpType::op_type, v1, v2, TypePromotion::default_op_config); \ } \ TensorView* op_name(TensorView* v1, Val* v2) { \ return binaryOp( \ BinaryOpType::op_type, v1, v2, TypePromotion::default_op_config); \ } \ TensorView* op_name(Val* v1, TensorView* v2) { \ return binaryOp( \ BinaryOpType::op_type, v1, v2, TypePromotion::default_op_config); \ } \ TensorView* op_name(TensorView* v1, TensorView* v2) { \ return binaryOp( \ BinaryOpType::op_type, v1, v2, TypePromotion::default_op_config); \ } // Integer binary ops NVFUSER_DEFINE_BINARY_CAST_OP(mod, Mod) NVFUSER_DEFINE_BINARY_CAST_OP(ceilDiv, CeilDiv) NVFUSER_DEFINE_BINARY_CAST_OP(add, Add) NVFUSER_DEFINE_BINARY_CAST_OP(fmod, Fmod) NVFUSER_DEFINE_BINARY_CAST_OP(mul, Mul) NVFUSER_DEFINE_BINARY_CAST_OP(pow, Pow) NVFUSER_DEFINE_BINARY_CAST_OP(remainder, Remainder) NVFUSER_DEFINE_BINARY_CAST_OP(sub, Sub) NVFUSER_DEFINE_BINARY_CAST_OP(lshift, Lshift) NVFUSER_DEFINE_BINARY_CAST_OP(rshift, Rshift) NVFUSER_DEFINE_BINARY_CAST_OP(andOp, And) NVFUSER_DEFINE_BINARY_CAST_OP(orOp, Or) NVFUSER_DEFINE_BINARY_CAST_OP(xorOp, Xor) #undef NVFUSER_DEFINE_BINARY_CAST_OP #define NVFUSER_DEFINE_BINARY_COMPARE_OP(op_name, op_type) \ Val* op_name(Val* v1, Val* v2) { \ return binaryOp( \ BinaryOpType::op_type, v1, v2, TypePromotion::comparison_op_config); \ } \ TensorView* op_name(TensorView* v1, Val* v2) { \ return binaryOp( \ BinaryOpType::op_type, v1, v2, TypePromotion::comparison_op_config); \ } \ TensorView* op_name(Val* v1, TensorView* v2) { \ return binaryOp( \ BinaryOpType::op_type, v1, v2, TypePromotion::comparison_op_config); \ } \ TensorView* op_name(TensorView* v1, TensorView* v2) { \ return binaryOp( \ BinaryOpType::op_type, v1, v2, TypePromotion::comparison_op_config); \ } // Logical binary ops NVFUSER_DEFINE_BINARY_COMPARE_OP(eq, Eq) NVFUSER_DEFINE_BINARY_COMPARE_OP(ge, GE) NVFUSER_DEFINE_BINARY_COMPARE_OP(gt, GT) NVFUSER_DEFINE_BINARY_COMPARE_OP(le, LE) NVFUSER_DEFINE_BINARY_COMPARE_OP(lt, LT) NVFUSER_DEFINE_BINARY_COMPARE_OP(ne, NE) #undef NVFUSER_DEFINE_BINARY_COMPARE_OP // REDUCTION OPERATIONS // TODO: How do we adjust this so we can reduce to a single scalar value? static TensorView* newForReduction( TensorView* tv, const std::vector& axes, DataType data_type = DataType::Null) { auto orig_domain = TensorDomain::noReductions(tv->getMaybeRFactorDomain()); std::set axes_set(axes.begin(), axes.end()); std::vector new_domain; TORCH_INTERNAL_ASSERT( !axes_set.empty(), "Asked for ouput of reduction, but no reduction axis provided."); TORCH_INTERNAL_ASSERT( (*(axes_set.rbegin())) < orig_domain.size(), "Error setting up reduction, reduction axis (", *(axes_set.rbegin()), ") is outside nDims (", orig_domain.size(), "). Keep in mind reductions are relative to root domains, not modified views."); auto axis_iter = axes_set.begin(); for (const auto dim : c10::irange(orig_domain.size())) { bool isReduction = false; if (axis_iter != axes_set.end() && *axis_iter == dim) { isReduction = true; axis_iter++; } const IterDomain* id = orig_domain[dim]; TORCH_CHECK( !(isReduction && id->isBroadcast() && !id->isImplicitBroadcast()), "Cannot reduce an axis that is marked as broadcasted as it has an undetermined size. Tried to reduce ID = ", id, " of tensor ", tv); new_domain.push_back(IrBuilder::create( id->start(), id->extent(), id->stopOffset(), ParallelType::Serial, isReduction ? IterType::Reduction : id->getIterType())); } TensorDomain* td = IrBuilder::create( new_domain, std::vector(new_domain.size(), true)); data_type = data_type == DataType::Null ? tv->getDataType().value() : data_type; return IrBuilder::create(td, data_type); } TensorView* reductionOp( BinaryOpType reduction_op_type, const std::vector& axes, Val* init, TensorView* tv, bool keep_dim /*=false*/) { TORCH_CHECK( init->isConstScalar(), "Cannot create a reduction operation where the initial value is not a const scalar."); TORCH_CHECK( TensorDomain::sameAs(tv->getMaybeRFactorDomain(), tv->domain()->domain()), "Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/computeAt."); TORCH_CHECK(tv->nDims() > 0, "Tried to reduce a 0-dim tensor"); TORCH_CHECK(axes.size() > 0, "No reduction axis specified"); std::vector uint_axes; const int ndims = tv->domain()->noReductions().size(); for (int axis : axes) { if (axis < 0) { axis += ndims; } TORCH_CHECK( axis >= 0 && axis < ndims, "Reduction on invalid axis, recieved: ", axis, " however tensor view only has ", ndims, " non-reduction dims."); uint_axes.push_back((unsigned int)axis); } TensorView* out = newForReduction(tv, uint_axes); const auto out_type = out->getDataType().value(); const auto init_type = init->getDataType().value(); TORCH_CHECK( (isFloatingPointType(out_type) && isFloatingPointType(init_type)) || (isComplexType(out_type) && isComplexType(init_type)) || (isIntegralType(out_type) && isIntegralType(init_type)) || (isBooleanType(out_type) && isBooleanType(init_type)), "Types should match for reduction ops but received: ", out_type, " and ", init_type); IrBuilder::create(reduction_op_type, init, out, tv); if (keep_dim) { auto tv_root = TensorDomain::noReductions(tv->getMaybeRFactorDomain()); std::vector is_broadcast(tv_root.size(), false); for (auto axis : uint_axes) { is_broadcast.at(axis) = true; } out = broadcast(out, is_broadcast); } return out; } TensorView* sum( TensorView* v1, const std::vector& axes, bool keep_dim /*=false*/) { Val* init = nullptr; auto dtype = v1->getDataType().value(); if (isFloatingPointType(dtype)) { init = IrBuilder::create(0.0); } else if (isComplexType(dtype)) { init = IrBuilder::create(c10::complex(0.0, 0.0)); } else if (isIntegralType(dtype)) { init = FusionGuard::getCurFusion()->zeroVal(); } else if (isBooleanType(dtype)) { v1 = castOp(DataType::Int, v1); init = FusionGuard::getCurFusion()->zeroVal(); } else { TORCH_CHECK( false, "Could not generate a sum op for tensor with type: ", v1->getDataType().value()); } return reductionOp(BinaryOpType::Add, axes, init, v1, keep_dim); } TensorView* max( TensorView* v1, const std::vector& axes, bool keep_dim /*=false*/) { Val* init = nullptr; switch (v1->getDataType().value()) { case (DataType::Double): init = IrBuilder::create(-std::numeric_limits::infinity()); break; case (DataType::Float): init = IrBuilder::create(-std::numeric_limits::infinity()); break; case (DataType::Int): init = IrBuilder::create(std::numeric_limits::lowest()); break; case (DataType::Int32): init = IrBuilder::create(std::numeric_limits::lowest()); break; case (DataType::Bool): init = IrBuilder::create(false); break; default: TORCH_CHECK( false, "Could not generate a max op for tensor with type: ", v1->getDataType().value()); } return reductionOp(BinaryOpType::Max, axes, init, v1, keep_dim); } TensorView* min( TensorView* v1, const std::vector& axes, bool keep_dim /*=false*/) { Val* init = nullptr; switch (v1->getDataType().value()) { case (DataType::Double): init = IrBuilder::create(std::numeric_limits::infinity()); break; case (DataType::Float): init = IrBuilder::create(std::numeric_limits::infinity()); break; case (DataType::Int): init = IrBuilder::create(std::numeric_limits::max()); break; case (DataType::Int32): init = IrBuilder::create(std::numeric_limits::max()); break; case (DataType::Bool): init = IrBuilder::create(true); break; default: TORCH_CHECK( false, "Could not generate a min op for tensor with type: ", v1->getDataType().value()); } return reductionOp(BinaryOpType::Min, axes, init, v1, keep_dim); } TensorView* broadcast( TensorView* inp, const std::vector& is_broadcast_dim) { auto nBCastDims = is_broadcast_dim.size(); // Validate is_broadcast_dim unsigned int n_broadcasts = 0; for (auto ent : is_broadcast_dim) if (ent) n_broadcasts++; TORCH_CHECK( nBCastDims - n_broadcasts == TensorDomain::noReductions(inp->getMaybeRFactorDomain()).size(), "Invalid broadcast, number of false entries in is_broadcast_dim expected to be ", TensorDomain::noReductions(inp->getMaybeRFactorDomain()).size(), " but received ", nBCastDims - n_broadcasts); if (n_broadcasts == 0) { auto identity = set(inp); TORCH_INTERNAL_ASSERT( identity->getValType().value() == ValType::TensorView, "Expected identity op, but didn't get a TensorView back."); return identity->as(); } std::vector out_domain; // Don't propagate reduction IDs through arith ops. auto inp_domain = TensorDomain::noReductions(inp->getMaybeRFactorDomain()); size_t iinp = 0, ibdim = 0; while (ibdim < is_broadcast_dim.size()) { if (is_broadcast_dim[ibdim]) { out_domain.push_back(IrBuilder::create( FusionGuard::getCurFusion()->zeroVal(), FusionGuard::getCurFusion()->oneVal(), ParallelType::Serial, IterType::BroadcastWithoutStride)); } else { out_domain.push_back(IrBuilder::create( inp_domain[iinp]->start(), inp_domain[iinp]->extent(), inp_domain[iinp]->stopOffset(), inp_domain[iinp]->getParallelType(), inp_domain[iinp]->getIterType())); iinp++; } ibdim++; } TensorView* out_tensor = IrBuilder::create( IrBuilder::create( out_domain, std::vector(out_domain.size(), true)), inp->getDataType().value()); IrBuilder::create(out_tensor, inp, is_broadcast_dim); return out_tensor; } WelfordResult Welford( TensorView* tv, const std::vector& axes, TensorView* init_avg, TensorView* init_var, Int* init_N) { TORCH_CHECK( TensorDomain::sameAs(tv->getRootDomain(), tv->domain()->domain()), "Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/computeAt."); TORCH_CHECK(tv->nDims() > 0, "Tried to reduce a 0-dim tensor"); TORCH_CHECK(axes.size() > 0, "No reduction axis specified"); if (init_N == nullptr) { init_N = FusionGuard::getCurFusion()->zeroVal(); } // Initial values for welford op are tensors, so their dims have to match the // output dim, // i.e. original_dims - dims_to_be_reduced Val* init_avg_val = nullptr; Val* init_var_val = nullptr; if (!init_N->isZeroInt()) { TORCH_CHECK( init_avg != nullptr && init_var != nullptr && init_N != nullptr, "welford op: all init values need to be provided"); TORCH_CHECK( (axes.size() + init_avg->getRootDomain().size()) == tv->getRootDomain().size(), "welford op: initial tensor mismatch"); TORCH_CHECK( (axes.size() + init_var->getRootDomain().size()) == tv->getRootDomain().size(), "welford op: initial tensor mismatch"); init_avg_val = init_avg; init_var_val = init_var; } else { init_avg_val = IrBuilder::create(0); init_var_val = IrBuilder::create(0); } // Check and collect reduction axes std::vector uint_axes; const int ndims = tv->domain()->noReductions().size(); for (int axis : axes) { if (axis < 0) { axis += ndims; } TORCH_CHECK( axis >= 0 && axis < ndims, "Reduction on invalid axis, recieved: ", axis, " however tensor view only has ", ndims, " non-reduction dims."); uint_axes.push_back((unsigned int)axis); } // Create tensor outputs TensorView* out_avg = newForReduction(tv, uint_axes); TensorView* out_var = newForReduction(tv, uint_axes); TensorView* out_N = newForReduction(tv, uint_axes, DataType::Index); IrBuilder::create( out_avg, out_var, out_N, /*out var/avg/count */ init_avg_val, init_var_val, init_N, /*init var/avg/count */ tv, nullptr, FusionGuard::getCurFusion()->oneVal()); /*in var/avg/count */ return WelfordResult(out_avg, out_var, out_N); } WelfordResult::WelfordResult( TensorView* in_avg, TensorView* in_var_sum, TensorView* in_n) : avg(in_avg), var_sum(in_var_sum), n(in_n) { TORCH_INTERNAL_ASSERT(avg->definition()->sameAs(var_sum->definition())); TORCH_INTERNAL_ASSERT(avg->definition()->sameAs(n->definition())); } WelfordResult WelfordResult::rFactor(const std::vector& axes) { auto o_tv = avg->definition()->as()->out()->as(); return o_tv->rFactor(axes, avg, var_sum, n); } TensorView* transpose( TensorView* inp, const std::unordered_map& old2new) { auto inp_domain = TensorDomain::noReductions(inp->getMaybeRFactorDomain()); std::vector out_domain(inp_domain.size()); auto new2old = ir_utils::normalizeOld2New(old2new, inp_domain.size()); for (const auto i : c10::irange(out_domain.size())) { auto in_id = inp_domain[new2old[i]]; out_domain[i] = in_id->clone(); } TensorView* out_tensor = IrBuilder::create( IrBuilder::create( out_domain, std::vector(out_domain.size(), true)), inp->getDataType().value()); IrBuilder::create(out_tensor, inp, new2old); return out_tensor; } // COMPOUND OPERATIONS // add_alpha Val* add_alpha(Val* v1, Val* v2, Val* s) { TORCH_CHECK( s->getValType().value() == ValType::Scalar, "Alpha value should be a Scalar Valtype and not ", s->getValType().value()); auto vals = maybeBroadcast({v1, v2, s}); Val* intrm = mul(vals[1], vals[2]); return add(vals[0], intrm); } TensorView* add_alpha(TensorView* v1, Val* v2, Val* v3) { return arithOpOverloads(add_alpha, v1, v2, v3); } TensorView* add_alpha(Val* v1, TensorView* v2, Val* v3) { return arithOpOverloads(add_alpha, v1, v2, v3); } TensorView* add_alpha(TensorView* v1, TensorView* v2, Val* v3) { return arithOpOverloads(add_alpha, v1, v2, v3); } // sub_alpha Val* sub_alpha(Val* v1, Val* v2, Val* s) { TORCH_CHECK( s->getValType().value() == ValType::Scalar, "Alpha value should be a Scalar Valtype and not ", s->getValType().value()); auto vals = maybeBroadcast({v1, v2, s}); Val* intrm = mul(vals[1], vals[2]); return sub(vals[0], intrm); } TensorView* sub_alpha(TensorView* v1, Val* v2, Val* v3) { return arithOpOverloads(sub_alpha, v1, v2, v3); } TensorView* sub_alpha(Val* v1, TensorView* v2, Val* v3) { return arithOpOverloads(sub_alpha, v1, v2, v3); } TensorView* sub_alpha(TensorView* v1, TensorView* v2, Val* v3) { return arithOpOverloads(sub_alpha, v1, v2, v3); } // lerp Val* lerp(Val* start, Val* end, Val* weight) { auto vals = maybeBroadcast({start, end, weight}); Val* intrm1 = sub(vals[1], vals[0]); Val* intrm2 = mul(vals[2], intrm1); return add(vals[0], intrm2); } TensorView* lerp(TensorView* v1, Val* v2, Val* v3) { return arithOpOverloads(lerp, v1, v2, v3); } TensorView* lerp(Val* v1, TensorView* v2, Val* v3) { return arithOpOverloads(lerp, v1, v2, v3); } TensorView* lerp(Val* v1, Val* v2, TensorView* v3) { return arithOpOverloads(lerp, v1, v2, v3); } TensorView* lerp(TensorView* v1, TensorView* v2, Val* v3) { return arithOpOverloads(lerp, v1, v2, v3); } TensorView* lerp(TensorView* v1, Val* v2, TensorView* v3) { return arithOpOverloads(lerp, v1, v2, v3); } TensorView* lerp(Val* v1, TensorView* v2, TensorView* v3) { return arithOpOverloads(lerp, v1, v2, v3); } TensorView* lerp(TensorView* v1, TensorView* v2, TensorView* v3) { return arithOpOverloads(lerp, v1, v2, v3); } // addcmul Val* addcmul(Val* v1, Val* v2, Val* v3, Val* s) { TORCH_CHECK( s->getValType().value() == ValType::Scalar, "Alpha value should be a Scalar Valtype and not ", s->getValType().value()); auto vals = maybeBroadcast({v1, v2, v3, s}); Val* intrm1 = mul(vals[2], vals[3]); Val* intrm2 = mul(vals[1], intrm1); return add(vals[0], intrm2); } TensorView* addcmul(TensorView* v1, Val* v2, Val* v3, Val* v4) { return arithOpOverloads(addcmul, v1, v2, v3, v4); } TensorView* addcmul(Val* v1, TensorView* v2, Val* v3, Val* v4) { return arithOpOverloads(addcmul, v1, v2, v3, v4); } TensorView* addcmul(Val* v1, Val* v2, TensorView* v3, Val* v4) { return arithOpOverloads(addcmul, v1, v2, v3, v4); } TensorView* addcmul(TensorView* v1, TensorView* v2, Val* v3, Val* v4) { return arithOpOverloads(addcmul, v1, v2, v3, v4); } TensorView* addcmul(TensorView* v1, Val* v2, TensorView* v3, Val* v4) { return arithOpOverloads(addcmul, v1, v2, v3, v4); } TensorView* addcmul(Val* v1, TensorView* v2, TensorView* v3, Val* v4) { return arithOpOverloads(addcmul, v1, v2, v3, v4); } TensorView* addcmul(TensorView* v1, TensorView* v2, TensorView* v3, Val* v4) { return arithOpOverloads(addcmul, v1, v2, v3, v4); } // TERNARY OPERATIONS // where (c ? v1 : v2) Val* where(Val* c, Val* v1, Val* v2) { TORCH_CHECK( c->getDataType().value() == DataType::Bool, "Condition should be of DataType Bool, not ", c->getDataType().value()); auto casted_values = promoteValues(TypePromotion::default_op_config, {v1, v2}); v1 = casted_values[0]; v2 = casted_values[1]; TORCH_CHECK(c->getDataType().value() == DataType::Bool); auto out_dtype = promote_type(v1->getDataType().value(), v2->getDataType().value()); auto out_vtype = promote_type(v1->getValType().value(), v2->getValType().value()); auto vals = maybeBroadcast({c, v1, v2}); Val* out = nullptr; if (out_vtype == ValType::TensorView) { out = newOutputTV(vals, out_dtype); } else { out = newScalar(out_vtype, out_dtype); } IrBuilder::create( TernaryOpType::Where, out, vals[0], vals[1], vals[2]); return out; } TensorView* where(TensorView* v1, Val* v2, Val* v3) { return arithOpOverloads(where, v1, v2, v3); } TensorView* where(Val* v1, TensorView* v2, Val* v3) { return arithOpOverloads(where, v1, v2, v3); } TensorView* where(Val* v1, Val* v2, TensorView* v3) { return arithOpOverloads(where, v1, v2, v3); } TensorView* where(TensorView* v1, TensorView* v2, Val* v3) { return arithOpOverloads(where, v1, v2, v3); } TensorView* where(TensorView* v1, Val* v2, TensorView* v3) { return arithOpOverloads(where, v1, v2, v3); } TensorView* where(Val* v1, TensorView* v2, TensorView* v3) { return arithOpOverloads(where, v1, v2, v3); } TensorView* where(TensorView* v1, TensorView* v2, TensorView* v3) { return arithOpOverloads(where, v1, v2, v3); } // TERNARY OPERATIONS Val* threshold(Val* in, Val* thresh, Val* value) { TORCH_CHECK( (thresh->getValType().value() == ValType::Scalar || thresh->getValType().value() == ValType::NamedScalar) && (value->getValType().value() == ValType::Scalar || value->getValType().value() == ValType::NamedScalar), "For Threshold operation: Thresh and Value values should be Scalars."); thresh = optionalCast(in->getDataType().value(), thresh); value = optionalCast(in->getDataType().value(), value); Val* out = newValLike(in, in->getDataType().value()); IrBuilder::create( TernaryOpType::Threshold, out, in, thresh, value); return out; } TensorView* threshold(TensorView* in, Val* thresh, Val* value) { return threshold(in->as(), thresh, value)->as(); } Val* clamp(Val* in, Val* min_val, Val* max_val) { TORCH_CHECK( (min_val->getValType().value() == ValType::Scalar || min_val->getValType().value() == ValType::NamedScalar) && (max_val->getValType().value() == ValType::Scalar || max_val->getValType().value() == ValType::NamedScalar), "For Clamp operation: Min and Max values should be Scalars."); min_val = optionalCast(in->getDataType().value(), min_val); max_val = optionalCast(in->getDataType().value(), max_val); Val* out = newValLike(in, in->getDataType().value()); IrBuilder::create(TernaryOpType::Clamp, out, in, min_val, max_val); return out; } TensorView* clamp(TensorView* in, Val* min_val, Val* max_val) { return clamp(in->as(), min_val, max_val)->as(); } // sum_to operator TensorView* sum_to(TensorView* in, const std::vector& sum_to_size) { const auto& root = TensorDomain::noReductions(in->getMaybeRFactorDomain()); TORCH_CHECK( root.size() >= sum_to_size.size(), "sum_to: Error trying to reduce", in, "into a shape of size", sum_to_size.size()); // If no reduction is needed sum_to returns the input tv TensorView* out = in; const int64_t leading_dims = root.size() - sum_to_size.size(); // Generate reduction axes for leading dims std::vector reduce_dims(leading_dims); std::iota(reduce_dims.begin(), reduce_dims.end(), 0); // Generate reduction axes for dims within sum_to_size std::vector inner_red_dims(sum_to_size.size(), false); bool reduction_within_shape = false; // Reduce rest of the dims with keep_dim for (const auto i : c10::irange(leading_dims, root.size())) { if (sum_to_size[i - leading_dims]->isOneInt() && !root[i]->extent()->isOneInt()) { inner_red_dims[i - leading_dims] = true; reduce_dims.push_back(i); reduction_within_shape = true; } } // Reduction step if (!reduce_dims.empty()) { out = sum(in, reduce_dims); } // Broadcast back reduced dims within shape if (reduction_within_shape) { out = broadcast(out, inner_red_dims); } return out; } TensorView* sum_to(TensorView* in, const std::vector& sum_to_size) { const auto& root = TensorDomain::noReductions(in->getMaybeRFactorDomain()); TORCH_CHECK( root.size() >= sum_to_size.size(), "sum_to: Error trying to reduce", in, "into a shape of size", sum_to_size.size()); // If no reduction is needed sum_to returns the input tv TensorView* out = in; const int64_t leading_dims = root.size() - sum_to_size.size(); // Generate reduction axes for leading dims std::vector reduce_dims(leading_dims); std::iota(reduce_dims.begin(), reduce_dims.end(), 0); // Generate reduction axes for dims within sum_to_size std::vector inner_red_dims(sum_to_size.size(), false); bool reduction_within_shape = false; // Reduce rest of the dims with keep_dim for (const auto i : c10::irange(leading_dims, root.size())) { if (sum_to_size[i - leading_dims] == 1 && !root[i]->extent()->isOneInt()) { inner_red_dims[i - leading_dims] = true; reduce_dims.push_back(i); reduction_within_shape = true; } } // Reduction step if (!reduce_dims.empty()) { out = sum(in, reduce_dims); } // Broadcast back reduced dims within shape if (reduction_within_shape) { out = broadcast(out, inner_red_dims); } return out; } TensorView* shift(TensorView* inp, const std::vector& offsets, bool pad) { // When pad is false, no padding is given. When it is true, padding // sizes are set so that output domains have the same extents as // input domains. std::vector pad_width(offsets.size(), 0); if (pad) { for (const auto i : c10::irange(offsets.size())) { pad_width[i] = std::abs(offsets[i]); } } return shift(inp, offsets, pad_width); } TensorView* shift( TensorView* inp, const std::vector& offsets, const std::vector& pad_width_param) { auto inp_dom = TensorDomain::noReductions(inp->getRootDomain()); const auto ndims = inp_dom.size(); auto pad_width = pad_width_param; // Default padding is set so that the extent is kept unchanged if (pad_width.empty()) { pad_width = offsets; for (auto& p : pad_width) { p = std::abs(p); } } TORCH_CHECK( ndims == offsets.size(), "Invalid shift offsets, number of entries in offsets expected to be ", ndims, " but received ", offsets.size()); TORCH_CHECK( ndims == pad_width.size(), "Invalid padding width list, number of entries in pad_width expected to be ", ndims, " but received ", pad_width.size()); std::for_each(pad_width.begin(), pad_width.end(), [](const auto& pad) { TORCH_CHECK(pad >= 0, "Padding width must be >= 0: ", pad); }); TensorView* out = nullptr; std::vector out_dom; for (const auto i : c10::irange(ndims)) { const auto inp_axis = inp_dom[i]; const auto offset = offsets[i]; const auto pad = pad_width[i]; if (offset == 0) { out_dom.push_back(inp_axis->clone()); continue; } Int* current_start_offset = dynamic_cast(inp_axis->start()); TORCH_INTERNAL_ASSERT( current_start_offset != nullptr && current_start_offset->isConst(), "Invalid IterDomain start value:", current_start_offset); Int* current_stop_offset = dynamic_cast(inp_axis->stopOffset()); TORCH_INTERNAL_ASSERT( current_stop_offset != nullptr && current_stop_offset->isConst(), "Invalid IterDomain stop offset value:", current_stop_offset); const auto cur_start_offset_value = current_start_offset->value().value(); const auto cur_stop_offset_value = current_stop_offset->value().value(); int64_t out_start_offset = 0; int64_t out_stop_offset = 0; if (offset > 0) { // shift to right; extent remains the same, start and stop // positions are moved right out_start_offset = cur_start_offset_value + offset - pad; out_stop_offset = std::max(cur_stop_offset_value - offset, int64_t(0)); // If pad > offset, the extent of the output ID could be larger than the // input, and the start offset of the output domain could become // negative, which is not supported. TORCH_CHECK( out_start_offset >= 0, "Invalid shift offset and padding. Padding must not be larger than the absolute extent of shift offset. Padding: ", pad, ". Shift: ", offset, "."); } else { // shift to left; extent remains the same, start and stop // positions are moved left out_start_offset = std::max(cur_start_offset_value + offset, int64_t(0)); out_stop_offset = cur_stop_offset_value - offset - pad; // Similar to the above case whwere offset is positive, if pad > // -offset (note offset is negative), the extent of the output // ID could be larger than the input, and the stop offset of the // output domain could become negative. TORCH_CHECK( out_stop_offset >= 0, "Invalid shift offset and padding. Padding must not be larger than the absolute extent of shift offset. Padding: ", pad, ". Shift: ", offset, "."); } out_dom.push_back(IrBuilder::create( IrBuilder::create(out_start_offset), inp_axis->extent(), IrBuilder::create(out_stop_offset), ParallelType::Serial, inp_axis->getIterType())); } out = IrBuilder::create( IrBuilder::create( out_dom, std::vector(out_dom.size(), true)), inp->getDataType().value()); IrBuilder::create(out, inp, offsets, pad_width); return out; } namespace { // Return a new TensorDomain with given root domains. Apply // strides if necessary. With non-unit strides, strided domains become an // rfactor domain. TensorDomain* generateTensorDomainWithStrides( const std::vector& root_domains, const std::vector& strides, bool skip_unit_stride) { std::vector strided_domains; // If strides are just unit strides, don't apply striding if (strides.empty() || (skip_unit_stride && std::all_of( strides.begin(), strides.end(), [](int s) { return s == 1; }))) { return IrBuilder::create( root_domains, std::vector(root_domains.size(), true)); } for (const auto i : c10::irange(root_domains.size())) { auto root_dom = root_domains.at(i); if (i >= strides.size() || (skip_unit_stride && strides[i] == 1)) { strided_domains.push_back(root_dom); continue; } // Split the root domain by the stride auto split_out = root_dom->stridedSplit(strides[i]); strided_domains.push_back(split_out.first); strided_domains.push_back(split_out.second); } auto contig_vector_size = strided_domains.size(); auto strided_td = IrBuilder::create( root_domains, strided_domains, strided_domains, std::vector(contig_vector_size, true)); return strided_td; } } // namespace TensorView* gather( TensorView* inp, const std::vector& window_shape, const std::vector>& pad_width, const std::vector& strides, bool trim_out_of_bounds) { auto inp_dom = TensorDomain::noReductions(inp->getMaybeRFactorDomain()); const auto ndims = inp_dom.size(); TORCH_CHECK( ndims == window_shape.size(), "Invalid window shape: number of entries expected to be ", ndims, " but received ", window_shape.size()); std::for_each(window_shape.begin(), window_shape.end(), [](const auto& w) { TORCH_CHECK(w > 0, "Window size must be > 0: ", w); }); TORCH_CHECK( ndims == pad_width.size(), "Invalid pad width: number of entries expected to be ", ndims, " but received ", pad_width.size()); std::for_each(pad_width.begin(), pad_width.end(), [](const auto& p) { TORCH_CHECK( p.size() == 2, "Each entry of pad_width must have two non-negative integers."); std::for_each(p.begin(), p.end(), [](const auto& p_left_or_right) { TORCH_CHECK( p_left_or_right >= 0, "Padding must be >= 0: ", p_left_or_right); }); }); TORCH_CHECK( strides.empty() || ndims == strides.size(), "Invalid strides: number of entries expected to be ", ndims, " but received ", strides.size()); std::for_each(strides.begin(), strides.end(), [](const auto& s) { TORCH_CHECK(s > 0, "Stride must be > 0: ", s); }); std::vector out_root_domains; std::vector out_gather_dom; for (const auto i : c10::irange(ndims)) { const auto inp_axis = inp_dom[i]; const auto window_dim = window_shape[i]; const auto pad_left = pad_width[i][0]; const auto pad_right = pad_width[i][1]; // This may be over-conservative TORCH_INTERNAL_ASSERT(inp_axis->start()->isZeroInt()); const auto inp_stop_offset = inp_axis->stopOffset()->getInt(); TORCH_INTERNAL_ASSERT( inp_stop_offset.has_value(), "Dynamic stop offset not supported: ", inp_axis); const auto extent_adjustment = window_dim - 1 - pad_left - pad_right; TORCH_CHECK( extent_adjustment >= 0, "Invalid gather window and padding as output extent would be larger than input.", " Window: ", window_dim, ". Padding left: ", pad_left, ". Padding right: ", pad_right); const auto out_stop_offset = inp_stop_offset.value() + extent_adjustment; out_root_domains.push_back(IrBuilder::create( FusionGuard::getCurFusion()->zeroVal(), inp_axis->extent(), IrBuilder::create(out_stop_offset), ParallelType::Serial, inp_axis->getIterType())); // create a new axis for the gathered domain out_gather_dom.push_back(IrBuilder::create( FusionGuard::getCurFusion()->zeroVal(), IrBuilder::create(window_dim), ParallelType::Serial, IterType::Gather)); } out_root_domains.insert( out_root_domains.end(), out_gather_dom.begin(), out_gather_dom.end()); TensorDomain* out_td = nullptr; if (trim_out_of_bounds) { // If no stride vector is given, just use stride 1. It does not do // any striding effect, but out-of-bounds values are trimmed. auto s = strides.empty() ? std::vector(ndims, 1) : strides; out_td = generateTensorDomainWithStrides(out_root_domains, strides, false); } else { out_td = generateTensorDomainWithStrides(out_root_domains, strides, true); } auto out_tv = IrBuilder::create(out_td, inp->getDataType().value()); IrBuilder::create(out_tv, inp, window_shape, pad_width); return out_tv; } namespace { //! Create new output for mma static TensorView* newForMma( TensorView* tv_a, TensorView* tv_b, const std::vector& axes, DataType data_type = DataType::Float) { auto orig_domain_a = TensorDomain::noReductions(tv_a->getMaybeRFactorDomain()); auto orig_domain_b = TensorDomain::noReductions(tv_b->getMaybeRFactorDomain()); TORCH_INTERNAL_ASSERT( orig_domain_a.size() == orig_domain_b.size(), "MMA op: need matching dim input"); std::set axes_set(axes.begin(), axes.end()); std::vector new_domain; TORCH_INTERNAL_ASSERT( !axes_set.empty(), "Asked for ouput of reduction, but no reduction axis provided."); TORCH_INTERNAL_ASSERT( (*(axes_set.rbegin())) < orig_domain_a.size(), "Error setting up reduction, reduction axis (", *(axes_set.rbegin()), ") is outside nDims (", orig_domain_a.size(), "). Keep in mind reductions are relative to root domains, not modified views."); auto axis_iter = axes_set.begin(); for (const auto dim : c10::irange(orig_domain_a.size())) { bool isReduction = false; if (axis_iter != axes_set.end() && *axis_iter == dim) { isReduction = true; axis_iter++; } const IterDomain* id = orig_domain_a[dim]->isBroadcast() ? orig_domain_b[dim] : orig_domain_a[dim]; TORCH_CHECK( !(isReduction && id->isBroadcast() && !id->isImplicitBroadcast()), "Cannot reduce an axis that is marked as broadcasted as it has an undetermined size. Tried to reduce ID = ", id, " of tensor ", tv_a, "and", tv_b); new_domain.push_back(IrBuilder::create( id->start(), id->extent(), id->stopOffset(), ParallelType::Serial, isReduction ? IterType::Reduction : id->getIterType())); } TensorDomain* td = IrBuilder::create( new_domain, std::vector(new_domain.size(), true)); return IrBuilder::create(td, data_type); } } // namespace TensorView* fusedMultiplySum( TensorView* tv_a, TensorView* tv_b, const std::vector& axes, Val* init) { if (init == nullptr) { init = IrBuilder::create(0); } // TODO: // We will want to support initialize and rfactor with // mma as well, for maybe fusing bias in prolog. // TODO: check init type if given a tv, // not supported currently though. TORCH_CHECK( init->isConstScalar(), "Cannot create a reduction operation where the initial value is not a const scalar."); // TODO: // Validate axis relationships between a and b TORCH_CHECK(tv_a->nDims() > 0, "Tried to reduce a 0-dim tensor"); // TODO: // Add tf32 and other mma data types // Add fallback path for non-mma data types. TORCH_CHECK(tv_a->getDataType().value() == DataType::Half); TORCH_CHECK(tv_b->getDataType().value() == DataType::Half); TORCH_CHECK(axes.size() > 0, "No reduction axis specified"); // TODO: // will lift this in a follow up when we have a // more generic axes matching. TORCH_CHECK( axes.size() == 1, "Single axis reduction only for mma op instantiation.") std::vector uint_axes; const int ndims = tv_a->domain()->noReductions().size(); for (int axis : axes) { if (axis < 0) { axis += ndims; } TORCH_CHECK( axis >= 0 && axis < ndims, "Reduction on invalid axis, recieved: ", axis, " however tensor view only has ", ndims, " non-reduction dims."); uint_axes.push_back((unsigned int)axis); } TensorView* out = newForMma(tv_a, tv_b, uint_axes); IrBuilder::create(out, tv_a, tv_b, init); return out; } } // namespace cuda } // namespace fuser } // namespace jit } // namespace torch