-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathserver.py
161 lines (134 loc) · 5.05 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from flask import Flask, jsonify, abort, make_response, request
from flask_cors import CORS
from threading import Lock
from time import sleep
from utils import base64tocv2, intersection_over_union
import time
import itertools
import heapq
import logging
from utils import vote_cliques, get_cliques
log = logging.getLogger('werkzeug')
log.setLevel(logging.ERROR)
frame_buffers = []
models = {}
frameId = 0
model_results = None
reload_queue = None # send model_names as msg to reload
inf_ready_queue = None
fps_stats = []
app = Flask(__name__)
CORS(app)
reload_model_lock = Lock()
def detect(image, model, frame_id, conf=0.2, iou=0.45, mode="parallel"):
if frame_buffers[models[model]].full():
frame_buffers[models[model]].get(False)
frame_buffers[models[model]].put((frame_id, image, float(conf), float(iou), mode))
def init(api_results, MODELS_IN_USE, frameBuffers, adminQueue, infReadyQueue):
global models, frame_buffers, model_results, reload_queue, inf_ready_queue
reload_queue = adminQueue
inf_ready_queue = infReadyQueue
for mi, model in enumerate(MODELS_IN_USE):
models[model[0]] = mi
frame_buffers.extend(frameBuffers)
model_results = api_results
def get_fps_stats():
while fps_stats and fps_stats[0] < time.time() - 1:
heapq.heappop(fps_stats)
return len(fps_stats)
def record_fps():
heapq.heappush(fps_stats, time.time())
@app.route('/detect_objects', methods=['POST'])
def detect_objects():
global frameId
if not request.json or 'image' not in request.json:
abort(400)
frameId += 1
image = base64tocv2(request.json['image'])
response = {}
for model in request.json['models']:
conf, iou, model_name = model['conf'], model['iou'], model['model']
detect(image, model_name, frameId, conf, iou, request.json['mode'])
return jsonify(response), 201
@app.route('/detect_objects_response', methods=['GET'])
def detect_objects_response():
global models
model_names = request.args.get('models', "")
model_names = model_names.split(",") if model_names else []
response = {}
execution_mode = None
min_t_iou = 1.0
for model_name in model_names:
model_index = models[model_name]
objects_detected = []
try:
objects, execution_mode, t_iou = model_results[model_index].get(timeout=1)[1:]
for obj in objects:
objects_detected.append({'bbox': [obj.xmin, obj.ymin, obj.xmax - obj.xmin, obj.ymax - obj.ymin],
'class': obj.name, 'score': float(obj.confidence)})
if t_iou < min_t_iou:
min_t_iou = t_iou
except:
pass
response[model_name] = objects_detected
if execution_mode == 'ensemble':
"""
# Ensemble Detection: Use non-maximum suppression to keep only those detected objects with high confidence
objects = list(sorted(list(itertools.chain.from_iterable(response.values())),
key=lambda obj: obj['score'], reverse=True))
skip_ids = []
for i in range(len(objects)):
for j in range(i + 1, len(objects)):
if intersection_over_union(objects[i], objects[j]) > min_t_iou:
skip_ids.append(j)
response = {'all': []}
for i, obj in enumerate(objects):
if i not in skip_ids:
response['all'].append(obj)
"""
# Ensemble Detection: Use majority voting on Graph Cliques algorithm
# The criterion to choose a box prediction for a clique is the box
# the box predicted with highest score
cliques, G = get_cliques(response)
predictions = vote_cliques(cliques, G)
# Get all models
response = {'all': []}
for pred in predictions:
# We do not need the model that gave the prediction
pred.pop('model')
response['all'].append(pred)
if model_names:
record_fps()
response["fps"] = get_fps_stats()
else:
response["fps"] = None
return jsonify(response), 201
@app.route('/reload_models', methods=['GET'])
def reload_models():
with reload_model_lock:
model_names = request.args.get('models', "")
model_names = model_names.split(",") if model_names else []
model_names = [m for m in model_names if m in models]
if not model_names:
return
reload_queue.put(model_names)
for q in frame_buffers + model_results:
while not q.empty():
q.get()
inf_ready_queue.get()
sleep(0.5)
return jsonify(model_names), 201
@app.route('/shutdown', methods=['POST'])
def shutdown():
global frame_buffers
global model_results
shutdown_hook = request.environ.get('werkzeug.server.shutdown')
for q in frame_buffers + model_results:
while not q.empty():
q.get()
if shutdown_hook is not None:
shutdown_hook()
return "", 200
@app.errorhandler(404)
def not_found(error):
return make_response(jsonify({'error': 'Not found'}), 404)