mirror of
https://github.com/immich-app/immich.git
synced 2025-11-29 08:30:03 +09:00
feat(ml): ML on Rockchip NPUs (#15241)
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import Any, ClassVar
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import ann.ann
|
||||
import app.sessions.rknn as rknn
|
||||
from app.sessions.ort import OrtSession
|
||||
|
||||
from ..config import clean_name, log, settings
|
||||
@@ -66,12 +67,17 @@ class InferenceModel(ABC):
|
||||
pass
|
||||
|
||||
def _download(self) -> None:
|
||||
ignore_patterns = [] if self.model_format == ModelFormat.ARMNN else ["*.armnn"]
|
||||
ignored_patterns: dict[ModelFormat, list[str]] = {
|
||||
ModelFormat.ONNX: ["*.armnn", "*.rknn"],
|
||||
ModelFormat.ARMNN: ["*.rknn"],
|
||||
ModelFormat.RKNN: ["*.armnn"],
|
||||
}
|
||||
|
||||
snapshot_download(
|
||||
f"immich-app/{clean_name(self.model_name)}",
|
||||
cache_dir=self.cache_dir,
|
||||
local_dir=self.cache_dir,
|
||||
ignore_patterns=ignore_patterns,
|
||||
ignore_patterns=ignored_patterns.get(self.model_format, []),
|
||||
)
|
||||
|
||||
def _load(self) -> ModelSession:
|
||||
@@ -108,17 +114,25 @@ class InferenceModel(ABC):
|
||||
session: ModelSession = AnnSession(model_path)
|
||||
case ".onnx":
|
||||
session = OrtSession(model_path)
|
||||
case ".rknn":
|
||||
session = rknn.RknnSession(model_path)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
|
||||
return session
|
||||
|
||||
def model_path_for_format(self, model_format: ModelFormat) -> Path:
|
||||
model_path_prefix = rknn.model_prefix if model_format == ModelFormat.RKNN else None
|
||||
if model_path_prefix:
|
||||
return self.model_dir / model_path_prefix / f"model.{model_format}"
|
||||
return self.model_dir / f"model.{model_format}"
|
||||
|
||||
@property
|
||||
def model_dir(self) -> Path:
|
||||
return self.cache_dir / self.model_type.value
|
||||
|
||||
@property
|
||||
def model_path(self) -> Path:
|
||||
return self.model_dir / f"model.{self.model_format}"
|
||||
return self.model_path_for_format(self.model_format)
|
||||
|
||||
@property
|
||||
def model_task(self) -> ModelTask:
|
||||
@@ -155,4 +169,9 @@ class InferenceModel(ABC):
|
||||
|
||||
@property
|
||||
def _model_format_default(self) -> ModelFormat:
|
||||
return ModelFormat.ARMNN if ann.ann.is_available and settings.ann else ModelFormat.ONNX
|
||||
if rknn.is_available:
|
||||
return ModelFormat.RKNN
|
||||
elif ann.ann.is_available and settings.ann:
|
||||
return ModelFormat.ARMNN
|
||||
else:
|
||||
return ModelFormat.ONNX
|
||||
|
||||
Reference in New Issue
Block a user