-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfer_plants.py
89 lines (60 loc) · 2.85 KB
/
infer_plants.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#import ipywidgets as widgets
import matplotlib.pyplot as plt
import cv2
import sys
import os
import torch
from tqdm import tnrange
import pylab
from adaptis.inference.adaptis_sampling import get_panoptic_segmentation
from adaptis.inference.prediction_model import AdaptISPrediction
from adaptis.data.plants import PlantsDataset
from adaptis.model.cityscapes.models import get_cityscapes_model
device = torch.device('cuda')
dataset_path = '/home/fftai/working/pytorch/adaptis.pytorch-master/custom_dataset/custom_dataset/'
weights_path = '/home/fftai/working/pytorch/adaptis.pytorch-master/experiments/plants/000/checkpoints/proposals_last_checkpoint.pth'
dataset = PlantsDataset(dataset_path, split='val', with_segmentation=True)
model = get_cityscapes_model(num_classes=6, norm_layer=torch.nn.BatchNorm2d, backbone='resnet50', with_proposals=True)
pmodel = AdaptISPrediction(model, dataset, device)
pmodel.net.load_state_dict(torch.load(weights_path)['model_state'])
proposals_sampling_params = {
'thresh1': 0.4,
'thresh2': 0.5,
'ithresh': 0.3,
'fl_prob': 0.10,
'fl_eps': 0.003,
'fl_blur': 2,
'max_iters': 100
}
image_path = '/home/fftai/working/pytorch/adaptis.pytorch-master/custom_dataset/custom_dataset/val/stn1_syn006_pkg000_0_1_rep_rgb.png'
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pred = get_panoptic_segmentation(pmodel, image,
sampling_algorithm='proposals',
use_flip=True, **proposals_sampling_params)
pylab.imshow(pred)
def show(ix):
import pylab
pylab.figure(figsize=(20,10))
pylab.imshow((pred['instances_mask'] == ix).astype('float32')[...,None] * 0.5 + image.astype('float32')/255/2)
#widgets.interact(show, ix=widgets.BoundedIntText(min=0, max=len(pred['masks']), value=0))
from adaptis.coco.panoptic_metric import PQStat, pq_compute, print_pq_stat
def test_model(pmodel, dataset,
sampling_algorithm, sampling_params,
use_flip=False, cut_radius=-1):
pq_stat = PQStat()
categories = dataset._generate_coco_categories()
categories = {x['id']: x for x in categories}
for indx in tnrange(len(dataset)):
sample = dataset.get_sample(indx)
pred = get_panoptic_segmentation(pmodel, sample['image'],
sampling_algorithm=sampling_algorithm,
use_flip=use_flip, cut_radius=cut_radius, **sampling_params)
coco_sample = dataset.convert_to_coco_format(sample)
pred = dataset.convert_to_coco_format(pred)
pq_stat = pq_compute(pq_stat, pred, coco_sample, categories)
print_pq_stat(pq_stat, categories)
test_model(pmodel, dataset,
sampling_algorithm='proposals',
sampling_params=proposals_sampling_params,
use_flip=True)