Merge branch 'main' of https://git.azuze.fr/kawa/KawAI
This commit is contained in:
@@ -28,8 +28,17 @@ def hardware_info() -> dict:
|
||||
def get_device():
|
||||
import torch
|
||||
backend = hardware_info()["backend"]
|
||||
<<<<<<< HEAD
|
||||
# ROCm builds of torch expose the cuda namespace.
|
||||
if backend in ("cuda", "rocm") and torch.cuda.is_available():
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
# ROCm builds of torch expose the cuda namespace.
|
||||
if backend in ("cuda", "rocm") and torch.cuda.is_available():
|
||||
=======
|
||||
if backend == "cuda" and torch.cuda.is_available():
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
return torch.device("cuda")
|
||||
if backend == "directml":
|
||||
try:
|
||||
@@ -37,8 +46,16 @@ def get_device():
|
||||
return torch_directml.device()
|
||||
except ImportError:
|
||||
pass
|
||||
<<<<<<< HEAD
|
||||
if backend == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
if backend == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
=======
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
@@ -47,5 +64,12 @@ def torch_dtype():
|
||||
backend = hardware_info()["backend"]
|
||||
if backend == "cpu":
|
||||
return torch.float32
|
||||
<<<<<<< HEAD
|
||||
# MPS supports fp16 for diffusers; bf16 has gaps. Stick with fp16.
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
# MPS supports fp16 for diffusers; bf16 has gaps. Stick with fp16.
|
||||
=======
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
return torch.float16
|
||||
|
||||
Reference in New Issue
Block a user