"""Model registry. VRAM tier drives default. User can override.""" from __future__ import annotations from dataclasses import dataclass from typing import Literal Kind = Literal["image", "video"] @dataclass(frozen=True) class ModelSpec: id: str label: str repo: str kind: Kind min_vram_gb: float nsfw_capable: bool download_gb: float = 0.0 notes: str = "" IMAGE_MODELS: list[ModelSpec] = [ ModelSpec( id="sdxl-turbo", label="SDXL Turbo (fast, low VRAM)", repo="stabilityai/sdxl-turbo", kind="image", min_vram_gb=4.0, nsfw_capable=False, download_gb=7.0, notes="1-4 step sampling. Censored base. Good for safe content on weak GPUs.", ), ModelSpec( id="pony-xl", label="Pony Diffusion XL v6 (NSFW, anime/realistic)", repo="AstraliteHeart/pony-diffusion-v6", kind="image", min_vram_gb=8.0, nsfw_capable=True, download_gb=7.0, notes="Use score_9, score_8_up tags. Strong NSFW capability.", ), ModelSpec( id="illustrious-xl", label="Illustrious XL v0.1 (NSFW, illustration)", repo="OnomaAIResearch/Illustrious-xl-early-release-v0", kind="image", min_vram_gb=10.0, nsfw_capable=True, download_gb=7.0, notes="High-detail anime/illustration. NSFW-capable.", ), ModelSpec( id="sdxl-base", label="SDXL Base 1.0 (versatile)", repo="stabilityai/stable-diffusion-xl-base-1.0", kind="image", min_vram_gb=8.0, nsfw_capable=False, download_gb=7.0, notes="General-purpose. Censored.", ), ModelSpec( id="flux-schnell", label="FLUX.1 Schnell (high quality, 12GB+)", repo="black-forest-labs/FLUX.1-schnell", kind="image", min_vram_gb=12.0, nsfw_capable=False, download_gb=24.0, notes="State-of-the-art prompt adherence. 4-step.", ), ] VIDEO_MODELS: list[ModelSpec] = [ ModelSpec( id="ltx-video", label="LTX-Video 0.9 (fast, 8GB+)", repo="Lightricks/LTX-Video", kind="video", min_vram_gb=8.0, nsfw_capable=True, download_gb=18.0, notes="2-5 second clips at 24fps. Fast generation.", ), ModelSpec( id="cogvideox-5b", label="CogVideoX 5B (12GB+)", repo="THUDM/CogVideoX-5b", kind="video", min_vram_gb=12.0, nsfw_capable=False, download_gb=20.0, notes="6 second clips. Good motion quality.", ), ModelSpec( id="wan-2-1", label="Wan 2.1 T2V 1.3B (16GB+)", repo="Wan-AI/Wan2.1-T2V-1.3B-Diffusers", kind="video", min_vram_gb=14.0, nsfw_capable=True, download_gb=16.0, notes="Top-tier open video model. Needs diffusers>=0.32.", ), ] def default_for_tier(tier: str, kind: Kind) -> ModelSpec | None: pool = IMAGE_MODELS if kind == "image" else VIDEO_MODELS tier_min = {"cpu": 0.0, "low": 0.0, "mid": 8.0, "high": 10.0, "ultra": 12.0} target = tier_min.get(tier, 0.0) if tier == "cpu" and kind == "video": return None if tier in ("cpu", "low") and kind == "image": return next(m for m in pool if m.id == "sdxl-turbo") if tier == "mid" and kind == "image": return next(m for m in pool if m.id == "pony-xl") if tier == "high" and kind == "image": return next(m for m in pool if m.id == "illustrious-xl") if tier == "ultra" and kind == "image": return next(m for m in pool if m.id == "illustrious-xl") if tier in ("low",) and kind == "video": return None if tier == "mid" and kind == "video": return next(m for m in pool if m.id == "ltx-video") if tier == "high" and kind == "video": return next(m for m in pool if m.id == "ltx-video") if tier == "ultra" and kind == "video": return next(m for m in pool if m.id == "wan-2-1") return next((m for m in pool if m.min_vram_gb <= target), None) def list_for_tier(tier: str, kind: Kind) -> list[ModelSpec]: pool = IMAGE_MODELS if kind == "image" else VIDEO_MODELS tier_max = {"cpu": 4.0, "low": 6.0, "mid": 10.0, "high": 14.0, "ultra": 1000.0} cap = tier_max.get(tier, 1000.0) out = [m for m in pool if m.min_vram_gb <= cap + 2.0] out.sort(key=lambda m: m.min_vram_gb) return out def find(model_id: str) -> ModelSpec | None: for m in IMAGE_MODELS + VIDEO_MODELS: if m.id == model_id: return m return None def is_cached(spec: ModelSpec) -> bool: """Best-effort: check if HF snapshot exists in our cache dir.""" from pathlib import Path cache = Path(__file__).parent.parent / "models" / "diffusers" if not cache.exists(): return False folder = "models--" + spec.repo.replace("/", "--") return (cache / folder).exists() def label_with_meta(spec: ModelSpec) -> str: nsfw = " · NSFW✓" if spec.nsfw_capable else "" cached = " · cached" if is_cached(spec) else f" · {spec.download_gb:.0f}GB download" return f"{spec.label}{nsfw}{cached}"