"""Memory strategy selection. CUDA supports cpu_offload; DirectML/CPU do not. Apply per-pipeline based on backend + VRAM tier. All paths reduce peak VRAM without breaking on non-CUDA devices. """ from __future__ import annotations from .device import get_device, hardware_info def apply_memory_strategy(pipe) -> None: """Apply VRAM-saving knobs that match the active backend.""" info = hardware_info() backend = info["backend"] vram = info["vram_gb"] # Always-safe: VAE tiling/slicing work on any device. Cuts peak VRAM during decode. # Newer diffusers (>=0.32) prefers calling on the VAE directly. vae = getattr(pipe, "vae", None) if vae is not None: for fn in ("enable_slicing", "enable_tiling"): if hasattr(vae, fn): try: getattr(vae, fn)() except Exception: pass if hasattr(pipe, "enable_attention_slicing"): try: pipe.enable_attention_slicing() except Exception: pass <<<<<<< HEAD if backend in ("cuda", "rocm"): # ROCm builds expose the cuda API, so accelerate offload hooks work the same way. # Offload only if VRAM tight. ======= if backend == "cuda": # Offload only if VRAM tight. cpu_offload is CUDA-only via accelerate hooks. >>>>>>> refs/remotes/azuze/main if vram < 10: try: pipe.enable_sequential_cpu_offload() return except Exception: pass try: pipe.enable_model_cpu_offload() return except Exception: pass pipe.to(get_device()) return if backend == "directml": # DirectML lacks accelerate hook support. Move whole pipe to device. # Slicing already enabled above keeps peak in check. try: pipe.to(get_device()) except Exception: # Some pipes have components that won't move cleanly; fall back to CPU. pipe.to("cpu") return <<<<<<< HEAD if backend == "mps": # Apple Silicon shares unified memory with CPU. accelerate's sequential offload # has spotty MPS support; rely on slicing/tiling already enabled above. # Tight memory tier: keep on CPU and let model_cpu_offload move chunks if available. if vram < 12: try: pipe.enable_model_cpu_offload(device="mps") return except TypeError: try: pipe.enable_model_cpu_offload() return except Exception: pass except Exception: pass try: pipe.to(get_device()) except Exception: pipe.to("cpu") return ======= >>>>>>> refs/remotes/azuze/main # CPU pipe.to("cpu")