Initial commit
This commit is contained in:
@@ -0,0 +1,239 @@
|
||||
"""Gradio UI. Two tabs: Image, Video. Auto-picks defaults from detected hardware."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from backends import models, refiner, safety
|
||||
from backends.device import hardware_info
|
||||
|
||||
ROOT = Path(__file__).parent
|
||||
OUTPUTS = ROOT / "outputs"
|
||||
OUTPUTS.mkdir(exist_ok=True)
|
||||
CONFIG = json.loads((ROOT / "config.json").read_text())
|
||||
|
||||
|
||||
def _hw_summary() -> str:
|
||||
hw = hardware_info()
|
||||
return (
|
||||
f"**{hw['device_name']}** — {hw['vendor'].upper()} via {hw['backend']} "
|
||||
f"— {hw['vram_gb']:.1f} GB VRAM — tier `{hw['tier']}`"
|
||||
)
|
||||
|
||||
|
||||
def _models_status_md() -> str:
|
||||
rows = ["| Model | Kind | Min VRAM | Download | Status |", "|---|---|---|---|---|"]
|
||||
for m in models.IMAGE_MODELS + models.VIDEO_MODELS:
|
||||
status = "cached" if models.is_cached(m) else "not downloaded"
|
||||
rows.append(
|
||||
f"| {m.label} | {m.kind} | {m.min_vram_gb:.0f} GB | {m.download_gb:.0f} GB | {status} |"
|
||||
)
|
||||
return "### Models\n\n" + "\n".join(rows)
|
||||
|
||||
|
||||
def _model_choices(kind: str) -> tuple[list[tuple[str, str]], str | None]:
|
||||
hw = hardware_info()
|
||||
available = models.list_for_tier(hw["tier"], kind)
|
||||
choices = [(models.label_with_meta(m), m.id) for m in available]
|
||||
default = models.default_for_tier(hw["tier"], kind)
|
||||
return choices, (default.id if default else None)
|
||||
|
||||
|
||||
def gen_image(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
model_id: str,
|
||||
width: int,
|
||||
height: int,
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed: int,
|
||||
auto_refine: bool,
|
||||
):
|
||||
if not prompt.strip():
|
||||
raise gr.Error("Empty prompt.")
|
||||
|
||||
chk = safety.check_prompt(prompt)
|
||||
if not chk.allowed:
|
||||
raise gr.Error(chk.reason)
|
||||
|
||||
spec = models.find(model_id)
|
||||
if spec and not models.is_cached(spec):
|
||||
gr.Info(f"Downloading {spec.label} (~{spec.download_gb:.0f} GB) on first use. Watch terminal for progress.")
|
||||
|
||||
refined = prompt
|
||||
if auto_refine:
|
||||
refined = refiner.refine(prompt, use_ollama=CONFIG["refiner"]["use_ollama"])
|
||||
|
||||
from backends import image_sdxl
|
||||
|
||||
seed_val = None if seed < 0 else seed
|
||||
if seed_val is None:
|
||||
seed_val = random.randint(0, 2**31 - 1)
|
||||
|
||||
img = image_sdxl.generate(
|
||||
prompt=refined,
|
||||
negative_prompt=negative_prompt,
|
||||
model_id=model_id,
|
||||
width=int(width),
|
||||
height=int(height),
|
||||
steps=int(steps),
|
||||
guidance=float(guidance),
|
||||
seed=seed_val,
|
||||
)
|
||||
|
||||
img_chk = safety.check_image(img)
|
||||
if not img_chk.allowed:
|
||||
raise gr.Error(img_chk.reason)
|
||||
|
||||
out_path = OUTPUTS / f"img_{int(time.time())}_{seed_val}.png"
|
||||
img.save(out_path)
|
||||
info = f"Seed: {seed_val}\n\nPrompt used:\n{refined}"
|
||||
return img, info
|
||||
|
||||
|
||||
def gen_video(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
model_id: str,
|
||||
width: int,
|
||||
height: int,
|
||||
num_frames: int,
|
||||
fps: int,
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed: int,
|
||||
auto_refine: bool,
|
||||
):
|
||||
if not prompt.strip():
|
||||
raise gr.Error("Empty prompt.")
|
||||
|
||||
chk = safety.check_prompt(prompt)
|
||||
if not chk.allowed:
|
||||
raise gr.Error(chk.reason)
|
||||
|
||||
spec = models.find(model_id)
|
||||
if spec and not models.is_cached(spec):
|
||||
gr.Info(f"Downloading {spec.label} (~{spec.download_gb:.0f} GB) on first use. Watch terminal for progress.")
|
||||
|
||||
refined = prompt
|
||||
if auto_refine:
|
||||
refined = refiner.refine(prompt, use_ollama=CONFIG["refiner"]["use_ollama"])
|
||||
|
||||
from backends import video_ltx
|
||||
|
||||
seed_val = None if seed < 0 else seed
|
||||
if seed_val is None:
|
||||
seed_val = random.randint(0, 2**31 - 1)
|
||||
|
||||
path = video_ltx.generate(
|
||||
prompt=refined,
|
||||
negative_prompt=negative_prompt,
|
||||
model_id=model_id,
|
||||
width=int(width),
|
||||
height=int(height),
|
||||
num_frames=int(num_frames),
|
||||
fps=int(fps),
|
||||
steps=int(steps),
|
||||
guidance=float(guidance),
|
||||
seed=seed_val,
|
||||
)
|
||||
info = f"Seed: {seed_val}\n\nPrompt used:\n{refined}"
|
||||
return path, info
|
||||
|
||||
|
||||
def build_ui() -> gr.Blocks:
|
||||
img_choices, img_default = _model_choices("image")
|
||||
vid_choices, vid_default = _model_choices("video")
|
||||
img_def = CONFIG["image_defaults"]
|
||||
vid_def = CONFIG["video_defaults"]
|
||||
|
||||
with gr.Blocks(title="Kawai", analytics_enabled=False) as ui:
|
||||
gr.Markdown("# Kawai\nLocal AI image and video generator.")
|
||||
gr.Markdown(_hw_summary())
|
||||
|
||||
with gr.Tabs():
|
||||
with gr.Tab("Image"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
i_prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Describe what you want...")
|
||||
i_neg = gr.Textbox(label="Negative prompt", lines=2, value=img_def["negative_prompt"])
|
||||
i_refine = gr.Checkbox(label="Auto-refine prompt with local LLM", value=True)
|
||||
i_model = gr.Dropdown(choices=img_choices, value=img_default, label="Model")
|
||||
with gr.Row():
|
||||
i_w = gr.Slider(512, 1536, value=img_def["width"], step=64, label="Width")
|
||||
i_h = gr.Slider(512, 1536, value=img_def["height"], step=64, label="Height")
|
||||
with gr.Row():
|
||||
i_steps = gr.Slider(1, 80, value=img_def["steps"], step=1, label="Steps")
|
||||
i_guidance = gr.Slider(0.0, 15.0, value=img_def["guidance"], step=0.1, label="Guidance")
|
||||
i_seed = gr.Number(value=-1, label="Seed (-1 = random)", precision=0)
|
||||
i_btn = gr.Button("Generate", variant="primary")
|
||||
with gr.Column(scale=3):
|
||||
i_out = gr.Image(label="Output", type="pil")
|
||||
i_info = gr.Textbox(label="Info", lines=6, interactive=False)
|
||||
i_btn.click(
|
||||
gen_image,
|
||||
inputs=[i_prompt, i_neg, i_model, i_w, i_h, i_steps, i_guidance, i_seed, i_refine],
|
||||
outputs=[i_out, i_info],
|
||||
)
|
||||
|
||||
with gr.Tab("Video"):
|
||||
if not vid_choices:
|
||||
gr.Markdown("**Video disabled** — detected hardware lacks VRAM for any video model.")
|
||||
else:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
v_prompt = gr.Textbox(label="Prompt", lines=3)
|
||||
v_neg = gr.Textbox(label="Negative prompt", lines=2, value="")
|
||||
v_refine = gr.Checkbox(label="Auto-refine prompt with local LLM", value=True)
|
||||
v_model = gr.Dropdown(choices=vid_choices, value=vid_default, label="Model")
|
||||
with gr.Row():
|
||||
v_w = gr.Slider(384, 1024, value=vid_def["width"], step=32, label="Width")
|
||||
v_h = gr.Slider(256, 1024, value=vid_def["height"], step=32, label="Height")
|
||||
with gr.Row():
|
||||
v_frames = gr.Slider(17, 161, value=vid_def["num_frames"], step=8, label="Frames")
|
||||
v_fps = gr.Slider(8, 30, value=vid_def["fps"], step=1, label="FPS")
|
||||
with gr.Row():
|
||||
v_steps = gr.Slider(10, 60, value=vid_def["steps"], step=1, label="Steps")
|
||||
v_guidance = gr.Slider(0.0, 10.0, value=vid_def["guidance"], step=0.1, label="Guidance")
|
||||
v_seed = gr.Number(value=-1, label="Seed (-1 = random)", precision=0)
|
||||
v_btn = gr.Button("Generate", variant="primary")
|
||||
with gr.Column(scale=3):
|
||||
v_out = gr.Video(label="Output")
|
||||
v_info = gr.Textbox(label="Info", lines=6, interactive=False)
|
||||
v_btn.click(
|
||||
gen_video,
|
||||
inputs=[v_prompt, v_neg, v_model, v_w, v_h, v_frames, v_fps, v_steps, v_guidance, v_seed, v_refine],
|
||||
outputs=[v_out, v_info],
|
||||
)
|
||||
|
||||
with gr.Tab("System"):
|
||||
gr.Markdown(_hw_summary())
|
||||
gr.Markdown(_models_status_md())
|
||||
gr.Markdown(
|
||||
"**Output folder:** `outputs/`\n\n"
|
||||
"**Models cache:** `models/diffusers/`\n\n"
|
||||
"**Prompt refiner:** Ollama with `dolphin-llama3:8b` if running, else GPT-2 fallback.\n\n"
|
||||
"Install Ollama: https://ollama.com/ then `ollama pull dolphin-llama3`.\n\n"
|
||||
"**Safety:** CSAM-gated only (prompt keyword gate + face age check on nude outputs). All other content allowed.\n\n"
|
||||
"**Note:** First use of a model triggers download (7–24 GB). Keep this terminal open during download."
|
||||
)
|
||||
|
||||
return ui
|
||||
|
||||
|
||||
def run() -> None:
|
||||
ui = build_ui()
|
||||
ui.queue().launch(
|
||||
server_name=CONFIG["ui"]["host"],
|
||||
server_port=CONFIG["ui"]["port"],
|
||||
inbrowser=CONFIG["ui"]["open_browser"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
Reference in New Issue
Block a user