Files
KawAI/app.py
T
2026-05-04 09:47:58 +02:00

240 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Gradio UI. Two tabs: Image, Video. Auto-picks defaults from detected hardware."""
from __future__ import annotations
import json
import random
import time
from pathlib import Path
import gradio as gr
from backends import models, refiner, safety
from backends.device import hardware_info
ROOT = Path(__file__).parent
OUTPUTS = ROOT / "outputs"
OUTPUTS.mkdir(exist_ok=True)
CONFIG = json.loads((ROOT / "config.json").read_text())
def _hw_summary() -> str:
hw = hardware_info()
return (
f"**{hw['device_name']}** — {hw['vendor'].upper()} via {hw['backend']} "
f"{hw['vram_gb']:.1f} GB VRAM — tier `{hw['tier']}`"
)
def _models_status_md() -> str:
rows = ["| Model | Kind | Min VRAM | Download | Status |", "|---|---|---|---|---|"]
for m in models.IMAGE_MODELS + models.VIDEO_MODELS:
status = "cached" if models.is_cached(m) else "not downloaded"
rows.append(
f"| {m.label} | {m.kind} | {m.min_vram_gb:.0f} GB | {m.download_gb:.0f} GB | {status} |"
)
return "### Models\n\n" + "\n".join(rows)
def _model_choices(kind: str) -> tuple[list[tuple[str, str]], str | None]:
hw = hardware_info()
available = models.list_for_tier(hw["tier"], kind)
choices = [(models.label_with_meta(m), m.id) for m in available]
default = models.default_for_tier(hw["tier"], kind)
return choices, (default.id if default else None)
def gen_image(
prompt: str,
negative_prompt: str,
model_id: str,
width: int,
height: int,
steps: int,
guidance: float,
seed: int,
auto_refine: bool,
):
if not prompt.strip():
raise gr.Error("Empty prompt.")
chk = safety.check_prompt(prompt)
if not chk.allowed:
raise gr.Error(chk.reason)
spec = models.find(model_id)
if spec and not models.is_cached(spec):
gr.Info(f"Downloading {spec.label} (~{spec.download_gb:.0f} GB) on first use. Watch terminal for progress.")
refined = prompt
if auto_refine:
refined = refiner.refine(prompt, use_ollama=CONFIG["refiner"]["use_ollama"])
from backends import image_sdxl
seed_val = None if seed < 0 else seed
if seed_val is None:
seed_val = random.randint(0, 2**31 - 1)
img = image_sdxl.generate(
prompt=refined,
negative_prompt=negative_prompt,
model_id=model_id,
width=int(width),
height=int(height),
steps=int(steps),
guidance=float(guidance),
seed=seed_val,
)
img_chk = safety.check_image(img)
if not img_chk.allowed:
raise gr.Error(img_chk.reason)
out_path = OUTPUTS / f"img_{int(time.time())}_{seed_val}.png"
img.save(out_path)
info = f"Seed: {seed_val}\n\nPrompt used:\n{refined}"
return img, info
def gen_video(
prompt: str,
negative_prompt: str,
model_id: str,
width: int,
height: int,
num_frames: int,
fps: int,
steps: int,
guidance: float,
seed: int,
auto_refine: bool,
):
if not prompt.strip():
raise gr.Error("Empty prompt.")
chk = safety.check_prompt(prompt)
if not chk.allowed:
raise gr.Error(chk.reason)
spec = models.find(model_id)
if spec and not models.is_cached(spec):
gr.Info(f"Downloading {spec.label} (~{spec.download_gb:.0f} GB) on first use. Watch terminal for progress.")
refined = prompt
if auto_refine:
refined = refiner.refine(prompt, use_ollama=CONFIG["refiner"]["use_ollama"])
from backends import video_ltx
seed_val = None if seed < 0 else seed
if seed_val is None:
seed_val = random.randint(0, 2**31 - 1)
path = video_ltx.generate(
prompt=refined,
negative_prompt=negative_prompt,
model_id=model_id,
width=int(width),
height=int(height),
num_frames=int(num_frames),
fps=int(fps),
steps=int(steps),
guidance=float(guidance),
seed=seed_val,
)
info = f"Seed: {seed_val}\n\nPrompt used:\n{refined}"
return path, info
def build_ui() -> gr.Blocks:
img_choices, img_default = _model_choices("image")
vid_choices, vid_default = _model_choices("video")
img_def = CONFIG["image_defaults"]
vid_def = CONFIG["video_defaults"]
with gr.Blocks(title="Kawai", analytics_enabled=False) as ui:
gr.Markdown("# Kawai\nLocal AI image and video generator.")
gr.Markdown(_hw_summary())
with gr.Tabs():
with gr.Tab("Image"):
with gr.Row():
with gr.Column(scale=2):
i_prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Describe what you want...")
i_neg = gr.Textbox(label="Negative prompt", lines=2, value=img_def["negative_prompt"])
i_refine = gr.Checkbox(label="Auto-refine prompt with local LLM", value=True)
i_model = gr.Dropdown(choices=img_choices, value=img_default, label="Model")
with gr.Row():
i_w = gr.Slider(512, 1536, value=img_def["width"], step=64, label="Width")
i_h = gr.Slider(512, 1536, value=img_def["height"], step=64, label="Height")
with gr.Row():
i_steps = gr.Slider(1, 80, value=img_def["steps"], step=1, label="Steps")
i_guidance = gr.Slider(0.0, 15.0, value=img_def["guidance"], step=0.1, label="Guidance")
i_seed = gr.Number(value=-1, label="Seed (-1 = random)", precision=0)
i_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=3):
i_out = gr.Image(label="Output", type="pil")
i_info = gr.Textbox(label="Info", lines=6, interactive=False)
i_btn.click(
gen_image,
inputs=[i_prompt, i_neg, i_model, i_w, i_h, i_steps, i_guidance, i_seed, i_refine],
outputs=[i_out, i_info],
)
with gr.Tab("Video"):
if not vid_choices:
gr.Markdown("**Video disabled** — detected hardware lacks VRAM for any video model.")
else:
with gr.Row():
with gr.Column(scale=2):
v_prompt = gr.Textbox(label="Prompt", lines=3)
v_neg = gr.Textbox(label="Negative prompt", lines=2, value="")
v_refine = gr.Checkbox(label="Auto-refine prompt with local LLM", value=True)
v_model = gr.Dropdown(choices=vid_choices, value=vid_default, label="Model")
with gr.Row():
v_w = gr.Slider(384, 1024, value=vid_def["width"], step=32, label="Width")
v_h = gr.Slider(256, 1024, value=vid_def["height"], step=32, label="Height")
with gr.Row():
v_frames = gr.Slider(17, 161, value=vid_def["num_frames"], step=8, label="Frames")
v_fps = gr.Slider(8, 30, value=vid_def["fps"], step=1, label="FPS")
with gr.Row():
v_steps = gr.Slider(10, 60, value=vid_def["steps"], step=1, label="Steps")
v_guidance = gr.Slider(0.0, 10.0, value=vid_def["guidance"], step=0.1, label="Guidance")
v_seed = gr.Number(value=-1, label="Seed (-1 = random)", precision=0)
v_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=3):
v_out = gr.Video(label="Output")
v_info = gr.Textbox(label="Info", lines=6, interactive=False)
v_btn.click(
gen_video,
inputs=[v_prompt, v_neg, v_model, v_w, v_h, v_frames, v_fps, v_steps, v_guidance, v_seed, v_refine],
outputs=[v_out, v_info],
)
with gr.Tab("System"):
gr.Markdown(_hw_summary())
gr.Markdown(_models_status_md())
gr.Markdown(
"**Output folder:** `outputs/`\n\n"
"**Models cache:** `models/diffusers/`\n\n"
"**Prompt refiner:** Ollama with `dolphin-llama3:8b` if running, else GPT-2 fallback.\n\n"
"Install Ollama: https://ollama.com/ then `ollama pull dolphin-llama3`.\n\n"
"**Safety:** CSAM-gated only (prompt keyword gate + face age check on nude outputs). All other content allowed.\n\n"
"**Note:** First use of a model triggers download (724 GB). Keep this terminal open during download."
)
return ui
def run() -> None:
ui = build_ui()
ui.queue().launch(
server_name=CONFIG["ui"]["host"],
server_port=CONFIG["ui"]["port"],
inbrowser=CONFIG["ui"]["open_browser"],
)
if __name__ == "__main__":
run()