forked from Github/frigate
Migrate pydantic to V2 (#10142)
* Run pydantic migration tool * Finish removing deprecated functions * Formatting * Fix movement weights type * Fix movement weight test * Fix config checks * formatting * fix typing * formatting * Fix * Fix serialization issues * Formatting * fix model namespace warnings * Update formatting * Format go2rtc file * Cleanup migrations * Fix warnings * Don't include null values in config json * Formatting * Fix test --------- Co-authored-by: Josh Hawkins <32435876+hawkeye217@users.noreply.github.com>
This commit is contained in:
@@ -6,11 +6,19 @@ import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, Extra, Field, parse_obj_as, validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
TypeAdapter,
|
||||
ValidationInfo,
|
||||
field_serializer,
|
||||
field_validator,
|
||||
)
|
||||
from pydantic.fields import PrivateAttr
|
||||
|
||||
from frigate.const import (
|
||||
@@ -66,8 +74,7 @@ DEFAULT_TIME_LAPSE_FFMPEG_ARGS = "-vf setpts=0.04*PTS -r 30"
|
||||
|
||||
|
||||
class FrigateBaseModel(BaseModel):
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
model_config = ConfigDict(extra="forbid", protected_namespaces=())
|
||||
|
||||
|
||||
class LiveModeEnum(str, Enum):
|
||||
@@ -93,7 +100,7 @@ class UIConfig(FrigateBaseModel):
|
||||
live_mode: LiveModeEnum = Field(
|
||||
default=LiveModeEnum.mse, title="Default Live Mode."
|
||||
)
|
||||
timezone: Optional[str] = Field(title="Override UI timezone.")
|
||||
timezone: Optional[str] = Field(default=None, title="Override UI timezone.")
|
||||
use_experimental: bool = Field(default=False, title="Experimental UI")
|
||||
time_format: TimeFormatEnum = Field(
|
||||
default=TimeFormatEnum.browser, title="Override UI time format."
|
||||
@@ -135,16 +142,17 @@ class MqttConfig(FrigateBaseModel):
|
||||
topic_prefix: str = Field(default="frigate", title="MQTT Topic Prefix")
|
||||
client_id: str = Field(default="frigate", title="MQTT Client ID")
|
||||
stats_interval: int = Field(default=60, title="MQTT Camera Stats Interval")
|
||||
user: Optional[str] = Field(title="MQTT Username")
|
||||
password: Optional[str] = Field(title="MQTT Password")
|
||||
tls_ca_certs: Optional[str] = Field(title="MQTT TLS CA Certificates")
|
||||
tls_client_cert: Optional[str] = Field(title="MQTT TLS Client Certificate")
|
||||
tls_client_key: Optional[str] = Field(title="MQTT TLS Client Key")
|
||||
tls_insecure: Optional[bool] = Field(title="MQTT TLS Insecure")
|
||||
user: Optional[str] = Field(None, title="MQTT Username")
|
||||
password: Optional[str] = Field(None, title="MQTT Password", validate_default=True)
|
||||
tls_ca_certs: Optional[str] = Field(None, title="MQTT TLS CA Certificates")
|
||||
tls_client_cert: Optional[str] = Field(None, title="MQTT TLS Client Certificate")
|
||||
tls_client_key: Optional[str] = Field(None, title="MQTT TLS Client Key")
|
||||
tls_insecure: Optional[bool] = Field(None, title="MQTT TLS Insecure")
|
||||
|
||||
@validator("password", pre=True, always=True)
|
||||
def validate_password(cls, v, values):
|
||||
if (v is None) != (values["user"] is None):
|
||||
@field_validator("password")
|
||||
def user_requires_pass(cls, v, info: ValidationInfo):
|
||||
print(f"doing a check where {v} is None and {info.data['user']} is None")
|
||||
if (v is None) != (info.data["user"] is None):
|
||||
raise ValueError("Password must be provided with username.")
|
||||
return v
|
||||
|
||||
@@ -186,18 +194,19 @@ class PtzAutotrackConfig(FrigateBaseModel):
|
||||
title="Internal value used for PTZ movements based on the speed of your camera's motor.",
|
||||
)
|
||||
enabled_in_config: Optional[bool] = Field(
|
||||
title="Keep track of original state of autotracking."
|
||||
None, title="Keep track of original state of autotracking."
|
||||
)
|
||||
|
||||
@validator("movement_weights", pre=True)
|
||||
@field_validator("movement_weights", mode="before")
|
||||
@classmethod
|
||||
def validate_weights(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
|
||||
if isinstance(v, str):
|
||||
weights = list(map(float, v.split(",")))
|
||||
weights = list(map(str, map(float, v.split(","))))
|
||||
elif isinstance(v, list):
|
||||
weights = [float(val) for val in v]
|
||||
weights = [str(float(val)) for val in v]
|
||||
else:
|
||||
raise ValueError("Invalid type for movement_weights")
|
||||
|
||||
@@ -210,8 +219,8 @@ class PtzAutotrackConfig(FrigateBaseModel):
|
||||
class OnvifConfig(FrigateBaseModel):
|
||||
host: str = Field(default="", title="Onvif Host")
|
||||
port: int = Field(default=8000, title="Onvif Port")
|
||||
user: Optional[str] = Field(title="Onvif Username")
|
||||
password: Optional[str] = Field(title="Onvif Password")
|
||||
user: Optional[str] = Field(None, title="Onvif Username")
|
||||
password: Optional[str] = Field(None, title="Onvif Password")
|
||||
autotracking: PtzAutotrackConfig = Field(
|
||||
default_factory=PtzAutotrackConfig,
|
||||
title="PTZ auto tracking config.",
|
||||
@@ -242,6 +251,7 @@ class EventsConfig(FrigateBaseModel):
|
||||
title="List of required zones to be entered in order to save the event.",
|
||||
)
|
||||
objects: Optional[List[str]] = Field(
|
||||
None,
|
||||
title="List of objects to be detected in order to save the event.",
|
||||
)
|
||||
retain: RetainConfig = Field(
|
||||
@@ -296,7 +306,7 @@ class RecordConfig(FrigateBaseModel):
|
||||
default_factory=RecordPreviewConfig, title="Recording Preview Config"
|
||||
)
|
||||
enabled_in_config: Optional[bool] = Field(
|
||||
title="Keep track of original state of recording."
|
||||
None, title="Keep track of original state of recording."
|
||||
)
|
||||
|
||||
|
||||
@@ -324,8 +334,17 @@ class MotionConfig(FrigateBaseModel):
|
||||
title="Delay for updating MQTT with no motion detected.",
|
||||
)
|
||||
enabled_in_config: Optional[bool] = Field(
|
||||
title="Keep track of original state of motion detection."
|
||||
None, title="Keep track of original state of motion detection."
|
||||
)
|
||||
raw_mask: Union[str, List[str]] = ""
|
||||
|
||||
@field_serializer("mask", when_used="json")
|
||||
def serialize_mask(self, value: Any, info):
|
||||
return self.raw_mask
|
||||
|
||||
@field_serializer("raw_mask", when_used="json")
|
||||
def serialize_raw_mask(self, value: Any, info):
|
||||
return None
|
||||
|
||||
|
||||
class RuntimeMotionConfig(MotionConfig):
|
||||
@@ -348,19 +367,25 @@ class RuntimeMotionConfig(MotionConfig):
|
||||
super().__init__(**config)
|
||||
|
||||
def dict(self, **kwargs):
|
||||
ret = super().dict(**kwargs)
|
||||
ret = super().model_dump(**kwargs)
|
||||
if "mask" in ret:
|
||||
ret["mask"] = ret["raw_mask"]
|
||||
ret.pop("raw_mask")
|
||||
return ret
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = Extra.ignore
|
||||
@field_serializer("mask", when_used="json")
|
||||
def serialize_mask(self, value: Any, info):
|
||||
return self.raw_mask
|
||||
|
||||
@field_serializer("raw_mask", when_used="json")
|
||||
def serialize_raw_mask(self, value: Any, info):
|
||||
return None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore")
|
||||
|
||||
|
||||
class StationaryMaxFramesConfig(FrigateBaseModel):
|
||||
default: Optional[int] = Field(title="Default max frames.", ge=1)
|
||||
default: Optional[int] = Field(None, title="Default max frames.", ge=1)
|
||||
objects: Dict[str, int] = Field(
|
||||
default_factory=dict, title="Object specific max frames."
|
||||
)
|
||||
@@ -368,10 +393,12 @@ class StationaryMaxFramesConfig(FrigateBaseModel):
|
||||
|
||||
class StationaryConfig(FrigateBaseModel):
|
||||
interval: Optional[int] = Field(
|
||||
None,
|
||||
title="Frame interval for checking stationary objects.",
|
||||
gt=0,
|
||||
)
|
||||
threshold: Optional[int] = Field(
|
||||
None,
|
||||
title="Number of frames without a position change for an object to be considered stationary",
|
||||
ge=1,
|
||||
)
|
||||
@@ -382,17 +409,21 @@ class StationaryConfig(FrigateBaseModel):
|
||||
|
||||
|
||||
class DetectConfig(FrigateBaseModel):
|
||||
height: Optional[int] = Field(title="Height of the stream for the detect role.")
|
||||
width: Optional[int] = Field(title="Width of the stream for the detect role.")
|
||||
height: Optional[int] = Field(
|
||||
None, title="Height of the stream for the detect role."
|
||||
)
|
||||
width: Optional[int] = Field(None, title="Width of the stream for the detect role.")
|
||||
fps: int = Field(
|
||||
default=5, title="Number of frames per second to process through detection."
|
||||
)
|
||||
enabled: bool = Field(default=True, title="Detection Enabled.")
|
||||
min_initialized: Optional[int] = Field(
|
||||
title="Minimum number of consecutive hits for an object to be initialized by the tracker."
|
||||
None,
|
||||
title="Minimum number of consecutive hits for an object to be initialized by the tracker.",
|
||||
)
|
||||
max_disappeared: Optional[int] = Field(
|
||||
title="Maximum number of frames the object can dissapear before detection ends."
|
||||
None,
|
||||
title="Maximum number of frames the object can dissapear before detection ends.",
|
||||
)
|
||||
stationary: StationaryConfig = Field(
|
||||
default_factory=StationaryConfig,
|
||||
@@ -426,8 +457,18 @@ class FilterConfig(FrigateBaseModel):
|
||||
default=0.5, title="Minimum detection confidence for object to be counted."
|
||||
)
|
||||
mask: Optional[Union[str, List[str]]] = Field(
|
||||
None,
|
||||
title="Detection area polygon mask for this filter configuration.",
|
||||
)
|
||||
raw_mask: Union[str, List[str]] = ""
|
||||
|
||||
@field_serializer("mask", when_used="json")
|
||||
def serialize_mask(self, value: Any, info):
|
||||
return self.raw_mask
|
||||
|
||||
@field_serializer("raw_mask", when_used="json")
|
||||
def serialize_raw_mask(self, value: Any, info):
|
||||
return None
|
||||
|
||||
|
||||
class AudioFilterConfig(FrigateBaseModel):
|
||||
@@ -440,8 +481,8 @@ class AudioFilterConfig(FrigateBaseModel):
|
||||
|
||||
|
||||
class RuntimeFilterConfig(FilterConfig):
|
||||
mask: Optional[np.ndarray]
|
||||
raw_mask: Optional[Union[str, List[str]]]
|
||||
mask: Optional[np.ndarray] = None
|
||||
raw_mask: Optional[Union[str, List[str]]] = None
|
||||
|
||||
def __init__(self, **config):
|
||||
mask = config.get("mask")
|
||||
@@ -453,15 +494,13 @@ class RuntimeFilterConfig(FilterConfig):
|
||||
super().__init__(**config)
|
||||
|
||||
def dict(self, **kwargs):
|
||||
ret = super().dict(**kwargs)
|
||||
ret = super().model_dump(**kwargs)
|
||||
if "mask" in ret:
|
||||
ret["mask"] = ret["raw_mask"]
|
||||
ret.pop("raw_mask")
|
||||
return ret
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = Extra.ignore
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore")
|
||||
|
||||
|
||||
# this uses the base model because the color is an extra attribute
|
||||
@@ -531,9 +570,11 @@ class AudioConfig(FrigateBaseModel):
|
||||
listen: List[str] = Field(
|
||||
default=DEFAULT_LISTEN_AUDIO, title="Audio to listen for."
|
||||
)
|
||||
filters: Optional[Dict[str, AudioFilterConfig]] = Field(title="Audio filters.")
|
||||
filters: Optional[Dict[str, AudioFilterConfig]] = Field(
|
||||
None, title="Audio filters."
|
||||
)
|
||||
enabled_in_config: Optional[bool] = Field(
|
||||
title="Keep track of original state of audio detection."
|
||||
None, title="Keep track of original state of audio detection."
|
||||
)
|
||||
num_threads: int = Field(default=2, title="Number of detection threads", ge=1)
|
||||
|
||||
@@ -660,7 +701,8 @@ class CameraInput(FrigateBaseModel):
|
||||
class CameraFfmpegConfig(FfmpegConfig):
|
||||
inputs: List[CameraInput] = Field(title="Camera inputs.")
|
||||
|
||||
@validator("inputs")
|
||||
@field_validator("inputs")
|
||||
@classmethod
|
||||
def validate_roles(cls, v):
|
||||
roles = [role for i in v for role in i.roles]
|
||||
roles_set = set(roles)
|
||||
@@ -690,7 +732,7 @@ class SnapshotsConfig(FrigateBaseModel):
|
||||
default_factory=list,
|
||||
title="List of required zones to be entered in order to save a snapshot.",
|
||||
)
|
||||
height: Optional[int] = Field(title="Snapshot image height.")
|
||||
height: Optional[int] = Field(None, title="Snapshot image height.")
|
||||
retain: RetainConfig = Field(
|
||||
default_factory=RetainConfig, title="Snapshot retention."
|
||||
)
|
||||
@@ -727,7 +769,7 @@ class TimestampStyleConfig(FrigateBaseModel):
|
||||
format: str = Field(default=DEFAULT_TIME_FORMAT, title="Timestamp format.")
|
||||
color: ColorConfig = Field(default_factory=ColorConfig, title="Timestamp color.")
|
||||
thickness: int = Field(default=2, title="Timestamp thickness.")
|
||||
effect: Optional[TimestampEffectEnum] = Field(title="Timestamp effect.")
|
||||
effect: Optional[TimestampEffectEnum] = Field(None, title="Timestamp effect.")
|
||||
|
||||
|
||||
class CameraMqttConfig(FrigateBaseModel):
|
||||
@@ -755,8 +797,7 @@ class CameraLiveConfig(FrigateBaseModel):
|
||||
|
||||
|
||||
class RestreamConfig(BaseModel):
|
||||
class Config:
|
||||
extra = Extra.allow
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class CameraUiConfig(FrigateBaseModel):
|
||||
@@ -767,7 +808,7 @@ class CameraUiConfig(FrigateBaseModel):
|
||||
|
||||
|
||||
class CameraConfig(FrigateBaseModel):
|
||||
name: Optional[str] = Field(title="Camera name.", regex=REGEX_CAMERA_NAME)
|
||||
name: Optional[str] = Field(None, title="Camera name.", pattern=REGEX_CAMERA_NAME)
|
||||
enabled: bool = Field(default=True, title="Enable camera.")
|
||||
ffmpeg: CameraFfmpegConfig = Field(title="FFmpeg configuration for the camera.")
|
||||
best_image_timeout: int = Field(
|
||||
@@ -775,6 +816,7 @@ class CameraConfig(FrigateBaseModel):
|
||||
title="How long to wait for the image with the highest confidence score.",
|
||||
)
|
||||
webui_url: Optional[str] = Field(
|
||||
None,
|
||||
title="URL to visit the camera directly from system page",
|
||||
)
|
||||
zones: Dict[str, ZoneConfig] = Field(
|
||||
@@ -798,7 +840,9 @@ class CameraConfig(FrigateBaseModel):
|
||||
audio: AudioConfig = Field(
|
||||
default_factory=AudioConfig, title="Audio events configuration."
|
||||
)
|
||||
motion: Optional[MotionConfig] = Field(title="Motion detection configuration.")
|
||||
motion: Optional[MotionConfig] = Field(
|
||||
None, title="Motion detection configuration."
|
||||
)
|
||||
detect: DetectConfig = Field(
|
||||
default_factory=DetectConfig, title="Object detection configuration."
|
||||
)
|
||||
@@ -983,7 +1027,7 @@ def verify_valid_live_stream_name(
|
||||
"""Verify that a restream exists to use for live view."""
|
||||
if (
|
||||
camera_config.live.stream_name
|
||||
not in frigate_config.go2rtc.dict().get("streams", {}).keys()
|
||||
not in frigate_config.go2rtc.model_dump().get("streams", {}).keys()
|
||||
):
|
||||
return ValueError(
|
||||
f"No restream with name {camera_config.live.stream_name} exists for camera {camera_config.name}."
|
||||
@@ -1108,7 +1152,7 @@ class FrigateConfig(FrigateBaseModel):
|
||||
default_factory=AudioConfig, title="Global Audio events configuration."
|
||||
)
|
||||
motion: Optional[MotionConfig] = Field(
|
||||
title="Global motion detection configuration."
|
||||
None, title="Global motion detection configuration."
|
||||
)
|
||||
detect: DetectConfig = Field(
|
||||
default_factory=DetectConfig, title="Global object tracking configuration."
|
||||
@@ -1121,7 +1165,7 @@ class FrigateConfig(FrigateBaseModel):
|
||||
|
||||
def runtime_config(self, plus_api: PlusApi = None) -> FrigateConfig:
|
||||
"""Merge camera config with globals."""
|
||||
config = self.copy(deep=True)
|
||||
config = self.model_copy(deep=True)
|
||||
|
||||
# MQTT user/password substitutions
|
||||
if config.mqtt.user or config.mqtt.password:
|
||||
@@ -1140,7 +1184,7 @@ class FrigateConfig(FrigateBaseModel):
|
||||
config.ffmpeg.hwaccel_args = auto_detect_hwaccel()
|
||||
|
||||
# Global config to propagate down to camera level
|
||||
global_config = config.dict(
|
||||
global_config = config.model_dump(
|
||||
include={
|
||||
"audio": ...,
|
||||
"birdseye": ...,
|
||||
@@ -1157,8 +1201,10 @@ class FrigateConfig(FrigateBaseModel):
|
||||
)
|
||||
|
||||
for name, camera in config.cameras.items():
|
||||
merged_config = deep_merge(camera.dict(exclude_unset=True), global_config)
|
||||
camera_config: CameraConfig = CameraConfig.parse_obj(
|
||||
merged_config = deep_merge(
|
||||
camera.model_dump(exclude_unset=True), global_config
|
||||
)
|
||||
camera_config: CameraConfig = CameraConfig.model_validate(
|
||||
{"name": name, **merged_config}
|
||||
)
|
||||
|
||||
@@ -1203,7 +1249,7 @@ class FrigateConfig(FrigateBaseModel):
|
||||
)
|
||||
|
||||
# Default min_initialized configuration
|
||||
min_initialized = camera_config.detect.fps / 2
|
||||
min_initialized = int(camera_config.detect.fps / 2)
|
||||
if camera_config.detect.min_initialized is None:
|
||||
camera_config.detect.min_initialized = min_initialized
|
||||
|
||||
@@ -1267,7 +1313,7 @@ class FrigateConfig(FrigateBaseModel):
|
||||
# Set runtime filter to create masks
|
||||
camera_config.objects.filters[object] = RuntimeFilterConfig(
|
||||
frame_shape=camera_config.frame_shape,
|
||||
**filter.dict(exclude_unset=True),
|
||||
**filter.model_dump(exclude_unset=True),
|
||||
)
|
||||
|
||||
# Convert motion configuration
|
||||
@@ -1279,7 +1325,7 @@ class FrigateConfig(FrigateBaseModel):
|
||||
camera_config.motion = RuntimeMotionConfig(
|
||||
frame_shape=camera_config.frame_shape,
|
||||
raw_mask=camera_config.motion.mask,
|
||||
**camera_config.motion.dict(exclude_unset=True),
|
||||
**camera_config.motion.model_dump(exclude_unset=True),
|
||||
)
|
||||
camera_config.motion.enabled_in_config = camera_config.motion.enabled
|
||||
|
||||
@@ -1309,12 +1355,16 @@ class FrigateConfig(FrigateBaseModel):
|
||||
config.model.check_and_load_plus_model(plus_api)
|
||||
|
||||
for key, detector in config.detectors.items():
|
||||
detector_config: DetectorConfig = parse_obj_as(DetectorConfig, detector)
|
||||
adapter = TypeAdapter(DetectorConfig)
|
||||
model_dict = (
|
||||
detector if isinstance(detector, dict) else detector.model_dump()
|
||||
)
|
||||
detector_config: DetectorConfig = adapter.validate_python(model_dict)
|
||||
if detector_config.model is None:
|
||||
detector_config.model = config.model
|
||||
else:
|
||||
model = detector_config.model
|
||||
schema = ModelConfig.schema()["properties"]
|
||||
schema = ModelConfig.model_json_schema()["properties"]
|
||||
if (
|
||||
model.width != schema["width"]["default"]
|
||||
or model.height != schema["height"]["default"]
|
||||
@@ -1328,8 +1378,8 @@ class FrigateConfig(FrigateBaseModel):
|
||||
"Customizing more than a detector model path is unsupported."
|
||||
)
|
||||
merged_model = deep_merge(
|
||||
detector_config.model.dict(exclude_unset=True),
|
||||
config.model.dict(exclude_unset=True),
|
||||
detector_config.model.model_dump(exclude_unset=True),
|
||||
config.model.model_dump(exclude_unset=True),
|
||||
)
|
||||
|
||||
if "path" not in merged_model:
|
||||
@@ -1338,7 +1388,7 @@ class FrigateConfig(FrigateBaseModel):
|
||||
elif detector_config.type == "edgetpu":
|
||||
merged_model["path"] = "/edgetpu_model.tflite"
|
||||
|
||||
detector_config.model = ModelConfig.parse_obj(merged_model)
|
||||
detector_config.model = ModelConfig.model_validate(merged_model)
|
||||
detector_config.model.check_and_load_plus_model(
|
||||
plus_api, detector_config.type
|
||||
)
|
||||
@@ -1347,7 +1397,8 @@ class FrigateConfig(FrigateBaseModel):
|
||||
|
||||
return config
|
||||
|
||||
@validator("cameras")
|
||||
@field_validator("cameras")
|
||||
@classmethod
|
||||
def ensure_zones_and_cameras_have_different_names(cls, v: Dict[str, CameraConfig]):
|
||||
zones = [zone for camera in v.values() for zone in camera.zones.keys()]
|
||||
for zone in zones:
|
||||
@@ -1365,9 +1416,9 @@ class FrigateConfig(FrigateBaseModel):
|
||||
elif config_file.endswith(".json"):
|
||||
config = json.loads(raw_config)
|
||||
|
||||
return cls.parse_obj(config)
|
||||
return cls.model_validate(config)
|
||||
|
||||
@classmethod
|
||||
def parse_raw(cls, raw_config):
|
||||
config = load_config_with_no_duplicates(raw_config)
|
||||
return cls.parse_obj(config)
|
||||
return cls.model_validate(config)
|
||||
|
||||
Reference in New Issue
Block a user