forked from Github/frigate
Use sqlite-vec extension instead of chromadb for embeddings (#14163)
* swap sqlite_vec for chroma in requirements * load sqlite_vec in embeddings manager * remove chroma and revamp Embeddings class for sqlite_vec * manual minilm onnx inference * remove chroma in clip model * migrate api from chroma to sqlite_vec * migrate event cleanup from chroma to sqlite_vec * migrate embedding maintainer from chroma to sqlite_vec * genai description for sqlite_vec * load sqlite_vec in main thread db * extend the SqliteQueueDatabase class and use peewee db.execute_sql * search with Event type for similarity * fix similarity search * install and add comment about transformers * fix normalization * add id filter * clean up * clean up * fully remove chroma and add transformers env var * readd uvicorn for fastapi * readd tokenizer parallelism env var * remove chroma from docs * remove chroma from UI * try removing custom pysqlite3 build * hard code limit * optimize queries * revert explore query * fix query * keep building pysqlite3 * single pass fetch and process * remove unnecessary re-embed * update deps * move SqliteVecQueueDatabase to db directory * make search thumbnail take up full size of results box * improve typing * improve model downloading and add status screen * daemon downloading thread * catch case when semantic search is disabled * fix typing * build sqlite_vec from source * resolve conflict * file permissions * try build deps * remove sources * sources * fix thread start * include git in build * reorder embeddings after detectors are started * build with sqlite amalgamation * non-platform specific * use wget instead of curl * remove unzip -d * remove sqlite_vec from requirements and load the compiled version * fix build * avoid race in db connection * add scale_factor and bias to description zscore normalization
This commit is contained in:
@@ -384,12 +384,12 @@ def vainfo():
|
||||
|
||||
@router.get("/logs/{service}", tags=[Tags.logs])
|
||||
def logs(
|
||||
service: str = Path(enum=["frigate", "nginx", "go2rtc", "chroma"]),
|
||||
service: str = Path(enum=["frigate", "nginx", "go2rtc"]),
|
||||
download: Optional[str] = None,
|
||||
start: Optional[int] = 0,
|
||||
end: Optional[int] = None,
|
||||
):
|
||||
"""Get logs for the requested service (frigate/nginx/go2rtc/chroma)"""
|
||||
"""Get logs for the requested service (frigate/nginx/go2rtc)"""
|
||||
|
||||
def download_logs(service_location: str):
|
||||
try:
|
||||
@@ -408,7 +408,6 @@ def logs(
|
||||
"frigate": "/dev/shm/logs/frigate/current",
|
||||
"go2rtc": "/dev/shm/logs/go2rtc/current",
|
||||
"nginx": "/dev/shm/logs/nginx/current",
|
||||
"chroma": "/dev/shm/logs/chroma/current",
|
||||
}
|
||||
service_location = log_locations.get(service)
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
"""Event apis."""
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from functools import reduce
|
||||
@@ -10,12 +8,10 @@ from pathlib import Path
|
||||
from urllib.parse import unquote
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.params import Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
from peewee import JOIN, DoesNotExist, fn, operator
|
||||
from PIL import Image
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
|
||||
from frigate.api.defs.events_body import (
|
||||
@@ -39,7 +35,6 @@ from frigate.const import (
|
||||
CLIPS_DIR,
|
||||
)
|
||||
from frigate.embeddings import EmbeddingsContext
|
||||
from frigate.embeddings.embeddings import get_metadata
|
||||
from frigate.models import Event, ReviewSegment, Timeline
|
||||
from frigate.object_processing import TrackedObject
|
||||
from frigate.util.builtin import get_tz_modifiers
|
||||
@@ -411,16 +406,12 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
|
||||
event_filters = []
|
||||
|
||||
if cameras != "all":
|
||||
camera_list = cameras.split(",")
|
||||
event_filters.append((Event.camera << camera_list))
|
||||
event_filters.append((Event.camera << cameras.split(",")))
|
||||
|
||||
if labels != "all":
|
||||
label_list = labels.split(",")
|
||||
event_filters.append((Event.label << label_list))
|
||||
event_filters.append((Event.label << labels.split(",")))
|
||||
|
||||
if zones != "all":
|
||||
# use matching so events with multiple zones
|
||||
# still match on a search where any zone matches
|
||||
zone_clauses = []
|
||||
filtered_zones = zones.split(",")
|
||||
|
||||
@@ -431,8 +422,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
|
||||
for zone in filtered_zones:
|
||||
zone_clauses.append((Event.zones.cast("text") % f'*"{zone}"*'))
|
||||
|
||||
zone_clause = reduce(operator.or_, zone_clauses)
|
||||
event_filters.append((zone_clause))
|
||||
event_filters.append((reduce(operator.or_, zone_clauses)))
|
||||
|
||||
if after:
|
||||
event_filters.append((Event.start_time > after))
|
||||
@@ -441,13 +431,11 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
|
||||
event_filters.append((Event.start_time < before))
|
||||
|
||||
if time_range != DEFAULT_TIME_RANGE:
|
||||
# get timezone arg to ensure browser times are used
|
||||
tz_name = params.timezone
|
||||
hour_modifier, minute_modifier, _ = get_tz_modifiers(tz_name)
|
||||
|
||||
times = time_range.split(",")
|
||||
time_after = times[0]
|
||||
time_before = times[1]
|
||||
time_after, time_before = times
|
||||
|
||||
start_hour_fun = fn.strftime(
|
||||
"%H:%M",
|
||||
@@ -470,132 +458,113 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
|
||||
event_filters.append((start_hour_fun > time_after))
|
||||
event_filters.append((start_hour_fun < time_before))
|
||||
|
||||
if event_filters:
|
||||
filtered_event_ids = (
|
||||
Event.select(Event.id)
|
||||
.where(reduce(operator.and_, event_filters))
|
||||
.tuples()
|
||||
.iterator()
|
||||
)
|
||||
event_ids = [event_id[0] for event_id in filtered_event_ids]
|
||||
|
||||
if not event_ids:
|
||||
return JSONResponse(content=[]) # No events to search on
|
||||
else:
|
||||
event_ids = []
|
||||
|
||||
# Build the Chroma where clause based on the event IDs
|
||||
where = {"id": {"$in": event_ids}} if event_ids else {}
|
||||
|
||||
thumb_ids = {}
|
||||
desc_ids = {}
|
||||
|
||||
# Perform semantic search
|
||||
search_results = {}
|
||||
if search_type == "similarity":
|
||||
# Grab the ids of events that match the thumbnail image embeddings
|
||||
try:
|
||||
search_event: Event = Event.get(Event.id == event_id)
|
||||
except DoesNotExist:
|
||||
return JSONResponse(
|
||||
content=(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Event not found",
|
||||
}
|
||||
),
|
||||
content={
|
||||
"success": False,
|
||||
"message": "Event not found",
|
||||
},
|
||||
status_code=404,
|
||||
)
|
||||
thumbnail = base64.b64decode(search_event.thumbnail)
|
||||
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
|
||||
thumb_result = context.embeddings.thumbnail.query(
|
||||
query_images=[img],
|
||||
n_results=limit,
|
||||
where=where,
|
||||
)
|
||||
|
||||
thumb_result = context.embeddings.search_thumbnail(search_event)
|
||||
thumb_ids = dict(
|
||||
zip(
|
||||
thumb_result["ids"][0],
|
||||
context.thumb_stats.normalize(thumb_result["distances"][0]),
|
||||
[result[0] for result in thumb_result],
|
||||
context.thumb_stats.normalize([result[1] for result in thumb_result]),
|
||||
)
|
||||
)
|
||||
search_results = {
|
||||
event_id: {"distance": distance, "source": "thumbnail"}
|
||||
for event_id, distance in thumb_ids.items()
|
||||
}
|
||||
else:
|
||||
search_types = search_type.split(",")
|
||||
|
||||
if "thumbnail" in search_types:
|
||||
thumb_result = context.embeddings.thumbnail.query(
|
||||
query_texts=[query],
|
||||
n_results=limit,
|
||||
where=where,
|
||||
)
|
||||
# Do a rudimentary normalization of the difference in distances returned by CLIP and MiniLM.
|
||||
thumb_result = context.embeddings.search_thumbnail(query)
|
||||
thumb_ids = dict(
|
||||
zip(
|
||||
thumb_result["ids"][0],
|
||||
context.thumb_stats.normalize(thumb_result["distances"][0]),
|
||||
[result[0] for result in thumb_result],
|
||||
context.thumb_stats.normalize(
|
||||
[result[1] for result in thumb_result]
|
||||
),
|
||||
)
|
||||
)
|
||||
search_results.update(
|
||||
{
|
||||
event_id: {"distance": distance, "source": "thumbnail"}
|
||||
for event_id, distance in thumb_ids.items()
|
||||
}
|
||||
)
|
||||
|
||||
if "description" in search_types:
|
||||
desc_result = context.embeddings.description.query(
|
||||
query_texts=[query],
|
||||
n_results=limit,
|
||||
where=where,
|
||||
)
|
||||
desc_result = context.embeddings.search_description(query)
|
||||
desc_ids = dict(
|
||||
zip(
|
||||
desc_result["ids"][0],
|
||||
context.desc_stats.normalize(desc_result["distances"][0]),
|
||||
[result[0] for result in desc_result],
|
||||
context.desc_stats.normalize([result[1] for result in desc_result]),
|
||||
)
|
||||
)
|
||||
for event_id, distance in desc_ids.items():
|
||||
if (
|
||||
event_id not in search_results
|
||||
or distance < search_results[event_id]["distance"]
|
||||
):
|
||||
search_results[event_id] = {
|
||||
"distance": distance,
|
||||
"source": "description",
|
||||
}
|
||||
|
||||
results = {}
|
||||
for event_id in thumb_ids.keys() | desc_ids:
|
||||
min_distance = min(
|
||||
i
|
||||
for i in (thumb_ids.get(event_id), desc_ids.get(event_id))
|
||||
if i is not None
|
||||
)
|
||||
results[event_id] = {
|
||||
"distance": min_distance,
|
||||
"source": "thumbnail"
|
||||
if min_distance == thumb_ids.get(event_id)
|
||||
else "description",
|
||||
}
|
||||
|
||||
if not results:
|
||||
if not search_results:
|
||||
return JSONResponse(content=[])
|
||||
|
||||
# Get the event data
|
||||
events = (
|
||||
Event.select(*selected_columns)
|
||||
.join(
|
||||
ReviewSegment,
|
||||
JOIN.LEFT_OUTER,
|
||||
on=(fn.json_extract(ReviewSegment.data, "$.detections").contains(Event.id)),
|
||||
)
|
||||
.where(Event.id << list(results.keys()))
|
||||
.dicts()
|
||||
.iterator()
|
||||
# Fetch events in a single query
|
||||
events_query = Event.select(*selected_columns).join(
|
||||
ReviewSegment,
|
||||
JOIN.LEFT_OUTER,
|
||||
on=(fn.json_extract(ReviewSegment.data, "$.detections").contains(Event.id)),
|
||||
)
|
||||
events = list(events)
|
||||
|
||||
events = [
|
||||
{k: v for k, v in event.items() if k != "data"}
|
||||
| {
|
||||
"data": {
|
||||
k: v
|
||||
for k, v in event["data"].items()
|
||||
if k in ["type", "score", "top_score", "description"]
|
||||
}
|
||||
}
|
||||
| {
|
||||
"search_distance": results[event["id"]]["distance"],
|
||||
"search_source": results[event["id"]]["source"],
|
||||
}
|
||||
for event in events
|
||||
]
|
||||
events = sorted(events, key=lambda x: x["search_distance"])[:limit]
|
||||
# Apply filters, if any
|
||||
if event_filters:
|
||||
events_query = events_query.where(reduce(operator.and_, event_filters))
|
||||
|
||||
return JSONResponse(content=events)
|
||||
# If we did a similarity search, limit events to those in search_results
|
||||
if search_results:
|
||||
events_query = events_query.where(Event.id << list(search_results.keys()))
|
||||
|
||||
# Fetch events and process them in a single pass
|
||||
processed_events = []
|
||||
for event in events_query.dicts():
|
||||
processed_event = {k: v for k, v in event.items() if k != "data"}
|
||||
processed_event["data"] = {
|
||||
k: v
|
||||
for k, v in event["data"].items()
|
||||
if k in ["type", "score", "top_score", "description"]
|
||||
}
|
||||
|
||||
if event["id"] in search_results:
|
||||
processed_event["search_distance"] = search_results[event["id"]]["distance"]
|
||||
processed_event["search_source"] = search_results[event["id"]]["source"]
|
||||
|
||||
processed_events.append(processed_event)
|
||||
|
||||
# Sort by search distance if search_results are available, otherwise by start_time
|
||||
if search_results:
|
||||
processed_events.sort(key=lambda x: x.get("search_distance", float("inf")))
|
||||
else:
|
||||
processed_events.sort(key=lambda x: x["start_time"], reverse=True)
|
||||
|
||||
# Limit the number of events returned
|
||||
processed_events = processed_events[:limit]
|
||||
|
||||
return JSONResponse(content=processed_events)
|
||||
|
||||
|
||||
@router.get("/events/summary")
|
||||
@@ -975,10 +944,9 @@ def set_description(
|
||||
# If semantic search is enabled, update the index
|
||||
if request.app.frigate_config.semantic_search.enabled:
|
||||
context: EmbeddingsContext = request.app.embeddings
|
||||
context.embeddings.description.upsert(
|
||||
documents=[new_description],
|
||||
metadatas=[get_metadata(event)],
|
||||
ids=[event_id],
|
||||
context.embeddings.upsert_description(
|
||||
event_id=event_id,
|
||||
description=new_description,
|
||||
)
|
||||
|
||||
response_message = (
|
||||
@@ -1065,8 +1033,8 @@ def delete_event(request: Request, event_id: str):
|
||||
# If semantic search is enabled, update the index
|
||||
if request.app.frigate_config.semantic_search.enabled:
|
||||
context: EmbeddingsContext = request.app.embeddings
|
||||
context.embeddings.thumbnail.delete(ids=[event_id])
|
||||
context.embeddings.description.delete(ids=[event_id])
|
||||
context.embeddings.delete_thumbnail(id=[event_id])
|
||||
context.embeddings.delete_description(id=[event_id])
|
||||
return JSONResponse(
|
||||
content=({"success": True, "message": "Event " + event_id + " deleted"}),
|
||||
status_code=200,
|
||||
|
||||
@@ -12,7 +12,6 @@ import psutil
|
||||
import uvicorn
|
||||
from peewee_migrate import Router
|
||||
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||
from playhouse.sqliteq import SqliteQueueDatabase
|
||||
|
||||
import frigate.util as util
|
||||
from frigate.api.auth import hash_password
|
||||
@@ -38,6 +37,7 @@ from frigate.const import (
|
||||
MODEL_CACHE_DIR,
|
||||
RECORD_DIR,
|
||||
)
|
||||
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
|
||||
from frigate.embeddings import EmbeddingsContext, manage_embeddings
|
||||
from frigate.events.audio import AudioProcessor
|
||||
from frigate.events.cleanup import EventCleanup
|
||||
@@ -88,6 +88,7 @@ class FrigateApp:
|
||||
self.camera_metrics: dict[str, CameraMetrics] = {}
|
||||
self.ptz_metrics: dict[str, PTZMetrics] = {}
|
||||
self.processes: dict[str, int] = {}
|
||||
self.embeddings: Optional[EmbeddingsContext] = None
|
||||
self.region_grids: dict[str, list[list[dict[str, int]]]] = {}
|
||||
self.config = config
|
||||
|
||||
@@ -220,11 +221,8 @@ class FrigateApp:
|
||||
|
||||
def init_embeddings_manager(self) -> None:
|
||||
if not self.config.semantic_search.enabled:
|
||||
self.embeddings = None
|
||||
return
|
||||
|
||||
# Create a client for other processes to use
|
||||
self.embeddings = EmbeddingsContext()
|
||||
embedding_process = util.Process(
|
||||
target=manage_embeddings,
|
||||
name="embeddings_manager",
|
||||
@@ -239,7 +237,7 @@ class FrigateApp:
|
||||
def bind_database(self) -> None:
|
||||
"""Bind db to the main process."""
|
||||
# NOTE: all db accessing processes need to be created before the db can be bound to the main process
|
||||
self.db = SqliteQueueDatabase(
|
||||
self.db = SqliteVecQueueDatabase(
|
||||
self.config.database.path,
|
||||
pragmas={
|
||||
"auto_vacuum": "FULL", # Does not defragment database
|
||||
@@ -249,6 +247,7 @@ class FrigateApp:
|
||||
timeout=max(
|
||||
60, 10 * len([c for c in self.config.cameras.values() if c.enabled])
|
||||
),
|
||||
load_vec_extension=self.config.semantic_search.enabled,
|
||||
)
|
||||
models = [
|
||||
Event,
|
||||
@@ -274,6 +273,11 @@ class FrigateApp:
|
||||
|
||||
migrate_exports(self.config.ffmpeg, list(self.config.cameras.keys()))
|
||||
|
||||
def init_embeddings_client(self) -> None:
|
||||
if self.config.semantic_search.enabled:
|
||||
# Create a client for other processes to use
|
||||
self.embeddings = EmbeddingsContext(self.db)
|
||||
|
||||
def init_external_event_processor(self) -> None:
|
||||
self.external_event_processor = ExternalEventProcessor(self.config)
|
||||
|
||||
@@ -464,7 +468,7 @@ class FrigateApp:
|
||||
self.event_processor.start()
|
||||
|
||||
def start_event_cleanup(self) -> None:
|
||||
self.event_cleanup = EventCleanup(self.config, self.stop_event)
|
||||
self.event_cleanup = EventCleanup(self.config, self.stop_event, self.db)
|
||||
self.event_cleanup.start()
|
||||
|
||||
def start_record_cleanup(self) -> None:
|
||||
@@ -576,13 +580,14 @@ class FrigateApp:
|
||||
self.init_onvif()
|
||||
self.init_recording_manager()
|
||||
self.init_review_segment_manager()
|
||||
self.init_embeddings_manager()
|
||||
self.init_go2rtc()
|
||||
self.bind_database()
|
||||
self.check_db_data_migrations()
|
||||
self.init_inter_process_communicator()
|
||||
self.init_dispatcher()
|
||||
self.start_detectors()
|
||||
self.init_embeddings_manager()
|
||||
self.init_embeddings_client()
|
||||
self.start_video_output_processor()
|
||||
self.start_ptz_autotracker()
|
||||
self.init_historical_regions()
|
||||
|
||||
@@ -16,10 +16,12 @@ from frigate.const import (
|
||||
REQUEST_REGION_GRID,
|
||||
UPDATE_CAMERA_ACTIVITY,
|
||||
UPDATE_EVENT_DESCRIPTION,
|
||||
UPDATE_MODEL_STATE,
|
||||
UPSERT_REVIEW_SEGMENT,
|
||||
)
|
||||
from frigate.models import Event, Previews, Recordings, ReviewSegment
|
||||
from frigate.ptz.onvif import OnvifCommandEnum, OnvifController
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util.object import get_camera_regions_grid
|
||||
from frigate.util.services import restart_frigate
|
||||
|
||||
@@ -83,6 +85,7 @@ class Dispatcher:
|
||||
comm.subscribe(self._receive)
|
||||
|
||||
self.camera_activity = {}
|
||||
self.model_state = {}
|
||||
|
||||
def _receive(self, topic: str, payload: str) -> Optional[Any]:
|
||||
"""Handle receiving of payload from communicators."""
|
||||
@@ -144,6 +147,14 @@ class Dispatcher:
|
||||
"event_update",
|
||||
json.dumps({"id": event.id, "description": event.data["description"]}),
|
||||
)
|
||||
elif topic == UPDATE_MODEL_STATE:
|
||||
model = payload["model"]
|
||||
state = payload["state"]
|
||||
self.model_state[model] = ModelStatusTypesEnum[state]
|
||||
self.publish("model_state", json.dumps(self.model_state))
|
||||
elif topic == "modelState":
|
||||
model_state = self.model_state.copy()
|
||||
self.publish("model_state", json.dumps(model_state))
|
||||
elif topic == "onConnect":
|
||||
camera_status = self.camera_activity.copy()
|
||||
|
||||
|
||||
@@ -84,6 +84,7 @@ UPSERT_REVIEW_SEGMENT = "upsert_review_segment"
|
||||
CLEAR_ONGOING_REVIEW_SEGMENTS = "clear_ongoing_review_segments"
|
||||
UPDATE_CAMERA_ACTIVITY = "update_camera_activity"
|
||||
UPDATE_EVENT_DESCRIPTION = "update_event_description"
|
||||
UPDATE_MODEL_STATE = "update_model_state"
|
||||
|
||||
# Stats Values
|
||||
|
||||
|
||||
23
frigate/db/sqlitevecq.py
Normal file
23
frigate/db/sqlitevecq.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import sqlite3
|
||||
|
||||
from playhouse.sqliteq import SqliteQueueDatabase
|
||||
|
||||
|
||||
class SqliteVecQueueDatabase(SqliteQueueDatabase):
|
||||
def __init__(self, *args, load_vec_extension: bool = False, **kwargs) -> None:
|
||||
self.load_vec_extension: bool = load_vec_extension
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# no extension necessary, sqlite will load correctly for each platform
|
||||
self.sqlite_vec_path = "/usr/local/lib/vec0"
|
||||
|
||||
def _connect(self, *args, **kwargs) -> sqlite3.Connection:
|
||||
conn: sqlite3.Connection = super()._connect(*args, **kwargs)
|
||||
if self.load_vec_extension:
|
||||
self._load_vec_extension(conn)
|
||||
return conn
|
||||
|
||||
def _load_vec_extension(self, conn: sqlite3.Connection) -> None:
|
||||
conn.enable_load_extension(True)
|
||||
conn.load_extension(self.sqlite_vec_path)
|
||||
conn.enable_load_extension(False)
|
||||
@@ -1,18 +1,19 @@
|
||||
"""ChromaDB embeddings database."""
|
||||
"""SQLite-vec embeddings database."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
from types import FrameType
|
||||
from typing import Optional
|
||||
|
||||
from playhouse.sqliteq import SqliteQueueDatabase
|
||||
from setproctitle import setproctitle
|
||||
|
||||
from frigate.config import FrigateConfig
|
||||
from frigate.const import CONFIG_DIR
|
||||
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
|
||||
from frigate.models import Event
|
||||
from frigate.util.services import listen
|
||||
|
||||
@@ -41,7 +42,7 @@ def manage_embeddings(config: FrigateConfig) -> None:
|
||||
listen()
|
||||
|
||||
# Configure Frigate DB
|
||||
db = SqliteQueueDatabase(
|
||||
db = SqliteVecQueueDatabase(
|
||||
config.database.path,
|
||||
pragmas={
|
||||
"auto_vacuum": "FULL", # Does not defragment database
|
||||
@@ -49,17 +50,19 @@ def manage_embeddings(config: FrigateConfig) -> None:
|
||||
"synchronous": "NORMAL", # Safe when using WAL https://www.sqlite.org/pragma.html#pragma_synchronous
|
||||
},
|
||||
timeout=max(60, 10 * len([c for c in config.cameras.values() if c.enabled])),
|
||||
load_vec_extension=True,
|
||||
)
|
||||
models = [Event]
|
||||
db.bind(models)
|
||||
|
||||
embeddings = Embeddings()
|
||||
embeddings = Embeddings(db)
|
||||
|
||||
# Check if we need to re-index events
|
||||
if config.semantic_search.reindex:
|
||||
embeddings.reindex()
|
||||
|
||||
maintainer = EmbeddingMaintainer(
|
||||
db,
|
||||
config,
|
||||
stop_event,
|
||||
)
|
||||
@@ -67,14 +70,14 @@ def manage_embeddings(config: FrigateConfig) -> None:
|
||||
|
||||
|
||||
class EmbeddingsContext:
|
||||
def __init__(self):
|
||||
self.embeddings = Embeddings()
|
||||
def __init__(self, db: SqliteVecQueueDatabase):
|
||||
self.embeddings = Embeddings(db)
|
||||
self.thumb_stats = ZScoreNormalization()
|
||||
self.desc_stats = ZScoreNormalization()
|
||||
self.desc_stats = ZScoreNormalization(scale_factor=2.5, bias=0.5)
|
||||
|
||||
# load stats from disk
|
||||
try:
|
||||
with open(f"{CONFIG_DIR}/.search_stats.json", "r") as f:
|
||||
with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "r") as f:
|
||||
data = json.loads(f.read())
|
||||
self.thumb_stats.from_dict(data["thumb_stats"])
|
||||
self.desc_stats.from_dict(data["desc_stats"])
|
||||
@@ -87,5 +90,5 @@ class EmbeddingsContext:
|
||||
"thumb_stats": self.thumb_stats.to_dict(),
|
||||
"desc_stats": self.desc_stats.to_dict(),
|
||||
}
|
||||
with open(f"{CONFIG_DIR}/.search_stats.json", "w") as f:
|
||||
f.write(json.dumps(contents))
|
||||
with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "w") as f:
|
||||
json.dump(contents, f)
|
||||
|
||||
@@ -1,37 +1,23 @@
|
||||
"""ChromaDB embeddings database."""
|
||||
"""SQLite-vec embeddings database."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import sys
|
||||
import struct
|
||||
import time
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
|
||||
from frigate.comms.inter_process import InterProcessRequestor
|
||||
from frigate.const import UPDATE_MODEL_STATE
|
||||
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
|
||||
from frigate.models import Event
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
|
||||
# Squelch posthog logging
|
||||
logging.getLogger("chromadb.telemetry.product.posthog").setLevel(logging.CRITICAL)
|
||||
|
||||
# Hot-swap the sqlite3 module for Chroma compatibility
|
||||
try:
|
||||
from chromadb import Collection
|
||||
from chromadb import HttpClient as ChromaClient
|
||||
from chromadb.config import Settings
|
||||
|
||||
from .functions.clip import ClipEmbedding
|
||||
from .functions.minilm_l6_v2 import MiniLMEmbedding
|
||||
except RuntimeError:
|
||||
__import__("pysqlite3")
|
||||
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
|
||||
from chromadb import Collection
|
||||
from chromadb import HttpClient as ChromaClient
|
||||
from chromadb.config import Settings
|
||||
|
||||
from .functions.clip import ClipEmbedding
|
||||
from .functions.minilm_l6_v2 import MiniLMEmbedding
|
||||
from .functions.clip import ClipEmbedding
|
||||
from .functions.minilm_l6_v2 import MiniLMEmbedding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -67,34 +53,198 @@ def get_metadata(event: Event) -> dict:
|
||||
)
|
||||
|
||||
|
||||
def serialize(vector: List[float]) -> bytes:
|
||||
"""Serializes a list of floats into a compact "raw bytes" format"""
|
||||
return struct.pack("%sf" % len(vector), *vector)
|
||||
|
||||
|
||||
def deserialize(bytes_data: bytes) -> List[float]:
|
||||
"""Deserializes a compact "raw bytes" format into a list of floats"""
|
||||
return list(struct.unpack("%sf" % (len(bytes_data) // 4), bytes_data))
|
||||
|
||||
|
||||
class Embeddings:
|
||||
"""ChromaDB embeddings database."""
|
||||
"""SQLite-vec embeddings database."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.client: ChromaClient = ChromaClient(
|
||||
host="127.0.0.1",
|
||||
settings=Settings(anonymized_telemetry=False),
|
||||
def __init__(self, db: SqliteVecQueueDatabase) -> None:
|
||||
self.db = db
|
||||
self.requestor = InterProcessRequestor()
|
||||
|
||||
# Create tables if they don't exist
|
||||
self._create_tables()
|
||||
|
||||
models = [
|
||||
"sentence-transformers/all-MiniLM-L6-v2-model.onnx",
|
||||
"sentence-transformers/all-MiniLM-L6-v2-tokenizer",
|
||||
"clip-clip_image_model_vitb32.onnx",
|
||||
"clip-clip_text_model_vitb32.onnx",
|
||||
]
|
||||
|
||||
for model in models:
|
||||
self.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": model,
|
||||
"state": ModelStatusTypesEnum.not_downloaded,
|
||||
},
|
||||
)
|
||||
|
||||
self.clip_embedding = ClipEmbedding(
|
||||
preferred_providers=["CPUExecutionProvider"]
|
||||
)
|
||||
self.minilm_embedding = MiniLMEmbedding(
|
||||
preferred_providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
@property
|
||||
def thumbnail(self) -> Collection:
|
||||
return self.client.get_or_create_collection(
|
||||
name="event_thumbnail", embedding_function=ClipEmbedding()
|
||||
def _create_tables(self):
|
||||
# Create vec0 virtual table for thumbnail embeddings
|
||||
self.db.execute_sql("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
|
||||
id TEXT PRIMARY KEY,
|
||||
thumbnail_embedding FLOAT[512]
|
||||
);
|
||||
""")
|
||||
|
||||
# Create vec0 virtual table for description embeddings
|
||||
self.db.execute_sql("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
|
||||
id TEXT PRIMARY KEY,
|
||||
description_embedding FLOAT[384]
|
||||
);
|
||||
""")
|
||||
|
||||
def upsert_thumbnail(self, event_id: str, thumbnail: bytes):
|
||||
# Convert thumbnail bytes to PIL Image
|
||||
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
|
||||
# Generate embedding using CLIP
|
||||
embedding = self.clip_embedding([image])[0]
|
||||
|
||||
self.db.execute_sql(
|
||||
"""
|
||||
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
|
||||
VALUES(?, ?)
|
||||
""",
|
||||
(event_id, serialize(embedding)),
|
||||
)
|
||||
|
||||
@property
|
||||
def description(self) -> Collection:
|
||||
return self.client.get_or_create_collection(
|
||||
name="event_description",
|
||||
embedding_function=MiniLMEmbedding(
|
||||
preferred_providers=["CPUExecutionProvider"]
|
||||
),
|
||||
return embedding
|
||||
|
||||
def upsert_description(self, event_id: str, description: str):
|
||||
# Generate embedding using MiniLM
|
||||
embedding = self.minilm_embedding([description])[0]
|
||||
|
||||
self.db.execute_sql(
|
||||
"""
|
||||
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
||||
VALUES(?, ?)
|
||||
""",
|
||||
(event_id, serialize(embedding)),
|
||||
)
|
||||
|
||||
return embedding
|
||||
|
||||
def delete_thumbnail(self, event_ids: List[str]) -> None:
|
||||
ids = ",".join(["?" for _ in event_ids])
|
||||
self.db.execute_sql(
|
||||
f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids
|
||||
)
|
||||
|
||||
def delete_description(self, event_ids: List[str]) -> None:
|
||||
ids = ",".join(["?" for _ in event_ids])
|
||||
self.db.execute_sql(
|
||||
f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids
|
||||
)
|
||||
|
||||
def search_thumbnail(
|
||||
self, query: Union[Event, str], event_ids: List[str] = None
|
||||
) -> List[Tuple[str, float]]:
|
||||
if query.__class__ == Event:
|
||||
cursor = self.db.execute_sql(
|
||||
"""
|
||||
SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
|
||||
""",
|
||||
[query.id],
|
||||
)
|
||||
|
||||
row = cursor.fetchone() if cursor else None
|
||||
|
||||
if row:
|
||||
query_embedding = deserialize(
|
||||
row[0]
|
||||
) # Deserialize the thumbnail embedding
|
||||
else:
|
||||
# If no embedding found, generate it and return it
|
||||
thumbnail = base64.b64decode(query.thumbnail)
|
||||
query_embedding = self.upsert_thumbnail(query.id, thumbnail)
|
||||
else:
|
||||
query_embedding = self.clip_embedding([query])[0]
|
||||
|
||||
sql_query = """
|
||||
SELECT
|
||||
id,
|
||||
distance
|
||||
FROM vec_thumbnails
|
||||
WHERE thumbnail_embedding MATCH ?
|
||||
AND k = 100
|
||||
"""
|
||||
|
||||
# Add the IN clause if event_ids is provided and not empty
|
||||
# this is the only filter supported by sqlite-vec as of 0.1.3
|
||||
# but it seems to be broken in this version
|
||||
if event_ids:
|
||||
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
|
||||
|
||||
# order by distance DESC is not implemented in this version of sqlite-vec
|
||||
# when it's implemented, we can use cosine similarity
|
||||
sql_query += " ORDER BY distance"
|
||||
|
||||
parameters = (
|
||||
[serialize(query_embedding)] + event_ids
|
||||
if event_ids
|
||||
else [serialize(query_embedding)]
|
||||
)
|
||||
|
||||
results = self.db.execute_sql(sql_query, parameters).fetchall()
|
||||
|
||||
return results
|
||||
|
||||
def search_description(
|
||||
self, query_text: str, event_ids: List[str] = None
|
||||
) -> List[Tuple[str, float]]:
|
||||
query_embedding = self.minilm_embedding([query_text])[0]
|
||||
|
||||
# Prepare the base SQL query
|
||||
sql_query = """
|
||||
SELECT
|
||||
id,
|
||||
distance
|
||||
FROM vec_descriptions
|
||||
WHERE description_embedding MATCH ?
|
||||
AND k = 100
|
||||
"""
|
||||
|
||||
# Add the IN clause if event_ids is provided and not empty
|
||||
# this is the only filter supported by sqlite-vec as of 0.1.3
|
||||
# but it seems to be broken in this version
|
||||
if event_ids:
|
||||
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
|
||||
|
||||
# order by distance DESC is not implemented in this version of sqlite-vec
|
||||
# when it's implemented, we can use cosine similarity
|
||||
sql_query += " ORDER BY distance"
|
||||
|
||||
parameters = (
|
||||
[serialize(query_embedding)] + event_ids
|
||||
if event_ids
|
||||
else [serialize(query_embedding)]
|
||||
)
|
||||
|
||||
results = self.db.execute_sql(sql_query, parameters).fetchall()
|
||||
|
||||
return results
|
||||
|
||||
def reindex(self) -> None:
|
||||
"""Reindex all event embeddings."""
|
||||
logger.info("Indexing event embeddings...")
|
||||
self.client.reset()
|
||||
|
||||
st = time.time()
|
||||
totals = {
|
||||
@@ -115,37 +265,14 @@ class Embeddings:
|
||||
)
|
||||
|
||||
while len(events) > 0:
|
||||
thumbnails = {"ids": [], "images": [], "metadatas": []}
|
||||
descriptions = {"ids": [], "documents": [], "metadatas": []}
|
||||
|
||||
event: Event
|
||||
for event in events:
|
||||
metadata = get_metadata(event)
|
||||
thumbnail = base64.b64decode(event.thumbnail)
|
||||
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
|
||||
thumbnails["ids"].append(event.id)
|
||||
thumbnails["images"].append(img)
|
||||
thumbnails["metadatas"].append(metadata)
|
||||
self.upsert_thumbnail(event.id, thumbnail)
|
||||
totals["thumb"] += 1
|
||||
if description := event.data.get("description", "").strip():
|
||||
descriptions["ids"].append(event.id)
|
||||
descriptions["documents"].append(description)
|
||||
descriptions["metadatas"].append(metadata)
|
||||
|
||||
if len(thumbnails["ids"]) > 0:
|
||||
totals["thumb"] += len(thumbnails["ids"])
|
||||
self.thumbnail.upsert(
|
||||
images=thumbnails["images"],
|
||||
metadatas=thumbnails["metadatas"],
|
||||
ids=thumbnails["ids"],
|
||||
)
|
||||
|
||||
if len(descriptions["ids"]) > 0:
|
||||
totals["desc"] += len(descriptions["ids"])
|
||||
self.description.upsert(
|
||||
documents=descriptions["documents"],
|
||||
metadatas=descriptions["metadatas"],
|
||||
ids=descriptions["ids"],
|
||||
)
|
||||
totals["desc"] += 1
|
||||
self.upsert_description(event.id, description)
|
||||
|
||||
current_page += 1
|
||||
events = (
|
||||
|
||||
@@ -1,35 +1,59 @@
|
||||
"""CLIP Embeddings for Frigate."""
|
||||
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import requests
|
||||
from chromadb import EmbeddingFunction, Embeddings
|
||||
from chromadb.api.types import (
|
||||
Documents,
|
||||
Images,
|
||||
is_document,
|
||||
is_image,
|
||||
)
|
||||
from onnx_clip import OnnxClip
|
||||
from onnx_clip import OnnxClip, Preprocessor, Tokenizer
|
||||
from PIL import Image
|
||||
|
||||
from frigate.const import MODEL_CACHE_DIR
|
||||
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util.downloader import ModelDownloader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Clip(OnnxClip):
|
||||
"""Override load models to download to cache directory."""
|
||||
"""Override load models to use pre-downloaded models from cache directory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "ViT-B/32",
|
||||
batch_size: Optional[int] = None,
|
||||
providers: List[str] = ["CPUExecutionProvider"],
|
||||
):
|
||||
"""
|
||||
Instantiates the model and required encoding classes.
|
||||
|
||||
Args:
|
||||
model: The model to utilize. Currently ViT-B/32 and RN50 are
|
||||
allowed.
|
||||
batch_size: If set, splits the lists in `get_image_embeddings`
|
||||
and `get_text_embeddings` into batches of this size before
|
||||
passing them to the model. The embeddings are then concatenated
|
||||
back together before being returned. This is necessary when
|
||||
passing large amounts of data (perhaps ~100 or more).
|
||||
"""
|
||||
allowed_models = ["ViT-B/32", "RN50"]
|
||||
if model not in allowed_models:
|
||||
raise ValueError(f"`model` must be in {allowed_models}. Got {model}.")
|
||||
if model == "ViT-B/32":
|
||||
self.embedding_size = 512
|
||||
elif model == "RN50":
|
||||
self.embedding_size = 1024
|
||||
self.image_model, self.text_model = self._load_models(model, providers)
|
||||
self._tokenizer = Tokenizer()
|
||||
self._preprocessor = Preprocessor()
|
||||
self._batch_size = batch_size
|
||||
|
||||
@staticmethod
|
||||
def _load_models(
|
||||
model: str,
|
||||
silent: bool,
|
||||
) -> Tuple[ort.InferenceSession, ort.InferenceSession]:
|
||||
providers: List[str],
|
||||
) -> tuple[ort.InferenceSession, ort.InferenceSession]:
|
||||
"""
|
||||
These models are a part of the container. Treat as as such.
|
||||
Load models from cache directory.
|
||||
"""
|
||||
if model == "ViT-B/32":
|
||||
IMAGE_MODEL_FILE = "clip_image_model_vitb32.onnx"
|
||||
@@ -43,64 +67,100 @@ class Clip(OnnxClip):
|
||||
models = []
|
||||
for model_file in [IMAGE_MODEL_FILE, TEXT_MODEL_FILE]:
|
||||
path = os.path.join(MODEL_CACHE_DIR, "clip", model_file)
|
||||
models.append(Clip._load_model(path, silent))
|
||||
models.append(Clip._load_model(path, providers))
|
||||
|
||||
return models[0], models[1]
|
||||
|
||||
@staticmethod
|
||||
def _load_model(path: str, silent: bool):
|
||||
providers = ["CPUExecutionProvider"]
|
||||
|
||||
try:
|
||||
if os.path.exists(path):
|
||||
return ort.InferenceSession(path, providers=providers)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
errno.ENOENT,
|
||||
os.strerror(errno.ENOENT),
|
||||
path,
|
||||
)
|
||||
except Exception:
|
||||
s3_url = f"https://lakera-clip.s3.eu-west-1.amazonaws.com/{os.path.basename(path)}"
|
||||
if not silent:
|
||||
logging.info(
|
||||
f"The model file ({path}) doesn't exist "
|
||||
f"or it is invalid. Downloading it from the public S3 "
|
||||
f"bucket: {s3_url}." # noqa: E501
|
||||
)
|
||||
|
||||
# Download from S3
|
||||
# Saving to a temporary file first to avoid corrupting the file
|
||||
temporary_filename = Path(path).with_name(os.path.basename(path) + ".part")
|
||||
|
||||
# Create any missing directories in the path
|
||||
temporary_filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with requests.get(s3_url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(temporary_filename, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
f.flush()
|
||||
# Finally move the temporary file to the correct location
|
||||
temporary_filename.rename(path)
|
||||
def _load_model(path: str, providers: List[str]):
|
||||
if os.path.exists(path):
|
||||
return ort.InferenceSession(path, providers=providers)
|
||||
else:
|
||||
logger.warning(f"CLIP model file {path} not found.")
|
||||
return None
|
||||
|
||||
|
||||
class ClipEmbedding(EmbeddingFunction):
|
||||
"""Embedding function for CLIP model used in Chroma."""
|
||||
class ClipEmbedding:
|
||||
"""Embedding function for CLIP model."""
|
||||
|
||||
def __init__(self, model: str = "ViT-B/32"):
|
||||
"""Initialize CLIP Embedding function."""
|
||||
self.model = Clip(model)
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "ViT-B/32",
|
||||
silent: bool = False,
|
||||
preferred_providers: List[str] = ["CPUExecutionProvider"],
|
||||
):
|
||||
self.model_name = model
|
||||
self.silent = silent
|
||||
self.preferred_providers = preferred_providers
|
||||
self.model_files = self._get_model_files()
|
||||
self.model = None
|
||||
|
||||
def __call__(self, input: Union[Documents, Images]) -> Embeddings:
|
||||
embeddings: Embeddings = []
|
||||
self.downloader = ModelDownloader(
|
||||
model_name="clip",
|
||||
download_path=os.path.join(MODEL_CACHE_DIR, "clip"),
|
||||
file_names=self.model_files,
|
||||
download_func=self._download_model,
|
||||
silent=self.silent,
|
||||
)
|
||||
self.downloader.ensure_model_files()
|
||||
|
||||
def _get_model_files(self):
|
||||
if self.model_name == "ViT-B/32":
|
||||
return ["clip_image_model_vitb32.onnx", "clip_text_model_vitb32.onnx"]
|
||||
elif self.model_name == "RN50":
|
||||
return ["clip_image_model_rn50.onnx", "clip_text_model_rn50.onnx"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected model {self.model_name}. No `.onnx` file found."
|
||||
)
|
||||
|
||||
def _download_model(self, path: str):
|
||||
s3_url = (
|
||||
f"https://lakera-clip.s3.eu-west-1.amazonaws.com/{os.path.basename(path)}"
|
||||
)
|
||||
try:
|
||||
ModelDownloader.download_from_url(s3_url, path, self.silent)
|
||||
self.downloader.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": f"{self.model_name}-{os.path.basename(path)}",
|
||||
"state": ModelStatusTypesEnum.downloaded,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
self.downloader.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": f"{self.model_name}-{os.path.basename(path)}",
|
||||
"state": ModelStatusTypesEnum.error,
|
||||
},
|
||||
)
|
||||
|
||||
def _load_model(self):
|
||||
if self.model is None:
|
||||
self.downloader.wait_for_download()
|
||||
self.model = Clip(self.model_name, providers=self.preferred_providers)
|
||||
|
||||
def __call__(self, input: Union[List[str], List[Image.Image]]) -> List[np.ndarray]:
|
||||
self._load_model()
|
||||
if (
|
||||
self.model is None
|
||||
or self.model.image_model is None
|
||||
or self.model.text_model is None
|
||||
):
|
||||
logger.info(
|
||||
"CLIP model is not fully loaded. Please wait for the download to complete."
|
||||
)
|
||||
return []
|
||||
|
||||
embeddings = []
|
||||
for item in input:
|
||||
if is_image(item):
|
||||
if isinstance(item, Image.Image):
|
||||
result = self.model.get_image_embeddings([item])
|
||||
embeddings.append(result[0, :].tolist())
|
||||
elif is_document(item):
|
||||
embeddings.append(result[0])
|
||||
elif isinstance(item, str):
|
||||
result = self.model.get_text_embeddings([item])
|
||||
embeddings.append(result[0, :].tolist())
|
||||
embeddings.append(result[0])
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {type(item)}")
|
||||
return embeddings
|
||||
|
||||
@@ -1,11 +1,107 @@
|
||||
"""Embedding function for ONNX MiniLM-L6 model used in Chroma."""
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from frigate.const import MODEL_CACHE_DIR
|
||||
# importing this without pytorch or others causes a warning
|
||||
# https://github.com/huggingface/transformers/issues/27214
|
||||
# suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util.downloader import ModelDownloader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MiniLMEmbedding(ONNXMiniLM_L6_V2):
|
||||
"""Override DOWNLOAD_PATH to download to cache directory."""
|
||||
class MiniLMEmbedding:
|
||||
"""Embedding function for ONNX MiniLM-L6 model."""
|
||||
|
||||
DOWNLOAD_PATH = f"{MODEL_CACHE_DIR}/all-MiniLM-L6-v2"
|
||||
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
IMAGE_MODEL_FILE = "model.onnx"
|
||||
TOKENIZER_FILE = "tokenizer"
|
||||
|
||||
def __init__(self, preferred_providers=["CPUExecutionProvider"]):
|
||||
self.preferred_providers = preferred_providers
|
||||
self.tokenizer = None
|
||||
self.session = None
|
||||
|
||||
self.downloader = ModelDownloader(
|
||||
model_name=self.MODEL_NAME,
|
||||
download_path=self.DOWNLOAD_PATH,
|
||||
file_names=[self.IMAGE_MODEL_FILE, self.TOKENIZER_FILE],
|
||||
download_func=self._download_model,
|
||||
)
|
||||
self.downloader.ensure_model_files()
|
||||
|
||||
def _download_model(self, path: str):
|
||||
try:
|
||||
if os.path.basename(path) == self.IMAGE_MODEL_FILE:
|
||||
s3_url = f"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/{self.IMAGE_MODEL_FILE}"
|
||||
ModelDownloader.download_from_url(s3_url, path)
|
||||
elif os.path.basename(path) == self.TOKENIZER_FILE:
|
||||
logger.info("Downloading MiniLM tokenizer")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.MODEL_NAME, clean_up_tokenization_spaces=True
|
||||
)
|
||||
tokenizer.save_pretrained(path)
|
||||
|
||||
self.downloader.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": f"{self.MODEL_NAME}-{os.path.basename(path)}",
|
||||
"state": ModelStatusTypesEnum.downloaded,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
self.downloader.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": f"{self.MODEL_NAME}-{os.path.basename(path)}",
|
||||
"state": ModelStatusTypesEnum.error,
|
||||
},
|
||||
)
|
||||
|
||||
def _load_model_and_tokenizer(self):
|
||||
if self.tokenizer is None or self.session is None:
|
||||
self.downloader.wait_for_download()
|
||||
self.tokenizer = self._load_tokenizer()
|
||||
self.session = self._load_model(
|
||||
os.path.join(self.DOWNLOAD_PATH, self.IMAGE_MODEL_FILE),
|
||||
self.preferred_providers,
|
||||
)
|
||||
|
||||
def _load_tokenizer(self):
|
||||
tokenizer_path = os.path.join(self.DOWNLOAD_PATH, self.TOKENIZER_FILE)
|
||||
return AutoTokenizer.from_pretrained(
|
||||
tokenizer_path, clean_up_tokenization_spaces=True
|
||||
)
|
||||
|
||||
def _load_model(self, path: str, providers: List[str]):
|
||||
if os.path.exists(path):
|
||||
return ort.InferenceSession(path, providers=providers)
|
||||
else:
|
||||
logger.warning(f"MiniLM model file {path} not found.")
|
||||
return None
|
||||
|
||||
def __call__(self, texts: List[str]) -> List[np.ndarray]:
|
||||
self._load_model_and_tokenizer()
|
||||
|
||||
if self.session is None or self.tokenizer is None:
|
||||
logger.error("MiniLM model or tokenizer is not loaded.")
|
||||
return []
|
||||
|
||||
inputs = self.tokenizer(
|
||||
texts, padding=True, truncation=True, return_tensors="np"
|
||||
)
|
||||
input_names = [input.name for input in self.session.get_inputs()]
|
||||
onnx_inputs = {name: inputs[name] for name in input_names if name in inputs}
|
||||
|
||||
outputs = self.session.run(None, onnx_inputs)
|
||||
embeddings = outputs[0].mean(axis=1)
|
||||
|
||||
return [embedding for embedding in embeddings]
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Maintain embeddings in Chroma."""
|
||||
"""Maintain embeddings in SQLite-vec."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
@@ -11,7 +10,7 @@ from typing import Optional
|
||||
import cv2
|
||||
import numpy as np
|
||||
from peewee import DoesNotExist
|
||||
from PIL import Image
|
||||
from playhouse.sqliteq import SqliteQueueDatabase
|
||||
|
||||
from frigate.comms.event_metadata_updater import (
|
||||
EventMetadataSubscriber,
|
||||
@@ -26,7 +25,7 @@ from frigate.genai import get_genai_client
|
||||
from frigate.models import Event
|
||||
from frigate.util.image import SharedMemoryFrameManager, calculate_region
|
||||
|
||||
from .embeddings import Embeddings, get_metadata
|
||||
from .embeddings import Embeddings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,13 +35,14 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: SqliteQueueDatabase,
|
||||
config: FrigateConfig,
|
||||
stop_event: MpEvent,
|
||||
) -> None:
|
||||
threading.Thread.__init__(self)
|
||||
self.name = "embeddings_maintainer"
|
||||
self.config = config
|
||||
self.embeddings = Embeddings()
|
||||
self.embeddings = Embeddings(db)
|
||||
self.event_subscriber = EventUpdateSubscriber()
|
||||
self.event_end_subscriber = EventEndSubscriber()
|
||||
self.event_metadata_subscriber = EventMetadataSubscriber(
|
||||
@@ -56,7 +56,7 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
self.genai_client = get_genai_client(config.genai)
|
||||
|
||||
def run(self) -> None:
|
||||
"""Maintain a Chroma vector database for semantic search."""
|
||||
"""Maintain a SQLite-vec database for semantic search."""
|
||||
while not self.stop_event.is_set():
|
||||
self._process_updates()
|
||||
self._process_finalized()
|
||||
@@ -117,12 +117,11 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
if event.data.get("type") != "object":
|
||||
continue
|
||||
|
||||
# Extract valid event metadata
|
||||
metadata = get_metadata(event)
|
||||
# Extract valid thumbnail
|
||||
thumbnail = base64.b64decode(event.thumbnail)
|
||||
|
||||
# Embed the thumbnail
|
||||
self._embed_thumbnail(event_id, thumbnail, metadata)
|
||||
self._embed_thumbnail(event_id, thumbnail)
|
||||
|
||||
if (
|
||||
camera_config.genai.enabled
|
||||
@@ -183,7 +182,6 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
args=(
|
||||
event,
|
||||
embed_image,
|
||||
metadata,
|
||||
),
|
||||
).start()
|
||||
|
||||
@@ -219,25 +217,16 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
|
||||
return None
|
||||
|
||||
def _embed_thumbnail(self, event_id: str, thumbnail: bytes, metadata: dict) -> None:
|
||||
def _embed_thumbnail(self, event_id: str, thumbnail: bytes) -> None:
|
||||
"""Embed the thumbnail for an event."""
|
||||
self.embeddings.upsert_thumbnail(event_id, thumbnail)
|
||||
|
||||
# Encode the thumbnail
|
||||
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
|
||||
self.embeddings.thumbnail.upsert(
|
||||
images=[img],
|
||||
metadatas=[metadata],
|
||||
ids=[event_id],
|
||||
)
|
||||
|
||||
def _embed_description(
|
||||
self, event: Event, thumbnails: list[bytes], metadata: dict
|
||||
) -> None:
|
||||
def _embed_description(self, event: Event, thumbnails: list[bytes]) -> None:
|
||||
"""Embed the description for an event."""
|
||||
camera_config = self.config.cameras[event.camera]
|
||||
|
||||
description = self.genai_client.generate_description(
|
||||
camera_config, thumbnails, metadata
|
||||
camera_config, thumbnails, event.label
|
||||
)
|
||||
|
||||
if not description:
|
||||
@@ -251,11 +240,7 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
)
|
||||
|
||||
# Encode the description
|
||||
self.embeddings.description.upsert(
|
||||
documents=[description],
|
||||
metadatas=[metadata],
|
||||
ids=[event.id],
|
||||
)
|
||||
self.embeddings.upsert_description(event.id, description)
|
||||
|
||||
logger.debug(
|
||||
"Generated description for %s (%d images): %s",
|
||||
@@ -276,7 +261,6 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
logger.error(f"GenAI not enabled for camera {event.camera}")
|
||||
return
|
||||
|
||||
metadata = get_metadata(event)
|
||||
thumbnail = base64.b64decode(event.thumbnail)
|
||||
|
||||
logger.debug(
|
||||
@@ -315,4 +299,4 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
)
|
||||
)
|
||||
|
||||
self._embed_description(event, embed_image, metadata)
|
||||
self._embed_description(event, embed_image)
|
||||
|
||||
@@ -4,12 +4,15 @@ import math
|
||||
|
||||
|
||||
class ZScoreNormalization:
|
||||
"""Running Z-score normalization for search distance."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, scale_factor: float = 1.0, bias: float = 0.0):
|
||||
"""Initialize with optional scaling and bias adjustments."""
|
||||
"""scale_factor adjusts the magnitude of each score"""
|
||||
"""bias will artificially shift the entire distribution upwards"""
|
||||
self.n = 0
|
||||
self.mean = 0
|
||||
self.m2 = 0
|
||||
self.scale_factor = scale_factor
|
||||
self.bias = bias
|
||||
|
||||
@property
|
||||
def variance(self):
|
||||
@@ -23,7 +26,10 @@ class ZScoreNormalization:
|
||||
self._update(distances)
|
||||
if self.stddev == 0:
|
||||
return distances
|
||||
return [(x - self.mean) / self.stddev for x in distances]
|
||||
return [
|
||||
(x - self.mean) / self.stddev * self.scale_factor + self.bias
|
||||
for x in distances
|
||||
]
|
||||
|
||||
def _update(self, distances: list[float]):
|
||||
for x in distances:
|
||||
|
||||
@@ -8,6 +8,8 @@ from enum import Enum
|
||||
from multiprocessing.synchronize import Event as MpEvent
|
||||
from pathlib import Path
|
||||
|
||||
from playhouse.sqliteq import SqliteQueueDatabase
|
||||
|
||||
from frigate.config import FrigateConfig
|
||||
from frigate.const import CLIPS_DIR
|
||||
from frigate.embeddings.embeddings import Embeddings
|
||||
@@ -22,16 +24,19 @@ class EventCleanupType(str, Enum):
|
||||
|
||||
|
||||
class EventCleanup(threading.Thread):
|
||||
def __init__(self, config: FrigateConfig, stop_event: MpEvent):
|
||||
def __init__(
|
||||
self, config: FrigateConfig, stop_event: MpEvent, db: SqliteQueueDatabase
|
||||
):
|
||||
super().__init__(name="event_cleanup")
|
||||
self.config = config
|
||||
self.stop_event = stop_event
|
||||
self.db = db
|
||||
self.camera_keys = list(self.config.cameras.keys())
|
||||
self.removed_camera_labels: list[str] = None
|
||||
self.camera_labels: dict[str, dict[str, any]] = {}
|
||||
|
||||
if self.config.semantic_search.enabled:
|
||||
self.embeddings = Embeddings()
|
||||
self.embeddings = Embeddings(self.db)
|
||||
|
||||
def get_removed_camera_labels(self) -> list[Event]:
|
||||
"""Get a list of distinct labels for removed cameras."""
|
||||
@@ -229,15 +234,8 @@ class EventCleanup(threading.Thread):
|
||||
Event.delete().where(Event.id << chunk).execute()
|
||||
|
||||
if self.config.semantic_search.enabled:
|
||||
for collection in [
|
||||
self.embeddings.thumbnail,
|
||||
self.embeddings.description,
|
||||
]:
|
||||
existing_ids = collection.get(ids=chunk, include=[])["ids"]
|
||||
if existing_ids:
|
||||
collection.delete(ids=existing_ids)
|
||||
logger.debug(
|
||||
f"Deleted {len(existing_ids)} embeddings from {collection.__class__.__name__}"
|
||||
)
|
||||
self.embeddings.delete_description(chunk)
|
||||
self.embeddings.delete_thumbnail(chunk)
|
||||
logger.debug(f"Deleted {len(events_to_delete)} embeddings")
|
||||
|
||||
logger.info("Exiting event cleanup...")
|
||||
|
||||
@@ -31,12 +31,12 @@ class GenAIClient:
|
||||
self,
|
||||
camera_config: CameraConfig,
|
||||
thumbnails: list[bytes],
|
||||
metadata: dict[str, any],
|
||||
label: str,
|
||||
) -> Optional[str]:
|
||||
"""Generate a description for the frame."""
|
||||
prompt = camera_config.genai.object_prompts.get(
|
||||
metadata["label"], camera_config.genai.prompt
|
||||
).format(**metadata)
|
||||
label, camera_config.genai.prompt
|
||||
)
|
||||
return self._send(prompt, thumbnails)
|
||||
|
||||
def _init_provider(self):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from enum import Enum
|
||||
from typing import TypedDict
|
||||
|
||||
from frigate.camera import CameraMetrics
|
||||
@@ -11,3 +12,10 @@ class StatsTrackingTypes(TypedDict):
|
||||
latest_frigate_version: str
|
||||
last_updated: int
|
||||
processes: dict[str, int]
|
||||
|
||||
|
||||
class ModelStatusTypesEnum(str, Enum):
|
||||
not_downloaded = "not_downloaded"
|
||||
downloading = "downloading"
|
||||
downloaded = "downloaded"
|
||||
error = "error"
|
||||
|
||||
123
frigate/util/downloader.py
Normal file
123
frigate/util/downloader.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable, List
|
||||
|
||||
import requests
|
||||
|
||||
from frigate.comms.inter_process import InterProcessRequestor
|
||||
from frigate.const import UPDATE_MODEL_STATE
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileLock:
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
self.lock_file = f"{path}.lock"
|
||||
|
||||
def acquire(self):
|
||||
parent_dir = os.path.dirname(self.lock_file)
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
|
||||
while True:
|
||||
try:
|
||||
with open(self.lock_file, "x"):
|
||||
return
|
||||
except FileExistsError:
|
||||
time.sleep(0.1)
|
||||
|
||||
def release(self):
|
||||
try:
|
||||
os.remove(self.lock_file)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
class ModelDownloader:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
download_path: str,
|
||||
file_names: List[str],
|
||||
download_func: Callable[[str], None],
|
||||
silent: bool = False,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.download_path = download_path
|
||||
self.file_names = file_names
|
||||
self.download_func = download_func
|
||||
self.silent = silent
|
||||
self.requestor = InterProcessRequestor()
|
||||
self.download_thread = None
|
||||
self.download_complete = threading.Event()
|
||||
|
||||
def ensure_model_files(self):
|
||||
for file in self.file_names:
|
||||
self.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": f"{self.model_name}-{file}",
|
||||
"state": ModelStatusTypesEnum.downloading,
|
||||
},
|
||||
)
|
||||
self.download_thread = threading.Thread(
|
||||
target=self._download_models,
|
||||
name=f"_download_model_{self.model_name}",
|
||||
daemon=True,
|
||||
)
|
||||
self.download_thread.start()
|
||||
|
||||
def _download_models(self):
|
||||
for file_name in self.file_names:
|
||||
path = os.path.join(self.download_path, file_name)
|
||||
lock = FileLock(path)
|
||||
|
||||
if not os.path.exists(path):
|
||||
lock.acquire()
|
||||
try:
|
||||
if not os.path.exists(path):
|
||||
self.download_func(path)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
self.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": f"{self.model_name}-{file_name}",
|
||||
"state": ModelStatusTypesEnum.downloaded,
|
||||
},
|
||||
)
|
||||
|
||||
self.download_complete.set()
|
||||
|
||||
@staticmethod
|
||||
def download_from_url(url: str, save_path: str, silent: bool = False):
|
||||
temporary_filename = Path(save_path).with_name(
|
||||
os.path.basename(save_path) + ".part"
|
||||
)
|
||||
temporary_filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not silent:
|
||||
logger.info(f"Downloading model file from: {url}")
|
||||
|
||||
try:
|
||||
with requests.get(url, stream=True, allow_redirects=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(temporary_filename, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
temporary_filename.rename(save_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading model: {str(e)}")
|
||||
raise
|
||||
|
||||
if not silent:
|
||||
logger.info(f"Downloading complete: {url}")
|
||||
|
||||
def wait_for_download(self):
|
||||
self.download_complete.wait()
|
||||
Reference in New Issue
Block a user