voice-changer/server/voice_changer/RVC/deviceManager/DeviceManager.py

84 lines
2.6 KiB
Python
Raw Normal View History

2023-05-02 19:11:03 +03:00
import torch
2023-05-29 11:34:35 +03:00
import onnxruntime
2023-05-02 19:11:03 +03:00
class DeviceManager(object):
_instance = None
forceTensor: bool = False
2023-05-02 19:11:03 +03:00
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
self.gpu_num = torch.cuda.device_count()
self.mps_enabled: bool = (
getattr(torch.backends, "mps", None) is not None
and torch.backends.mps.is_available()
)
def getDevice(self, id: int):
if id < 0 or (self.gpu_num == 0 and self.mps_enabled is False):
dev = torch.device("cpu")
elif self.mps_enabled:
dev = torch.device("mps")
else:
dev = torch.device("cuda", index=id)
return dev
2023-05-29 11:34:35 +03:00
def getOnnxExecutionProvider(self, gpu: int):
availableProviders = onnxruntime.get_available_providers()
2023-06-04 20:02:48 +03:00
devNum = torch.cuda.device_count()
if gpu >= 0 and "CUDAExecutionProvider" in availableProviders and devNum > 0:
2023-05-29 11:34:35 +03:00
return ["CUDAExecutionProvider"], [{"device_id": gpu}]
elif gpu >= 0 and "DmlExecutionProvider" in availableProviders:
return ["DmlExecutionProvider"], [{}]
else:
return ["CPUExecutionProvider"], [
{
"intra_op_num_threads": 8,
"execution_mode": onnxruntime.ExecutionMode.ORT_PARALLEL,
"inter_op_num_threads": 8,
}
]
def setForceTensor(self, forceTensor: bool):
self.forceTensor = forceTensor
2023-05-02 19:11:03 +03:00
def halfPrecisionAvailable(self, id: int):
if self.gpu_num == 0:
return False
2023-05-03 07:14:00 +03:00
if id < 0:
return False
if self.forceTensor:
return False
2023-05-03 07:14:00 +03:00
2023-05-04 11:15:53 +03:00
try:
gpuName = torch.cuda.get_device_name(id).upper()
if (
("16" in gpuName and "V100" not in gpuName)
or "P40" in gpuName.upper()
or "1070" in gpuName
or "1080" in gpuName
):
return False
except Exception as e:
print(e)
2023-05-02 19:11:03 +03:00
return False
2023-05-03 07:14:00 +03:00
cap = torch.cuda.get_device_capability(id)
if cap[0] < 7: # コンピューティング機能が7以上の場合half precisionが使えるとされているが例外があるT500とか
return False
2023-05-03 07:14:00 +03:00
return True
2023-05-03 12:47:14 +03:00
def getDeviceMemory(self, id: int):
try:
return torch.cuda.get_device_properties(id).total_memory
# except Exception as e:
except:
# print(e)
return 0