forked from Github/frigate
Embeddings normalization fixes (#14284)
* Use cosine distance metric for vec tables * Only apply normalization to multi modal searches * Catch possible edge case in stddev calc * Use sigmoid function for normalization for multi modal searches only * Ensure we get model state on initial page load * Only save stats for multi modal searches and only use cosine similarity for image -> image search
This commit is contained in:
@@ -473,12 +473,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
|
||||
)
|
||||
|
||||
thumb_result = context.search_thumbnail(search_event)
|
||||
thumb_ids = dict(
|
||||
zip(
|
||||
[result[0] for result in thumb_result],
|
||||
context.thumb_stats.normalize([result[1] for result in thumb_result]),
|
||||
)
|
||||
)
|
||||
thumb_ids = {result[0]: result[1] for result in thumb_result}
|
||||
search_results = {
|
||||
event_id: {"distance": distance, "source": "thumbnail"}
|
||||
for event_id, distance in thumb_ids.items()
|
||||
@@ -486,15 +481,18 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
|
||||
else:
|
||||
search_types = search_type.split(",")
|
||||
|
||||
# only save stats for multi-modal searches
|
||||
save_stats = "thumbnail" in search_types and "description" in search_types
|
||||
|
||||
if "thumbnail" in search_types:
|
||||
thumb_result = context.search_thumbnail(query)
|
||||
|
||||
thumb_distances = context.thumb_stats.normalize(
|
||||
[result[1] for result in thumb_result], save_stats
|
||||
)
|
||||
|
||||
thumb_ids = dict(
|
||||
zip(
|
||||
[result[0] for result in thumb_result],
|
||||
context.thumb_stats.normalize(
|
||||
[result[1] for result in thumb_result]
|
||||
),
|
||||
)
|
||||
zip([result[0] for result in thumb_result], thumb_distances)
|
||||
)
|
||||
search_results.update(
|
||||
{
|
||||
@@ -505,12 +503,13 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
|
||||
|
||||
if "description" in search_types:
|
||||
desc_result = context.search_description(query)
|
||||
desc_ids = dict(
|
||||
zip(
|
||||
[result[0] for result in desc_result],
|
||||
context.desc_stats.normalize([result[1] for result in desc_result]),
|
||||
)
|
||||
|
||||
desc_distances = context.desc_stats.normalize(
|
||||
[result[1] for result in desc_result], save_stats
|
||||
)
|
||||
|
||||
desc_ids = dict(zip([result[0] for result in desc_result], desc_distances))
|
||||
|
||||
for event_id, distance in desc_ids.items():
|
||||
if (
|
||||
event_id not in search_results
|
||||
|
||||
@@ -42,12 +42,12 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
|
||||
self.execute_sql("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
|
||||
id TEXT PRIMARY KEY,
|
||||
thumbnail_embedding FLOAT[768]
|
||||
thumbnail_embedding FLOAT[768] distance_metric=cosine
|
||||
);
|
||||
""")
|
||||
self.execute_sql("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
|
||||
id TEXT PRIMARY KEY,
|
||||
description_embedding FLOAT[768]
|
||||
description_embedding FLOAT[768] distance_metric=cosine
|
||||
);
|
||||
""")
|
||||
|
||||
@@ -20,10 +20,11 @@ class ZScoreNormalization:
|
||||
|
||||
@property
|
||||
def stddev(self):
|
||||
return math.sqrt(self.variance)
|
||||
return math.sqrt(self.variance) if self.variance > 0 else 0.0
|
||||
|
||||
def normalize(self, distances: list[float]):
|
||||
self._update(distances)
|
||||
def normalize(self, distances: list[float], save_stats: bool):
|
||||
if save_stats:
|
||||
self._update(distances)
|
||||
if self.stddev == 0:
|
||||
return distances
|
||||
return [
|
||||
|
||||
Reference in New Issue
Block a user