Fix bad merge
This commit is contained in:
@@ -28,17 +28,8 @@ def hardware_info() -> dict:
|
||||
def get_device():
|
||||
import torch
|
||||
backend = hardware_info()["backend"]
|
||||
<<<<<<< HEAD
|
||||
# ROCm builds of torch expose the cuda namespace.
|
||||
if backend in ("cuda", "rocm") and torch.cuda.is_available():
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
# ROCm builds of torch expose the cuda namespace.
|
||||
if backend in ("cuda", "rocm") and torch.cuda.is_available():
|
||||
=======
|
||||
if backend == "cuda" and torch.cuda.is_available():
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
return torch.device("cuda")
|
||||
if backend == "directml":
|
||||
try:
|
||||
@@ -46,16 +37,8 @@ def get_device():
|
||||
return torch_directml.device()
|
||||
except ImportError:
|
||||
pass
|
||||
<<<<<<< HEAD
|
||||
if backend == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
if backend == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
=======
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
@@ -64,12 +47,5 @@ def torch_dtype():
|
||||
backend = hardware_info()["backend"]
|
||||
if backend == "cpu":
|
||||
return torch.float32
|
||||
<<<<<<< HEAD
|
||||
# MPS supports fp16 for diffusers; bf16 has gaps. Stick with fp16.
|
||||
=======
|
||||
<<<<<<< HEAD
|
||||
# MPS supports fp16 for diffusers; bf16 has gaps. Stick with fp16.
|
||||
=======
|
||||
>>>>>>> refs/remotes/azuze/main
|
||||
>>>>>>> 965a3d97c6dae38fa25174559b1ea0f3050788f9
|
||||
return torch.float16
|
||||
|
||||
Reference in New Issue
Block a user