#pragma once #include #include #include #include #include #include #include namespace torch::jit::fuser::onednn { using ArgSpec = LlgaTensorDesc; using ArgSpecs = std::vector; using RunArg = dnnl::graph::tensor; using RunArgs = std::vector; using TensorArgs = std::vector; class LlgaKernel { public: explicit LlgaKernel(const Node* fusionNode); void run(Stack& stack); void initialize(const TensorArgs& inputs); const std::string& debugName() const { return debugName_; } private: bool useOpaqueLayout(size_t offset) const; // PyTorch copy constants inside the subgraph instead of referencing them. // Constants inputs to the partition are no longer in the graph->inputs(). // Need use the tid retrieved from the partition to find the missing // constant inputs. void initializeConstantInputs(); ArgSpecs initializeInputSpecs(const TensorArgs& inputs); ArgSpecs initializeOutputSpecs() const; dnnl::graph::compiled_partition compile( const dnnl::graph::partition& partition); std::map initializeTensorIdToOccurence() const; std::tuple prepareRunArgs( const TensorArgs& inputs, TensorArgs& outputs) const; static std::string genDebugName() { static size_t debugId = 0; return "LlgaPartition_" + std::to_string(debugId++); } static dnnl::graph::logical_tensor toLogicalTensor(const ArgSpec& s) { return s.logical_tensor(); } at::Device device_ = at::kCPU; const Node* fusionNode_; std::shared_ptr graph_; int64_t nGraphInputs_ = 0; // number of inputs to graph_ on the IR int64_t nOutputs_ = 0; std::map tensorIdToValue_; std::vector runArgsIdx_; dnnl::graph::partition partition_; // nPartitionInputs_ is the actual number of inputs to partition_ of graph_ // needed by the backend. // nPartitionInputs_ = nGraphInputs_ + constantInputs_.size() since Constant // inputs are copied to the inside of the subgraph int64_t nPartitionInputs_; dnnl::graph::compiled_partition compilation_; std::set initializedInputIds_; std::vector constantValues_; TensorArgs constantInputs_; ArgSpecs inputSpecs_; ArgSpecs outputSpecs_; std::vector constantLogicalTensors_; std::string debugName_; c10::once_flag initialized_flag; bool is_initialized_ = false; }; } // namespace torch::jit::fuser::onednn