bugfix: mps for rmvpe

This commit is contained in:
w-okada 2023-07-23 23:01:35 +09:00
parent 7ab7dba5c8
commit a8a392b20d
8 changed files with 24 additions and 34 deletions

View File

@ -18,10 +18,11 @@ class DeviceManager(object):
)
def getDevice(self, id: int):
if id < 0 or (self.gpu_num == 0 and self.mps_enabled is False):
dev = torch.device("cpu")
elif self.mps_enabled:
dev = torch.device("mps")
if id < 0 or self.gpu_num == 0:
if self.mps_enabled is False:
dev = torch.device("cpu")
else:
dev = torch.device("mps")
else:
dev = torch.device("cuda", index=id)
return dev
@ -51,6 +52,6 @@ class DeviceManager(object):
try:
return torch.cuda.get_device_properties(id).total_memory
# except Exception as e:
except:
except: # NOQA
# print(e)
return 0

View File

@ -3,21 +3,19 @@ import torch
import numpy as np
from const import PitchExtractorType
from voice_changer.DiffusionSVC.pitchExtractor.PitchExtractor import PitchExtractor
from voice_changer.RVC.deviceManager.DeviceManager import DeviceManager
from voice_changer.utils.VoiceChangerModel import AudioInOut
class CrepePitchExtractor(PitchExtractor):
def __init__(self):
def __init__(self, gpu: int):
super().__init__()
self.pitchExtractorType: PitchExtractorType = "crepe"
self.f0_min = 50
self.f0_max = 1100
self.uv_interp = True
if torch.cuda.is_available():
self.device = torch.device("cuda:" + str(torch.cuda.current_device()))
else:
self.device = torch.device("cpu")
self.device = DeviceManager.get_instance().getDevice(gpu)
def extract(self, audio: AudioInOut, sr: int, block_size: int, model_sr: int, pitch, f0_up_key, silence_front=0):
hop_size = block_size * sr / model_sr

View File

@ -33,7 +33,7 @@ class PitchExtractorManager(Protocol):
elif pitchExtractorType == "dio":
return DioPitchExtractor()
elif pitchExtractorType == "crepe":
return CrepePitchExtractor()
return CrepePitchExtractor(gpu)
elif pitchExtractorType == "crepe_tiny":
return CrepeOnnxPitchExtractor(pitchExtractorType, cls.params.crepe_onnx_tiny, gpu)
elif pitchExtractorType == "crepe_full":

View File

@ -5,6 +5,7 @@ from const import PitchExtractorType
from voice_changer.DiffusionSVC.pitchExtractor.PitchExtractor import PitchExtractor
from voice_changer.DiffusionSVC.pitchExtractor.rmvpe.rmvpe import RMVPE
from scipy.ndimage import zoom
from voice_changer.RVC.deviceManager.DeviceManager import DeviceManager
from voice_changer.utils.VoiceChangerModel import AudioInOut
@ -18,10 +19,7 @@ class RMVPEPitchExtractor(PitchExtractor):
self.f0_max = 1100
self.uv_interp = True
self.input_sr = -1
if torch.cuda.is_available() and gpu >= 0:
self.device = torch.device("cuda:" + str(torch.cuda.current_device()))
else:
self.device = torch.device("cpu")
self.device = DeviceManager.get_instance().getDevice(gpu)
self.rmvpe = RMVPE(model_path=file, is_half=False, device=self.device)

View File

@ -20,10 +20,11 @@ class DeviceManager(object):
)
def getDevice(self, id: int):
if id < 0 or (self.gpu_num == 0 and self.mps_enabled is False):
dev = torch.device("cpu")
elif self.mps_enabled:
dev = torch.device("mps")
if id < 0 or self.gpu_num == 0:
if self.mps_enabled is False:
dev = torch.device("cpu")
else:
dev = torch.device("mps")
else:
if id < self.gpu_num:
dev = torch.device("cuda", index=id)

View File

@ -1,23 +1,19 @@
import torchcrepe
import torch
import numpy as np
from const import PitchExtractorType
from voice_changer.RVC.deviceManager.DeviceManager import DeviceManager
from voice_changer.RVC.pitchExtractor.PitchExtractor import PitchExtractor
class CrepePitchExtractor(PitchExtractor):
def __init__(self):
def __init__(self, gpu: int):
super().__init__()
self.pitchExtractorType: PitchExtractorType = "crepe"
if torch.cuda.is_available():
self.device = torch.device("cuda:" + str(torch.cuda.current_device()))
else:
self.device = torch.device("cpu")
self.device = DeviceManager.get_instance().getDevice(gpu)
def extract(self, audio, pitchf, f0_up_key, sr, window, silence_front=0):
n_frames = int(len(audio) // window) + 1
start_frame = int(silence_front * sr / window)
real_silence_front = start_frame * window / sr

View File

@ -33,7 +33,7 @@ class PitchExtractorManager(Protocol):
elif pitchExtractorType == "dio":
return DioPitchExtractor()
elif pitchExtractorType == "crepe":
return CrepePitchExtractor()
return CrepePitchExtractor(gpu)
elif pitchExtractorType == "crepe_tiny":
return CrepeOnnxPitchExtractor(pitchExtractorType, cls.params.crepe_onnx_tiny, gpu)
elif pitchExtractorType == "crepe_full":

View File

@ -1,8 +1,8 @@
import torch
import numpy as np
from const import PitchExtractorType
from voice_changer.DiffusionSVC.pitchExtractor.PitchExtractor import PitchExtractor
from voice_changer.DiffusionSVC.pitchExtractor.rmvpe.rmvpe import RMVPE
from voice_changer.RVC.deviceManager.DeviceManager import DeviceManager
class RMVPEPitchExtractor(PitchExtractor):
@ -17,11 +17,7 @@ class RMVPEPitchExtractor(PitchExtractor):
self.uv_interp = True
self.input_sr = -1
if torch.cuda.is_available() and gpu >= 0:
self.device = torch.device("cuda:" + str(torch.cuda.current_device()))
else:
self.device = torch.device("cpu")
self.device = DeviceManager.get_instance().getDevice(gpu)
self.rmvpe = RMVPE(model_path=file, is_half=False, device=self.device)
def extract(self, audio, pitchf, f0_up_key, sr, window, silence_front=0):