Initial commit
This commit is contained in:
@@ -0,0 +1,92 @@
|
||||
"""SDXL-family image generation. Loaded lazily, cached by repo id."""
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from .device import get_device, torch_dtype
|
||||
from .memory import apply_memory_strategy
|
||||
from .models import ModelSpec, find
|
||||
|
||||
ROOT = Path(__file__).parent.parent
|
||||
MODELS_DIR = ROOT / "models"
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _load_pipeline(repo: str, kind: str):
|
||||
import torch
|
||||
from diffusers import (
|
||||
AutoPipelineForText2Image,
|
||||
DiffusionPipeline,
|
||||
FluxPipeline,
|
||||
)
|
||||
|
||||
device = get_device()
|
||||
dtype = torch_dtype()
|
||||
cache_dir = str(MODELS_DIR / "diffusers")
|
||||
|
||||
if "FLUX" in repo:
|
||||
pipe = FluxPipeline.from_pretrained(repo, torch_dtype=dtype, cache_dir=cache_dir)
|
||||
elif "turbo" in repo.lower():
|
||||
pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
repo, torch_dtype=dtype, variant="fp16", cache_dir=cache_dir
|
||||
)
|
||||
else:
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
repo, torch_dtype=dtype, use_safetensors=True, cache_dir=cache_dir
|
||||
)
|
||||
|
||||
# Disable built-in safety checker. We use our own filter (CSAM-only).
|
||||
if hasattr(pipe, "safety_checker"):
|
||||
pipe.safety_checker = None
|
||||
if hasattr(pipe, "requires_safety_checker"):
|
||||
pipe.requires_safety_checker = False
|
||||
|
||||
apply_memory_strategy(pipe)
|
||||
return pipe
|
||||
|
||||
|
||||
def generate(
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
model_id: str = "pony-xl",
|
||||
width: int = 1024,
|
||||
height: int = 1024,
|
||||
steps: int = 30,
|
||||
guidance: float = 7.0,
|
||||
seed: Optional[int] = None,
|
||||
) -> Image.Image:
|
||||
spec: ModelSpec | None = find(model_id)
|
||||
if spec is None or spec.kind != "image":
|
||||
raise ValueError(f"Unknown image model: {model_id}")
|
||||
|
||||
pipe = _load_pipeline(spec.repo, spec.kind)
|
||||
|
||||
import torch
|
||||
|
||||
generator = None
|
||||
if seed is not None:
|
||||
try:
|
||||
generator = torch.Generator(device=get_device()).manual_seed(int(seed))
|
||||
except RuntimeError:
|
||||
generator = torch.Generator().manual_seed(int(seed))
|
||||
|
||||
if "turbo" in spec.repo.lower():
|
||||
steps = max(1, min(steps, 4))
|
||||
guidance = 0.0
|
||||
if "FLUX" in spec.repo:
|
||||
guidance = 0.0
|
||||
|
||||
out = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt or None,
|
||||
width=width,
|
||||
height=height,
|
||||
num_inference_steps=steps,
|
||||
guidance_scale=guidance,
|
||||
generator=generator,
|
||||
)
|
||||
return out.images[0]
|
||||
Reference in New Issue
Block a user