diff --git a/server/voice_changer/RVC/RVC.py b/server/voice_changer/RVC/RVC.py index 606bde8d..91e52244 100644 --- a/server/voice_changer/RVC/RVC.py +++ b/server/voice_changer/RVC/RVC.py @@ -41,7 +41,7 @@ import torch import traceback import faiss -from const import UPLOAD_DIR, EnumEmbedderTypes +from const import UPLOAD_DIR from voice_changer.RVC.Pipeline import Pipeline @@ -73,6 +73,7 @@ class RVC: self.settings.f0Detector ) self.params = params + EmbedderManager.initialize(params) print("RVC initialization: ", params) def loadModel(self, props: LoadModelParams): @@ -106,20 +107,7 @@ class RVC: inferencerFilename = ( modelSlot.onnxModelFile if modelSlot.isONNX else modelSlot.pyTorchModelFile ) - # ファイル名特定(embedder) - if modelSlot.embedder == EnumEmbedderTypes.hubert: - emmbedderFilename = self.params.hubert_base - elif modelSlot.embedder == EnumEmbedderTypes.contentvec: - # emmbedderFilename = self.params.content_vec_500 - emmbedderFilename = self.params.hubert_base - elif modelSlot.embedder == EnumEmbedderTypes.hubert_jp: - emmbedderFilename = self.params.hubert_base_jp - else: - raise RuntimeError( - "[Voice Changer] Exception loading embedder failed. unknwon type:", - modelSlot.embedder, - ) - + # Inferencer 生成 try: inferencer = InferencerManager.getInferencer( @@ -136,7 +124,7 @@ class RVC: try: embedder = EmbedderManager.getEmbedder( modelSlot.embedder, - emmbedderFilename, + # emmbedderFilename, half, dev, ) diff --git a/server/voice_changer/RVC/embedder/Embedder.py b/server/voice_changer/RVC/embedder/Embedder.py index 728f1e7a..be8d1c49 100644 --- a/server/voice_changer/RVC/embedder/Embedder.py +++ b/server/voice_changer/RVC/embedder/Embedder.py @@ -45,7 +45,7 @@ class Embedder(Protocol): self.model = self.model.to(self.dev) return self - def matchCondition(self, embedderType: EnumEmbedderTypes, file: str) -> bool: + def matchCondition(self, embedderType: EnumEmbedderTypes) -> bool: # Check Type if self.embedderType != embedderType: print( @@ -55,14 +55,5 @@ class Embedder(Protocol): ) return False - # Check File Path - if self.file != file: - print( - "[Voice Changer] embeder file is not match", - self.file, - file, - ) - return False - else: return True diff --git a/server/voice_changer/RVC/embedder/EmbedderManager.py b/server/voice_changer/RVC/embedder/EmbedderManager.py index b6d6d100..dfcaa920 100644 --- a/server/voice_changer/RVC/embedder/EmbedderManager.py +++ b/server/voice_changer/RVC/embedder/EmbedderManager.py @@ -5,21 +5,27 @@ from voice_changer.RVC.embedder.Embedder import Embedder from voice_changer.RVC.embedder.FairseqContentvec import FairseqContentvec from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert from voice_changer.RVC.embedder.FairseqHubertJp import FairseqHubertJp +from voice_changer.utils.VoiceChangerParams import VoiceChangerParams class EmbedderManager: currentEmbedder: Embedder | None = None + params: VoiceChangerParams + + @classmethod + def initialize(cls, params: VoiceChangerParams): + cls.params = params @classmethod def getEmbedder( - cls, embederType: EnumEmbedderTypes, file: str, isHalf: bool, dev: device + cls, embederType: EnumEmbedderTypes, isHalf: bool, dev: device ) -> Embedder: if cls.currentEmbedder is None: print("[Voice Changer] generate new embedder. (no embedder)") - cls.currentEmbedder = cls.loadEmbedder(embederType, file, isHalf, dev) - elif cls.currentEmbedder.matchCondition(embederType, file) is False: + cls.currentEmbedder = cls.loadEmbedder(embederType, isHalf, dev) + elif cls.currentEmbedder.matchCondition(embederType) is False: print("[Voice Changer] generate new embedder. (not match)") - cls.currentEmbedder = cls.loadEmbedder(embederType, file, isHalf, dev) + cls.currentEmbedder = cls.loadEmbedder(embederType, isHalf, dev) else: cls.currentEmbedder.setDevice(dev) cls.currentEmbedder.setHalf(isHalf) @@ -29,22 +35,25 @@ class EmbedderManager: @classmethod def loadEmbedder( - cls, embederType: EnumEmbedderTypes, file: str, isHalf: bool, dev: device + cls, embederType: EnumEmbedderTypes, isHalf: bool, dev: device ) -> Embedder: if ( embederType == EnumEmbedderTypes.hubert or embederType == EnumEmbedderTypes.hubert.value ): + file = cls.params.hubert_base return FairseqHubert().loadModel(file, dev, isHalf) elif ( embederType == EnumEmbedderTypes.hubert_jp or embederType == EnumEmbedderTypes.hubert_jp.value ): + file = cls.params.hubert_base_jp return FairseqHubertJp().loadModel(file, dev, isHalf) elif ( embederType == EnumEmbedderTypes.contentvec or embederType == EnumEmbedderTypes.contentvec.value ): + file = cls.params.hubert_base return FairseqContentvec().loadModel(file, dev, isHalf) else: return FairseqHubert().loadModel(file, dev, isHalf)