Initial commit
This commit is contained in:
@@ -0,0 +1,83 @@
|
||||
"""Memory strategy selection. CUDA supports cpu_offload; DirectML/CPU do not.
|
||||
|
||||
Apply per-pipeline based on backend + VRAM tier. All paths reduce peak VRAM
|
||||
without breaking on non-CUDA devices.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from .device import get_device, hardware_info
|
||||
|
||||
|
||||
def apply_memory_strategy(pipe) -> None:
|
||||
"""Apply VRAM-saving knobs that match the active backend."""
|
||||
info = hardware_info()
|
||||
backend = info["backend"]
|
||||
vram = info["vram_gb"]
|
||||
|
||||
# Always-safe: VAE tiling/slicing work on any device. Cuts peak VRAM during decode.
|
||||
# Newer diffusers (>=0.32) prefers calling on the VAE directly.
|
||||
vae = getattr(pipe, "vae", None)
|
||||
if vae is not None:
|
||||
for fn in ("enable_slicing", "enable_tiling"):
|
||||
if hasattr(vae, fn):
|
||||
try:
|
||||
getattr(vae, fn)()
|
||||
except Exception:
|
||||
pass
|
||||
if hasattr(pipe, "enable_attention_slicing"):
|
||||
try:
|
||||
pipe.enable_attention_slicing()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if backend in ("cuda", "rocm"):
|
||||
# ROCm builds expose the cuda API, so accelerate offload hooks work the same way.
|
||||
# Offload only if VRAM tight.
|
||||
if vram < 10:
|
||||
try:
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_model_cpu_offload()
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
pipe.to(get_device())
|
||||
return
|
||||
|
||||
if backend == "directml":
|
||||
# DirectML lacks accelerate hook support. Move whole pipe to device.
|
||||
# Slicing already enabled above keeps peak in check.
|
||||
try:
|
||||
pipe.to(get_device())
|
||||
except Exception:
|
||||
# Some pipes have components that won't move cleanly; fall back to CPU.
|
||||
pipe.to("cpu")
|
||||
return
|
||||
|
||||
if backend == "mps":
|
||||
# Apple Silicon shares unified memory with CPU. accelerate's sequential offload
|
||||
# has spotty MPS support; rely on slicing/tiling already enabled above.
|
||||
# Tight memory tier: keep on CPU and let model_cpu_offload move chunks if available.
|
||||
if vram < 12:
|
||||
try:
|
||||
pipe.enable_model_cpu_offload(device="mps")
|
||||
return
|
||||
except TypeError:
|
||||
try:
|
||||
pipe.enable_model_cpu_offload()
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.to(get_device())
|
||||
except Exception:
|
||||
pipe.to("cpu")
|
||||
return
|
||||
|
||||
# CPU
|
||||
pipe.to("cpu")
|
||||
Reference in New Issue
Block a user