Merge branch 'main' of https://git.azuze.fr/kawa/KawAI
This commit is contained in:
@@ -28,8 +28,17 @@ def hardware_info() -> dict:
|
||||
def get_device():
|
||||
import torch
|
||||
backend = hardware_info()["backend"]
|
||||
<<<<<<< HEAD
|
||||
# ROCm builds of torch expose the cuda namespace.
|
||||
if backend in ("cuda", "rocm") and torch.cuda.is_available():
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
# ROCm builds of torch expose the cuda namespace.
|
||||
if backend in ("cuda", "rocm") and torch.cuda.is_available():
|
||||
=======
|
||||
if backend == "cuda" and torch.cuda.is_available():
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
return torch.device("cuda")
|
||||
if backend == "directml":
|
||||
try:
|
||||
@@ -37,8 +46,16 @@ def get_device():
|
||||
return torch_directml.device()
|
||||
except ImportError:
|
||||
pass
|
||||
<<<<<<< HEAD
|
||||
if backend == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
if backend == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
=======
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
@@ -47,5 +64,12 @@ def torch_dtype():
|
||||
backend = hardware_info()["backend"]
|
||||
if backend == "cpu":
|
||||
return torch.float32
|
||||
<<<<<<< HEAD
|
||||
# MPS supports fp16 for diffusers; bf16 has gaps. Stick with fp16.
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
# MPS supports fp16 for diffusers; bf16 has gaps. Stick with fp16.
|
||||
=======
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
return torch.float16
|
||||
|
||||
@@ -1,14 +1,32 @@
|
||||
"""GPU and VRAM detection. Returns vendor + tier used to pick torch wheel and default models."""
|
||||
from __future__ import annotations
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
import ctypes
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
import platform
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
<<<<<<< HEAD
|
||||
Vendor = Literal["nvidia", "amd", "intel", "apple", "cpu"]
|
||||
Backend = Literal["cuda", "rocm", "directml", "mps", "cpu"]
|
||||
SUPPORTED_BACKENDS: tuple[str, ...] = ("auto", "cuda", "rocm", "directml", "mps", "cpu")
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
Vendor = Literal["nvidia", "amd", "intel", "apple", "cpu"]
|
||||
Backend = Literal["cuda", "rocm", "directml", "mps", "cpu"]
|
||||
SUPPORTED_BACKENDS: tuple[str, ...] = ("auto", "cuda", "rocm", "directml", "mps", "cpu")
|
||||
=======
|
||||
Vendor = Literal["nvidia", "amd", "intel", "cpu"]
|
||||
Backend = Literal["cuda", "directml", "cpu"]
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -20,6 +38,10 @@ class HardwareInfo:
|
||||
tier: Literal["cpu", "low", "mid", "high", "ultra"]
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
def _detect_mac_gpu() -> tuple[str, float] | None:
|
||||
"""Apple Silicon: report chip name + unified memory (proxy for VRAM).
|
||||
Intel Mac: returns None (no GPU acceleration path)."""
|
||||
@@ -44,6 +66,11 @@ def _detect_mac_gpu() -> tuple[str, float] | None:
|
||||
return name, vram_gb
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
=======
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
def _run(cmd: list[str]) -> str:
|
||||
try:
|
||||
out = subprocess.run(cmd, capture_output=True, text=True, timeout=10, check=False)
|
||||
@@ -100,6 +127,10 @@ def _detect_dxgi() -> list[tuple[str, float, str]]:
|
||||
return results
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
def _detect_linux_gpus() -> list[tuple[str, float, str]]:
|
||||
"""Enumerate Linux GPUs. Returns list of (name, vram_gb, vendor_hint)."""
|
||||
if platform.system() != "Linux":
|
||||
@@ -160,6 +191,11 @@ def _vendor_from_backend(backend: str) -> Vendor:
|
||||
}.get(backend, "cpu") # type: ignore[return-value]
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
=======
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
def _vram_tier(vram_gb: float) -> Literal["cpu", "low", "mid", "high", "ultra"]:
|
||||
if vram_gb < 1:
|
||||
return "cpu"
|
||||
@@ -172,6 +208,10 @@ def _vram_tier(vram_gb: float) -> Literal["cpu", "low", "mid", "high", "ultra"]:
|
||||
return "ultra"
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
def detect(force_backend: str | None = None, force_vendor: str | None = None) -> HardwareInfo:
|
||||
"""Auto-detect hardware. If force_backend is set (cuda/rocm/directml/cpu), skip detection
|
||||
for that decision but still try to discover device name + VRAM for tier sizing."""
|
||||
@@ -227,10 +267,21 @@ def detect(force_backend: str | None = None, force_vendor: str | None = None) ->
|
||||
name, vram = mac_gpu
|
||||
return HardwareInfo("apple", "mps", name, vram, _vram_tier(vram))
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
=======
|
||||
def detect() -> HardwareInfo:
|
||||
nv = _detect_nvidia()
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
if nv:
|
||||
name, vram = nv
|
||||
return HardwareInfo("nvidia", "cuda", name, vram, _vram_tier(vram))
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
if is_windows and win_adapters:
|
||||
win_adapters.sort(key=lambda a: a[1], reverse=True)
|
||||
name, vram, hint = win_adapters[0]
|
||||
@@ -254,6 +305,26 @@ def detect(force_backend: str | None = None, force_vendor: str | None = None) ->
|
||||
# No good Intel-on-Linux torch path here; default to CPU.
|
||||
return HardwareInfo("cpu", "cpu", f"Intel GPU (no backend) — {name}", 0.0, "cpu")
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
=======
|
||||
adapters = _detect_dxgi()
|
||||
# Prefer discrete (highest VRAM) non-basic adapter
|
||||
adapters = [a for a in adapters if "basic" not in a[0].lower() and "microsoft" not in a[0].lower()]
|
||||
if adapters:
|
||||
adapters.sort(key=lambda a: a[1], reverse=True)
|
||||
name, vram, hint = adapters[0]
|
||||
# AdapterRAM is unreliable for >4GB cards. If exactly 4GB and modern AMD/Intel card name, bump.
|
||||
if vram <= 4.1 and any(k in name.lower() for k in ("rx 6", "rx 7", "arc a", "arc b")):
|
||||
vram = 8.0 # conservative guess
|
||||
if hint in ("amd", "intel"):
|
||||
return HardwareInfo(hint, "directml", name, vram, _vram_tier(vram))
|
||||
if hint == "nvidia":
|
||||
# nvidia-smi missing but card is nvidia: drivers may be broken, fall through to directml
|
||||
return HardwareInfo("nvidia", "directml", name, vram, _vram_tier(vram))
|
||||
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
return HardwareInfo("cpu", "cpu", platform.processor() or "CPU", 0.0, "cpu")
|
||||
|
||||
|
||||
@@ -277,6 +348,10 @@ def torch_install_args(info: HardwareInfo) -> list[str]:
|
||||
"--index-url",
|
||||
"https://download.pytorch.org/whl/cu124",
|
||||
]
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
if info.backend == "rocm":
|
||||
# ROCm wheels are Linux-only. Index pinned to a stable ROCm release line.
|
||||
return [
|
||||
@@ -285,6 +360,11 @@ def torch_install_args(info: HardwareInfo) -> list[str]:
|
||||
"--index-url",
|
||||
"https://download.pytorch.org/whl/rocm6.1",
|
||||
]
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
=======
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
if info.backend == "directml":
|
||||
# torch-directml currently pins to torch 2.4.x. Match it.
|
||||
return [
|
||||
@@ -292,12 +372,21 @@ def torch_install_args(info: HardwareInfo) -> list[str]:
|
||||
"torchvision>=0.19,<0.20",
|
||||
"torch-directml>=0.2.5",
|
||||
]
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
if info.backend == "mps":
|
||||
# Default PyPI torch wheel ships MPS support on macOS arm64. No custom index.
|
||||
return ["torch", "torchvision"]
|
||||
# CPU. macOS uses default PyPI wheels (no /whl/cpu index for darwin).
|
||||
if platform.system() == "Darwin":
|
||||
return ["torch", "torchvision"]
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
=======
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
return [
|
||||
"torch",
|
||||
"torchvision",
|
||||
|
||||
@@ -30,9 +30,20 @@ def apply_memory_strategy(pipe) -> None:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
<<<<<<< HEAD
|
||||
if backend in ("cuda", "rocm"):
|
||||
# ROCm builds expose the cuda API, so accelerate offload hooks work the same way.
|
||||
# Offload only if VRAM tight.
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
if backend in ("cuda", "rocm"):
|
||||
# ROCm builds expose the cuda API, so accelerate offload hooks work the same way.
|
||||
# Offload only if VRAM tight.
|
||||
=======
|
||||
if backend == "cuda":
|
||||
# Offload only if VRAM tight. cpu_offload is CUDA-only via accelerate hooks.
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
if vram < 10:
|
||||
try:
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
@@ -57,6 +68,10 @@ def apply_memory_strategy(pipe) -> None:
|
||||
pipe.to("cpu")
|
||||
return
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
if backend == "mps":
|
||||
# Apple Silicon shares unified memory with CPU. accelerate's sequential offload
|
||||
# has spotty MPS support; rely on slicing/tiling already enabled above.
|
||||
@@ -79,5 +94,10 @@ def apply_memory_strategy(pipe) -> None:
|
||||
pipe.to("cpu")
|
||||
return
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
=======
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
# CPU
|
||||
pipe.to("cpu")
|
||||
|
||||
Reference in New Issue
Block a user