Files
KawAI/backends/image_sdxl.py
T
2026-05-04 09:47:58 +02:00

93 lines
2.5 KiB
Python

"""SDXL-family image generation. Loaded lazily, cached by repo id."""
from __future__ import annotations
from functools import lru_cache
from pathlib import Path
from typing import Optional
from PIL import Image
from .device import get_device, torch_dtype
from .memory import apply_memory_strategy
from .models import ModelSpec, find
ROOT = Path(__file__).parent.parent
MODELS_DIR = ROOT / "models"
@lru_cache(maxsize=1)
def _load_pipeline(repo: str, kind: str):
import torch
from diffusers import (
AutoPipelineForText2Image,
DiffusionPipeline,
FluxPipeline,
)
device = get_device()
dtype = torch_dtype()
cache_dir = str(MODELS_DIR / "diffusers")
if "FLUX" in repo:
pipe = FluxPipeline.from_pretrained(repo, torch_dtype=dtype, cache_dir=cache_dir)
elif "turbo" in repo.lower():
pipe = AutoPipelineForText2Image.from_pretrained(
repo, torch_dtype=dtype, variant="fp16", cache_dir=cache_dir
)
else:
pipe = DiffusionPipeline.from_pretrained(
repo, torch_dtype=dtype, use_safetensors=True, cache_dir=cache_dir
)
# Disable built-in safety checker. We use our own filter (CSAM-only).
if hasattr(pipe, "safety_checker"):
pipe.safety_checker = None
if hasattr(pipe, "requires_safety_checker"):
pipe.requires_safety_checker = False
apply_memory_strategy(pipe)
return pipe
def generate(
prompt: str,
negative_prompt: str = "",
model_id: str = "pony-xl",
width: int = 1024,
height: int = 1024,
steps: int = 30,
guidance: float = 7.0,
seed: Optional[int] = None,
) -> Image.Image:
spec: ModelSpec | None = find(model_id)
if spec is None or spec.kind != "image":
raise ValueError(f"Unknown image model: {model_id}")
pipe = _load_pipeline(spec.repo, spec.kind)
import torch
generator = None
if seed is not None:
try:
generator = torch.Generator(device=get_device()).manual_seed(int(seed))
except RuntimeError:
generator = torch.Generator().manual_seed(int(seed))
if "turbo" in spec.repo.lower():
steps = max(1, min(steps, 4))
guidance = 0.0
if "FLUX" in spec.repo:
guidance = 0.0
out = pipe(
prompt=prompt,
negative_prompt=negative_prompt or None,
width=width,
height=height,
num_inference_steps=steps,
guidance_scale=guidance,
generator=generator,
)
return out.images[0]