forked from Github/frigate
add plus integration for models (#6328)
This commit is contained in:
@@ -1,11 +1,16 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Union, Literal
|
||||
|
||||
|
||||
import requests
|
||||
import matplotlib.pyplot as plt
|
||||
from pydantic import BaseModel, Extra, Field, validator
|
||||
from pydantic.fields import PrivateAttr
|
||||
from frigate.plus import PlusApi
|
||||
|
||||
from frigate.util import load_labels
|
||||
|
||||
@@ -73,6 +78,45 @@ class ModelConfig(BaseModel):
|
||||
}
|
||||
self._colormap = {}
|
||||
|
||||
def check_and_load_plus_model(
|
||||
self, plus_api: PlusApi, detector: str = None
|
||||
) -> None:
|
||||
if not self.path or not self.path.startswith("plus://"):
|
||||
return
|
||||
|
||||
model_id = self.path[7:]
|
||||
self.path = f"/config/model_cache/{model_id}"
|
||||
model_info_path = f"{self.path}.json"
|
||||
|
||||
# download the model if it doesn't exist
|
||||
if not os.path.isfile(self.path):
|
||||
download_url = plus_api.get_model_download_url(model_id)
|
||||
r = requests.get(download_url)
|
||||
with open(self.path, "wb") as f:
|
||||
f.write(r.content)
|
||||
|
||||
# download the model info if it doesn't exist
|
||||
if not os.path.isfile(model_info_path):
|
||||
model_info = plus_api.get_model_info(model_id)
|
||||
with open(model_info_path, "w") as f:
|
||||
json.dump(model_info, f)
|
||||
else:
|
||||
with open(model_info_path, "r") as f:
|
||||
model_info = json.load(f)
|
||||
|
||||
if detector and detector not in model_info["supportedDetectors"]:
|
||||
raise ValueError(f"Model does not support detector type of {detector}")
|
||||
|
||||
self.width = model_info["width"]
|
||||
self.height = model_info["height"]
|
||||
self.input_tensor = model_info["inputShape"]
|
||||
self.input_pixel_format = model_info["pixelFormat"]
|
||||
self.model_type = model_info["type"]
|
||||
self._merged_labelmap = {
|
||||
**{int(key): val for key, val in model_info["labelMap"].items()},
|
||||
**self.labelmap,
|
||||
}
|
||||
|
||||
def compute_model_hash(self) -> None:
|
||||
with open(self.path, "rb") as f:
|
||||
file_hash = hashlib.md5()
|
||||
|
||||
Reference in New Issue
Block a user