Migrate pydantic to V2 (#10142)

* Run pydantic migration tool

* Finish removing deprecated functions

* Formatting

* Fix movement weights type

* Fix movement weight test

* Fix config checks

* formatting

* fix typing

* formatting

* Fix

* Fix serialization issues

* Formatting

* fix model namespace warnings

* Update formatting

* Format go2rtc file

* Cleanup migrations

* Fix warnings

* Don't include null values in config json

* Formatting

* Fix test

---------

Co-authored-by: Josh Hawkins <32435876+hawkeye217@users.noreply.github.com>
This commit is contained in:
Nicolas Mowen
2024-02-29 16:10:13 -07:00
committed by GitHub
parent a1424bad6c
commit cb30450060
17 changed files with 209 additions and 149 deletions

View File

@@ -6,11 +6,19 @@ import logging
import os
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
from pydantic import BaseModel, Extra, Field, parse_obj_as, validator
from pydantic import (
BaseModel,
ConfigDict,
Field,
TypeAdapter,
ValidationInfo,
field_serializer,
field_validator,
)
from pydantic.fields import PrivateAttr
from frigate.const import (
@@ -66,8 +74,7 @@ DEFAULT_TIME_LAPSE_FFMPEG_ARGS = "-vf setpts=0.04*PTS -r 30"
class FrigateBaseModel(BaseModel):
class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid", protected_namespaces=())
class LiveModeEnum(str, Enum):
@@ -93,7 +100,7 @@ class UIConfig(FrigateBaseModel):
live_mode: LiveModeEnum = Field(
default=LiveModeEnum.mse, title="Default Live Mode."
)
timezone: Optional[str] = Field(title="Override UI timezone.")
timezone: Optional[str] = Field(default=None, title="Override UI timezone.")
use_experimental: bool = Field(default=False, title="Experimental UI")
time_format: TimeFormatEnum = Field(
default=TimeFormatEnum.browser, title="Override UI time format."
@@ -135,16 +142,17 @@ class MqttConfig(FrigateBaseModel):
topic_prefix: str = Field(default="frigate", title="MQTT Topic Prefix")
client_id: str = Field(default="frigate", title="MQTT Client ID")
stats_interval: int = Field(default=60, title="MQTT Camera Stats Interval")
user: Optional[str] = Field(title="MQTT Username")
password: Optional[str] = Field(title="MQTT Password")
tls_ca_certs: Optional[str] = Field(title="MQTT TLS CA Certificates")
tls_client_cert: Optional[str] = Field(title="MQTT TLS Client Certificate")
tls_client_key: Optional[str] = Field(title="MQTT TLS Client Key")
tls_insecure: Optional[bool] = Field(title="MQTT TLS Insecure")
user: Optional[str] = Field(None, title="MQTT Username")
password: Optional[str] = Field(None, title="MQTT Password", validate_default=True)
tls_ca_certs: Optional[str] = Field(None, title="MQTT TLS CA Certificates")
tls_client_cert: Optional[str] = Field(None, title="MQTT TLS Client Certificate")
tls_client_key: Optional[str] = Field(None, title="MQTT TLS Client Key")
tls_insecure: Optional[bool] = Field(None, title="MQTT TLS Insecure")
@validator("password", pre=True, always=True)
def validate_password(cls, v, values):
if (v is None) != (values["user"] is None):
@field_validator("password")
def user_requires_pass(cls, v, info: ValidationInfo):
print(f"doing a check where {v} is None and {info.data['user']} is None")
if (v is None) != (info.data["user"] is None):
raise ValueError("Password must be provided with username.")
return v
@@ -186,18 +194,19 @@ class PtzAutotrackConfig(FrigateBaseModel):
title="Internal value used for PTZ movements based on the speed of your camera's motor.",
)
enabled_in_config: Optional[bool] = Field(
title="Keep track of original state of autotracking."
None, title="Keep track of original state of autotracking."
)
@validator("movement_weights", pre=True)
@field_validator("movement_weights", mode="before")
@classmethod
def validate_weights(cls, v):
if v is None:
return None
if isinstance(v, str):
weights = list(map(float, v.split(",")))
weights = list(map(str, map(float, v.split(","))))
elif isinstance(v, list):
weights = [float(val) for val in v]
weights = [str(float(val)) for val in v]
else:
raise ValueError("Invalid type for movement_weights")
@@ -210,8 +219,8 @@ class PtzAutotrackConfig(FrigateBaseModel):
class OnvifConfig(FrigateBaseModel):
host: str = Field(default="", title="Onvif Host")
port: int = Field(default=8000, title="Onvif Port")
user: Optional[str] = Field(title="Onvif Username")
password: Optional[str] = Field(title="Onvif Password")
user: Optional[str] = Field(None, title="Onvif Username")
password: Optional[str] = Field(None, title="Onvif Password")
autotracking: PtzAutotrackConfig = Field(
default_factory=PtzAutotrackConfig,
title="PTZ auto tracking config.",
@@ -242,6 +251,7 @@ class EventsConfig(FrigateBaseModel):
title="List of required zones to be entered in order to save the event.",
)
objects: Optional[List[str]] = Field(
None,
title="List of objects to be detected in order to save the event.",
)
retain: RetainConfig = Field(
@@ -296,7 +306,7 @@ class RecordConfig(FrigateBaseModel):
default_factory=RecordPreviewConfig, title="Recording Preview Config"
)
enabled_in_config: Optional[bool] = Field(
title="Keep track of original state of recording."
None, title="Keep track of original state of recording."
)
@@ -324,8 +334,17 @@ class MotionConfig(FrigateBaseModel):
title="Delay for updating MQTT with no motion detected.",
)
enabled_in_config: Optional[bool] = Field(
title="Keep track of original state of motion detection."
None, title="Keep track of original state of motion detection."
)
raw_mask: Union[str, List[str]] = ""
@field_serializer("mask", when_used="json")
def serialize_mask(self, value: Any, info):
return self.raw_mask
@field_serializer("raw_mask", when_used="json")
def serialize_raw_mask(self, value: Any, info):
return None
class RuntimeMotionConfig(MotionConfig):
@@ -348,19 +367,25 @@ class RuntimeMotionConfig(MotionConfig):
super().__init__(**config)
def dict(self, **kwargs):
ret = super().dict(**kwargs)
ret = super().model_dump(**kwargs)
if "mask" in ret:
ret["mask"] = ret["raw_mask"]
ret.pop("raw_mask")
return ret
class Config:
arbitrary_types_allowed = True
extra = Extra.ignore
@field_serializer("mask", when_used="json")
def serialize_mask(self, value: Any, info):
return self.raw_mask
@field_serializer("raw_mask", when_used="json")
def serialize_raw_mask(self, value: Any, info):
return None
model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore")
class StationaryMaxFramesConfig(FrigateBaseModel):
default: Optional[int] = Field(title="Default max frames.", ge=1)
default: Optional[int] = Field(None, title="Default max frames.", ge=1)
objects: Dict[str, int] = Field(
default_factory=dict, title="Object specific max frames."
)
@@ -368,10 +393,12 @@ class StationaryMaxFramesConfig(FrigateBaseModel):
class StationaryConfig(FrigateBaseModel):
interval: Optional[int] = Field(
None,
title="Frame interval for checking stationary objects.",
gt=0,
)
threshold: Optional[int] = Field(
None,
title="Number of frames without a position change for an object to be considered stationary",
ge=1,
)
@@ -382,17 +409,21 @@ class StationaryConfig(FrigateBaseModel):
class DetectConfig(FrigateBaseModel):
height: Optional[int] = Field(title="Height of the stream for the detect role.")
width: Optional[int] = Field(title="Width of the stream for the detect role.")
height: Optional[int] = Field(
None, title="Height of the stream for the detect role."
)
width: Optional[int] = Field(None, title="Width of the stream for the detect role.")
fps: int = Field(
default=5, title="Number of frames per second to process through detection."
)
enabled: bool = Field(default=True, title="Detection Enabled.")
min_initialized: Optional[int] = Field(
title="Minimum number of consecutive hits for an object to be initialized by the tracker."
None,
title="Minimum number of consecutive hits for an object to be initialized by the tracker.",
)
max_disappeared: Optional[int] = Field(
title="Maximum number of frames the object can dissapear before detection ends."
None,
title="Maximum number of frames the object can dissapear before detection ends.",
)
stationary: StationaryConfig = Field(
default_factory=StationaryConfig,
@@ -426,8 +457,18 @@ class FilterConfig(FrigateBaseModel):
default=0.5, title="Minimum detection confidence for object to be counted."
)
mask: Optional[Union[str, List[str]]] = Field(
None,
title="Detection area polygon mask for this filter configuration.",
)
raw_mask: Union[str, List[str]] = ""
@field_serializer("mask", when_used="json")
def serialize_mask(self, value: Any, info):
return self.raw_mask
@field_serializer("raw_mask", when_used="json")
def serialize_raw_mask(self, value: Any, info):
return None
class AudioFilterConfig(FrigateBaseModel):
@@ -440,8 +481,8 @@ class AudioFilterConfig(FrigateBaseModel):
class RuntimeFilterConfig(FilterConfig):
mask: Optional[np.ndarray]
raw_mask: Optional[Union[str, List[str]]]
mask: Optional[np.ndarray] = None
raw_mask: Optional[Union[str, List[str]]] = None
def __init__(self, **config):
mask = config.get("mask")
@@ -453,15 +494,13 @@ class RuntimeFilterConfig(FilterConfig):
super().__init__(**config)
def dict(self, **kwargs):
ret = super().dict(**kwargs)
ret = super().model_dump(**kwargs)
if "mask" in ret:
ret["mask"] = ret["raw_mask"]
ret.pop("raw_mask")
return ret
class Config:
arbitrary_types_allowed = True
extra = Extra.ignore
model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore")
# this uses the base model because the color is an extra attribute
@@ -531,9 +570,11 @@ class AudioConfig(FrigateBaseModel):
listen: List[str] = Field(
default=DEFAULT_LISTEN_AUDIO, title="Audio to listen for."
)
filters: Optional[Dict[str, AudioFilterConfig]] = Field(title="Audio filters.")
filters: Optional[Dict[str, AudioFilterConfig]] = Field(
None, title="Audio filters."
)
enabled_in_config: Optional[bool] = Field(
title="Keep track of original state of audio detection."
None, title="Keep track of original state of audio detection."
)
num_threads: int = Field(default=2, title="Number of detection threads", ge=1)
@@ -660,7 +701,8 @@ class CameraInput(FrigateBaseModel):
class CameraFfmpegConfig(FfmpegConfig):
inputs: List[CameraInput] = Field(title="Camera inputs.")
@validator("inputs")
@field_validator("inputs")
@classmethod
def validate_roles(cls, v):
roles = [role for i in v for role in i.roles]
roles_set = set(roles)
@@ -690,7 +732,7 @@ class SnapshotsConfig(FrigateBaseModel):
default_factory=list,
title="List of required zones to be entered in order to save a snapshot.",
)
height: Optional[int] = Field(title="Snapshot image height.")
height: Optional[int] = Field(None, title="Snapshot image height.")
retain: RetainConfig = Field(
default_factory=RetainConfig, title="Snapshot retention."
)
@@ -727,7 +769,7 @@ class TimestampStyleConfig(FrigateBaseModel):
format: str = Field(default=DEFAULT_TIME_FORMAT, title="Timestamp format.")
color: ColorConfig = Field(default_factory=ColorConfig, title="Timestamp color.")
thickness: int = Field(default=2, title="Timestamp thickness.")
effect: Optional[TimestampEffectEnum] = Field(title="Timestamp effect.")
effect: Optional[TimestampEffectEnum] = Field(None, title="Timestamp effect.")
class CameraMqttConfig(FrigateBaseModel):
@@ -755,8 +797,7 @@ class CameraLiveConfig(FrigateBaseModel):
class RestreamConfig(BaseModel):
class Config:
extra = Extra.allow
model_config = ConfigDict(extra="allow")
class CameraUiConfig(FrigateBaseModel):
@@ -767,7 +808,7 @@ class CameraUiConfig(FrigateBaseModel):
class CameraConfig(FrigateBaseModel):
name: Optional[str] = Field(title="Camera name.", regex=REGEX_CAMERA_NAME)
name: Optional[str] = Field(None, title="Camera name.", pattern=REGEX_CAMERA_NAME)
enabled: bool = Field(default=True, title="Enable camera.")
ffmpeg: CameraFfmpegConfig = Field(title="FFmpeg configuration for the camera.")
best_image_timeout: int = Field(
@@ -775,6 +816,7 @@ class CameraConfig(FrigateBaseModel):
title="How long to wait for the image with the highest confidence score.",
)
webui_url: Optional[str] = Field(
None,
title="URL to visit the camera directly from system page",
)
zones: Dict[str, ZoneConfig] = Field(
@@ -798,7 +840,9 @@ class CameraConfig(FrigateBaseModel):
audio: AudioConfig = Field(
default_factory=AudioConfig, title="Audio events configuration."
)
motion: Optional[MotionConfig] = Field(title="Motion detection configuration.")
motion: Optional[MotionConfig] = Field(
None, title="Motion detection configuration."
)
detect: DetectConfig = Field(
default_factory=DetectConfig, title="Object detection configuration."
)
@@ -983,7 +1027,7 @@ def verify_valid_live_stream_name(
"""Verify that a restream exists to use for live view."""
if (
camera_config.live.stream_name
not in frigate_config.go2rtc.dict().get("streams", {}).keys()
not in frigate_config.go2rtc.model_dump().get("streams", {}).keys()
):
return ValueError(
f"No restream with name {camera_config.live.stream_name} exists for camera {camera_config.name}."
@@ -1108,7 +1152,7 @@ class FrigateConfig(FrigateBaseModel):
default_factory=AudioConfig, title="Global Audio events configuration."
)
motion: Optional[MotionConfig] = Field(
title="Global motion detection configuration."
None, title="Global motion detection configuration."
)
detect: DetectConfig = Field(
default_factory=DetectConfig, title="Global object tracking configuration."
@@ -1121,7 +1165,7 @@ class FrigateConfig(FrigateBaseModel):
def runtime_config(self, plus_api: PlusApi = None) -> FrigateConfig:
"""Merge camera config with globals."""
config = self.copy(deep=True)
config = self.model_copy(deep=True)
# MQTT user/password substitutions
if config.mqtt.user or config.mqtt.password:
@@ -1140,7 +1184,7 @@ class FrigateConfig(FrigateBaseModel):
config.ffmpeg.hwaccel_args = auto_detect_hwaccel()
# Global config to propagate down to camera level
global_config = config.dict(
global_config = config.model_dump(
include={
"audio": ...,
"birdseye": ...,
@@ -1157,8 +1201,10 @@ class FrigateConfig(FrigateBaseModel):
)
for name, camera in config.cameras.items():
merged_config = deep_merge(camera.dict(exclude_unset=True), global_config)
camera_config: CameraConfig = CameraConfig.parse_obj(
merged_config = deep_merge(
camera.model_dump(exclude_unset=True), global_config
)
camera_config: CameraConfig = CameraConfig.model_validate(
{"name": name, **merged_config}
)
@@ -1203,7 +1249,7 @@ class FrigateConfig(FrigateBaseModel):
)
# Default min_initialized configuration
min_initialized = camera_config.detect.fps / 2
min_initialized = int(camera_config.detect.fps / 2)
if camera_config.detect.min_initialized is None:
camera_config.detect.min_initialized = min_initialized
@@ -1267,7 +1313,7 @@ class FrigateConfig(FrigateBaseModel):
# Set runtime filter to create masks
camera_config.objects.filters[object] = RuntimeFilterConfig(
frame_shape=camera_config.frame_shape,
**filter.dict(exclude_unset=True),
**filter.model_dump(exclude_unset=True),
)
# Convert motion configuration
@@ -1279,7 +1325,7 @@ class FrigateConfig(FrigateBaseModel):
camera_config.motion = RuntimeMotionConfig(
frame_shape=camera_config.frame_shape,
raw_mask=camera_config.motion.mask,
**camera_config.motion.dict(exclude_unset=True),
**camera_config.motion.model_dump(exclude_unset=True),
)
camera_config.motion.enabled_in_config = camera_config.motion.enabled
@@ -1309,12 +1355,16 @@ class FrigateConfig(FrigateBaseModel):
config.model.check_and_load_plus_model(plus_api)
for key, detector in config.detectors.items():
detector_config: DetectorConfig = parse_obj_as(DetectorConfig, detector)
adapter = TypeAdapter(DetectorConfig)
model_dict = (
detector if isinstance(detector, dict) else detector.model_dump()
)
detector_config: DetectorConfig = adapter.validate_python(model_dict)
if detector_config.model is None:
detector_config.model = config.model
else:
model = detector_config.model
schema = ModelConfig.schema()["properties"]
schema = ModelConfig.model_json_schema()["properties"]
if (
model.width != schema["width"]["default"]
or model.height != schema["height"]["default"]
@@ -1328,8 +1378,8 @@ class FrigateConfig(FrigateBaseModel):
"Customizing more than a detector model path is unsupported."
)
merged_model = deep_merge(
detector_config.model.dict(exclude_unset=True),
config.model.dict(exclude_unset=True),
detector_config.model.model_dump(exclude_unset=True),
config.model.model_dump(exclude_unset=True),
)
if "path" not in merged_model:
@@ -1338,7 +1388,7 @@ class FrigateConfig(FrigateBaseModel):
elif detector_config.type == "edgetpu":
merged_model["path"] = "/edgetpu_model.tflite"
detector_config.model = ModelConfig.parse_obj(merged_model)
detector_config.model = ModelConfig.model_validate(merged_model)
detector_config.model.check_and_load_plus_model(
plus_api, detector_config.type
)
@@ -1347,7 +1397,8 @@ class FrigateConfig(FrigateBaseModel):
return config
@validator("cameras")
@field_validator("cameras")
@classmethod
def ensure_zones_and_cameras_have_different_names(cls, v: Dict[str, CameraConfig]):
zones = [zone for camera in v.values() for zone in camera.zones.keys()]
for zone in zones:
@@ -1365,9 +1416,9 @@ class FrigateConfig(FrigateBaseModel):
elif config_file.endswith(".json"):
config = json.loads(raw_config)
return cls.parse_obj(config)
return cls.model_validate(config)
@classmethod
def parse_raw(cls, raw_config):
config = load_config_with_no_duplicates(raw_config)
return cls.parse_obj(config)
return cls.model_validate(config)