forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloop_unrolling.h
36 lines (26 loc) · 1006 Bytes
/
loop_unrolling.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#pragma once
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
// return true if graph is modified
TORCH_API bool UnrollLoops(std::shared_ptr<Graph>& graph);
// Only unrolls constant loops. Will unroll them regardless of loop block size
TORCH_API bool UnrollConstantLoops(std::shared_ptr<Graph>& graph);
TORCH_API Node* PeelLoop(Node* n, size_t times);
// return true if graph is modified
TORCH_API bool PeelProfilingLoops(const std::shared_ptr<Graph>& graph);
struct TORCH_API LoopsPeeler {
LoopsPeeler(std::function<bool(Node* n)> callback, size_t num_iterations = 1)
: callback_(std::move(callback)), num_iterations_(num_iterations) {}
bool run(const std::shared_ptr<Graph>& graph);
private:
void collectLoop(Node* n);
void collectLoops(Block* block);
void peelLoops();
std::function<bool(Node* n)> callback_ = nullptr;
Node* in_loop_ = nullptr;
std::list<Node*> loops_to_peel_;
size_t num_iterations_ = 1;
};
} // namespace jit
} // namespace torch