-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
176 lines (146 loc) · 6.71 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""run_infer.py
Usage:
run_infer.py [options] [--help] <command> [<args>...]
run_infer.py --version
run_infer.py (-h | --help)
Options:
-h --help Show this string.
--version Show version.
--gpu=<id> GPU list. [default: 0]
--nr_types=<n> Number of nuclei types to predict. [default: 0]
--type_info_path=<path> Path to a json define mapping between type id, type name,
and expected overlaid color. [default: '']
--model_path=<path> Path to saved checkpoint.
--model_mode=<mode> Original HoVer-Net or the reduced version used PanNuke and MoNuSAC,
'original' or 'fast'. [default: fast]
--nr_inference_workers=<n> Number of workers during inference. [default: 8]
--nr_post_proc_workers=<n> Number of workers during post-processing. [default: 16]
--batch_size=<n> Batch size per 1 GPU. [default: 32]
Two command mode are `tile` and `wsi` to enter corresponding inference mode
tile run the inference on tile
wsi run the inference on wsi
Use `run_infer.py <command> --help` to show their options and usage.
"""
tile_cli = """
Arguments for processing tiles.
usage:
tile (--input_dir=<path>) (--output_dir=<path>) \
[--draw_dot] [--save_qupath] [--save_raw_map] [--mem_usage=<n>]
options:
--input_dir=<path> Path to input data directory. Assumes the files are not nested within directory.
--output_dir=<path> Path to output directory..
--mem_usage=<n> Declare how much memory (physical + swap) should be used for caching.
By default it will load as many tiles as possible till reaching the
declared limit. [default: 0.2]
--draw_dot To draw nuclei centroid on overlay. [default: False]
--save_qupath To optionally output QuPath v0.2.3 compatible format. [default: False]
--save_raw_map To save raw prediction or not. [default: False]
"""
wsi_cli = """
Arguments for processing wsi
usage:
wsi (--input_dir=<path>) (--output_dir=<path>) [--proc_mag=<n>]\
[--cache_path=<path>] [--input_mask_dir=<path>] \
[--ambiguous_size=<n>] [--chunk_shape=<n>] [--tile_shape=<n>] \
[--save_thumb] [--save_mask]
options:
--input_dir=<path> Path to input data directory. Assumes the files are not nested within directory.
--output_dir=<path> Path to output directory.
--cache_path=<path> Path for cache. Should be placed on SSD with at least 100GB. [default: cache]
--mask_dir=<path> Path to directory containing tissue masks.
Should have the same name as corresponding WSIs. [default: '']
--proc_mag=<n> Magnification level (objective power) used for WSI processing. [default: 40]
--ambiguous_size=<int> Define ambiguous region along tiling grid to perform re-post processing. [default: 128]
--chunk_shape=<n> Shape of chunk for processing. [default: 10000]
--tile_shape=<n> Shape of tiles for processing. [default: 2048]
--save_thumb To save thumb. [default: False]
--save_mask To save mask. [default: False]
"""
import torch
import logging
import os
import copy
from hovernet.misc.utils import log_info
from docopt import docopt
torch.cuda.is_available()
#-------------------------------------------------------------------------------------------------------
if __name__ == '__main__':
sub_cli_dict = {'tile' : tile_cli, 'wsi' : wsi_cli}
args = docopt(__doc__, help=False, options_first=True,
version='HoVer-Net Pytorch Inference v1.0')
sub_cmd = args.pop('<command>')
sub_cmd_args = args.pop('<args>')
if args['--help'] and sub_cmd is not None:
if sub_cmd in sub_cli_dict:
print(sub_cli_dict[sub_cmd])
else:
print(__doc__)
exit()
if args['--help'] or sub_cmd is None:
print(__doc__)
exit()
sub_args = docopt(sub_cli_dict[sub_cmd], argv=sub_cmd_args, help=True)
args.pop('--version')
gpu_list = args.pop('--gpu')
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
nr_gpus = torch.cuda.device_count()
log_info('Detect #GPUS: %d' % nr_gpus)
args = {k.replace('--', '') : v for k, v in args.items()}
sub_args = {k.replace('--', '') : v for k, v in sub_args.items()}
if args['model_path'] == None:
raise Exception('A model path must be supplied as an argument with --model_path.')
nr_types = int(args['nr_types']) if int(args['nr_types']) > 0 else None
method_args = {
'method' : {
'model_args' : {
'nr_types' : nr_types,
'mode' : args['model_mode'],
},
'model_path' : args['model_path'],
},
'type_info_path' : None if args['type_info_path'] == '' \
else args['type_info_path'],
}
# ***
run_args = {
'batch_size' : int(args['batch_size']) * nr_gpus,
'nr_inference_workers' : int(args['nr_inference_workers']),
'nr_post_proc_workers' : int(args['nr_post_proc_workers']),
}
if args['model_mode'] == 'fast':
run_args['patch_input_shape'] = 256
run_args['patch_output_shape'] = 164
else:
run_args['patch_input_shape'] = 270
run_args['patch_output_shape'] = 80
if sub_cmd == 'tile':
run_args.update({
'input_dir' : sub_args['input_dir'],
'output_dir' : sub_args['output_dir'],
'mem_usage' : float(sub_args['mem_usage']),
'draw_dot' : sub_args['draw_dot'],
'save_qupath' : sub_args['save_qupath'],
'save_raw_map': sub_args['save_raw_map'],
})
if sub_cmd == 'wsi':
run_args.update({
'input_dir' : sub_args['input_dir'],
'output_dir' : sub_args['output_dir'],
'input_mask_dir' : sub_args['input_mask_dir'],
'cache_path' : sub_args['cache_path'],
'proc_mag' : int(sub_args['proc_mag']),
'ambiguous_size' : int(sub_args['ambiguous_size']),
'chunk_shape' : int(sub_args['chunk_shape']),
'tile_shape' : int(sub_args['tile_shape']),
'save_thumb' : sub_args['save_thumb'],
'save_mask' : sub_args['save_mask'],
})
# ***
if sub_cmd == 'tile':
from hovernet.infer.tile import InferManager
infer = InferManager(**method_args)
infer.process_file_list(run_args)
else:
from hovernet.infer.wsi import InferManager
infer = InferManager(**method_args)
infer.process_wsi_list(run_args)