-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathrun.py
66 lines (51 loc) · 1.93 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from lib.config import args, cfg
def run_mesh_extract():
from lib.datasets import make_data_loader
from lib.networks import make_network
from lib.utils.mesh_utils import extract_mesh, refuse, transform
from lib.utils.net_utils import load_network
import open3d as o3d
network = make_network(cfg).cuda()
load_network(
network,
cfg.trained_model_dir,
resume=cfg.resume,
epoch=cfg.test.epoch
)
network.eval()
data_loader = make_data_loader(cfg, is_train=False)
mesh = extract_mesh(network.model.sdf_net)
mesh = refuse(mesh, data_loader)
mesh = transform(mesh, cfg.test_dataset.scale, cfg.test_dataset.offset)
assert args.output_mesh != ''
o3d.io.write_triangle_mesh(args.output_mesh, mesh)
def print_result(result_dict):
for k, v in result_dict.items():
print(f'{k:7s}: {v:1.3f}')
def run_evaluate():
from lib.datasets import make_data_loader
from lib.evaluators import make_evaluator
from lib.networks import make_network
from lib.utils.mesh_utils import extract_mesh, refuse, transform
from lib.utils.net_utils import load_network
import open3d as o3d
network = make_network(cfg).cuda()
load_network(
network,
cfg.trained_model_dir,
resume=cfg.resume,
epoch=cfg.test.epoch
)
network.eval()
data_loader = make_data_loader(cfg, is_train=False)
evaluator = make_evaluator(cfg)
mesh = extract_mesh(network.model.sdf_net)
mesh = refuse(mesh, data_loader)
mesh = transform(mesh, cfg.test_dataset.scale, cfg.test_dataset.offset)
if args.output_mesh != '':
o3d.io.write_triangle_mesh(args.output_mesh, mesh)
mesh_gt = o3d.io.read_triangle_mesh(f'{cfg.test_dataset.data_root}/{cfg.test_dataset.scene}/gt.obj')
evaluate_result = evaluator.evaluate(mesh, mesh_gt)
print_result(evaluate_result)
if __name__ == '__main__':
globals()['run_' + args.type]()