84 lines
2.7 KiB
Python
84 lines
2.7 KiB
Python
"""Memory strategy selection. CUDA supports cpu_offload; DirectML/CPU do not.
|
|
|
|
Apply per-pipeline based on backend + VRAM tier. All paths reduce peak VRAM
|
|
without breaking on non-CUDA devices.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from .device import get_device, hardware_info
|
|
|
|
|
|
def apply_memory_strategy(pipe) -> None:
|
|
"""Apply VRAM-saving knobs that match the active backend."""
|
|
info = hardware_info()
|
|
backend = info["backend"]
|
|
vram = info["vram_gb"]
|
|
|
|
# Always-safe: VAE tiling/slicing work on any device. Cuts peak VRAM during decode.
|
|
# Newer diffusers (>=0.32) prefers calling on the VAE directly.
|
|
vae = getattr(pipe, "vae", None)
|
|
if vae is not None:
|
|
for fn in ("enable_slicing", "enable_tiling"):
|
|
if hasattr(vae, fn):
|
|
try:
|
|
getattr(vae, fn)()
|
|
except Exception:
|
|
pass
|
|
if hasattr(pipe, "enable_attention_slicing"):
|
|
try:
|
|
pipe.enable_attention_slicing()
|
|
except Exception:
|
|
pass
|
|
|
|
if backend in ("cuda", "rocm"):
|
|
# ROCm builds expose the cuda API, so accelerate offload hooks work the same way.
|
|
# Offload only if VRAM tight.
|
|
if vram < 10:
|
|
try:
|
|
pipe.enable_sequential_cpu_offload()
|
|
return
|
|
except Exception:
|
|
pass
|
|
try:
|
|
pipe.enable_model_cpu_offload()
|
|
return
|
|
except Exception:
|
|
pass
|
|
pipe.to(get_device())
|
|
return
|
|
|
|
if backend == "directml":
|
|
# DirectML lacks accelerate hook support. Move whole pipe to device.
|
|
# Slicing already enabled above keeps peak in check.
|
|
try:
|
|
pipe.to(get_device())
|
|
except Exception:
|
|
# Some pipes have components that won't move cleanly; fall back to CPU.
|
|
pipe.to("cpu")
|
|
return
|
|
|
|
if backend == "mps":
|
|
# Apple Silicon shares unified memory with CPU. accelerate's sequential offload
|
|
# has spotty MPS support; rely on slicing/tiling already enabled above.
|
|
# Tight memory tier: keep on CPU and let model_cpu_offload move chunks if available.
|
|
if vram < 12:
|
|
try:
|
|
pipe.enable_model_cpu_offload(device="mps")
|
|
return
|
|
except TypeError:
|
|
try:
|
|
pipe.enable_model_cpu_offload()
|
|
return
|
|
except Exception:
|
|
pass
|
|
except Exception:
|
|
pass
|
|
try:
|
|
pipe.to(get_device())
|
|
except Exception:
|
|
pipe.to("cpu")
|
|
return
|
|
|
|
# CPU
|
|
pipe.to("cpu")
|