This commit is contained in:
2026-05-04 10:03:38 +02:00
5 changed files with 236 additions and 0 deletions
+24
View File
@@ -28,8 +28,17 @@ def hardware_info() -> dict:
def get_device():
import torch
backend = hardware_info()["backend"]
<<<<<<< HEAD
# ROCm builds of torch expose the cuda namespace.
if backend in ("cuda", "rocm") and torch.cuda.is_available():
=======
<<<<<<< HEAD
# ROCm builds of torch expose the cuda namespace.
if backend in ("cuda", "rocm") and torch.cuda.is_available():
=======
if backend == "cuda" and torch.cuda.is_available():
>>>>>>> refs/remotes/azuze/main
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
return torch.device("cuda")
if backend == "directml":
try:
@@ -37,8 +46,16 @@ def get_device():
return torch_directml.device()
except ImportError:
pass
<<<<<<< HEAD
if backend == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return torch.device("mps")
=======
<<<<<<< HEAD
if backend == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return torch.device("mps")
=======
>>>>>>> refs/remotes/azuze/main
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
return torch.device("cpu")
@@ -47,5 +64,12 @@ def torch_dtype():
backend = hardware_info()["backend"]
if backend == "cpu":
return torch.float32
<<<<<<< HEAD
# MPS supports fp16 for diffusers; bf16 has gaps. Stick with fp16.
=======
<<<<<<< HEAD
# MPS supports fp16 for diffusers; bf16 has gaps. Stick with fp16.
=======
>>>>>>> refs/remotes/azuze/main
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
return torch.float16