#include #include #include #include #include namespace torch { namespace jit { namespace fuser { namespace cuda { namespace { // Will return a new value of type val with the DataType dtype. Val* newScalar(ValType vtype, DataType dtype) { switch (vtype) { case (ValType::NamedScalar): case (ValType::Scalar): switch (dtype) { case DataType::Bool: return new Bool(); case DataType::Float: return new Float(); case DataType::Half: return new Half(); case DataType::Int: return new Int(); default: break; } default: break; } TORCH_CHECK( false, "Was expecting a scalar type, but received ValType: ", vtype, " with DataType:", dtype); } TensorView* newOutputTV(const std::vector& vals, DataType dtype) { std::vector tvs; for (auto val : vals) if (val->getValType() == ValType::TensorView) tvs.push_back(val->as()); TORCH_CHECK( !tvs.empty(), "Tried to create new output TensorView but received empty list."); std::vector out_domain( TensorDomain::noReductions(tvs[0]->getRootDomain()).size(), nullptr); for (auto tv : tvs) { auto dom = TensorDomain::noReductions(tv->getRootDomain()); TORCH_INTERNAL_ASSERT( dom.size() == out_domain.size(), "Invalid tensor view found while producing and output, it has ", dom.size(), " dimensions but expected ", out_domain.size()); for (const auto i : c10::irange(dom.size())) { if (out_domain[i] != nullptr) continue; if (dom[i]->isBroadcast()) continue; out_domain[i] = new IterDomain(dom[i]->start(), dom[i]->extent()); } } for (const auto dim_i : c10::irange(out_domain.size())) { if (out_domain[dim_i] == nullptr) { IterType itype = IterType::BroadcastWithoutStride; for (const auto tv : tvs) { auto dim = TensorDomain::noReductions(tv->getRootDomain())[dim_i]; // If there's an unresolved bcast dim and it came from a strided dim, // assume output of it should be strided too if (dim->getIterType() == IterType::BroadcastWithStride) { itype = IterType::BroadcastWithStride; break; } } out_domain[dim_i] = new IterDomain(new Int(0), new Int(1), ParallelType::Serial, itype); } } return new TensorView( new TensorDomain(out_domain, std::vector(out_domain.size(), true)), dtype); } std::vector maybeBroadcast(const std::vector& vals) { std::vector out_vals(vals.size(), nullptr); size_t n_dims = 0; for (auto val : vals) { if (val->getValType().value() == ValType::TensorView) { n_dims = std::max( n_dims, TensorDomain::noReductions(val->as()->getRootDomain()) .size()); } } for (const auto i : c10::irange(vals.size())) { if (vals[i]->getValType().value() == ValType::TensorView) { auto tv = vals[i]->as(); size_t tv_dims = TensorDomain::noReductions(tv->getRootDomain()).size(); if (tv_dims < n_dims) { std::vector bcast_flags(n_dims, false); for (size_t j = 0; j < n_dims - tv_dims; j++) { bcast_flags[j] = true; } out_vals[i] = broadcast(tv, bcast_flags); } else { out_vals[i] = vals[i]; } } else { out_vals[i] = vals[i]; } } return out_vals; } Val* newOutputVal(const std::vector& vals) { ValType out_vtype = vals[0]->getValType().value(); DataType out_dtype = vals[0]->getDataType().value(); for (auto val : vals) { TORCH_CHECK(val->isVal(), "Invalid statement found during promotion."); TORCH_CHECK( val->getDataType().value() != DataType::Null, "Invalid datatype found during prmotion."); out_vtype = promote_type(out_vtype, val->getValType().value()); out_dtype = promote_type(out_dtype, val->getDataType().value()); } if (out_vtype == ValType::TensorView) return newOutputTV(vals, out_dtype); return newScalar(out_vtype, out_dtype); } Val* newValLike(Val* val, DataType dtype) { TORCH_CHECK(val->isVal(), "Invalid statement provided to create new value."); TORCH_CHECK( dtype != DataType::Null, "Invalid datatype provided for new value."); ValType vtype = val->getValType().value(); if (vtype == ValType::TensorView) return newOutputTV({val}, dtype); return newScalar(vtype, dtype); } } // namespace Val* castOp(DataType dtype, Val* v1) { if (v1->getDataType().value() == dtype) return v1; if (cast_func_str(std::make_pair(v1->getDataType().value(), dtype)) == c10::nullopt) { TORCH_CHECK( false, "Illegal Cast value from DataType: ", v1->getDataType().value(), " to DataType: ", dtype); } Val* out = newValLike(v1, dtype); new UnaryOp(UnaryOpType::Cast, out, v1); return out; } TensorView* castOp(DataType dtype, TensorView* v1) { return castOp(dtype, v1->as())->as(); } // UNARY OPERATIONS Val* unaryOp(UnaryOpType type, Val* v1) { Val* out = newOutputVal({v1}); new UnaryOp(type, out, v1); return out; } TensorView* unaryOp(UnaryOpType type, TensorView* v1) { return unaryOp(type, v1->as())->as(); } Val* neg(Val* v) { return unaryOp(UnaryOpType::Neg, v); } TensorView* neg(TensorView* v) { return unaryOp(UnaryOpType::Neg, v); } // BINARY OPERATIONS namespace { // Helper function to reduce repetitive code template TensorView* arithOpOverloads(Val* (*func)(Val*, Val*), T1* v1, T2* v2) { return func(v1->template as(), v2->template as()) ->template as(); } template TensorView* arithOpOverloads(BinaryOpType type, T1* v1, T2* v2) { return binaryOp(type, v1->template as(), v2->template as()) ->template as(); } template TensorView* arithOpOverloads( Val* (*func)(Val*, Val*, Val*), T1* v1, T2* v2, T3* v3) { auto vals = maybeBroadcast({v1, v2, v3}); return func( vals[0]->template as(), vals[1]->template as(), vals[2]->template as()) ->template as(); } template TensorView* arithOpOverloads( Val* (*func)(Val*, Val*, Val*, Val*), T1* v1, T2* v2, T3* v3, T4* v4) { auto vals = maybeBroadcast({v1, v2, v3, v4}); return func( vals[0]->template as(), vals[1]->template as(), vals[2]->template as(), vals[3]->template as()) ->template as(); } } // namespace TORCH_CUDA_CU_API Val* binaryOp(BinaryOpType type, Val* v1, Val* v2) { auto vals = maybeBroadcast({v1, v2}); Val* out = newOutputVal({vals[0], vals[1]}); if (is_logical_op(type)) { if (out->getDataType().value() != DataType::Bool) out = newValLike(out, DataType::Bool); } else if (type >= BinaryOpType::Mod) { if (out->getDataType().value() != DataType::Int) out = newValLike(out, DataType::Int); } new BinaryOp(type, out, vals[0], vals[1]); return out; } TensorView* binaryOp(BinaryOpType type, TensorView* v1, Val* v2) { return arithOpOverloads(type, v1, v2); } TensorView* binaryOp(BinaryOpType type, Val* v1, TensorView* v2) { return arithOpOverloads(type, v1, v2); } TensorView* binaryOp(BinaryOpType type, TensorView* v1, TensorView* v2) { return arithOpOverloads(type, v1, v2); } // add Val* add(Val* v1, Val* v2) { return binaryOp(BinaryOpType::Add, v1, v2); } TensorView* add(TensorView* v1, Val* v2) { return arithOpOverloads(add, v1, v2); } TensorView* add(Val* v1, TensorView* v2) { return arithOpOverloads(add, v1, v2); } TensorView* add(TensorView* v1, TensorView* v2) { return arithOpOverloads(add, v1, v2); } // sub Val* sub(Val* v1, Val* v2) { return binaryOp(BinaryOpType::Sub, v1, v2); } TensorView* sub(TensorView* v1, Val* v2) { return arithOpOverloads(sub, v1, v2); } TensorView* sub(Val* v1, TensorView* v2) { return arithOpOverloads(sub, v1, v2); } TensorView* sub(TensorView* v1, TensorView* v2) { return arithOpOverloads(sub, v1, v2); } // mul Val* mul(Val* v1, Val* v2) { return binaryOp(BinaryOpType::Mul, v1, v2); } TensorView* mul(TensorView* v1, Val* v2) { return arithOpOverloads(mul, v1, v2); } TensorView* mul(Val* v1, TensorView* v2) { return arithOpOverloads(mul, v1, v2); } TensorView* mul(TensorView* v1, TensorView* v2) { return arithOpOverloads(mul, v1, v2); } // div Val* div(Val* v1, Val* v2) { return binaryOp(BinaryOpType::Div, v1, v2); } TensorView* div(TensorView* v1, Val* v2) { return arithOpOverloads(div, v1, v2); } TensorView* div(Val* v1, TensorView* v2) { return arithOpOverloads(div, v1, v2); } TensorView* div(TensorView* v1, TensorView* v2) { return arithOpOverloads(div, v1, v2); } // mod Val* mod(Val* v1, Val* v2) { return binaryOp(BinaryOpType::Mod, v1, v2); } TensorView* mod(TensorView* v1, Val* v2) { return arithOpOverloads(mod, v1, v2); } TensorView* mod(Val* v1, TensorView* v2) { return arithOpOverloads(mod, v1, v2); } TensorView* mod(TensorView* v1, TensorView* v2) { return arithOpOverloads(mod, v1, v2); } // lt Val* lt(Val* v1, Val* v2) { return binaryOp(BinaryOpType::LT, v1, v2); } TensorView* lt(TensorView* v1, Val* v2) { return arithOpOverloads(lt, v1, v2); } TensorView* lt(Val* v1, TensorView* v2) { return arithOpOverloads(lt, v1, v2); } TensorView* lt(TensorView* v1, TensorView* v2) { return arithOpOverloads(lt, v1, v2); } // eq Val* eq(Val* v1, Val* v2) { return binaryOp(BinaryOpType::Eq, v1, v2); } TensorView* eq(TensorView* v1, Val* v2) { return arithOpOverloads(eq, v1, v2); } TensorView* eq(Val* v1, TensorView* v2) { return arithOpOverloads(eq, v1, v2); } TensorView* eq(TensorView* v1, TensorView* v2) { return arithOpOverloads(eq, v1, v2); } // ceilDiv Val* ceilDiv(Val* v1, Val* v2) { return binaryOp(BinaryOpType::CeilDiv, v1, v2); } TensorView* ceilDiv(TensorView* v1, Val* v2) { return arithOpOverloads(ceilDiv, v1, v2); } TensorView* ceilDiv(Val* v1, TensorView* v2) { return arithOpOverloads(ceilDiv, v1, v2); } TensorView* ceilDiv(TensorView* v1, TensorView* v2) { return arithOpOverloads(ceilDiv, v1, v2); } // andOp Val* andOp(Val* v1, Val* v2) { TORCH_CHECK( v1->getDataType().value() == DataType::Bool, "Input1 should be of type bool, not ", v1->getDataType().value()); TORCH_CHECK( v2->getDataType().value() == DataType::Bool, "Input2 should be of type bool, not ", v2->getDataType().value()); return binaryOp(BinaryOpType::And, v1, v2); } TensorView* andOp(TensorView* v1, Val* v2) { return arithOpOverloads(andOp, v1, v2); } TensorView* andOp(Val* v1, TensorView* v2) { return arithOpOverloads(andOp, v1, v2); } TensorView* andOp(TensorView* v1, TensorView* v2) { return arithOpOverloads(andOp, v1, v2); } // REDUCTION OPERATIONS // TODO: How do we adjust this so we can reduce to a single scalar value? static TensorView* newForReduction( TensorView* tv, const std::vector& axes) { auto orig_domain = TensorDomain::noReductions(tv->getRootDomain()); std::set axes_set(axes.begin(), axes.end()); std::vector new_domain; TORCH_INTERNAL_ASSERT( !axes_set.empty(), "Asked for ouput of reduction, but no reduction axis provided."); TORCH_INTERNAL_ASSERT( (*(axes_set.rbegin())) < orig_domain.size(), "Error setting up reduction, reduction axis is outside nDims. Keep in mind reductions are relative to root domains, not modified views."); auto axis_iter = axes_set.begin(); for (const auto dim : c10::irange(orig_domain.size())) { bool isReduction = false; if (axis_iter != axes_set.end() && *axis_iter == dim) { isReduction = true; axis_iter++; } const IterDomain* id = orig_domain[dim]; TORCH_CHECK( !(isReduction && id->isBroadcast()), "Cannot reduce an axis that is marked as broadcasted as it has an undetermined size. Tried to reduce ID = ", id, " of tensor ", tv); new_domain.push_back(new IterDomain( id->start(), id->extent(), ParallelType::Serial, isReduction ? IterType::Reduction : id->getIterType())); } TensorDomain* td = new TensorDomain(new_domain, std::vector(new_domain.size(), true)); return new TensorView(td, tv->getDataType().value()); } TensorView* reductionOp( BinaryOpType reduction_op_type, const std::vector& axes, Val* init, TensorView* tv) { TORCH_CHECK( init->isConstScalar(), "Cannot create a reduction operation where the initial value is not a const scalar."); TORCH_CHECK( TensorDomain::sameAs(tv->getRootDomain(), tv->domain()->domain()), "Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/computeAt."); TORCH_CHECK(tv->nDims() > 0, "Tried to reduce a 0-dim tensor"); TORCH_CHECK(axes.size() > 0, "No reduction axis specified"); std::vector uint_axes; for (int axis : axes) { if (axis < 0) axis += int(tv->nDims()); TORCH_CHECK( axis >= 0 && (unsigned int)axis < tv->nDims(), "Reduction on invalid axis, recieved: ", axis, " however tensor view only has ", tv->nDims(), " dims."); uint_axes.push_back((unsigned int)axis); } TensorView* out = newForReduction(tv, uint_axes); if (init->getDataType().value() != tv->getDataType().value()) init = castOp(tv->getDataType().value(), init); new ReductionOp(reduction_op_type, init, out, tv); return out; } TensorView* sum(TensorView* v1, const std::vector& axes) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) Val* init; switch (v1->getDataType().value()) { case (DataType::Float): init = new Float(0.0); break; case (DataType::Int): init = new Int(0); break; default: TORCH_CHECK( false, "Could not generate a sum op for tensor with type: ", v1->getDataType().value()); } return reductionOp(BinaryOpType::Add, axes, init, v1); } TensorView* broadcast( TensorView* inp, const std::vector& is_broadcast_dim) { auto nBCastDims = is_broadcast_dim.size(); // Validate is_broadcast_dim unsigned int n_broadcasts = 0; for (auto ent : is_broadcast_dim) if (ent) n_broadcasts++; TORCH_CHECK( nBCastDims - n_broadcasts == TensorDomain::noReductions(inp->getRootDomain()).size(), "Invalid broadcast, number of false entries in is_broadcast_dim expected to be ", TensorDomain::noReductions(inp->getRootDomain()).size(), " but received ", nBCastDims - n_broadcasts); if (n_broadcasts == 0) { auto identity = unaryOp(UnaryOpType::Set, inp); TORCH_INTERNAL_ASSERT( identity->getValType().value() == ValType::TensorView, "Expected identity op, but didn't get a TensorView back."); return identity->as(); } std::vector out_domain; auto inp_domain = TensorDomain::noReductions(inp->getRootDomain()); size_t iinp = 0, ibdim = 0; while (ibdim < is_broadcast_dim.size()) { if (is_broadcast_dim[ibdim]) { out_domain.push_back(new IterDomain( new Int(0), new Int(1), ParallelType::Serial, IterType::BroadcastWithoutStride)); } else { // Don't propagate reduction IDs through arith ops. out_domain.push_back(inp_domain[iinp]); iinp++; } ibdim++; } TensorView* out_tensor = new TensorView( new TensorDomain(out_domain, std::vector(out_domain.size(), true)), inp->getDataType().value()); new BroadcastOp(out_tensor, inp); return out_tensor; } // COMPOUND OPERATIONS // add_alpha Val* add_alpha(Val* v1, Val* v2, Val* s) { TORCH_CHECK( s->getValType().value() == ValType::Scalar, "Alpha value should be a Scalar Valtype and not ", s->getValType().value()); auto vals = maybeBroadcast({v1, v2, s}); Val* intrm = binaryOp(BinaryOpType::Mul, vals[1], vals[2]); return binaryOp(BinaryOpType::Add, vals[0], intrm); } TensorView* add_alpha(TensorView* v1, Val* v2, Val* v3) { return arithOpOverloads(add_alpha, v1, v2, v3); } TensorView* add_alpha(Val* v1, TensorView* v2, Val* v3) { return arithOpOverloads(add_alpha, v1, v2, v3); } TensorView* add_alpha(TensorView* v1, TensorView* v2, Val* v3) { return arithOpOverloads(add_alpha, v1, v2, v3); } // sub_alpha Val* sub_alpha(Val* v1, Val* v2, Val* s) { TORCH_CHECK( s->getValType().value() == ValType::Scalar, "Alpha value should be a Scalar Valtype and not ", s->getValType().value()); auto vals = maybeBroadcast({v1, v2, s}); Val* intrm = binaryOp(BinaryOpType::Mul, vals[1], vals[2]); return binaryOp(BinaryOpType::Sub, vals[0], intrm); } TensorView* sub_alpha(TensorView* v1, Val* v2, Val* v3) { return arithOpOverloads(sub_alpha, v1, v2, v3); } TensorView* sub_alpha(Val* v1, TensorView* v2, Val* v3) { return arithOpOverloads(sub_alpha, v1, v2, v3); } TensorView* sub_alpha(TensorView* v1, TensorView* v2, Val* v3) { return arithOpOverloads(sub_alpha, v1, v2, v3); } // lerp TORCH_CUDA_CU_API Val* lerp(Val* start, Val* end, Val* weight) { auto vals = maybeBroadcast({start, end, weight}); Val* intrm1 = binaryOp(BinaryOpType::Sub, vals[1], vals[0]); Val* intrm2 = binaryOp(BinaryOpType::Mul, vals[2], intrm1); return binaryOp(BinaryOpType::Add, vals[0], intrm2); } TensorView* lerp(TensorView* v1, Val* v2, Val* v3) { return arithOpOverloads(lerp, v1, v2, v3); } TensorView* lerp(Val* v1, TensorView* v2, Val* v3) { return arithOpOverloads(lerp, v1, v2, v3); } TensorView* lerp(Val* v1, Val* v2, TensorView* v3) { return arithOpOverloads(lerp, v1, v2, v3); } TensorView* lerp(TensorView* v1, TensorView* v2, Val* v3) { return arithOpOverloads(lerp, v1, v2, v3); } TensorView* lerp(TensorView* v1, Val* v2, TensorView* v3) { return arithOpOverloads(lerp, v1, v2, v3); } TensorView* lerp(Val* v1, TensorView* v2, TensorView* v3) { return arithOpOverloads(lerp, v1, v2, v3); } TensorView* lerp(TensorView* v1, TensorView* v2, TensorView* v3) { return arithOpOverloads(lerp, v1, v2, v3); } // addcmul Val* addcmul(Val* v1, Val* v2, Val* v3, Val* s) { TORCH_CHECK( s->getValType().value() == ValType::Scalar, "Alpha value should be a Scalar Valtype and not ", s->getValType().value()); auto vals = maybeBroadcast({v1, v2, v3, s}); Val* intrm1 = binaryOp(BinaryOpType::Mul, vals[2], vals[3]); Val* intrm2 = binaryOp(BinaryOpType::Mul, vals[1], intrm1); return binaryOp(BinaryOpType::Add, vals[0], intrm2); } TensorView* addcmul(TensorView* v1, Val* v2, Val* v3, Val* v4) { return arithOpOverloads(addcmul, v1, v2, v3, v4); } TensorView* addcmul(Val* v1, TensorView* v2, Val* v3, Val* v4) { return arithOpOverloads(addcmul, v1, v2, v3, v4); } TensorView* addcmul(Val* v1, Val* v2, TensorView* v3, Val* v4) { return arithOpOverloads(addcmul, v1, v2, v3, v4); } TensorView* addcmul(TensorView* v1, TensorView* v2, Val* v3, Val* v4) { return arithOpOverloads(addcmul, v1, v2, v3, v4); } TensorView* addcmul(TensorView* v1, Val* v2, TensorView* v3, Val* v4) { return arithOpOverloads(addcmul, v1, v2, v3, v4); } TensorView* addcmul(Val* v1, TensorView* v2, TensorView* v3, Val* v4) { return arithOpOverloads(addcmul, v1, v2, v3, v4); } TensorView* addcmul(TensorView* v1, TensorView* v2, TensorView* v3, Val* v4) { return arithOpOverloads(addcmul, v1, v2, v3, v4); } // TERNARY OPERATIONS // where Val* where(Val* c, Val* v1, Val* v2) { TORCH_CHECK( c->getDataType().value() == DataType::Bool, "Condition should be of DataType Bool, not ", c->getDataType().value()); auto vals = maybeBroadcast({c, v1, v2}); Val* out = newOutputVal({vals[1], vals[2]}); new TernaryOp(TernaryOpType::Where, out, vals[0], vals[1], vals[2]); return out; } TensorView* where(TensorView* v1, Val* v2, Val* v3) { return arithOpOverloads(where, v1, v2, v3); } TensorView* where(Val* v1, TensorView* v2, Val* v3) { return arithOpOverloads(where, v1, v2, v3); } TensorView* where(Val* v1, Val* v2, TensorView* v3) { return arithOpOverloads(where, v1, v2, v3); } TensorView* where(TensorView* v1, TensorView* v2, Val* v3) { return arithOpOverloads(where, v1, v2, v3); } TensorView* where(TensorView* v1, Val* v2, TensorView* v3) { return arithOpOverloads(where, v1, v2, v3); } TensorView* where(Val* v1, TensorView* v2, TensorView* v3) { return arithOpOverloads(where, v1, v2, v3); } TensorView* where(TensorView* v1, TensorView* v2, TensorView* v3) { return arithOpOverloads(where, v1, v2, v3); } // TERNARY OPERATIONS Val* threshold(Val* in, Val* thresh, Val* value) { TORCH_CHECK( in->getDataType().value() == thresh->getDataType().value() && in->getDataType().value() == value->getDataType().value(), "All input DataType values should match the input ", in->getDataType().value()); TORCH_CHECK( thresh->getValType().value() == ValType::Scalar && value->getValType().value() == ValType::Scalar, "Thresh and Value values should be Scalars"); Val* out = newOutputVal({in}); new TernaryOp(TernaryOpType::Threshold, out, in, thresh, value); return out; } TensorView* threshold(TensorView* in, Val* thresh, Val* value) { return threshold(in->as(), thresh, value)->as(); } Val* clamp(Val* in, Val* min_val, Val* max_val) { TORCH_CHECK( in->getDataType().value() == min_val->getDataType().value() && in->getDataType().value() == max_val->getDataType().value(), "All input DataType values should match the input ", in->getDataType().value()); TORCH_CHECK( min_val->getValType().value() == ValType::Scalar && max_val->getValType().value() == ValType::Scalar, "Min and Max values should be Scalars"); Val* out = newOutputVal({in}); new TernaryOp(TernaryOpType::Clamp, out, in, min_val, max_val); return out; } TensorView* clamp(TensorView* in, Val* min_val, Val* max_val) { return clamp(in->as(), min_val, max_val)->as(); } } // namespace cuda } // namespace fuser } // namespace jit } // namespace torch