forked from Github/frigate
Convert detectors to factory pattern, ability to set different model for each detector (#4635)
* refactor detectors * move create_detector and DetectorTypeEnum * fixed code formatting * add detector model config models * fix detector unit tests * adjust SharedMemory size to largest detector model shape * fix detector model config defaults * enable auto-discovery of detectors * simplify config * simplify config changes further * update detectors docs; detect detector configs dynamic * add suggested changes * remove custom detector doc * fix grammar, adjust device defaults
This commit is contained in:
@@ -9,7 +9,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import yaml
|
||||
from pydantic import BaseModel, Extra, Field, validator
|
||||
from pydantic import BaseModel, Extra, Field, validator, parse_obj_as
|
||||
from pydantic.fields import PrivateAttr
|
||||
|
||||
from frigate.const import (
|
||||
@@ -32,8 +32,15 @@ from frigate.ffmpeg_presets import (
|
||||
parse_preset_output_record,
|
||||
parse_preset_output_rtmp,
|
||||
)
|
||||
from frigate.detectors import (
|
||||
PixelFormatEnum,
|
||||
InputTensorEnum,
|
||||
ModelConfig,
|
||||
DetectorConfig,
|
||||
)
|
||||
from frigate.version import VERSION
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO: Identify what the default format to display timestamps is
|
||||
@@ -52,18 +59,6 @@ class FrigateBaseModel(BaseModel):
|
||||
extra = Extra.forbid
|
||||
|
||||
|
||||
class DetectorTypeEnum(str, Enum):
|
||||
edgetpu = "edgetpu"
|
||||
openvino = "openvino"
|
||||
cpu = "cpu"
|
||||
|
||||
|
||||
class DetectorConfig(FrigateBaseModel):
|
||||
type: DetectorTypeEnum = Field(default=DetectorTypeEnum.cpu, title="Detector Type")
|
||||
device: str = Field(default="usb", title="Device Type")
|
||||
num_threads: int = Field(default=3, title="Number of detection threads")
|
||||
|
||||
|
||||
class UIConfig(FrigateBaseModel):
|
||||
use_experimental: bool = Field(default=False, title="Experimental UI")
|
||||
|
||||
@@ -725,57 +720,6 @@ class DatabaseConfig(FrigateBaseModel):
|
||||
)
|
||||
|
||||
|
||||
class PixelFormatEnum(str, Enum):
|
||||
rgb = "rgb"
|
||||
bgr = "bgr"
|
||||
yuv = "yuv"
|
||||
|
||||
|
||||
class InputTensorEnum(str, Enum):
|
||||
nchw = "nchw"
|
||||
nhwc = "nhwc"
|
||||
|
||||
|
||||
class ModelConfig(FrigateBaseModel):
|
||||
path: Optional[str] = Field(title="Custom Object detection model path.")
|
||||
labelmap_path: Optional[str] = Field(title="Label map for custom object detector.")
|
||||
width: int = Field(default=320, title="Object detection model input width.")
|
||||
height: int = Field(default=320, title="Object detection model input height.")
|
||||
labelmap: Dict[int, str] = Field(
|
||||
default_factory=dict, title="Labelmap customization."
|
||||
)
|
||||
input_tensor: InputTensorEnum = Field(
|
||||
default=InputTensorEnum.nhwc, title="Model Input Tensor Shape"
|
||||
)
|
||||
input_pixel_format: PixelFormatEnum = Field(
|
||||
default=PixelFormatEnum.rgb, title="Model Input Pixel Color Format"
|
||||
)
|
||||
_merged_labelmap: Optional[Dict[int, str]] = PrivateAttr()
|
||||
_colormap: Dict[int, Tuple[int, int, int]] = PrivateAttr()
|
||||
|
||||
@property
|
||||
def merged_labelmap(self) -> Dict[int, str]:
|
||||
return self._merged_labelmap
|
||||
|
||||
@property
|
||||
def colormap(self) -> Dict[int, Tuple[int, int, int]]:
|
||||
return self._colormap
|
||||
|
||||
def __init__(self, **config):
|
||||
super().__init__(**config)
|
||||
|
||||
self._merged_labelmap = {
|
||||
**load_labels(config.get("labelmap_path", "/labelmap.txt")),
|
||||
**config.get("labelmap", {}),
|
||||
}
|
||||
|
||||
cmap = plt.cm.get_cmap("tab10", len(self._merged_labelmap.keys()))
|
||||
|
||||
self._colormap = {}
|
||||
for key, val in self._merged_labelmap.items():
|
||||
self._colormap[val] = tuple(int(round(255 * c)) for c in cmap(key)[:3])
|
||||
|
||||
|
||||
class LogLevelEnum(str, Enum):
|
||||
debug = "debug"
|
||||
info = "info"
|
||||
@@ -890,7 +834,7 @@ class FrigateConfig(FrigateBaseModel):
|
||||
default_factory=ModelConfig, title="Detection model configuration."
|
||||
)
|
||||
detectors: Dict[str, DetectorConfig] = Field(
|
||||
default={name: DetectorConfig(**d) for name, d in DEFAULT_DETECTORS.items()},
|
||||
default=DEFAULT_DETECTORS,
|
||||
title="Detector hardware configuration.",
|
||||
)
|
||||
logger: LoggerConfig = Field(
|
||||
@@ -1032,6 +976,33 @@ class FrigateConfig(FrigateBaseModel):
|
||||
# generate the ffmpeg commands
|
||||
camera_config.create_ffmpeg_cmds()
|
||||
config.cameras[name] = camera_config
|
||||
|
||||
for key, detector in config.detectors.items():
|
||||
detector_config: DetectorConfig = parse_obj_as(DetectorConfig, detector)
|
||||
if detector_config.model is None:
|
||||
detector_config.model = config.model
|
||||
else:
|
||||
model = detector_config.model
|
||||
schema = ModelConfig.schema()["properties"]
|
||||
if (
|
||||
model.width != schema["width"]["default"]
|
||||
or model.height != schema["height"]["default"]
|
||||
or model.labelmap_path is not None
|
||||
or model.labelmap is not {}
|
||||
or model.input_tensor != schema["input_tensor"]["default"]
|
||||
or model.input_pixel_format
|
||||
!= schema["input_pixel_format"]["default"]
|
||||
):
|
||||
logger.warning(
|
||||
"Customizing more than a detector model path is unsupported."
|
||||
)
|
||||
merged_model = deep_merge(
|
||||
detector_config.model.dict(exclude_unset=True),
|
||||
config.model.dict(exclude_unset=True),
|
||||
)
|
||||
detector_config.model = ModelConfig.parse_obj(merged_model)
|
||||
config.detectors[key] = detector_config
|
||||
|
||||
return config
|
||||
|
||||
@validator("cameras")
|
||||
|
||||
Reference in New Issue
Block a user