Refactor with a working false positive test

This commit is contained in:
Blake Blackshear
2020-08-22 07:05:20 -05:00
parent a8556a729b
commit ea4ecae27c
9 changed files with 272 additions and 168 deletions

View File

@@ -2,6 +2,7 @@ import os
import datetime
import hashlib
import multiprocessing as mp
from abc import ABC, abstractmethod
import numpy as np
import pyarrow.plasma as plasma
import tflite_runtime.interpreter as tflite
@@ -27,8 +28,18 @@ def load_labels(path, encoding='utf-8'):
else:
return {index: line.strip() for index, line in enumerate(lines)}
class ObjectDetector():
def __init__(self):
class ObjectDetector(ABC):
@abstractmethod
def detect(self, tensor_input, threshold = .4):
pass
class LocalObjectDetector(ObjectDetector):
def __init__(self, labels=None):
if labels is None:
self.labels = {}
else:
self.labels = load_labels(labels)
edge_tpu_delegate = None
try:
edge_tpu_delegate = load_delegate('libedgetpu.so.1.0', {"device": "usb"})
@@ -53,6 +64,21 @@ class ObjectDetector():
self.tensor_input_details = self.interpreter.get_input_details()
self.tensor_output_details = self.interpreter.get_output_details()
def detect(self, tensor_input, threshold=.4):
detections = []
raw_detections = self.detect_raw(tensor_input)
for d in raw_detections:
if d[1] < threshold:
break
detections.append((
self.labels[int(d[0])],
float(d[1]),
(d[2], d[3], d[4], d[5])
))
return detections
def detect_raw(self, tensor_input):
self.interpreter.set_tensor(self.tensor_input_details[0]['index'], tensor_input)
self.interpreter.invoke()
@@ -70,7 +96,7 @@ def run_detector(detection_queue, avg_speed, start):
print(f"Starting detection process: {os.getpid()}")
listen()
plasma_client = plasma.connect("/tmp/plasma")
object_detector = ObjectDetector()
object_detector = LocalObjectDetector()
while True:
object_id_str = detection_queue.get()