strict cuda check

This commit is contained in:
wataru 2023-06-05 02:02:48 +09:00
parent b3b8a347e1
commit d6731f4adf

View File

@ -30,7 +30,8 @@ class DeviceManager(object):
def getOnnxExecutionProvider(self, gpu: int):
availableProviders = onnxruntime.get_available_providers()
if gpu >= 0 and "CUDAExecutionProvider" in availableProviders:
devNum = torch.cuda.device_count()
if gpu >= 0 and "CUDAExecutionProvider" in availableProviders and devNum > 0:
return ["CUDAExecutionProvider"], [{"device_id": gpu}]
elif gpu >= 0 and "DmlExecutionProvider" in availableProviders:
return ["DmlExecutionProvider"], [{}]