Add ability to filter based on search type (#13641)

This commit is contained in:
Josh Hawkins
2024-09-09 14:45:19 -05:00
committed by GitHub
parent 03ff3e639f
commit cae11cbb86
5 changed files with 113 additions and 28 deletions

View File

@@ -274,7 +274,7 @@ def event_ids():
@EventBp.route("/events/search")
def events_search():
query = request.args.get("query", type=str)
search_type = request.args.get("search_type", "text", type=str)
search_type = request.args.get("search_type", "thumbnail,description", type=str)
include_thumbnails = request.args.get("include_thumbnails", default=1, type=int)
limit = request.args.get("limit", 50, type=int)
@@ -358,7 +358,7 @@ def events_search():
thumb_ids = {}
desc_ids = {}
if search_type == "thumbnail":
if search_type == "similarity":
# Grab the ids of events that match the thumbnail image embeddings
try:
search_event: Event = Event.get(Event.id == query)
@@ -386,29 +386,34 @@ def events_search():
)
)
else:
thumb_result = context.embeddings.thumbnail.query(
query_texts=[query],
n_results=limit,
where=where,
)
# Do a rudimentary normalization of the difference in distances returned by CLIP and MiniLM.
thumb_ids = dict(
zip(
thumb_result["ids"][0],
context.thumb_stats.normalize(thumb_result["distances"][0]),
search_types = search_type.split(",")
if "thumbnail" in search_types:
thumb_result = context.embeddings.thumbnail.query(
query_texts=[query],
n_results=limit,
where=where,
)
)
desc_result = context.embeddings.description.query(
query_texts=[query],
n_results=limit,
where=where,
)
desc_ids = dict(
zip(
desc_result["ids"][0],
context.desc_stats.normalize(desc_result["distances"][0]),
# Do a rudimentary normalization of the difference in distances returned by CLIP and MiniLM.
thumb_ids = dict(
zip(
thumb_result["ids"][0],
context.thumb_stats.normalize(thumb_result["distances"][0]),
)
)
if "description" in search_types:
desc_result = context.embeddings.description.query(
query_texts=[query],
n_results=limit,
where=where,
)
desc_ids = dict(
zip(
desc_result["ids"][0],
context.desc_stats.normalize(desc_result["distances"][0]),
)
)
)
results = {}
for event_id in thumb_ids.keys() | desc_ids: