Files
KawAI/backends/device.py
T
2026-05-04 09:38:56 +02:00

48 lines
1.2 KiB
Python

"""Torch device selection. Wraps cuda / directml / cpu behind one helper."""
from __future__ import annotations
import json
from functools import lru_cache
from pathlib import Path
ROOT = Path(__file__).parent.parent
HARDWARE_CACHE = ROOT / "config.local.json"
@lru_cache(maxsize=1)
def hardware_info() -> dict:
if HARDWARE_CACHE.exists():
return json.loads(HARDWARE_CACHE.read_text())
from . import hardware
info = hardware.detect()
return {
"vendor": info.vendor,
"backend": info.backend,
"device_name": info.device_name,
"vram_gb": info.vram_gb,
"tier": info.tier,
}
@lru_cache(maxsize=1)
def get_device():
import torch
backend = hardware_info()["backend"]
if backend == "cuda" and torch.cuda.is_available():
return torch.device("cuda")
if backend == "directml":
try:
import torch_directml
return torch_directml.device()
except ImportError:
pass
return torch.device("cpu")
def torch_dtype():
import torch
backend = hardware_info()["backend"]
if backend == "cpu":
return torch.float32
return torch.float16