forked from Github/frigate
Embedding gpu (#14253)
This commit is contained in:
@@ -118,7 +118,7 @@ class Embeddings:
|
||||
},
|
||||
embedding_function=jina_text_embedding_function,
|
||||
model_type="text",
|
||||
preferred_providers=["CPUExecutionProvider"],
|
||||
force_cpu=True,
|
||||
)
|
||||
|
||||
self.vision_embedding = GenericONNXEmbedding(
|
||||
@@ -130,7 +130,6 @@ class Embeddings:
|
||||
},
|
||||
embedding_function=jina_vision_embedding_function,
|
||||
model_type="vision",
|
||||
preferred_providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def _create_tables(self):
|
||||
|
||||
@@ -18,6 +18,7 @@ from transformers.utils.logging import disable_progress_bar
|
||||
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util.downloader import ModelDownloader
|
||||
from frigate.util.model import get_ort_providers
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
@@ -40,8 +41,8 @@ class GenericONNXEmbedding:
|
||||
download_urls: Dict[str, str],
|
||||
embedding_function: Callable[[List[np.ndarray]], np.ndarray],
|
||||
model_type: str,
|
||||
preferred_providers: List[str] = ["CPUExecutionProvider"],
|
||||
tokenizer_file: Optional[str] = None,
|
||||
force_cpu: bool = False,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.model_file = model_file
|
||||
@@ -49,7 +50,9 @@ class GenericONNXEmbedding:
|
||||
self.download_urls = download_urls
|
||||
self.embedding_function = embedding_function
|
||||
self.model_type = model_type # 'text' or 'vision'
|
||||
self.preferred_providers = preferred_providers
|
||||
self.providers, self.provider_options = get_ort_providers(
|
||||
force_cpu=force_cpu, requires_fp16=True
|
||||
)
|
||||
|
||||
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||
self.tokenizer = None
|
||||
@@ -105,8 +108,7 @@ class GenericONNXEmbedding:
|
||||
else:
|
||||
self.feature_extractor = self._load_feature_extractor()
|
||||
self.session = self._load_model(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
self.preferred_providers,
|
||||
os.path.join(self.download_path, self.model_file)
|
||||
)
|
||||
|
||||
def _load_tokenizer(self):
|
||||
@@ -123,9 +125,11 @@ class GenericONNXEmbedding:
|
||||
f"{MODEL_CACHE_DIR}/{self.model_name}",
|
||||
)
|
||||
|
||||
def _load_model(self, path: str, providers: List[str]):
|
||||
def _load_model(self, path: str):
|
||||
if os.path.exists(path):
|
||||
return ort.InferenceSession(path, providers=providers)
|
||||
return ort.InferenceSession(
|
||||
path, providers=self.providers, provider_options=self.provider_options
|
||||
)
|
||||
else:
|
||||
logger.warning(f"{self.model_name} model file {path} not found.")
|
||||
return None
|
||||
|
||||
@@ -6,7 +6,7 @@ import onnxruntime as ort
|
||||
|
||||
|
||||
def get_ort_providers(
|
||||
force_cpu: bool = False, openvino_device: str = "AUTO"
|
||||
force_cpu: bool = False, openvino_device: str = "AUTO", requires_fp16: bool = False
|
||||
) -> tuple[list[str], list[dict[str, any]]]:
|
||||
if force_cpu:
|
||||
return (["CPUExecutionProvider"], [{}])
|
||||
@@ -17,14 +17,19 @@ def get_ort_providers(
|
||||
for provider in providers:
|
||||
if provider == "TensorrtExecutionProvider":
|
||||
os.makedirs("/config/model_cache/tensorrt/ort/trt-engines", exist_ok=True)
|
||||
options.append(
|
||||
{
|
||||
"trt_timing_cache_enable": True,
|
||||
"trt_engine_cache_enable": True,
|
||||
"trt_timing_cache_path": "/config/model_cache/tensorrt/ort",
|
||||
"trt_engine_cache_path": "/config/model_cache/tensorrt/ort/trt-engines",
|
||||
}
|
||||
)
|
||||
|
||||
if not requires_fp16 or os.environ.get("USE_FP_16", "True") != "False":
|
||||
options.append(
|
||||
{
|
||||
"trt_fp16_enable": requires_fp16,
|
||||
"trt_timing_cache_enable": True,
|
||||
"trt_engine_cache_enable": True,
|
||||
"trt_timing_cache_path": "/config/model_cache/tensorrt/ort",
|
||||
"trt_engine_cache_path": "/config/model_cache/tensorrt/ort/trt-engines",
|
||||
}
|
||||
)
|
||||
else:
|
||||
options.append({})
|
||||
elif provider == "OpenVINOExecutionProvider":
|
||||
os.makedirs("/config/model_cache/openvino/ort", exist_ok=True)
|
||||
options.append(
|
||||
|
||||
Reference in New Issue
Block a user