"""GPU and VRAM detection. Returns vendor + tier used to pick torch wheel and default models.""" from __future__ import annotations import ctypes import platform import subprocess from dataclasses import dataclass from typing import Literal Vendor = Literal["nvidia", "amd", "intel", "cpu"] Backend = Literal["cuda", "directml", "cpu"] @dataclass class HardwareInfo: vendor: Vendor backend: Backend device_name: str vram_gb: float tier: Literal["cpu", "low", "mid", "high", "ultra"] def _run(cmd: list[str]) -> str: try: out = subprocess.run(cmd, capture_output=True, text=True, timeout=10, check=False) return out.stdout.strip() except (FileNotFoundError, subprocess.TimeoutExpired): return "" def _detect_nvidia() -> tuple[str, float] | None: out = _run(["nvidia-smi", "--query-gpu=name,memory.total", "--format=csv,noheader,nounits"]) if not out: return None first = out.splitlines()[0] name, mem = [p.strip() for p in first.split(",", 1)] return name, float(mem) / 1024.0 def _detect_dxgi() -> list[tuple[str, float, str]]: """Enumerate all DXGI adapters via PowerShell. Returns list of (name, vram_gb, vendor_hint).""" if platform.system() != "Windows": return [] ps = ( "Get-CimInstance Win32_VideoController | " "Select-Object Name, AdapterRAM | " "ConvertTo-Json -Compress" ) out = _run(["powershell", "-NoProfile", "-Command", ps]) if not out: return [] import json try: data = json.loads(out) except json.JSONDecodeError: return [] if isinstance(data, dict): data = [data] results: list[tuple[str, float, str]] = [] for entry in data: name = (entry.get("Name") or "").strip() ram = entry.get("AdapterRAM") or 0 # Win32_VideoController caps AdapterRAM at 4GB on many systems. Trust value but flag below. vram_gb = float(ram) / (1024 ** 3) if ram else 0.0 low = name.lower() if "nvidia" in low or "geforce" in low or "rtx" in low or "gtx" in low: vendor = "nvidia" elif "amd" in low or "radeon" in low or "rx " in low: vendor = "amd" elif "intel" in low or "arc" in low or "iris" in low: vendor = "intel" else: vendor = "unknown" if name: results.append((name, vram_gb, vendor)) return results def _vram_tier(vram_gb: float) -> Literal["cpu", "low", "mid", "high", "ultra"]: if vram_gb < 1: return "cpu" if vram_gb < 6: return "low" if vram_gb < 10: return "mid" if vram_gb < 14: return "high" return "ultra" def detect() -> HardwareInfo: nv = _detect_nvidia() if nv: name, vram = nv return HardwareInfo("nvidia", "cuda", name, vram, _vram_tier(vram)) adapters = _detect_dxgi() # Prefer discrete (highest VRAM) non-basic adapter adapters = [a for a in adapters if "basic" not in a[0].lower() and "microsoft" not in a[0].lower()] if adapters: adapters.sort(key=lambda a: a[1], reverse=True) name, vram, hint = adapters[0] # AdapterRAM is unreliable for >4GB cards. If exactly 4GB and modern AMD/Intel card name, bump. if vram <= 4.1 and any(k in name.lower() for k in ("rx 6", "rx 7", "arc a", "arc b")): vram = 8.0 # conservative guess if hint in ("amd", "intel"): return HardwareInfo(hint, "directml", name, vram, _vram_tier(vram)) if hint == "nvidia": # nvidia-smi missing but card is nvidia: drivers may be broken, fall through to directml return HardwareInfo("nvidia", "directml", name, vram, _vram_tier(vram)) return HardwareInfo("cpu", "cpu", platform.processor() or "CPU", 0.0, "cpu") def directml_supported() -> bool: """torch-directml ships wheels for Python 3.10 / 3.11. The launcher pins the venv to 3.11, so this is True in the running app.""" import sys return sys.version_info[:2] in {(3, 10), (3, 11)} def torch_install_args(info: HardwareInfo) -> list[str]: """Return uv-pip install args for the right PyTorch build. Launcher pins venv to Python 3.11, so torch-directml wheels are always available for AMD/Intel paths. Latest stable torch is used elsewhere. """ if info.backend == "cuda": return [ "torch", "torchvision", "--index-url", "https://download.pytorch.org/whl/cu124", ] if info.backend == "directml": # torch-directml currently pins to torch 2.4.x. Match it. return [ "torch>=2.4,<2.5", "torchvision>=0.19,<0.20", "torch-directml>=0.2.5", ] return [ "torch", "torchvision", "--index-url", "https://download.pytorch.org/whl/cpu", ] if __name__ == "__main__": info = detect() print(f"Vendor: {info.vendor}") print(f"Backend: {info.backend}") print(f"Device: {info.device_name}") print(f"VRAM: {info.vram_gb:.1f} GB") print(f"Tier: {info.tier}")