Initial commit
This commit is contained in:
@@ -0,0 +1,47 @@
|
||||
"""Torch device selection. Wraps cuda / directml / cpu behind one helper."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).parent.parent
|
||||
HARDWARE_CACHE = ROOT / "config.local.json"
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def hardware_info() -> dict:
|
||||
if HARDWARE_CACHE.exists():
|
||||
return json.loads(HARDWARE_CACHE.read_text())
|
||||
from . import hardware
|
||||
info = hardware.detect()
|
||||
return {
|
||||
"vendor": info.vendor,
|
||||
"backend": info.backend,
|
||||
"device_name": info.device_name,
|
||||
"vram_gb": info.vram_gb,
|
||||
"tier": info.tier,
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_device():
|
||||
import torch
|
||||
backend = hardware_info()["backend"]
|
||||
if backend == "cuda" and torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
if backend == "directml":
|
||||
try:
|
||||
import torch_directml
|
||||
return torch_directml.device()
|
||||
except ImportError:
|
||||
pass
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def torch_dtype():
|
||||
import torch
|
||||
backend = hardware_info()["backend"]
|
||||
if backend == "cpu":
|
||||
return torch.float32
|
||||
return torch.float16
|
||||
@@ -0,0 +1,156 @@
|
||||
"""GPU and VRAM detection. Returns vendor + tier used to pick torch wheel and default models."""
|
||||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
import platform
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
Vendor = Literal["nvidia", "amd", "intel", "cpu"]
|
||||
Backend = Literal["cuda", "directml", "cpu"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class HardwareInfo:
|
||||
vendor: Vendor
|
||||
backend: Backend
|
||||
device_name: str
|
||||
vram_gb: float
|
||||
tier: Literal["cpu", "low", "mid", "high", "ultra"]
|
||||
|
||||
|
||||
def _run(cmd: list[str]) -> str:
|
||||
try:
|
||||
out = subprocess.run(cmd, capture_output=True, text=True, timeout=10, check=False)
|
||||
return out.stdout.strip()
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
return ""
|
||||
|
||||
|
||||
def _detect_nvidia() -> tuple[str, float] | None:
|
||||
out = _run(["nvidia-smi", "--query-gpu=name,memory.total", "--format=csv,noheader,nounits"])
|
||||
if not out:
|
||||
return None
|
||||
first = out.splitlines()[0]
|
||||
name, mem = [p.strip() for p in first.split(",", 1)]
|
||||
return name, float(mem) / 1024.0
|
||||
|
||||
|
||||
def _detect_dxgi() -> list[tuple[str, float, str]]:
|
||||
"""Enumerate all DXGI adapters via PowerShell. Returns list of (name, vram_gb, vendor_hint)."""
|
||||
if platform.system() != "Windows":
|
||||
return []
|
||||
ps = (
|
||||
"Get-CimInstance Win32_VideoController | "
|
||||
"Select-Object Name, AdapterRAM | "
|
||||
"ConvertTo-Json -Compress"
|
||||
)
|
||||
out = _run(["powershell", "-NoProfile", "-Command", ps])
|
||||
if not out:
|
||||
return []
|
||||
import json
|
||||
try:
|
||||
data = json.loads(out)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
if isinstance(data, dict):
|
||||
data = [data]
|
||||
results: list[tuple[str, float, str]] = []
|
||||
for entry in data:
|
||||
name = (entry.get("Name") or "").strip()
|
||||
ram = entry.get("AdapterRAM") or 0
|
||||
# Win32_VideoController caps AdapterRAM at 4GB on many systems. Trust value but flag below.
|
||||
vram_gb = float(ram) / (1024 ** 3) if ram else 0.0
|
||||
low = name.lower()
|
||||
if "nvidia" in low or "geforce" in low or "rtx" in low or "gtx" in low:
|
||||
vendor = "nvidia"
|
||||
elif "amd" in low or "radeon" in low or "rx " in low:
|
||||
vendor = "amd"
|
||||
elif "intel" in low or "arc" in low or "iris" in low:
|
||||
vendor = "intel"
|
||||
else:
|
||||
vendor = "unknown"
|
||||
if name:
|
||||
results.append((name, vram_gb, vendor))
|
||||
return results
|
||||
|
||||
|
||||
def _vram_tier(vram_gb: float) -> Literal["cpu", "low", "mid", "high", "ultra"]:
|
||||
if vram_gb < 1:
|
||||
return "cpu"
|
||||
if vram_gb < 6:
|
||||
return "low"
|
||||
if vram_gb < 10:
|
||||
return "mid"
|
||||
if vram_gb < 14:
|
||||
return "high"
|
||||
return "ultra"
|
||||
|
||||
|
||||
def detect() -> HardwareInfo:
|
||||
nv = _detect_nvidia()
|
||||
if nv:
|
||||
name, vram = nv
|
||||
return HardwareInfo("nvidia", "cuda", name, vram, _vram_tier(vram))
|
||||
|
||||
adapters = _detect_dxgi()
|
||||
# Prefer discrete (highest VRAM) non-basic adapter
|
||||
adapters = [a for a in adapters if "basic" not in a[0].lower() and "microsoft" not in a[0].lower()]
|
||||
if adapters:
|
||||
adapters.sort(key=lambda a: a[1], reverse=True)
|
||||
name, vram, hint = adapters[0]
|
||||
# AdapterRAM is unreliable for >4GB cards. If exactly 4GB and modern AMD/Intel card name, bump.
|
||||
if vram <= 4.1 and any(k in name.lower() for k in ("rx 6", "rx 7", "arc a", "arc b")):
|
||||
vram = 8.0 # conservative guess
|
||||
if hint in ("amd", "intel"):
|
||||
return HardwareInfo(hint, "directml", name, vram, _vram_tier(vram))
|
||||
if hint == "nvidia":
|
||||
# nvidia-smi missing but card is nvidia: drivers may be broken, fall through to directml
|
||||
return HardwareInfo("nvidia", "directml", name, vram, _vram_tier(vram))
|
||||
|
||||
return HardwareInfo("cpu", "cpu", platform.processor() or "CPU", 0.0, "cpu")
|
||||
|
||||
|
||||
def directml_supported() -> bool:
|
||||
"""torch-directml ships wheels for Python 3.10 / 3.11. The launcher pins the
|
||||
venv to 3.11, so this is True in the running app."""
|
||||
import sys
|
||||
return sys.version_info[:2] in {(3, 10), (3, 11)}
|
||||
|
||||
|
||||
def torch_install_args(info: HardwareInfo) -> list[str]:
|
||||
"""Return uv-pip install args for the right PyTorch build.
|
||||
|
||||
Launcher pins venv to Python 3.11, so torch-directml wheels are always available
|
||||
for AMD/Intel paths. Latest stable torch is used elsewhere.
|
||||
"""
|
||||
if info.backend == "cuda":
|
||||
return [
|
||||
"torch",
|
||||
"torchvision",
|
||||
"--index-url",
|
||||
"https://download.pytorch.org/whl/cu124",
|
||||
]
|
||||
if info.backend == "directml":
|
||||
# torch-directml currently pins to torch 2.4.x. Match it.
|
||||
return [
|
||||
"torch>=2.4,<2.5",
|
||||
"torchvision>=0.19,<0.20",
|
||||
"torch-directml>=0.2.5",
|
||||
]
|
||||
return [
|
||||
"torch",
|
||||
"torchvision",
|
||||
"--index-url",
|
||||
"https://download.pytorch.org/whl/cpu",
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
info = detect()
|
||||
print(f"Vendor: {info.vendor}")
|
||||
print(f"Backend: {info.backend}")
|
||||
print(f"Device: {info.device_name}")
|
||||
print(f"VRAM: {info.vram_gb:.1f} GB")
|
||||
print(f"Tier: {info.tier}")
|
||||
@@ -0,0 +1,92 @@
|
||||
"""SDXL-family image generation. Loaded lazily, cached by repo id."""
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from .device import get_device, torch_dtype
|
||||
from .memory import apply_memory_strategy
|
||||
from .models import ModelSpec, find
|
||||
|
||||
ROOT = Path(__file__).parent.parent
|
||||
MODELS_DIR = ROOT / "models"
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _load_pipeline(repo: str, kind: str):
|
||||
import torch
|
||||
from diffusers import (
|
||||
AutoPipelineForText2Image,
|
||||
DiffusionPipeline,
|
||||
FluxPipeline,
|
||||
)
|
||||
|
||||
device = get_device()
|
||||
dtype = torch_dtype()
|
||||
cache_dir = str(MODELS_DIR / "diffusers")
|
||||
|
||||
if "FLUX" in repo:
|
||||
pipe = FluxPipeline.from_pretrained(repo, torch_dtype=dtype, cache_dir=cache_dir)
|
||||
elif "turbo" in repo.lower():
|
||||
pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
repo, torch_dtype=dtype, variant="fp16", cache_dir=cache_dir
|
||||
)
|
||||
else:
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
repo, torch_dtype=dtype, use_safetensors=True, cache_dir=cache_dir
|
||||
)
|
||||
|
||||
# Disable built-in safety checker. We use our own filter (CSAM-only).
|
||||
if hasattr(pipe, "safety_checker"):
|
||||
pipe.safety_checker = None
|
||||
if hasattr(pipe, "requires_safety_checker"):
|
||||
pipe.requires_safety_checker = False
|
||||
|
||||
apply_memory_strategy(pipe)
|
||||
return pipe
|
||||
|
||||
|
||||
def generate(
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
model_id: str = "pony-xl",
|
||||
width: int = 1024,
|
||||
height: int = 1024,
|
||||
steps: int = 30,
|
||||
guidance: float = 7.0,
|
||||
seed: Optional[int] = None,
|
||||
) -> Image.Image:
|
||||
spec: ModelSpec | None = find(model_id)
|
||||
if spec is None or spec.kind != "image":
|
||||
raise ValueError(f"Unknown image model: {model_id}")
|
||||
|
||||
pipe = _load_pipeline(spec.repo, spec.kind)
|
||||
|
||||
import torch
|
||||
|
||||
generator = None
|
||||
if seed is not None:
|
||||
try:
|
||||
generator = torch.Generator(device=get_device()).manual_seed(int(seed))
|
||||
except RuntimeError:
|
||||
generator = torch.Generator().manual_seed(int(seed))
|
||||
|
||||
if "turbo" in spec.repo.lower():
|
||||
steps = max(1, min(steps, 4))
|
||||
guidance = 0.0
|
||||
if "FLUX" in spec.repo:
|
||||
guidance = 0.0
|
||||
|
||||
out = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt or None,
|
||||
width=width,
|
||||
height=height,
|
||||
num_inference_steps=steps,
|
||||
guidance_scale=guidance,
|
||||
generator=generator,
|
||||
)
|
||||
return out.images[0]
|
||||
@@ -0,0 +1,60 @@
|
||||
"""Memory strategy selection. CUDA supports cpu_offload; DirectML/CPU do not.
|
||||
|
||||
Apply per-pipeline based on backend + VRAM tier. All paths reduce peak VRAM
|
||||
without breaking on non-CUDA devices.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from .device import get_device, hardware_info
|
||||
|
||||
|
||||
def apply_memory_strategy(pipe) -> None:
|
||||
"""Apply VRAM-saving knobs that match the active backend."""
|
||||
info = hardware_info()
|
||||
backend = info["backend"]
|
||||
vram = info["vram_gb"]
|
||||
|
||||
# Always-safe: VAE tiling/slicing work on any device. Cuts peak VRAM during decode.
|
||||
# Newer diffusers (>=0.32) prefers calling on the VAE directly.
|
||||
vae = getattr(pipe, "vae", None)
|
||||
if vae is not None:
|
||||
for fn in ("enable_slicing", "enable_tiling"):
|
||||
if hasattr(vae, fn):
|
||||
try:
|
||||
getattr(vae, fn)()
|
||||
except Exception:
|
||||
pass
|
||||
if hasattr(pipe, "enable_attention_slicing"):
|
||||
try:
|
||||
pipe.enable_attention_slicing()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if backend == "cuda":
|
||||
# Offload only if VRAM tight. cpu_offload is CUDA-only via accelerate hooks.
|
||||
if vram < 10:
|
||||
try:
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_model_cpu_offload()
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
pipe.to(get_device())
|
||||
return
|
||||
|
||||
if backend == "directml":
|
||||
# DirectML lacks accelerate hook support. Move whole pipe to device.
|
||||
# Slicing already enabled above keeps peak in check.
|
||||
try:
|
||||
pipe.to(get_device())
|
||||
except Exception:
|
||||
# Some pipes have components that won't move cleanly; fall back to CPU.
|
||||
pipe.to("cpu")
|
||||
return
|
||||
|
||||
# CPU
|
||||
pipe.to("cpu")
|
||||
@@ -0,0 +1,163 @@
|
||||
"""Model registry. VRAM tier drives default. User can override."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
Kind = Literal["image", "video"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelSpec:
|
||||
id: str
|
||||
label: str
|
||||
repo: str
|
||||
kind: Kind
|
||||
min_vram_gb: float
|
||||
nsfw_capable: bool
|
||||
download_gb: float = 0.0
|
||||
notes: str = ""
|
||||
|
||||
|
||||
IMAGE_MODELS: list[ModelSpec] = [
|
||||
ModelSpec(
|
||||
id="sdxl-turbo",
|
||||
label="SDXL Turbo (fast, low VRAM)",
|
||||
repo="stabilityai/sdxl-turbo",
|
||||
kind="image",
|
||||
min_vram_gb=4.0,
|
||||
nsfw_capable=False,
|
||||
download_gb=7.0,
|
||||
notes="1-4 step sampling. Censored base. Good for safe content on weak GPUs.",
|
||||
),
|
||||
ModelSpec(
|
||||
id="pony-xl",
|
||||
label="Pony Diffusion XL v6 (NSFW, anime/realistic)",
|
||||
repo="AstraliteHeart/pony-diffusion-v6",
|
||||
kind="image",
|
||||
min_vram_gb=8.0,
|
||||
nsfw_capable=True,
|
||||
download_gb=7.0,
|
||||
notes="Use score_9, score_8_up tags. Strong NSFW capability.",
|
||||
),
|
||||
ModelSpec(
|
||||
id="illustrious-xl",
|
||||
label="Illustrious XL v0.1 (NSFW, illustration)",
|
||||
repo="OnomaAIResearch/Illustrious-xl-early-release-v0",
|
||||
kind="image",
|
||||
min_vram_gb=10.0,
|
||||
nsfw_capable=True,
|
||||
download_gb=7.0,
|
||||
notes="High-detail anime/illustration. NSFW-capable.",
|
||||
),
|
||||
ModelSpec(
|
||||
id="sdxl-base",
|
||||
label="SDXL Base 1.0 (versatile)",
|
||||
repo="stabilityai/stable-diffusion-xl-base-1.0",
|
||||
kind="image",
|
||||
min_vram_gb=8.0,
|
||||
nsfw_capable=False,
|
||||
download_gb=7.0,
|
||||
notes="General-purpose. Censored.",
|
||||
),
|
||||
ModelSpec(
|
||||
id="flux-schnell",
|
||||
label="FLUX.1 Schnell (high quality, 12GB+)",
|
||||
repo="black-forest-labs/FLUX.1-schnell",
|
||||
kind="image",
|
||||
min_vram_gb=12.0,
|
||||
nsfw_capable=False,
|
||||
download_gb=24.0,
|
||||
notes="State-of-the-art prompt adherence. 4-step.",
|
||||
),
|
||||
]
|
||||
|
||||
VIDEO_MODELS: list[ModelSpec] = [
|
||||
ModelSpec(
|
||||
id="ltx-video",
|
||||
label="LTX-Video 0.9 (fast, 8GB+)",
|
||||
repo="Lightricks/LTX-Video",
|
||||
kind="video",
|
||||
min_vram_gb=8.0,
|
||||
nsfw_capable=True,
|
||||
download_gb=18.0,
|
||||
notes="2-5 second clips at 24fps. Fast generation.",
|
||||
),
|
||||
ModelSpec(
|
||||
id="cogvideox-5b",
|
||||
label="CogVideoX 5B (12GB+)",
|
||||
repo="THUDM/CogVideoX-5b",
|
||||
kind="video",
|
||||
min_vram_gb=12.0,
|
||||
nsfw_capable=False,
|
||||
download_gb=20.0,
|
||||
notes="6 second clips. Good motion quality.",
|
||||
),
|
||||
ModelSpec(
|
||||
id="wan-2-1",
|
||||
label="Wan 2.1 T2V 1.3B (16GB+)",
|
||||
repo="Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
|
||||
kind="video",
|
||||
min_vram_gb=14.0,
|
||||
nsfw_capable=True,
|
||||
download_gb=16.0,
|
||||
notes="Top-tier open video model. Needs diffusers>=0.32.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def default_for_tier(tier: str, kind: Kind) -> ModelSpec | None:
|
||||
pool = IMAGE_MODELS if kind == "image" else VIDEO_MODELS
|
||||
tier_min = {"cpu": 0.0, "low": 0.0, "mid": 8.0, "high": 10.0, "ultra": 12.0}
|
||||
target = tier_min.get(tier, 0.0)
|
||||
if tier == "cpu" and kind == "video":
|
||||
return None
|
||||
if tier in ("cpu", "low") and kind == "image":
|
||||
return next(m for m in pool if m.id == "sdxl-turbo")
|
||||
if tier == "mid" and kind == "image":
|
||||
return next(m for m in pool if m.id == "pony-xl")
|
||||
if tier == "high" and kind == "image":
|
||||
return next(m for m in pool if m.id == "illustrious-xl")
|
||||
if tier == "ultra" and kind == "image":
|
||||
return next(m for m in pool if m.id == "illustrious-xl")
|
||||
if tier in ("low",) and kind == "video":
|
||||
return None
|
||||
if tier == "mid" and kind == "video":
|
||||
return next(m for m in pool if m.id == "ltx-video")
|
||||
if tier == "high" and kind == "video":
|
||||
return next(m for m in pool if m.id == "ltx-video")
|
||||
if tier == "ultra" and kind == "video":
|
||||
return next(m for m in pool if m.id == "wan-2-1")
|
||||
return next((m for m in pool if m.min_vram_gb <= target), None)
|
||||
|
||||
|
||||
def list_for_tier(tier: str, kind: Kind) -> list[ModelSpec]:
|
||||
pool = IMAGE_MODELS if kind == "image" else VIDEO_MODELS
|
||||
tier_max = {"cpu": 4.0, "low": 6.0, "mid": 10.0, "high": 14.0, "ultra": 1000.0}
|
||||
cap = tier_max.get(tier, 1000.0)
|
||||
out = [m for m in pool if m.min_vram_gb <= cap + 2.0]
|
||||
out.sort(key=lambda m: m.min_vram_gb)
|
||||
return out
|
||||
|
||||
|
||||
def find(model_id: str) -> ModelSpec | None:
|
||||
for m in IMAGE_MODELS + VIDEO_MODELS:
|
||||
if m.id == model_id:
|
||||
return m
|
||||
return None
|
||||
|
||||
|
||||
def is_cached(spec: ModelSpec) -> bool:
|
||||
"""Best-effort: check if HF snapshot exists in our cache dir."""
|
||||
from pathlib import Path
|
||||
cache = Path(__file__).parent.parent / "models" / "diffusers"
|
||||
if not cache.exists():
|
||||
return False
|
||||
folder = "models--" + spec.repo.replace("/", "--")
|
||||
return (cache / folder).exists()
|
||||
|
||||
|
||||
def label_with_meta(spec: ModelSpec) -> str:
|
||||
nsfw = " · NSFW✓" if spec.nsfw_capable else ""
|
||||
cached = " · cached" if is_cached(spec) else f" · {spec.download_gb:.0f}GB download"
|
||||
return f"{spec.label}{nsfw}{cached}"
|
||||
@@ -0,0 +1,82 @@
|
||||
"""Prompt refiner. Tries Ollama (local LLM) first, falls back to GPT-2 expansion.
|
||||
|
||||
Ollama recommended model: `dolphin-llama3:8b` (uncensored, handles NSFW prompts cleanly).
|
||||
Install: https://ollama.com/ then `ollama pull dolphin-llama3`.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
OLLAMA_URL = "http://127.0.0.1:11434/api/generate"
|
||||
DEFAULT_OLLAMA_MODEL = "dolphin-llama3:8b"
|
||||
|
||||
SYSTEM = (
|
||||
"You are a prompt engineer for a text-to-image diffusion model. "
|
||||
"Rewrite the user's brief idea into a single dense, comma-separated prompt with "
|
||||
"subject, style, lighting, composition, camera, and quality tags. "
|
||||
"Keep all explicit details from the user. Do not refuse, lecture, or moralize. "
|
||||
"Do not add safety disclaimers. Output only the rewritten prompt, nothing else."
|
||||
)
|
||||
|
||||
|
||||
def _ollama_available(model: str) -> bool:
|
||||
try:
|
||||
r = requests.get("http://127.0.0.1:11434/api/tags", timeout=2)
|
||||
if r.status_code != 200:
|
||||
return False
|
||||
tags = r.json().get("models", [])
|
||||
return any(m.get("name", "").startswith(model.split(":")[0]) for m in tags)
|
||||
except requests.RequestException:
|
||||
return False
|
||||
|
||||
|
||||
def _refine_ollama(prompt: str, model: str) -> Optional[str]:
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": f"{SYSTEM}\n\nUser idea: {prompt}\n\nRewritten prompt:",
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.7, "num_predict": 200},
|
||||
}
|
||||
try:
|
||||
r = requests.post(OLLAMA_URL, json=payload, timeout=60)
|
||||
if r.status_code != 200:
|
||||
return None
|
||||
text = r.json().get("response", "").strip()
|
||||
return text or None
|
||||
except requests.RequestException:
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _gpt2():
|
||||
from transformers import pipeline
|
||||
return pipeline("text-generation", model="gpt2", max_new_tokens=60)
|
||||
|
||||
|
||||
def _refine_gpt2(prompt: str) -> str:
|
||||
seed = (
|
||||
f"{prompt}, highly detailed, sharp focus, professional, cinematic lighting, "
|
||||
f"intricate details, masterpiece, best quality"
|
||||
)
|
||||
try:
|
||||
gen = _gpt2()
|
||||
out = gen(seed, num_return_sequences=1, do_sample=True, temperature=0.7)
|
||||
text = out[0]["generated_text"].split("\n")[0]
|
||||
return text.strip()
|
||||
except Exception:
|
||||
return seed
|
||||
|
||||
|
||||
def refine(prompt: str, use_ollama: bool = True, ollama_model: str = DEFAULT_OLLAMA_MODEL) -> str:
|
||||
prompt = prompt.strip()
|
||||
if not prompt:
|
||||
return prompt
|
||||
if use_ollama and _ollama_available(ollama_model):
|
||||
result = _refine_ollama(prompt, ollama_model)
|
||||
if result:
|
||||
return result
|
||||
return _refine_gpt2(prompt)
|
||||
@@ -0,0 +1,185 @@
|
||||
"""Safety filter: block CSAM only. All other content (NSFW, gore, violence) allowed.
|
||||
|
||||
Layers:
|
||||
1. Prompt keyword gate: rejects child-term + sexual-term combinations.
|
||||
2. Output check: NudeNet detects nudity. If nudity present, run face detection
|
||||
(MTCNN) + age classifier (ViT) on every face. Block if any face is classified
|
||||
as a minor with high confidence.
|
||||
|
||||
This is best-effort. The user is legally responsible for use of generated content.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Iterable
|
||||
|
||||
from PIL import Image
|
||||
|
||||
# --- prompt keyword gate ---------------------------------------------------
|
||||
|
||||
_CHILD_TERMS = [
|
||||
r"\b(child|children|kid|kids|minor|underage|under-?age|preteen|pre-?teen)\b",
|
||||
r"\b(toddler|infant|baby|babies)\b",
|
||||
r"\b(\d|0?[0-9]|1[0-7])\s*(yo|y/o|year[- ]?old)\b",
|
||||
r"\bloli(con)?\b",
|
||||
r"\bshota(con)?\b",
|
||||
r"\bcp\b",
|
||||
]
|
||||
_SEXUAL_TERMS = [
|
||||
r"\b(nude|naked|nsfw|porn|sex|sexual|sexy|erotic|explicit)\b",
|
||||
r"\b(penis|vagina|breast|nipple|genital|cum|orgasm)\b",
|
||||
r"\b(intercourse|fellatio|cunnilingus|masturbat)\w*\b",
|
||||
r"\b(rape|molest)\w*\b",
|
||||
]
|
||||
|
||||
_CHILD_RE = re.compile("|".join(_CHILD_TERMS), re.IGNORECASE)
|
||||
_SEX_RE = re.compile("|".join(_SEXUAL_TERMS), re.IGNORECASE)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SafetyResult:
|
||||
allowed: bool
|
||||
reason: str = ""
|
||||
|
||||
|
||||
def check_prompt(prompt: str) -> SafetyResult:
|
||||
if _CHILD_RE.search(prompt) and _SEX_RE.search(prompt):
|
||||
return SafetyResult(False, "Prompt blocked: combines minor and sexual terms (CSAM gate).")
|
||||
return SafetyResult(True)
|
||||
|
||||
|
||||
# --- nudity detection ------------------------------------------------------
|
||||
|
||||
_NUDITY_LABELS = {
|
||||
"FEMALE_BREAST_EXPOSED",
|
||||
"FEMALE_GENITALIA_EXPOSED",
|
||||
"MALE_GENITALIA_EXPOSED",
|
||||
"BUTTOCKS_EXPOSED",
|
||||
"ANUS_EXPOSED",
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _nudenet():
|
||||
try:
|
||||
from nudenet import NudeDetector
|
||||
return NudeDetector()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _has_nudity(detections: Iterable[dict]) -> bool:
|
||||
for d in detections:
|
||||
label = d.get("class") or d.get("label") or ""
|
||||
score = float(d.get("score", 0.0))
|
||||
if label in _NUDITY_LABELS and score >= 0.5:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# --- face detection + age classification -----------------------------------
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _mtcnn():
|
||||
try:
|
||||
from facenet_pytorch import MTCNN
|
||||
import torch
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
return MTCNN(keep_all=True, device=device, post_process=False, min_face_size=40)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _age_classifier():
|
||||
try:
|
||||
from transformers import pipeline
|
||||
return pipeline(
|
||||
"image-classification",
|
||||
model="nateraw/vit-age-classifier",
|
||||
top_k=3,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
# Buckets reported by nateraw/vit-age-classifier.
|
||||
# Conservative minor set. "10-19" includes some adults — treat as minor only on
|
||||
# strong confidence to limit false positives on young-looking adults.
|
||||
_HARD_MINOR = {"0-2", "3-9"}
|
||||
_SOFT_MINOR = {"10-19"}
|
||||
|
||||
|
||||
def _faces(img: Image.Image):
|
||||
mtcnn = _mtcnn()
|
||||
if mtcnn is None:
|
||||
return []
|
||||
try:
|
||||
boxes, probs = mtcnn.detect(img)
|
||||
except Exception:
|
||||
return []
|
||||
if boxes is None:
|
||||
return []
|
||||
if probs is None:
|
||||
probs = [None] * len(boxes)
|
||||
out = []
|
||||
for box, prob in zip(boxes, probs):
|
||||
if prob is None or float(prob) < 0.9:
|
||||
continue
|
||||
x1, y1, x2, y2 = [int(max(0, v)) for v in box]
|
||||
if x2 - x1 < 30 or y2 - y1 < 30:
|
||||
continue
|
||||
out.append(img.crop((x1, y1, x2, y2)))
|
||||
return out
|
||||
|
||||
|
||||
def _is_minor_face(face_img: Image.Image) -> tuple[bool, str]:
|
||||
clf = _age_classifier()
|
||||
if clf is None:
|
||||
return False, ""
|
||||
try:
|
||||
preds = clf(face_img)
|
||||
except Exception:
|
||||
return False, ""
|
||||
# preds is list[dict(label, score)] sorted by score desc
|
||||
top = preds[0] if preds else None
|
||||
if not top:
|
||||
return False, ""
|
||||
label = top["label"]
|
||||
score = float(top["score"])
|
||||
if label in _HARD_MINOR and score >= 0.55:
|
||||
return True, f"minor face detected ({label}, conf={score:.2f})"
|
||||
if label in _SOFT_MINOR and score >= 0.85:
|
||||
return True, f"likely minor face ({label}, conf={score:.2f})"
|
||||
return False, ""
|
||||
|
||||
|
||||
def check_image(img: Image.Image) -> SafetyResult:
|
||||
"""Block if (nudity present) AND (any face classified as minor)."""
|
||||
det = _nudenet()
|
||||
if det is None:
|
||||
# No nudity detector available — fall through. Prompt gate is primary defense.
|
||||
return SafetyResult(True)
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
img.save(f.name)
|
||||
results = det.detect(f.name)
|
||||
except Exception:
|
||||
return SafetyResult(True)
|
||||
|
||||
if not _has_nudity(results):
|
||||
return SafetyResult(True)
|
||||
|
||||
faces = _faces(img)
|
||||
for face in faces:
|
||||
is_minor, reason = _is_minor_face(face)
|
||||
if is_minor:
|
||||
return SafetyResult(
|
||||
False,
|
||||
f"Output blocked: nudity + {reason}. Image discarded.",
|
||||
)
|
||||
return SafetyResult(True)
|
||||
@@ -0,0 +1,118 @@
|
||||
"""Video generation. LTX-Video / CogVideoX / Wan via diffusers pipelines."""
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from .device import get_device, hardware_info, torch_dtype
|
||||
from .memory import apply_memory_strategy
|
||||
from .models import ModelSpec, find
|
||||
|
||||
ROOT = Path(__file__).parent.parent
|
||||
MODELS_DIR = ROOT / "models"
|
||||
OUTPUTS = ROOT / "outputs"
|
||||
OUTPUTS.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
def _import_pipeline_for(repo: str):
|
||||
"""Pick the right pipeline class. Wan needs diffusers>=0.32 with WanPipeline."""
|
||||
if "LTX" in repo:
|
||||
from diffusers import LTXPipeline
|
||||
return LTXPipeline
|
||||
if "CogVideoX" in repo:
|
||||
from diffusers import CogVideoXPipeline
|
||||
return CogVideoXPipeline
|
||||
if "Wan" in repo:
|
||||
try:
|
||||
from diffusers import WanPipeline
|
||||
return WanPipeline
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Wan 2.1 needs diffusers>=0.32. Run: pip install -U diffusers"
|
||||
) from e
|
||||
if "Hunyuan" in repo:
|
||||
from diffusers import HunyuanVideoPipeline
|
||||
return HunyuanVideoPipeline
|
||||
from diffusers import DiffusionPipeline
|
||||
return DiffusionPipeline
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _load_pipeline(repo: str):
|
||||
dtype = torch_dtype()
|
||||
cache_dir = str(MODELS_DIR / "diffusers")
|
||||
|
||||
PipelineCls = _import_pipeline_for(repo)
|
||||
pipe = PipelineCls.from_pretrained(repo, torch_dtype=dtype, cache_dir=cache_dir)
|
||||
|
||||
apply_memory_strategy(pipe)
|
||||
return pipe
|
||||
|
||||
|
||||
def _model_kwargs(repo: str, base: dict) -> dict:
|
||||
"""Adjust kwargs per-pipeline. Some pipes don't accept width/height or have different param names."""
|
||||
out = dict(base)
|
||||
if "Wan" in repo:
|
||||
# Wan accepts height/width but expects multiples of 16.
|
||||
out["height"] = (out["height"] // 16) * 16
|
||||
out["width"] = (out["width"] // 16) * 16
|
||||
if "LTX" in repo:
|
||||
# LTX needs 32-multiple resolutions.
|
||||
out["height"] = (out["height"] // 32) * 32
|
||||
out["width"] = (out["width"] // 32) * 32
|
||||
if "CogVideoX" in repo:
|
||||
# CogVideoX has fixed 720x480 default; clamp.
|
||||
out["height"] = min(out["height"], 480)
|
||||
out["width"] = min(out["width"], 720)
|
||||
return out
|
||||
|
||||
|
||||
def generate(
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
model_id: str = "ltx-video",
|
||||
width: int = 704,
|
||||
height: int = 480,
|
||||
num_frames: int = 73,
|
||||
fps: int = 24,
|
||||
steps: int = 30,
|
||||
guidance: float = 3.0,
|
||||
seed: Optional[int] = None,
|
||||
) -> str:
|
||||
spec: ModelSpec | None = find(model_id)
|
||||
if spec is None or spec.kind != "video":
|
||||
raise ValueError(f"Unknown video model: {model_id}")
|
||||
|
||||
pipe = _load_pipeline(spec.repo)
|
||||
|
||||
import torch
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
generator = None
|
||||
if seed is not None:
|
||||
try:
|
||||
generator = torch.Generator(device=get_device()).manual_seed(int(seed))
|
||||
except RuntimeError:
|
||||
generator = torch.Generator().manual_seed(int(seed))
|
||||
|
||||
base = dict(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt or None,
|
||||
width=int(width),
|
||||
height=int(height),
|
||||
num_frames=int(num_frames),
|
||||
num_inference_steps=int(steps),
|
||||
guidance_scale=float(guidance),
|
||||
generator=generator,
|
||||
)
|
||||
kwargs = _model_kwargs(spec.repo, base)
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
out = pipe(**kwargs)
|
||||
frames = out.frames[0]
|
||||
|
||||
import time
|
||||
path = OUTPUTS / f"video_{int(time.time())}.mp4"
|
||||
export_to_video(frames, str(path), fps=int(fps))
|
||||
return str(path)
|
||||
Reference in New Issue
Block a user