voice-changer/server/voice_changer/RVC/embedder/EmbedderManager.py

51 lines
2.1 KiB
Python
Raw Normal View History

2023-05-02 06:11:00 +03:00
from torch import device
from const import EnumEmbedderTypes
from voice_changer.RVC.embedder.Embedder import Embedder
from voice_changer.RVC.embedder.FairseqContentvec import FairseqContentvec
from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert
from voice_changer.RVC.embedder.FairseqHubertJp import FairseqHubertJp
class EmbedderManager:
currentEmbedder: Embedder | None = None
@classmethod
def getEmbedder(
cls, embederType: EnumEmbedderTypes, file: str, isHalf: bool, dev: device
) -> Embedder:
if cls.currentEmbedder is None:
print("[Voice Changer] generate new embedder. (no embedder)")
cls.currentEmbedder = cls.loadEmbedder(embederType, file, isHalf, dev)
elif cls.currentEmbedder.matchCondition(embederType, file) is False:
print("[Voice Changer] generate new embedder. (not match)")
cls.currentEmbedder = cls.loadEmbedder(embederType, file, isHalf, dev)
else:
cls.currentEmbedder.setDevice(dev)
cls.currentEmbedder.setHalf(isHalf)
2023-05-03 07:14:00 +03:00
# print("[Voice Changer] generate new embedder. (ANYWAY)", isHalf)
# cls.currentEmbedder = cls.loadEmbedder(embederType, file, isHalf, dev)
2023-05-02 06:11:00 +03:00
return cls.currentEmbedder
@classmethod
def loadEmbedder(
cls, embederType: EnumEmbedderTypes, file: str, isHalf: bool, dev: device
) -> Embedder:
2023-05-02 16:29:28 +03:00
if (
embederType == EnumEmbedderTypes.hubert
or embederType == EnumEmbedderTypes.hubert.value
):
2023-05-02 06:11:00 +03:00
return FairseqHubert().loadModel(file, dev, isHalf)
2023-05-02 16:29:28 +03:00
elif (
embederType == EnumEmbedderTypes.hubert_jp
or embederType == EnumEmbedderTypes.hubert_jp.value
):
2023-05-02 06:11:00 +03:00
return FairseqHubertJp().loadModel(file, dev, isHalf)
2023-05-02 16:29:28 +03:00
elif (
embederType == EnumEmbedderTypes.contentvec
or embederType == EnumEmbedderTypes.contentvec.value
):
2023-05-02 06:11:00 +03:00
return FairseqContentvec().loadModel(file, dev, isHalf)
else:
return FairseqHubert().loadModel(file, dev, isHalf)