2023-05-02 06:11:00 +03:00
|
|
|
from torch import device
|
|
|
|
|
2023-07-06 21:56:45 +03:00
|
|
|
from const import EmbedderType
|
2023-05-02 06:11:00 +03:00
|
|
|
from voice_changer.RVC.embedder.Embedder import Embedder
|
|
|
|
from voice_changer.RVC.embedder.FairseqContentvec import FairseqContentvec
|
|
|
|
from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert
|
|
|
|
from voice_changer.RVC.embedder.FairseqHubertJp import FairseqHubertJp
|
2023-07-07 21:18:23 +03:00
|
|
|
from voice_changer.RVC.embedder.OnnxContentvec import OnnxContentvec
|
2023-05-04 16:46:42 +03:00
|
|
|
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
|
2023-05-02 06:11:00 +03:00
|
|
|
|
|
|
|
|
|
|
|
class EmbedderManager:
|
|
|
|
currentEmbedder: Embedder | None = None
|
2023-05-04 16:46:42 +03:00
|
|
|
params: VoiceChangerParams
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def initialize(cls, params: VoiceChangerParams):
|
|
|
|
cls.params = params
|
2023-05-02 06:11:00 +03:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def getEmbedder(
|
2023-07-06 21:56:45 +03:00
|
|
|
cls, embederType: EmbedderType, isHalf: bool, dev: device
|
2023-05-02 06:11:00 +03:00
|
|
|
) -> Embedder:
|
|
|
|
if cls.currentEmbedder is None:
|
|
|
|
print("[Voice Changer] generate new embedder. (no embedder)")
|
2023-05-04 16:46:42 +03:00
|
|
|
cls.currentEmbedder = cls.loadEmbedder(embederType, isHalf, dev)
|
|
|
|
elif cls.currentEmbedder.matchCondition(embederType) is False:
|
2023-05-02 06:11:00 +03:00
|
|
|
print("[Voice Changer] generate new embedder. (not match)")
|
2023-05-04 16:46:42 +03:00
|
|
|
cls.currentEmbedder = cls.loadEmbedder(embederType, isHalf, dev)
|
2023-05-02 06:11:00 +03:00
|
|
|
else:
|
2023-07-07 21:18:23 +03:00
|
|
|
print("[Voice Changer] generate new embedder. (anyway)")
|
|
|
|
cls.currentEmbedder = cls.loadEmbedder(embederType, isHalf, dev)
|
|
|
|
|
|
|
|
# cls.currentEmbedder.setDevice(dev)
|
|
|
|
# cls.currentEmbedder.setHalf(isHalf)
|
2023-05-02 06:11:00 +03:00
|
|
|
return cls.currentEmbedder
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def loadEmbedder(
|
2023-07-06 21:56:45 +03:00
|
|
|
cls, embederType: EmbedderType, isHalf: bool, dev: device
|
2023-05-02 06:11:00 +03:00
|
|
|
) -> Embedder:
|
2023-07-06 21:56:45 +03:00
|
|
|
if embederType == "hubert_base":
|
2023-07-07 21:18:23 +03:00
|
|
|
try:
|
|
|
|
file = cls.params.content_vec_500_onnx
|
|
|
|
return OnnxContentvec().loadModel(file, dev)
|
|
|
|
except Exception as e:
|
|
|
|
print(e)
|
|
|
|
file = cls.params.hubert_base
|
|
|
|
return FairseqHubert().loadModel(file, dev, isHalf)
|
2023-07-06 21:56:45 +03:00
|
|
|
elif embederType == "hubert-base-japanese":
|
2023-05-04 16:46:42 +03:00
|
|
|
file = cls.params.hubert_base_jp
|
2023-05-02 06:11:00 +03:00
|
|
|
return FairseqHubertJp().loadModel(file, dev, isHalf)
|
2023-07-06 21:56:45 +03:00
|
|
|
elif embederType == "contentvec":
|
2023-07-07 21:18:23 +03:00
|
|
|
try:
|
|
|
|
file = cls.params.content_vec_500_onnx
|
|
|
|
return OnnxContentvec().loadModel(file, dev)
|
|
|
|
except Exception as e:
|
|
|
|
print(e)
|
|
|
|
file = cls.params.hubert_base
|
|
|
|
return FairseqContentvec().loadModel(file, dev, isHalf)
|
2023-05-02 06:11:00 +03:00
|
|
|
else:
|
|
|
|
return FairseqHubert().loadModel(file, dev, isHalf)
|