#pragma once #include #include #include namespace torch::jit { struct Graph; struct propagation_error : std::exception {}; class PropertyPropBase { // Used for both Shape Propagation and Dtype/Device Propagation public: explicit PropertyPropBase(std::shared_ptr graph) : graph_(std::move(graph)) {} virtual ~PropertyPropBase() = default; void propagateBlock(Block* block, bool insert_expands = true); // insert_expands is used for shape inference void processIf(Node* node); void processLoop(Node* node); protected: virtual void propagateNode(Node* node, bool insert_expands = true) = 0; void setUnshapedType(Value* o); void setUnshapedType(Node* node); std::shared_ptr graph_; }; TORCH_API void EraseShapeInformation(const std::shared_ptr& graph); TORCH_API void PropagateInputShapes(const std::shared_ptr& graph); TORCH_API bool mergeTypes( ArrayRef lhs, ArrayRef rhs, ArrayRef outputs); } // namespace torch::jit