embedder type - file mapping

This commit is contained in:
wataru 2023-05-04 22:46:42 +09:00
parent 15dae8f245
commit f48319c350
3 changed files with 19 additions and 31 deletions

View File

@ -41,7 +41,7 @@ import torch
import traceback import traceback
import faiss import faiss
from const import UPLOAD_DIR, EnumEmbedderTypes from const import UPLOAD_DIR
from voice_changer.RVC.Pipeline import Pipeline from voice_changer.RVC.Pipeline import Pipeline
@ -73,6 +73,7 @@ class RVC:
self.settings.f0Detector self.settings.f0Detector
) )
self.params = params self.params = params
EmbedderManager.initialize(params)
print("RVC initialization: ", params) print("RVC initialization: ", params)
def loadModel(self, props: LoadModelParams): def loadModel(self, props: LoadModelParams):
@ -106,20 +107,7 @@ class RVC:
inferencerFilename = ( inferencerFilename = (
modelSlot.onnxModelFile if modelSlot.isONNX else modelSlot.pyTorchModelFile modelSlot.onnxModelFile if modelSlot.isONNX else modelSlot.pyTorchModelFile
) )
# ファイル名特定(embedder)
if modelSlot.embedder == EnumEmbedderTypes.hubert:
emmbedderFilename = self.params.hubert_base
elif modelSlot.embedder == EnumEmbedderTypes.contentvec:
# emmbedderFilename = self.params.content_vec_500
emmbedderFilename = self.params.hubert_base
elif modelSlot.embedder == EnumEmbedderTypes.hubert_jp:
emmbedderFilename = self.params.hubert_base_jp
else:
raise RuntimeError(
"[Voice Changer] Exception loading embedder failed. unknwon type:",
modelSlot.embedder,
)
# Inferencer 生成 # Inferencer 生成
try: try:
inferencer = InferencerManager.getInferencer( inferencer = InferencerManager.getInferencer(
@ -136,7 +124,7 @@ class RVC:
try: try:
embedder = EmbedderManager.getEmbedder( embedder = EmbedderManager.getEmbedder(
modelSlot.embedder, modelSlot.embedder,
emmbedderFilename, # emmbedderFilename,
half, half,
dev, dev,
) )

View File

@ -45,7 +45,7 @@ class Embedder(Protocol):
self.model = self.model.to(self.dev) self.model = self.model.to(self.dev)
return self return self
def matchCondition(self, embedderType: EnumEmbedderTypes, file: str) -> bool: def matchCondition(self, embedderType: EnumEmbedderTypes) -> bool:
# Check Type # Check Type
if self.embedderType != embedderType: if self.embedderType != embedderType:
print( print(
@ -55,14 +55,5 @@ class Embedder(Protocol):
) )
return False return False
# Check File Path
if self.file != file:
print(
"[Voice Changer] embeder file is not match",
self.file,
file,
)
return False
else: else:
return True return True

View File

@ -5,21 +5,27 @@ from voice_changer.RVC.embedder.Embedder import Embedder
from voice_changer.RVC.embedder.FairseqContentvec import FairseqContentvec from voice_changer.RVC.embedder.FairseqContentvec import FairseqContentvec
from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert
from voice_changer.RVC.embedder.FairseqHubertJp import FairseqHubertJp from voice_changer.RVC.embedder.FairseqHubertJp import FairseqHubertJp
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
class EmbedderManager: class EmbedderManager:
currentEmbedder: Embedder | None = None currentEmbedder: Embedder | None = None
params: VoiceChangerParams
@classmethod
def initialize(cls, params: VoiceChangerParams):
cls.params = params
@classmethod @classmethod
def getEmbedder( def getEmbedder(
cls, embederType: EnumEmbedderTypes, file: str, isHalf: bool, dev: device cls, embederType: EnumEmbedderTypes, isHalf: bool, dev: device
) -> Embedder: ) -> Embedder:
if cls.currentEmbedder is None: if cls.currentEmbedder is None:
print("[Voice Changer] generate new embedder. (no embedder)") print("[Voice Changer] generate new embedder. (no embedder)")
cls.currentEmbedder = cls.loadEmbedder(embederType, file, isHalf, dev) cls.currentEmbedder = cls.loadEmbedder(embederType, isHalf, dev)
elif cls.currentEmbedder.matchCondition(embederType, file) is False: elif cls.currentEmbedder.matchCondition(embederType) is False:
print("[Voice Changer] generate new embedder. (not match)") print("[Voice Changer] generate new embedder. (not match)")
cls.currentEmbedder = cls.loadEmbedder(embederType, file, isHalf, dev) cls.currentEmbedder = cls.loadEmbedder(embederType, isHalf, dev)
else: else:
cls.currentEmbedder.setDevice(dev) cls.currentEmbedder.setDevice(dev)
cls.currentEmbedder.setHalf(isHalf) cls.currentEmbedder.setHalf(isHalf)
@ -29,22 +35,25 @@ class EmbedderManager:
@classmethod @classmethod
def loadEmbedder( def loadEmbedder(
cls, embederType: EnumEmbedderTypes, file: str, isHalf: bool, dev: device cls, embederType: EnumEmbedderTypes, isHalf: bool, dev: device
) -> Embedder: ) -> Embedder:
if ( if (
embederType == EnumEmbedderTypes.hubert embederType == EnumEmbedderTypes.hubert
or embederType == EnumEmbedderTypes.hubert.value or embederType == EnumEmbedderTypes.hubert.value
): ):
file = cls.params.hubert_base
return FairseqHubert().loadModel(file, dev, isHalf) return FairseqHubert().loadModel(file, dev, isHalf)
elif ( elif (
embederType == EnumEmbedderTypes.hubert_jp embederType == EnumEmbedderTypes.hubert_jp
or embederType == EnumEmbedderTypes.hubert_jp.value or embederType == EnumEmbedderTypes.hubert_jp.value
): ):
file = cls.params.hubert_base_jp
return FairseqHubertJp().loadModel(file, dev, isHalf) return FairseqHubertJp().loadModel(file, dev, isHalf)
elif ( elif (
embederType == EnumEmbedderTypes.contentvec embederType == EnumEmbedderTypes.contentvec
or embederType == EnumEmbedderTypes.contentvec.value or embederType == EnumEmbedderTypes.contentvec.value
): ):
file = cls.params.hubert_base
return FairseqContentvec().loadModel(file, dev, isHalf) return FairseqContentvec().loadModel(file, dev, isHalf)
else: else:
return FairseqHubert().loadModel(file, dev, isHalf) return FairseqHubert().loadModel(file, dev, isHalf)