"""GPU and VRAM detection. Returns vendor + tier used to pick torch wheel and default models.""" from __future__ import annotations <<<<<<< HEAD ======= <<<<<<< HEAD ======= import ctypes >>>>>>> refs/remotes/azuze/main >>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9 import platform import subprocess from dataclasses import dataclass from typing import Literal <<<<<<< HEAD Vendor = Literal["nvidia", "amd", "intel", "apple", "cpu"] Backend = Literal["cuda", "rocm", "directml", "mps", "cpu"] SUPPORTED_BACKENDS: tuple[str, ...] = ("auto", "cuda", "rocm", "directml", "mps", "cpu") ======= <<<<<<< HEAD Vendor = Literal["nvidia", "amd", "intel", "apple", "cpu"] Backend = Literal["cuda", "rocm", "directml", "mps", "cpu"] SUPPORTED_BACKENDS: tuple[str, ...] = ("auto", "cuda", "rocm", "directml", "mps", "cpu") ======= Vendor = Literal["nvidia", "amd", "intel", "cpu"] Backend = Literal["cuda", "directml", "cpu"] >>>>>>> refs/remotes/azuze/main >>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9 @dataclass class HardwareInfo: vendor: Vendor backend: Backend device_name: str vram_gb: float tier: Literal["cpu", "low", "mid", "high", "ultra"] <<<<<<< HEAD ======= <<<<<<< HEAD >>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9 def _detect_mac_gpu() -> tuple[str, float] | None: """Apple Silicon: report chip name + unified memory (proxy for VRAM). Intel Mac: returns None (no GPU acceleration path).""" if platform.system() != "Darwin": return None machine = platform.machine().lower() if machine not in ("arm64", "aarch64"): return None name = "Apple Silicon" brand = _run(["sysctl", "-n", "machdep.cpu.brand_string"]) if brand: name = brand.strip() mem_raw = _run(["sysctl", "-n", "hw.memsize"]) mem_gb = 0.0 if mem_raw: try: mem_gb = int(mem_raw) / (1024 ** 3) except ValueError: mem_gb = 0.0 # Unified memory: GPU can address ~75% in practice. Use that as VRAM proxy. vram_gb = mem_gb * 0.75 if mem_gb else 0.0 return name, vram_gb <<<<<<< HEAD ======= ======= >>>>>>> refs/remotes/azuze/main >>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9 def _run(cmd: list[str]) -> str: try: out = subprocess.run(cmd, capture_output=True, text=True, timeout=10, check=False) return out.stdout.strip() except (FileNotFoundError, subprocess.TimeoutExpired): return "" def _detect_nvidia() -> tuple[str, float] | None: out = _run(["nvidia-smi", "--query-gpu=name,memory.total", "--format=csv,noheader,nounits"]) if not out: return None first = out.splitlines()[0] name, mem = [p.strip() for p in first.split(",", 1)] return name, float(mem) / 1024.0 def _detect_dxgi() -> list[tuple[str, float, str]]: """Enumerate all DXGI adapters via PowerShell. Returns list of (name, vram_gb, vendor_hint).""" if platform.system() != "Windows": return [] ps = ( "Get-CimInstance Win32_VideoController | " "Select-Object Name, AdapterRAM | " "ConvertTo-Json -Compress" ) out = _run(["powershell", "-NoProfile", "-Command", ps]) if not out: return [] import json try: data = json.loads(out) except json.JSONDecodeError: return [] if isinstance(data, dict): data = [data] results: list[tuple[str, float, str]] = [] for entry in data: name = (entry.get("Name") or "").strip() ram = entry.get("AdapterRAM") or 0 # Win32_VideoController caps AdapterRAM at 4GB on many systems. Trust value but flag below. vram_gb = float(ram) / (1024 ** 3) if ram else 0.0 low = name.lower() if "nvidia" in low or "geforce" in low or "rtx" in low or "gtx" in low: vendor = "nvidia" elif "amd" in low or "radeon" in low or "rx " in low: vendor = "amd" elif "intel" in low or "arc" in low or "iris" in low: vendor = "intel" else: vendor = "unknown" if name: results.append((name, vram_gb, vendor)) return results <<<<<<< HEAD ======= <<<<<<< HEAD >>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9 def _detect_linux_gpus() -> list[tuple[str, float, str]]: """Enumerate Linux GPUs. Returns list of (name, vram_gb, vendor_hint).""" if platform.system() != "Linux": return [] results: list[tuple[str, float, str]] = [] # AMD via rocm-smi (most accurate VRAM) rocm = _run(["rocm-smi", "--showproductname", "--showmeminfo", "vram", "--json"]) if rocm: import json as _json try: data = _json.loads(rocm) for k, v in data.items(): if not isinstance(v, dict): continue name = v.get("Card series") or v.get("Card model") or v.get("Device Name") or k vram_bytes = 0 for key in ("VRAM Total Memory (B)", "vram total memory (b)", "VRAM Total"): if key in v: try: vram_bytes = int(v[key]) break except (TypeError, ValueError): pass vram_gb = float(vram_bytes) / (1024 ** 3) if vram_bytes else 0.0 results.append((str(name).strip(), vram_gb, "amd")) except (ValueError, KeyError): pass if results: return results # Fallback: lspci for vendor hints (no reliable VRAM number) lspci = _run(["bash", "-c", "lspci -nn | grep -Ei 'vga|3d|display'"]) for line in lspci.splitlines(): low = line.lower() if "nvidia" in low: vendor = "nvidia" elif "amd" in low or "advanced micro" in low or "ati " in low: vendor = "amd" elif "intel" in low: vendor = "intel" else: continue # Extract device name after the colon name = line.split(":", 2)[-1].strip() if ":" in line else line.strip() results.append((name, 0.0, vendor)) return results def _vendor_from_backend(backend: str) -> Vendor: return { "cuda": "nvidia", "rocm": "amd", "directml": "amd", "mps": "apple", "cpu": "cpu", }.get(backend, "cpu") # type: ignore[return-value] <<<<<<< HEAD ======= ======= >>>>>>> refs/remotes/azuze/main >>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9 def _vram_tier(vram_gb: float) -> Literal["cpu", "low", "mid", "high", "ultra"]: if vram_gb < 1: return "cpu" if vram_gb < 6: return "low" if vram_gb < 10: return "mid" if vram_gb < 14: return "high" return "ultra" <<<<<<< HEAD ======= <<<<<<< HEAD >>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9 def detect(force_backend: str | None = None, force_vendor: str | None = None) -> HardwareInfo: """Auto-detect hardware. If force_backend is set (cuda/rocm/directml/cpu), skip detection for that decision but still try to discover device name + VRAM for tier sizing.""" is_linux = platform.system() == "Linux" is_windows = platform.system() == "Windows" is_mac = platform.system() == "Darwin" # Validate forced backend against platform if force_backend == "directml" and not is_windows: raise RuntimeError( "DirectML backend is Windows-only. On Linux use --backend rocm (AMD), " "--backend cuda (NVIDIA), or --backend cpu. On macOS use --backend mps." ) if force_backend == "mps" and not is_mac: raise RuntimeError("MPS backend is macOS-only (Apple Silicon).") if force_backend == "rocm" and is_mac: raise RuntimeError("ROCm is not available on macOS. Use --backend mps or --backend cpu.") if force_backend == "cuda" and is_mac: raise RuntimeError("CUDA is not available on macOS. Use --backend mps or --backend cpu.") # Gather best-effort device info regardless of force. nv = _detect_nvidia() if not is_mac else None win_adapters = _detect_dxgi() if is_windows else [] win_adapters = [a for a in win_adapters if "basic" not in a[0].lower() and "microsoft" not in a[0].lower()] lin_gpus = _detect_linux_gpus() if is_linux else [] mac_gpu = _detect_mac_gpu() if is_mac else None def _best_for(vendor: str) -> tuple[str, float] | None: if vendor == "nvidia" and nv: return nv pool = [(n, v) for n, v, h in (win_adapters + lin_gpus) if h == vendor] if not pool: return None pool.sort(key=lambda x: x[1], reverse=True) return pool[0] if force_backend and force_backend != "auto": vendor: Vendor = force_vendor or _vendor_from_backend(force_backend) # type: ignore[assignment] if force_backend == "mps": match = mac_gpu else: match = _best_for(vendor) or (nv if force_backend == "cuda" else None) if match: name, vram = match else: # No matching device found, but user forced this backend — proceed with unknown VRAM. name = f"{vendor.upper()} (forced)" vram = 0.0 return HardwareInfo(vendor, force_backend, name, vram, _vram_tier(vram)) # type: ignore[arg-type] # Auto path if is_mac and mac_gpu: name, vram = mac_gpu return HardwareInfo("apple", "mps", name, vram, _vram_tier(vram)) <<<<<<< HEAD ======= ======= def detect() -> HardwareInfo: nv = _detect_nvidia() >>>>>>> refs/remotes/azuze/main >>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9 if nv: name, vram = nv return HardwareInfo("nvidia", "cuda", name, vram, _vram_tier(vram)) <<<<<<< HEAD ======= <<<<<<< HEAD >>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9 if is_windows and win_adapters: win_adapters.sort(key=lambda a: a[1], reverse=True) name, vram, hint = win_adapters[0] # AdapterRAM caps at 4GB for many cards. Bump if name suggests modern discrete. if vram <= 4.1 and any(k in name.lower() for k in ("rx 6", "rx 7", "arc a", "arc b")): vram = 8.0 if hint in ("amd", "intel"): return HardwareInfo(hint, "directml", name, vram, _vram_tier(vram)) if hint == "nvidia": return HardwareInfo("nvidia", "directml", name, vram, _vram_tier(vram)) if is_linux and lin_gpus: lin_gpus.sort(key=lambda a: a[1], reverse=True) name, vram, hint = lin_gpus[0] if hint == "amd": return HardwareInfo("amd", "rocm", name, vram, _vram_tier(vram)) if hint == "nvidia": # nvidia-smi failed but card is nvidia: drivers likely missing. Fall through to CPU. return HardwareInfo("cpu", "cpu", f"NVIDIA driver missing — {name}", 0.0, "cpu") if hint == "intel": # No good Intel-on-Linux torch path here; default to CPU. return HardwareInfo("cpu", "cpu", f"Intel GPU (no backend) — {name}", 0.0, "cpu") <<<<<<< HEAD ======= ======= adapters = _detect_dxgi() # Prefer discrete (highest VRAM) non-basic adapter adapters = [a for a in adapters if "basic" not in a[0].lower() and "microsoft" not in a[0].lower()] if adapters: adapters.sort(key=lambda a: a[1], reverse=True) name, vram, hint = adapters[0] # AdapterRAM is unreliable for >4GB cards. If exactly 4GB and modern AMD/Intel card name, bump. if vram <= 4.1 and any(k in name.lower() for k in ("rx 6", "rx 7", "arc a", "arc b")): vram = 8.0 # conservative guess if hint in ("amd", "intel"): return HardwareInfo(hint, "directml", name, vram, _vram_tier(vram)) if hint == "nvidia": # nvidia-smi missing but card is nvidia: drivers may be broken, fall through to directml return HardwareInfo("nvidia", "directml", name, vram, _vram_tier(vram)) >>>>>>> refs/remotes/azuze/main >>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9 return HardwareInfo("cpu", "cpu", platform.processor() or "CPU", 0.0, "cpu") def directml_supported() -> bool: """torch-directml ships wheels for Python 3.10 / 3.11. The launcher pins the venv to 3.11, so this is True in the running app.""" import sys return sys.version_info[:2] in {(3, 10), (3, 11)} def torch_install_args(info: HardwareInfo) -> list[str]: """Return uv-pip install args for the right PyTorch build. Launcher pins venv to Python 3.11, so torch-directml wheels are always available for AMD/Intel paths. Latest stable torch is used elsewhere. """ if info.backend == "cuda": return [ "torch", "torchvision", "--index-url", "https://download.pytorch.org/whl/cu124", ] <<<<<<< HEAD ======= <<<<<<< HEAD >>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9 if info.backend == "rocm": # ROCm wheels are Linux-only. Index pinned to a stable ROCm release line. return [ "torch", "torchvision", "--index-url", "https://download.pytorch.org/whl/rocm6.1", ] <<<<<<< HEAD ======= ======= >>>>>>> refs/remotes/azuze/main >>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9 if info.backend == "directml": # torch-directml currently pins to torch 2.4.x. Match it. return [ "torch>=2.4,<2.5", "torchvision>=0.19,<0.20", "torch-directml>=0.2.5", ] <<<<<<< HEAD ======= <<<<<<< HEAD >>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9 if info.backend == "mps": # Default PyPI torch wheel ships MPS support on macOS arm64. No custom index. return ["torch", "torchvision"] # CPU. macOS uses default PyPI wheels (no /whl/cpu index for darwin). if platform.system() == "Darwin": return ["torch", "torchvision"] <<<<<<< HEAD ======= ======= >>>>>>> refs/remotes/azuze/main >>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9 return [ "torch", "torchvision", "--index-url", "https://download.pytorch.org/whl/cpu", ] if __name__ == "__main__": info = detect() print(f"Vendor: {info.vendor}") print(f"Backend: {info.backend}") print(f"Device: {info.device_name}") print(f"VRAM: {info.vram_gb:.1f} GB") print(f"Tier: {info.tier}")