Semantic Search for Detections (#11899)

* Initial re-implementation of semantic search

* put docker-compose back and make reindex match docs

* remove debug code and fix import

* fix docs

* manually build pysqlite3 as binaries are only available for x86-64

* update comment in build_pysqlite3.sh

* only embed objects

* better error handling when genai fails

* ask ollama to pull requested model at startup

* update ollama docs

* address some PR review comments

* fix lint

* use IPC to write description, update docs for reindex

* remove gemini-pro-vision from docs as it will be unavailable soon

* fix OpenAI doc available models

* fix api error in gemini and metadata for embeddings
This commit is contained in:
Jason Hunter
2024-06-21 17:30:19 -04:00
committed by Nicolas Mowen
parent f4f3cfa911
commit 36cbffcc5e
48 changed files with 1244 additions and 166 deletions

View File

@@ -1,9 +1,12 @@
import faulthandler
import sys
import threading
from flask import cli
from frigate.app import FrigateApp
# Hotsawp the sqlite3 module for Chroma compatibility
__import__("pysqlite3")
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
faulthandler.enable()
@@ -12,6 +15,8 @@ threading.current_thread().name = "frigate"
cli.show_server_banner = lambda *x: None
if __name__ == "__main__":
from frigate.app import FrigateApp
frigate_app = FrigateApp()
frigate_app.start()

View File

@@ -454,6 +454,7 @@ def logs(service: str):
"frigate": "/dev/shm/logs/frigate/current",
"go2rtc": "/dev/shm/logs/go2rtc/current",
"nginx": "/dev/shm/logs/nginx/current",
"chroma": "/dev/shm/logs/chroma/current",
}
service_location = log_locations.get(service)

View File

@@ -22,11 +22,11 @@ from pydantic import ValidationError
from frigate.api.app import create_app
from frigate.api.auth import hash_password
from frigate.comms.config_updater import ConfigPublisher
from frigate.comms.detections_updater import DetectionProxy
from frigate.comms.dispatcher import Communicator, Dispatcher
from frigate.comms.inter_process import InterProcessCommunicator
from frigate.comms.mqtt import MqttClient
from frigate.comms.ws import WebSocketClient
from frigate.comms.zmq_proxy import ZmqProxy
from frigate.config import FrigateConfig
from frigate.const import (
CACHE_DIR,
@@ -37,6 +37,8 @@ from frigate.const import (
MODEL_CACHE_DIR,
RECORD_DIR,
)
from frigate.embeddings import manage_embeddings
from frigate.embeddings.embeddings import Embeddings
from frigate.events.audio import listen_to_audio
from frigate.events.cleanup import EventCleanup
from frigate.events.external import ExternalEventProcessor
@@ -316,7 +318,21 @@ class FrigateApp:
self.review_segment_process = review_segment_process
review_segment_process.start()
self.processes["review_segment"] = review_segment_process.pid or 0
logger.info(f"Recording process started: {review_segment_process.pid}")
logger.info(f"Review process started: {review_segment_process.pid}")
def init_embeddings_manager(self) -> None:
# Create a client for other processes to use
self.embeddings = Embeddings()
embedding_process = mp.Process(
target=manage_embeddings,
name="embeddings_manager",
args=(self.config,),
)
embedding_process.daemon = True
self.embedding_process = embedding_process
embedding_process.start()
self.processes["embeddings"] = embedding_process.pid or 0
logger.info(f"Embedding process started: {embedding_process.pid}")
def bind_database(self) -> None:
"""Bind db to the main process."""
@@ -362,7 +378,7 @@ class FrigateApp:
def init_inter_process_communicator(self) -> None:
self.inter_process_communicator = InterProcessCommunicator()
self.inter_config_updater = ConfigPublisher()
self.inter_detection_proxy = DetectionProxy()
self.inter_zmq_proxy = ZmqProxy()
def init_web_server(self) -> None:
self.flask_app = create_app(
@@ -678,6 +694,7 @@ class FrigateApp:
self.init_onvif()
self.init_recording_manager()
self.init_review_segment_manager()
self.init_embeddings_manager()
self.init_go2rtc()
self.bind_database()
self.check_db_data_migrations()
@@ -797,7 +814,7 @@ class FrigateApp:
# Stop Communicators
self.inter_process_communicator.stop()
self.inter_config_updater.stop()
self.inter_detection_proxy.stop()
self.inter_zmq_proxy.stop()
while len(self.detection_shms) > 0:
shm = self.detection_shms.pop()

View File

@@ -1,14 +1,9 @@
"""Facilitates communication between processes."""
import threading
from enum import Enum
from typing import Optional
import zmq
SOCKET_CONTROL = "inproc://control.detections_updater"
SOCKET_PUB = "ipc:///tmp/cache/detect_pub"
SOCKET_SUB = "ipc:///tmp/cache/detect_sub"
from .zmq_proxy import Publisher, Subscriber
class DetectionTypeEnum(str, Enum):
@@ -18,85 +13,31 @@ class DetectionTypeEnum(str, Enum):
audio = "audio"
class DetectionProxyRunner(threading.Thread):
def __init__(self, context: zmq.Context[zmq.Socket]) -> None:
threading.Thread.__init__(self)
self.name = "detection_proxy"
self.context = context
def run(self) -> None:
"""Run the proxy."""
control = self.context.socket(zmq.REP)
control.connect(SOCKET_CONTROL)
incoming = self.context.socket(zmq.XSUB)
incoming.bind(SOCKET_PUB)
outgoing = self.context.socket(zmq.XPUB)
outgoing.bind(SOCKET_SUB)
zmq.proxy_steerable(
incoming, outgoing, None, control
) # blocking, will unblock terminate message is received
incoming.close()
outgoing.close()
class DetectionProxy:
"""Proxies video and audio detections."""
def __init__(self) -> None:
self.context = zmq.Context()
self.control = self.context.socket(zmq.REQ)
self.control.bind(SOCKET_CONTROL)
self.runner = DetectionProxyRunner(self.context)
self.runner.start()
def stop(self) -> None:
self.control.send("TERMINATE".encode()) # tell the proxy to stop
self.runner.join()
self.context.destroy()
class DetectionPublisher:
class DetectionPublisher(Publisher):
"""Simplifies receiving video and audio detections."""
topic_base = "detection/"
def __init__(self, topic: DetectionTypeEnum) -> None:
self.topic = topic
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.socket.connect(SOCKET_PUB)
def send_data(self, payload: any) -> None:
"""Publish detection."""
self.socket.send_string(self.topic.value, flags=zmq.SNDMORE)
self.socket.send_json(payload)
def stop(self) -> None:
self.socket.close()
self.context.destroy()
topic = topic.value
super().__init__(topic)
class DetectionSubscriber:
class DetectionSubscriber(Subscriber):
"""Simplifies receiving video and audio detections."""
topic_base = "detection/"
def __init__(self, topic: DetectionTypeEnum) -> None:
self.context = zmq.Context()
self.socket = self.context.socket(zmq.SUB)
self.socket.setsockopt_string(zmq.SUBSCRIBE, topic.value)
self.socket.connect(SOCKET_SUB)
topic = topic.value
super().__init__(topic)
def get_data(self, timeout: float = None) -> Optional[tuple[str, any]]:
"""Returns detections or None if no update."""
try:
has_update, _, _ = zmq.select([self.socket], [], [], timeout)
def check_for_update(
self, timeout: float = None
) -> Optional[tuple[DetectionTypeEnum, any]]:
return super().check_for_update(timeout)
if has_update:
topic = DetectionTypeEnum[self.socket.recv_string(flags=zmq.NOBLOCK)]
return (topic, self.socket.recv_json())
except zmq.ZMQError:
pass
return (None, None)
def stop(self) -> None:
self.socket.close()
self.context.destroy()
def _return_object(self, topic: str, payload: any) -> any:
if payload is None:
return (None, None)
return (DetectionTypeEnum[topic[len(self.topic_base) :]], payload)

View File

@@ -14,9 +14,10 @@ from frigate.const import (
INSERT_PREVIEW,
REQUEST_REGION_GRID,
UPDATE_CAMERA_ACTIVITY,
UPDATE_EVENT_DESCRIPTION,
UPSERT_REVIEW_SEGMENT,
)
from frigate.models import Previews, Recordings, ReviewSegment
from frigate.models import Event, Previews, Recordings, ReviewSegment
from frigate.ptz.onvif import OnvifCommandEnum, OnvifController
from frigate.types import PTZMetricsTypes
from frigate.util.object import get_camera_regions_grid
@@ -128,6 +129,10 @@ class Dispatcher:
).execute()
elif topic == UPDATE_CAMERA_ACTIVITY:
self.camera_activity = payload
elif topic == UPDATE_EVENT_DESCRIPTION:
event: Event = Event.get(Event.id == payload["id"])
event.data["description"] = payload["description"]
event.save()
elif topic == "onConnect":
camera_status = self.camera_activity.copy()

View File

@@ -1,100 +1,51 @@
"""Facilitates communication between processes."""
import zmq
from frigate.events.types import EventStateEnum, EventTypeEnum
SOCKET_PUSH_PULL = "ipc:///tmp/cache/events"
SOCKET_PUSH_PULL_END = "ipc:///tmp/cache/events_ended"
from .zmq_proxy import Publisher, Subscriber
class EventUpdatePublisher:
class EventUpdatePublisher(Publisher):
"""Publishes events (objects, audio, manual)."""
topic_base = "event/"
def __init__(self) -> None:
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUSH)
self.socket.connect(SOCKET_PUSH_PULL)
super().__init__("update")
def publish(
self, payload: tuple[EventTypeEnum, EventStateEnum, str, dict[str, any]]
) -> None:
"""There is no communication back to the processes."""
self.socket.send_json(payload)
def stop(self) -> None:
self.socket.close()
self.context.destroy()
super().publish(payload)
class EventUpdateSubscriber:
class EventUpdateSubscriber(Subscriber):
"""Receives event updates."""
topic_base = "event/"
def __init__(self) -> None:
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL)
self.socket.bind(SOCKET_PUSH_PULL)
def check_for_update(
self, timeout=1
) -> tuple[EventTypeEnum, EventStateEnum, str, dict[str, any]]:
"""Returns events or None if no update."""
try:
has_update, _, _ = zmq.select([self.socket], [], [], timeout)
if has_update:
return self.socket.recv_json()
except zmq.ZMQError:
pass
return None
def stop(self) -> None:
self.socket.close()
self.context.destroy()
super().__init__("update")
class EventEndPublisher:
class EventEndPublisher(Publisher):
"""Publishes events that have ended."""
topic_base = "event/"
def __init__(self) -> None:
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUSH)
self.socket.connect(SOCKET_PUSH_PULL_END)
super().__init__("finalized")
def publish(
self, payload: tuple[EventTypeEnum, EventStateEnum, str, dict[str, any]]
) -> None:
"""There is no communication back to the processes."""
self.socket.send_json(payload)
def stop(self) -> None:
self.socket.close()
self.context.destroy()
super().publish(payload)
class EventEndSubscriber:
class EventEndSubscriber(Subscriber):
"""Receives events that have ended."""
topic_base = "event/"
def __init__(self) -> None:
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL)
self.socket.bind(SOCKET_PUSH_PULL_END)
def check_for_update(
self, timeout=1
) -> tuple[EventTypeEnum, EventStateEnum, str, dict[str, any]]:
"""Returns events ended or None if no update."""
try:
has_update, _, _ = zmq.select([self.socket], [], [], timeout)
if has_update:
return self.socket.recv_json()
except zmq.ZMQError:
pass
return None
def stop(self) -> None:
self.socket.close()
self.context.destroy()
super().__init__("finalized")

100
frigate/comms/zmq_proxy.py Normal file
View File

@@ -0,0 +1,100 @@
"""Facilitates communication over zmq proxy."""
import threading
from typing import Optional
import zmq
SOCKET_PUB = "ipc:///tmp/cache/proxy_pub"
SOCKET_SUB = "ipc:///tmp/cache/proxy_sub"
class ZmqProxyRunner(threading.Thread):
def __init__(self, context: zmq.Context[zmq.Socket]) -> None:
threading.Thread.__init__(self)
self.name = "detection_proxy"
self.context = context
def run(self) -> None:
"""Run the proxy."""
incoming = self.context.socket(zmq.XSUB)
incoming.bind(SOCKET_PUB)
outgoing = self.context.socket(zmq.XPUB)
outgoing.bind(SOCKET_SUB)
# Blocking: This will unblock (via exception) when we destroy the context
# The incoming and outgoing sockets will be closed automatically
# when the context is destroyed as well.
try:
zmq.proxy(incoming, outgoing)
except zmq.ZMQError:
pass
class ZmqProxy:
"""Proxies video and audio detections."""
def __init__(self) -> None:
self.context = zmq.Context()
self.runner = ZmqProxyRunner(self.context)
self.runner.start()
def stop(self) -> None:
# destroying the context will tell the proxy to stop
self.context.destroy()
self.runner.join()
class Publisher:
"""Publishes messages."""
topic_base: str = ""
def __init__(self, topic: str = "") -> None:
self.topic = f"{self.topic_base}{topic}"
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.socket.connect(SOCKET_PUB)
def publish(self, payload: any, sub_topic: str = "") -> None:
"""Publish message."""
self.socket.send_string(f"{self.topic}{sub_topic}", flags=zmq.SNDMORE)
self.socket.send_json(payload)
def stop(self) -> None:
self.socket.close()
self.context.destroy()
class Subscriber:
"""Receives messages."""
topic_base: str = ""
def __init__(self, topic: str = "") -> None:
self.topic = f"{self.topic_base}{topic}"
self.context = zmq.Context()
self.socket = self.context.socket(zmq.SUB)
self.socket.setsockopt_string(zmq.SUBSCRIBE, self.topic)
self.socket.connect(SOCKET_SUB)
def check_for_update(self, timeout: float = 1) -> Optional[tuple[str, any]]:
"""Returns message or None if no update."""
try:
has_update, _, _ = zmq.select([self.socket], [], [], timeout)
if has_update:
topic = self.socket.recv_string(flags=zmq.NOBLOCK)
payload = self.socket.recv_json()
return self._return_object(topic, payload)
except zmq.ZMQError:
pass
return self._return_object("", None)
def stop(self) -> None:
self.socket.close()
self.context.destroy()
def _return_object(self, topic: str, payload: any) -> any:
return payload

View File

@@ -730,6 +730,38 @@ class ReviewConfig(FrigateBaseModel):
)
class SemanticSearchConfig(FrigateBaseModel):
enabled: bool = Field(default=True, title="Enable semantic search.")
reindex: Optional[bool] = Field(
default=False, title="Reindex all detections on startup."
)
class GenAIProviderEnum(str, Enum):
openai = "openai"
gemini = "gemini"
ollama = "ollama"
class GenAIConfig(FrigateBaseModel):
enabled: bool = Field(default=False, title="Enable GenAI.")
provider: GenAIProviderEnum = Field(
default=GenAIProviderEnum.openai, title="GenAI provider."
)
base_url: Optional[str] = Field(None, title="Provider base url.")
api_key: Optional[str] = Field(None, title="Provider API key.")
model: str = Field(default="gpt-4o", title="GenAI model.")
prompt: str = Field(
default="Describe the {label} in the sequence of images with as much detail as possible. Do not describe the background.",
title="Default caption prompt.",
)
object_prompts: Dict[str, str] = Field(default={}, title="Object specific prompts.")
class GenAICameraConfig(FrigateBaseModel):
enabled: bool = Field(default=False, title="Enable GenAI for camera.")
class AudioConfig(FrigateBaseModel):
enabled: bool = Field(default=False, title="Enable audio events.")
max_not_heard: int = Field(
@@ -1011,6 +1043,9 @@ class CameraConfig(FrigateBaseModel):
review: ReviewConfig = Field(
default_factory=ReviewConfig, title="Review configuration."
)
genai: GenAICameraConfig = Field(
default_factory=GenAICameraConfig, title="Generative AI configuration."
)
audio: AudioConfig = Field(
default_factory=AudioConfig, title="Audio events configuration."
)
@@ -1363,6 +1398,12 @@ class FrigateConfig(FrigateBaseModel):
review: ReviewConfig = Field(
default_factory=ReviewConfig, title="Review configuration."
)
semantic_search: SemanticSearchConfig = Field(
default_factory=SemanticSearchConfig, title="Semantic search configuration."
)
genai: GenAIConfig = Field(
default_factory=GenAIConfig, title="Generative AI configuration."
)
audio: AudioConfig = Field(
default_factory=AudioConfig, title="Global Audio events configuration."
)
@@ -1397,6 +1438,10 @@ class FrigateConfig(FrigateBaseModel):
config.mqtt.user = config.mqtt.user.format(**FRIGATE_ENV_VARS)
config.mqtt.password = config.mqtt.password.format(**FRIGATE_ENV_VARS)
# GenAI substitution
if config.genai.api_key:
config.genai.api_key = config.genai.api_key.format(**FRIGATE_ENV_VARS)
# set default min_score for object attributes
for attribute in ALL_ATTRIBUTE_LABELS:
if not config.objects.filters.get(attribute):
@@ -1418,6 +1463,7 @@ class FrigateConfig(FrigateBaseModel):
"live": ...,
"objects": ...,
"review": ...,
"genai": {"enabled"},
"motion": ...,
"detect": ...,
"ffmpeg": ...,

View File

@@ -81,6 +81,7 @@ REQUEST_REGION_GRID = "request_region_grid"
UPSERT_REVIEW_SEGMENT = "upsert_review_segment"
CLEAR_ONGOING_REVIEW_SEGMENTS = "clear_ongoing_review_segments"
UPDATE_CAMERA_ACTIVITY = "update_camera_activity"
UPDATE_EVENT_DESCRIPTION = "update_event_description"
# Stats Values

View File

@@ -0,0 +1,67 @@
"""ChromaDB embeddings database."""
import logging
import multiprocessing as mp
import signal
import sys
import threading
from types import FrameType
from typing import Optional
from playhouse.sqliteq import SqliteQueueDatabase
from setproctitle import setproctitle
from frigate.config import FrigateConfig
from frigate.models import Event
from frigate.util.services import listen
logger = logging.getLogger(__name__)
def manage_embeddings(config: FrigateConfig) -> None:
# Only initialize embeddings if semantic search is enabled
if not config.semantic_search.enabled:
return
stop_event = mp.Event()
def receiveSignal(signalNumber: int, frame: Optional[FrameType]) -> None:
stop_event.set()
signal.signal(signal.SIGTERM, receiveSignal)
signal.signal(signal.SIGINT, receiveSignal)
threading.current_thread().name = "process:embeddings_manager"
setproctitle("frigate.embeddings_manager")
listen()
# Configure Frigate DB
db = SqliteQueueDatabase(
config.database.path,
pragmas={
"auto_vacuum": "FULL", # Does not defragment database
"cache_size": -512 * 1000, # 512MB of cache
"synchronous": "NORMAL", # Safe when using WAL https://www.sqlite.org/pragma.html#pragma_synchronous
},
timeout=max(60, 10 * len([c for c in config.cameras.values() if c.enabled])),
)
models = [Event]
db.bind(models)
# Hotsawp the sqlite3 module for Chroma compatibility
__import__("pysqlite3")
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
from .embeddings import Embeddings
from .maintainer import EmbeddingMaintainer
embeddings = Embeddings()
# Check if we need to re-index events
if config.semantic_search.reindex:
embeddings.reindex()
maintainer = EmbeddingMaintainer(
config,
stop_event,
)
maintainer.start()

View File

@@ -0,0 +1,122 @@
"""ChromaDB embeddings database."""
import base64
import io
import logging
import time
import numpy as np
from chromadb import Collection
from chromadb import HttpClient as ChromaClient
from chromadb.config import Settings
from PIL import Image
from playhouse.shortcuts import model_to_dict
from frigate.models import Event
from .functions.clip import ClipEmbedding
from .functions.minilm_l6_v2 import MiniLMEmbedding
logger = logging.getLogger(__name__)
def get_metadata(event: Event) -> dict:
"""Extract valid event metadata."""
event_dict = model_to_dict(event)
return (
{
k: v
for k, v in event_dict.items()
if k not in ["id", "thumbnail"]
and v is not None
and isinstance(v, (str, int, float, bool))
}
| {
k: v
for k, v in event_dict["data"].items()
if k not in ["description"]
and v is not None
and isinstance(v, (str, int, float, bool))
}
| {
# Metadata search doesn't support $contains
# and an event can have multiple zones, so
# we need to create a key for each zone
f"{k}_{x}": True
for k, v in event_dict.items()
if isinstance(v, list) and len(v) > 0
for x in v
if isinstance(x, str)
}
)
class Embeddings:
"""ChromaDB embeddings database."""
def __init__(self) -> None:
self.client: ChromaClient = ChromaClient(
host="127.0.0.1",
settings=Settings(anonymized_telemetry=False),
)
@property
def thumbnail(self) -> Collection:
return self.client.get_or_create_collection(
name="event_thumbnail", embedding_function=ClipEmbedding()
)
@property
def description(self) -> Collection:
return self.client.get_or_create_collection(
name="event_description", embedding_function=MiniLMEmbedding()
)
def reindex(self) -> None:
"""Reindex all event embeddings."""
logger.info("Indexing event embeddings...")
self.client.reset()
st = time.time()
thumbnails = {"ids": [], "images": [], "metadatas": []}
descriptions = {"ids": [], "documents": [], "metadatas": []}
events = Event.select().where(
(Event.has_clip == True | Event.has_snapshot == True)
& Event.thumbnail.is_null(False)
)
event: Event
for event in events.iterator():
metadata = get_metadata(event)
thumbnail = base64.b64decode(event.thumbnail)
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
thumbnails["ids"].append(event.id)
thumbnails["images"].append(img)
thumbnails["metadatas"].append(metadata)
if event.data.get("description") is not None:
descriptions["ids"].append(event.id)
descriptions["documents"].append(event.data["description"])
descriptions["metadatas"].append(metadata)
if len(thumbnails["ids"]) > 0:
self.thumbnail.upsert(
images=thumbnails["images"],
metadatas=thumbnails["metadatas"],
ids=thumbnails["ids"],
)
if len(descriptions["ids"]) > 0:
self.description.upsert(
documents=descriptions["documents"],
metadatas=descriptions["metadatas"],
ids=descriptions["ids"],
)
logger.info(
"Embedded %d thumbnails and %d descriptions in %s seconds",
len(thumbnails["ids"]),
len(descriptions["ids"]),
time.time() - st,
)

View File

@@ -0,0 +1,63 @@
"""CLIP Embeddings for Frigate."""
import os
from typing import Tuple, Union
import onnxruntime as ort
from chromadb import EmbeddingFunction, Embeddings
from chromadb.api.types import (
Documents,
Images,
is_document,
is_image,
)
from onnx_clip import OnnxClip
from frigate.const import MODEL_CACHE_DIR
class Clip(OnnxClip):
"""Override load models to download to cache directory."""
@staticmethod
def _load_models(
model: str,
silent: bool,
) -> Tuple[ort.InferenceSession, ort.InferenceSession]:
"""
These models are a part of the container. Treat as as such.
"""
if model == "ViT-B/32":
IMAGE_MODEL_FILE = "clip_image_model_vitb32.onnx"
TEXT_MODEL_FILE = "clip_text_model_vitb32.onnx"
elif model == "RN50":
IMAGE_MODEL_FILE = "clip_image_model_rn50.onnx"
TEXT_MODEL_FILE = "clip_text_model_rn50.onnx"
else:
raise ValueError(f"Unexpected model {model}. No `.onnx` file found.")
models = []
for model_file in [IMAGE_MODEL_FILE, TEXT_MODEL_FILE]:
path = os.path.join(MODEL_CACHE_DIR, "clip", model_file)
models.append(OnnxClip._load_model(path, silent))
return models[0], models[1]
class ClipEmbedding(EmbeddingFunction):
"""Embedding function for CLIP model used in Chroma."""
def __init__(self, model: str = "ViT-B/32"):
"""Initialize CLIP Embedding function."""
self.model = Clip(model)
def __call__(self, input: Union[Documents, Images]) -> Embeddings:
embeddings: Embeddings = []
for item in input:
if is_image(item):
result = self.model.get_image_embeddings([item])
embeddings.append(result[0, :].tolist())
elif is_document(item):
result = self.model.get_text_embeddings([item])
embeddings.append(result[0, :].tolist())
return embeddings

View File

@@ -0,0 +1,11 @@
"""Embedding function for ONNX MiniLM-L6 model used in Chroma."""
from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2
from frigate.const import MODEL_CACHE_DIR
class MiniLMEmbedding(ONNXMiniLM_L6_V2):
"""Override DOWNLOAD_PATH to download to cache directory."""
DOWNLOAD_PATH = f"{MODEL_CACHE_DIR}/all-MiniLM-L6-v2"

View File

@@ -0,0 +1,197 @@
"""Maintain embeddings in Chroma."""
import base64
import io
import logging
import threading
from multiprocessing.synchronize import Event as MpEvent
from typing import Optional
import cv2
import numpy as np
from peewee import DoesNotExist
from PIL import Image
from frigate.comms.events_updater import EventEndSubscriber, EventUpdateSubscriber
from frigate.comms.inter_process import InterProcessRequestor
from frigate.config import FrigateConfig
from frigate.const import UPDATE_EVENT_DESCRIPTION
from frigate.events.types import EventTypeEnum
from frigate.genai import get_genai_client
from frigate.models import Event
from frigate.util.image import SharedMemoryFrameManager, calculate_region
from .embeddings import Embeddings, get_metadata
logger = logging.getLogger(__name__)
class EmbeddingMaintainer(threading.Thread):
"""Handle embedding queue and post event updates."""
def __init__(
self,
config: FrigateConfig,
stop_event: MpEvent,
) -> None:
threading.Thread.__init__(self)
self.name = "embeddings_maintainer"
self.config = config
self.embeddings = Embeddings()
self.event_subscriber = EventUpdateSubscriber()
self.event_end_subscriber = EventEndSubscriber()
self.frame_manager = SharedMemoryFrameManager()
# create communication for updating event descriptions
self.requestor = InterProcessRequestor()
self.stop_event = stop_event
self.tracked_events = {}
self.genai_client = get_genai_client(config.genai)
def run(self) -> None:
"""Maintain a Chroma vector database for semantic search."""
while not self.stop_event.is_set():
self._process_updates()
self._process_finalized()
self.event_subscriber.stop()
self.event_end_subscriber.stop()
self.requestor.stop()
logger.info("Exiting embeddings maintenance...")
def _process_updates(self) -> None:
"""Process event updates"""
update = self.event_subscriber.check_for_update()
if update is None:
return
source_type, _, camera, data = update
if not camera or source_type != EventTypeEnum.tracked_object:
return
camera_config = self.config.cameras[camera]
if data["id"] not in self.tracked_events:
self.tracked_events[data["id"]] = []
# Create our own thumbnail based on the bounding box and the frame time
try:
frame_id = f"{camera}{data['frame_time']}"
yuv_frame = self.frame_manager.get(frame_id, camera_config.frame_shape_yuv)
data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"])
self.tracked_events[data["id"]].append(data)
self.frame_manager.close(frame_id)
except FileNotFoundError:
pass
def _process_finalized(self) -> None:
"""Process the end of an event."""
while True:
ended = self.event_end_subscriber.check_for_update()
if ended == None:
break
event_id, camera, updated_db = ended
camera_config = self.config.cameras[camera]
if updated_db:
try:
event: Event = Event.get(Event.id == event_id)
except DoesNotExist:
continue
# Skip the event if not an object
if event.data.get("type") != "object":
continue
# Extract valid event metadata
metadata = get_metadata(event)
thumbnail = base64.b64decode(event.thumbnail)
# Embed the thumbnail
self._embed_thumbnail(event_id, thumbnail, metadata)
if (
camera_config.genai.enabled
and self.genai_client is not None
and event.data.get("description") is None
):
# Generate the description. Call happens in a thread since it is network bound.
threading.Thread(
target=self._embed_description,
name=f"_embed_description_{event.id}",
daemon=True,
args=(
event,
[
data["thumbnail"]
for data in self.tracked_events[event_id]
]
if len(self.tracked_events.get(event_id, [])) > 0
else [thumbnail],
metadata,
),
).start()
# Delete tracked events based on the event_id
if event_id in self.tracked_events:
del self.tracked_events[event_id]
def _create_thumbnail(self, yuv_frame, box, height=500) -> Optional[bytes]:
"""Return jpg thumbnail of a region of the frame."""
frame = cv2.cvtColor(yuv_frame, cv2.COLOR_YUV2BGR_I420)
region = calculate_region(
frame.shape, box[0], box[1], box[2], box[3], height, multiplier=1.4
)
frame = frame[region[1] : region[3], region[0] : region[2]]
width = int(height * frame.shape[1] / frame.shape[0])
frame = cv2.resize(frame, dsize=(width, height), interpolation=cv2.INTER_AREA)
ret, jpg = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
if ret:
return jpg.tobytes()
return None
def _embed_thumbnail(self, event_id: str, thumbnail: bytes, metadata: dict) -> None:
"""Embed the thumbnail for an event."""
# Encode the thumbnail
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
self.embeddings.thumbnail.upsert(
images=[img],
metadatas=[metadata],
ids=[event_id],
)
def _embed_description(
self, event: Event, thumbnails: list[bytes], metadata: dict
) -> None:
"""Embed the description for an event."""
description = self.genai_client.generate_description(thumbnails, metadata)
if description is None:
logger.debug("Failed to generate description for %s", event.id)
return
# fire and forget description update
self.requestor.send_data(
UPDATE_EVENT_DESCRIPTION,
{"id": event.id, "description": description},
)
# Encode the description
self.embeddings.description.upsert(
documents=[description],
metadatas=[metadata],
ids=[event.id],
)
logger.debug(
"Generated description for %s (%d images): %s",
event.id,
len(thumbnails),
description,
)

View File

@@ -223,7 +223,7 @@ class AudioEventMaintainer(threading.Thread):
audio_detections.append(label)
# send audio detection data
self.detection_publisher.send_data(
self.detection_publisher.publish(
(
self.config.name,
datetime.datetime.now().timestamp(),

View File

@@ -10,6 +10,7 @@ from pathlib import Path
from frigate.config import FrigateConfig
from frigate.const import CLIPS_DIR
from frigate.embeddings.embeddings import Embeddings
from frigate.models import Event, Timeline
logger = logging.getLogger(__name__)
@@ -26,6 +27,7 @@ class EventCleanup(threading.Thread):
self.name = "event_cleanup"
self.config = config
self.stop_event = stop_event
self.embeddings = Embeddings()
self.camera_keys = list(self.config.cameras.keys())
self.removed_camera_labels: list[str] = None
self.camera_labels: dict[str, dict[str, any]] = {}
@@ -197,9 +199,20 @@ class EventCleanup(threading.Thread):
self.expire(EventCleanupType.snapshots)
# drop events from db where has_clip and has_snapshot are false
delete_query = Event.delete().where(
Event.has_clip == False, Event.has_snapshot == False
events = (
Event.select()
.where(Event.has_clip == False, Event.has_snapshot == False)
.iterator()
)
delete_query.execute()
events_to_delete = [e.id for e in events]
if len(events_to_delete) > 0:
chunk_size = 50
for i in range(0, len(events_to_delete), chunk_size):
chunk = events_to_delete[i : i + chunk_size]
Event.delete().where(Event.id << chunk).execute()
if self.config.semantic_search.enabled:
self.embeddings.thumbnail.delete(ids=chunk)
self.embeddings.description.delete(ids=chunk)
logger.info("Exiting event cleanup...")

View File

@@ -86,7 +86,7 @@ class ExternalEventProcessor:
if source_type == "api":
self.event_camera[event_id] = camera
self.detection_updater.send_data(
self.detection_updater.publish(
(
camera,
now,
@@ -115,7 +115,7 @@ class ExternalEventProcessor:
)
if event_id in self.event_camera:
self.detection_updater.send_data(
self.detection_updater.publish(
(
self.event_camera[event_id],
end_time,

View File

@@ -237,7 +237,7 @@ class EventProcessor(threading.Thread):
if event_type == EventStateEnum.end:
del self.events_in_process[event_data["id"]]
self.event_end_publisher.publish((event_data["id"], camera))
self.event_end_publisher.publish((event_data["id"], camera, updated_db))
def handle_external_detection(
self, event_type: EventStateEnum, event_data: Event

63
frigate/genai/__init__.py Normal file
View File

@@ -0,0 +1,63 @@
"""Generative AI module for Frigate."""
import importlib
import os
from typing import Optional
from frigate.config import GenAIConfig, GenAIProviderEnum
PROVIDERS = {}
def register_genai_provider(key: GenAIProviderEnum):
"""Register a GenAI provider."""
def decorator(cls):
PROVIDERS[key] = cls
return cls
return decorator
class GenAIClient:
"""Generative AI client for Frigate."""
def __init__(self, genai_config: GenAIConfig, timeout: int = 60) -> None:
self.genai_config: GenAIConfig = genai_config
self.timeout = timeout
self.provider = self._init_provider()
def generate_description(
self, thumbnails: list[bytes], metadata: dict[str, any]
) -> Optional[str]:
"""Generate a description for the frame."""
prompt = self.genai_config.object_prompts.get(
metadata["label"], self.genai_config.prompt
).format(**metadata)
return self._send(prompt, thumbnails)
def _init_provider(self):
"""Initialize the client."""
return None
def _send(self, prompt: str, images: list[bytes]) -> Optional[str]:
"""Submit a request to the provider."""
return None
def get_genai_client(genai_config: GenAIConfig) -> Optional[GenAIClient]:
"""Get the GenAI client."""
if genai_config.enabled:
load_providers()
provider = PROVIDERS.get(genai_config.provider)
if provider:
return provider(genai_config)
return None
def load_providers():
package_dir = os.path.dirname(__file__)
for filename in os.listdir(package_dir):
if filename.endswith(".py") and filename != "__init__.py":
module_name = f"frigate.genai.{filename[:-3]}"
importlib.import_module(module_name)

49
frigate/genai/gemini.py Normal file
View File

@@ -0,0 +1,49 @@
"""Gemini Provider for Frigate AI."""
from typing import Optional
import google.generativeai as genai
from google.api_core.exceptions import GoogleAPICallError
from frigate.config import GenAIProviderEnum
from frigate.genai import GenAIClient, register_genai_provider
@register_genai_provider(GenAIProviderEnum.gemini)
class GeminiClient(GenAIClient):
"""Generative AI client for Frigate using Gemini."""
provider: genai.GenerativeModel
def _init_provider(self):
"""Initialize the client."""
genai.configure(api_key=self.genai_config.api_key)
return genai.GenerativeModel(self.genai_config.model)
def _send(self, prompt: str, images: list[bytes]) -> Optional[str]:
"""Submit a request to Gemini."""
data = [
{
"mime_type": "image/jpeg",
"data": img,
}
for img in images
] + [prompt]
try:
response = self.provider.generate_content(
data,
generation_config=genai.types.GenerationConfig(
candidate_count=1,
),
request_options=genai.types.RequestOptions(
timeout=self.timeout,
),
)
except GoogleAPICallError:
return None
try:
description = response.text.strip()
except ValueError:
# No description was generated
return None
return description

41
frigate/genai/ollama.py Normal file
View File

@@ -0,0 +1,41 @@
"""Ollama Provider for Frigate AI."""
import logging
from typing import Optional
from httpx import TimeoutException
from ollama import Client as ApiClient
from ollama import ResponseError
from frigate.config import GenAIProviderEnum
from frigate.genai import GenAIClient, register_genai_provider
logger = logging.getLogger(__name__)
@register_genai_provider(GenAIProviderEnum.ollama)
class OllamaClient(GenAIClient):
"""Generative AI client for Frigate using Ollama."""
provider: ApiClient
def _init_provider(self):
"""Initialize the client."""
client = ApiClient(host=self.genai_config.base_url, timeout=self.timeout)
response = client.pull(self.genai_config.model)
if response["status"] != "success":
logger.error("Failed to pull %s model from Ollama", self.genai_config.model)
return None
return client
def _send(self, prompt: str, images: list[bytes]) -> Optional[str]:
"""Submit a request to Ollama"""
try:
result = self.provider.generate(
self.genai_config.model,
prompt,
images=images,
)
return result["response"].strip()
except (TimeoutException, ResponseError):
return None

51
frigate/genai/openai.py Normal file
View File

@@ -0,0 +1,51 @@
"""OpenAI Provider for Frigate AI."""
import base64
from typing import Optional
from httpx import TimeoutException
from openai import OpenAI
from frigate.config import GenAIProviderEnum
from frigate.genai import GenAIClient, register_genai_provider
@register_genai_provider(GenAIProviderEnum.openai)
class OpenAIClient(GenAIClient):
"""Generative AI client for Frigate using OpenAI."""
provider: OpenAI
def _init_provider(self):
"""Initialize the client."""
return OpenAI(api_key=self.genai_config.api_key)
def _send(self, prompt: str, images: list[bytes]) -> Optional[str]:
"""Submit a request to OpenAI."""
encoded_images = [base64.b64encode(image).decode("utf-8") for image in images]
try:
result = self.provider.chat.completions.create(
model=self.genai_config.model,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image}",
"detail": "low",
},
}
for image in encoded_images
]
+ [prompt],
},
],
timeout=self.timeout,
)
except TimeoutException:
return None
if len(result.choices) > 0:
return result.choices[0].message.content.strip()
return None

View File

@@ -1187,7 +1187,7 @@ class TrackedObjectProcessor(threading.Thread):
]
# publish info on this frame
self.detection_publisher.send_data(
self.detection_publisher.publish(
(
camera,
frame_time,
@@ -1274,7 +1274,7 @@ class TrackedObjectProcessor(threading.Thread):
if not update:
break
event_id, camera = update
event_id, camera, _ = update
self.camera_states[camera].finished(event_id)
self.requestor.stop()

View File

@@ -80,7 +80,7 @@ def output_frames(
websocket_thread.start()
while not stop_event.is_set():
(topic, data) = detection_subscriber.get_data(timeout=1)
(topic, data) = detection_subscriber.check_for_update(timeout=1)
if not topic:
continue
@@ -134,7 +134,7 @@ def output_frames(
move_preview_frames("clips")
while True:
(topic, data) = detection_subscriber.get_data(timeout=0)
(topic, data) = detection_subscriber.check_for_update(timeout=0)
if not topic:
break

View File

@@ -470,7 +470,7 @@ class RecordingMaintainer(threading.Thread):
stale_frame_count_threshold = 10
# empty the object recordings info queue
while True:
(topic, data) = self.detection_subscriber.get_data(
(topic, data) = self.detection_subscriber.check_for_update(
timeout=QUEUE_READ_TIMEOUT
)

View File

@@ -424,7 +424,7 @@ class ReviewSegmentMaintainer(threading.Thread):
camera_name = updated_topic.rpartition("/")[-1]
self.config.cameras[camera_name].record = updated_record_config
(topic, data) = self.detection_subscriber.get_data(timeout=1)
(topic, data) = self.detection_subscriber.check_for_update(timeout=1)
if not topic:
continue