157 lines
5.0 KiB
Python
157 lines
5.0 KiB
Python
"""GPU and VRAM detection. Returns vendor + tier used to pick torch wheel and default models."""
|
|
from __future__ import annotations
|
|
|
|
import ctypes
|
|
import platform
|
|
import subprocess
|
|
from dataclasses import dataclass
|
|
from typing import Literal
|
|
|
|
Vendor = Literal["nvidia", "amd", "intel", "cpu"]
|
|
Backend = Literal["cuda", "directml", "cpu"]
|
|
|
|
|
|
@dataclass
|
|
class HardwareInfo:
|
|
vendor: Vendor
|
|
backend: Backend
|
|
device_name: str
|
|
vram_gb: float
|
|
tier: Literal["cpu", "low", "mid", "high", "ultra"]
|
|
|
|
|
|
def _run(cmd: list[str]) -> str:
|
|
try:
|
|
out = subprocess.run(cmd, capture_output=True, text=True, timeout=10, check=False)
|
|
return out.stdout.strip()
|
|
except (FileNotFoundError, subprocess.TimeoutExpired):
|
|
return ""
|
|
|
|
|
|
def _detect_nvidia() -> tuple[str, float] | None:
|
|
out = _run(["nvidia-smi", "--query-gpu=name,memory.total", "--format=csv,noheader,nounits"])
|
|
if not out:
|
|
return None
|
|
first = out.splitlines()[0]
|
|
name, mem = [p.strip() for p in first.split(",", 1)]
|
|
return name, float(mem) / 1024.0
|
|
|
|
|
|
def _detect_dxgi() -> list[tuple[str, float, str]]:
|
|
"""Enumerate all DXGI adapters via PowerShell. Returns list of (name, vram_gb, vendor_hint)."""
|
|
if platform.system() != "Windows":
|
|
return []
|
|
ps = (
|
|
"Get-CimInstance Win32_VideoController | "
|
|
"Select-Object Name, AdapterRAM | "
|
|
"ConvertTo-Json -Compress"
|
|
)
|
|
out = _run(["powershell", "-NoProfile", "-Command", ps])
|
|
if not out:
|
|
return []
|
|
import json
|
|
try:
|
|
data = json.loads(out)
|
|
except json.JSONDecodeError:
|
|
return []
|
|
if isinstance(data, dict):
|
|
data = [data]
|
|
results: list[tuple[str, float, str]] = []
|
|
for entry in data:
|
|
name = (entry.get("Name") or "").strip()
|
|
ram = entry.get("AdapterRAM") or 0
|
|
# Win32_VideoController caps AdapterRAM at 4GB on many systems. Trust value but flag below.
|
|
vram_gb = float(ram) / (1024 ** 3) if ram else 0.0
|
|
low = name.lower()
|
|
if "nvidia" in low or "geforce" in low or "rtx" in low or "gtx" in low:
|
|
vendor = "nvidia"
|
|
elif "amd" in low or "radeon" in low or "rx " in low:
|
|
vendor = "amd"
|
|
elif "intel" in low or "arc" in low or "iris" in low:
|
|
vendor = "intel"
|
|
else:
|
|
vendor = "unknown"
|
|
if name:
|
|
results.append((name, vram_gb, vendor))
|
|
return results
|
|
|
|
|
|
def _vram_tier(vram_gb: float) -> Literal["cpu", "low", "mid", "high", "ultra"]:
|
|
if vram_gb < 1:
|
|
return "cpu"
|
|
if vram_gb < 6:
|
|
return "low"
|
|
if vram_gb < 10:
|
|
return "mid"
|
|
if vram_gb < 14:
|
|
return "high"
|
|
return "ultra"
|
|
|
|
|
|
def detect() -> HardwareInfo:
|
|
nv = _detect_nvidia()
|
|
if nv:
|
|
name, vram = nv
|
|
return HardwareInfo("nvidia", "cuda", name, vram, _vram_tier(vram))
|
|
|
|
adapters = _detect_dxgi()
|
|
# Prefer discrete (highest VRAM) non-basic adapter
|
|
adapters = [a for a in adapters if "basic" not in a[0].lower() and "microsoft" not in a[0].lower()]
|
|
if adapters:
|
|
adapters.sort(key=lambda a: a[1], reverse=True)
|
|
name, vram, hint = adapters[0]
|
|
# AdapterRAM is unreliable for >4GB cards. If exactly 4GB and modern AMD/Intel card name, bump.
|
|
if vram <= 4.1 and any(k in name.lower() for k in ("rx 6", "rx 7", "arc a", "arc b")):
|
|
vram = 8.0 # conservative guess
|
|
if hint in ("amd", "intel"):
|
|
return HardwareInfo(hint, "directml", name, vram, _vram_tier(vram))
|
|
if hint == "nvidia":
|
|
# nvidia-smi missing but card is nvidia: drivers may be broken, fall through to directml
|
|
return HardwareInfo("nvidia", "directml", name, vram, _vram_tier(vram))
|
|
|
|
return HardwareInfo("cpu", "cpu", platform.processor() or "CPU", 0.0, "cpu")
|
|
|
|
|
|
def directml_supported() -> bool:
|
|
"""torch-directml ships wheels for Python 3.10 / 3.11. The launcher pins the
|
|
venv to 3.11, so this is True in the running app."""
|
|
import sys
|
|
return sys.version_info[:2] in {(3, 10), (3, 11)}
|
|
|
|
|
|
def torch_install_args(info: HardwareInfo) -> list[str]:
|
|
"""Return uv-pip install args for the right PyTorch build.
|
|
|
|
Launcher pins venv to Python 3.11, so torch-directml wheels are always available
|
|
for AMD/Intel paths. Latest stable torch is used elsewhere.
|
|
"""
|
|
if info.backend == "cuda":
|
|
return [
|
|
"torch",
|
|
"torchvision",
|
|
"--index-url",
|
|
"https://download.pytorch.org/whl/cu124",
|
|
]
|
|
if info.backend == "directml":
|
|
# torch-directml currently pins to torch 2.4.x. Match it.
|
|
return [
|
|
"torch>=2.4,<2.5",
|
|
"torchvision>=0.19,<0.20",
|
|
"torch-directml>=0.2.5",
|
|
]
|
|
return [
|
|
"torch",
|
|
"torchvision",
|
|
"--index-url",
|
|
"https://download.pytorch.org/whl/cpu",
|
|
]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
info = detect()
|
|
print(f"Vendor: {info.vendor}")
|
|
print(f"Backend: {info.backend}")
|
|
print(f"Device: {info.device_name}")
|
|
print(f"VRAM: {info.vram_gb:.1f} GB")
|
|
print(f"Tier: {info.tier}")
|