76 lines
2.2 KiB
Python
76 lines
2.2 KiB
Python
"""Torch device selection. Wraps cuda / directml / cpu behind one helper."""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from functools import lru_cache
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).parent.parent
|
|
HARDWARE_CACHE = ROOT / "config.local.json"
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def hardware_info() -> dict:
|
|
if HARDWARE_CACHE.exists():
|
|
return json.loads(HARDWARE_CACHE.read_text())
|
|
from . import hardware
|
|
info = hardware.detect()
|
|
return {
|
|
"vendor": info.vendor,
|
|
"backend": info.backend,
|
|
"device_name": info.device_name,
|
|
"vram_gb": info.vram_gb,
|
|
"tier": info.tier,
|
|
}
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_device():
|
|
import torch
|
|
backend = hardware_info()["backend"]
|
|
<<<<<<< HEAD
|
|
# ROCm builds of torch expose the cuda namespace.
|
|
if backend in ("cuda", "rocm") and torch.cuda.is_available():
|
|
=======
|
|
<<<<<<< HEAD
|
|
# ROCm builds of torch expose the cuda namespace.
|
|
if backend in ("cuda", "rocm") and torch.cuda.is_available():
|
|
=======
|
|
if backend == "cuda" and torch.cuda.is_available():
|
|
>>>>>>> refs/remotes/azuze/main
|
|
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
|
return torch.device("cuda")
|
|
if backend == "directml":
|
|
try:
|
|
import torch_directml
|
|
return torch_directml.device()
|
|
except ImportError:
|
|
pass
|
|
<<<<<<< HEAD
|
|
if backend == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
|
return torch.device("mps")
|
|
=======
|
|
<<<<<<< HEAD
|
|
if backend == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
|
return torch.device("mps")
|
|
=======
|
|
>>>>>>> refs/remotes/azuze/main
|
|
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
|
return torch.device("cpu")
|
|
|
|
|
|
def torch_dtype():
|
|
import torch
|
|
backend = hardware_info()["backend"]
|
|
if backend == "cpu":
|
|
return torch.float32
|
|
<<<<<<< HEAD
|
|
# MPS supports fp16 for diffusers; bf16 has gaps. Stick with fp16.
|
|
=======
|
|
<<<<<<< HEAD
|
|
# MPS supports fp16 for diffusers; bf16 has gaps. Stick with fp16.
|
|
=======
|
|
>>>>>>> refs/remotes/azuze/main
|
|
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
|
return torch.float16
|