405 lines
14 KiB
Python
405 lines
14 KiB
Python
"""GPU and VRAM detection. Returns vendor + tier used to pick torch wheel and default models."""
|
|
from __future__ import annotations
|
|
|
|
<<<<<<< HEAD
|
|
=======
|
|
<<<<<<< HEAD
|
|
=======
|
|
import ctypes
|
|
>>>>>>> refs/remotes/azuze/main
|
|
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
|
import platform
|
|
import subprocess
|
|
from dataclasses import dataclass
|
|
from typing import Literal
|
|
|
|
<<<<<<< HEAD
|
|
Vendor = Literal["nvidia", "amd", "intel", "apple", "cpu"]
|
|
Backend = Literal["cuda", "rocm", "directml", "mps", "cpu"]
|
|
SUPPORTED_BACKENDS: tuple[str, ...] = ("auto", "cuda", "rocm", "directml", "mps", "cpu")
|
|
=======
|
|
<<<<<<< HEAD
|
|
Vendor = Literal["nvidia", "amd", "intel", "apple", "cpu"]
|
|
Backend = Literal["cuda", "rocm", "directml", "mps", "cpu"]
|
|
SUPPORTED_BACKENDS: tuple[str, ...] = ("auto", "cuda", "rocm", "directml", "mps", "cpu")
|
|
=======
|
|
Vendor = Literal["nvidia", "amd", "intel", "cpu"]
|
|
Backend = Literal["cuda", "directml", "cpu"]
|
|
>>>>>>> refs/remotes/azuze/main
|
|
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
|
|
|
|
|
@dataclass
|
|
class HardwareInfo:
|
|
vendor: Vendor
|
|
backend: Backend
|
|
device_name: str
|
|
vram_gb: float
|
|
tier: Literal["cpu", "low", "mid", "high", "ultra"]
|
|
|
|
|
|
<<<<<<< HEAD
|
|
=======
|
|
<<<<<<< HEAD
|
|
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
|
def _detect_mac_gpu() -> tuple[str, float] | None:
|
|
"""Apple Silicon: report chip name + unified memory (proxy for VRAM).
|
|
Intel Mac: returns None (no GPU acceleration path)."""
|
|
if platform.system() != "Darwin":
|
|
return None
|
|
machine = platform.machine().lower()
|
|
if machine not in ("arm64", "aarch64"):
|
|
return None
|
|
name = "Apple Silicon"
|
|
brand = _run(["sysctl", "-n", "machdep.cpu.brand_string"])
|
|
if brand:
|
|
name = brand.strip()
|
|
mem_raw = _run(["sysctl", "-n", "hw.memsize"])
|
|
mem_gb = 0.0
|
|
if mem_raw:
|
|
try:
|
|
mem_gb = int(mem_raw) / (1024 ** 3)
|
|
except ValueError:
|
|
mem_gb = 0.0
|
|
# Unified memory: GPU can address ~75% in practice. Use that as VRAM proxy.
|
|
vram_gb = mem_gb * 0.75 if mem_gb else 0.0
|
|
return name, vram_gb
|
|
|
|
|
|
<<<<<<< HEAD
|
|
=======
|
|
=======
|
|
>>>>>>> refs/remotes/azuze/main
|
|
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
|
def _run(cmd: list[str]) -> str:
|
|
try:
|
|
out = subprocess.run(cmd, capture_output=True, text=True, timeout=10, check=False)
|
|
return out.stdout.strip()
|
|
except (FileNotFoundError, subprocess.TimeoutExpired):
|
|
return ""
|
|
|
|
|
|
def _detect_nvidia() -> tuple[str, float] | None:
|
|
out = _run(["nvidia-smi", "--query-gpu=name,memory.total", "--format=csv,noheader,nounits"])
|
|
if not out:
|
|
return None
|
|
first = out.splitlines()[0]
|
|
name, mem = [p.strip() for p in first.split(",", 1)]
|
|
return name, float(mem) / 1024.0
|
|
|
|
|
|
def _detect_dxgi() -> list[tuple[str, float, str]]:
|
|
"""Enumerate all DXGI adapters via PowerShell. Returns list of (name, vram_gb, vendor_hint)."""
|
|
if platform.system() != "Windows":
|
|
return []
|
|
ps = (
|
|
"Get-CimInstance Win32_VideoController | "
|
|
"Select-Object Name, AdapterRAM | "
|
|
"ConvertTo-Json -Compress"
|
|
)
|
|
out = _run(["powershell", "-NoProfile", "-Command", ps])
|
|
if not out:
|
|
return []
|
|
import json
|
|
try:
|
|
data = json.loads(out)
|
|
except json.JSONDecodeError:
|
|
return []
|
|
if isinstance(data, dict):
|
|
data = [data]
|
|
results: list[tuple[str, float, str]] = []
|
|
for entry in data:
|
|
name = (entry.get("Name") or "").strip()
|
|
ram = entry.get("AdapterRAM") or 0
|
|
# Win32_VideoController caps AdapterRAM at 4GB on many systems. Trust value but flag below.
|
|
vram_gb = float(ram) / (1024 ** 3) if ram else 0.0
|
|
low = name.lower()
|
|
if "nvidia" in low or "geforce" in low or "rtx" in low or "gtx" in low:
|
|
vendor = "nvidia"
|
|
elif "amd" in low or "radeon" in low or "rx " in low:
|
|
vendor = "amd"
|
|
elif "intel" in low or "arc" in low or "iris" in low:
|
|
vendor = "intel"
|
|
else:
|
|
vendor = "unknown"
|
|
if name:
|
|
results.append((name, vram_gb, vendor))
|
|
return results
|
|
|
|
|
|
<<<<<<< HEAD
|
|
=======
|
|
<<<<<<< HEAD
|
|
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
|
def _detect_linux_gpus() -> list[tuple[str, float, str]]:
|
|
"""Enumerate Linux GPUs. Returns list of (name, vram_gb, vendor_hint)."""
|
|
if platform.system() != "Linux":
|
|
return []
|
|
results: list[tuple[str, float, str]] = []
|
|
|
|
# AMD via rocm-smi (most accurate VRAM)
|
|
rocm = _run(["rocm-smi", "--showproductname", "--showmeminfo", "vram", "--json"])
|
|
if rocm:
|
|
import json as _json
|
|
try:
|
|
data = _json.loads(rocm)
|
|
for k, v in data.items():
|
|
if not isinstance(v, dict):
|
|
continue
|
|
name = v.get("Card series") or v.get("Card model") or v.get("Device Name") or k
|
|
vram_bytes = 0
|
|
for key in ("VRAM Total Memory (B)", "vram total memory (b)", "VRAM Total"):
|
|
if key in v:
|
|
try:
|
|
vram_bytes = int(v[key])
|
|
break
|
|
except (TypeError, ValueError):
|
|
pass
|
|
vram_gb = float(vram_bytes) / (1024 ** 3) if vram_bytes else 0.0
|
|
results.append((str(name).strip(), vram_gb, "amd"))
|
|
except (ValueError, KeyError):
|
|
pass
|
|
|
|
if results:
|
|
return results
|
|
|
|
# Fallback: lspci for vendor hints (no reliable VRAM number)
|
|
lspci = _run(["bash", "-c", "lspci -nn | grep -Ei 'vga|3d|display'"])
|
|
for line in lspci.splitlines():
|
|
low = line.lower()
|
|
if "nvidia" in low:
|
|
vendor = "nvidia"
|
|
elif "amd" in low or "advanced micro" in low or "ati " in low:
|
|
vendor = "amd"
|
|
elif "intel" in low:
|
|
vendor = "intel"
|
|
else:
|
|
continue
|
|
# Extract device name after the colon
|
|
name = line.split(":", 2)[-1].strip() if ":" in line else line.strip()
|
|
results.append((name, 0.0, vendor))
|
|
return results
|
|
|
|
|
|
def _vendor_from_backend(backend: str) -> Vendor:
|
|
return {
|
|
"cuda": "nvidia",
|
|
"rocm": "amd",
|
|
"directml": "amd",
|
|
"mps": "apple",
|
|
"cpu": "cpu",
|
|
}.get(backend, "cpu") # type: ignore[return-value]
|
|
|
|
|
|
<<<<<<< HEAD
|
|
=======
|
|
=======
|
|
>>>>>>> refs/remotes/azuze/main
|
|
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
|
def _vram_tier(vram_gb: float) -> Literal["cpu", "low", "mid", "high", "ultra"]:
|
|
if vram_gb < 1:
|
|
return "cpu"
|
|
if vram_gb < 6:
|
|
return "low"
|
|
if vram_gb < 10:
|
|
return "mid"
|
|
if vram_gb < 14:
|
|
return "high"
|
|
return "ultra"
|
|
|
|
|
|
<<<<<<< HEAD
|
|
=======
|
|
<<<<<<< HEAD
|
|
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
|
def detect(force_backend: str | None = None, force_vendor: str | None = None) -> HardwareInfo:
|
|
"""Auto-detect hardware. If force_backend is set (cuda/rocm/directml/cpu), skip detection
|
|
for that decision but still try to discover device name + VRAM for tier sizing."""
|
|
is_linux = platform.system() == "Linux"
|
|
is_windows = platform.system() == "Windows"
|
|
is_mac = platform.system() == "Darwin"
|
|
|
|
# Validate forced backend against platform
|
|
if force_backend == "directml" and not is_windows:
|
|
raise RuntimeError(
|
|
"DirectML backend is Windows-only. On Linux use --backend rocm (AMD), "
|
|
"--backend cuda (NVIDIA), or --backend cpu. On macOS use --backend mps."
|
|
)
|
|
if force_backend == "mps" and not is_mac:
|
|
raise RuntimeError("MPS backend is macOS-only (Apple Silicon).")
|
|
if force_backend == "rocm" and is_mac:
|
|
raise RuntimeError("ROCm is not available on macOS. Use --backend mps or --backend cpu.")
|
|
if force_backend == "cuda" and is_mac:
|
|
raise RuntimeError("CUDA is not available on macOS. Use --backend mps or --backend cpu.")
|
|
|
|
# Gather best-effort device info regardless of force.
|
|
nv = _detect_nvidia() if not is_mac else None
|
|
win_adapters = _detect_dxgi() if is_windows else []
|
|
win_adapters = [a for a in win_adapters if "basic" not in a[0].lower() and "microsoft" not in a[0].lower()]
|
|
lin_gpus = _detect_linux_gpus() if is_linux else []
|
|
mac_gpu = _detect_mac_gpu() if is_mac else None
|
|
|
|
def _best_for(vendor: str) -> tuple[str, float] | None:
|
|
if vendor == "nvidia" and nv:
|
|
return nv
|
|
pool = [(n, v) for n, v, h in (win_adapters + lin_gpus) if h == vendor]
|
|
if not pool:
|
|
return None
|
|
pool.sort(key=lambda x: x[1], reverse=True)
|
|
return pool[0]
|
|
|
|
if force_backend and force_backend != "auto":
|
|
vendor: Vendor = force_vendor or _vendor_from_backend(force_backend) # type: ignore[assignment]
|
|
if force_backend == "mps":
|
|
match = mac_gpu
|
|
else:
|
|
match = _best_for(vendor) or (nv if force_backend == "cuda" else None)
|
|
if match:
|
|
name, vram = match
|
|
else:
|
|
# No matching device found, but user forced this backend — proceed with unknown VRAM.
|
|
name = f"{vendor.upper()} (forced)"
|
|
vram = 0.0
|
|
return HardwareInfo(vendor, force_backend, name, vram, _vram_tier(vram)) # type: ignore[arg-type]
|
|
|
|
# Auto path
|
|
if is_mac and mac_gpu:
|
|
name, vram = mac_gpu
|
|
return HardwareInfo("apple", "mps", name, vram, _vram_tier(vram))
|
|
|
|
<<<<<<< HEAD
|
|
=======
|
|
=======
|
|
def detect() -> HardwareInfo:
|
|
nv = _detect_nvidia()
|
|
>>>>>>> refs/remotes/azuze/main
|
|
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
|
if nv:
|
|
name, vram = nv
|
|
return HardwareInfo("nvidia", "cuda", name, vram, _vram_tier(vram))
|
|
|
|
<<<<<<< HEAD
|
|
=======
|
|
<<<<<<< HEAD
|
|
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
|
if is_windows and win_adapters:
|
|
win_adapters.sort(key=lambda a: a[1], reverse=True)
|
|
name, vram, hint = win_adapters[0]
|
|
# AdapterRAM caps at 4GB for many cards. Bump if name suggests modern discrete.
|
|
if vram <= 4.1 and any(k in name.lower() for k in ("rx 6", "rx 7", "arc a", "arc b")):
|
|
vram = 8.0
|
|
if hint in ("amd", "intel"):
|
|
return HardwareInfo(hint, "directml", name, vram, _vram_tier(vram))
|
|
if hint == "nvidia":
|
|
return HardwareInfo("nvidia", "directml", name, vram, _vram_tier(vram))
|
|
|
|
if is_linux and lin_gpus:
|
|
lin_gpus.sort(key=lambda a: a[1], reverse=True)
|
|
name, vram, hint = lin_gpus[0]
|
|
if hint == "amd":
|
|
return HardwareInfo("amd", "rocm", name, vram, _vram_tier(vram))
|
|
if hint == "nvidia":
|
|
# nvidia-smi failed but card is nvidia: drivers likely missing. Fall through to CPU.
|
|
return HardwareInfo("cpu", "cpu", f"NVIDIA driver missing — {name}", 0.0, "cpu")
|
|
if hint == "intel":
|
|
# No good Intel-on-Linux torch path here; default to CPU.
|
|
return HardwareInfo("cpu", "cpu", f"Intel GPU (no backend) — {name}", 0.0, "cpu")
|
|
|
|
<<<<<<< HEAD
|
|
=======
|
|
=======
|
|
adapters = _detect_dxgi()
|
|
# Prefer discrete (highest VRAM) non-basic adapter
|
|
adapters = [a for a in adapters if "basic" not in a[0].lower() and "microsoft" not in a[0].lower()]
|
|
if adapters:
|
|
adapters.sort(key=lambda a: a[1], reverse=True)
|
|
name, vram, hint = adapters[0]
|
|
# AdapterRAM is unreliable for >4GB cards. If exactly 4GB and modern AMD/Intel card name, bump.
|
|
if vram <= 4.1 and any(k in name.lower() for k in ("rx 6", "rx 7", "arc a", "arc b")):
|
|
vram = 8.0 # conservative guess
|
|
if hint in ("amd", "intel"):
|
|
return HardwareInfo(hint, "directml", name, vram, _vram_tier(vram))
|
|
if hint == "nvidia":
|
|
# nvidia-smi missing but card is nvidia: drivers may be broken, fall through to directml
|
|
return HardwareInfo("nvidia", "directml", name, vram, _vram_tier(vram))
|
|
|
|
>>>>>>> refs/remotes/azuze/main
|
|
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
|
return HardwareInfo("cpu", "cpu", platform.processor() or "CPU", 0.0, "cpu")
|
|
|
|
|
|
def directml_supported() -> bool:
|
|
"""torch-directml ships wheels for Python 3.10 / 3.11. The launcher pins the
|
|
venv to 3.11, so this is True in the running app."""
|
|
import sys
|
|
return sys.version_info[:2] in {(3, 10), (3, 11)}
|
|
|
|
|
|
def torch_install_args(info: HardwareInfo) -> list[str]:
|
|
"""Return uv-pip install args for the right PyTorch build.
|
|
|
|
Launcher pins venv to Python 3.11, so torch-directml wheels are always available
|
|
for AMD/Intel paths. Latest stable torch is used elsewhere.
|
|
"""
|
|
if info.backend == "cuda":
|
|
return [
|
|
"torch",
|
|
"torchvision",
|
|
"--index-url",
|
|
"https://download.pytorch.org/whl/cu124",
|
|
]
|
|
<<<<<<< HEAD
|
|
=======
|
|
<<<<<<< HEAD
|
|
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
|
if info.backend == "rocm":
|
|
# ROCm wheels are Linux-only. Index pinned to a stable ROCm release line.
|
|
return [
|
|
"torch",
|
|
"torchvision",
|
|
"--index-url",
|
|
"https://download.pytorch.org/whl/rocm6.1",
|
|
]
|
|
<<<<<<< HEAD
|
|
=======
|
|
=======
|
|
>>>>>>> refs/remotes/azuze/main
|
|
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
|
if info.backend == "directml":
|
|
# torch-directml currently pins to torch 2.4.x. Match it.
|
|
return [
|
|
"torch>=2.4,<2.5",
|
|
"torchvision>=0.19,<0.20",
|
|
"torch-directml>=0.2.5",
|
|
]
|
|
<<<<<<< HEAD
|
|
=======
|
|
<<<<<<< HEAD
|
|
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
|
if info.backend == "mps":
|
|
# Default PyPI torch wheel ships MPS support on macOS arm64. No custom index.
|
|
return ["torch", "torchvision"]
|
|
# CPU. macOS uses default PyPI wheels (no /whl/cpu index for darwin).
|
|
if platform.system() == "Darwin":
|
|
return ["torch", "torchvision"]
|
|
<<<<<<< HEAD
|
|
=======
|
|
=======
|
|
>>>>>>> refs/remotes/azuze/main
|
|
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
|
return [
|
|
"torch",
|
|
"torchvision",
|
|
"--index-url",
|
|
"https://download.pytorch.org/whl/cpu",
|
|
]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
info = detect()
|
|
print(f"Vendor: {info.vendor}")
|
|
print(f"Backend: {info.backend}")
|
|
print(f"Device: {info.device_name}")
|
|
print(f"VRAM: {info.vram_gb:.1f} GB")
|
|
print(f"Tier: {info.tier}")
|