mirror of
https://github.com/immich-app/immich.git
synced 2025-11-24 11:20:40 +09:00
feat(ml): composable ml (#9973)
* modularize model classes * various fixes * expose port * change response * round coordinates * simplify preload * update server * simplify interface simplify * update tests * composable endpoint * cleanup fixes remove unnecessary interface support text input, cleanup * ew camelcase * update server server fixes fix typing * ml fixes update locustfile fixes * cleaner response * better repo response * update tests formatting and typing rename * undo compose change * linting fix type actually fix typing * stricter typing fix detection-only response no need for defaultdict * update spec file update api linting * update e2e * unnecessary dimension * remove commented code * remove duplicate code * remove unused imports * add batch dim
This commit is contained in:
@@ -6,22 +6,34 @@ import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from typing import Any, AsyncGenerator, Callable, Iterator
|
||||
from zipfile import BadZipFile
|
||||
|
||||
import orjson
|
||||
from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile
|
||||
from fastapi import Depends, FastAPI, File, Form, HTTPException
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
|
||||
from PIL.Image import Image
|
||||
from pydantic import ValidationError
|
||||
from starlette.formparsers import MultiPartParser
|
||||
|
||||
from app.models import get_model_deps
|
||||
from app.models.base import InferenceModel
|
||||
from app.models.transforms import decode_pil
|
||||
|
||||
from .config import PreloadModelData, log, settings
|
||||
from .models.cache import ModelCache
|
||||
from .schemas import (
|
||||
InferenceEntries,
|
||||
InferenceEntry,
|
||||
InferenceResponse,
|
||||
MessageResponse,
|
||||
ModelIdentity,
|
||||
ModelTask,
|
||||
ModelType,
|
||||
PipelineRequest,
|
||||
T,
|
||||
TextResponse,
|
||||
)
|
||||
|
||||
@@ -63,12 +75,21 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
|
||||
gc.collect()
|
||||
|
||||
|
||||
async def preload_models(preload_models: PreloadModelData) -> None:
|
||||
log.info(f"Preloading models: {preload_models}")
|
||||
if preload_models.clip is not None:
|
||||
await load(await model_cache.get(preload_models.clip, ModelType.CLIP))
|
||||
if preload_models.facial_recognition is not None:
|
||||
await load(await model_cache.get(preload_models.facial_recognition, ModelType.FACIAL_RECOGNITION))
|
||||
async def preload_models(preload: PreloadModelData) -> None:
|
||||
log.info(f"Preloading models: {preload}")
|
||||
if preload.clip is not None:
|
||||
model = await model_cache.get(preload.clip, ModelType.TEXTUAL, ModelTask.SEARCH)
|
||||
await load(model)
|
||||
|
||||
model = await model_cache.get(preload.clip, ModelType.VISUAL, ModelTask.SEARCH)
|
||||
await load(model)
|
||||
|
||||
if preload.facial_recognition is not None:
|
||||
model = await model_cache.get(preload.facial_recognition, ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)
|
||||
await load(model)
|
||||
|
||||
model = await model_cache.get(preload.facial_recognition, ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION)
|
||||
await load(model)
|
||||
|
||||
|
||||
def update_state() -> Iterator[None]:
|
||||
@@ -81,6 +102,27 @@ def update_state() -> Iterator[None]:
|
||||
active_requests -= 1
|
||||
|
||||
|
||||
def get_entries(entries: str = Form()) -> InferenceEntries:
|
||||
try:
|
||||
request: PipelineRequest = orjson.loads(entries)
|
||||
without_deps: list[InferenceEntry] = []
|
||||
with_deps: list[InferenceEntry] = []
|
||||
for task, types in request.items():
|
||||
for type, entry in types.items():
|
||||
parsed: InferenceEntry = {
|
||||
"name": entry["modelName"],
|
||||
"task": task,
|
||||
"type": type,
|
||||
"options": entry.get("options", {}),
|
||||
}
|
||||
dep = get_model_deps(parsed["name"], type, task)
|
||||
(with_deps if dep else without_deps).append(parsed)
|
||||
return without_deps, with_deps
|
||||
except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e:
|
||||
log.error(f"Invalid request format: {e}")
|
||||
raise HTTPException(422, "Invalid request format.")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@@ -96,42 +138,63 @@ def ping() -> str:
|
||||
|
||||
@app.post("/predict", dependencies=[Depends(update_state)])
|
||||
async def predict(
|
||||
model_name: str = Form(alias="modelName"),
|
||||
model_type: ModelType = Form(alias="modelType"),
|
||||
options: str = Form(default="{}"),
|
||||
entries: InferenceEntries = Depends(get_entries),
|
||||
image: bytes | None = File(default=None),
|
||||
text: str | None = Form(default=None),
|
||||
image: UploadFile | None = None,
|
||||
) -> Any:
|
||||
if image is not None:
|
||||
inputs: str | bytes = await image.read()
|
||||
inputs: Image | str = await run(lambda: decode_pil(image))
|
||||
elif text is not None:
|
||||
inputs = text
|
||||
else:
|
||||
raise HTTPException(400, "Either image or text must be provided")
|
||||
try:
|
||||
kwargs = orjson.loads(options)
|
||||
except orjson.JSONDecodeError:
|
||||
raise HTTPException(400, f"Invalid options JSON: {options}")
|
||||
|
||||
model = await load(await model_cache.get(model_name, model_type, ttl=settings.model_ttl, **kwargs))
|
||||
model.configure(**kwargs)
|
||||
outputs = await run(model.predict, inputs)
|
||||
return ORJSONResponse(outputs)
|
||||
response = await run_inference(inputs, entries)
|
||||
return ORJSONResponse(response)
|
||||
|
||||
|
||||
async def run(func: Callable[..., Any], inputs: Any) -> Any:
|
||||
async def run_inference(payload: Image | str, entries: InferenceEntries) -> InferenceResponse:
|
||||
outputs: dict[ModelIdentity, Any] = {}
|
||||
response: InferenceResponse = {}
|
||||
|
||||
async def _run_inference(entry: InferenceEntry) -> None:
|
||||
model = await model_cache.get(entry["name"], entry["type"], entry["task"], ttl=settings.model_ttl)
|
||||
inputs = [payload]
|
||||
for dep in model.depends:
|
||||
try:
|
||||
inputs.append(outputs[dep])
|
||||
except KeyError:
|
||||
message = f"Task {entry['task']} of type {entry['type']} depends on output of {dep}"
|
||||
raise HTTPException(400, message)
|
||||
model = await load(model)
|
||||
output = await run(model.predict, *inputs, **entry["options"])
|
||||
outputs[model.identity] = output
|
||||
response[entry["task"]] = output
|
||||
|
||||
without_deps, with_deps = entries
|
||||
await asyncio.gather(*[_run_inference(entry) for entry in without_deps])
|
||||
if with_deps:
|
||||
await asyncio.gather(*[_run_inference(entry) for entry in with_deps])
|
||||
if isinstance(payload, Image):
|
||||
response["imageHeight"], response["imageWidth"] = payload.height, payload.width
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def run(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
|
||||
if thread_pool is None:
|
||||
return func(inputs)
|
||||
return await asyncio.get_running_loop().run_in_executor(thread_pool, func, inputs)
|
||||
return func(*args, **kwargs)
|
||||
partial_func = partial(func, *args, **kwargs)
|
||||
return await asyncio.get_running_loop().run_in_executor(thread_pool, partial_func)
|
||||
|
||||
|
||||
async def load(model: InferenceModel) -> InferenceModel:
|
||||
if model.loaded:
|
||||
return model
|
||||
|
||||
def _load(model: InferenceModel) -> None:
|
||||
def _load(model: InferenceModel) -> InferenceModel:
|
||||
with lock:
|
||||
model.load()
|
||||
return model
|
||||
|
||||
try:
|
||||
await run(_load, model)
|
||||
|
||||
Reference in New Issue
Block a user