93 lines
2.5 KiB
Python
93 lines
2.5 KiB
Python
"""SDXL-family image generation. Loaded lazily, cached by repo id."""
|
|
from __future__ import annotations
|
|
|
|
from functools import lru_cache
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from PIL import Image
|
|
|
|
from .device import get_device, torch_dtype
|
|
from .memory import apply_memory_strategy
|
|
from .models import ModelSpec, find
|
|
|
|
ROOT = Path(__file__).parent.parent
|
|
MODELS_DIR = ROOT / "models"
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def _load_pipeline(repo: str, kind: str):
|
|
import torch
|
|
from diffusers import (
|
|
AutoPipelineForText2Image,
|
|
DiffusionPipeline,
|
|
FluxPipeline,
|
|
)
|
|
|
|
device = get_device()
|
|
dtype = torch_dtype()
|
|
cache_dir = str(MODELS_DIR / "diffusers")
|
|
|
|
if "FLUX" in repo:
|
|
pipe = FluxPipeline.from_pretrained(repo, torch_dtype=dtype, cache_dir=cache_dir)
|
|
elif "turbo" in repo.lower():
|
|
pipe = AutoPipelineForText2Image.from_pretrained(
|
|
repo, torch_dtype=dtype, variant="fp16", cache_dir=cache_dir
|
|
)
|
|
else:
|
|
pipe = DiffusionPipeline.from_pretrained(
|
|
repo, torch_dtype=dtype, use_safetensors=True, cache_dir=cache_dir
|
|
)
|
|
|
|
# Disable built-in safety checker. We use our own filter (CSAM-only).
|
|
if hasattr(pipe, "safety_checker"):
|
|
pipe.safety_checker = None
|
|
if hasattr(pipe, "requires_safety_checker"):
|
|
pipe.requires_safety_checker = False
|
|
|
|
apply_memory_strategy(pipe)
|
|
return pipe
|
|
|
|
|
|
def generate(
|
|
prompt: str,
|
|
negative_prompt: str = "",
|
|
model_id: str = "pony-xl",
|
|
width: int = 1024,
|
|
height: int = 1024,
|
|
steps: int = 30,
|
|
guidance: float = 7.0,
|
|
seed: Optional[int] = None,
|
|
) -> Image.Image:
|
|
spec: ModelSpec | None = find(model_id)
|
|
if spec is None or spec.kind != "image":
|
|
raise ValueError(f"Unknown image model: {model_id}")
|
|
|
|
pipe = _load_pipeline(spec.repo, spec.kind)
|
|
|
|
import torch
|
|
|
|
generator = None
|
|
if seed is not None:
|
|
try:
|
|
generator = torch.Generator(device=get_device()).manual_seed(int(seed))
|
|
except RuntimeError:
|
|
generator = torch.Generator().manual_seed(int(seed))
|
|
|
|
if "turbo" in spec.repo.lower():
|
|
steps = max(1, min(steps, 4))
|
|
guidance = 0.0
|
|
if "FLUX" in spec.repo:
|
|
guidance = 0.0
|
|
|
|
out = pipe(
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt or None,
|
|
width=width,
|
|
height=height,
|
|
num_inference_steps=steps,
|
|
guidance_scale=guidance,
|
|
generator=generator,
|
|
)
|
|
return out.images[0]
|