from torch import device from const import EnumEmbedderTypes from voice_changer.RVC.embedder.Embedder import Embedder from voice_changer.RVC.embedder.FairseqContentvec import FairseqContentvec from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert from voice_changer.RVC.embedder.FairseqHubertJp import FairseqHubertJp from voice_changer.utils.VoiceChangerParams import VoiceChangerParams class EmbedderManager: currentEmbedder: Embedder | None = None params: VoiceChangerParams @classmethod def initialize(cls, params: VoiceChangerParams): cls.params = params @classmethod def getEmbedder( cls, embederType: EnumEmbedderTypes, isHalf: bool, dev: device ) -> Embedder: if cls.currentEmbedder is None: print("[Voice Changer] generate new embedder. (no embedder)") cls.currentEmbedder = cls.loadEmbedder(embederType, isHalf, dev) elif cls.currentEmbedder.matchCondition(embederType) is False: print("[Voice Changer] generate new embedder. (not match)") cls.currentEmbedder = cls.loadEmbedder(embederType, isHalf, dev) else: cls.currentEmbedder.setDevice(dev) cls.currentEmbedder.setHalf(isHalf) # print("[Voice Changer] generate new embedder. (ANYWAY)", isHalf) # cls.currentEmbedder = cls.loadEmbedder(embederType, file, isHalf, dev) return cls.currentEmbedder @classmethod def loadEmbedder( cls, embederType: EnumEmbedderTypes, isHalf: bool, dev: device ) -> Embedder: if ( embederType == EnumEmbedderTypes.hubert or embederType == EnumEmbedderTypes.hubert.value ): file = cls.params.hubert_base return FairseqHubert().loadModel(file, dev, isHalf) elif ( embederType == EnumEmbedderTypes.hubert_jp or embederType == EnumEmbedderTypes.hubert_jp.value ): file = cls.params.hubert_base_jp return FairseqHubertJp().loadModel(file, dev, isHalf) elif ( embederType == EnumEmbedderTypes.contentvec or embederType == EnumEmbedderTypes.contentvec.value ): file = cls.params.hubert_base return FairseqContentvec().loadModel(file, dev, isHalf) else: return FairseqHubert().loadModel(file, dev, isHalf)