Add ability to configure model input dtype (#14659)

* Add input type for dtype

* Add ability to manually enable TRT execution provider

* Formatting
This commit is contained in:
Nicolas Mowen
2024-10-29 09:28:05 -06:00
committed by GitHub
parent abd22d2566
commit 4e25bebdd0
4 changed files with 44 additions and 7 deletions

View File

@@ -12,7 +12,11 @@ from setproctitle import setproctitle
import frigate.util as util
from frigate.detectors import create_detector
from frigate.detectors.detector_config import BaseDetectorConfig, InputTensorEnum
from frigate.detectors.detector_config import (
BaseDetectorConfig,
InputDTypeEnum,
InputTensorEnum,
)
from frigate.detectors.plugins.rocm import DETECTOR_KEY as ROCM_DETECTOR_KEY
from frigate.util.builtin import EventsPerSecond, load_labels
from frigate.util.image import SharedMemoryFrameManager
@@ -55,12 +59,15 @@ class LocalObjectDetector(ObjectDetector):
self.input_transform = tensor_transform(
detector_config.model.input_tensor
)
self.dtype = detector_config.model.input_dtype
else:
self.input_transform = None
self.dtype = InputDTypeEnum.int
self.detect_api = create_detector(detector_config)
def detect(self, tensor_input, threshold=0.4):
def detect(self, tensor_input: np.ndarray, threshold=0.4):
detections = []
raw_detections = self.detect_raw(tensor_input)
@@ -77,9 +84,13 @@ class LocalObjectDetector(ObjectDetector):
self.fps.update()
return detections
def detect_raw(self, tensor_input):
def detect_raw(self, tensor_input: np.ndarray):
if self.input_transform:
tensor_input = np.transpose(tensor_input, self.input_transform)
if self.dtype == InputDTypeEnum.float:
tensor_input = tensor_input.astype(np.float32)
return self.detect_api.detect_raw(tensor_input=tensor_input)