"""Torch device selection. Wraps cuda / directml / cpu behind one helper.""" from __future__ import annotations import json from functools import lru_cache from pathlib import Path ROOT = Path(__file__).parent.parent HARDWARE_CACHE = ROOT / "config.local.json" @lru_cache(maxsize=1) def hardware_info() -> dict: if HARDWARE_CACHE.exists(): return json.loads(HARDWARE_CACHE.read_text()) from . import hardware info = hardware.detect() return { "vendor": info.vendor, "backend": info.backend, "device_name": info.device_name, "vram_gb": info.vram_gb, "tier": info.tier, } @lru_cache(maxsize=1) def get_device(): import torch backend = hardware_info()["backend"] if backend == "cuda" and torch.cuda.is_available(): return torch.device("cuda") if backend == "directml": try: import torch_directml return torch_directml.device() except ImportError: pass return torch.device("cpu") def torch_dtype(): import torch backend = hardware_info()["backend"] if backend == "cpu": return torch.float32 return torch.float16