Files
KawAI/backends/device.py
T

76 lines
2.2 KiB
Python

"""Torch device selection. Wraps cuda / directml / cpu behind one helper."""
from __future__ import annotations
import json
from functools import lru_cache
from pathlib import Path
ROOT = Path(__file__).parent.parent
HARDWARE_CACHE = ROOT / "config.local.json"
@lru_cache(maxsize=1)
def hardware_info() -> dict:
if HARDWARE_CACHE.exists():
return json.loads(HARDWARE_CACHE.read_text())
from . import hardware
info = hardware.detect()
return {
"vendor": info.vendor,
"backend": info.backend,
"device_name": info.device_name,
"vram_gb": info.vram_gb,
"tier": info.tier,
}
@lru_cache(maxsize=1)
def get_device():
import torch
backend = hardware_info()["backend"]
<<<<<<< HEAD
# ROCm builds of torch expose the cuda namespace.
if backend in ("cuda", "rocm") and torch.cuda.is_available():
=======
<<<<<<< HEAD
# ROCm builds of torch expose the cuda namespace.
if backend in ("cuda", "rocm") and torch.cuda.is_available():
=======
if backend == "cuda" and torch.cuda.is_available():
>>>>>>> refs/remotes/azuze/main
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
return torch.device("cuda")
if backend == "directml":
try:
import torch_directml
return torch_directml.device()
except ImportError:
pass
<<<<<<< HEAD
if backend == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return torch.device("mps")
=======
<<<<<<< HEAD
if backend == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return torch.device("mps")
=======
>>>>>>> refs/remotes/azuze/main
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
return torch.device("cpu")
def torch_dtype():
import torch
backend = hardware_info()["backend"]
if backend == "cpu":
return torch.float32
<<<<<<< HEAD
# MPS supports fp16 for diffusers; bf16 has gaps. Stick with fp16.
=======
<<<<<<< HEAD
# MPS supports fp16 for diffusers; bf16 has gaps. Stick with fp16.
=======
>>>>>>> refs/remotes/azuze/main
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
return torch.float16