"""SDXL-family image generation. Loaded lazily, cached by repo id.""" from __future__ import annotations from functools import lru_cache from pathlib import Path from typing import Optional from PIL import Image from .device import get_device, torch_dtype from .memory import apply_memory_strategy from .models import ModelSpec, find ROOT = Path(__file__).parent.parent MODELS_DIR = ROOT / "models" @lru_cache(maxsize=1) def _load_pipeline(repo: str, kind: str): import torch from diffusers import ( AutoPipelineForText2Image, DiffusionPipeline, FluxPipeline, ) device = get_device() dtype = torch_dtype() cache_dir = str(MODELS_DIR / "diffusers") if "FLUX" in repo: pipe = FluxPipeline.from_pretrained(repo, torch_dtype=dtype, cache_dir=cache_dir) elif "turbo" in repo.lower(): pipe = AutoPipelineForText2Image.from_pretrained( repo, torch_dtype=dtype, variant="fp16", cache_dir=cache_dir ) else: pipe = DiffusionPipeline.from_pretrained( repo, torch_dtype=dtype, use_safetensors=True, cache_dir=cache_dir ) # Disable built-in safety checker. We use our own filter (CSAM-only). if hasattr(pipe, "safety_checker"): pipe.safety_checker = None if hasattr(pipe, "requires_safety_checker"): pipe.requires_safety_checker = False apply_memory_strategy(pipe) return pipe def generate( prompt: str, negative_prompt: str = "", model_id: str = "pony-xl", width: int = 1024, height: int = 1024, steps: int = 30, guidance: float = 7.0, seed: Optional[int] = None, ) -> Image.Image: spec: ModelSpec | None = find(model_id) if spec is None or spec.kind != "image": raise ValueError(f"Unknown image model: {model_id}") pipe = _load_pipeline(spec.repo, spec.kind) import torch generator = None if seed is not None: try: generator = torch.Generator(device=get_device()).manual_seed(int(seed)) except RuntimeError: generator = torch.Generator().manual_seed(int(seed)) if "turbo" in spec.repo.lower(): steps = max(1, min(steps, 4)) guidance = 0.0 if "FLUX" in spec.repo: guidance = 0.0 out = pipe( prompt=prompt, negative_prompt=negative_prompt or None, width=width, height=height, num_inference_steps=steps, guidance_scale=guidance, generator=generator, ) return out.images[0]