Files
KawAI/backends/hardware.py
T
2026-05-04 09:38:56 +02:00

157 lines
5.0 KiB
Python

"""GPU and VRAM detection. Returns vendor + tier used to pick torch wheel and default models."""
from __future__ import annotations
import ctypes
import platform
import subprocess
from dataclasses import dataclass
from typing import Literal
Vendor = Literal["nvidia", "amd", "intel", "cpu"]
Backend = Literal["cuda", "directml", "cpu"]
@dataclass
class HardwareInfo:
vendor: Vendor
backend: Backend
device_name: str
vram_gb: float
tier: Literal["cpu", "low", "mid", "high", "ultra"]
def _run(cmd: list[str]) -> str:
try:
out = subprocess.run(cmd, capture_output=True, text=True, timeout=10, check=False)
return out.stdout.strip()
except (FileNotFoundError, subprocess.TimeoutExpired):
return ""
def _detect_nvidia() -> tuple[str, float] | None:
out = _run(["nvidia-smi", "--query-gpu=name,memory.total", "--format=csv,noheader,nounits"])
if not out:
return None
first = out.splitlines()[0]
name, mem = [p.strip() for p in first.split(",", 1)]
return name, float(mem) / 1024.0
def _detect_dxgi() -> list[tuple[str, float, str]]:
"""Enumerate all DXGI adapters via PowerShell. Returns list of (name, vram_gb, vendor_hint)."""
if platform.system() != "Windows":
return []
ps = (
"Get-CimInstance Win32_VideoController | "
"Select-Object Name, AdapterRAM | "
"ConvertTo-Json -Compress"
)
out = _run(["powershell", "-NoProfile", "-Command", ps])
if not out:
return []
import json
try:
data = json.loads(out)
except json.JSONDecodeError:
return []
if isinstance(data, dict):
data = [data]
results: list[tuple[str, float, str]] = []
for entry in data:
name = (entry.get("Name") or "").strip()
ram = entry.get("AdapterRAM") or 0
# Win32_VideoController caps AdapterRAM at 4GB on many systems. Trust value but flag below.
vram_gb = float(ram) / (1024 ** 3) if ram else 0.0
low = name.lower()
if "nvidia" in low or "geforce" in low or "rtx" in low or "gtx" in low:
vendor = "nvidia"
elif "amd" in low or "radeon" in low or "rx " in low:
vendor = "amd"
elif "intel" in low or "arc" in low or "iris" in low:
vendor = "intel"
else:
vendor = "unknown"
if name:
results.append((name, vram_gb, vendor))
return results
def _vram_tier(vram_gb: float) -> Literal["cpu", "low", "mid", "high", "ultra"]:
if vram_gb < 1:
return "cpu"
if vram_gb < 6:
return "low"
if vram_gb < 10:
return "mid"
if vram_gb < 14:
return "high"
return "ultra"
def detect() -> HardwareInfo:
nv = _detect_nvidia()
if nv:
name, vram = nv
return HardwareInfo("nvidia", "cuda", name, vram, _vram_tier(vram))
adapters = _detect_dxgi()
# Prefer discrete (highest VRAM) non-basic adapter
adapters = [a for a in adapters if "basic" not in a[0].lower() and "microsoft" not in a[0].lower()]
if adapters:
adapters.sort(key=lambda a: a[1], reverse=True)
name, vram, hint = adapters[0]
# AdapterRAM is unreliable for >4GB cards. If exactly 4GB and modern AMD/Intel card name, bump.
if vram <= 4.1 and any(k in name.lower() for k in ("rx 6", "rx 7", "arc a", "arc b")):
vram = 8.0 # conservative guess
if hint in ("amd", "intel"):
return HardwareInfo(hint, "directml", name, vram, _vram_tier(vram))
if hint == "nvidia":
# nvidia-smi missing but card is nvidia: drivers may be broken, fall through to directml
return HardwareInfo("nvidia", "directml", name, vram, _vram_tier(vram))
return HardwareInfo("cpu", "cpu", platform.processor() or "CPU", 0.0, "cpu")
def directml_supported() -> bool:
"""torch-directml ships wheels for Python 3.10 / 3.11. The launcher pins the
venv to 3.11, so this is True in the running app."""
import sys
return sys.version_info[:2] in {(3, 10), (3, 11)}
def torch_install_args(info: HardwareInfo) -> list[str]:
"""Return uv-pip install args for the right PyTorch build.
Launcher pins venv to Python 3.11, so torch-directml wheels are always available
for AMD/Intel paths. Latest stable torch is used elsewhere.
"""
if info.backend == "cuda":
return [
"torch",
"torchvision",
"--index-url",
"https://download.pytorch.org/whl/cu124",
]
if info.backend == "directml":
# torch-directml currently pins to torch 2.4.x. Match it.
return [
"torch>=2.4,<2.5",
"torchvision>=0.19,<0.20",
"torch-directml>=0.2.5",
]
return [
"torch",
"torchvision",
"--index-url",
"https://download.pytorch.org/whl/cpu",
]
if __name__ == "__main__":
info = detect()
print(f"Vendor: {info.vendor}")
print(f"Backend: {info.backend}")
print(f"Device: {info.device_name}")
print(f"VRAM: {info.vram_gb:.1f} GB")
print(f"Tier: {info.tier}")