Initial commit
This commit is contained in:
@@ -0,0 +1,51 @@
|
||||
"""Torch device selection. Wraps cuda / directml / cpu behind one helper."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).parent.parent
|
||||
HARDWARE_CACHE = ROOT / "config.local.json"
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def hardware_info() -> dict:
|
||||
if HARDWARE_CACHE.exists():
|
||||
return json.loads(HARDWARE_CACHE.read_text())
|
||||
from . import hardware
|
||||
info = hardware.detect()
|
||||
return {
|
||||
"vendor": info.vendor,
|
||||
"backend": info.backend,
|
||||
"device_name": info.device_name,
|
||||
"vram_gb": info.vram_gb,
|
||||
"tier": info.tier,
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_device():
|
||||
import torch
|
||||
backend = hardware_info()["backend"]
|
||||
# ROCm builds of torch expose the cuda namespace.
|
||||
if backend in ("cuda", "rocm") and torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
if backend == "directml":
|
||||
try:
|
||||
import torch_directml
|
||||
return torch_directml.device()
|
||||
except ImportError:
|
||||
pass
|
||||
if backend == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def torch_dtype():
|
||||
import torch
|
||||
backend = hardware_info()["backend"]
|
||||
if backend == "cpu":
|
||||
return torch.float32
|
||||
# MPS supports fp16 for diffusers; bf16 has gaps. Stick with fp16.
|
||||
return torch.float16
|
||||
Reference in New Issue
Block a user