mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Follows #132604 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132753 Approved by: https://github.com/Skylion007
35 lines
981 B
C++
35 lines
981 B
C++
#pragma once
|
|
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
|
|
namespace torch::jit {
|
|
|
|
// return true if graph is modified
|
|
TORCH_API bool UnrollLoops(std::shared_ptr<Graph>& graph);
|
|
|
|
// Only unrolls constant loops. Will unroll them regardless of loop block size
|
|
TORCH_API bool UnrollConstantLoops(std::shared_ptr<Graph>& graph);
|
|
|
|
TORCH_API Node* PeelLoop(Node* n, size_t times);
|
|
|
|
// return true if graph is modified
|
|
TORCH_API bool PeelProfilingLoops(const std::shared_ptr<Graph>& graph);
|
|
|
|
struct TORCH_API LoopsPeeler {
|
|
LoopsPeeler(std::function<bool(Node* n)> callback, size_t num_iterations = 1)
|
|
: callback_(std::move(callback)), num_iterations_(num_iterations) {}
|
|
|
|
bool run(const std::shared_ptr<Graph>& graph);
|
|
|
|
private:
|
|
void collectLoop(Node* n);
|
|
void collectLoops(Block* block);
|
|
void peelLoops();
|
|
|
|
std::function<bool(Node* n)> callback_ = nullptr;
|
|
Node* in_loop_ = nullptr;
|
|
std::list<Node*> loops_to_peel_;
|
|
size_t num_iterations_ = 1;
|
|
};
|
|
} // namespace torch::jit
|