add labelmap customization to the config (fixes #507)

This commit is contained in:
Blake Blackshear
2021-07-07 22:57:19 -05:00
parent a3853af47a
commit 92aa16c556
7 changed files with 111 additions and 16 deletions

View File

@@ -259,6 +259,7 @@ class FrigateApp:
name,
config,
model_shape,
self.config.model.merged_labelmap,
self.detection_queue,
self.detection_out_events[name],
self.detected_frames_queue,

View File

@@ -13,6 +13,7 @@ from pydantic.fields import PrivateAttr
import yaml
from frigate.const import BASE_DIR, RECORD_DIR, CACHE_DIR
from frigate.edgetpu import load_labels
from frigate.util import create_mask, deep_merge
logger = logging.getLogger(__name__)
@@ -615,6 +616,22 @@ class DatabaseConfig(BaseModel):
class ModelConfig(BaseModel):
width: int = Field(default=320, title="Object detection model input width.")
height: int = Field(default=320, title="Object detection model input height.")
labelmap: Dict[int, str] = Field(
default_factory=dict, title="Labelmap customization."
)
_merged_labelmap: Optional[Dict[int, str]] = PrivateAttr()
@property
def merged_labelmap(self) -> Dict[int, str]:
return self._merged_labelmap
def __init__(self, **config):
super().__init__(**config)
self._merged_labelmap = {
**load_labels("/labelmap.txt"),
**config.get("labelmap", {}),
}
class LogLevelEnum(str, Enum):

View File

@@ -231,7 +231,7 @@ class EdgeTPUProcess:
class RemoteObjectDetector:
def __init__(self, name, labels, detection_queue, event, model_shape):
self.labels = load_labels(labels)
self.labels = labels
self.name = name
self.fps = EventsPerSecond()
self.detection_queue = detection_queue

View File

@@ -503,6 +503,86 @@ class TestConfig(unittest.TestCase):
runtime_config = frigate_config.runtime_config
assert round(runtime_config.cameras["back"].motion.contour_area) == 99
def test_merge_labelmap(self):
config = {
"mqtt": {"host": "mqtt"},
"model": {"labelmap": {7: "truck"}},
"cameras": {
"back": {
"ffmpeg": {
"inputs": [
{
"path": "rtsp://10.0.0.1:554/video",
"roles": ["detect"],
},
]
},
"height": 1080,
"width": 1920,
}
},
}
frigate_config = FrigateConfig(**config)
assert config == frigate_config.dict(exclude_unset=True)
runtime_config = frigate_config.runtime_config
assert runtime_config.model.merged_labelmap[7] == "truck"
def test_default_labelmap_empty(self):
config = {
"mqtt": {"host": "mqtt"},
"cameras": {
"back": {
"ffmpeg": {
"inputs": [
{
"path": "rtsp://10.0.0.1:554/video",
"roles": ["detect"],
},
]
},
"height": 1080,
"width": 1920,
}
},
}
frigate_config = FrigateConfig(**config)
assert config == frigate_config.dict(exclude_unset=True)
runtime_config = frigate_config.runtime_config
assert runtime_config.model.merged_labelmap[0] == "person"
def test_default_labelmap(self):
config = {
"mqtt": {"host": "mqtt"},
"model": {"width": 320, "height": 320},
"cameras": {
"back": {
"ffmpeg": {
"inputs": [
{
"path": "rtsp://10.0.0.1:554/video",
"roles": ["detect"],
},
]
},
"height": 1080,
"width": 1920,
}
},
}
frigate_config = FrigateConfig(**config)
assert config == frigate_config.dict(exclude_unset=True)
runtime_config = frigate_config.runtime_config
assert runtime_config.model.merged_labelmap[0] == "person"
if __name__ == "__main__":
unittest.main(verbosity=2)

View File

@@ -318,6 +318,7 @@ def track_camera(
name,
config: CameraConfig,
model_shape,
labelmap,
detection_queue,
result_connection,
detected_objects_queue,
@@ -344,7 +345,7 @@ def track_camera(
motion_detector = MotionDetector(frame_shape, config.motion)
object_detector = RemoteObjectDetector(
name, "/labelmap.txt", detection_queue, result_connection, model_shape
name, labelmap, detection_queue, result_connection, model_shape
)
object_tracker = ObjectTracker(config.detect)