feat(ml): add more search models (#11468)

* update export code

* add uuid glob, sort model names

* add new models to ml, sort names

* add new models to server, sort by dims and name

* typo in name

* update export dependencies

* onnx save function

* format
This commit is contained in:
Mert
2024-07-31 00:34:45 -04:00
committed by GitHub
parent 2423bb36c4
commit 41580696c7
9 changed files with 3804 additions and 2923 deletions

View File

@@ -1,3 +1,4 @@
import os
import tempfile
import warnings
from pathlib import Path
@@ -8,7 +9,6 @@ from transformers import AutoTokenizer
from .openclip import OpenCLIPModelConfig
from .openclip import to_onnx as openclip_to_onnx
from .optimize import optimize
from .util import get_model_path
_MCLIP_TO_OPENCLIP = {
@@ -23,18 +23,20 @@ def to_onnx(
model_name: str,
output_dir_visual: Path | str,
output_dir_textual: Path | str,
) -> None:
) -> tuple[Path, Path]:
textual_path = get_model_path(output_dir_textual)
with tempfile.TemporaryDirectory() as tmpdir:
model = MultilingualCLIP.from_pretrained(model_name, cache_dir=tmpdir)
model = MultilingualCLIP.from_pretrained(model_name, cache_dir=os.environ.get("CACHE_DIR", tmpdir))
AutoTokenizer.from_pretrained(model_name).save_pretrained(output_dir_textual)
model.eval()
for param in model.parameters():
param.requires_grad_(False)
export_text_encoder(model, textual_path)
openclip_to_onnx(_MCLIP_TO_OPENCLIP[model_name], output_dir_visual)
optimize(textual_path)
visual_path, _ = openclip_to_onnx(_MCLIP_TO_OPENCLIP[model_name], output_dir_visual)
assert visual_path is not None, "Visual model export failed"
return visual_path, textual_path
def export_text_encoder(model: MultilingualCLIP, output_path: Path | str) -> None:
@@ -58,10 +60,10 @@ def export_text_encoder(model: MultilingualCLIP, output_path: Path | str) -> Non
args,
output_path.as_posix(),
input_names=["input_ids", "attention_mask"],
output_names=["text_embedding"],
output_names=["embedding"],
opset_version=17,
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
},
# dynamic_axes={
# "input_ids": {0: "batch_size", 1: "sequence_length"},
# "attention_mask": {0: "batch_size", 1: "sequence_length"},
# },
)