mirror of
https://github.com/immich-app/immich.git
synced 2025-11-17 12:52:38 +09:00
feat(ml): add more search models (#11468)
* update export code * add uuid glob, sort model names * add new models to ml, sort names * add new models to server, sort by dims and name * typo in name * update export dependencies * onnx save function * format
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
@@ -8,7 +9,6 @@ from transformers import AutoTokenizer
|
||||
|
||||
from .openclip import OpenCLIPModelConfig
|
||||
from .openclip import to_onnx as openclip_to_onnx
|
||||
from .optimize import optimize
|
||||
from .util import get_model_path
|
||||
|
||||
_MCLIP_TO_OPENCLIP = {
|
||||
@@ -23,18 +23,20 @@ def to_onnx(
|
||||
model_name: str,
|
||||
output_dir_visual: Path | str,
|
||||
output_dir_textual: Path | str,
|
||||
) -> None:
|
||||
) -> tuple[Path, Path]:
|
||||
textual_path = get_model_path(output_dir_textual)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model = MultilingualCLIP.from_pretrained(model_name, cache_dir=tmpdir)
|
||||
model = MultilingualCLIP.from_pretrained(model_name, cache_dir=os.environ.get("CACHE_DIR", tmpdir))
|
||||
AutoTokenizer.from_pretrained(model_name).save_pretrained(output_dir_textual)
|
||||
|
||||
model.eval()
|
||||
for param in model.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
export_text_encoder(model, textual_path)
|
||||
openclip_to_onnx(_MCLIP_TO_OPENCLIP[model_name], output_dir_visual)
|
||||
optimize(textual_path)
|
||||
visual_path, _ = openclip_to_onnx(_MCLIP_TO_OPENCLIP[model_name], output_dir_visual)
|
||||
assert visual_path is not None, "Visual model export failed"
|
||||
return visual_path, textual_path
|
||||
|
||||
|
||||
def export_text_encoder(model: MultilingualCLIP, output_path: Path | str) -> None:
|
||||
@@ -58,10 +60,10 @@ def export_text_encoder(model: MultilingualCLIP, output_path: Path | str) -> Non
|
||||
args,
|
||||
output_path.as_posix(),
|
||||
input_names=["input_ids", "attention_mask"],
|
||||
output_names=["text_embedding"],
|
||||
output_names=["embedding"],
|
||||
opset_version=17,
|
||||
dynamic_axes={
|
||||
"input_ids": {0: "batch_size", 1: "sequence_length"},
|
||||
"attention_mask": {0: "batch_size", 1: "sequence_length"},
|
||||
},
|
||||
# dynamic_axes={
|
||||
# "input_ids": {0: "batch_size", 1: "sequence_length"},
|
||||
# "attention_mask": {0: "batch_size", 1: "sequence_length"},
|
||||
# },
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user