#pragma once #include #include #include #include #include #include #include #include namespace torch::jit::tensorexpr { struct SmallSizeTPairHash { public: std::size_t operator()(const std::pair& x) const { // hashing input index and then dim index return x.first * 128 + x.second; } }; // Returns true if the TE fuser supports this conv2d. bool conv2dIsSupportedJit(const Node* node); // Returns true if the TE fuser supports this conv2d with mkldnn prepacked conv. bool mkldnnPrepackedConvIsSupportedJit(const Node* node); // Returns true if the TE _convolution node is Conv2d. bool isConv2d(const Node* node); // Returns true if the TE fuser supports this matmul. bool matmulIsSupported(const Node* node); template inline std::vector bufferSizes(const T& t) { std::vector sizes; for (size_t i = 0; i < t->ndim(); i++) { sizes.push_back(*intValue(t->dim(i))); } return sizes; } // Get the dimensions of a value. std::vector valueShape(const ArgValue& v); // If v is a tensor, broadcast it to match the shape of axes, or return // directly if v is a constant. ExprHandle tensorOrConstant( const ArgValue& v, const std::vector& axes); int64_t normalizeAndCheckIndex(int64_t idx, int64_t list_size); ExprHandle broadcast(const BufHandle& b, const std::vector& axes); ExprHandle constant(const ArgValue& v); std::vector computeIndicesToBroadcast( const std::vector& outputAxes, const std::vector& inputSizes); inline std::string getArgValueName(const ArgValue& a) { if (std::holds_alternative(a)) { return "BufHandle"; } else if (std::holds_alternative(a)) { return "VarHandle"; } else if (std::holds_alternative(a)) { return "double"; } else if (std::holds_alternative(a)) { return "int64_t"; } else if (std::holds_alternative(a)) { return "bool"; } else if (std::holds_alternative(a)) { return "BufList"; } else if (std::holds_alternative(a)) { return "DoubleList"; } else if (std::holds_alternative(a)) { return "IntList"; } else if (std::holds_alternative(a)) { return "None"; } else { throw std::runtime_error("ArgValue type not handled in string conversion"); } } template std::vector convertVecArgValue(const std::vector& v) { std::vector res; for (auto& x : v) { auto val = std::get_if(&x); if (val) { res.push_back(*val); } else { throw std::runtime_error( "vector type not homogeneous - found " + getArgValueName(x) + ", expected " + getArgValueName(v[0])); } } return res; } class TORCH_API TensorExprKernel { struct ConstantDescr { BufPtr buf; // Only one of ptr and node is used at a time // 1) ptr for the constant tensors // 2) node for the constant custom class objects void* ptr = nullptr; Node* node = nullptr; }; public: // Constructor Params: // * subgraph // - the graph that needs to be compiled. // * kernel_func_name // - the name that should be used for the generated kernel. // * custom_lowerings // - map that represents custom lowering definitions for a set of ops. // * symbolic_shape_inputs // - a list of symbolic graph inputs that represent the symbolic dims of // the input tensors. // * pre_alloc // - a flag to control pre-allocation of buffers. explicit TensorExprKernel( const std::shared_ptr& subgraph, std::string kernel_func_name, std::unordered_map custom_lowerings = {}, std::vector symbolic_shape_inputs = {}, bool pre_alloc = false, std::unordered_map< const torch::jit::Value*, std::vector> symbolic_strides = {}); explicit TensorExprKernel( const std::shared_ptr& subgraph, std::unordered_map custom_lowerings = {}, std::vector symbolic_shape_inputs = {}, bool pre_alloc = false, std::unordered_map< const torch::jit::Value*, std::vector> symbolic_strides = {}) : TensorExprKernel( subgraph, SubgraphUtils::generateNameForGraph(subgraph), std::move(custom_lowerings), std::move(symbolic_shape_inputs), pre_alloc, std::move(symbolic_strides)) {} void run(Stack& stack) const; void runFast( const std::vector& inputs, const std::vector& outputs) const; // Expected format of stack: // ... // i.e., output IValues must be below the input IValues in the stack. void runWithAllocatedOutputs(Stack& stack) const; void fallback(Stack& stack) const { InterpreterState(code_).run(stack); } void recompile(); StmtPtr getCodeGenStmt(); std::string getCodeText(const std::string& attr = "") { return codegen_->getCodeText(attr); } const std::shared_ptr graph() { return graph_; } const std::vector& getConstantDescriptors() const { return constants_; } const std::vector& getBufferArgs() const { return bufferArgs_; } const std::string& getKernelName() const { return (codegen_ ? codegen_->kernel_func_name() : kernel_func_name_); } const std::vector& getSymbolicShapeInputs() const { return symbolic_shape_inputs_; } private: enum BackendType { kUninitialized, kSimpleIREval, kLLVMCodeGen, kCudaCodeGen, kBlockCodeGen, }; enum MemoryLayoutPolicy { kContiguous, kChannelsLastNdContiguous, }; void compile(); void genInputDebugNames(); void runKernel(Stack& stack) const; std::vector sizesForValue(const torch::jit::Value* v); // These functions broadcast shape and also store a `hasBroadcast_` variable. std::vector broadcastShapesMut( const std::vector& a, const std::vector& b); std::vector broadcastShapesMut( std::vector> shapes); ArgValue toArg(const torch::jit::Value* v) const; ExprHandle constant(const torch::jit::Value* v); Tensor computeValue(const torch::jit::Value* v); void bindConstant(const torch::jit::Value* v); StmtPtr transformLoops(BackendType backendType, StmtPtr st); std::string getCodeGenName(BackendType backendType); void getStaticOutputSizesAndStrides( const at::ArrayRef& inputs, std::vector>* static_sizes, std::vector>* static_strides) const; std::vector prepareRunArgs( const at::ArrayRef& inputs, std::vector& outputs) const; BackendType inferBackendTypeFromDevice(at::Device device); Tensor bindInput(const torch::jit::Value* input); BlockPtr bindAllInputs(); // Deduce the memory layout policy to be propagated within // NNC fusion group. The memory layout policy could be `kContiguous` // or `kChannelsLastNdContiguous`. // `kContiguous`: Always convert the non-contiguous input tensors and // internal buffers to contiguous. // `kChannelsLastNdContiguous`: Always convert the input tensors and // internal buffers to channels-last contiguous. // Currently, the rule is simple. // If all the input and out tensors of NNC fusion group are channels-last // contiguous, the policy is `kChannelsLastNdContiguous`. Otherwise, it // is always `kContiguous`. void deduceMemoryLayoutPolicy(); Tensor convertSymbolicOutputToCorrectStrides(torch::jit::Value* v); Tensor convertStaticShapeOutputToCorrectStrides(torch::jit::Value* v); Tensor convertSymbolicOutputToCorrectStrides( const std::vector& sizes, const std::vector& sorted_stride_indices_descending, const std::vector& strides, BufPtr& buf); NNCLoweringFunction getCustomLoweringFor(c10::Symbol op) const; std::unordered_map getCustomLowerings() const { return custom_lowerings_; } // Allocate memory for intermediate buffers at compile time. // Specifically, we pre-allocate memory for intermediate buffers with static // size and manage these buffers in the way we manage JIT constant tensors: // push the buf args into the stack so NNC IR can access them at runtime. std::vector preAllocIntermediateBufs( const std::vector& interm_bufs); struct UnpackedTensorOptions { std::optional dtype; std::optional layout; std::optional device; std::optional pinned_memory; UnpackedTensorOptions(const c10::TensorOptions& opts) : dtype(c10::optTypeMetaToScalarType(opts.dtype_opt())), layout(opts.layout_opt()), device(opts.device_opt()), pinned_memory(opts.pinned_memory_opt()) {} }; ExprHandle getVarForShape(const c10::ShapeSymbol& ss); std::vector computeInputTensorDims( const torch::jit::Value* input); ExprHandle getStrideArg(size_t tensor_input, size_t stride_index); std::vector sizesFromSymbolicShape( const c10::SymbolicShape& shape); std::vector getInputStrides( const torch::jit::Value* input, const std::vector& inputTensorDims); std::vector& getSymbolicStrideDesc( const torch::jit::Value* value); // Apply the optimizations to the graph owned by the current fusion group, // like concatenation optimization, post-op fusion, and some other graph-level // optimizations. void optimizeOwningGraph(); int64_t nInputs_ = 0; int64_t nOutputs_ = 0; std::vector bufferArgs_; std::vector> tensorOutputSizes_; std::vector> tensorOutputStrides_; std::vector tensorOutputStrideDesc_; std::vector isOutputScalar_; std::vector tensorOutputTensorOptions_; std::unordered_set bufOutputs_; std::unordered_set bufsToBeParallelized_; std::unordered_map bufs_; std::unordered_map scalars_; std::unordered_map input_name_map_; std::unique_ptr codegen_; at::Device device_ = at::kCPU; std::shared_ptr graph_; Code code_; bool allow_fallback_{false}; bool use_fallback_{false}; bool hasRandom_{false}; bool hasBroadcast_{false}; std::unordered_map> known_sizes_; std::vector> tensorOutputSymbolicSizes_; // A map from ShapeSymbol.value() to the corresponding Var. std::unordered_map shapeSymbolToVar_; std::unordered_map shapeSymbolInputPos_; // List of values corresponding to the ShapeSymbols that are inputs to // kernel being compiled. The order of these values correspond to the order // of the symbolic inputs at the end of the list of inputs to the kernel. std::vector symbolic_shape_inputs_; bool has_symbolic_shapes_{false}; std::vector unpacked_constant_tensors_; std::vector constants_; std::unordered_map custom_lowerings_; StmtPtr stmt_ = nullptr; bool pre_alloc_{false}; std::string kernel_func_name_; // index of stack, stride index of tensor that will be appended as a codegen // arg std::vector> input_stride_args_; // map from to stride as arg VarHandle std::unordered_map, VarHandle, SmallSizeTPairHash> strideArgToVar_; std::unordered_map< const torch::jit::Value*, std::vector> symbolic_strides_; // Memory layout to be propagated with fusion group MemoryLayoutPolicy memory_layout_policy_ = MemoryLayoutPolicy::kContiguous; }; TORCH_API int& getTECudaPointwiseLoopLevels(); TORCH_API int& getTECudaPointwiseBlockCount(); TORCH_API int& getTECudaPointwiseBlockSize(); TORCH_API bool& getTEGenerateBlockCode(); TORCH_API bool& getTEMustUseLLVMOnCPU(); TORCH_API bool fallbackAllowed(); TORCH_API bool setFallbackAllowed(bool value); TORCH_API bool& getCatWoConditionals(); TORCH_API bool& getOptConditionals(); TORCH_API std::optional pickDeviceType( const at::ArrayRef& inputs); bool isContiguous( const torch::jit::Value* v, at::MemoryFormat memory_format = at::MemoryFormat::Contiguous); } // namespace torch::jit::tensorexpr