"""Torch device selection. Wraps cuda / directml / cpu behind one helper.""" from __future__ import annotations import json from functools import lru_cache from pathlib import Path ROOT = Path(__file__).parent.parent HARDWARE_CACHE = ROOT / "config.local.json" @lru_cache(maxsize=1) def hardware_info() -> dict: if HARDWARE_CACHE.exists(): return json.loads(HARDWARE_CACHE.read_text()) from . import hardware info = hardware.detect() return { "vendor": info.vendor, "backend": info.backend, "device_name": info.device_name, "vram_gb": info.vram_gb, "tier": info.tier, } @lru_cache(maxsize=1) def get_device(): import torch backend = hardware_info()["backend"] <<<<<<< HEAD # ROCm builds of torch expose the cuda namespace. if backend in ("cuda", "rocm") and torch.cuda.is_available(): ======= <<<<<<< HEAD # ROCm builds of torch expose the cuda namespace. if backend in ("cuda", "rocm") and torch.cuda.is_available(): ======= if backend == "cuda" and torch.cuda.is_available(): >>>>>>> refs/remotes/azuze/main >>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9 return torch.device("cuda") if backend == "directml": try: import torch_directml return torch_directml.device() except ImportError: pass <<<<<<< HEAD if backend == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): return torch.device("mps") ======= <<<<<<< HEAD if backend == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): return torch.device("mps") ======= >>>>>>> refs/remotes/azuze/main >>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9 return torch.device("cpu") def torch_dtype(): import torch backend = hardware_info()["backend"] if backend == "cpu": return torch.float32 <<<<<<< HEAD # MPS supports fp16 for diffusers; bf16 has gaps. Stick with fp16. ======= <<<<<<< HEAD # MPS supports fp16 for diffusers; bf16 has gaps. Stick with fp16. ======= >>>>>>> refs/remotes/azuze/main >>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9 return torch.float16