forked from pinae/UnsharpDetector
-
Notifications
You must be signed in to change notification settings - Fork 0
/
inferencing_list.py
78 lines (68 loc) · 2.35 KB
/
inferencing_list.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# -*- coding: utf-8 -*-
from __future__ import division, print_function, unicode_literals
from generic_list_model import GenericListModel
from classified_image_datatype import ClassifiedImageBundle
from threading import Thread
from queue import Queue, Empty
from inference import load_model
import numpy as np
def inferencer(work_queue):
running = True
data = work_queue.get()
if type(data) == bool:
running = data
elif type(data) == ClassifiedImageBundle:
data.set_progress()
while running:
model = load_model(data.get_np_array().shape)
prediction = model.predict(np.array([data.get_np_array() / 255]), batch_size=1)
print(prediction[0])
data.set_classification(prediction[0])
work_queue.task_done()
data = work_queue.get()
if type(data) == bool:
running = data
elif type(data) == ClassifiedImageBundle:
data.set_progress()
class InferencingList(GenericListModel):
def __init__(self, *args):
super().__init__(*args)
self.work_queue = Queue()
self.queued_bundles = []
self.inferencer_thread = Thread(
target=inferencer,
args=(self.work_queue,))
self.inferencer_thread.start()
def stop_worker_thread(self):
self.clear_queue()
self.work_queue.put(False)
self.inferencer_thread.join()
def update_queue(self):
clear_necessary = False
for item in self.queued_bundles:
if item not in self.list or not item.is_undecided():
clear_necessary = True
if clear_necessary:
self.clear_queue()
for item in self.list:
if item.is_undecided() and item not in self.queued_bundles:
self.work_queue.put(item)
self.queued_bundles.append(item)
item.ani.start()
def clear_queue(self):
while not self.work_queue.empty():
try:
self.work_queue.get(False)
except Empty:
break
self.work_queue.task_done()
self.queued_bundles = []
def append(self, item):
super().append(item)
self.update_queue()
def data_changed(self, item):
super().data_changed(item)
self.update_queue()
def clear(self):
super().clear()
self.update_queue()