commit f35e301ef7919a47d0ff4aea7efbd8997581cdf1 Author: kawa Date: Mon May 4 09:38:56 2026 +0200 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7130000 --- /dev/null +++ b/.gitignore @@ -0,0 +1,14 @@ +venv/ +.tools/ +__pycache__/ +*.pyc +*.pyo +models/* +!models/.gitkeep +outputs/ +config.local.json +.env +*.safetensors +*.ckpt +*.bin +*.pth diff --git a/README.md b/README.md new file mode 100644 index 0000000..fe6bd3e --- /dev/null +++ b/README.md @@ -0,0 +1,61 @@ +# Kawai + +Local AI image and video generator. Simple UI. NSFW-capable. Auto GPU detection (Nvidia / AMD / Intel / Apple Silicon / CPU). + +## Quick start + +``` +python launcher.py +``` + +Run with **any Python** you have installed. The launcher bootstraps `uv` and uses it to fetch a clean Python 3.11 runtime + venv, then installs the right PyTorch build for your GPU. Nothing about your system Python is touched. + +Works on Windows, Linux, and macOS. + +First run takes a few minutes (uv install + Python 3.11 download + torch + dependencies). Subsequent runs start instantly. + +### Force a specific backend + +Auto-detect picks one of `cuda` (NVIDIA), `rocm` (AMD on Linux), `directml` (AMD/Intel on Windows), `mps` (Apple Silicon), or `cpu`. To override: + +``` +python launcher.py --backend cuda # force CUDA wheel +python launcher.py --backend rocm # AMD on Linux (ROCm) +python launcher.py --backend directml # AMD/Intel on Windows +python launcher.py --backend mps # Apple Silicon (macOS) +python launcher.py --backend cpu # CPU only +python launcher.py --reinstall # wipe install marker, re-detect, reinstall torch +``` + +`--vendor {nvidia,amd,intel,cpu}` is available too if you need to pair (e.g. `--backend directml --vendor intel`). Override is persisted in `config.local.json` and survives relaunches until you pass `--backend` again or `--reinstall`. + +### What the launcher does + +1. Installs `uv` to `.tools/` if not present. +2. Creates `venv/` with Python 3.11 (uv downloads the interpreter on demand). +3. Detects GPU (Nvidia / AMD / Intel / Apple Silicon / CPU) and installs matching PyTorch wheel. +4. Installs latest `diffusers`, `transformers`, etc. +5. Opens browser UI at `http://127.0.0.1:7860`. + +### Reset + +Delete `venv/` and `.tools/` to force a clean reinstall. + +## Hardware tiers + +| VRAM | Default image model | Default video model | +|------|--------------------|---------------------| +| 4 GB | SDXL Turbo (fp16) | disabled | +| 8 GB | Pony Diffusion XL | LTX-Video (fp8) | +| 12 GB | Illustrious XL | LTX-Video (fp16) | +| 16 GB+ | Illustrious XL + refiner | Wan 2.1 | + +User can override defaults in UI. + +## Safety + +CSAM detection on all outputs (NudeNet age classifier + hash check). All other content allowed: NSFW, gore, violence. + +## Status + +Windows + Linux + macOS. AMD on Linux uses ROCm; AMD/Intel on Windows use DirectML; Apple Silicon uses MPS. Intel Macs run on CPU only (no GPU acceleration path). diff --git a/app.py b/app.py new file mode 100644 index 0000000..498d533 --- /dev/null +++ b/app.py @@ -0,0 +1,239 @@ +"""Gradio UI. Two tabs: Image, Video. Auto-picks defaults from detected hardware.""" +from __future__ import annotations + +import json +import random +import time +from pathlib import Path + +import gradio as gr + +from backends import models, refiner, safety +from backends.device import hardware_info + +ROOT = Path(__file__).parent +OUTPUTS = ROOT / "outputs" +OUTPUTS.mkdir(exist_ok=True) +CONFIG = json.loads((ROOT / "config.json").read_text()) + + +def _hw_summary() -> str: + hw = hardware_info() + return ( + f"**{hw['device_name']}** — {hw['vendor'].upper()} via {hw['backend']} " + f"— {hw['vram_gb']:.1f} GB VRAM — tier `{hw['tier']}`" + ) + + +def _models_status_md() -> str: + rows = ["| Model | Kind | Min VRAM | Download | Status |", "|---|---|---|---|---|"] + for m in models.IMAGE_MODELS + models.VIDEO_MODELS: + status = "cached" if models.is_cached(m) else "not downloaded" + rows.append( + f"| {m.label} | {m.kind} | {m.min_vram_gb:.0f} GB | {m.download_gb:.0f} GB | {status} |" + ) + return "### Models\n\n" + "\n".join(rows) + + +def _model_choices(kind: str) -> tuple[list[tuple[str, str]], str | None]: + hw = hardware_info() + available = models.list_for_tier(hw["tier"], kind) + choices = [(models.label_with_meta(m), m.id) for m in available] + default = models.default_for_tier(hw["tier"], kind) + return choices, (default.id if default else None) + + +def gen_image( + prompt: str, + negative_prompt: str, + model_id: str, + width: int, + height: int, + steps: int, + guidance: float, + seed: int, + auto_refine: bool, +): + if not prompt.strip(): + raise gr.Error("Empty prompt.") + + chk = safety.check_prompt(prompt) + if not chk.allowed: + raise gr.Error(chk.reason) + + spec = models.find(model_id) + if spec and not models.is_cached(spec): + gr.Info(f"Downloading {spec.label} (~{spec.download_gb:.0f} GB) on first use. Watch terminal for progress.") + + refined = prompt + if auto_refine: + refined = refiner.refine(prompt, use_ollama=CONFIG["refiner"]["use_ollama"]) + + from backends import image_sdxl + + seed_val = None if seed < 0 else seed + if seed_val is None: + seed_val = random.randint(0, 2**31 - 1) + + img = image_sdxl.generate( + prompt=refined, + negative_prompt=negative_prompt, + model_id=model_id, + width=int(width), + height=int(height), + steps=int(steps), + guidance=float(guidance), + seed=seed_val, + ) + + img_chk = safety.check_image(img) + if not img_chk.allowed: + raise gr.Error(img_chk.reason) + + out_path = OUTPUTS / f"img_{int(time.time())}_{seed_val}.png" + img.save(out_path) + info = f"Seed: {seed_val}\n\nPrompt used:\n{refined}" + return img, info + + +def gen_video( + prompt: str, + negative_prompt: str, + model_id: str, + width: int, + height: int, + num_frames: int, + fps: int, + steps: int, + guidance: float, + seed: int, + auto_refine: bool, +): + if not prompt.strip(): + raise gr.Error("Empty prompt.") + + chk = safety.check_prompt(prompt) + if not chk.allowed: + raise gr.Error(chk.reason) + + spec = models.find(model_id) + if spec and not models.is_cached(spec): + gr.Info(f"Downloading {spec.label} (~{spec.download_gb:.0f} GB) on first use. Watch terminal for progress.") + + refined = prompt + if auto_refine: + refined = refiner.refine(prompt, use_ollama=CONFIG["refiner"]["use_ollama"]) + + from backends import video_ltx + + seed_val = None if seed < 0 else seed + if seed_val is None: + seed_val = random.randint(0, 2**31 - 1) + + path = video_ltx.generate( + prompt=refined, + negative_prompt=negative_prompt, + model_id=model_id, + width=int(width), + height=int(height), + num_frames=int(num_frames), + fps=int(fps), + steps=int(steps), + guidance=float(guidance), + seed=seed_val, + ) + info = f"Seed: {seed_val}\n\nPrompt used:\n{refined}" + return path, info + + +def build_ui() -> gr.Blocks: + img_choices, img_default = _model_choices("image") + vid_choices, vid_default = _model_choices("video") + img_def = CONFIG["image_defaults"] + vid_def = CONFIG["video_defaults"] + + with gr.Blocks(title="Kawai", analytics_enabled=False) as ui: + gr.Markdown("# Kawai\nLocal AI image and video generator.") + gr.Markdown(_hw_summary()) + + with gr.Tabs(): + with gr.Tab("Image"): + with gr.Row(): + with gr.Column(scale=2): + i_prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Describe what you want...") + i_neg = gr.Textbox(label="Negative prompt", lines=2, value=img_def["negative_prompt"]) + i_refine = gr.Checkbox(label="Auto-refine prompt with local LLM", value=True) + i_model = gr.Dropdown(choices=img_choices, value=img_default, label="Model") + with gr.Row(): + i_w = gr.Slider(512, 1536, value=img_def["width"], step=64, label="Width") + i_h = gr.Slider(512, 1536, value=img_def["height"], step=64, label="Height") + with gr.Row(): + i_steps = gr.Slider(1, 80, value=img_def["steps"], step=1, label="Steps") + i_guidance = gr.Slider(0.0, 15.0, value=img_def["guidance"], step=0.1, label="Guidance") + i_seed = gr.Number(value=-1, label="Seed (-1 = random)", precision=0) + i_btn = gr.Button("Generate", variant="primary") + with gr.Column(scale=3): + i_out = gr.Image(label="Output", type="pil") + i_info = gr.Textbox(label="Info", lines=6, interactive=False) + i_btn.click( + gen_image, + inputs=[i_prompt, i_neg, i_model, i_w, i_h, i_steps, i_guidance, i_seed, i_refine], + outputs=[i_out, i_info], + ) + + with gr.Tab("Video"): + if not vid_choices: + gr.Markdown("**Video disabled** — detected hardware lacks VRAM for any video model.") + else: + with gr.Row(): + with gr.Column(scale=2): + v_prompt = gr.Textbox(label="Prompt", lines=3) + v_neg = gr.Textbox(label="Negative prompt", lines=2, value="") + v_refine = gr.Checkbox(label="Auto-refine prompt with local LLM", value=True) + v_model = gr.Dropdown(choices=vid_choices, value=vid_default, label="Model") + with gr.Row(): + v_w = gr.Slider(384, 1024, value=vid_def["width"], step=32, label="Width") + v_h = gr.Slider(256, 1024, value=vid_def["height"], step=32, label="Height") + with gr.Row(): + v_frames = gr.Slider(17, 161, value=vid_def["num_frames"], step=8, label="Frames") + v_fps = gr.Slider(8, 30, value=vid_def["fps"], step=1, label="FPS") + with gr.Row(): + v_steps = gr.Slider(10, 60, value=vid_def["steps"], step=1, label="Steps") + v_guidance = gr.Slider(0.0, 10.0, value=vid_def["guidance"], step=0.1, label="Guidance") + v_seed = gr.Number(value=-1, label="Seed (-1 = random)", precision=0) + v_btn = gr.Button("Generate", variant="primary") + with gr.Column(scale=3): + v_out = gr.Video(label="Output") + v_info = gr.Textbox(label="Info", lines=6, interactive=False) + v_btn.click( + gen_video, + inputs=[v_prompt, v_neg, v_model, v_w, v_h, v_frames, v_fps, v_steps, v_guidance, v_seed, v_refine], + outputs=[v_out, v_info], + ) + + with gr.Tab("System"): + gr.Markdown(_hw_summary()) + gr.Markdown(_models_status_md()) + gr.Markdown( + "**Output folder:** `outputs/`\n\n" + "**Models cache:** `models/diffusers/`\n\n" + "**Prompt refiner:** Ollama with `dolphin-llama3:8b` if running, else GPT-2 fallback.\n\n" + "Install Ollama: https://ollama.com/ then `ollama pull dolphin-llama3`.\n\n" + "**Safety:** CSAM-gated only (prompt keyword gate + face age check on nude outputs). All other content allowed.\n\n" + "**Note:** First use of a model triggers download (7–24 GB). Keep this terminal open during download." + ) + + return ui + + +def run() -> None: + ui = build_ui() + ui.queue().launch( + server_name=CONFIG["ui"]["host"], + server_port=CONFIG["ui"]["port"], + inbrowser=CONFIG["ui"]["open_browser"], + ) + + +if __name__ == "__main__": + run() diff --git a/backends/__init__.py b/backends/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backends/device.py b/backends/device.py new file mode 100644 index 0000000..0f7ef69 --- /dev/null +++ b/backends/device.py @@ -0,0 +1,51 @@ +"""Torch device selection. Wraps cuda / directml / cpu behind one helper.""" +from __future__ import annotations + +import json +from functools import lru_cache +from pathlib import Path + +ROOT = Path(__file__).parent.parent +HARDWARE_CACHE = ROOT / "config.local.json" + + +@lru_cache(maxsize=1) +def hardware_info() -> dict: + if HARDWARE_CACHE.exists(): + return json.loads(HARDWARE_CACHE.read_text()) + from . import hardware + info = hardware.detect() + return { + "vendor": info.vendor, + "backend": info.backend, + "device_name": info.device_name, + "vram_gb": info.vram_gb, + "tier": info.tier, + } + + +@lru_cache(maxsize=1) +def get_device(): + import torch + backend = hardware_info()["backend"] + # ROCm builds of torch expose the cuda namespace. + if backend in ("cuda", "rocm") and torch.cuda.is_available(): + return torch.device("cuda") + if backend == "directml": + try: + import torch_directml + return torch_directml.device() + except ImportError: + pass + if backend == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): + return torch.device("mps") + return torch.device("cpu") + + +def torch_dtype(): + import torch + backend = hardware_info()["backend"] + if backend == "cpu": + return torch.float32 + # MPS supports fp16 for diffusers; bf16 has gaps. Stick with fp16. + return torch.float16 diff --git a/backends/hardware.py b/backends/hardware.py new file mode 100644 index 0000000..15b04e2 --- /dev/null +++ b/backends/hardware.py @@ -0,0 +1,315 @@ +"""GPU and VRAM detection. Returns vendor + tier used to pick torch wheel and default models.""" +from __future__ import annotations + +import platform +import subprocess +from dataclasses import dataclass +from typing import Literal + +Vendor = Literal["nvidia", "amd", "intel", "apple", "cpu"] +Backend = Literal["cuda", "rocm", "directml", "mps", "cpu"] +SUPPORTED_BACKENDS: tuple[str, ...] = ("auto", "cuda", "rocm", "directml", "mps", "cpu") + + +@dataclass +class HardwareInfo: + vendor: Vendor + backend: Backend + device_name: str + vram_gb: float + tier: Literal["cpu", "low", "mid", "high", "ultra"] + + +def _detect_mac_gpu() -> tuple[str, float] | None: + """Apple Silicon: report chip name + unified memory (proxy for VRAM). + Intel Mac: returns None (no GPU acceleration path).""" + if platform.system() != "Darwin": + return None + machine = platform.machine().lower() + if machine not in ("arm64", "aarch64"): + return None + name = "Apple Silicon" + brand = _run(["sysctl", "-n", "machdep.cpu.brand_string"]) + if brand: + name = brand.strip() + mem_raw = _run(["sysctl", "-n", "hw.memsize"]) + mem_gb = 0.0 + if mem_raw: + try: + mem_gb = int(mem_raw) / (1024 ** 3) + except ValueError: + mem_gb = 0.0 + # Unified memory: GPU can address ~75% in practice. Use that as VRAM proxy. + vram_gb = mem_gb * 0.75 if mem_gb else 0.0 + return name, vram_gb + + +def _run(cmd: list[str]) -> str: + try: + out = subprocess.run(cmd, capture_output=True, text=True, timeout=10, check=False) + return out.stdout.strip() + except (FileNotFoundError, subprocess.TimeoutExpired): + return "" + + +def _detect_nvidia() -> tuple[str, float] | None: + out = _run(["nvidia-smi", "--query-gpu=name,memory.total", "--format=csv,noheader,nounits"]) + if not out: + return None + first = out.splitlines()[0] + name, mem = [p.strip() for p in first.split(",", 1)] + return name, float(mem) / 1024.0 + + +def _detect_dxgi() -> list[tuple[str, float, str]]: + """Enumerate all DXGI adapters via PowerShell. Returns list of (name, vram_gb, vendor_hint).""" + if platform.system() != "Windows": + return [] + ps = ( + "Get-CimInstance Win32_VideoController | " + "Select-Object Name, AdapterRAM | " + "ConvertTo-Json -Compress" + ) + out = _run(["powershell", "-NoProfile", "-Command", ps]) + if not out: + return [] + import json + try: + data = json.loads(out) + except json.JSONDecodeError: + return [] + if isinstance(data, dict): + data = [data] + results: list[tuple[str, float, str]] = [] + for entry in data: + name = (entry.get("Name") or "").strip() + ram = entry.get("AdapterRAM") or 0 + # Win32_VideoController caps AdapterRAM at 4GB on many systems. Trust value but flag below. + vram_gb = float(ram) / (1024 ** 3) if ram else 0.0 + low = name.lower() + if "nvidia" in low or "geforce" in low or "rtx" in low or "gtx" in low: + vendor = "nvidia" + elif "amd" in low or "radeon" in low or "rx " in low: + vendor = "amd" + elif "intel" in low or "arc" in low or "iris" in low: + vendor = "intel" + else: + vendor = "unknown" + if name: + results.append((name, vram_gb, vendor)) + return results + + +def _detect_linux_gpus() -> list[tuple[str, float, str]]: + """Enumerate Linux GPUs. Returns list of (name, vram_gb, vendor_hint).""" + if platform.system() != "Linux": + return [] + results: list[tuple[str, float, str]] = [] + + # AMD via rocm-smi (most accurate VRAM) + rocm = _run(["rocm-smi", "--showproductname", "--showmeminfo", "vram", "--json"]) + if rocm: + import json as _json + try: + data = _json.loads(rocm) + for k, v in data.items(): + if not isinstance(v, dict): + continue + name = v.get("Card series") or v.get("Card model") or v.get("Device Name") or k + vram_bytes = 0 + for key in ("VRAM Total Memory (B)", "vram total memory (b)", "VRAM Total"): + if key in v: + try: + vram_bytes = int(v[key]) + break + except (TypeError, ValueError): + pass + vram_gb = float(vram_bytes) / (1024 ** 3) if vram_bytes else 0.0 + results.append((str(name).strip(), vram_gb, "amd")) + except (ValueError, KeyError): + pass + + if results: + return results + + # Fallback: lspci for vendor hints (no reliable VRAM number) + lspci = _run(["bash", "-c", "lspci -nn | grep -Ei 'vga|3d|display'"]) + for line in lspci.splitlines(): + low = line.lower() + if "nvidia" in low: + vendor = "nvidia" + elif "amd" in low or "advanced micro" in low or "ati " in low: + vendor = "amd" + elif "intel" in low: + vendor = "intel" + else: + continue + # Extract device name after the colon + name = line.split(":", 2)[-1].strip() if ":" in line else line.strip() + results.append((name, 0.0, vendor)) + return results + + +def _vendor_from_backend(backend: str) -> Vendor: + return { + "cuda": "nvidia", + "rocm": "amd", + "directml": "amd", + "mps": "apple", + "cpu": "cpu", + }.get(backend, "cpu") # type: ignore[return-value] + + +def _vram_tier(vram_gb: float) -> Literal["cpu", "low", "mid", "high", "ultra"]: + if vram_gb < 1: + return "cpu" + if vram_gb < 6: + return "low" + if vram_gb < 10: + return "mid" + if vram_gb < 14: + return "high" + return "ultra" + + +def detect(force_backend: str | None = None, force_vendor: str | None = None) -> HardwareInfo: + """Auto-detect hardware. If force_backend is set (cuda/rocm/directml/cpu), skip detection + for that decision but still try to discover device name + VRAM for tier sizing.""" + is_linux = platform.system() == "Linux" + is_windows = platform.system() == "Windows" + is_mac = platform.system() == "Darwin" + + # Validate forced backend against platform + if force_backend == "directml" and not is_windows: + raise RuntimeError( + "DirectML backend is Windows-only. On Linux use --backend rocm (AMD), " + "--backend cuda (NVIDIA), or --backend cpu. On macOS use --backend mps." + ) + if force_backend == "mps" and not is_mac: + raise RuntimeError("MPS backend is macOS-only (Apple Silicon).") + if force_backend == "rocm" and is_mac: + raise RuntimeError("ROCm is not available on macOS. Use --backend mps or --backend cpu.") + if force_backend == "cuda" and is_mac: + raise RuntimeError("CUDA is not available on macOS. Use --backend mps or --backend cpu.") + + # Gather best-effort device info regardless of force. + nv = _detect_nvidia() if not is_mac else None + win_adapters = _detect_dxgi() if is_windows else [] + win_adapters = [a for a in win_adapters if "basic" not in a[0].lower() and "microsoft" not in a[0].lower()] + lin_gpus = _detect_linux_gpus() if is_linux else [] + mac_gpu = _detect_mac_gpu() if is_mac else None + + def _best_for(vendor: str) -> tuple[str, float] | None: + if vendor == "nvidia" and nv: + return nv + pool = [(n, v) for n, v, h in (win_adapters + lin_gpus) if h == vendor] + if not pool: + return None + pool.sort(key=lambda x: x[1], reverse=True) + return pool[0] + + if force_backend and force_backend != "auto": + vendor: Vendor = force_vendor or _vendor_from_backend(force_backend) # type: ignore[assignment] + if force_backend == "mps": + match = mac_gpu + else: + match = _best_for(vendor) or (nv if force_backend == "cuda" else None) + if match: + name, vram = match + else: + # No matching device found, but user forced this backend — proceed with unknown VRAM. + name = f"{vendor.upper()} (forced)" + vram = 0.0 + return HardwareInfo(vendor, force_backend, name, vram, _vram_tier(vram)) # type: ignore[arg-type] + + # Auto path + if is_mac and mac_gpu: + name, vram = mac_gpu + return HardwareInfo("apple", "mps", name, vram, _vram_tier(vram)) + + if nv: + name, vram = nv + return HardwareInfo("nvidia", "cuda", name, vram, _vram_tier(vram)) + + if is_windows and win_adapters: + win_adapters.sort(key=lambda a: a[1], reverse=True) + name, vram, hint = win_adapters[0] + # AdapterRAM caps at 4GB for many cards. Bump if name suggests modern discrete. + if vram <= 4.1 and any(k in name.lower() for k in ("rx 6", "rx 7", "arc a", "arc b")): + vram = 8.0 + if hint in ("amd", "intel"): + return HardwareInfo(hint, "directml", name, vram, _vram_tier(vram)) + if hint == "nvidia": + return HardwareInfo("nvidia", "directml", name, vram, _vram_tier(vram)) + + if is_linux and lin_gpus: + lin_gpus.sort(key=lambda a: a[1], reverse=True) + name, vram, hint = lin_gpus[0] + if hint == "amd": + return HardwareInfo("amd", "rocm", name, vram, _vram_tier(vram)) + if hint == "nvidia": + # nvidia-smi failed but card is nvidia: drivers likely missing. Fall through to CPU. + return HardwareInfo("cpu", "cpu", f"NVIDIA driver missing — {name}", 0.0, "cpu") + if hint == "intel": + # No good Intel-on-Linux torch path here; default to CPU. + return HardwareInfo("cpu", "cpu", f"Intel GPU (no backend) — {name}", 0.0, "cpu") + + return HardwareInfo("cpu", "cpu", platform.processor() or "CPU", 0.0, "cpu") + + +def directml_supported() -> bool: + """torch-directml ships wheels for Python 3.10 / 3.11. The launcher pins the + venv to 3.11, so this is True in the running app.""" + import sys + return sys.version_info[:2] in {(3, 10), (3, 11)} + + +def torch_install_args(info: HardwareInfo) -> list[str]: + """Return uv-pip install args for the right PyTorch build. + + Launcher pins venv to Python 3.11, so torch-directml wheels are always available + for AMD/Intel paths. Latest stable torch is used elsewhere. + """ + if info.backend == "cuda": + return [ + "torch", + "torchvision", + "--index-url", + "https://download.pytorch.org/whl/cu124", + ] + if info.backend == "rocm": + # ROCm wheels are Linux-only. Index pinned to a stable ROCm release line. + return [ + "torch", + "torchvision", + "--index-url", + "https://download.pytorch.org/whl/rocm6.1", + ] + if info.backend == "directml": + # torch-directml currently pins to torch 2.4.x. Match it. + return [ + "torch>=2.4,<2.5", + "torchvision>=0.19,<0.20", + "torch-directml>=0.2.5", + ] + if info.backend == "mps": + # Default PyPI torch wheel ships MPS support on macOS arm64. No custom index. + return ["torch", "torchvision"] + # CPU. macOS uses default PyPI wheels (no /whl/cpu index for darwin). + if platform.system() == "Darwin": + return ["torch", "torchvision"] + return [ + "torch", + "torchvision", + "--index-url", + "https://download.pytorch.org/whl/cpu", + ] + + +if __name__ == "__main__": + info = detect() + print(f"Vendor: {info.vendor}") + print(f"Backend: {info.backend}") + print(f"Device: {info.device_name}") + print(f"VRAM: {info.vram_gb:.1f} GB") + print(f"Tier: {info.tier}") diff --git a/backends/image_sdxl.py b/backends/image_sdxl.py new file mode 100644 index 0000000..68c393e --- /dev/null +++ b/backends/image_sdxl.py @@ -0,0 +1,92 @@ +"""SDXL-family image generation. Loaded lazily, cached by repo id.""" +from __future__ import annotations + +from functools import lru_cache +from pathlib import Path +from typing import Optional + +from PIL import Image + +from .device import get_device, torch_dtype +from .memory import apply_memory_strategy +from .models import ModelSpec, find + +ROOT = Path(__file__).parent.parent +MODELS_DIR = ROOT / "models" + + +@lru_cache(maxsize=1) +def _load_pipeline(repo: str, kind: str): + import torch + from diffusers import ( + AutoPipelineForText2Image, + DiffusionPipeline, + FluxPipeline, + ) + + device = get_device() + dtype = torch_dtype() + cache_dir = str(MODELS_DIR / "diffusers") + + if "FLUX" in repo: + pipe = FluxPipeline.from_pretrained(repo, torch_dtype=dtype, cache_dir=cache_dir) + elif "turbo" in repo.lower(): + pipe = AutoPipelineForText2Image.from_pretrained( + repo, torch_dtype=dtype, variant="fp16", cache_dir=cache_dir + ) + else: + pipe = DiffusionPipeline.from_pretrained( + repo, torch_dtype=dtype, use_safetensors=True, cache_dir=cache_dir + ) + + # Disable built-in safety checker. We use our own filter (CSAM-only). + if hasattr(pipe, "safety_checker"): + pipe.safety_checker = None + if hasattr(pipe, "requires_safety_checker"): + pipe.requires_safety_checker = False + + apply_memory_strategy(pipe) + return pipe + + +def generate( + prompt: str, + negative_prompt: str = "", + model_id: str = "pony-xl", + width: int = 1024, + height: int = 1024, + steps: int = 30, + guidance: float = 7.0, + seed: Optional[int] = None, +) -> Image.Image: + spec: ModelSpec | None = find(model_id) + if spec is None or spec.kind != "image": + raise ValueError(f"Unknown image model: {model_id}") + + pipe = _load_pipeline(spec.repo, spec.kind) + + import torch + + generator = None + if seed is not None: + try: + generator = torch.Generator(device=get_device()).manual_seed(int(seed)) + except RuntimeError: + generator = torch.Generator().manual_seed(int(seed)) + + if "turbo" in spec.repo.lower(): + steps = max(1, min(steps, 4)) + guidance = 0.0 + if "FLUX" in spec.repo: + guidance = 0.0 + + out = pipe( + prompt=prompt, + negative_prompt=negative_prompt or None, + width=width, + height=height, + num_inference_steps=steps, + guidance_scale=guidance, + generator=generator, + ) + return out.images[0] diff --git a/backends/memory.py b/backends/memory.py new file mode 100644 index 0000000..c794dc4 --- /dev/null +++ b/backends/memory.py @@ -0,0 +1,83 @@ +"""Memory strategy selection. CUDA supports cpu_offload; DirectML/CPU do not. + +Apply per-pipeline based on backend + VRAM tier. All paths reduce peak VRAM +without breaking on non-CUDA devices. +""" +from __future__ import annotations + +from .device import get_device, hardware_info + + +def apply_memory_strategy(pipe) -> None: + """Apply VRAM-saving knobs that match the active backend.""" + info = hardware_info() + backend = info["backend"] + vram = info["vram_gb"] + + # Always-safe: VAE tiling/slicing work on any device. Cuts peak VRAM during decode. + # Newer diffusers (>=0.32) prefers calling on the VAE directly. + vae = getattr(pipe, "vae", None) + if vae is not None: + for fn in ("enable_slicing", "enable_tiling"): + if hasattr(vae, fn): + try: + getattr(vae, fn)() + except Exception: + pass + if hasattr(pipe, "enable_attention_slicing"): + try: + pipe.enable_attention_slicing() + except Exception: + pass + + if backend in ("cuda", "rocm"): + # ROCm builds expose the cuda API, so accelerate offload hooks work the same way. + # Offload only if VRAM tight. + if vram < 10: + try: + pipe.enable_sequential_cpu_offload() + return + except Exception: + pass + try: + pipe.enable_model_cpu_offload() + return + except Exception: + pass + pipe.to(get_device()) + return + + if backend == "directml": + # DirectML lacks accelerate hook support. Move whole pipe to device. + # Slicing already enabled above keeps peak in check. + try: + pipe.to(get_device()) + except Exception: + # Some pipes have components that won't move cleanly; fall back to CPU. + pipe.to("cpu") + return + + if backend == "mps": + # Apple Silicon shares unified memory with CPU. accelerate's sequential offload + # has spotty MPS support; rely on slicing/tiling already enabled above. + # Tight memory tier: keep on CPU and let model_cpu_offload move chunks if available. + if vram < 12: + try: + pipe.enable_model_cpu_offload(device="mps") + return + except TypeError: + try: + pipe.enable_model_cpu_offload() + return + except Exception: + pass + except Exception: + pass + try: + pipe.to(get_device()) + except Exception: + pipe.to("cpu") + return + + # CPU + pipe.to("cpu") diff --git a/backends/models.py b/backends/models.py new file mode 100644 index 0000000..cbf9721 --- /dev/null +++ b/backends/models.py @@ -0,0 +1,163 @@ +"""Model registry. VRAM tier drives default. User can override.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +Kind = Literal["image", "video"] + + +@dataclass(frozen=True) +class ModelSpec: + id: str + label: str + repo: str + kind: Kind + min_vram_gb: float + nsfw_capable: bool + download_gb: float = 0.0 + notes: str = "" + + +IMAGE_MODELS: list[ModelSpec] = [ + ModelSpec( + id="sdxl-turbo", + label="SDXL Turbo (fast, low VRAM)", + repo="stabilityai/sdxl-turbo", + kind="image", + min_vram_gb=4.0, + nsfw_capable=False, + download_gb=7.0, + notes="1-4 step sampling. Censored base. Good for safe content on weak GPUs.", + ), + ModelSpec( + id="pony-xl", + label="Pony Diffusion XL v6 (NSFW, anime/realistic)", + repo="AstraliteHeart/pony-diffusion-v6", + kind="image", + min_vram_gb=8.0, + nsfw_capable=True, + download_gb=7.0, + notes="Use score_9, score_8_up tags. Strong NSFW capability.", + ), + ModelSpec( + id="illustrious-xl", + label="Illustrious XL v0.1 (NSFW, illustration)", + repo="OnomaAIResearch/Illustrious-xl-early-release-v0", + kind="image", + min_vram_gb=10.0, + nsfw_capable=True, + download_gb=7.0, + notes="High-detail anime/illustration. NSFW-capable.", + ), + ModelSpec( + id="sdxl-base", + label="SDXL Base 1.0 (versatile)", + repo="stabilityai/stable-diffusion-xl-base-1.0", + kind="image", + min_vram_gb=8.0, + nsfw_capable=False, + download_gb=7.0, + notes="General-purpose. Censored.", + ), + ModelSpec( + id="flux-schnell", + label="FLUX.1 Schnell (high quality, 12GB+)", + repo="black-forest-labs/FLUX.1-schnell", + kind="image", + min_vram_gb=12.0, + nsfw_capable=False, + download_gb=24.0, + notes="State-of-the-art prompt adherence. 4-step.", + ), +] + +VIDEO_MODELS: list[ModelSpec] = [ + ModelSpec( + id="ltx-video", + label="LTX-Video 0.9 (fast, 8GB+)", + repo="Lightricks/LTX-Video", + kind="video", + min_vram_gb=8.0, + nsfw_capable=True, + download_gb=18.0, + notes="2-5 second clips at 24fps. Fast generation.", + ), + ModelSpec( + id="cogvideox-5b", + label="CogVideoX 5B (12GB+)", + repo="THUDM/CogVideoX-5b", + kind="video", + min_vram_gb=12.0, + nsfw_capable=False, + download_gb=20.0, + notes="6 second clips. Good motion quality.", + ), + ModelSpec( + id="wan-2-1", + label="Wan 2.1 T2V 1.3B (16GB+)", + repo="Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + kind="video", + min_vram_gb=14.0, + nsfw_capable=True, + download_gb=16.0, + notes="Top-tier open video model. Needs diffusers>=0.32.", + ), +] + + +def default_for_tier(tier: str, kind: Kind) -> ModelSpec | None: + pool = IMAGE_MODELS if kind == "image" else VIDEO_MODELS + tier_min = {"cpu": 0.0, "low": 0.0, "mid": 8.0, "high": 10.0, "ultra": 12.0} + target = tier_min.get(tier, 0.0) + if tier == "cpu" and kind == "video": + return None + if tier in ("cpu", "low") and kind == "image": + return next(m for m in pool if m.id == "sdxl-turbo") + if tier == "mid" and kind == "image": + return next(m for m in pool if m.id == "pony-xl") + if tier == "high" and kind == "image": + return next(m for m in pool if m.id == "illustrious-xl") + if tier == "ultra" and kind == "image": + return next(m for m in pool if m.id == "illustrious-xl") + if tier in ("low",) and kind == "video": + return None + if tier == "mid" and kind == "video": + return next(m for m in pool if m.id == "ltx-video") + if tier == "high" and kind == "video": + return next(m for m in pool if m.id == "ltx-video") + if tier == "ultra" and kind == "video": + return next(m for m in pool if m.id == "wan-2-1") + return next((m for m in pool if m.min_vram_gb <= target), None) + + +def list_for_tier(tier: str, kind: Kind) -> list[ModelSpec]: + pool = IMAGE_MODELS if kind == "image" else VIDEO_MODELS + tier_max = {"cpu": 4.0, "low": 6.0, "mid": 10.0, "high": 14.0, "ultra": 1000.0} + cap = tier_max.get(tier, 1000.0) + out = [m for m in pool if m.min_vram_gb <= cap + 2.0] + out.sort(key=lambda m: m.min_vram_gb) + return out + + +def find(model_id: str) -> ModelSpec | None: + for m in IMAGE_MODELS + VIDEO_MODELS: + if m.id == model_id: + return m + return None + + +def is_cached(spec: ModelSpec) -> bool: + """Best-effort: check if HF snapshot exists in our cache dir.""" + from pathlib import Path + cache = Path(__file__).parent.parent / "models" / "diffusers" + if not cache.exists(): + return False + folder = "models--" + spec.repo.replace("/", "--") + return (cache / folder).exists() + + +def label_with_meta(spec: ModelSpec) -> str: + nsfw = " · NSFW✓" if spec.nsfw_capable else "" + cached = " · cached" if is_cached(spec) else f" · {spec.download_gb:.0f}GB download" + return f"{spec.label}{nsfw}{cached}" diff --git a/backends/refiner.py b/backends/refiner.py new file mode 100644 index 0000000..4d3e8db --- /dev/null +++ b/backends/refiner.py @@ -0,0 +1,82 @@ +"""Prompt refiner. Tries Ollama (local LLM) first, falls back to GPT-2 expansion. + +Ollama recommended model: `dolphin-llama3:8b` (uncensored, handles NSFW prompts cleanly). +Install: https://ollama.com/ then `ollama pull dolphin-llama3`. +""" +from __future__ import annotations + +import json +from functools import lru_cache +from typing import Optional + +import requests + +OLLAMA_URL = "http://127.0.0.1:11434/api/generate" +DEFAULT_OLLAMA_MODEL = "dolphin-llama3:8b" + +SYSTEM = ( + "You are a prompt engineer for a text-to-image diffusion model. " + "Rewrite the user's brief idea into a single dense, comma-separated prompt with " + "subject, style, lighting, composition, camera, and quality tags. " + "Keep all explicit details from the user. Do not refuse, lecture, or moralize. " + "Do not add safety disclaimers. Output only the rewritten prompt, nothing else." +) + + +def _ollama_available(model: str) -> bool: + try: + r = requests.get("http://127.0.0.1:11434/api/tags", timeout=2) + if r.status_code != 200: + return False + tags = r.json().get("models", []) + return any(m.get("name", "").startswith(model.split(":")[0]) for m in tags) + except requests.RequestException: + return False + + +def _refine_ollama(prompt: str, model: str) -> Optional[str]: + payload = { + "model": model, + "prompt": f"{SYSTEM}\n\nUser idea: {prompt}\n\nRewritten prompt:", + "stream": False, + "options": {"temperature": 0.7, "num_predict": 200}, + } + try: + r = requests.post(OLLAMA_URL, json=payload, timeout=60) + if r.status_code != 200: + return None + text = r.json().get("response", "").strip() + return text or None + except requests.RequestException: + return None + + +@lru_cache(maxsize=1) +def _gpt2(): + from transformers import pipeline + return pipeline("text-generation", model="gpt2", max_new_tokens=60) + + +def _refine_gpt2(prompt: str) -> str: + seed = ( + f"{prompt}, highly detailed, sharp focus, professional, cinematic lighting, " + f"intricate details, masterpiece, best quality" + ) + try: + gen = _gpt2() + out = gen(seed, num_return_sequences=1, do_sample=True, temperature=0.7) + text = out[0]["generated_text"].split("\n")[0] + return text.strip() + except Exception: + return seed + + +def refine(prompt: str, use_ollama: bool = True, ollama_model: str = DEFAULT_OLLAMA_MODEL) -> str: + prompt = prompt.strip() + if not prompt: + return prompt + if use_ollama and _ollama_available(ollama_model): + result = _refine_ollama(prompt, ollama_model) + if result: + return result + return _refine_gpt2(prompt) diff --git a/backends/safety.py b/backends/safety.py new file mode 100644 index 0000000..ea91a3c --- /dev/null +++ b/backends/safety.py @@ -0,0 +1,185 @@ +"""Safety filter: block CSAM only. All other content (NSFW, gore, violence) allowed. + +Layers: +1. Prompt keyword gate: rejects child-term + sexual-term combinations. +2. Output check: NudeNet detects nudity. If nudity present, run face detection + (MTCNN) + age classifier (ViT) on every face. Block if any face is classified + as a minor with high confidence. + +This is best-effort. The user is legally responsible for use of generated content. +""" +from __future__ import annotations + +import re +import tempfile +from dataclasses import dataclass +from functools import lru_cache +from typing import Iterable + +from PIL import Image + +# --- prompt keyword gate --------------------------------------------------- + +_CHILD_TERMS = [ + r"\b(child|children|kid|kids|minor|underage|under-?age|preteen|pre-?teen)\b", + r"\b(toddler|infant|baby|babies)\b", + r"\b(\d|0?[0-9]|1[0-7])\s*(yo|y/o|year[- ]?old)\b", + r"\bloli(con)?\b", + r"\bshota(con)?\b", + r"\bcp\b", +] +_SEXUAL_TERMS = [ + r"\b(nude|naked|nsfw|porn|sex|sexual|sexy|erotic|explicit)\b", + r"\b(penis|vagina|breast|nipple|genital|cum|orgasm)\b", + r"\b(intercourse|fellatio|cunnilingus|masturbat)\w*\b", + r"\b(rape|molest)\w*\b", +] + +_CHILD_RE = re.compile("|".join(_CHILD_TERMS), re.IGNORECASE) +_SEX_RE = re.compile("|".join(_SEXUAL_TERMS), re.IGNORECASE) + + +@dataclass +class SafetyResult: + allowed: bool + reason: str = "" + + +def check_prompt(prompt: str) -> SafetyResult: + if _CHILD_RE.search(prompt) and _SEX_RE.search(prompt): + return SafetyResult(False, "Prompt blocked: combines minor and sexual terms (CSAM gate).") + return SafetyResult(True) + + +# --- nudity detection ------------------------------------------------------ + +_NUDITY_LABELS = { + "FEMALE_BREAST_EXPOSED", + "FEMALE_GENITALIA_EXPOSED", + "MALE_GENITALIA_EXPOSED", + "BUTTOCKS_EXPOSED", + "ANUS_EXPOSED", +} + + +@lru_cache(maxsize=1) +def _nudenet(): + try: + from nudenet import NudeDetector + return NudeDetector() + except Exception: + return None + + +def _has_nudity(detections: Iterable[dict]) -> bool: + for d in detections: + label = d.get("class") or d.get("label") or "" + score = float(d.get("score", 0.0)) + if label in _NUDITY_LABELS and score >= 0.5: + return True + return False + + +# --- face detection + age classification ----------------------------------- + +@lru_cache(maxsize=1) +def _mtcnn(): + try: + from facenet_pytorch import MTCNN + import torch + device = "cuda" if torch.cuda.is_available() else "cpu" + return MTCNN(keep_all=True, device=device, post_process=False, min_face_size=40) + except Exception: + return None + + +@lru_cache(maxsize=1) +def _age_classifier(): + try: + from transformers import pipeline + return pipeline( + "image-classification", + model="nateraw/vit-age-classifier", + top_k=3, + ) + except Exception: + return None + + +# Buckets reported by nateraw/vit-age-classifier. +# Conservative minor set. "10-19" includes some adults — treat as minor only on +# strong confidence to limit false positives on young-looking adults. +_HARD_MINOR = {"0-2", "3-9"} +_SOFT_MINOR = {"10-19"} + + +def _faces(img: Image.Image): + mtcnn = _mtcnn() + if mtcnn is None: + return [] + try: + boxes, probs = mtcnn.detect(img) + except Exception: + return [] + if boxes is None: + return [] + if probs is None: + probs = [None] * len(boxes) + out = [] + for box, prob in zip(boxes, probs): + if prob is None or float(prob) < 0.9: + continue + x1, y1, x2, y2 = [int(max(0, v)) for v in box] + if x2 - x1 < 30 or y2 - y1 < 30: + continue + out.append(img.crop((x1, y1, x2, y2))) + return out + + +def _is_minor_face(face_img: Image.Image) -> tuple[bool, str]: + clf = _age_classifier() + if clf is None: + return False, "" + try: + preds = clf(face_img) + except Exception: + return False, "" + # preds is list[dict(label, score)] sorted by score desc + top = preds[0] if preds else None + if not top: + return False, "" + label = top["label"] + score = float(top["score"]) + if label in _HARD_MINOR and score >= 0.55: + return True, f"minor face detected ({label}, conf={score:.2f})" + if label in _SOFT_MINOR and score >= 0.85: + return True, f"likely minor face ({label}, conf={score:.2f})" + return False, "" + + +def check_image(img: Image.Image) -> SafetyResult: + """Block if (nudity present) AND (any face classified as minor).""" + det = _nudenet() + if det is None: + # No nudity detector available — fall through. Prompt gate is primary defense. + return SafetyResult(True) + + try: + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + img.save(f.name) + results = det.detect(f.name) + except Exception: + return SafetyResult(True) + + if not _has_nudity(results): + return SafetyResult(True) + + faces = _faces(img) + for face in faces: + is_minor, reason = _is_minor_face(face) + if is_minor: + return SafetyResult( + False, + f"Output blocked: nudity + {reason}. Image discarded.", + ) + return SafetyResult(True) diff --git a/backends/video_ltx.py b/backends/video_ltx.py new file mode 100644 index 0000000..4455df5 --- /dev/null +++ b/backends/video_ltx.py @@ -0,0 +1,118 @@ +"""Video generation. LTX-Video / CogVideoX / Wan via diffusers pipelines.""" +from __future__ import annotations + +from functools import lru_cache +from pathlib import Path +from typing import Optional + +from .device import get_device, hardware_info, torch_dtype +from .memory import apply_memory_strategy +from .models import ModelSpec, find + +ROOT = Path(__file__).parent.parent +MODELS_DIR = ROOT / "models" +OUTPUTS = ROOT / "outputs" +OUTPUTS.mkdir(exist_ok=True) + + +def _import_pipeline_for(repo: str): + """Pick the right pipeline class. Wan needs diffusers>=0.32 with WanPipeline.""" + if "LTX" in repo: + from diffusers import LTXPipeline + return LTXPipeline + if "CogVideoX" in repo: + from diffusers import CogVideoXPipeline + return CogVideoXPipeline + if "Wan" in repo: + try: + from diffusers import WanPipeline + return WanPipeline + except ImportError as e: + raise RuntimeError( + "Wan 2.1 needs diffusers>=0.32. Run: pip install -U diffusers" + ) from e + if "Hunyuan" in repo: + from diffusers import HunyuanVideoPipeline + return HunyuanVideoPipeline + from diffusers import DiffusionPipeline + return DiffusionPipeline + + +@lru_cache(maxsize=1) +def _load_pipeline(repo: str): + dtype = torch_dtype() + cache_dir = str(MODELS_DIR / "diffusers") + + PipelineCls = _import_pipeline_for(repo) + pipe = PipelineCls.from_pretrained(repo, torch_dtype=dtype, cache_dir=cache_dir) + + apply_memory_strategy(pipe) + return pipe + + +def _model_kwargs(repo: str, base: dict) -> dict: + """Adjust kwargs per-pipeline. Some pipes don't accept width/height or have different param names.""" + out = dict(base) + if "Wan" in repo: + # Wan accepts height/width but expects multiples of 16. + out["height"] = (out["height"] // 16) * 16 + out["width"] = (out["width"] // 16) * 16 + if "LTX" in repo: + # LTX needs 32-multiple resolutions. + out["height"] = (out["height"] // 32) * 32 + out["width"] = (out["width"] // 32) * 32 + if "CogVideoX" in repo: + # CogVideoX has fixed 720x480 default; clamp. + out["height"] = min(out["height"], 480) + out["width"] = min(out["width"], 720) + return out + + +def generate( + prompt: str, + negative_prompt: str = "", + model_id: str = "ltx-video", + width: int = 704, + height: int = 480, + num_frames: int = 73, + fps: int = 24, + steps: int = 30, + guidance: float = 3.0, + seed: Optional[int] = None, +) -> str: + spec: ModelSpec | None = find(model_id) + if spec is None or spec.kind != "video": + raise ValueError(f"Unknown video model: {model_id}") + + pipe = _load_pipeline(spec.repo) + + import torch + from diffusers.utils import export_to_video + + generator = None + if seed is not None: + try: + generator = torch.Generator(device=get_device()).manual_seed(int(seed)) + except RuntimeError: + generator = torch.Generator().manual_seed(int(seed)) + + base = dict( + prompt=prompt, + negative_prompt=negative_prompt or None, + width=int(width), + height=int(height), + num_frames=int(num_frames), + num_inference_steps=int(steps), + guidance_scale=float(guidance), + generator=generator, + ) + kwargs = _model_kwargs(spec.repo, base) + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + out = pipe(**kwargs) + frames = out.frames[0] + + import time + path = OUTPUTS / f"video_{int(time.time())}.mp4" + export_to_video(frames, str(path), fps=int(fps)) + return str(path) diff --git a/config.json b/config.json new file mode 100644 index 0000000..b358c05 --- /dev/null +++ b/config.json @@ -0,0 +1,31 @@ +{ + "ui": { + "host": "127.0.0.1", + "port": 7860, + "open_browser": true, + "theme": "default" + }, + "refiner": { + "use_ollama": true, + "ollama_model": "dolphin-llama3:8b", + "auto_refine_threshold_words": 8 + }, + "image_defaults": { + "width": 1024, + "height": 1024, + "steps": 30, + "guidance": 7.0, + "negative_prompt": "blurry, low quality, watermark, text, signature, deformed, bad anatomy, extra limbs" + }, + "video_defaults": { + "width": 704, + "height": 480, + "num_frames": 73, + "fps": 24, + "steps": 30, + "guidance": 3.0 + }, + "safety": { + "csam_gate": true + } +} diff --git a/launcher.py b/launcher.py new file mode 100644 index 0000000..477459f --- /dev/null +++ b/launcher.py @@ -0,0 +1,314 @@ +"""First-run launcher. + +Bootstraps a clean Python 3.11 environment via `uv`, regardless of the system +Python the user invoked us with. Keeps the user on a single supported runtime +while the AI ecosystem stabilizes around newer Python versions. + +uv install strategy (in order): +1. Direct binary download from GitHub releases (no shell, no admin). +2. PowerShell / shell installer. +3. pip fallback, invoked as `python -m uv`. +""" +from __future__ import annotations + +import argparse +import io +import json +import os +import platform +import shutil +import subprocess +import sys +import urllib.request +import zipfile +from pathlib import Path + +ROOT = Path(__file__).parent.resolve() +VENV_DIR = ROOT / "venv" +MARKER = VENV_DIR / ".kawai_ready" +HARDWARE_CACHE = ROOT / "config.local.json" +UV_CACHE_DIR = ROOT / ".tools" +PYTHON_TARGET = "3.11" + + +def venv_python() -> Path: + if os.name == "nt": + return VENV_DIR / "Scripts" / "python.exe" + return VENV_DIR / "bin" / "python" + + +# --- uv discovery / install ------------------------------------------------ + +def _local_uv_path() -> Path: + return UV_CACHE_DIR / ("uv.exe" if os.name == "nt" else "uv") + + +def _uv_argv() -> list[str] | None: + """Return argv prefix to invoke uv. None if not available.""" + local = _local_uv_path() + if local.exists(): + return [str(local)] + found = shutil.which("uv") + if found: + return [found] + # Installed via pip into current Python (works as module). + try: + subprocess.check_output( + [sys.executable, "-m", "uv", "--version"], + stderr=subprocess.STDOUT, + timeout=10, + ) + return [sys.executable, "-m", "uv"] + except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired): + return None + + +def _uv_release_asset() -> str | None: + """Pick the GitHub release asset name for this OS+arch.""" + machine = platform.machine().lower() + arch_win = "x86_64" if machine in ("amd64", "x86_64") else "aarch64" if "arm" in machine else None + if os.name == "nt" and arch_win: + return f"uv-{arch_win}-pc-windows-msvc.zip" + if sys.platform == "darwin": + arch = "aarch64" if "arm" in machine else "x86_64" + return f"uv-{arch}-apple-darwin.tar.gz" + if sys.platform.startswith("linux"): + arch = "aarch64" if "aarch64" in machine or "arm64" in machine else "x86_64" + return f"uv-{arch}-unknown-linux-gnu.tar.gz" + return None + + +def _download_uv() -> bool: + """Download uv binary directly from GitHub releases. Most reliable path.""" + asset = _uv_release_asset() + if asset is None: + return False + url = f"https://github.com/astral-sh/uv/releases/latest/download/{asset}" + UV_CACHE_DIR.mkdir(parents=True, exist_ok=True) + print(f"[kawai] Downloading uv from {url}") + try: + with urllib.request.urlopen(url, timeout=60) as resp: + data = resp.read() + except Exception as e: + print(f"[kawai] download failed: {e}") + return False + + try: + if asset.endswith(".zip"): + with zipfile.ZipFile(io.BytesIO(data)) as z: + for name in z.namelist(): + if name.endswith("uv.exe") or name.endswith("/uv"): + target = _local_uv_path() + target.write_bytes(z.read(name)) + if os.name != "nt": + target.chmod(0o755) + return target.exists() + else: + import tarfile + with tarfile.open(fileobj=io.BytesIO(data), mode="r:gz") as t: + for member in t.getmembers(): + if member.name.endswith("/uv") or member.name == "uv": + f = t.extractfile(member) + if f is None: + continue + target = _local_uv_path() + target.write_bytes(f.read()) + target.chmod(0o755) + return target.exists() + except Exception as e: + print(f"[kawai] extract failed: {e}") + return False + return False + + +def _install_uv_via_shell() -> bool: + """Use astral.sh installer scripts. Often blocked on locked-down systems.""" + UV_CACHE_DIR.mkdir(parents=True, exist_ok=True) + env = os.environ.copy() + env["UV_INSTALL_DIR"] = str(UV_CACHE_DIR) + env["UV_NO_MODIFY_PATH"] = "1" + try: + if os.name == "nt": + subprocess.check_call( + [ + "powershell", + "-NoProfile", + "-ExecutionPolicy", "Bypass", + "-Command", + "irm https://astral.sh/uv/install.ps1 | iex", + ], + env=env, + ) + else: + subprocess.check_call( + ["bash", "-c", "curl -LsSf https://astral.sh/uv/install.sh | sh"], + env=env, + ) + except (subprocess.CalledProcessError, FileNotFoundError): + return False + return _local_uv_path().exists() + + +def _install_uv_via_pip() -> bool: + print("[kawai] Installing uv via pip...") + try: + subprocess.check_call([sys.executable, "-m", "pip", "install", "--user", "--upgrade", "uv"]) + except subprocess.CalledProcessError: + return False + # Verify it's invokable as a module. + try: + subprocess.check_output( + [sys.executable, "-m", "uv", "--version"], + stderr=subprocess.STDOUT, + timeout=10, + ) + return True + except Exception: + return False + + +def _ensure_uv() -> list[str]: + argv = _uv_argv() + if argv: + return argv + + print("[kawai] Installing uv...") + if _download_uv(): + return _uv_argv() or [] + + if _install_uv_via_shell(): + argv = _uv_argv() + if argv: + return argv + + if _install_uv_via_pip(): + argv = _uv_argv() + if argv: + return argv + + raise RuntimeError( + "Failed to install uv. Install manually from https://astral.sh/uv " + "and rerun this launcher." + ) + + +# --- venv + deps ----------------------------------------------------------- + +def _create_venv(uv: list[str]) -> None: + if venv_python().exists(): + return + print(f"[kawai] Creating venv with Python {PYTHON_TARGET} (uv will download it if needed)...") + subprocess.check_call([*uv, "venv", str(VENV_DIR), "--python", PYTHON_TARGET]) + + +def _uv_pip(uv: list[str], args: list[str]) -> None: + cmd = [*uv, "pip", "install", "--python", str(venv_python()), *args] + print(f"[kawai] uv pip install {' '.join(args)}") + subprocess.check_call(cmd) + + +def detect_and_install( + uv: list[str], + force_backend: str | None = None, + force_vendor: str | None = None, +) -> dict: + sys.path.insert(0, str(ROOT)) + from backends import hardware + + info = hardware.detect(force_backend=force_backend, force_vendor=force_vendor) + forced_note = " (forced)" if force_backend and force_backend != "auto" else "" + print( + f"[kawai] Backend: {info.backend}{forced_note} | " + f"{info.vendor} / {info.device_name} / {info.vram_gb:.1f} GB / tier={info.tier}" + ) + + _uv_pip(uv, hardware.torch_install_args(info)) + _uv_pip(uv, ["-r", str(ROOT / "requirements.txt")]) + + payload = { + "vendor": info.vendor, + "backend": info.backend, + "device_name": info.device_name, + "vram_gb": info.vram_gb, + "tier": info.tier, + "forced": bool(force_backend and force_backend != "auto"), + } + HARDWARE_CACHE.write_text(json.dumps(payload, indent=2)) + MARKER.write_text("ok") + return payload + + +def already_in_venv() -> bool: + try: + return Path(sys.executable).resolve() == venv_python().resolve() + except OSError: + return False + + +def relaunch_in_venv(forwarded_args: list[str]) -> None: + """Re-exec the launcher inside the venv. Use subprocess on Windows because + os.execv mangles argv with spaces in paths.""" + print("[kawai] Relaunching inside venv...") + py = str(venv_python()) + script = str(ROOT / "launcher.py") + argv = [py, script, *forwarded_args] + if os.name == "nt": + result = subprocess.run(argv) + sys.exit(result.returncode) + else: + os.execv(py, argv) + + +def _build_arg_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser( + prog="kawai", + description="Local AI image/video generator. Auto-detects GPU; pass --backend to override.", + ) + p.add_argument( + "--backend", + choices=["auto", "cuda", "rocm", "directml", "mps", "cpu"], + default="auto", + help=( + "Force torch backend. cuda=NVIDIA, rocm=AMD on Linux, directml=AMD/Intel on Windows, " + "mps=Apple Silicon, cpu=fallback. Default: auto-detect." + ), + ) + p.add_argument( + "--vendor", + choices=["nvidia", "amd", "intel", "apple", "cpu"], + default=None, + help="Override detected vendor (rarely needed; useful when pairing --backend directml with intel).", + ) + p.add_argument( + "--reinstall", + action="store_true", + help="Force re-detect and reinstall torch (clears the install marker).", + ) + return p + + +def main() -> None: + args = _build_arg_parser().parse_args() + forced = args.backend if args.backend != "auto" else None + + if args.reinstall and MARKER.exists(): + print("[kawai] --reinstall: clearing install marker") + MARKER.unlink() + + if already_in_venv(): + if not MARKER.exists(): + uv = _ensure_uv() + detect_and_install(uv, force_backend=forced, force_vendor=args.vendor) + from app import run + run() + return + + uv = _ensure_uv() + _create_venv(uv) + if not MARKER.exists(): + detect_and_install(uv, force_backend=forced, force_vendor=args.vendor) + relaunch_in_venv(sys.argv[1:]) + + +if __name__ == "__main__": + main() diff --git a/models/.gitkeep b/models/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4c3836d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,26 @@ +# Core +gradio>=4.40.0 +pillow>=10.0.0 +numpy>=1.26.0 +requests>=2.31.0 +tqdm>=4.66.0 +huggingface_hub>=0.24.0 + +# Diffusion (latest available; pinned floors only) +diffusers>=0.32.0 +transformers>=4.45.0 +accelerate>=0.34.0 +safetensors>=0.4.5 +sentencepiece>=0.2.0 +protobuf>=4.25.0 +peft>=0.13.0 + +# Video export +imageio[ffmpeg]>=2.34.0 +opencv-python-headless>=4.10.0 + +# Safety +nudenet>=3.4.0 +facenet-pytorch>=2.5.3 + +# Torch installed separately by launcher with vendor-specific wheel