#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch { namespace jit { typedef Value JitValue; typedef Node JitOp; namespace fuser { namespace cuda { constexpr auto kNumUnaryOps = 10; constexpr auto kNumUnaryFloatOps = 23; constexpr auto kNumUnaryIsOps = 6; constexpr auto kNumBinaryFloatOps = 3; constexpr auto kNumBinaryComparisonOps = 12; constexpr auto kNumBinaryCastOps = 19; constexpr auto kNumBinaryOpsWithAlpha = 6; constexpr auto kNumLerpOps = 2; constexpr auto kNumLayernormFwd = 2; constexpr auto kNumBatchnormFwd = 3; constexpr auto kNumBatchnormBwd = 2; constexpr auto kNumInstancenormFwd = 1; constexpr auto kNumSumToSize = 2; constexpr auto kNumAutocastOps = 2; constexpr auto kNumAliasDimOps = 2; constexpr auto kNumViewOps = 2; constexpr auto kNumVarOps = 2; constexpr auto kNumSoftmaxFwd = 2; constexpr auto kNumSoftmaxBwd = 2; constexpr auto kNumAminAmaxOps = 2; namespace { #define REGISTER_PARSE_RULE(op, func_body, ...) \ registerParseRule( \ op, \ [](const Node* node, std::unordered_map& value_map) \ -> void func_body, \ __VA_ARGS__) const auto& reductionSizeAttr = Symbol::attr("profiled_reduction_size"); const auto& viewSizeAttr = Symbol::attr("profiled_view_size"); const auto& intListAttr = Symbol::attr("profiled_int_list"); const auto& intAttr = Symbol::attr("profiled_int"); const auto& boolListAttr = Symbol::attr("profiled_bool_list"); const auto& boolAttr = Symbol::attr("profiled_bool"); const auto& strAttr = Symbol::attr("profiled_str"); const auto& ivalAttr = Symbol::attr("profiled_ival"); const auto& profileFailedAttr = Symbol::attr("profile_failed"); typedef Val* CgValue; typedef Expr* CgOp; Val* castTensoToDtype(CgValue self, JitValue* cast_val) { auto cast_ival = toIValue(cast_val); // we need static type for cast TORCH_INTERNAL_ASSERT(cast_ival.has_value()); if (cast_ival->isInt()) { auto dtype = cast_ival->toScalarType(); // We want to keep our internal fusion math in FP32 // Shape Inference will continue to propagate the right // type to outputs unchanged. if (dtype == at::ScalarType::Half || dtype == at::ScalarType::BFloat16) { dtype = at::ScalarType::Float; } return castOp(aten_to_data_type(dtype), self); } else { TORCH_INTERNAL_ASSERT( cast_ival->isNone(), "unrecognized dtype option, expect 'int' but got: ", cast_ival->tagKind()); // return a copy if dtype is `None` return set(self); } } bool isReductionNonCompatibleTensor( const std::shared_ptr& tensor_type) { return is_zero_dim_tensor(tensor_type) || is_zero_sized_tensor(tensor_type); } bool isInputNonSizeZeroTensor(const Node* node) { for (const auto& val : node->inputs()) { auto tensor_type = val->type()->cast(); if (tensor_type && is_zero_sized_tensor(tensor_type)) { return false; } } return true; } bool isScalarTypeCompatible(const Node* node, size_t offset) { auto val = node->input(offset); // return true if it's not specified if (val->type()->isSubtypeOf(static_cast(NoneType::get()))) { return true; } // return false if it's runtime value if (val->node()->kind() != prim::Constant) { return false; } auto dtype = toIValue(val)->toScalarType(); // we do NOT support half math type yet if (dtype == at::ScalarType::Half || dtype == at::ScalarType::BFloat16) { return false; } return true; } // Note [ Permutation Bookkeeping and Propagation in Parser ] // // The goal in supporting permutation propagation in parser is to: // 1. resolves conflicts and propagate permutation; // 2. bookkeeping of permutation on existing tensors; // // The requirement right now is that all parsing rules should support // non-permuted inputs, some binary operations support inputs with arbitrary // permutation, a few operations support special inputs. // In case where "wrong" inputs are fed to an operation, we should transpose // it to proper supported permutation. This allows us to progressively expand // permutation support. // Currently we bind all permuted codegen Val in `ValueHolder`. This saves // unnecessary transpose (not sure if it actually helps) since we can reuse // permuted tensors. // // Parsing rule pattern: // a. ops that only support non-permuted inputs (e.g. sum) // // // Specifying `MemoryFormat::Contiguous` here to force all inputs to be in // // `Contiguous` // auto [format, self] = getConsistentValues( // MemoryFormat::Contiguous, // value_map[node->inputs()[0]->unique()]); // // ... use self // // b. format agnostic ops (e.g. PW unary/binary op like aten::add) // // // getConsistentValues -> return target format and copies of operands in // // the same format // auto [format, lhs, rhs] = getConsistentValues( // c10::nullopt, // value_map[node->inputs()[0]->unique()], // value_map[node->inputs()[1]->unique()]); // // // compute out // auto out = binaryOp(op_mapping[node->kind()], lhs, rhs); // // specify `format` for out when adding it to `value_map_` // value_map.emplace(node->output()->unique(), ValueHolder(out, format)); // // c. ops that supports special permutation. e.g. aten::batch_norm with // channels-last inputs. struct MemoryFormat { // indices of dimensions with increasing stride. std::vector permuted_order_; // permutation_ encodes `permuted_order_` by concatenating all elements, with // the exception for unpermuted tensor, where we special case permutation_ to // be 0. // // e.g. for an channels-last tensor, permutation_ would be (n-1)123...(n-2); // Note: we are omitting the leading '0' when applicable, and apparently this // encoding only works with rank < 10 // see [ Note: MemoryFormat and Stride Order ] size_t permutation_ = 0; // default to non-permuted tensor MemoryFormat() = default; // [ Note: MemoryFormat and Stride Order ] // stride_order is extracted from // `TensorType::stride_properties()::stride_index_`, it describes the // index of axes from fastest to slowest. // or a 4d tensor, if we have stride_order = {x0, x1, x2, x3}, The i-th // fastest dimension would be stride_order[i]. // // Look at comment for c10::Stride in aten/src/ATen/core/jit_type.h // // eg0. for rank 4 non-permuted tensor, stride_order would be {3, 2, 1, 0}, it // means the fastest dimension is axis-3. the next one would be 2, e.t.c.. So // it's a non-permuted tensor. // it should be encoded as permutation_ = 3210 (we special case it to 0) // // eg1. for rank 4 channels-last tensor, stride_order would be {1, 3, 2, 0}, // it means the fastest dimension is axis-1. the next one would be 3, and then // 2, and then 0. So this is a channels last tensor (NCHW). // it will be encoded as permutation_ = 1320 // // eg2. for a rank 4 permuted tensor, stride_order can be {0, 3, 2, 1} // it will be encoded as permutation_ = 321 (omitting leading '0') void setPermutation(const std::vector& stride_order) { int rank = stride_order.size(); TORCH_INTERNAL_ASSERT( rank <= 10, "MemoryFormat for permutation only supports rank <= 10"); // storing stride_order in `permuted_order` for a simpler life, so we don't // have to decode `permutation_` when we want to apply/restore permutation_. permuted_order_ = stride_order; bool has_permutation = false; permutation_ = 0; for (const auto i : c10::irange(rank)) { permutation_ = permutation_ * 10 + stride_order[i]; if (!has_permutation && stride_order[i] != rank - 1 - i) { has_permutation = true; } } // special case permutation_ to reflect non-permuted tensor if (!has_permutation) { permutation_ = 0; } } // returns the stride order for given MemoryFormat encoding permutation_ // // see details for encoding in [ Note: MemoryFormat and Stride Order ] std::vector toStrideOrder() const { std::vector stride_order; // return empty vector for no permutation if (hasPermutation()) { // be generous with reserved space stride_order.reserve(10); bool encountered_zero = false; size_t permutation = permutation_; while (permutation != 0) { int order = static_cast(permutation % 10); permutation /= 10; if (order == 0) { encountered_zero = true; } stride_order.push_back(order); } if (!encountered_zero) { // in case leading '0' is omitted, push it back stride_order.push_back(0); } // since we use push_back, our stride_order is reversed. std::reverse(stride_order.begin(), stride_order.end()); } return stride_order; } // returns c10::nullopt when it's not safe to broadcast current permutation to // rank c10::optional broadcastToRank(size_t rank) const { auto ret = Contiguous(); if (hasPermutation()) { auto stride_order = toStrideOrder(); auto cur_rank = stride_order.size(); // no op for (cur_rank == 0) || (cur_rank == rank) if (cur_rank < rank) { // broadcasting to hight rank can be done by: // 1. incrementing all existing stride order by rank_diff; // 2. push back decrementing elements starting with rank_diff; // where rank_diff = rank - cur_rank // // see [ Note: MemoryFormat and Stride Order] // e.g. // taking broadcasted bias for channels last as an example // stride_order = {0, 2, 1} broadcasted to rank == 4 would give us // rank_diff = 4 - 3 = 1 // take step 1 -> {1, 3, 2} // take step 2 -> {1, 3, 2, 0} int rank_diff = static_cast(rank - cur_rank); for (auto& val : stride_order) { val += rank_diff; } for (int i = rank_diff - 1; i >= 0; i--) { stride_order.push_back(i); } } else if (cur_rank > rank) { // shrink permutation to lower rank. We can simply discard higher rank // stride order when they are not permuted to lower rank bit, because in // those instance we can't obey broadcasting semantics while preserving // permutation. We check for stride order and ensure that the lower // `rank` bits are all permuted within the lower rank. Afterwards, we // update stride_order by decrement each entry by rank_diff to reflect // correct stride order. // // see [ Note: MemoryFormat and Stride Order] // e.g. for rank 4 channels last {1, 3, 2, 0}: // 1. format can safely shrink to rank 3, since any@{1, 3, 2} >= // (4-3); We ditch last (4-3) rank and decrement each element by (4-1) // that gives us {0, 2, 1}; // 2. but when we shrink it to rank 2, we have {1, 3} where 1 < (4-2) // and it can't be handled, we return c10::nullopt. int collapsed_ranks = static_cast(cur_rank - rank); for (size_t i = 0; i < rank; i++) { if (stride_order[i] < collapsed_ranks) { // illegal collapsing, return c10::nullopt return c10::nullopt; } // update collapsed stride_order stride_order[i] -= collapsed_ranks; } // discard higher rank stride order. stride_order.resize(rank); } ret.setPermutation(stride_order); } return ret; } // returns non-permuted format static MemoryFormat Contiguous() { return MemoryFormat(); } bool hasPermutation() const { return permutation_ != 0; } bool isChannelsLast() const { int rank = permuted_order_.size(); if (rank > 2 && permuted_order_[0] == 1 && permuted_order_[rank - 1] == 0) { for (const auto i : c10::irange(rank - 2)) { if (permuted_order_[i + 1] != rank - 1 - i) { return false; } } return true; } return false; } // returns transpose map to achieve permutation on non-permuted tensor // note: used for aten::permute API and codegen tranpose API std::vector apply() const { std::vector ret; if (hasPermutation()) { ret.resize(permuted_order_.size()); std::copy(permuted_order_.rbegin(), permuted_order_.rend(), ret.begin()); } return ret; } // returns transpose map to restore back to non-permuted tensor // note: used for aten::permute API and codegen transpose API std::vector restore() const { std::vector ret; if (hasPermutation()) { int rank = permuted_order_.size(); ret.resize(rank); for (const auto i : c10::irange(rank)) { ret[permuted_order_[i]] = rank - 1 - i; } } return ret; } }; struct MemoryCompare { bool operator()(const MemoryFormat& format0, const MemoryFormat& format1) const { return format0.permutation_ < format1.permutation_; } }; typedef std::map MemoryFormatMap; MemoryFormat operator+(const MemoryFormat& a, const MemoryFormat& b) { // Note: TensorIterator logic uses first input to dominate output MemoryFormat // so instead of `a.permutation_ >= b.permutation_ ? a : b;`, we use: return a; }; //! ValueHolder is holds multiple copies in different permutation `MemoryFormat` //! of a tensor view. This mainly serves two purposes: //! //! 1. reuse permuted tensor views among consumers //! 2. bookkeeping for permuted tensor views in input/output tensors //! //! refer to Note [ Permutation Bookkeeping and Propagation in Parser ] class ValueHolder { public: // checks if given Val in target format exists. bool hasValue(const MemoryFormat& format) const { return vals_.count(format) != 0; } // returns Val in target format. CgValue value(const MemoryFormat& format) const { auto iter_val = vals_.find(format); TORCH_INTERNAL_ASSERT( iter_val != vals_.end(), "accessing non existing c_last_value()"); return iter_val->second; } // returns Val in target format if it exists, otherwise, transpose an existing // copy and add that to bookkeeping. CgValue maybeConvertValue(const MemoryFormat& format) { auto cur_rank = rank(); // scalar (tensor) where cur_rank == 0, memory format doesn't carry meaning // and should just return the value as-is. same for non-tensor where // cur_rank == -1 if (cur_rank <= 0) { return std::get<1>(getEntry()); } MemoryFormat format_s; CgValue value_s = nullptr; std::tie(format_s, value_s) = getEntry(); auto opt_format_d = format.broadcastToRank(static_cast(cur_rank)); TORCH_INTERNAL_ASSERT( opt_format_d.has_value(), "maybeConvertValue requested for illegal permutation"); MemoryFormat format_d = opt_format_d.value(); auto iter_val = vals_.find(format_d); if (iter_val != vals_.end()) { return iter_val->second; } auto val = convertValue(format_d, format_s, value_s); vals_[format_d] = val; return val; } int rank() const { if (!is_tensor_view_) { return -1; } else { auto v = std::get<1>(getEntry()); TORCH_INTERNAL_ASSERT( v->isA(), "can only access rank of TensorView"); return static_cast(v->as()->nDims()); } } // TODO: delete this and update accessor for value_map(_) ValueHolder() { TORCH_INTERNAL_ASSERT(false, "can't default constructor ValueHolder"); } ValueHolder(CgValue val, MemoryFormat format = MemoryFormat()) { vals_[format] = val; if (val->isA()) { is_tensor_view_ = true; } } // returns the MemoryFormat and codegen Val with the highest precedence among // existing copies. std::tuple getEntry() const { TORCH_CHECK(!vals_.empty(), "ValueHolder::getEntry() on empty vals_"); // return the last entry, this allows us to prioritize permuted (e.g. // channels-last) tensor over non-permuted tensors return *vals_.rbegin(); } // TODO: code cleaning in parser so we don't need these. // returns Val*, keeping them here just so we have less code change. CgValue operator*() const { return std::get<1>(getEntry()); } CgValue operator->() const { return std::get<1>(getEntry()); } operator CgValue() const { return std::get<1>(getEntry()); } private: // helper function to convert value_s @ format_s to format_d CgValue convertValue( MemoryFormat format_d, MemoryFormat format_s, CgValue value_s) { TORCH_INTERNAL_ASSERT( value_s->isA(), "cannot convert non-TensorView"); auto tv = value_s->as(); // TODO: we could probably merge the two if it has perf impact on generated // kernel // restore source permutation if (format_s.hasPermutation()) { tv = permute(tv, format_s.restore()); } // apply destination permutation if (format_d.hasPermutation()) { tv = permute(tv, format_d.apply()); } return tv; } private: // container to hold all copies of value in different MemoryFormat // std::unordered_map vals_; MemoryFormatMap vals_; // identify scalar Val bool is_tensor_view_ = false; }; template auto iterate(Func f, ValueHolder& val) { return f(val); } template auto iterate(Func f, ValueHolder& val, Values&... vals) { return f(val, iterate(f, vals...)); } // iterate through all vals and return the output MemoryFormat and copies of // vals. // 1. When `forced_format == c10::nullopt`, target MemoryFormat returns the // format of the first val in `vals`, this is to achieve a coherent // behavior as with eager TensorIterator; // 2. The target can be overwritten vias specifying `forced_format`. // // Note: take `Values&` by reference, since `maybeConvertValue` needs to modify // the entry and we want that to be updated in `value_map_` template std::pair> getConsistentValues( c10::optional forced_format, Values&... vals) { MemoryFormat format; if (forced_format.has_value()) { format = forced_format.value(); } else { // check for identical nDim on vals auto rank_func = [](const ValueHolder& val, int rank = 0) { int v_rank = val.rank(); v_rank = std::max(0, v_rank); if (rank == 0) { return v_rank; } else if (v_rank == 0) { return rank; } else if (rank == -1 || v_rank != rank) { return -1; } return rank; }; int rank = iterate(rank_func, vals...); // TODO: this is not needed as we are only using the first val // only apply permutation when all inputs are of identical rank, since // permutation could have changed semantics among broadcasted tensors. // Consider pointwise operation between two tensor [N, C, H, W] + [H, W] if (rank > 0) { auto format_func = [](const ValueHolder& val, MemoryFormat f = MemoryFormat::Contiguous()) { return std::get<0>(val.getEntry()) + f; }; format = iterate(format_func, vals...); } else { format = MemoryFormat::Contiguous(); } } auto convert_func = [format]( ValueHolder& val, std::list list_val = {}) { list_val.push_front(val.maybeConvertValue(format)); return list_val; }; auto list_val = iterate(convert_func, vals...); return std::make_pair(format, list_val); } // iterate through all vals and return the output MemoryFormat and copies of // vals. // 1. When `forced_format == c10::nullopt`, target MemoryFormat returns the // format of the first val in `vals`, this is to achieve a coherent // behavior as with eager TensorIterator; // 2. The target can be overwritten vias specifying `forced_format`. // // Note: take `Values&` by reference, since `maybeConvertValue` needs to modify // the entry and we want that to be updated in `value_map_` template std::pair> getPWFormatValues( c10::optional forced_format, Values&... vals) { MemoryFormat format; if (forced_format.has_value()) { format = forced_format.value(); } else { // get maximum rank on vals std::vector formats; std::vector ranks; auto max_rank_func = [&ranks](const ValueHolder& val, int rank = 0) { int v_rank = val.rank(); ranks.push_back(v_rank); return std::max(rank, v_rank); }; int max_rank = iterate(max_rank_func, vals...); // going through all permutation, keeping consistency with TensorIterator // behavior and the first tensor with highest rank dictates output // permutation auto format_func = [&formats, &max_rank]( const ValueHolder& val, MemoryFormat f = MemoryFormat::Contiguous()) { auto cur_format = std::get<0>(val.getEntry()); formats.push_back(cur_format); return val.rank() == max_rank ? cur_format : f; }; format = iterate(format_func, vals...); // we need to do pair-wise comparison to ensure that all permutation are // compatible since permutation could have changed semantics among // broadcasted tensors. Consider pointwise operation between three tensor // [N, C, H, W] + [C, H, W] + [H, W] for (size_t i = 0; i < formats.size() && format.hasPermutation(); i++) { for (size_t j = 0; j < formats.size(); j++) { // don't compare scalar tensor or scalar if (ranks[i] <= 0 || ranks[j] <= 0 || i == j) { continue; } size_t lower_rank = std::min(ranks[i], ranks[j]); auto i_format = formats[i].broadcastToRank(lower_rank); auto j_format = formats[j].broadcastToRank(lower_rank); // breaks permutation if any: // 1. i_format can't be broadcasted to lower_rank; // 2. j_format can't be broadcasted to lower_rank; if (!i_format.has_value() || !j_format.has_value()) { format = MemoryFormat::Contiguous(); } } } } auto convert_func = [format]( ValueHolder& val, std::list list_val = {}) { list_val.push_front(val.maybeConvertValue(format)); return list_val; }; auto list_val = iterate(convert_func, vals...); return std::make_pair(format, list_val); } typedef void ( *ParseFuncPtr)(const Node*, std::unordered_map&); typedef bool (*MergeQueryFuncPtr)(const Node*); // TODO: add a mutex to make it thread safe. class IrParser { enum class OperatorType { ElementWise, Reduction, ReductionToSize, Normalization }; typedef OperatorType (*OperatorTypeFuncPtr)(const Node*); class RegistrationEntry { public: RegistrationEntry( ParseFuncPtr parse_f, MergeQueryFuncPtr merge_f = nullptr, OperatorTypeFuncPtr type_f = nullptr) : parse_f_(parse_f), merge_f_(merge_f), type_f_(type_f) {} void parse( const Node* node, std::unordered_map& values) const { parse_f_(node, values); } bool isCompatible(const Node* node) const { if (merge_f_ == nullptr) { return true; } return merge_f_(node); } bool isType(const Node* node, OperatorType type) const { auto n_type = type_f_ == nullptr ? OperatorType::ElementWise : type_f_(node); return n_type == type; } private: ParseFuncPtr parse_f_; MergeQueryFuncPtr merge_f_; OperatorTypeFuncPtr type_f_; }; public: // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) IrParser(std::shared_ptr graph) : graph_(std::move(graph)) { initRegistry(); } std::unique_ptr parse() { auto fusion = std::make_unique(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) FusionGuard fg(fusion.get()); auto block = graph_->block(); std::unordered_map permuted_tensors; // register all inputs; for (auto val : block->inputs()) { TORCH_INTERNAL_ASSERT( registerValue(val), "Failure when register value: ", *(val->node()), " with type: ", val->type()->repr_str()); MemoryFormat format; Val* operand = nullptr; std::tie(format, operand) = value_map_[val->unique()].getEntry(); fusion->addInput(operand); // mark input tensor as permuted; if (format.hasPermutation()) { permuted_tensors.insert({operand, format}); } auto opt_dtype = operand->getDataType(); // computation promotion, we cast fp16 or bf16 inputs to fp32 and use // promoted type in the computation. if (opt_dtype.has_value() && (opt_dtype.value() == DataType::Half || opt_dtype.value() == DataType::BFloat16)) { Val* promoted_val = castOp(DataType::Float, operand); value_map_[val->unique()] = ValueHolder(promoted_val, format); } } // compose nodes in topo order; for (const JitOp* node : block->nodes()) { processJitNode(node); } // mark output; for (auto jit_output : block->outputs()) { MemoryFormat format; Val* operand = nullptr; std::tie(format, operand) = value_map_[jit_output->unique()].getEntry(); TensorView* out = operand->as(); // demote output dtype to be match PyTorch JIT graph. auto tensor_type = jit_output->type()->cast(); TORCH_INTERNAL_ASSERT( tensor_type, "output of fusion group is not TensorType."); if (tensor_type->scalarType().has_value()) { out = optionalCastStrict( aten_to_data_type(*tensor_type->scalarType()), out) ->as(); } if (out->isFusionOutput()) { // TODO: This is wasted memory bandwidth, we need to copy since we can't // output a tensor twice. out = set(out); } fusion->addOutput(out); // mark output tensor as permuted; if (format.hasPermutation()) { permuted_tensors.insert({out, format}); } } for (const auto& i : c10::irange(fusion->inputs().size())) { const auto& entry = permuted_tensors.find(fusion->inputs()[i]); if (entry != permuted_tensors.end()) { fusion->setPermutationOnInput(i, entry->second.apply()); } } for (const auto& i : c10::irange(fusion->outputs().size())) { const auto& entry = permuted_tensors.find(fusion->outputs()[i]); if (entry != permuted_tensors.end()) { fusion->setPermutationOnOutput(i, entry->second.restore()); } } return fusion; } static bool lookupInSymbolSet(const Node* node) { initRegistry(); std::lock_guard lock(parser_mutex_); return parser_symbol_set_.count(node->kind()) != 0; } // return nullptr if entry does not exist static const RegistrationEntry* lookupInRegistry(const Node* node) { std::lock_guard lock(parser_mutex_); if (parser_skip_set_.count(node->kind()) != 0) { return nullptr; } // we need to use maybeSchema for nodes like prim::Constant, which doesn't // have a schema auto schema_ptr = node->maybeSchema(); if (schema_ptr != nullptr) { // search cached entry first auto cache_it = cached_registry_lookup_.find(schema_ptr); if (cache_it != cached_registry_lookup_.end()) { return cache_it->second; } else { // match signature auto schema_str = canonicalSchemaString(*schema_ptr); auto iter = jit_operator_registry_.find(schema_str); if (iter != jit_operator_registry_.end()) { // update cache entry cached_registry_lookup_.insert(cache_it, {schema_ptr, &iter->second}); return &iter->second; } } } return nullptr; } static bool querySkipSymbolSet(c10::Symbol symbol, bool flip) { initRegistry(); std::lock_guard lock(parser_mutex_); // no need to init registry here (unlike `lookupInSymbolSet`, as // `parser_skip_set_` is not initialized via initialization bool ret = parser_skip_set_.count(symbol) != 0; if (flip) { if (ret) { parser_skip_set_.erase(symbol); } else { parser_skip_set_.insert(symbol); } } return ret; } static void initRegistry() { c10::call_once(once_flag_, []() { std::lock_guard lock(parser_mutex_); registerJitOperator(); }); } static bool canParseNode(const Node* node) { initRegistry(); // match signature. auto schema_ptr = node->maybeSchema(); if (schema_ptr == nullptr) { return false; } auto reg_entry = lookupInRegistry(node); return reg_entry != nullptr && reg_entry->isCompatible(node); } static bool isReductionToSizeNode(const Node* node) { initRegistry(); auto reg_entry = lookupInRegistry(node); return reg_entry != nullptr && reg_entry->isType(node, OperatorType::ReductionToSize); } static bool isReductionNode(const Node* node) { initRegistry(); auto reg_entry = lookupInRegistry(node); return reg_entry != nullptr && (reg_entry->isType(node, OperatorType::Reduction) || reg_entry->isType(node, OperatorType::ReductionToSize)); } static bool isNormalizationNode(const Node* node) { initRegistry(); auto reg_entry = lookupInRegistry(node); return reg_entry != nullptr && reg_entry->isType(node, OperatorType::Normalization); } static bool isElementWiseNode(const Node* node) { initRegistry(); auto reg_entry = lookupInRegistry(node); return reg_entry != nullptr && reg_entry->isType(node, OperatorType::ElementWise); } // TODO: is_reduction is too hacky here. we should categorize operation types // based on their memory accessing pattern, which would affect fusion // strategy and partition logic. static void registerParseRule( std::shared_ptr& op, ParseFuncPtr parse_fn, MergeQueryFuncPtr merge_query_fn = nullptr, OperatorTypeFuncPtr type_fn = nullptr) { auto op_name = op->schema().name(); parser_symbol_set_.insert(c10::Symbol::fromQualString(op_name)); // We blindly attempt to profile the inplace version of supported op, this // is to ensure that in-place removal in fusion partition would have the // profile information for them readily available after the pass. parser_symbol_set_.insert(c10::Symbol::fromQualString(op_name + '_')); jit_operator_registry_.emplace( std::piecewise_construct, std::forward_as_tuple(canonicalSchemaString(op->schema())), std::forward_as_tuple(parse_fn, merge_query_fn, type_fn)); } private: static void registerJitOperator() { // Register parse-function for each JIT operator; // This is a one-time look up, our hash registry indexes on the pointer in // OperatorRegistry. std::array BinaryOpWithAlpha = { "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor", "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor", "aten::rsub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", "aten::rsub(Tensor self, Scalar other, Scalar alpha) -> Tensor"}; for (auto signature : BinaryOpWithAlpha) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { using BinaryOpWithAlphaType = Val* (*)(Val*, Val*, Val*); static std::unordered_map< Symbol, std::pair> op_mapping( {{aten::add, std::make_pair( BinaryOpType::Add, static_cast(&add_alpha))}, {aten::sub, std::make_pair( BinaryOpType::Sub, static_cast(&sub_alpha))}, {aten::rsub, std::make_pair( BinaryOpType::Sub, static_cast(&sub_alpha))}}); // TODO: handle scaling factor when it's not constant 1; MemoryFormat format; std::list list_val; std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()]); auto lhs = list_val.front(); list_val.pop_front(); auto rhs = list_val.front(); list_val.pop_front(); Val* alpha = value_map[node->inputs()[2]->unique()]; auto out = alpha->isOneInt() ? binaryOp( op_mapping[node->kind()].first, node->kind() == aten::rsub ? rhs : lhs, node->kind() == aten::rsub ? lhs : rhs, TypePromotion::default_op_config) : (node->kind() == aten::rsub ? op_mapping[node->kind()].second(rhs, lhs, alpha) : op_mapping[node->kind()].second(lhs, rhs, alpha)); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, isInputNonSizeZeroTensor, nullptr); } std::array BinaryFloatOp = { "aten::div(Tensor self, Tensor other) -> Tensor", "aten::div(Tensor self, Scalar other) -> Tensor", "aten::atan2(Tensor self, Tensor other) -> Tensor"}; for (auto signature : BinaryFloatOp) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { static std::unordered_map op_mapping( {{aten::div, BinaryOpType::Div}, {aten::atan2, BinaryOpType::Atan2}}); MemoryFormat format; std::list list_val; std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()]); auto lhs = list_val.front(); list_val.pop_front(); auto rhs = list_val.front(); list_val.pop_front(); auto out = binaryOp( op_mapping[node->kind()], lhs, rhs, TypePromotion::float_op_config); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, isInputNonSizeZeroTensor, nullptr); } std::array BinaryCastOp = { "aten::mul(Tensor self, Tensor other) -> Tensor", "aten::mul(Tensor self, Scalar other) -> Tensor", "aten::max(Tensor self, Tensor other) -> Tensor", "aten::min(Tensor self, Tensor other) -> Tensor", "aten::pow(Tensor self, Tensor exponent) -> Tensor", "aten::pow(Tensor self, Scalar exponent) -> Tensor", "aten::pow(Scalar self, Tensor exponent) -> Tensor", "aten::remainder(Tensor self, Tensor other) -> Tensor", "aten::fmod(Tensor self, Tensor other) -> Tensor", "aten::bitwise_and(Tensor self, Tensor other) -> Tensor", "aten::__and__(Tensor self, Tensor other) -> Tensor", "aten::bitwise_or(Tensor self, Tensor other) -> Tensor", "aten::__or__(Tensor self, Tensor other) -> Tensor", "aten::bitwise_xor(Tensor self, Tensor other) -> Tensor", "aten::__xor__(Tensor self, Tensor other) -> Tensor", "aten::bitwise_left_shift(Tensor self, Tensor other) -> Tensor", "aten::__lshift__(Tensor self, Tensor other) -> Tensor", "aten::bitwise_right_shift(Tensor self, Tensor other) -> Tensor", "aten::__rshift__(Tensor self, Tensor other) -> Tensor"}; for (auto signature : BinaryCastOp) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { static std::unordered_map op_mapping( {{aten::mul, BinaryOpType::Mul}, {aten::min, BinaryOpType::Min}, {aten::max, BinaryOpType::Max}, {aten::pow, BinaryOpType::Pow}, {aten::remainder, BinaryOpType::Remainder}, {aten::fmod, BinaryOpType::Fmod}, {aten::bitwise_and, BinaryOpType::And}, {aten::__and__, BinaryOpType::And}, {aten::bitwise_or, BinaryOpType::Or}, {aten::__or__, BinaryOpType::Or}, {aten::bitwise_xor, BinaryOpType::Xor}, {aten::__xor__, BinaryOpType::Xor}, {aten::bitwise_left_shift, BinaryOpType::Lshift}, {aten::__lshift__, BinaryOpType::Lshift}, {aten::bitwise_right_shift, BinaryOpType::Rshift}, {aten::__rshift__, BinaryOpType::Rshift}}); MemoryFormat format; std::list list_val; std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()]); auto lhs = list_val.front(); list_val.pop_front(); auto rhs = list_val.front(); list_val.pop_front(); auto out = binaryOp( op_mapping[node->kind()], lhs, rhs, TypePromotion::default_op_config); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, isInputNonSizeZeroTensor, nullptr); } std::array BinaryOp = { "aten::eq(Tensor self, Tensor other) -> Tensor", "aten::eq(Tensor self, Scalar other) -> Tensor", "aten::ne(Tensor self, Tensor other) -> Tensor", "aten::ne(Tensor self, Scalar other) -> Tensor", "aten::ge(Tensor self, Tensor other) -> Tensor", "aten::ge(Tensor self, Scalar other) -> Tensor", "aten::gt(Tensor self, Tensor other) -> Tensor", "aten::gt(Tensor self, Scalar other) -> Tensor", "aten::le(Tensor self, Tensor other) -> Tensor", "aten::le(Tensor self, Scalar other) -> Tensor", "aten::lt(Tensor self, Tensor other) -> Tensor", "aten::lt(Tensor self, Scalar other) -> Tensor"}; for (auto signature : BinaryOp) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { static std::unordered_map op_mapping( {{aten::lt, BinaryOpType::LT}, {aten::le, BinaryOpType::LE}, {aten::gt, BinaryOpType::GT}, {aten::ge, BinaryOpType::GE}, {aten::ne, BinaryOpType::NE}, {aten::eq, BinaryOpType::Eq}}); MemoryFormat format; std::list list_val; std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()]); auto lhs = list_val.front(); list_val.pop_front(); auto rhs = list_val.front(); list_val.pop_front(); auto out = binaryOp( op_mapping[node->kind()], lhs, rhs, TypePromotion::comparison_op_config); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, isInputNonSizeZeroTensor, nullptr); } std::array UnaryOp = { "aten::abs(Tensor self) -> Tensor", "aten::bitwise_not(Tensor self) -> Tensor", "aten::ceil(Tensor self) -> Tensor", "aten::floor(Tensor self) -> Tensor", "aten::frac(Tensor self) -> Tensor", "aten::neg(Tensor self) -> Tensor", "aten::relu(Tensor self) -> Tensor", "aten::round(Tensor self) -> Tensor", "aten::silu(Tensor self) -> Tensor", "aten::trunc(Tensor self) -> Tensor", }; for (auto signature : UnaryOp) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { static std::unordered_map op_mapping({ {aten::abs, UnaryOpType::Abs}, {aten::bitwise_not, UnaryOpType::Not}, {aten::ceil, UnaryOpType::Ceil}, {aten::floor, UnaryOpType::Floor}, {aten::frac, UnaryOpType::Frac}, {aten::neg, UnaryOpType::Neg}, {aten::relu, UnaryOpType::Relu}, {aten::round, UnaryOpType::Round}, {aten::silu, UnaryOpType::Silu}, {aten::trunc, UnaryOpType::Trunc}, }); MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( c10::nullopt, value_map[node->inputs()[0]->unique()]); auto operand = list_val.front(); list_val.pop_front(); auto out = unaryOp(op_mapping[node->kind()], operand); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, isInputNonSizeZeroTensor, nullptr); } std::array UnaryFloatOp = { "aten::log(Tensor self) -> Tensor", "aten::log10(Tensor self) -> Tensor", "aten::log1p(Tensor self) -> Tensor", "aten::log2(Tensor self) -> Tensor", "aten::lgamma(Tensor self) -> Tensor", "aten::exp(Tensor self) -> Tensor", "aten::expm1(Tensor self) -> Tensor", "aten::erf(Tensor self) -> Tensor", "aten::erfc(Tensor self) -> Tensor", "aten::cos(Tensor self) -> Tensor", "aten::acos(Tensor self) -> Tensor", "aten::cosh(Tensor self) -> Tensor", "aten::sin(Tensor self) -> Tensor", "aten::asin(Tensor self) -> Tensor", "aten::sinh(Tensor self) -> Tensor", "aten::tan(Tensor self) -> Tensor", "aten::atan(Tensor self) -> Tensor", "aten::tanh(Tensor self) -> Tensor", "aten::atanh(Tensor self) -> Tensor", "aten::sqrt(Tensor self) -> Tensor", "aten::rsqrt(Tensor self) -> Tensor", "aten::reciprocal(Tensor self) -> Tensor", "aten::sigmoid(Tensor self) -> Tensor"}; for (auto signature : UnaryFloatOp) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { static std::unordered_map op_mapping({ {aten::log, UnaryOpType::Log}, {aten::log10, UnaryOpType::Log10}, {aten::log1p, UnaryOpType::Log1p}, {aten::log2, UnaryOpType::Log2}, {aten::lgamma, UnaryOpType::Lgamma}, {aten::exp, UnaryOpType::Exp}, {aten::expm1, UnaryOpType::Expm1}, {aten::erf, UnaryOpType::Erf}, {aten::erfc, UnaryOpType::Erfc}, {aten::cos, UnaryOpType::Cos}, {aten::acos, UnaryOpType::Acos}, {aten::cosh, UnaryOpType::Cosh}, {aten::sin, UnaryOpType::Sin}, {aten::asin, UnaryOpType::Asin}, {aten::sinh, UnaryOpType::Sinh}, {aten::tan, UnaryOpType::Tan}, {aten::tanh, UnaryOpType::Tanh}, {aten::atan, UnaryOpType::Atan}, {aten::atanh, UnaryOpType::Atanh}, {aten::sqrt, UnaryOpType::Sqrt}, {aten::rsqrt, UnaryOpType::Rsqrt}, {aten::reciprocal, UnaryOpType::Reciprocal}, {aten::sigmoid, UnaryOpType::Sigmoid}, }); MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( c10::nullopt, value_map[node->inputs()[0]->unique()]); auto operand = list_val.front(); list_val.pop_front(); auto out = unaryOp( op_mapping[node->kind()], operand, TypePromotion::float_op_config); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, isInputNonSizeZeroTensor, nullptr); } std::array UnaryIsOp = { "aten::isfinite(Tensor self) -> Tensor", "aten::isinf(Tensor self) -> Tensor", "aten::isnan(Tensor self) -> Tensor", "aten::isneginf(Tensor self) -> Tensor", "aten::isposinf(Tensor self) -> Tensor", "aten::isreal(Tensor self) -> Tensor"}; for (auto signature : UnaryIsOp) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { static std::unordered_map op_mapping({ {aten::isfinite, UnaryOpType::IsFinite}, {aten::isinf, UnaryOpType::IsInf}, {aten::isnan, UnaryOpType::IsNan}, {aten::isneginf, UnaryOpType::IsNegInf}, {aten::isposinf, UnaryOpType::IsPosInf}, {aten::isreal, UnaryOpType::IsReal}, }); MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( c10::nullopt, value_map[node->inputs()[0]->unique()]); auto operand = list_val.front(); list_val.pop_front(); auto out = unaryIsOp(op_mapping[node->kind()], operand); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, isInputNonSizeZeroTensor, nullptr); } { auto ptr_op = getOperatorForLiteral( "aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( c10::nullopt, value_map[node->inputs()[0]->unique()]); auto operand = list_val.front(); list_val.pop_front(); if (!node->input(3)->type()->isSubtypeOf( static_cast(NoneType::get()))) { auto device = constant_as(node->input(3)); TORCH_INTERNAL_ASSERT( device.has_value() && device->is_cuda(), "rand_like in nvfuser is not on cuda device"); auto input_tensor_type = node->input(0)->type()->cast(); // device->index() == -1 indicating that we don't change device // index if (device->index() != -1 && input_tensor_type) { auto input_device = input_tensor_type->device(); // we expect device index to be consistent with input and it // should have already been handled by partition TORCH_INTERNAL_ASSERT( !input_device.has_value() || input_device->index() == device->index(), "rand_like in nvfuser is not on cuda device"); } } auto out = randlike(operand); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, [](const Node* node) -> bool { if (!isInputNonSizeZeroTensor(node)) { return false; } if (!node->input(1)->type()->isSubtypeOf( static_cast(NoneType::get())) || !node->input(2)->type()->isSubtypeOf( static_cast(NoneType::get())) || !node->input(5)->type()->isSubtypeOf( static_cast(NoneType::get()))) { return false; } return true; }, nullptr); } { auto ptr_op = getOperatorForLiteral( "aten::softplus(Tensor self, Scalar beta, Scalar threshold) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( c10::nullopt, value_map[node->inputs()[0]->unique()]); auto operand = list_val.front()->as(); list_val.pop_front(); auto& beta = value_map[node->inputs()[1]->unique()]; auto& threshold = value_map[node->inputs()[2]->unique()]; auto out = softplus(operand, beta, threshold); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, isInputNonSizeZeroTensor, nullptr); } { auto ptr_op = getOperatorForLiteral( "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( c10::nullopt, value_map[node->inputs()[0]->unique()]); auto operand = list_val.front(); list_val.pop_front(); auto& th = value_map[node->inputs()[1]->unique()]; auto& value = value_map[node->inputs()[2]->unique()]; auto out = threshold(operand, th, value); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, isInputNonSizeZeroTensor, nullptr); } { // LTC uses threshold_backward for relu_backward auto ptr_op = getOperatorForLiteral( "aten::threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()]); auto grad_output = list_val.front(); list_val.pop_front(); auto input = list_val.front(); auto& threshold = value_map[node->inputs()[2]->unique()]; auto comparison = binaryOp( BinaryOpType::GT, input, threshold, TypePromotion::comparison_op_config); auto mask = castOp(input->getDataType().value(), comparison); auto out = mul(grad_output, mask); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, isInputNonSizeZeroTensor, nullptr); } { auto ptr_op = getOperatorForLiteral( "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( c10::nullopt, value_map[node->inputs()[0]->unique()]); auto operand = list_val.front(); list_val.pop_front(); Val* min = value_map.count(node->inputs()[1]->unique()) != 0 ? *value_map[node->inputs()[1]->unique()] : nullptr; Val* max = value_map.count(node->inputs()[2]->unique()) != 0 ? *value_map[node->inputs()[2]->unique()] : nullptr; Val* out = clamp(operand, min, max); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, isInputNonSizeZeroTensor, nullptr); } { auto ptr_op = getOperatorForLiteral( "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()], value_map[node->inputs()[2]->unique()]); auto condition = list_val.front(); list_val.pop_front(); auto x = list_val.front(); list_val.pop_front(); auto y = list_val.front(); list_val.pop_front(); auto out = where(condition, x, y); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, isInputNonSizeZeroTensor, nullptr); } { std::array LerpOp = { "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor", "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor"}; for (auto signature : LerpOp) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()], value_map[node->inputs()[2]->unique()]); auto self = list_val.front(); list_val.pop_front(); auto end = list_val.front(); list_val.pop_front(); auto weight = list_val.front(); list_val.pop_front(); auto out = lerp(self, end, weight); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, isInputNonSizeZeroTensor, nullptr); } } { auto ptr_op = getOperatorForLiteral( "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()], value_map[node->inputs()[2]->unique()], value_map[node->inputs()[3]->unique()]); auto self = list_val.front(); list_val.pop_front(); auto tensor1 = list_val.front(); list_val.pop_front(); auto tensor2 = list_val.front(); list_val.pop_front(); auto value = list_val.front(); list_val.pop_front(); auto out = addcmul(self, tensor1, tensor2, value); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, isInputNonSizeZeroTensor, nullptr); } { auto ptr_op = getOperatorForLiteral( "aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)"); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()]); auto input = list_val.front(); list_val.pop_front(); auto prob = list_val.front(); list_val.pop_front(); auto train = constant_as(node->input(2)); TORCH_INTERNAL_ASSERT( train.has_value(), "dropout needs constant `train` flag"); if (train.value()) { auto result = dropout(input->as(), prob); value_map.emplace( node->output(0)->unique(), ValueHolder(result.output, format)); value_map.emplace( node->output(1)->unique(), ValueHolder(result.mask, format)); } else { value_map.emplace(node->output(0)->unique(), input); value_map.emplace( node->output(1)->unique(), ValueHolder(TensorViewBuilder().build(), format)); } }, [](const Node* node) -> bool { if (!isInputNonSizeZeroTensor(node)) { return false; } if (node->inputs()[2]->node()->kind() != prim::Constant) { return false; } return true; }, nullptr); } { auto ptr_op = getOperatorForLiteral( "aten::dropout(Tensor input, float p, bool train) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()]); auto input = list_val.front(); list_val.pop_front(); auto prob = list_val.front(); list_val.pop_front(); auto train = constant_as(node->input(2)); TORCH_INTERNAL_ASSERT( train.has_value(), "dropout needs constant `train` flag"); if (train.value()) { auto result = dropout(input->as(), prob); value_map.emplace( node->output()->unique(), ValueHolder(result.output, format)); } else { value_map.emplace( node->output()->unique(), ValueHolder(input, format)); } }, [](const Node* node) -> bool { if (!isInputNonSizeZeroTensor(node)) { return false; } if (node->inputs()[2]->node()->kind() != prim::Constant) { return false; } return true; }, nullptr); } { auto ptr_op = getOperatorForLiteral( "aten::native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()], value_map[node->inputs()[2]->unique()]); auto grad = list_val.front(); list_val.pop_front(); auto mask = list_val.front(); list_val.pop_front(); auto scale = list_val.front(); list_val.pop_front(); auto output = dropout_backward( grad->as(), mask->as(), scale); value_map.emplace( node->output()->unique(), ValueHolder(output, format)); }, isInputNonSizeZeroTensor, nullptr); } { std::array InstanceNormFwd = { "aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor"}; for (auto signature : InstanceNormFwd) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { // TODO: handle channels last MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( MemoryFormat::Contiguous(), value_map[node->inputs()[0]->unique()]); auto input_t = list_val.front(); list_val.pop_front(); auto input = input_t->as(); TensorView* weight = nullptr; if (!node->input(1)->type()->isSubtypeOf( static_cast(NoneType::get()))) { weight = value_map[node->input(1)->unique()]->as(); } TensorView* bias = nullptr; if (!node->input(2)->type()->isSubtypeOf( static_cast(NoneType::get()))) { bias = value_map[node->input(2)->unique()]->as(); } TensorView* running_mean = nullptr; if (!node->input(3)->type()->isSubtypeOf( static_cast(NoneType::get()))) { running_mean = value_map[node->input(3)->unique()]->as(); } TensorView* running_var = nullptr; if (!node->input(4)->type()->isSubtypeOf( static_cast(NoneType::get()))) { running_var = value_map[node->input(4)->unique()]->as(); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) auto use_input_stats = constant_as(node->input(5)); TORCH_INTERNAL_ASSERT( use_input_stats.has_value(), "The use_input_stats (bool) parameter is required."); const bool kUseInputStats = use_input_stats.value(); Val* momentum_ptr = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) if (auto momentum = constant_as(node->input(6))) { momentum_ptr = IrBuilder::create(momentum.value()); } else { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) momentum_ptr = value_map[node->input(6)->unique()]; } Val* eps_ptr = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) if (auto eps = constant_as(node->input(7))) { eps_ptr = IrBuilder::create(eps.value()); } else { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) eps_ptr = value_map[node->input(7)->unique()]; } auto result = instance_norm( input, weight, bias, running_mean, running_var, kUseInputStats, momentum_ptr, eps_ptr); if (node->kind() == c10::Symbol::fromQualString("aten::instance_norm")) { value_map.emplace(node->output()->unique(), result.output); } }, [](const Node* node) -> bool { if (isReductionNonCompatibleTensor( node->input(0)->type()->cast())) { return false; } return true; }, [](const Node* node) -> OperatorType { return OperatorType::Normalization; }); } } { std::array BatchNormFwd = { "aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)", "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor"}; for (auto signature : BatchNormFwd) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; Val* operand = nullptr; std::tie(format, operand) = value_map[node->input(0)->unique()].getEntry(); if (format.hasPermutation() && !format.isChannelsLast()) { format = MemoryFormat::Contiguous(); operand = value_map[node->input(0)->unique()].maybeConvertValue( format); } auto input = operand->as(); TensorView* weight = nullptr; if (!node->input(1)->type()->isSubtypeOf( static_cast(NoneType::get()))) { weight = value_map[node->input(1)->unique()]->as(); } TensorView* bias = nullptr; if (!node->input(2)->type()->isSubtypeOf( static_cast(NoneType::get()))) { bias = value_map[node->input(2)->unique()]->as(); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) auto training = constant_as(node->input(5)); TORCH_INTERNAL_ASSERT( training.has_value(), "The training (bool) parameter is required."); const bool kTraining = training.value(); TensorView* running_mean = nullptr; if (!node->input(3)->type()->isSubtypeOf( static_cast(NoneType::get()))) { running_mean = value_map[node->input(3)->unique()]->as(); } TensorView* running_var = nullptr; if (!node->input(4)->type()->isSubtypeOf( static_cast(NoneType::get()))) { running_var = value_map[node->input(4)->unique()]->as(); } Val* momentum_ptr = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) if (auto momentum = constant_as(node->input(6))) { momentum_ptr = IrBuilder::create(momentum.value()); } else { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) momentum_ptr = value_map[node->input(6)->unique()]; } Val* eps_ptr = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) if (auto eps = constant_as(node->input(7))) { eps_ptr = IrBuilder::create(eps.value()); } else { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) eps_ptr = value_map[node->input(7)->unique()]; } auto result = batch_norm( input, weight, bias, running_mean, running_var, kTraining, momentum_ptr, eps_ptr, format.isChannelsLast()); if (node->kind() == c10::Symbol::fromQualString("aten::native_batch_norm") || node->kind() == c10::Symbol::fromQualString( "aten::_batch_norm_impl_index")) { // TODO: output 3 & 4 are not created // we are not creating these outputs because codegen // currently lacks the support. value_map.emplace( node->output(0)->unique(), ValueHolder(result.output, format)); value_map.emplace(node->output(1)->unique(), result.mean); value_map.emplace(node->output(2)->unique(), result.invstd); } else if ( node->kind() == c10::Symbol::fromQualString("aten::batch_norm")) { value_map.emplace( node->output()->unique(), ValueHolder(result.output, format)); } }, [](const Node* node) -> bool { if (isReductionNonCompatibleTensor( node->input(0)->type()->cast())) { return false; } if (node->input(5)->node()->kind() != prim::Constant) { return false; } return true; }, [](const Node* node) -> OperatorType { return OperatorType::Normalization; }); } } { std::array BatchNormBwd = { "aten::_batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)", "aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)"}; for (auto signature : BatchNormBwd) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { JitValue* ts_input = nullptr; JitValue* ts_grad_output; JitValue* ts_weight = nullptr; JitValue* ts_r_mean = nullptr; JitValue* ts_r_var = nullptr; JitValue* ts_save_mean = nullptr; JitValue* ts_save_invstd = nullptr; JitValue* ts_train = nullptr; JitValue* ts_eps = nullptr; JitValue* ts_mask = nullptr; if (node->kind() == c10::Symbol::fromQualString( "aten::_batch_norm_impl_index_backward")) { ts_input = node->input(1); ts_grad_output = node->input(2); ts_weight = node->input(3); ts_r_mean = node->input(4); ts_r_var = node->input(5); ts_save_mean = node->input(6); ts_save_invstd = node->input(7); ts_train = node->input(8); ts_eps = node->input(9); ts_mask = node->input(10); } else if ( node->kind() == c10::Symbol::fromQualString( "aten::native_batch_norm_backward")) { ts_grad_output = node->input(0); ts_input = node->input(1); ts_weight = node->input(2); ts_r_mean = node->input(3); ts_r_var = node->input(4); ts_save_mean = node->input(5); ts_save_invstd = node->input(6); ts_train = node->input(7); ts_eps = node->input(8); ts_mask = node->input(9); } else { TORCH_INTERNAL_ASSERT( false, "Forgot to register the key for BN variation: ", node->kind().toDisplayString()); } // discard impl_index and reservedSpace since we don't use them MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( c10::nullopt, value_map[ts_input->unique()], value_map[ts_grad_output->unique()]); if (format.hasPermutation() && !format.isChannelsLast()) { std::tie(format, list_val) = getConsistentValues( MemoryFormat::Contiguous(), value_map[ts_input->unique()], value_map[ts_grad_output->unique()]); } auto operand0 = list_val.front(); list_val.pop_front(); auto operand1 = list_val.front(); list_val.pop_front(); auto input = operand0->as(); auto grad_out = operand1->as(); TensorView* weight = nullptr; if (!ts_weight->type()->isSubtypeOf( static_cast(NoneType::get()))) { weight = value_map[ts_weight->unique()]->as(); } TensorView* running_mean = nullptr; if (!ts_r_mean->type()->isSubtypeOf( static_cast(NoneType::get()))) { running_mean = value_map[ts_r_mean->unique()]->as(); } TensorView* running_var = nullptr; if (!ts_r_var->type()->isSubtypeOf( static_cast(NoneType::get()))) { running_var = value_map[ts_r_var->unique()]->as(); } TensorView* save_mean = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) if (!ts_save_mean->type()->isSubtypeOf( static_cast(NoneType::get()))) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) save_mean = value_map[ts_save_mean->unique()]->as(); } TensorView* save_invstd = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) if (!ts_save_invstd->type()->isSubtypeOf( static_cast(NoneType::get()))) { save_invstd = // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) value_map[ts_save_invstd->unique()]->as(); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) auto training = constant_as(ts_train); TORCH_INTERNAL_ASSERT( training.has_value(), "The training (bool) parameter is required."); const bool kTraining = training.value(); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) Val* eps_ptr = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) if (auto eps = constant_as(ts_eps)) { eps_ptr = IrBuilder::create(eps.value()); } else { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) eps_ptr = value_map[ts_eps->unique()]; } // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) auto out_mask_list = constant_as>(ts_mask); TORCH_INTERNAL_ASSERT( out_mask_list.has_value(), "output mask for batch_norm_backward"); std::vector output_mask; for (const auto value : out_mask_list->vec()) { output_mask.emplace_back(static_cast(value)); } // TODO: merge this loop below. if (kTraining) { TORCH_INTERNAL_ASSERT( save_mean != nullptr && save_invstd != nullptr, "When training=True, save_mean and save_invstd are required."); } else { // TODO: this is not a legit assumption? Can't we run with // track_running_stats == false && training == false // which should just run through the case above. TORCH_INTERNAL_ASSERT( running_mean != nullptr && running_var != nullptr, "When training=False, running_mean and running_invstd are required."); } auto grads = batch_norm_backward( input, grad_out, weight, running_mean, running_var, save_mean, save_invstd, kTraining, eps_ptr, output_mask, format.isChannelsLast()); if (output_mask[0]) { TORCH_INTERNAL_ASSERT(grads.grad_input != nullptr); value_map.emplace( node->output(0)->unique(), ValueHolder(grads.grad_input, format)); } else { TORCH_INTERNAL_ASSERT(grads.grad_input == nullptr); value_map.emplace( node->output(0)->unique(), ValueHolder(TensorViewBuilder().build(), format)); } if (output_mask[1]) { TORCH_INTERNAL_ASSERT(grads.grad_weight != nullptr); value_map.emplace(node->output(1)->unique(), grads.grad_weight); } else { TORCH_INTERNAL_ASSERT(grads.grad_weight == nullptr); value_map.emplace( node->output(1)->unique(), TensorViewBuilder().build()); } if (output_mask[2]) { TORCH_INTERNAL_ASSERT(grads.grad_bias != nullptr); value_map.emplace(node->output(2)->unique(), grads.grad_bias); } else { TORCH_INTERNAL_ASSERT(grads.grad_bias == nullptr); value_map.emplace( node->output(2)->unique(), TensorViewBuilder().build()); } }, [](const Node* node) -> bool { if (isReductionNonCompatibleTensor( node->input(1)->type()->cast())) { return false; } if (node->kind() == c10::Symbol::fromQualString( "aten::_batch_norm_impl_index_backward")) { if (node->inputs()[8]->node()->kind() != prim::Constant) { return false; } if (node->inputs()[10]->node()->kind() != prim::Constant) { return false; } } else if ( node->kind() == c10::Symbol::fromQualString( "aten::native_batch_norm_backward")) { if (node->inputs()[7]->node()->kind() != prim::Constant) { return false; } if (node->inputs()[9]->node()->kind() != prim::Constant) { return false; } } else { TORCH_INTERNAL_ASSERT( false, "Forgot to update profiled constant check for", node->kind().toDisplayString()); } return true; }, [](const Node* node) -> OperatorType { return OperatorType::Normalization; }); } } { std::array LayerNormFwd = { "aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)", "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"}; for (auto signature : LayerNormFwd) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( MemoryFormat::Contiguous(), value_map[node->inputs()[0]->unique()]); auto input_t = list_val.front(); list_val.pop_front(); auto input = input_t->as(); auto norm_shape_optional = constant_as>(node->input(1)); TORCH_INTERNAL_ASSERT( norm_shape_optional.has_value(), "The Normalized_Shape list is required."); auto norm_shape = norm_shape_optional->vec(); TensorView* weight = nullptr; if (!node->input(2)->type()->isSubtypeOf( static_cast(NoneType::get()))) { weight = value_map[node->input(2)->unique()]->as(); } TensorView* bias = nullptr; if (!node->input(3)->type()->isSubtypeOf( static_cast(NoneType::get()))) { bias = value_map[node->input(3)->unique()]->as(); } Val* eps_ptr = nullptr; if (auto eps = constant_as(node->input(4))) { eps_ptr = IrBuilder::create(eps.value()); } else { eps_ptr = value_map[node->input(4)->unique()]; } auto result = layer_norm(input, norm_shape, weight, bias, eps_ptr); if (node->kind() == c10::Symbol::fromQualString("aten::native_layer_norm")) { value_map.emplace(node->output(0)->unique(), result.output); value_map.emplace(node->output(1)->unique(), result.mean); value_map.emplace(node->output(2)->unique(), result.invstd); } else if ( node->kind() == c10::Symbol::fromQualString("aten::layer_norm")) { value_map.emplace(node->output()->unique(), result.output); } }, // TODO: #ProfileIValue List should update this [](const Node* node) -> bool { if (isReductionNonCompatibleTensor( node->input(0)->type()->cast())) { return false; } if (node->inputs()[1]->node()->kind() != prim::Constant) { return false; } return true; }, [](const Node* node) -> OperatorType { return OperatorType::Normalization; }); } } { auto ptr_op = getOperatorForLiteral( "aten::native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)"); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( MemoryFormat::Contiguous(), value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()]); auto grad_out_t = list_val.front(); list_val.pop_front(); auto input_t = list_val.front(); list_val.pop_front(); auto grad_out = grad_out_t->as(); auto input = input_t->as(); auto norm_shape_optional = constant_as>(node->input(2)); TORCH_INTERNAL_ASSERT( norm_shape_optional.has_value(), "The Normalized_Shape list is required."); auto norm_shape = norm_shape_optional->vec(); auto mean = value_map[node->input(3)->unique()]->as(); auto rstd = value_map[node->input(4)->unique()]->as(); TensorView* weight = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) if (!node->input(5)->type()->isSubtypeOf( static_cast(NoneType::get()))) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) weight = value_map[node->input(5)->unique()]->as(); } TensorView* bias = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) if (!node->input(6)->type()->isSubtypeOf( static_cast(NoneType::get()))) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) bias = value_map[node->input(6)->unique()]->as(); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) auto output_mask_optional = constant_as>(node->input(7)); TORCH_INTERNAL_ASSERT( output_mask_optional.has_value(), "output mask for layer_norm_backward"); std::vector output_mask = output_mask_optional->vec(); auto grad = layer_norm_backward( grad_out, input, norm_shape, mean, rstd, weight, bias, output_mask); if (output_mask[0]) { TORCH_INTERNAL_ASSERT(grad.grad_input != nullptr); value_map.emplace(node->output(0)->unique(), grad.grad_input); } else { TORCH_INTERNAL_ASSERT(grad.grad_input == nullptr); value_map.emplace( node->output(0)->unique(), TensorViewBuilder().build()); } if (output_mask[1] && weight != nullptr) { TORCH_INTERNAL_ASSERT(grad.grad_weight != nullptr); value_map.emplace(node->output(1)->unique(), grad.grad_weight); } else { TORCH_INTERNAL_ASSERT(grad.grad_weight == nullptr); value_map.emplace( node->output(1)->unique(), TensorViewBuilder().build()); } if (output_mask[2] && bias != nullptr) { TORCH_INTERNAL_ASSERT(grad.grad_bias != nullptr); value_map.emplace(node->output(2)->unique(), grad.grad_bias); } else { TORCH_INTERNAL_ASSERT(grad.grad_bias == nullptr); value_map.emplace( node->output(2)->unique(), TensorViewBuilder().build()); } }, // TODO: #ProfileIValue List should update this [](const Node* node) -> bool { if (isReductionNonCompatibleTensor( node->input(0)->type()->cast())) { return false; } if (node->inputs()[2]->node()->kind() != prim::Constant) { return false; } if (node->inputs()[7]->node()->kind() != prim::Constant) { return false; } return true; }, [](const Node* node) -> OperatorType { return OperatorType::Normalization; }); } { std::array SoftmaxFwd = { "aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"}; for (auto signature : SoftmaxFwd) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( MemoryFormat::Contiguous(), value_map[node->inputs()[0]->unique()]); auto input_t = list_val.front(); list_val.pop_front(); auto input = input_t->as(); auto dim_value = constant_as(node->input(1)); TORCH_INTERNAL_ASSERT( dim_value.has_value(), "dim in softmax is not valid"); auto data_type = DataType::Null; if (const auto opt_ivalue = toIValue(node->input(2))) { if (!opt_ivalue.value().isNone()) { data_type = aten_to_data_type(opt_ivalue->toScalarType()); } } input = (data_type != DataType::Null) ? optionalCastStrict(data_type, input)->as() : input; bool is_log_softmax = node->kind() == c10::Symbol::fromQualString("aten::log_softmax"); auto output = (is_log_softmax) ? log_softmax(input, dim_value.value()) : softmax(input, dim_value.value()); value_map.emplace(node->output()->unique(), output); }, [](const Node* node) -> bool { if (isReductionNonCompatibleTensor( node->input(0)->type()->cast())) { return false; } if (node->inputs()[1]->node()->kind() != prim::Constant) { return false; } if (!isScalarTypeCompatible(node, 2)) { return false; } return true; }, [](const Node* node) -> OperatorType { return OperatorType::Normalization; }); } } { // LTC uses this op for softmax auto ptr_op = getOperatorForLiteral( "aten::_softmax(Tensor self, int dim, bool half_to_float) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( MemoryFormat::Contiguous(), value_map[node->inputs()[0]->unique()]); auto input_t = list_val.front(); list_val.pop_front(); auto input = input_t->as(); auto dim_value = constant_as(node->input(1)); TORCH_INTERNAL_ASSERT( dim_value.has_value(), "dim in softmax is not valid"); auto output = softmax(input, dim_value.value()); value_map.emplace(node->output()->unique(), output); }, [](const Node* node) -> bool { if (isReductionNonCompatibleTensor( node->input(0)->type()->cast())) { return false; } if (node->inputs()[1]->node()->kind() != prim::Constant) { return false; } if (node->inputs()[2]->node()->kind() != prim::Constant) { return false; } else { const auto half_to_float = constant_as(node->input(2)); TORCH_INTERNAL_ASSERT( half_to_float.has_value(), "Bool half_to_float is not valid"); auto input_tensor_type = node->input(0)->type()->cast(); if (half_to_float.value() && input_tensor_type->scalarType() != at::ScalarType::Half) { return false; } } return true; }, [](const Node* node) -> OperatorType { return OperatorType::Normalization; }); } { std::array SoftmaxBwd = { "aten::_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor", "aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor"}; for (auto signature : SoftmaxBwd) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( MemoryFormat::Contiguous(), value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()]); auto grad_output_t = list_val.front(); list_val.pop_front(); auto grad_output = grad_output_t->as(); auto output_t = list_val.front(); list_val.pop_front(); auto output = output_t->as(); auto dim_value = constant_as(node->input(2)); TORCH_INTERNAL_ASSERT( dim_value.has_value(), "dim in softmax is not valid"); // input_dtype here is ignored! type_inference handles it bool is_log_softmax = node->kind() == c10::Symbol::fromQualString( "aten::_log_softmax_backward_data"); auto grad_input = (is_log_softmax) ? log_softmax_backward(grad_output, output, dim_value.value()) : softmax_backward(grad_output, output, dim_value.value()); value_map.emplace(node->output()->unique(), grad_input); }, [](const Node* node) -> bool { if (isReductionNonCompatibleTensor( node->input(0)->type()->cast())) { return false; } if (node->inputs()[2]->node()->kind() != prim::Constant) { return false; } if (node->inputs()[3]->node()->kind() != prim::Constant) { return false; } return true; }, [](const Node* node) -> OperatorType { return OperatorType::Normalization; }); } } { std::array Variance = { "aten::var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor", "aten::std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor"}; for (auto signature : Variance) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( MemoryFormat::Contiguous(), value_map[node->inputs()[0]->unique()]); auto input_t = list_val.front(); list_val.pop_front(); auto input = input_t->as(); bool is_variance = node->kind() == c10::Symbol::fromQualString("aten::var"); auto dims_list = constant_as>(node->input(1)); TORCH_INTERNAL_ASSERT( dims_list.has_value(), "Cannot fuse with dynamic axes"); std::vector dims; if (!dims_list->empty()) { for (const auto dim : dims_list->vec()) { dims.emplace_back(static_cast(dim)); } } else { dims.resize(input->as()->nDims()); std::iota(dims.begin(), dims.end(), 0); } auto unbiased = constant_as(node->input(2)); TORCH_INTERNAL_ASSERT( unbiased.has_value(), "Cannot fuse with dynamic unbiased"); auto keepdim = constant_as(node->input(3)); TORCH_INTERNAL_ASSERT( keepdim.has_value(), "Cannot fuse with dynamic keepdim"); auto output = (is_variance) ? variance(input, dims, unbiased.value(), keepdim.value()) : standard_deviation( input, dims, unbiased.value(), keepdim.value()); value_map.emplace(node->output()->unique(), output); }, [](const Node* node) -> bool { if (isReductionNonCompatibleTensor( node->input(0)->type()->cast())) { return false; } return true; }, [](const Node* node) -> OperatorType { return OperatorType::Normalization; }); } } { auto ptr_op = getOperatorForLiteral( "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"); REGISTER_PARSE_RULE( ptr_op, { // TODO: support channels last in sum MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( MemoryFormat::Contiguous(), value_map[node->inputs()[0]->unique()]); auto self = list_val.front(); list_val.pop_front(); auto dims_list = constant_as>(node->input(1)); TORCH_INTERNAL_ASSERT( dims_list.has_value(), "aten::sum cannot be fused with dynamic axes"); std::vector dims; if (!dims_list->empty()) { for (const auto dim : dims_list->vec()) { dims.emplace_back(static_cast(dim)); } } else { dims.resize(self->as()->nDims()); std::iota(dims.begin(), dims.end(), 0); } auto keepdim = constant_as(node->input(2)); TORCH_INTERNAL_ASSERT( keepdim.has_value(), "aten::sum cannot be fused with dynamic keepdim"); auto out = sum(self->as(), dims, keepdim.value()); value_map.emplace(node->output()->unique(), out); }, [](const Node* node) -> bool { if (isReductionNonCompatibleTensor( node->input(0)->type()->cast())) { return false; } // TODO: support cast of output types if (!node->inputs()[3]->type()->isSubtypeOf( static_cast(NoneType::get()))) { // We can only handle output as half, float, and double; if (const auto opt_ivalue = toIValue(node->input(3))) { const auto scalar_type = opt_ivalue->toScalarType(); if (!at::isFloatingType(scalar_type)) { return false; } } } // we don't support dynamic reduction axes; if (node->inputs()[1]->node()->kind() != prim::Constant) { return false; } // we don't support dynamic keepdim yet; if (node->inputs()[2]->node()->kind() != prim::Constant) { return false; } return true; }, [](const Node* node) -> OperatorType { return OperatorType::Reduction; }); } { auto ptr_op = getOperatorForLiteral( "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( MemoryFormat::Contiguous(), value_map[node->inputs()[0]->unique()]); auto operand = list_val.front(); list_val.pop_front(); auto self = operand->as(); auto dims_list = constant_as>(node->input(1)); TORCH_INTERNAL_ASSERT( dims_list.has_value(), "aten::mean cannot be fused with dynamic axes"); std::vector dims; if (!dims_list->empty()) { for (const auto dim : dims_list->vec()) { dims.emplace_back(static_cast(dim)); } } else { dims.resize(self->as()->nDims()); std::iota(dims.begin(), dims.end(), 0); } auto keepdim = constant_as(node->input(2)); TORCH_INTERNAL_ASSERT( keepdim.has_value(), "aten::mean cannot be fused with dynamic keepdim"); auto o_sum = sum(self, dims, keepdim.value()); Val* num_features = IrBuilder::create(1); for (auto axis : dims) { if (axis < 0) { axis += int(self->nDims()); } num_features = mul(num_features, self->domain()->domain()[axis]->extent()); } auto out = div(o_sum, num_features); value_map.emplace(node->output()->unique(), out); }, [](const Node* node) -> bool { if (isReductionNonCompatibleTensor( node->input(0)->type()->cast())) { return false; } // TODO: support cast of output types if (!node->inputs()[3]->type()->isSubtypeOf( static_cast(NoneType::get()))) { // We can only handle output as half, float, and double; if (const auto opt_ivalue = toIValue(node->input(3))) { const auto scalar_type = opt_ivalue->toScalarType(); if (!at::isFloatingType(scalar_type)) { return false; } } } // we don't support dynamic reduction axes; if (node->inputs()[1]->node()->kind() != prim::Constant) { return false; } // we don't support dynamic keepdim yet; if (node->inputs()[2]->node()->kind() != prim::Constant) { return false; } return true; }, [](const Node* node) -> OperatorType { return OperatorType::Reduction; }); } { std::array SumToSize = { "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)", "aten::sum_to_size(Tensor self, int[] size) -> Tensor"}; for (auto signature : SumToSize) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( MemoryFormat::Contiguous(), value_map[node->inputs()[0]->unique()]); auto self = list_val.front(); list_val.pop_front(); auto size_to = constant_as>(node->input(1)); TORCH_INTERNAL_ASSERT( size_to.has_value(), "aten::sum cannot be fused with dynamic axes"); if (!size_to->empty()) { auto input = self->as(); auto out = sum_to(input, size_to->vec()); // this copy is not necessary, but making copy avoids tricky // computational graph where no-op could be challenging. if (out == input) { out = set(input); } value_map.emplace(node->output()->unique(), out); } else { // We are introducing alias here! value_map.emplace(node->output()->unique(), self); } }, [](const Node* node) -> bool { if (isReductionNonCompatibleTensor( node->input(0)->type()->cast())) { return false; } // we don't support dynamic reduction axes; if (node->inputs()[1]->node()->kind() != prim::Constant) { return false; } return true; }, [](const Node* node) -> OperatorType { auto size_to = constant_as>(node->input(1)); // technically size_to->empty() should never occur, as specialized // _grad_sum_to_size should have been removed by optimization pass if (size_to->empty()) { return OperatorType::ElementWise; } else { return OperatorType::ReductionToSize; } }); } } { std::array AutocastOps = { "aten::_autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a)", "aten::_autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a)"}; for (auto signature : AutocastOps) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( c10::nullopt, value_map[node->inputs()[0]->unique()]); auto self = list_val.front(); list_val.pop_front(); auto out = set(self); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, isInputNonSizeZeroTensor, nullptr); } } { auto ptr_op = getOperatorForLiteral( "aten::_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( c10::nullopt, value_map[node->inputs()[0]->unique()]); auto self = list_val.front(); list_val.pop_front(); auto out = castTensoToDtype(self, node->input(1)); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, [](const Node* node) -> bool { if (!isInputNonSizeZeroTensor(node)) { return false; } if (node->inputs()[1]->node()->kind() != prim::Constant) { return false; } // we do not support explicit memory_format on output if (!node->inputs()[2]->type()->isSubtypeOf( static_cast(NoneType::get()))) { return false; } // we do not support explicit memory_format on output if (!node->inputs()[3]->type()->isSubtypeOf( static_cast(NoneType::get()))) { return false; } // we do not support explicit memory_format on output if (!node->inputs()[4]->type()->isSubtypeOf( static_cast(NoneType::get()))) { return false; } // we do not support explicit memory_format on output if (!node->inputs()[6]->type()->isSubtypeOf( static_cast(NoneType::get()))) { return false; } return true; }, nullptr); } // Limiting aten::to implementation to only change the dtype of a tensor { auto ptr_op = getOperatorForLiteral( "aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( c10::nullopt, value_map[node->inputs()[0]->unique()]); auto self = list_val.front(); list_val.pop_front(); auto out = castTensoToDtype(self, node->input(1)); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, [](const Node* node) -> bool { if (!isInputNonSizeZeroTensor(node)) { return false; } if (node->inputs()[1]->node()->kind() != prim::Constant) { return false; } // we do not support explicit memory_format on output if (!node->inputs()[4]->type()->isSubtypeOf( static_cast(NoneType::get()))) { return false; } return true; }, nullptr); } { auto ptr_op = getOperatorForLiteral( "aten::type_as(Tensor self, Tensor other) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( c10::nullopt, value_map[node->inputs()[0]->unique()]); auto self = list_val.front(); list_val.pop_front(); // TODO: switch to PyTorch dtype as it's closer to truth. // For now, reality is that PyTorch IR profiling information could // be missing even with profiling executor, due to upstream // transformations between profiling runs to fusion pass. auto opt_dtype = value_map[node->inputs()[1]->unique()]->getDataType(); TORCH_INTERNAL_ASSERT(opt_dtype.has_value()); auto out = castOp(opt_dtype.value(), self); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, isInputNonSizeZeroTensor, nullptr); } { // We are not fusing `linear` yet, because we can't codegen efficient gemm // However, we still need this here, so PE would insert profile node for // this node. // During fusion pass, We decompose linear into gemm + elementwise. auto ptr_op = getOperatorForLiteral( "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { // this entry is created so we do profile input tensors; TORCH_INTERNAL_ASSERT(false, "not implemented yet"); }, [](const Node* node) -> bool { // We only profile `linear` layer but not fusing it. return false; }); } { auto ptr_op = getOperatorForLiteral( "prim::add_optional(Tensor(a) input, Tensor? bias) -> Tensor(a)"); REGISTER_PARSE_RULE( ptr_op, { // this entry is created so we do profile input tensors; if (node->input(1)->type()->isSubtypeOf( static_cast(NoneType::get()))) { // forwarding the value; value_map.emplace( node->output()->unique(), value_map[node->inputs()[0]->unique()]); } else { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()]); auto lhs = list_val.front(); list_val.pop_front(); auto rhs = list_val.front(); list_val.pop_front(); auto out = binaryOp( BinaryOpType::Add, lhs, rhs, TypePromotion::default_op_config); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); } }, isInputNonSizeZeroTensor, nullptr); } { auto ptr_op = getOperatorForLiteral( "aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( c10::nullopt, value_map[node->inputs()[0]->unique()]); auto self = list_val.front()->as(); list_val.pop_front(); Val* negative_slope = value_map[node->inputs()[1]->unique()]; auto out = leaky_relu(self, negative_slope); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, [](const Node* node) -> bool { if (!isInputNonSizeZeroTensor(node)) { return false; } return true; }, nullptr); } { auto ptr_op = getOperatorForLiteral( "aten::gelu(Tensor self, *, str approximate='none') -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( c10::nullopt, value_map[node->inputs()[0]->unique()]); auto self = list_val.front()->as(); list_val.pop_front(); auto approximate = constant_as(node->input(1)); TORCH_INTERNAL_ASSERT( approximate.has_value(), "The approximate parameter is required."); const auto kTanhGelu = at::native::get_gelutype_enum(approximate.value()) == at::native::GeluType::Tanh; auto out = (kTanhGelu) ? tanh_gelu(self) : gelu(self); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, [](const Node* node) -> bool { if (!isInputNonSizeZeroTensor(node)) { return false; } if (node->input(1)->node()->kind() != prim::Constant) { return false; } return true; }, nullptr); } { auto ptr_op = getOperatorForLiteral( "aten::gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()]); auto grad_out = list_val.front()->as(); list_val.pop_front(); auto self = list_val.front()->as(); list_val.pop_front(); auto approximate = constant_as(node->input(2)); TORCH_INTERNAL_ASSERT( approximate.has_value(), "The approximate parameter is required."); const auto kTanhGelu = at::native::get_gelutype_enum(approximate.value()) == at::native::GeluType::Tanh; auto grad_in = (kTanhGelu) ? tanh_gelu_backward(grad_out, self) : gelu_backward(grad_out, self); value_map.emplace( node->output()->unique(), ValueHolder(grad_in, format)); }, [](const Node* node) -> bool { if (!isInputNonSizeZeroTensor(node)) { return false; } if (node->input(2)->node()->kind() != prim::Constant) { return false; } return true; }, nullptr); } { auto ptr_op = getOperatorForLiteral( "aten::tanh_backward(Tensor grad_output, Tensor output) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()]); auto grad_out = list_val.front()->as(); list_val.pop_front(); auto self = list_val.front()->as(); list_val.pop_front(); auto grad_in = tanh_backward(grad_out, self); value_map.emplace( node->output()->unique(), ValueHolder(grad_in, format)); }, isInputNonSizeZeroTensor, nullptr); } { std::array BinaryFloatOp = { "aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor", "aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor"}; for (auto signature : BinaryFloatOp) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( MemoryFormat::Contiguous(), value_map[node->inputs()[0]->unique()]); auto self = list_val.front(); list_val.pop_front(); auto dims_list = constant_as>(node->input(1)); TORCH_INTERNAL_ASSERT( dims_list.has_value(), "aten::amax/amin cannot be fused with dynamic axes"); std::vector dims; if (!dims_list->empty()) { for (const auto dim : dims_list->vec()) { dims.emplace_back(static_cast(dim)); } } else { dims.resize(self->as()->nDims()); std::iota(dims.begin(), dims.end(), 0); } auto keepdim = constant_as(node->input(2)); TORCH_INTERNAL_ASSERT( keepdim.has_value(), "aten::amax/amin cannot be fused with dynamic keepdim"); TensorView* out = nullptr; if (node->kind() == c10::Symbol::fromQualString("aten::amax")) { out = max(self->as(), dims, keepdim.value()); } else if ( node->kind() == c10::Symbol::fromQualString("aten::amin")) { out = min(self->as(), dims, keepdim.value()); } else { TORCH_INTERNAL_ASSERT( false, "unrecognized operation in aten::amax/amin"); } value_map.emplace(node->output()->unique(), out); }, [](const Node* node) -> bool { if (isReductionNonCompatibleTensor( node->input(0)->type()->cast())) { return false; } // we don't support dynamic reduction axes; if (node->inputs()[1]->node()->kind() != prim::Constant) { return false; } // we don't support dynamic keepdim yet; if (node->inputs()[2]->node()->kind() != prim::Constant) { return false; } return true; }, [](const Node* node) -> OperatorType { return OperatorType::Reduction; }); } } { std::array ViewOps = { "prim::reshape_copy(Tensor self, int[] shape) -> Tensor", "prim::view_copy(Tensor self, int[] size) -> Tensor"}; for (auto signature : ViewOps) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { auto self_value = node->inputs()[0]; MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( MemoryFormat::Contiguous(), value_map[self_value->unique()]); auto self = list_val.front()->as(); list_val.pop_front(); auto self_type = self_value->type()->cast(); TORCH_INTERNAL_ASSERT(self_type != nullptr); auto self_sizes = getTensorSizes(self_type); auto view_sizes = constant_as>(node->input(1)); TORCH_INTERNAL_ASSERT( view_sizes.has_value(), "The size parameter is required."); auto output = view(self, self_sizes, view_sizes->vec()); value_map.emplace(node->output()->unique(), output); }, [](const Node* node) -> bool { auto self_value = node->inputs()[0]; auto tensor_type = self_value->type()->cast(); if (tensor_type == nullptr) { return false; } if (!tensor_type->sizes().concrete_sizes().has_value()) { // Shape information for input tensor is required. return false; } if (!isInputNonSizeZeroTensor(node)) { return false; } // Reject fusing node if view_sizes contains an inferred dimension auto view_sizes = constant_as>(node->input(1)); if (!view_sizes.has_value()) { // The size parameter is required. return false; } for (auto axis_size : view_sizes->vec()) { if (axis_size == -1) { return false; } } return true; }, nullptr); } } { auto flatten_op = getOperatorForLiteral( "prim::flatten_copy(Tensor self, int start_dim, int end_dim) -> Tensor"); REGISTER_PARSE_RULE( flatten_op, { auto self_value = node->inputs()[0]; MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( MemoryFormat::Contiguous(), value_map[self_value->unique()]); auto self = list_val.front()->as(); list_val.pop_front(); auto start_dim_value = constant_as(node->input(1)); TORCH_INTERNAL_ASSERT( start_dim_value.has_value(), "start_dim is not valid"); auto end_dim_value = constant_as(node->input(2)); TORCH_INTERNAL_ASSERT( end_dim_value.has_value(), "end_dim is not valid"); TensorView* output = flatten(self, start_dim_value.value(), end_dim_value.value()); value_map.emplace(node->output()->unique(), output); }, [](const Node* node) -> bool { // we don't support dynamic start_dim; if (node->inputs()[1]->node()->kind() != prim::Constant) { return false; } // we don't support dynamic end_dim yet; if (node->inputs()[2]->node()->kind() != prim::Constant) { return false; } return true; }, nullptr); } { auto ptr_op = getOperatorForLiteral("prim::squeeze_copy(Tensor self) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { auto self_value = node->inputs()[0]; MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( MemoryFormat::Contiguous(), value_map[self_value->unique()]); auto self = list_val.front()->as(); list_val.pop_front(); auto self_type = self_value->type()->cast(); TORCH_INTERNAL_ASSERT(self_type != nullptr); auto self_sizes = getTensorSizes(self_type); TensorView* output = nullptr; if (self_sizes.empty()) { // squeeze on scalar tensor should just return itself; output = set(self); } else { output = squeeze(self, self_sizes); } value_map.emplace(node->output()->unique(), output); }, [](const Node* node) -> bool { // Shape information for input tensor is required. auto self_value = node->inputs()[0]; auto tensor_type = self_value->type()->cast(); if (tensor_type == nullptr) { return false; } if (!isInputNonSizeZeroTensor(node)) { return false; } return tensor_type->sizes().concrete_sizes().has_value(); }, nullptr); } { std::array AliasOpWithDim = { "prim::squeeze_copy.dim(Tensor self, int dim) -> Tensor", "prim::unsqueeze_copy(Tensor self, int dim) -> Tensor"}; for (auto signature : AliasOpWithDim) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { auto self_value = node->inputs()[0]; MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( MemoryFormat::Contiguous(), value_map[node->inputs()[0]->unique()]); auto self = list_val.front()->as(); list_val.pop_front(); auto dim_value = constant_as(node->input(1)); TORCH_INTERNAL_ASSERT(dim_value.has_value(), "dim is not valid"); TensorView* output = nullptr; if (node->kind() == prim::unsqueeze_copy) { output = unsqueeze(self, dim_value.value()); } else { auto self_type = self_value->type()->cast(); TORCH_INTERNAL_ASSERT(self_type != nullptr); auto self_sizes = getTensorSizes(self_type); if (self_sizes.empty()) { // squeeze on scalar tensor should just return itself; output = set(self); } else { output = squeeze(self, self_sizes, dim_value.value()); } } value_map.emplace(node->output()->unique(), output); }, [](const Node* node) -> bool { // Shape information for input tensor is required. auto self_value = node->inputs()[0]; auto tensor_type = self_value->type()->cast(); if (tensor_type == nullptr) { return false; } if (!isInputNonSizeZeroTensor(node)) { return false; } if (node->input(1)->node()->kind() != prim::Constant) { return false; } auto optional_sizes = tensor_type->sizes().concrete_sizes(); return tensor_type->sizes().concrete_sizes().has_value(); }, nullptr); } } { auto ptr_op = getOperatorForLiteral( "prim::expand_as_copy(Tensor self, Tensor other) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { MemoryFormat format; std::list list_val; std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()]); auto self = list_val.front()->as(); list_val.pop_front(); auto other = list_val.front()->as(); list_val.pop_front(); auto output = expand_as(self, other); value_map.emplace( node->output()->unique(), ValueHolder(output, format)); }, [](const Node* node) -> bool { if (!isInputNonSizeZeroTensor(node)) { return false; } return true; }, nullptr); } { auto ptr_op = getOperatorForLiteral( "prim::expand_copy(Tensor self, int[] size, *, bool implicit=False) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { auto self_value = node->inputs()[0]; MemoryFormat format; std::list list_val; std::tie(format, list_val) = getConsistentValues( MemoryFormat::Contiguous(), value_map[self_value->unique()]); auto self = list_val.front()->as(); list_val.pop_front(); auto expand_sizes = constant_as>(node->input(1)); TORCH_INTERNAL_ASSERT( expand_sizes.has_value(), "The size parameter is required."); std::vector expand_sizes_vec; for (const int64_t& size : expand_sizes.value()) { expand_sizes_vec.push_back(IrBuilder::create(size)); } // TODO: we should be able to support dynamic expand values auto output = expand(self, expand_sizes_vec); value_map.emplace(node->output()->unique(), output); }, [](const Node* node) -> bool { if (!isInputNonSizeZeroTensor(node)) { return false; } // expand_sizes needs to be constant auto expand_sizes = constant_as>(node->input(1)); if (!expand_sizes.has_value()) { return false; } return true; }, nullptr); } } void processJitNode(const JitOp* node) { if (node->kind() == prim::Constant) { // partition doesn't take constant node explicitly, but it does and copy // constant into subgraph. So we need to register constants in codegen IR; for (auto output : node->outputs()) { TORCH_INTERNAL_ASSERT( registerScalar(output), "registration of output failed at index ", output->offset(), " for node ", *node); } } else { auto reg_entry = lookupInRegistry(node); TORCH_INTERNAL_ASSERT( reg_entry != nullptr, "CudaFusionGroup Parser doesn't handle node: ", canonicalSchemaString(node->schema())); reg_entry->parse(node, value_map_); } } bool registerValue(const JitValue* val) { return registerInputTensor(val) || registerScalar(val); } bool registerScalar(const JitValue* val) { if (val->type()->isSubtypeOf( static_cast(ComplexType::get()))) { CgValue cg_val = nullptr; if (auto ival = constant_as>(val)) { cg_val = IrBuilder::create(ival.value()); } else { cg_val = IrBuilder::create(); } value_map_.emplace(val->unique(), cg_val); return true; } else if (val->type()->isSubtypeOf( static_cast(FloatType::get()))) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) CgValue cg_val; if (auto ival = constant_as(val)) { cg_val = IrBuilder::create(ival.value()); } else { cg_val = IrBuilder::create(); } value_map_.emplace(val->unique(), cg_val); return true; } else if (val->type()->isSubtypeOf( static_cast(IntType::get()))) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) CgValue cg_val; if (auto ival = constant_as(val)) { cg_val = IrBuilder::create(ival.value()); } else { cg_val = IrBuilder::create(); } value_map_.emplace(val->unique(), cg_val); return true; } else if (val->type()->isSubtypeOf( static_cast(BoolType::get()))) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) CgValue cg_val; if (auto ival = constant_as(val)) { cg_val = IrBuilder::create(ival.value()); } else { cg_val = IrBuilder::create(); } value_map_.emplace(val->unique(), cg_val); return true; } else if ( val->type()->isSubtypeOf( static_cast(StringType::get())) || val->type()->isSubtypeOf( static_cast(DeviceObjType::get())) || val->type()->isSubtypeOf(static_cast(NoneType::get()))) { // TODO: should we consider adding support for NoneType; // Note: String/Device scalars are only used in parsing rules, do not // register string with codegen IR. return true; } else if (val->type()->cast()) { // TODO: we don't support list type in codegen yet; // This is a WAR to allow axes of reduction to be passed as constant list; // We simply ignore conversion if the scalar value is a constant; auto ivalue = toIValue(val); TORCH_INTERNAL_ASSERT( ivalue.has_value(), "List[T] is not supported as an argument by NvFuser. Use a Constant List."); return true; } return false; } bool registerInputTensor(const JitValue* val) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) CgValue cg_val; // Don't register if we don't support the type if (auto tensor_type = val->type()->cast()) { if (!tensor_type->scalarType().has_value()) { return false; } if (aten_to_data_type(tensor_type->scalarType().value()) == DataType::Null) { return false; } // check for NHWC contiguous tensor TORCH_CHECK(tensor_type->dim().has_value(), "rank missing"); const auto n_dim = tensor_type->dim().value(); MemoryFormat format; std::vector stride_index; for (const auto i : c10::irange(n_dim)) { const auto& stride_property_i = tensor_type->stride_properties()[i]; if (stride_property_i->stride_index_.has_value()) { stride_index.emplace_back(stride_property_i->stride_index_.value()); } } // only set permutation when all stride_index are available if (stride_index.size() == n_dim) { format.setPermutation(stride_index); } // construct permuted tensor_type if (format.hasPermutation()) { auto opt_s_vec = tensor_type->symbolic_sizes().sizes(); TORCH_CHECK(opt_s_vec.has_value(), "missing rank of symbolic sizes"); std::vector s_vec = opt_s_vec.value(); // apply permutation auto permutation = format.apply(); for (auto new_axis : c10::irange(permutation.size())) { auto old_axis = permutation.at(new_axis); s_vec[new_axis] = opt_s_vec.value()[old_axis]; } // copying stride properties because we need to permute it auto opt_stride_vec = tensor_type->stride_properties().sizes(); TORCH_CHECK(opt_stride_vec.has_value(), "missing stride properties"); auto nhwc_stride_vec = opt_stride_vec.value(); // Make tensor contiguous after permutation. // Note that we are only updating stride_properties.stride_index, since // contiguous_ and stride_ value should remain the same after // permutation for (const auto i : c10::irange(n_dim)) { nhwc_stride_vec[i]->stride_index_ = n_dim - i - 1; } tensor_type = c10::TensorType::create( tensor_type->scalarType(), tensor_type->device(), s_vec, nhwc_stride_vec, tensor_type->requires_grad(), tensor_type->undefined()); } cg_val = IrBuilder::create(tensor_type); if (is_cpu_scalar(*tensor_type)) { cg_val->as()->setCpuScalar(true); } value_map_.emplace(val->unique(), ValueHolder(cg_val, format)); return true; } return false; } std::shared_ptr graph_; // maps from JitValue::unique() to fusion Val; std::unordered_map value_map_; static std::unordered_set parser_symbol_set_; static std::unordered_set parser_skip_set_; static std::mutex parser_mutex_; // parsing rule registry. static std::unordered_map jit_operator_registry_; // NOLINT // pointing cached entry stored in `jit_operator_registry_` static std::unordered_map cached_registry_lookup_; // NOLINT // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) static c10::once_flag once_flag_; }; std::unordered_set IrParser::parser_symbol_set_; // NOLINT std::unordered_set IrParser::parser_skip_set_; // NOLINT std::mutex IrParser::parser_mutex_; std::unordered_map IrParser::jit_operator_registry_; // NOLINT std::unordered_map IrParser::cached_registry_lookup_; // NOLINT // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) c10::once_flag IrParser::once_flag_; ProfileIValueOp* insertProfileIValueOp( Node* node, size_t offset, ProfilingRecord* pr) { auto in_val = node->input(offset); auto pn = pr->createProfileIValueNode(in_val); pn->insertBefore(node); node->replaceInput(offset, pn->output()); return pn; } void profileReductionSize(ProfilingRecord* pr, Node* node, size_t offset) { auto pn = insertProfileIValueOp(node, offset, pr); const auto ivalue_profiler = [pr, pn](Stack& stack) { std::lock_guard lock(pr->mutex_); // TODO: we don't care about merging multiple profiling runs as we don't // support it at all; int64_t frame_id = 0; pop(stack, frame_id); IValue value; pop(stack, value); std::vector size_vec; if (value.isIntList()) { size_vec = value.toIntVector(); } else if (value.isNone()) { size_vec.clear(); } else { TORCH_INTERNAL_ASSERT( false, "profileReductionSize does not support data type: ", value.tagKind()); } // We stop profiling when it has failed if (!pn->hasAttribute(profileFailedAttr)) { if (!pn->hasAttribute(reductionSizeAttr)) { pn->is_(reductionSizeAttr, size_vec); } else { auto profiled_ints = pn->is(reductionSizeAttr); if (profiled_ints.size() != size_vec.size() || !std::equal( profiled_ints.begin(), profiled_ints.end(), size_vec.begin())) { TORCH_WARN_ONCE( __FUNCTION__, " sees varying value in profiling, ignoring and this should be handled by GUARD logic"); pn->s_(profileFailedAttr, "varying profile values"); pn->removeAttribute(reductionSizeAttr); } } } else { TORCH_INTERNAL_ASSERT( !pn->hasAttribute(reductionSizeAttr), "profiled attribute should have been removed when profiling is marked as failed"); } push(stack, value); }; pn->setCallback(ivalue_profiler); } void profileViewSize(ProfilingRecord* pr, Node* node, size_t offset) { auto pn = insertProfileIValueOp(node, offset, pr); const auto ivalue_profiler = [pr, pn](Stack& stack) { std::lock_guard lock(pr->mutex_); // TODO: we don't care about merging multiple profiling runs as we don't // support it at all; int64_t frame_id = 0; pop(stack, frame_id); IValue value; pop(stack, value); TORCH_INTERNAL_ASSERT( value.isIntList(), "profiling seeing the wrong data type"); if (!pn->hasAttribute(profileFailedAttr)) { if (!pn->hasAttribute(viewSizeAttr)) { pn->is_(viewSizeAttr, value.toIntVector()); } else { auto profiled_ints = pn->is(viewSizeAttr); auto input_ints = value.toIntList(); if (profiled_ints.size() != input_ints.size() || !std::equal( profiled_ints.begin(), profiled_ints.end(), input_ints.begin())) { TORCH_WARN_ONCE( __FUNCTION__, " sees varying value in profiling, ignoring and this should be handled by GUARD logic"); pn->s_(profileFailedAttr, "varying profile values"); pn->removeAttribute(viewSizeAttr); } } } else { TORCH_INTERNAL_ASSERT( !pn->hasAttribute(viewSizeAttr), "profiled attribute should have been removed when profiling is marked as failed"); } push(stack, value); }; pn->setCallback(ivalue_profiler); } void profileIntList(ProfilingRecord* pr, Node* node, size_t offset) { auto pn = insertProfileIValueOp(node, offset, pr); const auto ivalue_profiler = [pr, pn](Stack& stack) { std::lock_guard lock(pr->mutex_); // TODO: we don't care about merging multiple profiling runs as we don't // support it at all; int64_t frame_id = 0; pop(stack, frame_id); IValue value; pop(stack, value); TORCH_INTERNAL_ASSERT( value.isIntList(), "profiling seeing the wrong data type"); if (!pn->hasAttribute(profileFailedAttr)) { if (!pn->hasAttribute(intListAttr)) { pn->is_(intListAttr, value.toIntVector()); } else { auto profiled_ints = pn->is(intListAttr); auto input_ints = value.toIntList(); if (profiled_ints.size() != input_ints.size() || !std::equal( profiled_ints.begin(), profiled_ints.end(), input_ints.begin())) { TORCH_WARN_ONCE( __FUNCTION__, " sees varying value in profiling, ignoring and this should be handled by GUARD logic"); pn->s_(profileFailedAttr, "varying profile values"); pn->removeAttribute(intListAttr); } } } else { TORCH_INTERNAL_ASSERT( !pn->hasAttribute(intListAttr), "profiled attribute should have been removed when profiling is marked as failed"); } push(stack, value); }; pn->setCallback(ivalue_profiler); } void profileString(ProfilingRecord* pr, Node* node, size_t offset) { auto pn = insertProfileIValueOp(node, offset, pr); const auto ivalue_profiler = [pr, pn](Stack& stack) { std::lock_guard lock(pr->mutex_); // TODO: we don't care about merging multiple profiling runs as we don't // support it at all; int64_t frame_id = 0; pop(stack, frame_id); IValue value; pop(stack, value); TORCH_INTERNAL_ASSERT( value.isString(), "profiling seeing the wrong data type"); if (!pn->hasAttribute(profileFailedAttr)) { if (!pn->hasAttribute(strAttr)) { pn->s_(strAttr, value.toStringRef()); } else { const auto& profiled_str = pn->s(strAttr); const auto& input_str = value.toStringRef(); if (input_str != profiled_str) { TORCH_WARN_ONCE( __FUNCTION__, " sees varying value in profiling, ignoring and this should be handled by GUARD logic"); pn->s_(profileFailedAttr, "varying profile values"); pn->removeAttribute(strAttr); } } } else { TORCH_INTERNAL_ASSERT( !pn->hasAttribute(strAttr), "profiled attribute should have been removed when profiling is marked as failed"); } push(stack, value); }; pn->setCallback(ivalue_profiler); } void profileBool(ProfilingRecord* pr, Node* node, size_t offset) { auto pn = insertProfileIValueOp(node, offset, pr); const auto ivalue_profiler = [pr, pn](Stack& stack) { std::lock_guard lock(pr->mutex_); // TODO: we don't care about merging multiple profiling runs as we don't // support it at all; int64_t frame_id = 0; pop(stack, frame_id); IValue value; pop(stack, value); TORCH_INTERNAL_ASSERT( value.isBool(), "profiling seeing the wrong data type"); if (!pn->hasAttribute(profileFailedAttr)) { if (!pn->hasAttribute(boolAttr)) { pn->i_(boolAttr, value.toBool()); } else { auto profiled_bool = pn->i(boolAttr); auto input_bool = value.toBool(); if (input_bool != profiled_bool) { TORCH_WARN_ONCE( __FUNCTION__, " sees varying value in profiling, ignoring and this should be handled by GUARD logic"); pn->s_(profileFailedAttr, "varying profile values"); pn->removeAttribute(boolAttr); } } } else { TORCH_INTERNAL_ASSERT( !pn->hasAttribute(boolAttr), "profiled attribute should have been removed when profiling is marked as failed"); } push(stack, value); }; pn->setCallback(ivalue_profiler); } void profileInt(ProfilingRecord* pr, Node* node, size_t offset) { auto pn = insertProfileIValueOp(node, offset, pr); const auto ivalue_profiler = [pr, pn](Stack& stack) { std::lock_guard lock(pr->mutex_); // TODO: we don't care about merging multiple profiling runs as we don't // support it at all; int64_t frame_id = 0; pop(stack, frame_id); IValue value; pop(stack, value); TORCH_INTERNAL_ASSERT( value.isInt(), "profiling seeing the wrong data type"); if (!pn->hasAttribute(profileFailedAttr)) { if (!pn->hasAttribute(intAttr)) { pn->i_(intAttr, value.toInt()); } else { auto profiled_int = pn->i(intAttr); auto input_int = value.toInt(); if (input_int != profiled_int) { TORCH_WARN_ONCE( __FUNCTION__, " sees varying value in profiling, ignoring and this should be handled by GUARD logic"); pn->s_(profileFailedAttr, "varying profile values"); pn->removeAttribute(intAttr); } } } else { TORCH_INTERNAL_ASSERT( !pn->hasAttribute(intAttr), "profiled attribute should have been removed when profiling is marked as failed"); } push(stack, value); }; pn->setCallback(ivalue_profiler); } // profile ivalue, used for optional arguments void profileIval(ProfilingRecord* pr, Node* node, size_t offset) { auto pn = insertProfileIValueOp(node, offset, pr); const auto ivalue_profiler = [pr, pn](Stack& stack) { std::lock_guard lock(pr->mutex_); // TODO: we don't care about merging multiple profiling runs as we don't // support it at all; int64_t frame_id = 0; pop(stack, frame_id); IValue value; pop(stack, value); if (!pn->hasAttribute(profileFailedAttr)) { if (!pn->hasAttribute(ivalAttr)) { pn->ival_(ivalAttr, value); } else { auto profiled_ival = pn->ival(ivalAttr); if (value != profiled_ival) { TORCH_WARN_ONCE( __FUNCTION__, " sees varying value in profiling, ignoring and this should be handled by GUARD logic"); pn->s_(profileFailedAttr, "varying profile values"); pn->removeAttribute(ivalAttr); } } } else { TORCH_INTERNAL_ASSERT( !pn->hasAttribute(ivalAttr), "profiled attribute should have been removed when profiling is marked as failed"); } push(stack, value); }; pn->setCallback(ivalue_profiler); } void profileBoolList(ProfilingRecord* pr, Node* node, size_t offset) { auto pn = insertProfileIValueOp(node, offset, pr); const auto ivalue_profiler = [pr, pn](Stack& stack) { std::lock_guard lock(pr->mutex_); // TODO: we don't care about merging multiple profiling runs as we don't // support it at all; int64_t frame_id = 0; pop(stack, frame_id); IValue value; pop(stack, value); TORCH_INTERNAL_ASSERT( value.isBoolList(), "profiling seeing the wrong data type"); if (!pn->hasAttribute(profileFailedAttr)) { if (!pn->hasAttribute(boolListAttr)) { auto list = value.toBoolList(); std::vector val(list.begin(), list.end()); pn->is_(boolListAttr, val); } else { auto profiled_ints = pn->is(boolListAttr); auto input_bools = value.toBoolList(); if (profiled_ints.size() != input_bools.size() || !std::equal( input_bools.begin(), input_bools.end(), profiled_ints.begin())) { TORCH_WARN_ONCE( __FUNCTION__, " sees varying value in profiling, ignoring and this should be handled by GUARD logic"); pn->s_(profileFailedAttr, "varying profile values"); pn->removeAttribute(boolListAttr); } } } else { TORCH_INTERNAL_ASSERT( !pn->hasAttribute(boolListAttr), "profiled attribute should have been removed when profiling is marked as failed"); } push(stack, value); }; pn->setCallback(ivalue_profiler); } bool anyInBlock( const Block* block, const std::function& fn) { for (auto node : block->nodes()) { if (fn(node)) { return true; } for (auto block : node->blocks()) { if (anyInBlock(block, fn)) { return true; } } } return false; } } // namespace bool hasReductionNode(const Block* block) { return anyInBlock(block, isReductionNode); } bool isReductionNode(const Node* node) { return IrParser::isReductionNode(node); } bool isReductionToSizeNode(const Node* node) { return IrParser::isReductionToSizeNode(node); } bool hasNormalizationNode(const Block* block) { return anyInBlock(block, isNormalizationNode); } bool isNormalizationNode(const Node* node) { return IrParser::isNormalizationNode(node); } bool isElementWiseNode(const Node* node) { return IrParser::isElementWiseNode(node); } bool isNodeParsible(const Node* node) { return IrParser::canParseNode(node); } bool shouldProfileNode(const Node* node) { return IrParser::lookupInSymbolSet(node); } bool skipNodeKind(const std::string& symbol_str, bool flip) { return IrParser::querySkipSymbolSet( c10::Symbol::fromQualString(symbol_str), flip); } bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) { // is skip constant necessary? if (node->input(offset)->node()->kind() == prim::Constant) { return false; } static auto dropout_schema = getOperatorForLiteral( "aten::dropout(Tensor input, float p, bool train) -> Tensor") ->schema(); static auto native_dropout_schema = getOperatorForLiteral( "aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)") ->schema(); if (node->matches(dropout_schema) || node->matches(native_dropout_schema)) { switch (offset) { // argument 2: Is training? case 2: profileBool(pr, node, offset); break; default: return false; } return true; } static auto amax_schema = getOperatorForLiteral( "aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor") ->schema(); static auto amin_schema = getOperatorForLiteral( "aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor") ->schema(); if (node->matches(amax_schema) || node->matches(amin_schema)) { switch (offset) { // argument 1: reduction axes; case 1: profileIntList(pr, node, offset); break; // argument 2: keepdim; case 2: profileBool(pr, node, offset); break; default: return false; } return true; } static auto reduction_operator_schema = getOperatorForLiteral( "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)") ->schema(); if (node->matches(reduction_operator_schema)) { switch (offset) { // argument 1: reduction axes; case 1: profileIntList(pr, node, offset); break; // argument 2: keepdim; case 2: profileBool(pr, node, offset); break; default: return false; } return true; } static auto sum_to_size_schema = getOperatorForLiteral( "aten::sum_to_size(Tensor self, int[] size) -> Tensor") ->schema(); static auto grad_sum_to_size_schema = getOperatorForLiteral( "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)") ->schema(); if (node->matches(sum_to_size_schema) || node->matches(grad_sum_to_size_schema)) { switch (offset) { // argument 1: reduction sizes; case 1: // TODO(profile_size): double check optional[size]? profileReductionSize(pr, node, offset); break; default: return false; } return true; } static auto reshape_schema = getOperatorForLiteral("aten::reshape(Tensor self, int[] shape) -> Tensor") ->schema(); static auto reshape_copy_schema = getOperatorForLiteral( "prim::reshape_copy(Tensor self, int[] shape) -> Tensor") ->schema(); static auto view_schema = getOperatorForLiteral("aten::view(Tensor self, int[] size) -> Tensor") ->schema(); static auto view_copy_schema = getOperatorForLiteral( "prim::view_copy(Tensor self, int[] size) -> Tensor") ->schema(); if (node->matches(reshape_schema) || node->matches(reshape_copy_schema) || node->matches(view_schema) || node->matches(view_copy_schema)) { switch (offset) { // argument 1: new tensor size; case 1: profileViewSize(pr, node, offset); break; default: return false; } return true; } static auto flatten_schema1 = getOperatorForLiteral( "aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> Tensor") ->schema(); static auto flatten_schema2 = getOperatorForLiteral( "prim::flatten_copy(Tensor self, int start_dim, int end_dim) -> Tensor") ->schema(); if (node->matches(flatten_schema1) || node->matches(flatten_schema2)) { switch (offset) { // argument 1: start_dim; // argument 2: end_dim; case 1: case 2: profileInt(pr, node, offset); break; default: return false; } return true; } static auto squeeze_dim_schema = getOperatorForLiteral( "prim::squeeze_copy.dim(Tensor self, int dim) -> Tensor") ->schema(); static auto unsqueeze_schema = getOperatorForLiteral( "prim::unsqueeze_copy(Tensor self, int dim) -> Tensor") ->schema(); if (node->matches(squeeze_dim_schema) || node->matches(unsqueeze_schema)) { switch (offset) { // argument 1: unsqueeze dim; case 1: profileInt(pr, node, offset); break; default: return false; } return true; } static auto batch_norm_impl_index_schema = getOperatorForLiteral( "aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)") ->schema(); static auto native_batch_norm_schema = getOperatorForLiteral( "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)") ->schema(); static auto batch_norm_schema = getOperatorForLiteral( "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor") ->schema(); static auto instance_norm_schema = getOperatorForLiteral( "aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor") ->schema(); if (node->matches(native_batch_norm_schema) || node->matches(batch_norm_impl_index_schema) || node->matches(batch_norm_schema) || node->matches(instance_norm_schema)) { switch (offset) { // argument 5: training; case 5: profileBool(pr, node, offset); break; default: return false; } return true; } static auto gelu_schema = getOperatorForLiteral( "aten::gelu(Tensor self, *, str approximate='none') -> Tensor") ->schema(); if (node->matches(gelu_schema)) { switch (offset) { // argument 1: approximate; case 1: profileString(pr, node, offset); break; default: return false; } return true; } static auto gelu_backward_schema = getOperatorForLiteral( "aten::gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor") ->schema(); if (node->matches(gelu_backward_schema)) { switch (offset) { // argument 2: approximate; case 2: profileString(pr, node, offset); break; default: return false; } return true; } static auto native_layer_norm_schema = getOperatorForLiteral( "aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)") ->schema(); static auto layer_norm_schema = getOperatorForLiteral( "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor") ->schema(); if (node->matches(native_layer_norm_schema) || node->matches(layer_norm_schema)) { switch (offset) { case 1: profileIntList(pr, node, offset); break; default: return false; } return true; } static auto batch_norm_impl_index_backward_schema = getOperatorForLiteral( "aten::_batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)") ->schema(); if (node->matches(batch_norm_impl_index_backward_schema)) { switch (offset) { // TODO: guard impl_index, but I think that's not needed; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) case 8: // argument 8: training; profileBool(pr, node, offset); break; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) case 10: profileBoolList(pr, node, offset); break; default: return false; } return true; } static auto batch_norm_backward_schema = getOperatorForLiteral( "aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)") ->schema(); if (node->matches(batch_norm_backward_schema)) { switch (offset) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) case 7: // argument 8: training; profileBool(pr, node, offset); break; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) case 9: profileBoolList(pr, node, offset); break; default: return false; } return true; } static auto native_layer_norm_backward_schema = getOperatorForLiteral( "aten::native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)") ->schema(); if (node->matches(native_layer_norm_backward_schema)) { switch (offset) { case 2: profileIntList(pr, node, offset); break; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) case 7: profileBoolList(pr, node, offset); break; default: return false; } return true; } static auto to_copy_schema = getOperatorForLiteral( "aten::_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor") ->schema(); if (node->matches(to_copy_schema)) { switch (offset) { case 1: profileInt(pr, node, offset); return true; default: return false; } } static auto to_dtype_schema = getOperatorForLiteral( "aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor") ->schema(); if (node->matches(to_dtype_schema)) { switch (offset) { case 1: profileInt(pr, node, offset); return true; default: return false; } } static auto log_softmax_data_schema = getOperatorForLiteral( "aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor") ->schema(); static auto softmax_data_schema = getOperatorForLiteral( "aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor") ->schema(); if (node->matches(log_softmax_data_schema) || node->matches(softmax_data_schema)) { switch (offset) { case 2: profileIval(pr, node, offset); return true; default: return false; } } static auto log_softmax_backward_data_schema = getOperatorForLiteral( "aten::_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor") ->schema(); static auto softmax_backward_data_schema = getOperatorForLiteral( "aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor") ->schema(); if (node->matches(log_softmax_backward_data_schema) || node->matches(softmax_backward_data_schema)) { switch (offset) { case 2: profileInt(pr, node, offset); return true; case 3: profileInt(pr, node, offset); return true; default: return false; } } return false; } void insertProfileNodesForCUDAFuser_(Block* block, ProfilingRecord* pr) { for (const auto& n : block->nodes()) { for (const auto offset : c10::irange(n->inputs().size())) { insertProfileIValue(pr, n, offset); } for (auto ib : n->blocks()) { insertProfileNodesForCUDAFuser_(ib, pr); } } } void InsertProfileNodes(ProfilingRecord* pr) { insertProfileNodesForCUDAFuser_(pr->profiled_graph_->block(), pr); } std::unique_ptr parseJitIR(const std::shared_ptr& graph) { FUSER_PERF_SCOPE("parseJitIR"); IrParser parser(graph); return parser.parse(); } } // namespace cuda } // namespace fuser } // namespace jit } // namespace torch