Files
KawAI/backends/memory.py
T
2026-05-04 09:47:58 +02:00

84 lines
2.7 KiB
Python

"""Memory strategy selection. CUDA supports cpu_offload; DirectML/CPU do not.
Apply per-pipeline based on backend + VRAM tier. All paths reduce peak VRAM
without breaking on non-CUDA devices.
"""
from __future__ import annotations
from .device import get_device, hardware_info
def apply_memory_strategy(pipe) -> None:
"""Apply VRAM-saving knobs that match the active backend."""
info = hardware_info()
backend = info["backend"]
vram = info["vram_gb"]
# Always-safe: VAE tiling/slicing work on any device. Cuts peak VRAM during decode.
# Newer diffusers (>=0.32) prefers calling on the VAE directly.
vae = getattr(pipe, "vae", None)
if vae is not None:
for fn in ("enable_slicing", "enable_tiling"):
if hasattr(vae, fn):
try:
getattr(vae, fn)()
except Exception:
pass
if hasattr(pipe, "enable_attention_slicing"):
try:
pipe.enable_attention_slicing()
except Exception:
pass
if backend in ("cuda", "rocm"):
# ROCm builds expose the cuda API, so accelerate offload hooks work the same way.
# Offload only if VRAM tight.
if vram < 10:
try:
pipe.enable_sequential_cpu_offload()
return
except Exception:
pass
try:
pipe.enable_model_cpu_offload()
return
except Exception:
pass
pipe.to(get_device())
return
if backend == "directml":
# DirectML lacks accelerate hook support. Move whole pipe to device.
# Slicing already enabled above keeps peak in check.
try:
pipe.to(get_device())
except Exception:
# Some pipes have components that won't move cleanly; fall back to CPU.
pipe.to("cpu")
return
if backend == "mps":
# Apple Silicon shares unified memory with CPU. accelerate's sequential offload
# has spotty MPS support; rely on slicing/tiling already enabled above.
# Tight memory tier: keep on CPU and let model_cpu_offload move chunks if available.
if vram < 12:
try:
pipe.enable_model_cpu_offload(device="mps")
return
except TypeError:
try:
pipe.enable_model_cpu_offload()
return
except Exception:
pass
except Exception:
pass
try:
pipe.to(get_device())
except Exception:
pipe.to("cpu")
return
# CPU
pipe.to("cpu")