Files
KawAI/backends/hardware.py
T

360 lines
13 KiB
Python

"""GPU and VRAM detection. Returns vendor + tier used to pick torch wheel and default models."""
from __future__ import annotations
<<<<<<< HEAD
=======
import ctypes
>>>>>>> refs/remotes/azuze/main
import platform
import subprocess
from dataclasses import dataclass
from typing import Literal
<<<<<<< HEAD
Vendor = Literal["nvidia", "amd", "intel", "apple", "cpu"]
Backend = Literal["cuda", "rocm", "directml", "mps", "cpu"]
SUPPORTED_BACKENDS: tuple[str, ...] = ("auto", "cuda", "rocm", "directml", "mps", "cpu")
=======
Vendor = Literal["nvidia", "amd", "intel", "cpu"]
Backend = Literal["cuda", "directml", "cpu"]
>>>>>>> refs/remotes/azuze/main
@dataclass
class HardwareInfo:
vendor: Vendor
backend: Backend
device_name: str
vram_gb: float
tier: Literal["cpu", "low", "mid", "high", "ultra"]
<<<<<<< HEAD
def _detect_mac_gpu() -> tuple[str, float] | None:
"""Apple Silicon: report chip name + unified memory (proxy for VRAM).
Intel Mac: returns None (no GPU acceleration path)."""
if platform.system() != "Darwin":
return None
machine = platform.machine().lower()
if machine not in ("arm64", "aarch64"):
return None
name = "Apple Silicon"
brand = _run(["sysctl", "-n", "machdep.cpu.brand_string"])
if brand:
name = brand.strip()
mem_raw = _run(["sysctl", "-n", "hw.memsize"])
mem_gb = 0.0
if mem_raw:
try:
mem_gb = int(mem_raw) / (1024 ** 3)
except ValueError:
mem_gb = 0.0
# Unified memory: GPU can address ~75% in practice. Use that as VRAM proxy.
vram_gb = mem_gb * 0.75 if mem_gb else 0.0
return name, vram_gb
=======
>>>>>>> refs/remotes/azuze/main
def _run(cmd: list[str]) -> str:
try:
out = subprocess.run(cmd, capture_output=True, text=True, timeout=10, check=False)
return out.stdout.strip()
except (FileNotFoundError, subprocess.TimeoutExpired):
return ""
def _detect_nvidia() -> tuple[str, float] | None:
out = _run(["nvidia-smi", "--query-gpu=name,memory.total", "--format=csv,noheader,nounits"])
if not out:
return None
first = out.splitlines()[0]
name, mem = [p.strip() for p in first.split(",", 1)]
return name, float(mem) / 1024.0
def _detect_dxgi() -> list[tuple[str, float, str]]:
"""Enumerate all DXGI adapters via PowerShell. Returns list of (name, vram_gb, vendor_hint)."""
if platform.system() != "Windows":
return []
ps = (
"Get-CimInstance Win32_VideoController | "
"Select-Object Name, AdapterRAM | "
"ConvertTo-Json -Compress"
)
out = _run(["powershell", "-NoProfile", "-Command", ps])
if not out:
return []
import json
try:
data = json.loads(out)
except json.JSONDecodeError:
return []
if isinstance(data, dict):
data = [data]
results: list[tuple[str, float, str]] = []
for entry in data:
name = (entry.get("Name") or "").strip()
ram = entry.get("AdapterRAM") or 0
# Win32_VideoController caps AdapterRAM at 4GB on many systems. Trust value but flag below.
vram_gb = float(ram) / (1024 ** 3) if ram else 0.0
low = name.lower()
if "nvidia" in low or "geforce" in low or "rtx" in low or "gtx" in low:
vendor = "nvidia"
elif "amd" in low or "radeon" in low or "rx " in low:
vendor = "amd"
elif "intel" in low or "arc" in low or "iris" in low:
vendor = "intel"
else:
vendor = "unknown"
if name:
results.append((name, vram_gb, vendor))
return results
<<<<<<< HEAD
def _detect_linux_gpus() -> list[tuple[str, float, str]]:
"""Enumerate Linux GPUs. Returns list of (name, vram_gb, vendor_hint)."""
if platform.system() != "Linux":
return []
results: list[tuple[str, float, str]] = []
# AMD via rocm-smi (most accurate VRAM)
rocm = _run(["rocm-smi", "--showproductname", "--showmeminfo", "vram", "--json"])
if rocm:
import json as _json
try:
data = _json.loads(rocm)
for k, v in data.items():
if not isinstance(v, dict):
continue
name = v.get("Card series") or v.get("Card model") or v.get("Device Name") or k
vram_bytes = 0
for key in ("VRAM Total Memory (B)", "vram total memory (b)", "VRAM Total"):
if key in v:
try:
vram_bytes = int(v[key])
break
except (TypeError, ValueError):
pass
vram_gb = float(vram_bytes) / (1024 ** 3) if vram_bytes else 0.0
results.append((str(name).strip(), vram_gb, "amd"))
except (ValueError, KeyError):
pass
if results:
return results
# Fallback: lspci for vendor hints (no reliable VRAM number)
lspci = _run(["bash", "-c", "lspci -nn | grep -Ei 'vga|3d|display'"])
for line in lspci.splitlines():
low = line.lower()
if "nvidia" in low:
vendor = "nvidia"
elif "amd" in low or "advanced micro" in low or "ati " in low:
vendor = "amd"
elif "intel" in low:
vendor = "intel"
else:
continue
# Extract device name after the colon
name = line.split(":", 2)[-1].strip() if ":" in line else line.strip()
results.append((name, 0.0, vendor))
return results
def _vendor_from_backend(backend: str) -> Vendor:
return {
"cuda": "nvidia",
"rocm": "amd",
"directml": "amd",
"mps": "apple",
"cpu": "cpu",
}.get(backend, "cpu") # type: ignore[return-value]
=======
>>>>>>> refs/remotes/azuze/main
def _vram_tier(vram_gb: float) -> Literal["cpu", "low", "mid", "high", "ultra"]:
if vram_gb < 1:
return "cpu"
if vram_gb < 6:
return "low"
if vram_gb < 10:
return "mid"
if vram_gb < 14:
return "high"
return "ultra"
<<<<<<< HEAD
def detect(force_backend: str | None = None, force_vendor: str | None = None) -> HardwareInfo:
"""Auto-detect hardware. If force_backend is set (cuda/rocm/directml/cpu), skip detection
for that decision but still try to discover device name + VRAM for tier sizing."""
is_linux = platform.system() == "Linux"
is_windows = platform.system() == "Windows"
is_mac = platform.system() == "Darwin"
# Validate forced backend against platform
if force_backend == "directml" and not is_windows:
raise RuntimeError(
"DirectML backend is Windows-only. On Linux use --backend rocm (AMD), "
"--backend cuda (NVIDIA), or --backend cpu. On macOS use --backend mps."
)
if force_backend == "mps" and not is_mac:
raise RuntimeError("MPS backend is macOS-only (Apple Silicon).")
if force_backend == "rocm" and is_mac:
raise RuntimeError("ROCm is not available on macOS. Use --backend mps or --backend cpu.")
if force_backend == "cuda" and is_mac:
raise RuntimeError("CUDA is not available on macOS. Use --backend mps or --backend cpu.")
# Gather best-effort device info regardless of force.
nv = _detect_nvidia() if not is_mac else None
win_adapters = _detect_dxgi() if is_windows else []
win_adapters = [a for a in win_adapters if "basic" not in a[0].lower() and "microsoft" not in a[0].lower()]
lin_gpus = _detect_linux_gpus() if is_linux else []
mac_gpu = _detect_mac_gpu() if is_mac else None
def _best_for(vendor: str) -> tuple[str, float] | None:
if vendor == "nvidia" and nv:
return nv
pool = [(n, v) for n, v, h in (win_adapters + lin_gpus) if h == vendor]
if not pool:
return None
pool.sort(key=lambda x: x[1], reverse=True)
return pool[0]
if force_backend and force_backend != "auto":
vendor: Vendor = force_vendor or _vendor_from_backend(force_backend) # type: ignore[assignment]
if force_backend == "mps":
match = mac_gpu
else:
match = _best_for(vendor) or (nv if force_backend == "cuda" else None)
if match:
name, vram = match
else:
# No matching device found, but user forced this backend — proceed with unknown VRAM.
name = f"{vendor.upper()} (forced)"
vram = 0.0
return HardwareInfo(vendor, force_backend, name, vram, _vram_tier(vram)) # type: ignore[arg-type]
# Auto path
if is_mac and mac_gpu:
name, vram = mac_gpu
return HardwareInfo("apple", "mps", name, vram, _vram_tier(vram))
=======
def detect() -> HardwareInfo:
nv = _detect_nvidia()
>>>>>>> refs/remotes/azuze/main
if nv:
name, vram = nv
return HardwareInfo("nvidia", "cuda", name, vram, _vram_tier(vram))
<<<<<<< HEAD
if is_windows and win_adapters:
win_adapters.sort(key=lambda a: a[1], reverse=True)
name, vram, hint = win_adapters[0]
# AdapterRAM caps at 4GB for many cards. Bump if name suggests modern discrete.
if vram <= 4.1 and any(k in name.lower() for k in ("rx 6", "rx 7", "arc a", "arc b")):
vram = 8.0
if hint in ("amd", "intel"):
return HardwareInfo(hint, "directml", name, vram, _vram_tier(vram))
if hint == "nvidia":
return HardwareInfo("nvidia", "directml", name, vram, _vram_tier(vram))
if is_linux and lin_gpus:
lin_gpus.sort(key=lambda a: a[1], reverse=True)
name, vram, hint = lin_gpus[0]
if hint == "amd":
return HardwareInfo("amd", "rocm", name, vram, _vram_tier(vram))
if hint == "nvidia":
# nvidia-smi failed but card is nvidia: drivers likely missing. Fall through to CPU.
return HardwareInfo("cpu", "cpu", f"NVIDIA driver missing — {name}", 0.0, "cpu")
if hint == "intel":
# No good Intel-on-Linux torch path here; default to CPU.
return HardwareInfo("cpu", "cpu", f"Intel GPU (no backend) — {name}", 0.0, "cpu")
=======
adapters = _detect_dxgi()
# Prefer discrete (highest VRAM) non-basic adapter
adapters = [a for a in adapters if "basic" not in a[0].lower() and "microsoft" not in a[0].lower()]
if adapters:
adapters.sort(key=lambda a: a[1], reverse=True)
name, vram, hint = adapters[0]
# AdapterRAM is unreliable for >4GB cards. If exactly 4GB and modern AMD/Intel card name, bump.
if vram <= 4.1 and any(k in name.lower() for k in ("rx 6", "rx 7", "arc a", "arc b")):
vram = 8.0 # conservative guess
if hint in ("amd", "intel"):
return HardwareInfo(hint, "directml", name, vram, _vram_tier(vram))
if hint == "nvidia":
# nvidia-smi missing but card is nvidia: drivers may be broken, fall through to directml
return HardwareInfo("nvidia", "directml", name, vram, _vram_tier(vram))
>>>>>>> refs/remotes/azuze/main
return HardwareInfo("cpu", "cpu", platform.processor() or "CPU", 0.0, "cpu")
def directml_supported() -> bool:
"""torch-directml ships wheels for Python 3.10 / 3.11. The launcher pins the
venv to 3.11, so this is True in the running app."""
import sys
return sys.version_info[:2] in {(3, 10), (3, 11)}
def torch_install_args(info: HardwareInfo) -> list[str]:
"""Return uv-pip install args for the right PyTorch build.
Launcher pins venv to Python 3.11, so torch-directml wheels are always available
for AMD/Intel paths. Latest stable torch is used elsewhere.
"""
if info.backend == "cuda":
return [
"torch",
"torchvision",
"--index-url",
"https://download.pytorch.org/whl/cu124",
]
<<<<<<< HEAD
if info.backend == "rocm":
# ROCm wheels are Linux-only. Index pinned to a stable ROCm release line.
return [
"torch",
"torchvision",
"--index-url",
"https://download.pytorch.org/whl/rocm6.1",
]
=======
>>>>>>> refs/remotes/azuze/main
if info.backend == "directml":
# torch-directml currently pins to torch 2.4.x. Match it.
return [
"torch>=2.4,<2.5",
"torchvision>=0.19,<0.20",
"torch-directml>=0.2.5",
]
<<<<<<< HEAD
if info.backend == "mps":
# Default PyPI torch wheel ships MPS support on macOS arm64. No custom index.
return ["torch", "torchvision"]
# CPU. macOS uses default PyPI wheels (no /whl/cpu index for darwin).
if platform.system() == "Darwin":
return ["torch", "torchvision"]
=======
>>>>>>> refs/remotes/azuze/main
return [
"torch",
"torchvision",
"--index-url",
"https://download.pytorch.org/whl/cpu",
]
if __name__ == "__main__":
info = detect()
print(f"Vendor: {info.vendor}")
print(f"Backend: {info.backend}")
print(f"Device: {info.device_name}")
print(f"VRAM: {info.vram_gb:.1f} GB")
print(f"Tier: {info.tier}")