-
Notifications
You must be signed in to change notification settings - Fork 3
/
vis_test.py
58 lines (53 loc) · 2.05 KB
/
vis_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import time
import datetime
#import mmcv
import cv2 as cv
import json
import numpy as np
import pycocotools.mask as maskutil
import pycocotools.coco as COCO
from itertools import groupby
from skimage import measure,draw,data
from PIL import Image
import matplotlib.pyplot as plt
def get_index(image_id,load_dict):#get seglist and label list by image_id
seg_list = []
label_list = []
for i in range(len(load_dict['annotations'])):
if image_id == load_dict['annotations'][i]['image_id']:
seg_list.append(i)
label_list.append(load_dict['annotations'][i]['category_id'])
return seg_list,label_list
def get_color(class_id):#for Distinguish different classes
return class_id*50
with open('train_restriction.json','r') as f:
load_dict = json.load(f)
paths = os.listdir('restricted')
for im_path in paths:
im = cv.imread('restricted/'+im_path)
print(im.shape)
print(im_path)
print(int(im_path[:-4]))
seg_list,label_list = get_index(int(im_path[:-4]),load_dict)
#print(seg_list)
#print(label_list)
#masks = np.zeros((im.shape[0],im.shape[1], 1), np.uint8)
seg = []
masks = []
cnt = 0
for seg_idx in seg_list:
seg = load_dict['annotations'][seg_idx]['segmentation'][0] #load first seg in seg list
compactedRLE = maskutil.frPyObjects([seg], im.shape[0], im.shape[1]) #compress through RLE
mask = maskutil.decode(compactedRLE) #decode to mask
print(mask.shape)
mask=np.reshape(mask,(im.shape[0],im.shape[1])) #for display
mask = mask*get_color(label_list[cnt]) #change color for different class
masks.append(mask) #add sub mask for a full mask
print(mask.shape)
cnt+=1
final_mask = np.zeros((im.shape[0],im.shape[1]), np.uint8) #final mask for each img
for mask in masks: #merge all mask into final mask
final_mask = final_mask+mask
plt.imshow(final_mask) #show final mask
plt.show()