forked from Github/frigate
Semantic Search for Detections (#11899)
* Initial re-implementation of semantic search * put docker-compose back and make reindex match docs * remove debug code and fix import * fix docs * manually build pysqlite3 as binaries are only available for x86-64 * update comment in build_pysqlite3.sh * only embed objects * better error handling when genai fails * ask ollama to pull requested model at startup * update ollama docs * address some PR review comments * fix lint * use IPC to write description, update docs for reindex * remove gemini-pro-vision from docs as it will be unavailable soon * fix OpenAI doc available models * fix api error in gemini and metadata for embeddings
This commit is contained in:
committed by
Nicolas Mowen
parent
f4f3cfa911
commit
36cbffcc5e
@@ -1,9 +1,12 @@
|
||||
import faulthandler
|
||||
import sys
|
||||
import threading
|
||||
|
||||
from flask import cli
|
||||
|
||||
from frigate.app import FrigateApp
|
||||
# Hotsawp the sqlite3 module for Chroma compatibility
|
||||
__import__("pysqlite3")
|
||||
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
|
||||
|
||||
faulthandler.enable()
|
||||
|
||||
@@ -12,6 +15,8 @@ threading.current_thread().name = "frigate"
|
||||
cli.show_server_banner = lambda *x: None
|
||||
|
||||
if __name__ == "__main__":
|
||||
from frigate.app import FrigateApp
|
||||
|
||||
frigate_app = FrigateApp()
|
||||
|
||||
frigate_app.start()
|
||||
|
||||
@@ -454,6 +454,7 @@ def logs(service: str):
|
||||
"frigate": "/dev/shm/logs/frigate/current",
|
||||
"go2rtc": "/dev/shm/logs/go2rtc/current",
|
||||
"nginx": "/dev/shm/logs/nginx/current",
|
||||
"chroma": "/dev/shm/logs/chroma/current",
|
||||
}
|
||||
service_location = log_locations.get(service)
|
||||
|
||||
|
||||
@@ -22,11 +22,11 @@ from pydantic import ValidationError
|
||||
from frigate.api.app import create_app
|
||||
from frigate.api.auth import hash_password
|
||||
from frigate.comms.config_updater import ConfigPublisher
|
||||
from frigate.comms.detections_updater import DetectionProxy
|
||||
from frigate.comms.dispatcher import Communicator, Dispatcher
|
||||
from frigate.comms.inter_process import InterProcessCommunicator
|
||||
from frigate.comms.mqtt import MqttClient
|
||||
from frigate.comms.ws import WebSocketClient
|
||||
from frigate.comms.zmq_proxy import ZmqProxy
|
||||
from frigate.config import FrigateConfig
|
||||
from frigate.const import (
|
||||
CACHE_DIR,
|
||||
@@ -37,6 +37,8 @@ from frigate.const import (
|
||||
MODEL_CACHE_DIR,
|
||||
RECORD_DIR,
|
||||
)
|
||||
from frigate.embeddings import manage_embeddings
|
||||
from frigate.embeddings.embeddings import Embeddings
|
||||
from frigate.events.audio import listen_to_audio
|
||||
from frigate.events.cleanup import EventCleanup
|
||||
from frigate.events.external import ExternalEventProcessor
|
||||
@@ -316,7 +318,21 @@ class FrigateApp:
|
||||
self.review_segment_process = review_segment_process
|
||||
review_segment_process.start()
|
||||
self.processes["review_segment"] = review_segment_process.pid or 0
|
||||
logger.info(f"Recording process started: {review_segment_process.pid}")
|
||||
logger.info(f"Review process started: {review_segment_process.pid}")
|
||||
|
||||
def init_embeddings_manager(self) -> None:
|
||||
# Create a client for other processes to use
|
||||
self.embeddings = Embeddings()
|
||||
embedding_process = mp.Process(
|
||||
target=manage_embeddings,
|
||||
name="embeddings_manager",
|
||||
args=(self.config,),
|
||||
)
|
||||
embedding_process.daemon = True
|
||||
self.embedding_process = embedding_process
|
||||
embedding_process.start()
|
||||
self.processes["embeddings"] = embedding_process.pid or 0
|
||||
logger.info(f"Embedding process started: {embedding_process.pid}")
|
||||
|
||||
def bind_database(self) -> None:
|
||||
"""Bind db to the main process."""
|
||||
@@ -362,7 +378,7 @@ class FrigateApp:
|
||||
def init_inter_process_communicator(self) -> None:
|
||||
self.inter_process_communicator = InterProcessCommunicator()
|
||||
self.inter_config_updater = ConfigPublisher()
|
||||
self.inter_detection_proxy = DetectionProxy()
|
||||
self.inter_zmq_proxy = ZmqProxy()
|
||||
|
||||
def init_web_server(self) -> None:
|
||||
self.flask_app = create_app(
|
||||
@@ -678,6 +694,7 @@ class FrigateApp:
|
||||
self.init_onvif()
|
||||
self.init_recording_manager()
|
||||
self.init_review_segment_manager()
|
||||
self.init_embeddings_manager()
|
||||
self.init_go2rtc()
|
||||
self.bind_database()
|
||||
self.check_db_data_migrations()
|
||||
@@ -797,7 +814,7 @@ class FrigateApp:
|
||||
# Stop Communicators
|
||||
self.inter_process_communicator.stop()
|
||||
self.inter_config_updater.stop()
|
||||
self.inter_detection_proxy.stop()
|
||||
self.inter_zmq_proxy.stop()
|
||||
|
||||
while len(self.detection_shms) > 0:
|
||||
shm = self.detection_shms.pop()
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
"""Facilitates communication between processes."""
|
||||
|
||||
import threading
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import zmq
|
||||
|
||||
SOCKET_CONTROL = "inproc://control.detections_updater"
|
||||
SOCKET_PUB = "ipc:///tmp/cache/detect_pub"
|
||||
SOCKET_SUB = "ipc:///tmp/cache/detect_sub"
|
||||
from .zmq_proxy import Publisher, Subscriber
|
||||
|
||||
|
||||
class DetectionTypeEnum(str, Enum):
|
||||
@@ -18,85 +13,31 @@ class DetectionTypeEnum(str, Enum):
|
||||
audio = "audio"
|
||||
|
||||
|
||||
class DetectionProxyRunner(threading.Thread):
|
||||
def __init__(self, context: zmq.Context[zmq.Socket]) -> None:
|
||||
threading.Thread.__init__(self)
|
||||
self.name = "detection_proxy"
|
||||
self.context = context
|
||||
|
||||
def run(self) -> None:
|
||||
"""Run the proxy."""
|
||||
control = self.context.socket(zmq.REP)
|
||||
control.connect(SOCKET_CONTROL)
|
||||
incoming = self.context.socket(zmq.XSUB)
|
||||
incoming.bind(SOCKET_PUB)
|
||||
outgoing = self.context.socket(zmq.XPUB)
|
||||
outgoing.bind(SOCKET_SUB)
|
||||
|
||||
zmq.proxy_steerable(
|
||||
incoming, outgoing, None, control
|
||||
) # blocking, will unblock terminate message is received
|
||||
incoming.close()
|
||||
outgoing.close()
|
||||
|
||||
|
||||
class DetectionProxy:
|
||||
"""Proxies video and audio detections."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.context = zmq.Context()
|
||||
self.control = self.context.socket(zmq.REQ)
|
||||
self.control.bind(SOCKET_CONTROL)
|
||||
self.runner = DetectionProxyRunner(self.context)
|
||||
self.runner.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
self.control.send("TERMINATE".encode()) # tell the proxy to stop
|
||||
self.runner.join()
|
||||
self.context.destroy()
|
||||
|
||||
|
||||
class DetectionPublisher:
|
||||
class DetectionPublisher(Publisher):
|
||||
"""Simplifies receiving video and audio detections."""
|
||||
|
||||
topic_base = "detection/"
|
||||
|
||||
def __init__(self, topic: DetectionTypeEnum) -> None:
|
||||
self.topic = topic
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.PUB)
|
||||
self.socket.connect(SOCKET_PUB)
|
||||
|
||||
def send_data(self, payload: any) -> None:
|
||||
"""Publish detection."""
|
||||
self.socket.send_string(self.topic.value, flags=zmq.SNDMORE)
|
||||
self.socket.send_json(payload)
|
||||
|
||||
def stop(self) -> None:
|
||||
self.socket.close()
|
||||
self.context.destroy()
|
||||
topic = topic.value
|
||||
super().__init__(topic)
|
||||
|
||||
|
||||
class DetectionSubscriber:
|
||||
class DetectionSubscriber(Subscriber):
|
||||
"""Simplifies receiving video and audio detections."""
|
||||
|
||||
topic_base = "detection/"
|
||||
|
||||
def __init__(self, topic: DetectionTypeEnum) -> None:
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.SUB)
|
||||
self.socket.setsockopt_string(zmq.SUBSCRIBE, topic.value)
|
||||
self.socket.connect(SOCKET_SUB)
|
||||
topic = topic.value
|
||||
super().__init__(topic)
|
||||
|
||||
def get_data(self, timeout: float = None) -> Optional[tuple[str, any]]:
|
||||
"""Returns detections or None if no update."""
|
||||
try:
|
||||
has_update, _, _ = zmq.select([self.socket], [], [], timeout)
|
||||
def check_for_update(
|
||||
self, timeout: float = None
|
||||
) -> Optional[tuple[DetectionTypeEnum, any]]:
|
||||
return super().check_for_update(timeout)
|
||||
|
||||
if has_update:
|
||||
topic = DetectionTypeEnum[self.socket.recv_string(flags=zmq.NOBLOCK)]
|
||||
return (topic, self.socket.recv_json())
|
||||
except zmq.ZMQError:
|
||||
pass
|
||||
|
||||
return (None, None)
|
||||
|
||||
def stop(self) -> None:
|
||||
self.socket.close()
|
||||
self.context.destroy()
|
||||
def _return_object(self, topic: str, payload: any) -> any:
|
||||
if payload is None:
|
||||
return (None, None)
|
||||
return (DetectionTypeEnum[topic[len(self.topic_base) :]], payload)
|
||||
|
||||
@@ -14,9 +14,10 @@ from frigate.const import (
|
||||
INSERT_PREVIEW,
|
||||
REQUEST_REGION_GRID,
|
||||
UPDATE_CAMERA_ACTIVITY,
|
||||
UPDATE_EVENT_DESCRIPTION,
|
||||
UPSERT_REVIEW_SEGMENT,
|
||||
)
|
||||
from frigate.models import Previews, Recordings, ReviewSegment
|
||||
from frigate.models import Event, Previews, Recordings, ReviewSegment
|
||||
from frigate.ptz.onvif import OnvifCommandEnum, OnvifController
|
||||
from frigate.types import PTZMetricsTypes
|
||||
from frigate.util.object import get_camera_regions_grid
|
||||
@@ -128,6 +129,10 @@ class Dispatcher:
|
||||
).execute()
|
||||
elif topic == UPDATE_CAMERA_ACTIVITY:
|
||||
self.camera_activity = payload
|
||||
elif topic == UPDATE_EVENT_DESCRIPTION:
|
||||
event: Event = Event.get(Event.id == payload["id"])
|
||||
event.data["description"] = payload["description"]
|
||||
event.save()
|
||||
elif topic == "onConnect":
|
||||
camera_status = self.camera_activity.copy()
|
||||
|
||||
|
||||
@@ -1,100 +1,51 @@
|
||||
"""Facilitates communication between processes."""
|
||||
|
||||
import zmq
|
||||
|
||||
from frigate.events.types import EventStateEnum, EventTypeEnum
|
||||
|
||||
SOCKET_PUSH_PULL = "ipc:///tmp/cache/events"
|
||||
SOCKET_PUSH_PULL_END = "ipc:///tmp/cache/events_ended"
|
||||
from .zmq_proxy import Publisher, Subscriber
|
||||
|
||||
|
||||
class EventUpdatePublisher:
|
||||
class EventUpdatePublisher(Publisher):
|
||||
"""Publishes events (objects, audio, manual)."""
|
||||
|
||||
topic_base = "event/"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.PUSH)
|
||||
self.socket.connect(SOCKET_PUSH_PULL)
|
||||
super().__init__("update")
|
||||
|
||||
def publish(
|
||||
self, payload: tuple[EventTypeEnum, EventStateEnum, str, dict[str, any]]
|
||||
) -> None:
|
||||
"""There is no communication back to the processes."""
|
||||
self.socket.send_json(payload)
|
||||
|
||||
def stop(self) -> None:
|
||||
self.socket.close()
|
||||
self.context.destroy()
|
||||
super().publish(payload)
|
||||
|
||||
|
||||
class EventUpdateSubscriber:
|
||||
class EventUpdateSubscriber(Subscriber):
|
||||
"""Receives event updates."""
|
||||
|
||||
topic_base = "event/"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.PULL)
|
||||
self.socket.bind(SOCKET_PUSH_PULL)
|
||||
|
||||
def check_for_update(
|
||||
self, timeout=1
|
||||
) -> tuple[EventTypeEnum, EventStateEnum, str, dict[str, any]]:
|
||||
"""Returns events or None if no update."""
|
||||
try:
|
||||
has_update, _, _ = zmq.select([self.socket], [], [], timeout)
|
||||
|
||||
if has_update:
|
||||
return self.socket.recv_json()
|
||||
except zmq.ZMQError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def stop(self) -> None:
|
||||
self.socket.close()
|
||||
self.context.destroy()
|
||||
super().__init__("update")
|
||||
|
||||
|
||||
class EventEndPublisher:
|
||||
class EventEndPublisher(Publisher):
|
||||
"""Publishes events that have ended."""
|
||||
|
||||
topic_base = "event/"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.PUSH)
|
||||
self.socket.connect(SOCKET_PUSH_PULL_END)
|
||||
super().__init__("finalized")
|
||||
|
||||
def publish(
|
||||
self, payload: tuple[EventTypeEnum, EventStateEnum, str, dict[str, any]]
|
||||
) -> None:
|
||||
"""There is no communication back to the processes."""
|
||||
self.socket.send_json(payload)
|
||||
|
||||
def stop(self) -> None:
|
||||
self.socket.close()
|
||||
self.context.destroy()
|
||||
super().publish(payload)
|
||||
|
||||
|
||||
class EventEndSubscriber:
|
||||
class EventEndSubscriber(Subscriber):
|
||||
"""Receives events that have ended."""
|
||||
|
||||
topic_base = "event/"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.PULL)
|
||||
self.socket.bind(SOCKET_PUSH_PULL_END)
|
||||
|
||||
def check_for_update(
|
||||
self, timeout=1
|
||||
) -> tuple[EventTypeEnum, EventStateEnum, str, dict[str, any]]:
|
||||
"""Returns events ended or None if no update."""
|
||||
try:
|
||||
has_update, _, _ = zmq.select([self.socket], [], [], timeout)
|
||||
|
||||
if has_update:
|
||||
return self.socket.recv_json()
|
||||
except zmq.ZMQError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def stop(self) -> None:
|
||||
self.socket.close()
|
||||
self.context.destroy()
|
||||
super().__init__("finalized")
|
||||
|
||||
100
frigate/comms/zmq_proxy.py
Normal file
100
frigate/comms/zmq_proxy.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Facilitates communication over zmq proxy."""
|
||||
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import zmq
|
||||
|
||||
SOCKET_PUB = "ipc:///tmp/cache/proxy_pub"
|
||||
SOCKET_SUB = "ipc:///tmp/cache/proxy_sub"
|
||||
|
||||
|
||||
class ZmqProxyRunner(threading.Thread):
|
||||
def __init__(self, context: zmq.Context[zmq.Socket]) -> None:
|
||||
threading.Thread.__init__(self)
|
||||
self.name = "detection_proxy"
|
||||
self.context = context
|
||||
|
||||
def run(self) -> None:
|
||||
"""Run the proxy."""
|
||||
incoming = self.context.socket(zmq.XSUB)
|
||||
incoming.bind(SOCKET_PUB)
|
||||
outgoing = self.context.socket(zmq.XPUB)
|
||||
outgoing.bind(SOCKET_SUB)
|
||||
|
||||
# Blocking: This will unblock (via exception) when we destroy the context
|
||||
# The incoming and outgoing sockets will be closed automatically
|
||||
# when the context is destroyed as well.
|
||||
try:
|
||||
zmq.proxy(incoming, outgoing)
|
||||
except zmq.ZMQError:
|
||||
pass
|
||||
|
||||
|
||||
class ZmqProxy:
|
||||
"""Proxies video and audio detections."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.context = zmq.Context()
|
||||
self.runner = ZmqProxyRunner(self.context)
|
||||
self.runner.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
# destroying the context will tell the proxy to stop
|
||||
self.context.destroy()
|
||||
self.runner.join()
|
||||
|
||||
|
||||
class Publisher:
|
||||
"""Publishes messages."""
|
||||
|
||||
topic_base: str = ""
|
||||
|
||||
def __init__(self, topic: str = "") -> None:
|
||||
self.topic = f"{self.topic_base}{topic}"
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.PUB)
|
||||
self.socket.connect(SOCKET_PUB)
|
||||
|
||||
def publish(self, payload: any, sub_topic: str = "") -> None:
|
||||
"""Publish message."""
|
||||
self.socket.send_string(f"{self.topic}{sub_topic}", flags=zmq.SNDMORE)
|
||||
self.socket.send_json(payload)
|
||||
|
||||
def stop(self) -> None:
|
||||
self.socket.close()
|
||||
self.context.destroy()
|
||||
|
||||
|
||||
class Subscriber:
|
||||
"""Receives messages."""
|
||||
|
||||
topic_base: str = ""
|
||||
|
||||
def __init__(self, topic: str = "") -> None:
|
||||
self.topic = f"{self.topic_base}{topic}"
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.SUB)
|
||||
self.socket.setsockopt_string(zmq.SUBSCRIBE, self.topic)
|
||||
self.socket.connect(SOCKET_SUB)
|
||||
|
||||
def check_for_update(self, timeout: float = 1) -> Optional[tuple[str, any]]:
|
||||
"""Returns message or None if no update."""
|
||||
try:
|
||||
has_update, _, _ = zmq.select([self.socket], [], [], timeout)
|
||||
|
||||
if has_update:
|
||||
topic = self.socket.recv_string(flags=zmq.NOBLOCK)
|
||||
payload = self.socket.recv_json()
|
||||
return self._return_object(topic, payload)
|
||||
except zmq.ZMQError:
|
||||
pass
|
||||
|
||||
return self._return_object("", None)
|
||||
|
||||
def stop(self) -> None:
|
||||
self.socket.close()
|
||||
self.context.destroy()
|
||||
|
||||
def _return_object(self, topic: str, payload: any) -> any:
|
||||
return payload
|
||||
@@ -730,6 +730,38 @@ class ReviewConfig(FrigateBaseModel):
|
||||
)
|
||||
|
||||
|
||||
class SemanticSearchConfig(FrigateBaseModel):
|
||||
enabled: bool = Field(default=True, title="Enable semantic search.")
|
||||
reindex: Optional[bool] = Field(
|
||||
default=False, title="Reindex all detections on startup."
|
||||
)
|
||||
|
||||
|
||||
class GenAIProviderEnum(str, Enum):
|
||||
openai = "openai"
|
||||
gemini = "gemini"
|
||||
ollama = "ollama"
|
||||
|
||||
|
||||
class GenAIConfig(FrigateBaseModel):
|
||||
enabled: bool = Field(default=False, title="Enable GenAI.")
|
||||
provider: GenAIProviderEnum = Field(
|
||||
default=GenAIProviderEnum.openai, title="GenAI provider."
|
||||
)
|
||||
base_url: Optional[str] = Field(None, title="Provider base url.")
|
||||
api_key: Optional[str] = Field(None, title="Provider API key.")
|
||||
model: str = Field(default="gpt-4o", title="GenAI model.")
|
||||
prompt: str = Field(
|
||||
default="Describe the {label} in the sequence of images with as much detail as possible. Do not describe the background.",
|
||||
title="Default caption prompt.",
|
||||
)
|
||||
object_prompts: Dict[str, str] = Field(default={}, title="Object specific prompts.")
|
||||
|
||||
|
||||
class GenAICameraConfig(FrigateBaseModel):
|
||||
enabled: bool = Field(default=False, title="Enable GenAI for camera.")
|
||||
|
||||
|
||||
class AudioConfig(FrigateBaseModel):
|
||||
enabled: bool = Field(default=False, title="Enable audio events.")
|
||||
max_not_heard: int = Field(
|
||||
@@ -1011,6 +1043,9 @@ class CameraConfig(FrigateBaseModel):
|
||||
review: ReviewConfig = Field(
|
||||
default_factory=ReviewConfig, title="Review configuration."
|
||||
)
|
||||
genai: GenAICameraConfig = Field(
|
||||
default_factory=GenAICameraConfig, title="Generative AI configuration."
|
||||
)
|
||||
audio: AudioConfig = Field(
|
||||
default_factory=AudioConfig, title="Audio events configuration."
|
||||
)
|
||||
@@ -1363,6 +1398,12 @@ class FrigateConfig(FrigateBaseModel):
|
||||
review: ReviewConfig = Field(
|
||||
default_factory=ReviewConfig, title="Review configuration."
|
||||
)
|
||||
semantic_search: SemanticSearchConfig = Field(
|
||||
default_factory=SemanticSearchConfig, title="Semantic search configuration."
|
||||
)
|
||||
genai: GenAIConfig = Field(
|
||||
default_factory=GenAIConfig, title="Generative AI configuration."
|
||||
)
|
||||
audio: AudioConfig = Field(
|
||||
default_factory=AudioConfig, title="Global Audio events configuration."
|
||||
)
|
||||
@@ -1397,6 +1438,10 @@ class FrigateConfig(FrigateBaseModel):
|
||||
config.mqtt.user = config.mqtt.user.format(**FRIGATE_ENV_VARS)
|
||||
config.mqtt.password = config.mqtt.password.format(**FRIGATE_ENV_VARS)
|
||||
|
||||
# GenAI substitution
|
||||
if config.genai.api_key:
|
||||
config.genai.api_key = config.genai.api_key.format(**FRIGATE_ENV_VARS)
|
||||
|
||||
# set default min_score for object attributes
|
||||
for attribute in ALL_ATTRIBUTE_LABELS:
|
||||
if not config.objects.filters.get(attribute):
|
||||
@@ -1418,6 +1463,7 @@ class FrigateConfig(FrigateBaseModel):
|
||||
"live": ...,
|
||||
"objects": ...,
|
||||
"review": ...,
|
||||
"genai": {"enabled"},
|
||||
"motion": ...,
|
||||
"detect": ...,
|
||||
"ffmpeg": ...,
|
||||
|
||||
@@ -81,6 +81,7 @@ REQUEST_REGION_GRID = "request_region_grid"
|
||||
UPSERT_REVIEW_SEGMENT = "upsert_review_segment"
|
||||
CLEAR_ONGOING_REVIEW_SEGMENTS = "clear_ongoing_review_segments"
|
||||
UPDATE_CAMERA_ACTIVITY = "update_camera_activity"
|
||||
UPDATE_EVENT_DESCRIPTION = "update_event_description"
|
||||
|
||||
# Stats Values
|
||||
|
||||
|
||||
67
frigate/embeddings/__init__.py
Normal file
67
frigate/embeddings/__init__.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""ChromaDB embeddings database."""
|
||||
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from types import FrameType
|
||||
from typing import Optional
|
||||
|
||||
from playhouse.sqliteq import SqliteQueueDatabase
|
||||
from setproctitle import setproctitle
|
||||
|
||||
from frigate.config import FrigateConfig
|
||||
from frigate.models import Event
|
||||
from frigate.util.services import listen
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def manage_embeddings(config: FrigateConfig) -> None:
|
||||
# Only initialize embeddings if semantic search is enabled
|
||||
if not config.semantic_search.enabled:
|
||||
return
|
||||
|
||||
stop_event = mp.Event()
|
||||
|
||||
def receiveSignal(signalNumber: int, frame: Optional[FrameType]) -> None:
|
||||
stop_event.set()
|
||||
|
||||
signal.signal(signal.SIGTERM, receiveSignal)
|
||||
signal.signal(signal.SIGINT, receiveSignal)
|
||||
|
||||
threading.current_thread().name = "process:embeddings_manager"
|
||||
setproctitle("frigate.embeddings_manager")
|
||||
listen()
|
||||
|
||||
# Configure Frigate DB
|
||||
db = SqliteQueueDatabase(
|
||||
config.database.path,
|
||||
pragmas={
|
||||
"auto_vacuum": "FULL", # Does not defragment database
|
||||
"cache_size": -512 * 1000, # 512MB of cache
|
||||
"synchronous": "NORMAL", # Safe when using WAL https://www.sqlite.org/pragma.html#pragma_synchronous
|
||||
},
|
||||
timeout=max(60, 10 * len([c for c in config.cameras.values() if c.enabled])),
|
||||
)
|
||||
models = [Event]
|
||||
db.bind(models)
|
||||
|
||||
# Hotsawp the sqlite3 module for Chroma compatibility
|
||||
__import__("pysqlite3")
|
||||
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
|
||||
from .embeddings import Embeddings
|
||||
from .maintainer import EmbeddingMaintainer
|
||||
|
||||
embeddings = Embeddings()
|
||||
|
||||
# Check if we need to re-index events
|
||||
if config.semantic_search.reindex:
|
||||
embeddings.reindex()
|
||||
|
||||
maintainer = EmbeddingMaintainer(
|
||||
config,
|
||||
stop_event,
|
||||
)
|
||||
maintainer.start()
|
||||
122
frigate/embeddings/embeddings.py
Normal file
122
frigate/embeddings/embeddings.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""ChromaDB embeddings database."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from chromadb import Collection
|
||||
from chromadb import HttpClient as ChromaClient
|
||||
from chromadb.config import Settings
|
||||
from PIL import Image
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
|
||||
from frigate.models import Event
|
||||
|
||||
from .functions.clip import ClipEmbedding
|
||||
from .functions.minilm_l6_v2 import MiniLMEmbedding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_metadata(event: Event) -> dict:
|
||||
"""Extract valid event metadata."""
|
||||
event_dict = model_to_dict(event)
|
||||
return (
|
||||
{
|
||||
k: v
|
||||
for k, v in event_dict.items()
|
||||
if k not in ["id", "thumbnail"]
|
||||
and v is not None
|
||||
and isinstance(v, (str, int, float, bool))
|
||||
}
|
||||
| {
|
||||
k: v
|
||||
for k, v in event_dict["data"].items()
|
||||
if k not in ["description"]
|
||||
and v is not None
|
||||
and isinstance(v, (str, int, float, bool))
|
||||
}
|
||||
| {
|
||||
# Metadata search doesn't support $contains
|
||||
# and an event can have multiple zones, so
|
||||
# we need to create a key for each zone
|
||||
f"{k}_{x}": True
|
||||
for k, v in event_dict.items()
|
||||
if isinstance(v, list) and len(v) > 0
|
||||
for x in v
|
||||
if isinstance(x, str)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class Embeddings:
|
||||
"""ChromaDB embeddings database."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.client: ChromaClient = ChromaClient(
|
||||
host="127.0.0.1",
|
||||
settings=Settings(anonymized_telemetry=False),
|
||||
)
|
||||
|
||||
@property
|
||||
def thumbnail(self) -> Collection:
|
||||
return self.client.get_or_create_collection(
|
||||
name="event_thumbnail", embedding_function=ClipEmbedding()
|
||||
)
|
||||
|
||||
@property
|
||||
def description(self) -> Collection:
|
||||
return self.client.get_or_create_collection(
|
||||
name="event_description", embedding_function=MiniLMEmbedding()
|
||||
)
|
||||
|
||||
def reindex(self) -> None:
|
||||
"""Reindex all event embeddings."""
|
||||
logger.info("Indexing event embeddings...")
|
||||
self.client.reset()
|
||||
|
||||
st = time.time()
|
||||
|
||||
thumbnails = {"ids": [], "images": [], "metadatas": []}
|
||||
descriptions = {"ids": [], "documents": [], "metadatas": []}
|
||||
|
||||
events = Event.select().where(
|
||||
(Event.has_clip == True | Event.has_snapshot == True)
|
||||
& Event.thumbnail.is_null(False)
|
||||
)
|
||||
|
||||
event: Event
|
||||
for event in events.iterator():
|
||||
metadata = get_metadata(event)
|
||||
thumbnail = base64.b64decode(event.thumbnail)
|
||||
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
|
||||
thumbnails["ids"].append(event.id)
|
||||
thumbnails["images"].append(img)
|
||||
thumbnails["metadatas"].append(metadata)
|
||||
if event.data.get("description") is not None:
|
||||
descriptions["ids"].append(event.id)
|
||||
descriptions["documents"].append(event.data["description"])
|
||||
descriptions["metadatas"].append(metadata)
|
||||
|
||||
if len(thumbnails["ids"]) > 0:
|
||||
self.thumbnail.upsert(
|
||||
images=thumbnails["images"],
|
||||
metadatas=thumbnails["metadatas"],
|
||||
ids=thumbnails["ids"],
|
||||
)
|
||||
|
||||
if len(descriptions["ids"]) > 0:
|
||||
self.description.upsert(
|
||||
documents=descriptions["documents"],
|
||||
metadatas=descriptions["metadatas"],
|
||||
ids=descriptions["ids"],
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Embedded %d thumbnails and %d descriptions in %s seconds",
|
||||
len(thumbnails["ids"]),
|
||||
len(descriptions["ids"]),
|
||||
time.time() - st,
|
||||
)
|
||||
63
frigate/embeddings/functions/clip.py
Normal file
63
frigate/embeddings/functions/clip.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""CLIP Embeddings for Frigate."""
|
||||
|
||||
import os
|
||||
from typing import Tuple, Union
|
||||
|
||||
import onnxruntime as ort
|
||||
from chromadb import EmbeddingFunction, Embeddings
|
||||
from chromadb.api.types import (
|
||||
Documents,
|
||||
Images,
|
||||
is_document,
|
||||
is_image,
|
||||
)
|
||||
from onnx_clip import OnnxClip
|
||||
|
||||
from frigate.const import MODEL_CACHE_DIR
|
||||
|
||||
|
||||
class Clip(OnnxClip):
|
||||
"""Override load models to download to cache directory."""
|
||||
|
||||
@staticmethod
|
||||
def _load_models(
|
||||
model: str,
|
||||
silent: bool,
|
||||
) -> Tuple[ort.InferenceSession, ort.InferenceSession]:
|
||||
"""
|
||||
These models are a part of the container. Treat as as such.
|
||||
"""
|
||||
if model == "ViT-B/32":
|
||||
IMAGE_MODEL_FILE = "clip_image_model_vitb32.onnx"
|
||||
TEXT_MODEL_FILE = "clip_text_model_vitb32.onnx"
|
||||
elif model == "RN50":
|
||||
IMAGE_MODEL_FILE = "clip_image_model_rn50.onnx"
|
||||
TEXT_MODEL_FILE = "clip_text_model_rn50.onnx"
|
||||
else:
|
||||
raise ValueError(f"Unexpected model {model}. No `.onnx` file found.")
|
||||
|
||||
models = []
|
||||
for model_file in [IMAGE_MODEL_FILE, TEXT_MODEL_FILE]:
|
||||
path = os.path.join(MODEL_CACHE_DIR, "clip", model_file)
|
||||
models.append(OnnxClip._load_model(path, silent))
|
||||
|
||||
return models[0], models[1]
|
||||
|
||||
|
||||
class ClipEmbedding(EmbeddingFunction):
|
||||
"""Embedding function for CLIP model used in Chroma."""
|
||||
|
||||
def __init__(self, model: str = "ViT-B/32"):
|
||||
"""Initialize CLIP Embedding function."""
|
||||
self.model = Clip(model)
|
||||
|
||||
def __call__(self, input: Union[Documents, Images]) -> Embeddings:
|
||||
embeddings: Embeddings = []
|
||||
for item in input:
|
||||
if is_image(item):
|
||||
result = self.model.get_image_embeddings([item])
|
||||
embeddings.append(result[0, :].tolist())
|
||||
elif is_document(item):
|
||||
result = self.model.get_text_embeddings([item])
|
||||
embeddings.append(result[0, :].tolist())
|
||||
return embeddings
|
||||
11
frigate/embeddings/functions/minilm_l6_v2.py
Normal file
11
frigate/embeddings/functions/minilm_l6_v2.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Embedding function for ONNX MiniLM-L6 model used in Chroma."""
|
||||
|
||||
from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2
|
||||
|
||||
from frigate.const import MODEL_CACHE_DIR
|
||||
|
||||
|
||||
class MiniLMEmbedding(ONNXMiniLM_L6_V2):
|
||||
"""Override DOWNLOAD_PATH to download to cache directory."""
|
||||
|
||||
DOWNLOAD_PATH = f"{MODEL_CACHE_DIR}/all-MiniLM-L6-v2"
|
||||
197
frigate/embeddings/maintainer.py
Normal file
197
frigate/embeddings/maintainer.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""Maintain embeddings in Chroma."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import threading
|
||||
from multiprocessing.synchronize import Event as MpEvent
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from peewee import DoesNotExist
|
||||
from PIL import Image
|
||||
|
||||
from frigate.comms.events_updater import EventEndSubscriber, EventUpdateSubscriber
|
||||
from frigate.comms.inter_process import InterProcessRequestor
|
||||
from frigate.config import FrigateConfig
|
||||
from frigate.const import UPDATE_EVENT_DESCRIPTION
|
||||
from frigate.events.types import EventTypeEnum
|
||||
from frigate.genai import get_genai_client
|
||||
from frigate.models import Event
|
||||
from frigate.util.image import SharedMemoryFrameManager, calculate_region
|
||||
|
||||
from .embeddings import Embeddings, get_metadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingMaintainer(threading.Thread):
|
||||
"""Handle embedding queue and post event updates."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FrigateConfig,
|
||||
stop_event: MpEvent,
|
||||
) -> None:
|
||||
threading.Thread.__init__(self)
|
||||
self.name = "embeddings_maintainer"
|
||||
self.config = config
|
||||
self.embeddings = Embeddings()
|
||||
self.event_subscriber = EventUpdateSubscriber()
|
||||
self.event_end_subscriber = EventEndSubscriber()
|
||||
self.frame_manager = SharedMemoryFrameManager()
|
||||
# create communication for updating event descriptions
|
||||
self.requestor = InterProcessRequestor()
|
||||
self.stop_event = stop_event
|
||||
self.tracked_events = {}
|
||||
self.genai_client = get_genai_client(config.genai)
|
||||
|
||||
def run(self) -> None:
|
||||
"""Maintain a Chroma vector database for semantic search."""
|
||||
while not self.stop_event.is_set():
|
||||
self._process_updates()
|
||||
self._process_finalized()
|
||||
|
||||
self.event_subscriber.stop()
|
||||
self.event_end_subscriber.stop()
|
||||
self.requestor.stop()
|
||||
logger.info("Exiting embeddings maintenance...")
|
||||
|
||||
def _process_updates(self) -> None:
|
||||
"""Process event updates"""
|
||||
update = self.event_subscriber.check_for_update()
|
||||
|
||||
if update is None:
|
||||
return
|
||||
|
||||
source_type, _, camera, data = update
|
||||
|
||||
if not camera or source_type != EventTypeEnum.tracked_object:
|
||||
return
|
||||
|
||||
camera_config = self.config.cameras[camera]
|
||||
if data["id"] not in self.tracked_events:
|
||||
self.tracked_events[data["id"]] = []
|
||||
|
||||
# Create our own thumbnail based on the bounding box and the frame time
|
||||
try:
|
||||
frame_id = f"{camera}{data['frame_time']}"
|
||||
yuv_frame = self.frame_manager.get(frame_id, camera_config.frame_shape_yuv)
|
||||
data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"])
|
||||
self.tracked_events[data["id"]].append(data)
|
||||
self.frame_manager.close(frame_id)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
def _process_finalized(self) -> None:
|
||||
"""Process the end of an event."""
|
||||
while True:
|
||||
ended = self.event_end_subscriber.check_for_update()
|
||||
|
||||
if ended == None:
|
||||
break
|
||||
|
||||
event_id, camera, updated_db = ended
|
||||
camera_config = self.config.cameras[camera]
|
||||
|
||||
if updated_db:
|
||||
try:
|
||||
event: Event = Event.get(Event.id == event_id)
|
||||
except DoesNotExist:
|
||||
continue
|
||||
|
||||
# Skip the event if not an object
|
||||
if event.data.get("type") != "object":
|
||||
continue
|
||||
|
||||
# Extract valid event metadata
|
||||
metadata = get_metadata(event)
|
||||
thumbnail = base64.b64decode(event.thumbnail)
|
||||
|
||||
# Embed the thumbnail
|
||||
self._embed_thumbnail(event_id, thumbnail, metadata)
|
||||
|
||||
if (
|
||||
camera_config.genai.enabled
|
||||
and self.genai_client is not None
|
||||
and event.data.get("description") is None
|
||||
):
|
||||
# Generate the description. Call happens in a thread since it is network bound.
|
||||
threading.Thread(
|
||||
target=self._embed_description,
|
||||
name=f"_embed_description_{event.id}",
|
||||
daemon=True,
|
||||
args=(
|
||||
event,
|
||||
[
|
||||
data["thumbnail"]
|
||||
for data in self.tracked_events[event_id]
|
||||
]
|
||||
if len(self.tracked_events.get(event_id, [])) > 0
|
||||
else [thumbnail],
|
||||
metadata,
|
||||
),
|
||||
).start()
|
||||
|
||||
# Delete tracked events based on the event_id
|
||||
if event_id in self.tracked_events:
|
||||
del self.tracked_events[event_id]
|
||||
|
||||
def _create_thumbnail(self, yuv_frame, box, height=500) -> Optional[bytes]:
|
||||
"""Return jpg thumbnail of a region of the frame."""
|
||||
frame = cv2.cvtColor(yuv_frame, cv2.COLOR_YUV2BGR_I420)
|
||||
region = calculate_region(
|
||||
frame.shape, box[0], box[1], box[2], box[3], height, multiplier=1.4
|
||||
)
|
||||
frame = frame[region[1] : region[3], region[0] : region[2]]
|
||||
width = int(height * frame.shape[1] / frame.shape[0])
|
||||
frame = cv2.resize(frame, dsize=(width, height), interpolation=cv2.INTER_AREA)
|
||||
ret, jpg = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
|
||||
|
||||
if ret:
|
||||
return jpg.tobytes()
|
||||
|
||||
return None
|
||||
|
||||
def _embed_thumbnail(self, event_id: str, thumbnail: bytes, metadata: dict) -> None:
|
||||
"""Embed the thumbnail for an event."""
|
||||
|
||||
# Encode the thumbnail
|
||||
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
|
||||
self.embeddings.thumbnail.upsert(
|
||||
images=[img],
|
||||
metadatas=[metadata],
|
||||
ids=[event_id],
|
||||
)
|
||||
|
||||
def _embed_description(
|
||||
self, event: Event, thumbnails: list[bytes], metadata: dict
|
||||
) -> None:
|
||||
"""Embed the description for an event."""
|
||||
|
||||
description = self.genai_client.generate_description(thumbnails, metadata)
|
||||
|
||||
if description is None:
|
||||
logger.debug("Failed to generate description for %s", event.id)
|
||||
return
|
||||
|
||||
# fire and forget description update
|
||||
self.requestor.send_data(
|
||||
UPDATE_EVENT_DESCRIPTION,
|
||||
{"id": event.id, "description": description},
|
||||
)
|
||||
|
||||
# Encode the description
|
||||
self.embeddings.description.upsert(
|
||||
documents=[description],
|
||||
metadatas=[metadata],
|
||||
ids=[event.id],
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Generated description for %s (%d images): %s",
|
||||
event.id,
|
||||
len(thumbnails),
|
||||
description,
|
||||
)
|
||||
@@ -223,7 +223,7 @@ class AudioEventMaintainer(threading.Thread):
|
||||
audio_detections.append(label)
|
||||
|
||||
# send audio detection data
|
||||
self.detection_publisher.send_data(
|
||||
self.detection_publisher.publish(
|
||||
(
|
||||
self.config.name,
|
||||
datetime.datetime.now().timestamp(),
|
||||
|
||||
@@ -10,6 +10,7 @@ from pathlib import Path
|
||||
|
||||
from frigate.config import FrigateConfig
|
||||
from frigate.const import CLIPS_DIR
|
||||
from frigate.embeddings.embeddings import Embeddings
|
||||
from frigate.models import Event, Timeline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -26,6 +27,7 @@ class EventCleanup(threading.Thread):
|
||||
self.name = "event_cleanup"
|
||||
self.config = config
|
||||
self.stop_event = stop_event
|
||||
self.embeddings = Embeddings()
|
||||
self.camera_keys = list(self.config.cameras.keys())
|
||||
self.removed_camera_labels: list[str] = None
|
||||
self.camera_labels: dict[str, dict[str, any]] = {}
|
||||
@@ -197,9 +199,20 @@ class EventCleanup(threading.Thread):
|
||||
self.expire(EventCleanupType.snapshots)
|
||||
|
||||
# drop events from db where has_clip and has_snapshot are false
|
||||
delete_query = Event.delete().where(
|
||||
Event.has_clip == False, Event.has_snapshot == False
|
||||
events = (
|
||||
Event.select()
|
||||
.where(Event.has_clip == False, Event.has_snapshot == False)
|
||||
.iterator()
|
||||
)
|
||||
delete_query.execute()
|
||||
events_to_delete = [e.id for e in events]
|
||||
if len(events_to_delete) > 0:
|
||||
chunk_size = 50
|
||||
for i in range(0, len(events_to_delete), chunk_size):
|
||||
chunk = events_to_delete[i : i + chunk_size]
|
||||
Event.delete().where(Event.id << chunk).execute()
|
||||
|
||||
if self.config.semantic_search.enabled:
|
||||
self.embeddings.thumbnail.delete(ids=chunk)
|
||||
self.embeddings.description.delete(ids=chunk)
|
||||
|
||||
logger.info("Exiting event cleanup...")
|
||||
|
||||
@@ -86,7 +86,7 @@ class ExternalEventProcessor:
|
||||
|
||||
if source_type == "api":
|
||||
self.event_camera[event_id] = camera
|
||||
self.detection_updater.send_data(
|
||||
self.detection_updater.publish(
|
||||
(
|
||||
camera,
|
||||
now,
|
||||
@@ -115,7 +115,7 @@ class ExternalEventProcessor:
|
||||
)
|
||||
|
||||
if event_id in self.event_camera:
|
||||
self.detection_updater.send_data(
|
||||
self.detection_updater.publish(
|
||||
(
|
||||
self.event_camera[event_id],
|
||||
end_time,
|
||||
|
||||
@@ -237,7 +237,7 @@ class EventProcessor(threading.Thread):
|
||||
|
||||
if event_type == EventStateEnum.end:
|
||||
del self.events_in_process[event_data["id"]]
|
||||
self.event_end_publisher.publish((event_data["id"], camera))
|
||||
self.event_end_publisher.publish((event_data["id"], camera, updated_db))
|
||||
|
||||
def handle_external_detection(
|
||||
self, event_type: EventStateEnum, event_data: Event
|
||||
|
||||
63
frigate/genai/__init__.py
Normal file
63
frigate/genai/__init__.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Generative AI module for Frigate."""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from frigate.config import GenAIConfig, GenAIProviderEnum
|
||||
|
||||
PROVIDERS = {}
|
||||
|
||||
|
||||
def register_genai_provider(key: GenAIProviderEnum):
|
||||
"""Register a GenAI provider."""
|
||||
|
||||
def decorator(cls):
|
||||
PROVIDERS[key] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class GenAIClient:
|
||||
"""Generative AI client for Frigate."""
|
||||
|
||||
def __init__(self, genai_config: GenAIConfig, timeout: int = 60) -> None:
|
||||
self.genai_config: GenAIConfig = genai_config
|
||||
self.timeout = timeout
|
||||
self.provider = self._init_provider()
|
||||
|
||||
def generate_description(
|
||||
self, thumbnails: list[bytes], metadata: dict[str, any]
|
||||
) -> Optional[str]:
|
||||
"""Generate a description for the frame."""
|
||||
prompt = self.genai_config.object_prompts.get(
|
||||
metadata["label"], self.genai_config.prompt
|
||||
).format(**metadata)
|
||||
return self._send(prompt, thumbnails)
|
||||
|
||||
def _init_provider(self):
|
||||
"""Initialize the client."""
|
||||
return None
|
||||
|
||||
def _send(self, prompt: str, images: list[bytes]) -> Optional[str]:
|
||||
"""Submit a request to the provider."""
|
||||
return None
|
||||
|
||||
|
||||
def get_genai_client(genai_config: GenAIConfig) -> Optional[GenAIClient]:
|
||||
"""Get the GenAI client."""
|
||||
if genai_config.enabled:
|
||||
load_providers()
|
||||
provider = PROVIDERS.get(genai_config.provider)
|
||||
if provider:
|
||||
return provider(genai_config)
|
||||
return None
|
||||
|
||||
|
||||
def load_providers():
|
||||
package_dir = os.path.dirname(__file__)
|
||||
for filename in os.listdir(package_dir):
|
||||
if filename.endswith(".py") and filename != "__init__.py":
|
||||
module_name = f"frigate.genai.{filename[:-3]}"
|
||||
importlib.import_module(module_name)
|
||||
49
frigate/genai/gemini.py
Normal file
49
frigate/genai/gemini.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Gemini Provider for Frigate AI."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import google.generativeai as genai
|
||||
from google.api_core.exceptions import GoogleAPICallError
|
||||
|
||||
from frigate.config import GenAIProviderEnum
|
||||
from frigate.genai import GenAIClient, register_genai_provider
|
||||
|
||||
|
||||
@register_genai_provider(GenAIProviderEnum.gemini)
|
||||
class GeminiClient(GenAIClient):
|
||||
"""Generative AI client for Frigate using Gemini."""
|
||||
|
||||
provider: genai.GenerativeModel
|
||||
|
||||
def _init_provider(self):
|
||||
"""Initialize the client."""
|
||||
genai.configure(api_key=self.genai_config.api_key)
|
||||
return genai.GenerativeModel(self.genai_config.model)
|
||||
|
||||
def _send(self, prompt: str, images: list[bytes]) -> Optional[str]:
|
||||
"""Submit a request to Gemini."""
|
||||
data = [
|
||||
{
|
||||
"mime_type": "image/jpeg",
|
||||
"data": img,
|
||||
}
|
||||
for img in images
|
||||
] + [prompt]
|
||||
try:
|
||||
response = self.provider.generate_content(
|
||||
data,
|
||||
generation_config=genai.types.GenerationConfig(
|
||||
candidate_count=1,
|
||||
),
|
||||
request_options=genai.types.RequestOptions(
|
||||
timeout=self.timeout,
|
||||
),
|
||||
)
|
||||
except GoogleAPICallError:
|
||||
return None
|
||||
try:
|
||||
description = response.text.strip()
|
||||
except ValueError:
|
||||
# No description was generated
|
||||
return None
|
||||
return description
|
||||
41
frigate/genai/ollama.py
Normal file
41
frigate/genai/ollama.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Ollama Provider for Frigate AI."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from httpx import TimeoutException
|
||||
from ollama import Client as ApiClient
|
||||
from ollama import ResponseError
|
||||
|
||||
from frigate.config import GenAIProviderEnum
|
||||
from frigate.genai import GenAIClient, register_genai_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_genai_provider(GenAIProviderEnum.ollama)
|
||||
class OllamaClient(GenAIClient):
|
||||
"""Generative AI client for Frigate using Ollama."""
|
||||
|
||||
provider: ApiClient
|
||||
|
||||
def _init_provider(self):
|
||||
"""Initialize the client."""
|
||||
client = ApiClient(host=self.genai_config.base_url, timeout=self.timeout)
|
||||
response = client.pull(self.genai_config.model)
|
||||
if response["status"] != "success":
|
||||
logger.error("Failed to pull %s model from Ollama", self.genai_config.model)
|
||||
return None
|
||||
return client
|
||||
|
||||
def _send(self, prompt: str, images: list[bytes]) -> Optional[str]:
|
||||
"""Submit a request to Ollama"""
|
||||
try:
|
||||
result = self.provider.generate(
|
||||
self.genai_config.model,
|
||||
prompt,
|
||||
images=images,
|
||||
)
|
||||
return result["response"].strip()
|
||||
except (TimeoutException, ResponseError):
|
||||
return None
|
||||
51
frigate/genai/openai.py
Normal file
51
frigate/genai/openai.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""OpenAI Provider for Frigate AI."""
|
||||
|
||||
import base64
|
||||
from typing import Optional
|
||||
|
||||
from httpx import TimeoutException
|
||||
from openai import OpenAI
|
||||
|
||||
from frigate.config import GenAIProviderEnum
|
||||
from frigate.genai import GenAIClient, register_genai_provider
|
||||
|
||||
|
||||
@register_genai_provider(GenAIProviderEnum.openai)
|
||||
class OpenAIClient(GenAIClient):
|
||||
"""Generative AI client for Frigate using OpenAI."""
|
||||
|
||||
provider: OpenAI
|
||||
|
||||
def _init_provider(self):
|
||||
"""Initialize the client."""
|
||||
return OpenAI(api_key=self.genai_config.api_key)
|
||||
|
||||
def _send(self, prompt: str, images: list[bytes]) -> Optional[str]:
|
||||
"""Submit a request to OpenAI."""
|
||||
encoded_images = [base64.b64encode(image).decode("utf-8") for image in images]
|
||||
try:
|
||||
result = self.provider.chat.completions.create(
|
||||
model=self.genai_config.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image}",
|
||||
"detail": "low",
|
||||
},
|
||||
}
|
||||
for image in encoded_images
|
||||
]
|
||||
+ [prompt],
|
||||
},
|
||||
],
|
||||
timeout=self.timeout,
|
||||
)
|
||||
except TimeoutException:
|
||||
return None
|
||||
if len(result.choices) > 0:
|
||||
return result.choices[0].message.content.strip()
|
||||
return None
|
||||
@@ -1187,7 +1187,7 @@ class TrackedObjectProcessor(threading.Thread):
|
||||
]
|
||||
|
||||
# publish info on this frame
|
||||
self.detection_publisher.send_data(
|
||||
self.detection_publisher.publish(
|
||||
(
|
||||
camera,
|
||||
frame_time,
|
||||
@@ -1274,7 +1274,7 @@ class TrackedObjectProcessor(threading.Thread):
|
||||
if not update:
|
||||
break
|
||||
|
||||
event_id, camera = update
|
||||
event_id, camera, _ = update
|
||||
self.camera_states[camera].finished(event_id)
|
||||
|
||||
self.requestor.stop()
|
||||
|
||||
@@ -80,7 +80,7 @@ def output_frames(
|
||||
websocket_thread.start()
|
||||
|
||||
while not stop_event.is_set():
|
||||
(topic, data) = detection_subscriber.get_data(timeout=1)
|
||||
(topic, data) = detection_subscriber.check_for_update(timeout=1)
|
||||
|
||||
if not topic:
|
||||
continue
|
||||
@@ -134,7 +134,7 @@ def output_frames(
|
||||
move_preview_frames("clips")
|
||||
|
||||
while True:
|
||||
(topic, data) = detection_subscriber.get_data(timeout=0)
|
||||
(topic, data) = detection_subscriber.check_for_update(timeout=0)
|
||||
|
||||
if not topic:
|
||||
break
|
||||
|
||||
@@ -470,7 +470,7 @@ class RecordingMaintainer(threading.Thread):
|
||||
stale_frame_count_threshold = 10
|
||||
# empty the object recordings info queue
|
||||
while True:
|
||||
(topic, data) = self.detection_subscriber.get_data(
|
||||
(topic, data) = self.detection_subscriber.check_for_update(
|
||||
timeout=QUEUE_READ_TIMEOUT
|
||||
)
|
||||
|
||||
|
||||
@@ -424,7 +424,7 @@ class ReviewSegmentMaintainer(threading.Thread):
|
||||
camera_name = updated_topic.rpartition("/")[-1]
|
||||
self.config.cameras[camera_name].record = updated_record_config
|
||||
|
||||
(topic, data) = self.detection_subscriber.get_data(timeout=1)
|
||||
(topic, data) = self.detection_subscriber.check_for_update(timeout=1)
|
||||
|
||||
if not topic:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user