Initial commit

This commit is contained in:
2026-05-04 09:38:56 +02:00
commit 7be72a3650
16 changed files with 1525 additions and 0 deletions
+47
View File
@@ -0,0 +1,47 @@
"""Torch device selection. Wraps cuda / directml / cpu behind one helper."""
from __future__ import annotations
import json
from functools import lru_cache
from pathlib import Path
ROOT = Path(__file__).parent.parent
HARDWARE_CACHE = ROOT / "config.local.json"
@lru_cache(maxsize=1)
def hardware_info() -> dict:
if HARDWARE_CACHE.exists():
return json.loads(HARDWARE_CACHE.read_text())
from . import hardware
info = hardware.detect()
return {
"vendor": info.vendor,
"backend": info.backend,
"device_name": info.device_name,
"vram_gb": info.vram_gb,
"tier": info.tier,
}
@lru_cache(maxsize=1)
def get_device():
import torch
backend = hardware_info()["backend"]
if backend == "cuda" and torch.cuda.is_available():
return torch.device("cuda")
if backend == "directml":
try:
import torch_directml
return torch_directml.device()
except ImportError:
pass
return torch.device("cpu")
def torch_dtype():
import torch
backend = hardware_info()["backend"]
if backend == "cpu":
return torch.float32
return torch.float16