164 lines
5.1 KiB
Python
164 lines
5.1 KiB
Python
"""Model registry. VRAM tier drives default. User can override."""
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Literal
|
|
|
|
Kind = Literal["image", "video"]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ModelSpec:
|
|
id: str
|
|
label: str
|
|
repo: str
|
|
kind: Kind
|
|
min_vram_gb: float
|
|
nsfw_capable: bool
|
|
download_gb: float = 0.0
|
|
notes: str = ""
|
|
|
|
|
|
IMAGE_MODELS: list[ModelSpec] = [
|
|
ModelSpec(
|
|
id="sdxl-turbo",
|
|
label="SDXL Turbo (fast, low VRAM)",
|
|
repo="stabilityai/sdxl-turbo",
|
|
kind="image",
|
|
min_vram_gb=4.0,
|
|
nsfw_capable=False,
|
|
download_gb=7.0,
|
|
notes="1-4 step sampling. Censored base. Good for safe content on weak GPUs.",
|
|
),
|
|
ModelSpec(
|
|
id="pony-xl",
|
|
label="Pony Diffusion XL v6 (NSFW, anime/realistic)",
|
|
repo="AstraliteHeart/pony-diffusion-v6",
|
|
kind="image",
|
|
min_vram_gb=8.0,
|
|
nsfw_capable=True,
|
|
download_gb=7.0,
|
|
notes="Use score_9, score_8_up tags. Strong NSFW capability.",
|
|
),
|
|
ModelSpec(
|
|
id="illustrious-xl",
|
|
label="Illustrious XL v0.1 (NSFW, illustration)",
|
|
repo="OnomaAIResearch/Illustrious-xl-early-release-v0",
|
|
kind="image",
|
|
min_vram_gb=10.0,
|
|
nsfw_capable=True,
|
|
download_gb=7.0,
|
|
notes="High-detail anime/illustration. NSFW-capable.",
|
|
),
|
|
ModelSpec(
|
|
id="sdxl-base",
|
|
label="SDXL Base 1.0 (versatile)",
|
|
repo="stabilityai/stable-diffusion-xl-base-1.0",
|
|
kind="image",
|
|
min_vram_gb=8.0,
|
|
nsfw_capable=False,
|
|
download_gb=7.0,
|
|
notes="General-purpose. Censored.",
|
|
),
|
|
ModelSpec(
|
|
id="flux-schnell",
|
|
label="FLUX.1 Schnell (high quality, 12GB+)",
|
|
repo="black-forest-labs/FLUX.1-schnell",
|
|
kind="image",
|
|
min_vram_gb=12.0,
|
|
nsfw_capable=False,
|
|
download_gb=24.0,
|
|
notes="State-of-the-art prompt adherence. 4-step.",
|
|
),
|
|
]
|
|
|
|
VIDEO_MODELS: list[ModelSpec] = [
|
|
ModelSpec(
|
|
id="ltx-video",
|
|
label="LTX-Video 0.9 (fast, 8GB+)",
|
|
repo="Lightricks/LTX-Video",
|
|
kind="video",
|
|
min_vram_gb=8.0,
|
|
nsfw_capable=True,
|
|
download_gb=18.0,
|
|
notes="2-5 second clips at 24fps. Fast generation.",
|
|
),
|
|
ModelSpec(
|
|
id="cogvideox-5b",
|
|
label="CogVideoX 5B (12GB+)",
|
|
repo="THUDM/CogVideoX-5b",
|
|
kind="video",
|
|
min_vram_gb=12.0,
|
|
nsfw_capable=False,
|
|
download_gb=20.0,
|
|
notes="6 second clips. Good motion quality.",
|
|
),
|
|
ModelSpec(
|
|
id="wan-2-1",
|
|
label="Wan 2.1 T2V 1.3B (16GB+)",
|
|
repo="Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
|
|
kind="video",
|
|
min_vram_gb=14.0,
|
|
nsfw_capable=True,
|
|
download_gb=16.0,
|
|
notes="Top-tier open video model. Needs diffusers>=0.32.",
|
|
),
|
|
]
|
|
|
|
|
|
def default_for_tier(tier: str, kind: Kind) -> ModelSpec | None:
|
|
pool = IMAGE_MODELS if kind == "image" else VIDEO_MODELS
|
|
tier_min = {"cpu": 0.0, "low": 0.0, "mid": 8.0, "high": 10.0, "ultra": 12.0}
|
|
target = tier_min.get(tier, 0.0)
|
|
if tier == "cpu" and kind == "video":
|
|
return None
|
|
if tier in ("cpu", "low") and kind == "image":
|
|
return next(m for m in pool if m.id == "sdxl-turbo")
|
|
if tier == "mid" and kind == "image":
|
|
return next(m for m in pool if m.id == "pony-xl")
|
|
if tier == "high" and kind == "image":
|
|
return next(m for m in pool if m.id == "illustrious-xl")
|
|
if tier == "ultra" and kind == "image":
|
|
return next(m for m in pool if m.id == "illustrious-xl")
|
|
if tier in ("low",) and kind == "video":
|
|
return None
|
|
if tier == "mid" and kind == "video":
|
|
return next(m for m in pool if m.id == "ltx-video")
|
|
if tier == "high" and kind == "video":
|
|
return next(m for m in pool if m.id == "ltx-video")
|
|
if tier == "ultra" and kind == "video":
|
|
return next(m for m in pool if m.id == "wan-2-1")
|
|
return next((m for m in pool if m.min_vram_gb <= target), None)
|
|
|
|
|
|
def list_for_tier(tier: str, kind: Kind) -> list[ModelSpec]:
|
|
pool = IMAGE_MODELS if kind == "image" else VIDEO_MODELS
|
|
tier_max = {"cpu": 4.0, "low": 6.0, "mid": 10.0, "high": 14.0, "ultra": 1000.0}
|
|
cap = tier_max.get(tier, 1000.0)
|
|
out = [m for m in pool if m.min_vram_gb <= cap + 2.0]
|
|
out.sort(key=lambda m: m.min_vram_gb)
|
|
return out
|
|
|
|
|
|
def find(model_id: str) -> ModelSpec | None:
|
|
for m in IMAGE_MODELS + VIDEO_MODELS:
|
|
if m.id == model_id:
|
|
return m
|
|
return None
|
|
|
|
|
|
def is_cached(spec: ModelSpec) -> bool:
|
|
"""Best-effort: check if HF snapshot exists in our cache dir."""
|
|
from pathlib import Path
|
|
cache = Path(__file__).parent.parent / "models" / "diffusers"
|
|
if not cache.exists():
|
|
return False
|
|
folder = "models--" + spec.repo.replace("/", "--")
|
|
return (cache / folder).exists()
|
|
|
|
|
|
def label_with_meta(spec: ModelSpec) -> str:
|
|
nsfw = " · NSFW✓" if spec.nsfw_capable else ""
|
|
cached = " · cached" if is_cached(spec) else f" · {spec.download_gb:.0f}GB download"
|
|
return f"{spec.label}{nsfw}{cached}"
|