forked from Github/frigate
Refactor with a working false positive test
This commit is contained in:
@@ -2,6 +2,7 @@ import os
|
||||
import datetime
|
||||
import hashlib
|
||||
import multiprocessing as mp
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
import pyarrow.plasma as plasma
|
||||
import tflite_runtime.interpreter as tflite
|
||||
@@ -27,8 +28,18 @@ def load_labels(path, encoding='utf-8'):
|
||||
else:
|
||||
return {index: line.strip() for index, line in enumerate(lines)}
|
||||
|
||||
class ObjectDetector():
|
||||
def __init__(self):
|
||||
class ObjectDetector(ABC):
|
||||
@abstractmethod
|
||||
def detect(self, tensor_input, threshold = .4):
|
||||
pass
|
||||
|
||||
class LocalObjectDetector(ObjectDetector):
|
||||
def __init__(self, labels=None):
|
||||
if labels is None:
|
||||
self.labels = {}
|
||||
else:
|
||||
self.labels = load_labels(labels)
|
||||
|
||||
edge_tpu_delegate = None
|
||||
try:
|
||||
edge_tpu_delegate = load_delegate('libedgetpu.so.1.0', {"device": "usb"})
|
||||
@@ -53,6 +64,21 @@ class ObjectDetector():
|
||||
self.tensor_input_details = self.interpreter.get_input_details()
|
||||
self.tensor_output_details = self.interpreter.get_output_details()
|
||||
|
||||
def detect(self, tensor_input, threshold=.4):
|
||||
detections = []
|
||||
|
||||
raw_detections = self.detect_raw(tensor_input)
|
||||
|
||||
for d in raw_detections:
|
||||
if d[1] < threshold:
|
||||
break
|
||||
detections.append((
|
||||
self.labels[int(d[0])],
|
||||
float(d[1]),
|
||||
(d[2], d[3], d[4], d[5])
|
||||
))
|
||||
return detections
|
||||
|
||||
def detect_raw(self, tensor_input):
|
||||
self.interpreter.set_tensor(self.tensor_input_details[0]['index'], tensor_input)
|
||||
self.interpreter.invoke()
|
||||
@@ -70,7 +96,7 @@ def run_detector(detection_queue, avg_speed, start):
|
||||
print(f"Starting detection process: {os.getpid()}")
|
||||
listen()
|
||||
plasma_client = plasma.connect("/tmp/plasma")
|
||||
object_detector = ObjectDetector()
|
||||
object_detector = LocalObjectDetector()
|
||||
|
||||
while True:
|
||||
object_id_str = detection_queue.get()
|
||||
|
||||
Reference in New Issue
Block a user