Semantic Search API (#12105)

* initial event search api implementation

* fix lint

* fix tests

* move chromadb imports and pysqlite hotswap to fix tests

* remove unused import

* switch default limit to 50

* fix events accidently pulling inside chroma results loop
This commit is contained in:
Jason Hunter
2024-06-23 09:13:02 -04:00
committed by Nicolas Mowen
parent 36cbffcc5e
commit 9e825811f2
8 changed files with 359 additions and 23 deletions

View File

@@ -23,6 +23,7 @@ from frigate.api.preview import PreviewBp
from frigate.api.review import ReviewBp
from frigate.config import FrigateConfig
from frigate.const import CONFIG_DIR
from frigate.embeddings import EmbeddingsContext
from frigate.events.external import ExternalEventProcessor
from frigate.models import Event, Timeline
from frigate.plus import PlusApi
@@ -52,6 +53,7 @@ bp.register_blueprint(AuthBp)
def create_app(
frigate_config,
database: SqliteQueueDatabase,
embeddings: EmbeddingsContext,
detected_frames_processor,
storage_maintainer: StorageMaintainer,
onvif: OnvifController,
@@ -79,6 +81,7 @@ def create_app(
database.close()
app.frigate_config = frigate_config
app.embeddings = embeddings
app.detected_frames_processor = detected_frames_processor
app.storage_maintainer = storage_maintainer
app.onvif = onvif

View File

@@ -1,5 +1,7 @@
"""Event apis."""
import base64
import io
import logging
import os
from datetime import datetime
@@ -8,6 +10,7 @@ from pathlib import Path
from urllib.parse import unquote
import cv2
import numpy as np
from flask import (
Blueprint,
current_app,
@@ -15,13 +18,16 @@ from flask import (
make_response,
request,
)
from peewee import DoesNotExist, fn, operator
from peewee import JOIN, DoesNotExist, fn, operator
from PIL import Image
from playhouse.shortcuts import model_to_dict
from frigate.const import (
CLIPS_DIR,
)
from frigate.models import Event, Timeline
from frigate.embeddings import EmbeddingsContext
from frigate.embeddings.embeddings import get_metadata
from frigate.models import Event, ReviewSegment, Timeline
from frigate.object_processing import TrackedObject
from frigate.util.builtin import get_tz_modifiers
@@ -245,6 +251,189 @@ def events():
return jsonify(list(events))
@EventBp.route("/events/search")
def events_search():
query = request.args.get("query", type=str)
search_type = request.args.get("search_type", "text", type=str)
include_thumbnails = request.args.get("include_thumbnails", default=1, type=int)
limit = request.args.get("limit", 50, type=int)
# Filters
cameras = request.args.get("cameras", "all", type=str)
labels = request.args.get("labels", "all", type=str)
zones = request.args.get("zones", "all", type=str)
after = request.args.get("after", type=float)
before = request.args.get("before", type=float)
if not query:
return make_response(
jsonify(
{
"success": False,
"message": "A search query must be supplied",
}
),
400,
)
if not current_app.frigate_config.semantic_search.enabled:
return make_response(
jsonify(
{
"success": False,
"message": "Semantic search is not enabled",
}
),
400,
)
context: EmbeddingsContext = current_app.embeddings
selected_columns = [
Event.id,
Event.camera,
Event.label,
Event.sub_label,
Event.zones,
Event.start_time,
Event.end_time,
Event.data,
ReviewSegment.thumb_path,
]
if include_thumbnails:
selected_columns.append(Event.thumbnail)
# Build the where clause for the embeddings query
embeddings_filters = []
if cameras != "all":
camera_list = cameras.split(",")
embeddings_filters.append({"camera": {"$in": camera_list}})
if labels != "all":
label_list = labels.split(",")
embeddings_filters.append({"label": {"$in": label_list}})
if zones != "all":
filtered_zones = zones.split(",")
zone_filters = [{f"zones_{zone}": {"$eq": True}} for zone in filtered_zones]
if len(zone_filters) > 1:
embeddings_filters.append({"$or": zone_filters})
else:
embeddings_filters.append(zone_filters[0])
if after:
embeddings_filters.append({"start_time": {"$gt": after}})
if before:
embeddings_filters.append({"start_time": {"$lt": before}})
where = None
if len(embeddings_filters) > 1:
where = {"$and": embeddings_filters}
elif len(embeddings_filters) == 1:
where = embeddings_filters[0]
thumb_ids = {}
desc_ids = {}
if search_type == "thumbnail":
# Grab the ids of events that match the thumbnail image embeddings
try:
search_event: Event = Event.get(Event.id == query)
except DoesNotExist:
return make_response(
jsonify(
{
"success": False,
"message": "Event not found",
}
),
404,
)
thumbnail = base64.b64decode(search_event.thumbnail)
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
thumb_result = context.embeddings.thumbnail.query(
query_images=[img],
n_results=limit,
where=where,
)
thumb_ids = dict(zip(thumb_result["ids"][0], thumb_result["distances"][0]))
else:
thumb_result = context.embeddings.thumbnail.query(
query_texts=[query],
n_results=limit,
where=where,
)
# Do a rudimentary normalization of the difference in distances returned by CLIP and MiniLM.
thumb_ids = dict(
zip(
thumb_result["ids"][0],
context.thumb_stats.normalize(thumb_result["distances"][0]),
)
)
desc_result = context.embeddings.description.query(
query_texts=[query],
n_results=limit,
where=where,
)
desc_ids = dict(
zip(
desc_result["ids"][0],
context.desc_stats.normalize(desc_result["distances"][0]),
)
)
results = {}
for event_id in thumb_ids.keys() | desc_ids:
min_distance = min(
i
for i in (thumb_ids.get(event_id), desc_ids.get(event_id))
if i is not None
)
results[event_id] = {
"distance": min_distance,
"source": "thumbnail"
if min_distance == thumb_ids.get(event_id)
else "description",
}
if not results:
return jsonify([])
# Get the event data
events = (
Event.select(*selected_columns)
.join(
ReviewSegment,
JOIN.LEFT_OUTER,
on=(fn.json_extract(ReviewSegment.data, "$.detections").contains(Event.id)),
)
.where(Event.id << list(results.keys()))
.dicts()
.iterator()
)
events = list(events)
events = [
{k: v for k, v in event.items() if k != "data"}
| {
k: v
for k, v in event["data"].items()
if k in ["type", "score", "top_score", "description"]
}
| {
"search_distance": results[event["id"]]["distance"],
"search_source": results[event["id"]]["source"],
}
for event in events
]
events = sorted(events, key=lambda x: x["search_distance"])[:limit]
return jsonify(events)
@EventBp.route("/events/summary")
def events_summary():
tz_name = request.args.get("timezone", default="utc", type=str)
@@ -604,6 +793,52 @@ def set_sub_label(id):
)
@EventBp.route("/events/<id>/description", methods=("POST",))
def set_description(id):
try:
event: Event = Event.get(Event.id == id)
except DoesNotExist:
return make_response(
jsonify({"success": False, "message": "Event " + id + " not found"}), 404
)
json: dict[str, any] = request.get_json(silent=True) or {}
new_description = json.get("description")
if new_description is None or len(new_description) == 0:
return make_response(
jsonify(
{
"success": False,
"message": "description cannot be empty",
}
),
400,
)
event.data["description"] = new_description
event.save()
# If semantic search is enabled, update the index
if current_app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = current_app.embeddings
context.embeddings.description.upsert(
documents=[new_description],
metadatas=[get_metadata(event)],
ids=[id],
)
return make_response(
jsonify(
{
"success": True,
"message": "Event " + id + " description set to " + new_description,
}
),
200,
)
@EventBp.route("/events/<id>", methods=("DELETE",))
def delete_event(id):
try:
@@ -625,6 +860,11 @@ def delete_event(id):
event.delete_instance()
Timeline.delete().where(Timeline.source_id == id).execute()
# If semantic search is enabled, update the index
if current_app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = current_app.embeddings
context.embeddings.thumbnail.delete(ids=[id])
context.embeddings.description.delete(ids=[id])
return make_response(
jsonify({"success": True, "message": "Event " + id + " deleted"}), 200
)