forked from Github/frigate
Add ability to configure model input dtype (#14659)
* Add input type for dtype * Add ability to manually enable TRT execution provider * Formatting
This commit is contained in:
@@ -12,7 +12,11 @@ from setproctitle import setproctitle
|
||||
|
||||
import frigate.util as util
|
||||
from frigate.detectors import create_detector
|
||||
from frigate.detectors.detector_config import BaseDetectorConfig, InputTensorEnum
|
||||
from frigate.detectors.detector_config import (
|
||||
BaseDetectorConfig,
|
||||
InputDTypeEnum,
|
||||
InputTensorEnum,
|
||||
)
|
||||
from frigate.detectors.plugins.rocm import DETECTOR_KEY as ROCM_DETECTOR_KEY
|
||||
from frigate.util.builtin import EventsPerSecond, load_labels
|
||||
from frigate.util.image import SharedMemoryFrameManager
|
||||
@@ -55,12 +59,15 @@ class LocalObjectDetector(ObjectDetector):
|
||||
self.input_transform = tensor_transform(
|
||||
detector_config.model.input_tensor
|
||||
)
|
||||
|
||||
self.dtype = detector_config.model.input_dtype
|
||||
else:
|
||||
self.input_transform = None
|
||||
self.dtype = InputDTypeEnum.int
|
||||
|
||||
self.detect_api = create_detector(detector_config)
|
||||
|
||||
def detect(self, tensor_input, threshold=0.4):
|
||||
def detect(self, tensor_input: np.ndarray, threshold=0.4):
|
||||
detections = []
|
||||
|
||||
raw_detections = self.detect_raw(tensor_input)
|
||||
@@ -77,9 +84,13 @@ class LocalObjectDetector(ObjectDetector):
|
||||
self.fps.update()
|
||||
return detections
|
||||
|
||||
def detect_raw(self, tensor_input):
|
||||
def detect_raw(self, tensor_input: np.ndarray):
|
||||
if self.input_transform:
|
||||
tensor_input = np.transpose(tensor_input, self.input_transform)
|
||||
|
||||
if self.dtype == InputDTypeEnum.float:
|
||||
tensor_input = tensor_input.astype(np.float32)
|
||||
|
||||
return self.detect_api.detect_raw(tensor_input=tensor_input)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user