Initial commit
This commit is contained in:
@@ -0,0 +1,163 @@
|
||||
"""Model registry. VRAM tier drives default. User can override."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
Kind = Literal["image", "video"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelSpec:
|
||||
id: str
|
||||
label: str
|
||||
repo: str
|
||||
kind: Kind
|
||||
min_vram_gb: float
|
||||
nsfw_capable: bool
|
||||
download_gb: float = 0.0
|
||||
notes: str = ""
|
||||
|
||||
|
||||
IMAGE_MODELS: list[ModelSpec] = [
|
||||
ModelSpec(
|
||||
id="sdxl-turbo",
|
||||
label="SDXL Turbo (fast, low VRAM)",
|
||||
repo="stabilityai/sdxl-turbo",
|
||||
kind="image",
|
||||
min_vram_gb=4.0,
|
||||
nsfw_capable=False,
|
||||
download_gb=7.0,
|
||||
notes="1-4 step sampling. Censored base. Good for safe content on weak GPUs.",
|
||||
),
|
||||
ModelSpec(
|
||||
id="pony-xl",
|
||||
label="Pony Diffusion XL v6 (NSFW, anime/realistic)",
|
||||
repo="AstraliteHeart/pony-diffusion-v6",
|
||||
kind="image",
|
||||
min_vram_gb=8.0,
|
||||
nsfw_capable=True,
|
||||
download_gb=7.0,
|
||||
notes="Use score_9, score_8_up tags. Strong NSFW capability.",
|
||||
),
|
||||
ModelSpec(
|
||||
id="illustrious-xl",
|
||||
label="Illustrious XL v0.1 (NSFW, illustration)",
|
||||
repo="OnomaAIResearch/Illustrious-xl-early-release-v0",
|
||||
kind="image",
|
||||
min_vram_gb=10.0,
|
||||
nsfw_capable=True,
|
||||
download_gb=7.0,
|
||||
notes="High-detail anime/illustration. NSFW-capable.",
|
||||
),
|
||||
ModelSpec(
|
||||
id="sdxl-base",
|
||||
label="SDXL Base 1.0 (versatile)",
|
||||
repo="stabilityai/stable-diffusion-xl-base-1.0",
|
||||
kind="image",
|
||||
min_vram_gb=8.0,
|
||||
nsfw_capable=False,
|
||||
download_gb=7.0,
|
||||
notes="General-purpose. Censored.",
|
||||
),
|
||||
ModelSpec(
|
||||
id="flux-schnell",
|
||||
label="FLUX.1 Schnell (high quality, 12GB+)",
|
||||
repo="black-forest-labs/FLUX.1-schnell",
|
||||
kind="image",
|
||||
min_vram_gb=12.0,
|
||||
nsfw_capable=False,
|
||||
download_gb=24.0,
|
||||
notes="State-of-the-art prompt adherence. 4-step.",
|
||||
),
|
||||
]
|
||||
|
||||
VIDEO_MODELS: list[ModelSpec] = [
|
||||
ModelSpec(
|
||||
id="ltx-video",
|
||||
label="LTX-Video 0.9 (fast, 8GB+)",
|
||||
repo="Lightricks/LTX-Video",
|
||||
kind="video",
|
||||
min_vram_gb=8.0,
|
||||
nsfw_capable=True,
|
||||
download_gb=18.0,
|
||||
notes="2-5 second clips at 24fps. Fast generation.",
|
||||
),
|
||||
ModelSpec(
|
||||
id="cogvideox-5b",
|
||||
label="CogVideoX 5B (12GB+)",
|
||||
repo="THUDM/CogVideoX-5b",
|
||||
kind="video",
|
||||
min_vram_gb=12.0,
|
||||
nsfw_capable=False,
|
||||
download_gb=20.0,
|
||||
notes="6 second clips. Good motion quality.",
|
||||
),
|
||||
ModelSpec(
|
||||
id="wan-2-1",
|
||||
label="Wan 2.1 T2V 1.3B (16GB+)",
|
||||
repo="Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
|
||||
kind="video",
|
||||
min_vram_gb=14.0,
|
||||
nsfw_capable=True,
|
||||
download_gb=16.0,
|
||||
notes="Top-tier open video model. Needs diffusers>=0.32.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def default_for_tier(tier: str, kind: Kind) -> ModelSpec | None:
|
||||
pool = IMAGE_MODELS if kind == "image" else VIDEO_MODELS
|
||||
tier_min = {"cpu": 0.0, "low": 0.0, "mid": 8.0, "high": 10.0, "ultra": 12.0}
|
||||
target = tier_min.get(tier, 0.0)
|
||||
if tier == "cpu" and kind == "video":
|
||||
return None
|
||||
if tier in ("cpu", "low") and kind == "image":
|
||||
return next(m for m in pool if m.id == "sdxl-turbo")
|
||||
if tier == "mid" and kind == "image":
|
||||
return next(m for m in pool if m.id == "pony-xl")
|
||||
if tier == "high" and kind == "image":
|
||||
return next(m for m in pool if m.id == "illustrious-xl")
|
||||
if tier == "ultra" and kind == "image":
|
||||
return next(m for m in pool if m.id == "illustrious-xl")
|
||||
if tier in ("low",) and kind == "video":
|
||||
return None
|
||||
if tier == "mid" and kind == "video":
|
||||
return next(m for m in pool if m.id == "ltx-video")
|
||||
if tier == "high" and kind == "video":
|
||||
return next(m for m in pool if m.id == "ltx-video")
|
||||
if tier == "ultra" and kind == "video":
|
||||
return next(m for m in pool if m.id == "wan-2-1")
|
||||
return next((m for m in pool if m.min_vram_gb <= target), None)
|
||||
|
||||
|
||||
def list_for_tier(tier: str, kind: Kind) -> list[ModelSpec]:
|
||||
pool = IMAGE_MODELS if kind == "image" else VIDEO_MODELS
|
||||
tier_max = {"cpu": 4.0, "low": 6.0, "mid": 10.0, "high": 14.0, "ultra": 1000.0}
|
||||
cap = tier_max.get(tier, 1000.0)
|
||||
out = [m for m in pool if m.min_vram_gb <= cap + 2.0]
|
||||
out.sort(key=lambda m: m.min_vram_gb)
|
||||
return out
|
||||
|
||||
|
||||
def find(model_id: str) -> ModelSpec | None:
|
||||
for m in IMAGE_MODELS + VIDEO_MODELS:
|
||||
if m.id == model_id:
|
||||
return m
|
||||
return None
|
||||
|
||||
|
||||
def is_cached(spec: ModelSpec) -> bool:
|
||||
"""Best-effort: check if HF snapshot exists in our cache dir."""
|
||||
from pathlib import Path
|
||||
cache = Path(__file__).parent.parent / "models" / "diffusers"
|
||||
if not cache.exists():
|
||||
return False
|
||||
folder = "models--" + spec.repo.replace("/", "--")
|
||||
return (cache / folder).exists()
|
||||
|
||||
|
||||
def label_with_meta(spec: ModelSpec) -> str:
|
||||
nsfw = " · NSFW✓" if spec.nsfw_capable else ""
|
||||
cached = " · cached" if is_cached(spec) else f" · {spec.download_gb:.0f}GB download"
|
||||
return f"{spec.label}{nsfw}{cached}"
|
||||
Reference in New Issue
Block a user