mirror of
https://github.com/immich-app/immich.git
synced 2025-12-01 22:09:46 +09:00
chore(ml): added testing and github workflow (#2969)
* added testing * github action for python, made mypy happy * formatted with black * minor fixes and styling * test model cache * cache test dependencies * narrowed model cache tests * moved endpoint tests to their own class * cleaned up fixtures * formatting * removed unused dep
This commit is contained in:
@@ -5,7 +5,7 @@ from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Any
|
||||
|
||||
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf
|
||||
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf # type: ignore
|
||||
|
||||
from ..config import get_cache_dir
|
||||
from ..schemas import ModelType
|
||||
@@ -14,15 +14,9 @@ from ..schemas import ModelType
|
||||
class InferenceModel(ABC):
|
||||
_model_type: ModelType
|
||||
|
||||
def __init__(
|
||||
self, model_name: str, cache_dir: Path | None = None, **model_kwargs
|
||||
) -> None:
|
||||
def __init__(self, model_name: str, cache_dir: Path | str | None = None, **model_kwargs: Any) -> None:
|
||||
self.model_name = model_name
|
||||
self._cache_dir = (
|
||||
cache_dir
|
||||
if cache_dir is not None
|
||||
else get_cache_dir(model_name, self.model_type)
|
||||
)
|
||||
self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type)
|
||||
|
||||
try:
|
||||
self.load(**model_kwargs)
|
||||
@@ -51,12 +45,8 @@ class InferenceModel(ABC):
|
||||
self._cache_dir = cache_dir
|
||||
|
||||
@classmethod
|
||||
def from_model_type(
|
||||
cls, model_type: ModelType, model_name, **model_kwargs
|
||||
) -> InferenceModel:
|
||||
subclasses = {
|
||||
subclass._model_type: subclass for subclass in cls.__subclasses__()
|
||||
}
|
||||
def from_model_type(cls, model_type: ModelType, model_name: str, **model_kwargs: Any) -> InferenceModel:
|
||||
subclasses = {subclass._model_type: subclass for subclass in cls.__subclasses__()}
|
||||
if model_type not in subclasses:
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
|
||||
@@ -66,8 +56,6 @@ class InferenceModel(ABC):
|
||||
if not self.cache_dir.exists():
|
||||
return
|
||||
elif not rmtree.avoids_symlink_attacks:
|
||||
raise RuntimeError(
|
||||
"Attempted to clear cache, but rmtree is not safe on this platform."
|
||||
)
|
||||
raise RuntimeError("Attempted to clear cache, but rmtree is not safe on this platform.")
|
||||
|
||||
rmtree(self.cache_dir)
|
||||
|
||||
Reference in New Issue
Block a user