Fix bad merge

This commit is contained in:
2026-05-04 10:08:18 +02:00
parent 0b670cbd24
commit c610ae5a68
5 changed files with 0 additions and 236 deletions
-24
View File
@@ -28,17 +28,8 @@ def hardware_info() -> dict:
def get_device():
import torch
backend = hardware_info()["backend"]
<<<<<<< HEAD
# ROCm builds of torch expose the cuda namespace.
if backend in ("cuda", "rocm") and torch.cuda.is_available():
=======
<<<<<<< HEAD
# ROCm builds of torch expose the cuda namespace.
if backend in ("cuda", "rocm") and torch.cuda.is_available():
=======
if backend == "cuda" and torch.cuda.is_available():
>>>>>>> refs/remotes/azuze/main
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
return torch.device("cuda")
if backend == "directml":
try:
@@ -46,16 +37,8 @@ def get_device():
return torch_directml.device()
except ImportError:
pass
<<<<<<< HEAD
if backend == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return torch.device("mps")
=======
<<<<<<< HEAD
if backend == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return torch.device("mps")
=======
>>>>>>> refs/remotes/azuze/main
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
return torch.device("cpu")
@@ -64,12 +47,5 @@ def torch_dtype():
backend = hardware_info()["backend"]
if backend == "cpu":
return torch.float32
<<<<<<< HEAD
# MPS supports fp16 for diffusers; bf16 has gaps. Stick with fp16.
=======
<<<<<<< HEAD
# MPS supports fp16 for diffusers; bf16 has gaps. Stick with fp16.
=======
>>>>>>> refs/remotes/azuze/main
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
return torch.float16
-89
View File
@@ -1,32 +1,14 @@
"""GPU and VRAM detection. Returns vendor + tier used to pick torch wheel and default models."""
from __future__ import annotations
<<<<<<< HEAD
=======
<<<<<<< HEAD
=======
import ctypes
>>>>>>> refs/remotes/azuze/main
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
import platform
import subprocess
from dataclasses import dataclass
from typing import Literal
<<<<<<< HEAD
Vendor = Literal["nvidia", "amd", "intel", "apple", "cpu"]
Backend = Literal["cuda", "rocm", "directml", "mps", "cpu"]
SUPPORTED_BACKENDS: tuple[str, ...] = ("auto", "cuda", "rocm", "directml", "mps", "cpu")
=======
<<<<<<< HEAD
Vendor = Literal["nvidia", "amd", "intel", "apple", "cpu"]
Backend = Literal["cuda", "rocm", "directml", "mps", "cpu"]
SUPPORTED_BACKENDS: tuple[str, ...] = ("auto", "cuda", "rocm", "directml", "mps", "cpu")
=======
Vendor = Literal["nvidia", "amd", "intel", "cpu"]
Backend = Literal["cuda", "directml", "cpu"]
>>>>>>> refs/remotes/azuze/main
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
@dataclass
@@ -38,10 +20,6 @@ class HardwareInfo:
tier: Literal["cpu", "low", "mid", "high", "ultra"]
<<<<<<< HEAD
=======
<<<<<<< HEAD
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
def _detect_mac_gpu() -> tuple[str, float] | None:
"""Apple Silicon: report chip name + unified memory (proxy for VRAM).
Intel Mac: returns None (no GPU acceleration path)."""
@@ -66,11 +44,6 @@ def _detect_mac_gpu() -> tuple[str, float] | None:
return name, vram_gb
<<<<<<< HEAD
=======
=======
>>>>>>> refs/remotes/azuze/main
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
def _run(cmd: list[str]) -> str:
try:
out = subprocess.run(cmd, capture_output=True, text=True, timeout=10, check=False)
@@ -127,10 +100,6 @@ def _detect_dxgi() -> list[tuple[str, float, str]]:
return results
<<<<<<< HEAD
=======
<<<<<<< HEAD
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
def _detect_linux_gpus() -> list[tuple[str, float, str]]:
"""Enumerate Linux GPUs. Returns list of (name, vram_gb, vendor_hint)."""
if platform.system() != "Linux":
@@ -191,11 +160,6 @@ def _vendor_from_backend(backend: str) -> Vendor:
}.get(backend, "cpu") # type: ignore[return-value]
<<<<<<< HEAD
=======
=======
>>>>>>> refs/remotes/azuze/main
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
def _vram_tier(vram_gb: float) -> Literal["cpu", "low", "mid", "high", "ultra"]:
if vram_gb < 1:
return "cpu"
@@ -208,10 +172,6 @@ def _vram_tier(vram_gb: float) -> Literal["cpu", "low", "mid", "high", "ultra"]:
return "ultra"
<<<<<<< HEAD
=======
<<<<<<< HEAD
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
def detect(force_backend: str | None = None, force_vendor: str | None = None) -> HardwareInfo:
"""Auto-detect hardware. If force_backend is set (cuda/rocm/directml/cpu), skip detection
for that decision but still try to discover device name + VRAM for tier sizing."""
@@ -267,21 +227,10 @@ def detect(force_backend: str | None = None, force_vendor: str | None = None) ->
name, vram = mac_gpu
return HardwareInfo("apple", "mps", name, vram, _vram_tier(vram))
<<<<<<< HEAD
=======
=======
def detect() -> HardwareInfo:
nv = _detect_nvidia()
>>>>>>> refs/remotes/azuze/main
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
if nv:
name, vram = nv
return HardwareInfo("nvidia", "cuda", name, vram, _vram_tier(vram))
<<<<<<< HEAD
=======
<<<<<<< HEAD
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
if is_windows and win_adapters:
win_adapters.sort(key=lambda a: a[1], reverse=True)
name, vram, hint = win_adapters[0]
@@ -305,26 +254,6 @@ def detect() -> HardwareInfo:
# No good Intel-on-Linux torch path here; default to CPU.
return HardwareInfo("cpu", "cpu", f"Intel GPU (no backend) — {name}", 0.0, "cpu")
<<<<<<< HEAD
=======
=======
adapters = _detect_dxgi()
# Prefer discrete (highest VRAM) non-basic adapter
adapters = [a for a in adapters if "basic" not in a[0].lower() and "microsoft" not in a[0].lower()]
if adapters:
adapters.sort(key=lambda a: a[1], reverse=True)
name, vram, hint = adapters[0]
# AdapterRAM is unreliable for >4GB cards. If exactly 4GB and modern AMD/Intel card name, bump.
if vram <= 4.1 and any(k in name.lower() for k in ("rx 6", "rx 7", "arc a", "arc b")):
vram = 8.0 # conservative guess
if hint in ("amd", "intel"):
return HardwareInfo(hint, "directml", name, vram, _vram_tier(vram))
if hint == "nvidia":
# nvidia-smi missing but card is nvidia: drivers may be broken, fall through to directml
return HardwareInfo("nvidia", "directml", name, vram, _vram_tier(vram))
>>>>>>> refs/remotes/azuze/main
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
return HardwareInfo("cpu", "cpu", platform.processor() or "CPU", 0.0, "cpu")
@@ -348,10 +277,6 @@ def torch_install_args(info: HardwareInfo) -> list[str]:
"--index-url",
"https://download.pytorch.org/whl/cu124",
]
<<<<<<< HEAD
=======
<<<<<<< HEAD
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
if info.backend == "rocm":
# ROCm wheels are Linux-only. Index pinned to a stable ROCm release line.
return [
@@ -360,11 +285,6 @@ def torch_install_args(info: HardwareInfo) -> list[str]:
"--index-url",
"https://download.pytorch.org/whl/rocm6.1",
]
<<<<<<< HEAD
=======
=======
>>>>>>> refs/remotes/azuze/main
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
if info.backend == "directml":
# torch-directml currently pins to torch 2.4.x. Match it.
return [
@@ -372,21 +292,12 @@ def torch_install_args(info: HardwareInfo) -> list[str]:
"torchvision>=0.19,<0.20",
"torch-directml>=0.2.5",
]
<<<<<<< HEAD
=======
<<<<<<< HEAD
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
if info.backend == "mps":
# Default PyPI torch wheel ships MPS support on macOS arm64. No custom index.
return ["torch", "torchvision"]
# CPU. macOS uses default PyPI wheels (no /whl/cpu index for darwin).
if platform.system() == "Darwin":
return ["torch", "torchvision"]
<<<<<<< HEAD
=======
=======
>>>>>>> refs/remotes/azuze/main
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
return [
"torch",
"torchvision",
-20
View File
@@ -30,20 +30,9 @@ def apply_memory_strategy(pipe) -> None:
except Exception:
pass
<<<<<<< HEAD
if backend in ("cuda", "rocm"):
# ROCm builds expose the cuda API, so accelerate offload hooks work the same way.
# Offload only if VRAM tight.
=======
<<<<<<< HEAD
if backend in ("cuda", "rocm"):
# ROCm builds expose the cuda API, so accelerate offload hooks work the same way.
# Offload only if VRAM tight.
=======
if backend == "cuda":
# Offload only if VRAM tight. cpu_offload is CUDA-only via accelerate hooks.
>>>>>>> refs/remotes/azuze/main
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
if vram < 10:
try:
pipe.enable_sequential_cpu_offload()
@@ -68,10 +57,6 @@ def apply_memory_strategy(pipe) -> None:
pipe.to("cpu")
return
<<<<<<< HEAD
=======
<<<<<<< HEAD
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
if backend == "mps":
# Apple Silicon shares unified memory with CPU. accelerate's sequential offload
# has spotty MPS support; rely on slicing/tiling already enabled above.
@@ -94,10 +79,5 @@ def apply_memory_strategy(pipe) -> None:
pipe.to("cpu")
return
<<<<<<< HEAD
=======
=======
>>>>>>> refs/remotes/azuze/main
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
# CPU
pipe.to("cpu")