Initial commit
This commit is contained in:
+14
@@ -0,0 +1,14 @@
|
|||||||
|
venv/
|
||||||
|
.tools/
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
models/*
|
||||||
|
!models/.gitkeep
|
||||||
|
outputs/
|
||||||
|
config.local.json
|
||||||
|
.env
|
||||||
|
*.safetensors
|
||||||
|
*.ckpt
|
||||||
|
*.bin
|
||||||
|
*.pth
|
||||||
@@ -0,0 +1,61 @@
|
|||||||
|
# Kawai
|
||||||
|
|
||||||
|
Local AI image and video generator. Simple UI. NSFW-capable. Auto GPU detection (Nvidia / AMD / Intel / Apple Silicon / CPU).
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
```
|
||||||
|
python launcher.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Run with **any Python** you have installed. The launcher bootstraps `uv` and uses it to fetch a clean Python 3.11 runtime + venv, then installs the right PyTorch build for your GPU. Nothing about your system Python is touched.
|
||||||
|
|
||||||
|
Works on Windows, Linux, and macOS.
|
||||||
|
|
||||||
|
First run takes a few minutes (uv install + Python 3.11 download + torch + dependencies). Subsequent runs start instantly.
|
||||||
|
|
||||||
|
### Force a specific backend
|
||||||
|
|
||||||
|
Auto-detect picks one of `cuda` (NVIDIA), `rocm` (AMD on Linux), `directml` (AMD/Intel on Windows), `mps` (Apple Silicon), or `cpu`. To override:
|
||||||
|
|
||||||
|
```
|
||||||
|
python launcher.py --backend cuda # force CUDA wheel
|
||||||
|
python launcher.py --backend rocm # AMD on Linux (ROCm)
|
||||||
|
python launcher.py --backend directml # AMD/Intel on Windows
|
||||||
|
python launcher.py --backend mps # Apple Silicon (macOS)
|
||||||
|
python launcher.py --backend cpu # CPU only
|
||||||
|
python launcher.py --reinstall # wipe install marker, re-detect, reinstall torch
|
||||||
|
```
|
||||||
|
|
||||||
|
`--vendor {nvidia,amd,intel,cpu}` is available too if you need to pair (e.g. `--backend directml --vendor intel`). Override is persisted in `config.local.json` and survives relaunches until you pass `--backend` again or `--reinstall`.
|
||||||
|
|
||||||
|
### What the launcher does
|
||||||
|
|
||||||
|
1. Installs `uv` to `.tools/` if not present.
|
||||||
|
2. Creates `venv/` with Python 3.11 (uv downloads the interpreter on demand).
|
||||||
|
3. Detects GPU (Nvidia / AMD / Intel / Apple Silicon / CPU) and installs matching PyTorch wheel.
|
||||||
|
4. Installs latest `diffusers`, `transformers`, etc.
|
||||||
|
5. Opens browser UI at `http://127.0.0.1:7860`.
|
||||||
|
|
||||||
|
### Reset
|
||||||
|
|
||||||
|
Delete `venv/` and `.tools/` to force a clean reinstall.
|
||||||
|
|
||||||
|
## Hardware tiers
|
||||||
|
|
||||||
|
| VRAM | Default image model | Default video model |
|
||||||
|
|------|--------------------|---------------------|
|
||||||
|
| 4 GB | SDXL Turbo (fp16) | disabled |
|
||||||
|
| 8 GB | Pony Diffusion XL | LTX-Video (fp8) |
|
||||||
|
| 12 GB | Illustrious XL | LTX-Video (fp16) |
|
||||||
|
| 16 GB+ | Illustrious XL + refiner | Wan 2.1 |
|
||||||
|
|
||||||
|
User can override defaults in UI.
|
||||||
|
|
||||||
|
## Safety
|
||||||
|
|
||||||
|
CSAM detection on all outputs (NudeNet age classifier + hash check). All other content allowed: NSFW, gore, violence.
|
||||||
|
|
||||||
|
## Status
|
||||||
|
|
||||||
|
Windows + Linux + macOS. AMD on Linux uses ROCm; AMD/Intel on Windows use DirectML; Apple Silicon uses MPS. Intel Macs run on CPU only (no GPU acceleration path).
|
||||||
@@ -0,0 +1,239 @@
|
|||||||
|
"""Gradio UI. Two tabs: Image, Video. Auto-picks defaults from detected hardware."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from backends import models, refiner, safety
|
||||||
|
from backends.device import hardware_info
|
||||||
|
|
||||||
|
ROOT = Path(__file__).parent
|
||||||
|
OUTPUTS = ROOT / "outputs"
|
||||||
|
OUTPUTS.mkdir(exist_ok=True)
|
||||||
|
CONFIG = json.loads((ROOT / "config.json").read_text())
|
||||||
|
|
||||||
|
|
||||||
|
def _hw_summary() -> str:
|
||||||
|
hw = hardware_info()
|
||||||
|
return (
|
||||||
|
f"**{hw['device_name']}** — {hw['vendor'].upper()} via {hw['backend']} "
|
||||||
|
f"— {hw['vram_gb']:.1f} GB VRAM — tier `{hw['tier']}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _models_status_md() -> str:
|
||||||
|
rows = ["| Model | Kind | Min VRAM | Download | Status |", "|---|---|---|---|---|"]
|
||||||
|
for m in models.IMAGE_MODELS + models.VIDEO_MODELS:
|
||||||
|
status = "cached" if models.is_cached(m) else "not downloaded"
|
||||||
|
rows.append(
|
||||||
|
f"| {m.label} | {m.kind} | {m.min_vram_gb:.0f} GB | {m.download_gb:.0f} GB | {status} |"
|
||||||
|
)
|
||||||
|
return "### Models\n\n" + "\n".join(rows)
|
||||||
|
|
||||||
|
|
||||||
|
def _model_choices(kind: str) -> tuple[list[tuple[str, str]], str | None]:
|
||||||
|
hw = hardware_info()
|
||||||
|
available = models.list_for_tier(hw["tier"], kind)
|
||||||
|
choices = [(models.label_with_meta(m), m.id) for m in available]
|
||||||
|
default = models.default_for_tier(hw["tier"], kind)
|
||||||
|
return choices, (default.id if default else None)
|
||||||
|
|
||||||
|
|
||||||
|
def gen_image(
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str,
|
||||||
|
model_id: str,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
steps: int,
|
||||||
|
guidance: float,
|
||||||
|
seed: int,
|
||||||
|
auto_refine: bool,
|
||||||
|
):
|
||||||
|
if not prompt.strip():
|
||||||
|
raise gr.Error("Empty prompt.")
|
||||||
|
|
||||||
|
chk = safety.check_prompt(prompt)
|
||||||
|
if not chk.allowed:
|
||||||
|
raise gr.Error(chk.reason)
|
||||||
|
|
||||||
|
spec = models.find(model_id)
|
||||||
|
if spec and not models.is_cached(spec):
|
||||||
|
gr.Info(f"Downloading {spec.label} (~{spec.download_gb:.0f} GB) on first use. Watch terminal for progress.")
|
||||||
|
|
||||||
|
refined = prompt
|
||||||
|
if auto_refine:
|
||||||
|
refined = refiner.refine(prompt, use_ollama=CONFIG["refiner"]["use_ollama"])
|
||||||
|
|
||||||
|
from backends import image_sdxl
|
||||||
|
|
||||||
|
seed_val = None if seed < 0 else seed
|
||||||
|
if seed_val is None:
|
||||||
|
seed_val = random.randint(0, 2**31 - 1)
|
||||||
|
|
||||||
|
img = image_sdxl.generate(
|
||||||
|
prompt=refined,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
model_id=model_id,
|
||||||
|
width=int(width),
|
||||||
|
height=int(height),
|
||||||
|
steps=int(steps),
|
||||||
|
guidance=float(guidance),
|
||||||
|
seed=seed_val,
|
||||||
|
)
|
||||||
|
|
||||||
|
img_chk = safety.check_image(img)
|
||||||
|
if not img_chk.allowed:
|
||||||
|
raise gr.Error(img_chk.reason)
|
||||||
|
|
||||||
|
out_path = OUTPUTS / f"img_{int(time.time())}_{seed_val}.png"
|
||||||
|
img.save(out_path)
|
||||||
|
info = f"Seed: {seed_val}\n\nPrompt used:\n{refined}"
|
||||||
|
return img, info
|
||||||
|
|
||||||
|
|
||||||
|
def gen_video(
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str,
|
||||||
|
model_id: str,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
num_frames: int,
|
||||||
|
fps: int,
|
||||||
|
steps: int,
|
||||||
|
guidance: float,
|
||||||
|
seed: int,
|
||||||
|
auto_refine: bool,
|
||||||
|
):
|
||||||
|
if not prompt.strip():
|
||||||
|
raise gr.Error("Empty prompt.")
|
||||||
|
|
||||||
|
chk = safety.check_prompt(prompt)
|
||||||
|
if not chk.allowed:
|
||||||
|
raise gr.Error(chk.reason)
|
||||||
|
|
||||||
|
spec = models.find(model_id)
|
||||||
|
if spec and not models.is_cached(spec):
|
||||||
|
gr.Info(f"Downloading {spec.label} (~{spec.download_gb:.0f} GB) on first use. Watch terminal for progress.")
|
||||||
|
|
||||||
|
refined = prompt
|
||||||
|
if auto_refine:
|
||||||
|
refined = refiner.refine(prompt, use_ollama=CONFIG["refiner"]["use_ollama"])
|
||||||
|
|
||||||
|
from backends import video_ltx
|
||||||
|
|
||||||
|
seed_val = None if seed < 0 else seed
|
||||||
|
if seed_val is None:
|
||||||
|
seed_val = random.randint(0, 2**31 - 1)
|
||||||
|
|
||||||
|
path = video_ltx.generate(
|
||||||
|
prompt=refined,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
model_id=model_id,
|
||||||
|
width=int(width),
|
||||||
|
height=int(height),
|
||||||
|
num_frames=int(num_frames),
|
||||||
|
fps=int(fps),
|
||||||
|
steps=int(steps),
|
||||||
|
guidance=float(guidance),
|
||||||
|
seed=seed_val,
|
||||||
|
)
|
||||||
|
info = f"Seed: {seed_val}\n\nPrompt used:\n{refined}"
|
||||||
|
return path, info
|
||||||
|
|
||||||
|
|
||||||
|
def build_ui() -> gr.Blocks:
|
||||||
|
img_choices, img_default = _model_choices("image")
|
||||||
|
vid_choices, vid_default = _model_choices("video")
|
||||||
|
img_def = CONFIG["image_defaults"]
|
||||||
|
vid_def = CONFIG["video_defaults"]
|
||||||
|
|
||||||
|
with gr.Blocks(title="Kawai", analytics_enabled=False) as ui:
|
||||||
|
gr.Markdown("# Kawai\nLocal AI image and video generator.")
|
||||||
|
gr.Markdown(_hw_summary())
|
||||||
|
|
||||||
|
with gr.Tabs():
|
||||||
|
with gr.Tab("Image"):
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
i_prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Describe what you want...")
|
||||||
|
i_neg = gr.Textbox(label="Negative prompt", lines=2, value=img_def["negative_prompt"])
|
||||||
|
i_refine = gr.Checkbox(label="Auto-refine prompt with local LLM", value=True)
|
||||||
|
i_model = gr.Dropdown(choices=img_choices, value=img_default, label="Model")
|
||||||
|
with gr.Row():
|
||||||
|
i_w = gr.Slider(512, 1536, value=img_def["width"], step=64, label="Width")
|
||||||
|
i_h = gr.Slider(512, 1536, value=img_def["height"], step=64, label="Height")
|
||||||
|
with gr.Row():
|
||||||
|
i_steps = gr.Slider(1, 80, value=img_def["steps"], step=1, label="Steps")
|
||||||
|
i_guidance = gr.Slider(0.0, 15.0, value=img_def["guidance"], step=0.1, label="Guidance")
|
||||||
|
i_seed = gr.Number(value=-1, label="Seed (-1 = random)", precision=0)
|
||||||
|
i_btn = gr.Button("Generate", variant="primary")
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
i_out = gr.Image(label="Output", type="pil")
|
||||||
|
i_info = gr.Textbox(label="Info", lines=6, interactive=False)
|
||||||
|
i_btn.click(
|
||||||
|
gen_image,
|
||||||
|
inputs=[i_prompt, i_neg, i_model, i_w, i_h, i_steps, i_guidance, i_seed, i_refine],
|
||||||
|
outputs=[i_out, i_info],
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Tab("Video"):
|
||||||
|
if not vid_choices:
|
||||||
|
gr.Markdown("**Video disabled** — detected hardware lacks VRAM for any video model.")
|
||||||
|
else:
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
v_prompt = gr.Textbox(label="Prompt", lines=3)
|
||||||
|
v_neg = gr.Textbox(label="Negative prompt", lines=2, value="")
|
||||||
|
v_refine = gr.Checkbox(label="Auto-refine prompt with local LLM", value=True)
|
||||||
|
v_model = gr.Dropdown(choices=vid_choices, value=vid_default, label="Model")
|
||||||
|
with gr.Row():
|
||||||
|
v_w = gr.Slider(384, 1024, value=vid_def["width"], step=32, label="Width")
|
||||||
|
v_h = gr.Slider(256, 1024, value=vid_def["height"], step=32, label="Height")
|
||||||
|
with gr.Row():
|
||||||
|
v_frames = gr.Slider(17, 161, value=vid_def["num_frames"], step=8, label="Frames")
|
||||||
|
v_fps = gr.Slider(8, 30, value=vid_def["fps"], step=1, label="FPS")
|
||||||
|
with gr.Row():
|
||||||
|
v_steps = gr.Slider(10, 60, value=vid_def["steps"], step=1, label="Steps")
|
||||||
|
v_guidance = gr.Slider(0.0, 10.0, value=vid_def["guidance"], step=0.1, label="Guidance")
|
||||||
|
v_seed = gr.Number(value=-1, label="Seed (-1 = random)", precision=0)
|
||||||
|
v_btn = gr.Button("Generate", variant="primary")
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
v_out = gr.Video(label="Output")
|
||||||
|
v_info = gr.Textbox(label="Info", lines=6, interactive=False)
|
||||||
|
v_btn.click(
|
||||||
|
gen_video,
|
||||||
|
inputs=[v_prompt, v_neg, v_model, v_w, v_h, v_frames, v_fps, v_steps, v_guidance, v_seed, v_refine],
|
||||||
|
outputs=[v_out, v_info],
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Tab("System"):
|
||||||
|
gr.Markdown(_hw_summary())
|
||||||
|
gr.Markdown(_models_status_md())
|
||||||
|
gr.Markdown(
|
||||||
|
"**Output folder:** `outputs/`\n\n"
|
||||||
|
"**Models cache:** `models/diffusers/`\n\n"
|
||||||
|
"**Prompt refiner:** Ollama with `dolphin-llama3:8b` if running, else GPT-2 fallback.\n\n"
|
||||||
|
"Install Ollama: https://ollama.com/ then `ollama pull dolphin-llama3`.\n\n"
|
||||||
|
"**Safety:** CSAM-gated only (prompt keyword gate + face age check on nude outputs). All other content allowed.\n\n"
|
||||||
|
"**Note:** First use of a model triggers download (7–24 GB). Keep this terminal open during download."
|
||||||
|
)
|
||||||
|
|
||||||
|
return ui
|
||||||
|
|
||||||
|
|
||||||
|
def run() -> None:
|
||||||
|
ui = build_ui()
|
||||||
|
ui.queue().launch(
|
||||||
|
server_name=CONFIG["ui"]["host"],
|
||||||
|
server_port=CONFIG["ui"]["port"],
|
||||||
|
inbrowser=CONFIG["ui"]["open_browser"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
"""Torch device selection. Wraps cuda / directml / cpu behind one helper."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
ROOT = Path(__file__).parent.parent
|
||||||
|
HARDWARE_CACHE = ROOT / "config.local.json"
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def hardware_info() -> dict:
|
||||||
|
if HARDWARE_CACHE.exists():
|
||||||
|
return json.loads(HARDWARE_CACHE.read_text())
|
||||||
|
from . import hardware
|
||||||
|
info = hardware.detect()
|
||||||
|
return {
|
||||||
|
"vendor": info.vendor,
|
||||||
|
"backend": info.backend,
|
||||||
|
"device_name": info.device_name,
|
||||||
|
"vram_gb": info.vram_gb,
|
||||||
|
"tier": info.tier,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_device():
|
||||||
|
import torch
|
||||||
|
backend = hardware_info()["backend"]
|
||||||
|
# ROCm builds of torch expose the cuda namespace.
|
||||||
|
if backend in ("cuda", "rocm") and torch.cuda.is_available():
|
||||||
|
return torch.device("cuda")
|
||||||
|
if backend == "directml":
|
||||||
|
try:
|
||||||
|
import torch_directml
|
||||||
|
return torch_directml.device()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
if backend == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
||||||
|
return torch.device("mps")
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def torch_dtype():
|
||||||
|
import torch
|
||||||
|
backend = hardware_info()["backend"]
|
||||||
|
if backend == "cpu":
|
||||||
|
return torch.float32
|
||||||
|
# MPS supports fp16 for diffusers; bf16 has gaps. Stick with fp16.
|
||||||
|
return torch.float16
|
||||||
@@ -0,0 +1,315 @@
|
|||||||
|
"""GPU and VRAM detection. Returns vendor + tier used to pick torch wheel and default models."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import platform
|
||||||
|
import subprocess
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
Vendor = Literal["nvidia", "amd", "intel", "apple", "cpu"]
|
||||||
|
Backend = Literal["cuda", "rocm", "directml", "mps", "cpu"]
|
||||||
|
SUPPORTED_BACKENDS: tuple[str, ...] = ("auto", "cuda", "rocm", "directml", "mps", "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HardwareInfo:
|
||||||
|
vendor: Vendor
|
||||||
|
backend: Backend
|
||||||
|
device_name: str
|
||||||
|
vram_gb: float
|
||||||
|
tier: Literal["cpu", "low", "mid", "high", "ultra"]
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_mac_gpu() -> tuple[str, float] | None:
|
||||||
|
"""Apple Silicon: report chip name + unified memory (proxy for VRAM).
|
||||||
|
Intel Mac: returns None (no GPU acceleration path)."""
|
||||||
|
if platform.system() != "Darwin":
|
||||||
|
return None
|
||||||
|
machine = platform.machine().lower()
|
||||||
|
if machine not in ("arm64", "aarch64"):
|
||||||
|
return None
|
||||||
|
name = "Apple Silicon"
|
||||||
|
brand = _run(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||||
|
if brand:
|
||||||
|
name = brand.strip()
|
||||||
|
mem_raw = _run(["sysctl", "-n", "hw.memsize"])
|
||||||
|
mem_gb = 0.0
|
||||||
|
if mem_raw:
|
||||||
|
try:
|
||||||
|
mem_gb = int(mem_raw) / (1024 ** 3)
|
||||||
|
except ValueError:
|
||||||
|
mem_gb = 0.0
|
||||||
|
# Unified memory: GPU can address ~75% in practice. Use that as VRAM proxy.
|
||||||
|
vram_gb = mem_gb * 0.75 if mem_gb else 0.0
|
||||||
|
return name, vram_gb
|
||||||
|
|
||||||
|
|
||||||
|
def _run(cmd: list[str]) -> str:
|
||||||
|
try:
|
||||||
|
out = subprocess.run(cmd, capture_output=True, text=True, timeout=10, check=False)
|
||||||
|
return out.stdout.strip()
|
||||||
|
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_nvidia() -> tuple[str, float] | None:
|
||||||
|
out = _run(["nvidia-smi", "--query-gpu=name,memory.total", "--format=csv,noheader,nounits"])
|
||||||
|
if not out:
|
||||||
|
return None
|
||||||
|
first = out.splitlines()[0]
|
||||||
|
name, mem = [p.strip() for p in first.split(",", 1)]
|
||||||
|
return name, float(mem) / 1024.0
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_dxgi() -> list[tuple[str, float, str]]:
|
||||||
|
"""Enumerate all DXGI adapters via PowerShell. Returns list of (name, vram_gb, vendor_hint)."""
|
||||||
|
if platform.system() != "Windows":
|
||||||
|
return []
|
||||||
|
ps = (
|
||||||
|
"Get-CimInstance Win32_VideoController | "
|
||||||
|
"Select-Object Name, AdapterRAM | "
|
||||||
|
"ConvertTo-Json -Compress"
|
||||||
|
)
|
||||||
|
out = _run(["powershell", "-NoProfile", "-Command", ps])
|
||||||
|
if not out:
|
||||||
|
return []
|
||||||
|
import json
|
||||||
|
try:
|
||||||
|
data = json.loads(out)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return []
|
||||||
|
if isinstance(data, dict):
|
||||||
|
data = [data]
|
||||||
|
results: list[tuple[str, float, str]] = []
|
||||||
|
for entry in data:
|
||||||
|
name = (entry.get("Name") or "").strip()
|
||||||
|
ram = entry.get("AdapterRAM") or 0
|
||||||
|
# Win32_VideoController caps AdapterRAM at 4GB on many systems. Trust value but flag below.
|
||||||
|
vram_gb = float(ram) / (1024 ** 3) if ram else 0.0
|
||||||
|
low = name.lower()
|
||||||
|
if "nvidia" in low or "geforce" in low or "rtx" in low or "gtx" in low:
|
||||||
|
vendor = "nvidia"
|
||||||
|
elif "amd" in low or "radeon" in low or "rx " in low:
|
||||||
|
vendor = "amd"
|
||||||
|
elif "intel" in low or "arc" in low or "iris" in low:
|
||||||
|
vendor = "intel"
|
||||||
|
else:
|
||||||
|
vendor = "unknown"
|
||||||
|
if name:
|
||||||
|
results.append((name, vram_gb, vendor))
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_linux_gpus() -> list[tuple[str, float, str]]:
|
||||||
|
"""Enumerate Linux GPUs. Returns list of (name, vram_gb, vendor_hint)."""
|
||||||
|
if platform.system() != "Linux":
|
||||||
|
return []
|
||||||
|
results: list[tuple[str, float, str]] = []
|
||||||
|
|
||||||
|
# AMD via rocm-smi (most accurate VRAM)
|
||||||
|
rocm = _run(["rocm-smi", "--showproductname", "--showmeminfo", "vram", "--json"])
|
||||||
|
if rocm:
|
||||||
|
import json as _json
|
||||||
|
try:
|
||||||
|
data = _json.loads(rocm)
|
||||||
|
for k, v in data.items():
|
||||||
|
if not isinstance(v, dict):
|
||||||
|
continue
|
||||||
|
name = v.get("Card series") or v.get("Card model") or v.get("Device Name") or k
|
||||||
|
vram_bytes = 0
|
||||||
|
for key in ("VRAM Total Memory (B)", "vram total memory (b)", "VRAM Total"):
|
||||||
|
if key in v:
|
||||||
|
try:
|
||||||
|
vram_bytes = int(v[key])
|
||||||
|
break
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
vram_gb = float(vram_bytes) / (1024 ** 3) if vram_bytes else 0.0
|
||||||
|
results.append((str(name).strip(), vram_gb, "amd"))
|
||||||
|
except (ValueError, KeyError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if results:
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Fallback: lspci for vendor hints (no reliable VRAM number)
|
||||||
|
lspci = _run(["bash", "-c", "lspci -nn | grep -Ei 'vga|3d|display'"])
|
||||||
|
for line in lspci.splitlines():
|
||||||
|
low = line.lower()
|
||||||
|
if "nvidia" in low:
|
||||||
|
vendor = "nvidia"
|
||||||
|
elif "amd" in low or "advanced micro" in low or "ati " in low:
|
||||||
|
vendor = "amd"
|
||||||
|
elif "intel" in low:
|
||||||
|
vendor = "intel"
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
# Extract device name after the colon
|
||||||
|
name = line.split(":", 2)[-1].strip() if ":" in line else line.strip()
|
||||||
|
results.append((name, 0.0, vendor))
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _vendor_from_backend(backend: str) -> Vendor:
|
||||||
|
return {
|
||||||
|
"cuda": "nvidia",
|
||||||
|
"rocm": "amd",
|
||||||
|
"directml": "amd",
|
||||||
|
"mps": "apple",
|
||||||
|
"cpu": "cpu",
|
||||||
|
}.get(backend, "cpu") # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
|
def _vram_tier(vram_gb: float) -> Literal["cpu", "low", "mid", "high", "ultra"]:
|
||||||
|
if vram_gb < 1:
|
||||||
|
return "cpu"
|
||||||
|
if vram_gb < 6:
|
||||||
|
return "low"
|
||||||
|
if vram_gb < 10:
|
||||||
|
return "mid"
|
||||||
|
if vram_gb < 14:
|
||||||
|
return "high"
|
||||||
|
return "ultra"
|
||||||
|
|
||||||
|
|
||||||
|
def detect(force_backend: str | None = None, force_vendor: str | None = None) -> HardwareInfo:
|
||||||
|
"""Auto-detect hardware. If force_backend is set (cuda/rocm/directml/cpu), skip detection
|
||||||
|
for that decision but still try to discover device name + VRAM for tier sizing."""
|
||||||
|
is_linux = platform.system() == "Linux"
|
||||||
|
is_windows = platform.system() == "Windows"
|
||||||
|
is_mac = platform.system() == "Darwin"
|
||||||
|
|
||||||
|
# Validate forced backend against platform
|
||||||
|
if force_backend == "directml" and not is_windows:
|
||||||
|
raise RuntimeError(
|
||||||
|
"DirectML backend is Windows-only. On Linux use --backend rocm (AMD), "
|
||||||
|
"--backend cuda (NVIDIA), or --backend cpu. On macOS use --backend mps."
|
||||||
|
)
|
||||||
|
if force_backend == "mps" and not is_mac:
|
||||||
|
raise RuntimeError("MPS backend is macOS-only (Apple Silicon).")
|
||||||
|
if force_backend == "rocm" and is_mac:
|
||||||
|
raise RuntimeError("ROCm is not available on macOS. Use --backend mps or --backend cpu.")
|
||||||
|
if force_backend == "cuda" and is_mac:
|
||||||
|
raise RuntimeError("CUDA is not available on macOS. Use --backend mps or --backend cpu.")
|
||||||
|
|
||||||
|
# Gather best-effort device info regardless of force.
|
||||||
|
nv = _detect_nvidia() if not is_mac else None
|
||||||
|
win_adapters = _detect_dxgi() if is_windows else []
|
||||||
|
win_adapters = [a for a in win_adapters if "basic" not in a[0].lower() and "microsoft" not in a[0].lower()]
|
||||||
|
lin_gpus = _detect_linux_gpus() if is_linux else []
|
||||||
|
mac_gpu = _detect_mac_gpu() if is_mac else None
|
||||||
|
|
||||||
|
def _best_for(vendor: str) -> tuple[str, float] | None:
|
||||||
|
if vendor == "nvidia" and nv:
|
||||||
|
return nv
|
||||||
|
pool = [(n, v) for n, v, h in (win_adapters + lin_gpus) if h == vendor]
|
||||||
|
if not pool:
|
||||||
|
return None
|
||||||
|
pool.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
return pool[0]
|
||||||
|
|
||||||
|
if force_backend and force_backend != "auto":
|
||||||
|
vendor: Vendor = force_vendor or _vendor_from_backend(force_backend) # type: ignore[assignment]
|
||||||
|
if force_backend == "mps":
|
||||||
|
match = mac_gpu
|
||||||
|
else:
|
||||||
|
match = _best_for(vendor) or (nv if force_backend == "cuda" else None)
|
||||||
|
if match:
|
||||||
|
name, vram = match
|
||||||
|
else:
|
||||||
|
# No matching device found, but user forced this backend — proceed with unknown VRAM.
|
||||||
|
name = f"{vendor.upper()} (forced)"
|
||||||
|
vram = 0.0
|
||||||
|
return HardwareInfo(vendor, force_backend, name, vram, _vram_tier(vram)) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
# Auto path
|
||||||
|
if is_mac and mac_gpu:
|
||||||
|
name, vram = mac_gpu
|
||||||
|
return HardwareInfo("apple", "mps", name, vram, _vram_tier(vram))
|
||||||
|
|
||||||
|
if nv:
|
||||||
|
name, vram = nv
|
||||||
|
return HardwareInfo("nvidia", "cuda", name, vram, _vram_tier(vram))
|
||||||
|
|
||||||
|
if is_windows and win_adapters:
|
||||||
|
win_adapters.sort(key=lambda a: a[1], reverse=True)
|
||||||
|
name, vram, hint = win_adapters[0]
|
||||||
|
# AdapterRAM caps at 4GB for many cards. Bump if name suggests modern discrete.
|
||||||
|
if vram <= 4.1 and any(k in name.lower() for k in ("rx 6", "rx 7", "arc a", "arc b")):
|
||||||
|
vram = 8.0
|
||||||
|
if hint in ("amd", "intel"):
|
||||||
|
return HardwareInfo(hint, "directml", name, vram, _vram_tier(vram))
|
||||||
|
if hint == "nvidia":
|
||||||
|
return HardwareInfo("nvidia", "directml", name, vram, _vram_tier(vram))
|
||||||
|
|
||||||
|
if is_linux and lin_gpus:
|
||||||
|
lin_gpus.sort(key=lambda a: a[1], reverse=True)
|
||||||
|
name, vram, hint = lin_gpus[0]
|
||||||
|
if hint == "amd":
|
||||||
|
return HardwareInfo("amd", "rocm", name, vram, _vram_tier(vram))
|
||||||
|
if hint == "nvidia":
|
||||||
|
# nvidia-smi failed but card is nvidia: drivers likely missing. Fall through to CPU.
|
||||||
|
return HardwareInfo("cpu", "cpu", f"NVIDIA driver missing — {name}", 0.0, "cpu")
|
||||||
|
if hint == "intel":
|
||||||
|
# No good Intel-on-Linux torch path here; default to CPU.
|
||||||
|
return HardwareInfo("cpu", "cpu", f"Intel GPU (no backend) — {name}", 0.0, "cpu")
|
||||||
|
|
||||||
|
return HardwareInfo("cpu", "cpu", platform.processor() or "CPU", 0.0, "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def directml_supported() -> bool:
|
||||||
|
"""torch-directml ships wheels for Python 3.10 / 3.11. The launcher pins the
|
||||||
|
venv to 3.11, so this is True in the running app."""
|
||||||
|
import sys
|
||||||
|
return sys.version_info[:2] in {(3, 10), (3, 11)}
|
||||||
|
|
||||||
|
|
||||||
|
def torch_install_args(info: HardwareInfo) -> list[str]:
|
||||||
|
"""Return uv-pip install args for the right PyTorch build.
|
||||||
|
|
||||||
|
Launcher pins venv to Python 3.11, so torch-directml wheels are always available
|
||||||
|
for AMD/Intel paths. Latest stable torch is used elsewhere.
|
||||||
|
"""
|
||||||
|
if info.backend == "cuda":
|
||||||
|
return [
|
||||||
|
"torch",
|
||||||
|
"torchvision",
|
||||||
|
"--index-url",
|
||||||
|
"https://download.pytorch.org/whl/cu124",
|
||||||
|
]
|
||||||
|
if info.backend == "rocm":
|
||||||
|
# ROCm wheels are Linux-only. Index pinned to a stable ROCm release line.
|
||||||
|
return [
|
||||||
|
"torch",
|
||||||
|
"torchvision",
|
||||||
|
"--index-url",
|
||||||
|
"https://download.pytorch.org/whl/rocm6.1",
|
||||||
|
]
|
||||||
|
if info.backend == "directml":
|
||||||
|
# torch-directml currently pins to torch 2.4.x. Match it.
|
||||||
|
return [
|
||||||
|
"torch>=2.4,<2.5",
|
||||||
|
"torchvision>=0.19,<0.20",
|
||||||
|
"torch-directml>=0.2.5",
|
||||||
|
]
|
||||||
|
if info.backend == "mps":
|
||||||
|
# Default PyPI torch wheel ships MPS support on macOS arm64. No custom index.
|
||||||
|
return ["torch", "torchvision"]
|
||||||
|
# CPU. macOS uses default PyPI wheels (no /whl/cpu index for darwin).
|
||||||
|
if platform.system() == "Darwin":
|
||||||
|
return ["torch", "torchvision"]
|
||||||
|
return [
|
||||||
|
"torch",
|
||||||
|
"torchvision",
|
||||||
|
"--index-url",
|
||||||
|
"https://download.pytorch.org/whl/cpu",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
info = detect()
|
||||||
|
print(f"Vendor: {info.vendor}")
|
||||||
|
print(f"Backend: {info.backend}")
|
||||||
|
print(f"Device: {info.device_name}")
|
||||||
|
print(f"VRAM: {info.vram_gb:.1f} GB")
|
||||||
|
print(f"Tier: {info.tier}")
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
"""SDXL-family image generation. Loaded lazily, cached by repo id."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from .device import get_device, torch_dtype
|
||||||
|
from .memory import apply_memory_strategy
|
||||||
|
from .models import ModelSpec, find
|
||||||
|
|
||||||
|
ROOT = Path(__file__).parent.parent
|
||||||
|
MODELS_DIR = ROOT / "models"
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def _load_pipeline(repo: str, kind: str):
|
||||||
|
import torch
|
||||||
|
from diffusers import (
|
||||||
|
AutoPipelineForText2Image,
|
||||||
|
DiffusionPipeline,
|
||||||
|
FluxPipeline,
|
||||||
|
)
|
||||||
|
|
||||||
|
device = get_device()
|
||||||
|
dtype = torch_dtype()
|
||||||
|
cache_dir = str(MODELS_DIR / "diffusers")
|
||||||
|
|
||||||
|
if "FLUX" in repo:
|
||||||
|
pipe = FluxPipeline.from_pretrained(repo, torch_dtype=dtype, cache_dir=cache_dir)
|
||||||
|
elif "turbo" in repo.lower():
|
||||||
|
pipe = AutoPipelineForText2Image.from_pretrained(
|
||||||
|
repo, torch_dtype=dtype, variant="fp16", cache_dir=cache_dir
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pipe = DiffusionPipeline.from_pretrained(
|
||||||
|
repo, torch_dtype=dtype, use_safetensors=True, cache_dir=cache_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
# Disable built-in safety checker. We use our own filter (CSAM-only).
|
||||||
|
if hasattr(pipe, "safety_checker"):
|
||||||
|
pipe.safety_checker = None
|
||||||
|
if hasattr(pipe, "requires_safety_checker"):
|
||||||
|
pipe.requires_safety_checker = False
|
||||||
|
|
||||||
|
apply_memory_strategy(pipe)
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
model_id: str = "pony-xl",
|
||||||
|
width: int = 1024,
|
||||||
|
height: int = 1024,
|
||||||
|
steps: int = 30,
|
||||||
|
guidance: float = 7.0,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
) -> Image.Image:
|
||||||
|
spec: ModelSpec | None = find(model_id)
|
||||||
|
if spec is None or spec.kind != "image":
|
||||||
|
raise ValueError(f"Unknown image model: {model_id}")
|
||||||
|
|
||||||
|
pipe = _load_pipeline(spec.repo, spec.kind)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
generator = None
|
||||||
|
if seed is not None:
|
||||||
|
try:
|
||||||
|
generator = torch.Generator(device=get_device()).manual_seed(int(seed))
|
||||||
|
except RuntimeError:
|
||||||
|
generator = torch.Generator().manual_seed(int(seed))
|
||||||
|
|
||||||
|
if "turbo" in spec.repo.lower():
|
||||||
|
steps = max(1, min(steps, 4))
|
||||||
|
guidance = 0.0
|
||||||
|
if "FLUX" in spec.repo:
|
||||||
|
guidance = 0.0
|
||||||
|
|
||||||
|
out = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt or None,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
num_inference_steps=steps,
|
||||||
|
guidance_scale=guidance,
|
||||||
|
generator=generator,
|
||||||
|
)
|
||||||
|
return out.images[0]
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
"""Memory strategy selection. CUDA supports cpu_offload; DirectML/CPU do not.
|
||||||
|
|
||||||
|
Apply per-pipeline based on backend + VRAM tier. All paths reduce peak VRAM
|
||||||
|
without breaking on non-CUDA devices.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .device import get_device, hardware_info
|
||||||
|
|
||||||
|
|
||||||
|
def apply_memory_strategy(pipe) -> None:
|
||||||
|
"""Apply VRAM-saving knobs that match the active backend."""
|
||||||
|
info = hardware_info()
|
||||||
|
backend = info["backend"]
|
||||||
|
vram = info["vram_gb"]
|
||||||
|
|
||||||
|
# Always-safe: VAE tiling/slicing work on any device. Cuts peak VRAM during decode.
|
||||||
|
# Newer diffusers (>=0.32) prefers calling on the VAE directly.
|
||||||
|
vae = getattr(pipe, "vae", None)
|
||||||
|
if vae is not None:
|
||||||
|
for fn in ("enable_slicing", "enable_tiling"):
|
||||||
|
if hasattr(vae, fn):
|
||||||
|
try:
|
||||||
|
getattr(vae, fn)()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if hasattr(pipe, "enable_attention_slicing"):
|
||||||
|
try:
|
||||||
|
pipe.enable_attention_slicing()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if backend in ("cuda", "rocm"):
|
||||||
|
# ROCm builds expose the cuda API, so accelerate offload hooks work the same way.
|
||||||
|
# Offload only if VRAM tight.
|
||||||
|
if vram < 10:
|
||||||
|
try:
|
||||||
|
pipe.enable_sequential_cpu_offload()
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
pipe.to(get_device())
|
||||||
|
return
|
||||||
|
|
||||||
|
if backend == "directml":
|
||||||
|
# DirectML lacks accelerate hook support. Move whole pipe to device.
|
||||||
|
# Slicing already enabled above keeps peak in check.
|
||||||
|
try:
|
||||||
|
pipe.to(get_device())
|
||||||
|
except Exception:
|
||||||
|
# Some pipes have components that won't move cleanly; fall back to CPU.
|
||||||
|
pipe.to("cpu")
|
||||||
|
return
|
||||||
|
|
||||||
|
if backend == "mps":
|
||||||
|
# Apple Silicon shares unified memory with CPU. accelerate's sequential offload
|
||||||
|
# has spotty MPS support; rely on slicing/tiling already enabled above.
|
||||||
|
# Tight memory tier: keep on CPU and let model_cpu_offload move chunks if available.
|
||||||
|
if vram < 12:
|
||||||
|
try:
|
||||||
|
pipe.enable_model_cpu_offload(device="mps")
|
||||||
|
return
|
||||||
|
except TypeError:
|
||||||
|
try:
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
pipe.to(get_device())
|
||||||
|
except Exception:
|
||||||
|
pipe.to("cpu")
|
||||||
|
return
|
||||||
|
|
||||||
|
# CPU
|
||||||
|
pipe.to("cpu")
|
||||||
@@ -0,0 +1,163 @@
|
|||||||
|
"""Model registry. VRAM tier drives default. User can override."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
Kind = Literal["image", "video"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ModelSpec:
|
||||||
|
id: str
|
||||||
|
label: str
|
||||||
|
repo: str
|
||||||
|
kind: Kind
|
||||||
|
min_vram_gb: float
|
||||||
|
nsfw_capable: bool
|
||||||
|
download_gb: float = 0.0
|
||||||
|
notes: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
IMAGE_MODELS: list[ModelSpec] = [
|
||||||
|
ModelSpec(
|
||||||
|
id="sdxl-turbo",
|
||||||
|
label="SDXL Turbo (fast, low VRAM)",
|
||||||
|
repo="stabilityai/sdxl-turbo",
|
||||||
|
kind="image",
|
||||||
|
min_vram_gb=4.0,
|
||||||
|
nsfw_capable=False,
|
||||||
|
download_gb=7.0,
|
||||||
|
notes="1-4 step sampling. Censored base. Good for safe content on weak GPUs.",
|
||||||
|
),
|
||||||
|
ModelSpec(
|
||||||
|
id="pony-xl",
|
||||||
|
label="Pony Diffusion XL v6 (NSFW, anime/realistic)",
|
||||||
|
repo="AstraliteHeart/pony-diffusion-v6",
|
||||||
|
kind="image",
|
||||||
|
min_vram_gb=8.0,
|
||||||
|
nsfw_capable=True,
|
||||||
|
download_gb=7.0,
|
||||||
|
notes="Use score_9, score_8_up tags. Strong NSFW capability.",
|
||||||
|
),
|
||||||
|
ModelSpec(
|
||||||
|
id="illustrious-xl",
|
||||||
|
label="Illustrious XL v0.1 (NSFW, illustration)",
|
||||||
|
repo="OnomaAIResearch/Illustrious-xl-early-release-v0",
|
||||||
|
kind="image",
|
||||||
|
min_vram_gb=10.0,
|
||||||
|
nsfw_capable=True,
|
||||||
|
download_gb=7.0,
|
||||||
|
notes="High-detail anime/illustration. NSFW-capable.",
|
||||||
|
),
|
||||||
|
ModelSpec(
|
||||||
|
id="sdxl-base",
|
||||||
|
label="SDXL Base 1.0 (versatile)",
|
||||||
|
repo="stabilityai/stable-diffusion-xl-base-1.0",
|
||||||
|
kind="image",
|
||||||
|
min_vram_gb=8.0,
|
||||||
|
nsfw_capable=False,
|
||||||
|
download_gb=7.0,
|
||||||
|
notes="General-purpose. Censored.",
|
||||||
|
),
|
||||||
|
ModelSpec(
|
||||||
|
id="flux-schnell",
|
||||||
|
label="FLUX.1 Schnell (high quality, 12GB+)",
|
||||||
|
repo="black-forest-labs/FLUX.1-schnell",
|
||||||
|
kind="image",
|
||||||
|
min_vram_gb=12.0,
|
||||||
|
nsfw_capable=False,
|
||||||
|
download_gb=24.0,
|
||||||
|
notes="State-of-the-art prompt adherence. 4-step.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
VIDEO_MODELS: list[ModelSpec] = [
|
||||||
|
ModelSpec(
|
||||||
|
id="ltx-video",
|
||||||
|
label="LTX-Video 0.9 (fast, 8GB+)",
|
||||||
|
repo="Lightricks/LTX-Video",
|
||||||
|
kind="video",
|
||||||
|
min_vram_gb=8.0,
|
||||||
|
nsfw_capable=True,
|
||||||
|
download_gb=18.0,
|
||||||
|
notes="2-5 second clips at 24fps. Fast generation.",
|
||||||
|
),
|
||||||
|
ModelSpec(
|
||||||
|
id="cogvideox-5b",
|
||||||
|
label="CogVideoX 5B (12GB+)",
|
||||||
|
repo="THUDM/CogVideoX-5b",
|
||||||
|
kind="video",
|
||||||
|
min_vram_gb=12.0,
|
||||||
|
nsfw_capable=False,
|
||||||
|
download_gb=20.0,
|
||||||
|
notes="6 second clips. Good motion quality.",
|
||||||
|
),
|
||||||
|
ModelSpec(
|
||||||
|
id="wan-2-1",
|
||||||
|
label="Wan 2.1 T2V 1.3B (16GB+)",
|
||||||
|
repo="Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
|
||||||
|
kind="video",
|
||||||
|
min_vram_gb=14.0,
|
||||||
|
nsfw_capable=True,
|
||||||
|
download_gb=16.0,
|
||||||
|
notes="Top-tier open video model. Needs diffusers>=0.32.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def default_for_tier(tier: str, kind: Kind) -> ModelSpec | None:
|
||||||
|
pool = IMAGE_MODELS if kind == "image" else VIDEO_MODELS
|
||||||
|
tier_min = {"cpu": 0.0, "low": 0.0, "mid": 8.0, "high": 10.0, "ultra": 12.0}
|
||||||
|
target = tier_min.get(tier, 0.0)
|
||||||
|
if tier == "cpu" and kind == "video":
|
||||||
|
return None
|
||||||
|
if tier in ("cpu", "low") and kind == "image":
|
||||||
|
return next(m for m in pool if m.id == "sdxl-turbo")
|
||||||
|
if tier == "mid" and kind == "image":
|
||||||
|
return next(m for m in pool if m.id == "pony-xl")
|
||||||
|
if tier == "high" and kind == "image":
|
||||||
|
return next(m for m in pool if m.id == "illustrious-xl")
|
||||||
|
if tier == "ultra" and kind == "image":
|
||||||
|
return next(m for m in pool if m.id == "illustrious-xl")
|
||||||
|
if tier in ("low",) and kind == "video":
|
||||||
|
return None
|
||||||
|
if tier == "mid" and kind == "video":
|
||||||
|
return next(m for m in pool if m.id == "ltx-video")
|
||||||
|
if tier == "high" and kind == "video":
|
||||||
|
return next(m for m in pool if m.id == "ltx-video")
|
||||||
|
if tier == "ultra" and kind == "video":
|
||||||
|
return next(m for m in pool if m.id == "wan-2-1")
|
||||||
|
return next((m for m in pool if m.min_vram_gb <= target), None)
|
||||||
|
|
||||||
|
|
||||||
|
def list_for_tier(tier: str, kind: Kind) -> list[ModelSpec]:
|
||||||
|
pool = IMAGE_MODELS if kind == "image" else VIDEO_MODELS
|
||||||
|
tier_max = {"cpu": 4.0, "low": 6.0, "mid": 10.0, "high": 14.0, "ultra": 1000.0}
|
||||||
|
cap = tier_max.get(tier, 1000.0)
|
||||||
|
out = [m for m in pool if m.min_vram_gb <= cap + 2.0]
|
||||||
|
out.sort(key=lambda m: m.min_vram_gb)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def find(model_id: str) -> ModelSpec | None:
|
||||||
|
for m in IMAGE_MODELS + VIDEO_MODELS:
|
||||||
|
if m.id == model_id:
|
||||||
|
return m
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_cached(spec: ModelSpec) -> bool:
|
||||||
|
"""Best-effort: check if HF snapshot exists in our cache dir."""
|
||||||
|
from pathlib import Path
|
||||||
|
cache = Path(__file__).parent.parent / "models" / "diffusers"
|
||||||
|
if not cache.exists():
|
||||||
|
return False
|
||||||
|
folder = "models--" + spec.repo.replace("/", "--")
|
||||||
|
return (cache / folder).exists()
|
||||||
|
|
||||||
|
|
||||||
|
def label_with_meta(spec: ModelSpec) -> str:
|
||||||
|
nsfw = " · NSFW✓" if spec.nsfw_capable else ""
|
||||||
|
cached = " · cached" if is_cached(spec) else f" · {spec.download_gb:.0f}GB download"
|
||||||
|
return f"{spec.label}{nsfw}{cached}"
|
||||||
@@ -0,0 +1,82 @@
|
|||||||
|
"""Prompt refiner. Tries Ollama (local LLM) first, falls back to GPT-2 expansion.
|
||||||
|
|
||||||
|
Ollama recommended model: `dolphin-llama3:8b` (uncensored, handles NSFW prompts cleanly).
|
||||||
|
Install: https://ollama.com/ then `ollama pull dolphin-llama3`.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
OLLAMA_URL = "http://127.0.0.1:11434/api/generate"
|
||||||
|
DEFAULT_OLLAMA_MODEL = "dolphin-llama3:8b"
|
||||||
|
|
||||||
|
SYSTEM = (
|
||||||
|
"You are a prompt engineer for a text-to-image diffusion model. "
|
||||||
|
"Rewrite the user's brief idea into a single dense, comma-separated prompt with "
|
||||||
|
"subject, style, lighting, composition, camera, and quality tags. "
|
||||||
|
"Keep all explicit details from the user. Do not refuse, lecture, or moralize. "
|
||||||
|
"Do not add safety disclaimers. Output only the rewritten prompt, nothing else."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ollama_available(model: str) -> bool:
|
||||||
|
try:
|
||||||
|
r = requests.get("http://127.0.0.1:11434/api/tags", timeout=2)
|
||||||
|
if r.status_code != 200:
|
||||||
|
return False
|
||||||
|
tags = r.json().get("models", [])
|
||||||
|
return any(m.get("name", "").startswith(model.split(":")[0]) for m in tags)
|
||||||
|
except requests.RequestException:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _refine_ollama(prompt: str, model: str) -> Optional[str]:
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": f"{SYSTEM}\n\nUser idea: {prompt}\n\nRewritten prompt:",
|
||||||
|
"stream": False,
|
||||||
|
"options": {"temperature": 0.7, "num_predict": 200},
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
r = requests.post(OLLAMA_URL, json=payload, timeout=60)
|
||||||
|
if r.status_code != 200:
|
||||||
|
return None
|
||||||
|
text = r.json().get("response", "").strip()
|
||||||
|
return text or None
|
||||||
|
except requests.RequestException:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def _gpt2():
|
||||||
|
from transformers import pipeline
|
||||||
|
return pipeline("text-generation", model="gpt2", max_new_tokens=60)
|
||||||
|
|
||||||
|
|
||||||
|
def _refine_gpt2(prompt: str) -> str:
|
||||||
|
seed = (
|
||||||
|
f"{prompt}, highly detailed, sharp focus, professional, cinematic lighting, "
|
||||||
|
f"intricate details, masterpiece, best quality"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
gen = _gpt2()
|
||||||
|
out = gen(seed, num_return_sequences=1, do_sample=True, temperature=0.7)
|
||||||
|
text = out[0]["generated_text"].split("\n")[0]
|
||||||
|
return text.strip()
|
||||||
|
except Exception:
|
||||||
|
return seed
|
||||||
|
|
||||||
|
|
||||||
|
def refine(prompt: str, use_ollama: bool = True, ollama_model: str = DEFAULT_OLLAMA_MODEL) -> str:
|
||||||
|
prompt = prompt.strip()
|
||||||
|
if not prompt:
|
||||||
|
return prompt
|
||||||
|
if use_ollama and _ollama_available(ollama_model):
|
||||||
|
result = _refine_ollama(prompt, ollama_model)
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
return _refine_gpt2(prompt)
|
||||||
@@ -0,0 +1,185 @@
|
|||||||
|
"""Safety filter: block CSAM only. All other content (NSFW, gore, violence) allowed.
|
||||||
|
|
||||||
|
Layers:
|
||||||
|
1. Prompt keyword gate: rejects child-term + sexual-term combinations.
|
||||||
|
2. Output check: NudeNet detects nudity. If nudity present, run face detection
|
||||||
|
(MTCNN) + age classifier (ViT) on every face. Block if any face is classified
|
||||||
|
as a minor with high confidence.
|
||||||
|
|
||||||
|
This is best-effort. The user is legally responsible for use of generated content.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
import tempfile
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# --- prompt keyword gate ---------------------------------------------------
|
||||||
|
|
||||||
|
_CHILD_TERMS = [
|
||||||
|
r"\b(child|children|kid|kids|minor|underage|under-?age|preteen|pre-?teen)\b",
|
||||||
|
r"\b(toddler|infant|baby|babies)\b",
|
||||||
|
r"\b(\d|0?[0-9]|1[0-7])\s*(yo|y/o|year[- ]?old)\b",
|
||||||
|
r"\bloli(con)?\b",
|
||||||
|
r"\bshota(con)?\b",
|
||||||
|
r"\bcp\b",
|
||||||
|
]
|
||||||
|
_SEXUAL_TERMS = [
|
||||||
|
r"\b(nude|naked|nsfw|porn|sex|sexual|sexy|erotic|explicit)\b",
|
||||||
|
r"\b(penis|vagina|breast|nipple|genital|cum|orgasm)\b",
|
||||||
|
r"\b(intercourse|fellatio|cunnilingus|masturbat)\w*\b",
|
||||||
|
r"\b(rape|molest)\w*\b",
|
||||||
|
]
|
||||||
|
|
||||||
|
_CHILD_RE = re.compile("|".join(_CHILD_TERMS), re.IGNORECASE)
|
||||||
|
_SEX_RE = re.compile("|".join(_SEXUAL_TERMS), re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SafetyResult:
|
||||||
|
allowed: bool
|
||||||
|
reason: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def check_prompt(prompt: str) -> SafetyResult:
|
||||||
|
if _CHILD_RE.search(prompt) and _SEX_RE.search(prompt):
|
||||||
|
return SafetyResult(False, "Prompt blocked: combines minor and sexual terms (CSAM gate).")
|
||||||
|
return SafetyResult(True)
|
||||||
|
|
||||||
|
|
||||||
|
# --- nudity detection ------------------------------------------------------
|
||||||
|
|
||||||
|
_NUDITY_LABELS = {
|
||||||
|
"FEMALE_BREAST_EXPOSED",
|
||||||
|
"FEMALE_GENITALIA_EXPOSED",
|
||||||
|
"MALE_GENITALIA_EXPOSED",
|
||||||
|
"BUTTOCKS_EXPOSED",
|
||||||
|
"ANUS_EXPOSED",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def _nudenet():
|
||||||
|
try:
|
||||||
|
from nudenet import NudeDetector
|
||||||
|
return NudeDetector()
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _has_nudity(detections: Iterable[dict]) -> bool:
|
||||||
|
for d in detections:
|
||||||
|
label = d.get("class") or d.get("label") or ""
|
||||||
|
score = float(d.get("score", 0.0))
|
||||||
|
if label in _NUDITY_LABELS and score >= 0.5:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# --- face detection + age classification -----------------------------------
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def _mtcnn():
|
||||||
|
try:
|
||||||
|
from facenet_pytorch import MTCNN
|
||||||
|
import torch
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
return MTCNN(keep_all=True, device=device, post_process=False, min_face_size=40)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def _age_classifier():
|
||||||
|
try:
|
||||||
|
from transformers import pipeline
|
||||||
|
return pipeline(
|
||||||
|
"image-classification",
|
||||||
|
model="nateraw/vit-age-classifier",
|
||||||
|
top_k=3,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Buckets reported by nateraw/vit-age-classifier.
|
||||||
|
# Conservative minor set. "10-19" includes some adults — treat as minor only on
|
||||||
|
# strong confidence to limit false positives on young-looking adults.
|
||||||
|
_HARD_MINOR = {"0-2", "3-9"}
|
||||||
|
_SOFT_MINOR = {"10-19"}
|
||||||
|
|
||||||
|
|
||||||
|
def _faces(img: Image.Image):
|
||||||
|
mtcnn = _mtcnn()
|
||||||
|
if mtcnn is None:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
boxes, probs = mtcnn.detect(img)
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
if boxes is None:
|
||||||
|
return []
|
||||||
|
if probs is None:
|
||||||
|
probs = [None] * len(boxes)
|
||||||
|
out = []
|
||||||
|
for box, prob in zip(boxes, probs):
|
||||||
|
if prob is None or float(prob) < 0.9:
|
||||||
|
continue
|
||||||
|
x1, y1, x2, y2 = [int(max(0, v)) for v in box]
|
||||||
|
if x2 - x1 < 30 or y2 - y1 < 30:
|
||||||
|
continue
|
||||||
|
out.append(img.crop((x1, y1, x2, y2)))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _is_minor_face(face_img: Image.Image) -> tuple[bool, str]:
|
||||||
|
clf = _age_classifier()
|
||||||
|
if clf is None:
|
||||||
|
return False, ""
|
||||||
|
try:
|
||||||
|
preds = clf(face_img)
|
||||||
|
except Exception:
|
||||||
|
return False, ""
|
||||||
|
# preds is list[dict(label, score)] sorted by score desc
|
||||||
|
top = preds[0] if preds else None
|
||||||
|
if not top:
|
||||||
|
return False, ""
|
||||||
|
label = top["label"]
|
||||||
|
score = float(top["score"])
|
||||||
|
if label in _HARD_MINOR and score >= 0.55:
|
||||||
|
return True, f"minor face detected ({label}, conf={score:.2f})"
|
||||||
|
if label in _SOFT_MINOR and score >= 0.85:
|
||||||
|
return True, f"likely minor face ({label}, conf={score:.2f})"
|
||||||
|
return False, ""
|
||||||
|
|
||||||
|
|
||||||
|
def check_image(img: Image.Image) -> SafetyResult:
|
||||||
|
"""Block if (nudity present) AND (any face classified as minor)."""
|
||||||
|
det = _nudenet()
|
||||||
|
if det is None:
|
||||||
|
# No nudity detector available — fall through. Prompt gate is primary defense.
|
||||||
|
return SafetyResult(True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||||
|
img.save(f.name)
|
||||||
|
results = det.detect(f.name)
|
||||||
|
except Exception:
|
||||||
|
return SafetyResult(True)
|
||||||
|
|
||||||
|
if not _has_nudity(results):
|
||||||
|
return SafetyResult(True)
|
||||||
|
|
||||||
|
faces = _faces(img)
|
||||||
|
for face in faces:
|
||||||
|
is_minor, reason = _is_minor_face(face)
|
||||||
|
if is_minor:
|
||||||
|
return SafetyResult(
|
||||||
|
False,
|
||||||
|
f"Output blocked: nudity + {reason}. Image discarded.",
|
||||||
|
)
|
||||||
|
return SafetyResult(True)
|
||||||
@@ -0,0 +1,118 @@
|
|||||||
|
"""Video generation. LTX-Video / CogVideoX / Wan via diffusers pipelines."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .device import get_device, hardware_info, torch_dtype
|
||||||
|
from .memory import apply_memory_strategy
|
||||||
|
from .models import ModelSpec, find
|
||||||
|
|
||||||
|
ROOT = Path(__file__).parent.parent
|
||||||
|
MODELS_DIR = ROOT / "models"
|
||||||
|
OUTPUTS = ROOT / "outputs"
|
||||||
|
OUTPUTS.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _import_pipeline_for(repo: str):
|
||||||
|
"""Pick the right pipeline class. Wan needs diffusers>=0.32 with WanPipeline."""
|
||||||
|
if "LTX" in repo:
|
||||||
|
from diffusers import LTXPipeline
|
||||||
|
return LTXPipeline
|
||||||
|
if "CogVideoX" in repo:
|
||||||
|
from diffusers import CogVideoXPipeline
|
||||||
|
return CogVideoXPipeline
|
||||||
|
if "Wan" in repo:
|
||||||
|
try:
|
||||||
|
from diffusers import WanPipeline
|
||||||
|
return WanPipeline
|
||||||
|
except ImportError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Wan 2.1 needs diffusers>=0.32. Run: pip install -U diffusers"
|
||||||
|
) from e
|
||||||
|
if "Hunyuan" in repo:
|
||||||
|
from diffusers import HunyuanVideoPipeline
|
||||||
|
return HunyuanVideoPipeline
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
return DiffusionPipeline
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def _load_pipeline(repo: str):
|
||||||
|
dtype = torch_dtype()
|
||||||
|
cache_dir = str(MODELS_DIR / "diffusers")
|
||||||
|
|
||||||
|
PipelineCls = _import_pipeline_for(repo)
|
||||||
|
pipe = PipelineCls.from_pretrained(repo, torch_dtype=dtype, cache_dir=cache_dir)
|
||||||
|
|
||||||
|
apply_memory_strategy(pipe)
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
def _model_kwargs(repo: str, base: dict) -> dict:
|
||||||
|
"""Adjust kwargs per-pipeline. Some pipes don't accept width/height or have different param names."""
|
||||||
|
out = dict(base)
|
||||||
|
if "Wan" in repo:
|
||||||
|
# Wan accepts height/width but expects multiples of 16.
|
||||||
|
out["height"] = (out["height"] // 16) * 16
|
||||||
|
out["width"] = (out["width"] // 16) * 16
|
||||||
|
if "LTX" in repo:
|
||||||
|
# LTX needs 32-multiple resolutions.
|
||||||
|
out["height"] = (out["height"] // 32) * 32
|
||||||
|
out["width"] = (out["width"] // 32) * 32
|
||||||
|
if "CogVideoX" in repo:
|
||||||
|
# CogVideoX has fixed 720x480 default; clamp.
|
||||||
|
out["height"] = min(out["height"], 480)
|
||||||
|
out["width"] = min(out["width"], 720)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
model_id: str = "ltx-video",
|
||||||
|
width: int = 704,
|
||||||
|
height: int = 480,
|
||||||
|
num_frames: int = 73,
|
||||||
|
fps: int = 24,
|
||||||
|
steps: int = 30,
|
||||||
|
guidance: float = 3.0,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
) -> str:
|
||||||
|
spec: ModelSpec | None = find(model_id)
|
||||||
|
if spec is None or spec.kind != "video":
|
||||||
|
raise ValueError(f"Unknown video model: {model_id}")
|
||||||
|
|
||||||
|
pipe = _load_pipeline(spec.repo)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers.utils import export_to_video
|
||||||
|
|
||||||
|
generator = None
|
||||||
|
if seed is not None:
|
||||||
|
try:
|
||||||
|
generator = torch.Generator(device=get_device()).manual_seed(int(seed))
|
||||||
|
except RuntimeError:
|
||||||
|
generator = torch.Generator().manual_seed(int(seed))
|
||||||
|
|
||||||
|
base = dict(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt or None,
|
||||||
|
width=int(width),
|
||||||
|
height=int(height),
|
||||||
|
num_frames=int(num_frames),
|
||||||
|
num_inference_steps=int(steps),
|
||||||
|
guidance_scale=float(guidance),
|
||||||
|
generator=generator,
|
||||||
|
)
|
||||||
|
kwargs = _model_kwargs(spec.repo, base)
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
|
out = pipe(**kwargs)
|
||||||
|
frames = out.frames[0]
|
||||||
|
|
||||||
|
import time
|
||||||
|
path = OUTPUTS / f"video_{int(time.time())}.mp4"
|
||||||
|
export_to_video(frames, str(path), fps=int(fps))
|
||||||
|
return str(path)
|
||||||
+31
@@ -0,0 +1,31 @@
|
|||||||
|
{
|
||||||
|
"ui": {
|
||||||
|
"host": "127.0.0.1",
|
||||||
|
"port": 7860,
|
||||||
|
"open_browser": true,
|
||||||
|
"theme": "default"
|
||||||
|
},
|
||||||
|
"refiner": {
|
||||||
|
"use_ollama": true,
|
||||||
|
"ollama_model": "dolphin-llama3:8b",
|
||||||
|
"auto_refine_threshold_words": 8
|
||||||
|
},
|
||||||
|
"image_defaults": {
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"steps": 30,
|
||||||
|
"guidance": 7.0,
|
||||||
|
"negative_prompt": "blurry, low quality, watermark, text, signature, deformed, bad anatomy, extra limbs"
|
||||||
|
},
|
||||||
|
"video_defaults": {
|
||||||
|
"width": 704,
|
||||||
|
"height": 480,
|
||||||
|
"num_frames": 73,
|
||||||
|
"fps": 24,
|
||||||
|
"steps": 30,
|
||||||
|
"guidance": 3.0
|
||||||
|
},
|
||||||
|
"safety": {
|
||||||
|
"csam_gate": true
|
||||||
|
}
|
||||||
|
}
|
||||||
+314
@@ -0,0 +1,314 @@
|
|||||||
|
"""First-run launcher.
|
||||||
|
|
||||||
|
Bootstraps a clean Python 3.11 environment via `uv`, regardless of the system
|
||||||
|
Python the user invoked us with. Keeps the user on a single supported runtime
|
||||||
|
while the AI ecosystem stabilizes around newer Python versions.
|
||||||
|
|
||||||
|
uv install strategy (in order):
|
||||||
|
1. Direct binary download from GitHub releases (no shell, no admin).
|
||||||
|
2. PowerShell / shell installer.
|
||||||
|
3. pip fallback, invoked as `python -m uv`.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import urllib.request
|
||||||
|
import zipfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
ROOT = Path(__file__).parent.resolve()
|
||||||
|
VENV_DIR = ROOT / "venv"
|
||||||
|
MARKER = VENV_DIR / ".kawai_ready"
|
||||||
|
HARDWARE_CACHE = ROOT / "config.local.json"
|
||||||
|
UV_CACHE_DIR = ROOT / ".tools"
|
||||||
|
PYTHON_TARGET = "3.11"
|
||||||
|
|
||||||
|
|
||||||
|
def venv_python() -> Path:
|
||||||
|
if os.name == "nt":
|
||||||
|
return VENV_DIR / "Scripts" / "python.exe"
|
||||||
|
return VENV_DIR / "bin" / "python"
|
||||||
|
|
||||||
|
|
||||||
|
# --- uv discovery / install ------------------------------------------------
|
||||||
|
|
||||||
|
def _local_uv_path() -> Path:
|
||||||
|
return UV_CACHE_DIR / ("uv.exe" if os.name == "nt" else "uv")
|
||||||
|
|
||||||
|
|
||||||
|
def _uv_argv() -> list[str] | None:
|
||||||
|
"""Return argv prefix to invoke uv. None if not available."""
|
||||||
|
local = _local_uv_path()
|
||||||
|
if local.exists():
|
||||||
|
return [str(local)]
|
||||||
|
found = shutil.which("uv")
|
||||||
|
if found:
|
||||||
|
return [found]
|
||||||
|
# Installed via pip into current Python (works as module).
|
||||||
|
try:
|
||||||
|
subprocess.check_output(
|
||||||
|
[sys.executable, "-m", "uv", "--version"],
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
return [sys.executable, "-m", "uv"]
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _uv_release_asset() -> str | None:
|
||||||
|
"""Pick the GitHub release asset name for this OS+arch."""
|
||||||
|
machine = platform.machine().lower()
|
||||||
|
arch_win = "x86_64" if machine in ("amd64", "x86_64") else "aarch64" if "arm" in machine else None
|
||||||
|
if os.name == "nt" and arch_win:
|
||||||
|
return f"uv-{arch_win}-pc-windows-msvc.zip"
|
||||||
|
if sys.platform == "darwin":
|
||||||
|
arch = "aarch64" if "arm" in machine else "x86_64"
|
||||||
|
return f"uv-{arch}-apple-darwin.tar.gz"
|
||||||
|
if sys.platform.startswith("linux"):
|
||||||
|
arch = "aarch64" if "aarch64" in machine or "arm64" in machine else "x86_64"
|
||||||
|
return f"uv-{arch}-unknown-linux-gnu.tar.gz"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _download_uv() -> bool:
|
||||||
|
"""Download uv binary directly from GitHub releases. Most reliable path."""
|
||||||
|
asset = _uv_release_asset()
|
||||||
|
if asset is None:
|
||||||
|
return False
|
||||||
|
url = f"https://github.com/astral-sh/uv/releases/latest/download/{asset}"
|
||||||
|
UV_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
print(f"[kawai] Downloading uv from {url}")
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(url, timeout=60) as resp:
|
||||||
|
data = resp.read()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[kawai] download failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
if asset.endswith(".zip"):
|
||||||
|
with zipfile.ZipFile(io.BytesIO(data)) as z:
|
||||||
|
for name in z.namelist():
|
||||||
|
if name.endswith("uv.exe") or name.endswith("/uv"):
|
||||||
|
target = _local_uv_path()
|
||||||
|
target.write_bytes(z.read(name))
|
||||||
|
if os.name != "nt":
|
||||||
|
target.chmod(0o755)
|
||||||
|
return target.exists()
|
||||||
|
else:
|
||||||
|
import tarfile
|
||||||
|
with tarfile.open(fileobj=io.BytesIO(data), mode="r:gz") as t:
|
||||||
|
for member in t.getmembers():
|
||||||
|
if member.name.endswith("/uv") or member.name == "uv":
|
||||||
|
f = t.extractfile(member)
|
||||||
|
if f is None:
|
||||||
|
continue
|
||||||
|
target = _local_uv_path()
|
||||||
|
target.write_bytes(f.read())
|
||||||
|
target.chmod(0o755)
|
||||||
|
return target.exists()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[kawai] extract failed: {e}")
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _install_uv_via_shell() -> bool:
|
||||||
|
"""Use astral.sh installer scripts. Often blocked on locked-down systems."""
|
||||||
|
UV_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["UV_INSTALL_DIR"] = str(UV_CACHE_DIR)
|
||||||
|
env["UV_NO_MODIFY_PATH"] = "1"
|
||||||
|
try:
|
||||||
|
if os.name == "nt":
|
||||||
|
subprocess.check_call(
|
||||||
|
[
|
||||||
|
"powershell",
|
||||||
|
"-NoProfile",
|
||||||
|
"-ExecutionPolicy", "Bypass",
|
||||||
|
"-Command",
|
||||||
|
"irm https://astral.sh/uv/install.ps1 | iex",
|
||||||
|
],
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
subprocess.check_call(
|
||||||
|
["bash", "-c", "curl -LsSf https://astral.sh/uv/install.sh | sh"],
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
|
return False
|
||||||
|
return _local_uv_path().exists()
|
||||||
|
|
||||||
|
|
||||||
|
def _install_uv_via_pip() -> bool:
|
||||||
|
print("[kawai] Installing uv via pip...")
|
||||||
|
try:
|
||||||
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "--user", "--upgrade", "uv"])
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
return False
|
||||||
|
# Verify it's invokable as a module.
|
||||||
|
try:
|
||||||
|
subprocess.check_output(
|
||||||
|
[sys.executable, "-m", "uv", "--version"],
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_uv() -> list[str]:
|
||||||
|
argv = _uv_argv()
|
||||||
|
if argv:
|
||||||
|
return argv
|
||||||
|
|
||||||
|
print("[kawai] Installing uv...")
|
||||||
|
if _download_uv():
|
||||||
|
return _uv_argv() or []
|
||||||
|
|
||||||
|
if _install_uv_via_shell():
|
||||||
|
argv = _uv_argv()
|
||||||
|
if argv:
|
||||||
|
return argv
|
||||||
|
|
||||||
|
if _install_uv_via_pip():
|
||||||
|
argv = _uv_argv()
|
||||||
|
if argv:
|
||||||
|
return argv
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
"Failed to install uv. Install manually from https://astral.sh/uv "
|
||||||
|
"and rerun this launcher."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- venv + deps -----------------------------------------------------------
|
||||||
|
|
||||||
|
def _create_venv(uv: list[str]) -> None:
|
||||||
|
if venv_python().exists():
|
||||||
|
return
|
||||||
|
print(f"[kawai] Creating venv with Python {PYTHON_TARGET} (uv will download it if needed)...")
|
||||||
|
subprocess.check_call([*uv, "venv", str(VENV_DIR), "--python", PYTHON_TARGET])
|
||||||
|
|
||||||
|
|
||||||
|
def _uv_pip(uv: list[str], args: list[str]) -> None:
|
||||||
|
cmd = [*uv, "pip", "install", "--python", str(venv_python()), *args]
|
||||||
|
print(f"[kawai] uv pip install {' '.join(args)}")
|
||||||
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_and_install(
|
||||||
|
uv: list[str],
|
||||||
|
force_backend: str | None = None,
|
||||||
|
force_vendor: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
sys.path.insert(0, str(ROOT))
|
||||||
|
from backends import hardware
|
||||||
|
|
||||||
|
info = hardware.detect(force_backend=force_backend, force_vendor=force_vendor)
|
||||||
|
forced_note = " (forced)" if force_backend and force_backend != "auto" else ""
|
||||||
|
print(
|
||||||
|
f"[kawai] Backend: {info.backend}{forced_note} | "
|
||||||
|
f"{info.vendor} / {info.device_name} / {info.vram_gb:.1f} GB / tier={info.tier}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_uv_pip(uv, hardware.torch_install_args(info))
|
||||||
|
_uv_pip(uv, ["-r", str(ROOT / "requirements.txt")])
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"vendor": info.vendor,
|
||||||
|
"backend": info.backend,
|
||||||
|
"device_name": info.device_name,
|
||||||
|
"vram_gb": info.vram_gb,
|
||||||
|
"tier": info.tier,
|
||||||
|
"forced": bool(force_backend and force_backend != "auto"),
|
||||||
|
}
|
||||||
|
HARDWARE_CACHE.write_text(json.dumps(payload, indent=2))
|
||||||
|
MARKER.write_text("ok")
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def already_in_venv() -> bool:
|
||||||
|
try:
|
||||||
|
return Path(sys.executable).resolve() == venv_python().resolve()
|
||||||
|
except OSError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def relaunch_in_venv(forwarded_args: list[str]) -> None:
|
||||||
|
"""Re-exec the launcher inside the venv. Use subprocess on Windows because
|
||||||
|
os.execv mangles argv with spaces in paths."""
|
||||||
|
print("[kawai] Relaunching inside venv...")
|
||||||
|
py = str(venv_python())
|
||||||
|
script = str(ROOT / "launcher.py")
|
||||||
|
argv = [py, script, *forwarded_args]
|
||||||
|
if os.name == "nt":
|
||||||
|
result = subprocess.run(argv)
|
||||||
|
sys.exit(result.returncode)
|
||||||
|
else:
|
||||||
|
os.execv(py, argv)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_arg_parser() -> argparse.ArgumentParser:
|
||||||
|
p = argparse.ArgumentParser(
|
||||||
|
prog="kawai",
|
||||||
|
description="Local AI image/video generator. Auto-detects GPU; pass --backend to override.",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--backend",
|
||||||
|
choices=["auto", "cuda", "rocm", "directml", "mps", "cpu"],
|
||||||
|
default="auto",
|
||||||
|
help=(
|
||||||
|
"Force torch backend. cuda=NVIDIA, rocm=AMD on Linux, directml=AMD/Intel on Windows, "
|
||||||
|
"mps=Apple Silicon, cpu=fallback. Default: auto-detect."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--vendor",
|
||||||
|
choices=["nvidia", "amd", "intel", "apple", "cpu"],
|
||||||
|
default=None,
|
||||||
|
help="Override detected vendor (rarely needed; useful when pairing --backend directml with intel).",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--reinstall",
|
||||||
|
action="store_true",
|
||||||
|
help="Force re-detect and reinstall torch (clears the install marker).",
|
||||||
|
)
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = _build_arg_parser().parse_args()
|
||||||
|
forced = args.backend if args.backend != "auto" else None
|
||||||
|
|
||||||
|
if args.reinstall and MARKER.exists():
|
||||||
|
print("[kawai] --reinstall: clearing install marker")
|
||||||
|
MARKER.unlink()
|
||||||
|
|
||||||
|
if already_in_venv():
|
||||||
|
if not MARKER.exists():
|
||||||
|
uv = _ensure_uv()
|
||||||
|
detect_and_install(uv, force_backend=forced, force_vendor=args.vendor)
|
||||||
|
from app import run
|
||||||
|
run()
|
||||||
|
return
|
||||||
|
|
||||||
|
uv = _ensure_uv()
|
||||||
|
_create_venv(uv)
|
||||||
|
if not MARKER.exists():
|
||||||
|
detect_and_install(uv, force_backend=forced, force_vendor=args.vendor)
|
||||||
|
relaunch_in_venv(sys.argv[1:])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
# Core
|
||||||
|
gradio>=4.40.0
|
||||||
|
pillow>=10.0.0
|
||||||
|
numpy>=1.26.0
|
||||||
|
requests>=2.31.0
|
||||||
|
tqdm>=4.66.0
|
||||||
|
huggingface_hub>=0.24.0
|
||||||
|
|
||||||
|
# Diffusion (latest available; pinned floors only)
|
||||||
|
diffusers>=0.32.0
|
||||||
|
transformers>=4.45.0
|
||||||
|
accelerate>=0.34.0
|
||||||
|
safetensors>=0.4.5
|
||||||
|
sentencepiece>=0.2.0
|
||||||
|
protobuf>=4.25.0
|
||||||
|
peft>=0.13.0
|
||||||
|
|
||||||
|
# Video export
|
||||||
|
imageio[ffmpeg]>=2.34.0
|
||||||
|
opencv-python-headless>=4.10.0
|
||||||
|
|
||||||
|
# Safety
|
||||||
|
nudenet>=3.4.0
|
||||||
|
facenet-pytorch>=2.5.3
|
||||||
|
|
||||||
|
# Torch installed separately by launcher with vendor-specific wheel
|
||||||
Reference in New Issue
Block a user