Fix bad merge
This commit is contained in:
@@ -1,32 +1,14 @@
|
||||
"""GPU and VRAM detection. Returns vendor + tier used to pick torch wheel and default models."""
|
||||
from __future__ import annotations
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
import ctypes
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
import platform
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
<<<<<<< HEAD
|
||||
Vendor = Literal["nvidia", "amd", "intel", "apple", "cpu"]
|
||||
Backend = Literal["cuda", "rocm", "directml", "mps", "cpu"]
|
||||
SUPPORTED_BACKENDS: tuple[str, ...] = ("auto", "cuda", "rocm", "directml", "mps", "cpu")
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
Vendor = Literal["nvidia", "amd", "intel", "apple", "cpu"]
|
||||
Backend = Literal["cuda", "rocm", "directml", "mps", "cpu"]
|
||||
SUPPORTED_BACKENDS: tuple[str, ...] = ("auto", "cuda", "rocm", "directml", "mps", "cpu")
|
||||
=======
|
||||
Vendor = Literal["nvidia", "amd", "intel", "cpu"]
|
||||
Backend = Literal["cuda", "directml", "cpu"]
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -38,10 +20,6 @@ class HardwareInfo:
|
||||
tier: Literal["cpu", "low", "mid", "high", "ultra"]
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
def _detect_mac_gpu() -> tuple[str, float] | None:
|
||||
"""Apple Silicon: report chip name + unified memory (proxy for VRAM).
|
||||
Intel Mac: returns None (no GPU acceleration path)."""
|
||||
@@ -66,11 +44,6 @@ def _detect_mac_gpu() -> tuple[str, float] | None:
|
||||
return name, vram_gb
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
=======
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
def _run(cmd: list[str]) -> str:
|
||||
try:
|
||||
out = subprocess.run(cmd, capture_output=True, text=True, timeout=10, check=False)
|
||||
@@ -127,10 +100,6 @@ def _detect_dxgi() -> list[tuple[str, float, str]]:
|
||||
return results
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
def _detect_linux_gpus() -> list[tuple[str, float, str]]:
|
||||
"""Enumerate Linux GPUs. Returns list of (name, vram_gb, vendor_hint)."""
|
||||
if platform.system() != "Linux":
|
||||
@@ -191,11 +160,6 @@ def _vendor_from_backend(backend: str) -> Vendor:
|
||||
}.get(backend, "cpu") # type: ignore[return-value]
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
=======
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
def _vram_tier(vram_gb: float) -> Literal["cpu", "low", "mid", "high", "ultra"]:
|
||||
if vram_gb < 1:
|
||||
return "cpu"
|
||||
@@ -208,10 +172,6 @@ def _vram_tier(vram_gb: float) -> Literal["cpu", "low", "mid", "high", "ultra"]:
|
||||
return "ultra"
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
def detect(force_backend: str | None = None, force_vendor: str | None = None) -> HardwareInfo:
|
||||
"""Auto-detect hardware. If force_backend is set (cuda/rocm/directml/cpu), skip detection
|
||||
for that decision but still try to discover device name + VRAM for tier sizing."""
|
||||
@@ -267,21 +227,10 @@ def detect(force_backend: str | None = None, force_vendor: str | None = None) ->
|
||||
name, vram = mac_gpu
|
||||
return HardwareInfo("apple", "mps", name, vram, _vram_tier(vram))
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
=======
|
||||
def detect() -> HardwareInfo:
|
||||
nv = _detect_nvidia()
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
if nv:
|
||||
name, vram = nv
|
||||
return HardwareInfo("nvidia", "cuda", name, vram, _vram_tier(vram))
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
if is_windows and win_adapters:
|
||||
win_adapters.sort(key=lambda a: a[1], reverse=True)
|
||||
name, vram, hint = win_adapters[0]
|
||||
@@ -305,26 +254,6 @@ def detect() -> HardwareInfo:
|
||||
# No good Intel-on-Linux torch path here; default to CPU.
|
||||
return HardwareInfo("cpu", "cpu", f"Intel GPU (no backend) — {name}", 0.0, "cpu")
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
=======
|
||||
adapters = _detect_dxgi()
|
||||
# Prefer discrete (highest VRAM) non-basic adapter
|
||||
adapters = [a for a in adapters if "basic" not in a[0].lower() and "microsoft" not in a[0].lower()]
|
||||
if adapters:
|
||||
adapters.sort(key=lambda a: a[1], reverse=True)
|
||||
name, vram, hint = adapters[0]
|
||||
# AdapterRAM is unreliable for >4GB cards. If exactly 4GB and modern AMD/Intel card name, bump.
|
||||
if vram <= 4.1 and any(k in name.lower() for k in ("rx 6", "rx 7", "arc a", "arc b")):
|
||||
vram = 8.0 # conservative guess
|
||||
if hint in ("amd", "intel"):
|
||||
return HardwareInfo(hint, "directml", name, vram, _vram_tier(vram))
|
||||
if hint == "nvidia":
|
||||
# nvidia-smi missing but card is nvidia: drivers may be broken, fall through to directml
|
||||
return HardwareInfo("nvidia", "directml", name, vram, _vram_tier(vram))
|
||||
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
return HardwareInfo("cpu", "cpu", platform.processor() or "CPU", 0.0, "cpu")
|
||||
|
||||
|
||||
@@ -348,10 +277,6 @@ def torch_install_args(info: HardwareInfo) -> list[str]:
|
||||
"--index-url",
|
||||
"https://download.pytorch.org/whl/cu124",
|
||||
]
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
if info.backend == "rocm":
|
||||
# ROCm wheels are Linux-only. Index pinned to a stable ROCm release line.
|
||||
return [
|
||||
@@ -360,11 +285,6 @@ def torch_install_args(info: HardwareInfo) -> list[str]:
|
||||
"--index-url",
|
||||
"https://download.pytorch.org/whl/rocm6.1",
|
||||
]
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
=======
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
if info.backend == "directml":
|
||||
# torch-directml currently pins to torch 2.4.x. Match it.
|
||||
return [
|
||||
@@ -372,21 +292,12 @@ def torch_install_args(info: HardwareInfo) -> list[str]:
|
||||
"torchvision>=0.19,<0.20",
|
||||
"torch-directml>=0.2.5",
|
||||
]
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
if info.backend == "mps":
|
||||
# Default PyPI torch wheel ships MPS support on macOS arm64. No custom index.
|
||||
return ["torch", "torchvision"]
|
||||
# CPU. macOS uses default PyPI wheels (no /whl/cpu index for darwin).
|
||||
if platform.system() == "Darwin":
|
||||
return ["torch", "torchvision"]
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
=======
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
return [
|
||||
"torch",
|
||||
"torchvision",
|
||||
|
||||
Reference in New Issue
Block a user