embedder type - file mapping

This commit is contained in:
wataru 2023-05-04 22:46:42 +09:00
parent 15dae8f245
commit f48319c350
3 changed files with 19 additions and 31 deletions

View File

@ -41,7 +41,7 @@ import torch
import traceback
import faiss
from const import UPLOAD_DIR, EnumEmbedderTypes
from const import UPLOAD_DIR
from voice_changer.RVC.Pipeline import Pipeline
@ -73,6 +73,7 @@ class RVC:
self.settings.f0Detector
)
self.params = params
EmbedderManager.initialize(params)
print("RVC initialization: ", params)
def loadModel(self, props: LoadModelParams):
@ -106,20 +107,7 @@ class RVC:
inferencerFilename = (
modelSlot.onnxModelFile if modelSlot.isONNX else modelSlot.pyTorchModelFile
)
# ファイル名特定(embedder)
if modelSlot.embedder == EnumEmbedderTypes.hubert:
emmbedderFilename = self.params.hubert_base
elif modelSlot.embedder == EnumEmbedderTypes.contentvec:
# emmbedderFilename = self.params.content_vec_500
emmbedderFilename = self.params.hubert_base
elif modelSlot.embedder == EnumEmbedderTypes.hubert_jp:
emmbedderFilename = self.params.hubert_base_jp
else:
raise RuntimeError(
"[Voice Changer] Exception loading embedder failed. unknwon type:",
modelSlot.embedder,
)
# Inferencer 生成
try:
inferencer = InferencerManager.getInferencer(
@ -136,7 +124,7 @@ class RVC:
try:
embedder = EmbedderManager.getEmbedder(
modelSlot.embedder,
emmbedderFilename,
# emmbedderFilename,
half,
dev,
)

View File

@ -45,7 +45,7 @@ class Embedder(Protocol):
self.model = self.model.to(self.dev)
return self
def matchCondition(self, embedderType: EnumEmbedderTypes, file: str) -> bool:
def matchCondition(self, embedderType: EnumEmbedderTypes) -> bool:
# Check Type
if self.embedderType != embedderType:
print(
@ -55,14 +55,5 @@ class Embedder(Protocol):
)
return False
# Check File Path
if self.file != file:
print(
"[Voice Changer] embeder file is not match",
self.file,
file,
)
return False
else:
return True

View File

@ -5,21 +5,27 @@ from voice_changer.RVC.embedder.Embedder import Embedder
from voice_changer.RVC.embedder.FairseqContentvec import FairseqContentvec
from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert
from voice_changer.RVC.embedder.FairseqHubertJp import FairseqHubertJp
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
class EmbedderManager:
currentEmbedder: Embedder | None = None
params: VoiceChangerParams
@classmethod
def initialize(cls, params: VoiceChangerParams):
cls.params = params
@classmethod
def getEmbedder(
cls, embederType: EnumEmbedderTypes, file: str, isHalf: bool, dev: device
cls, embederType: EnumEmbedderTypes, isHalf: bool, dev: device
) -> Embedder:
if cls.currentEmbedder is None:
print("[Voice Changer] generate new embedder. (no embedder)")
cls.currentEmbedder = cls.loadEmbedder(embederType, file, isHalf, dev)
elif cls.currentEmbedder.matchCondition(embederType, file) is False:
cls.currentEmbedder = cls.loadEmbedder(embederType, isHalf, dev)
elif cls.currentEmbedder.matchCondition(embederType) is False:
print("[Voice Changer] generate new embedder. (not match)")
cls.currentEmbedder = cls.loadEmbedder(embederType, file, isHalf, dev)
cls.currentEmbedder = cls.loadEmbedder(embederType, isHalf, dev)
else:
cls.currentEmbedder.setDevice(dev)
cls.currentEmbedder.setHalf(isHalf)
@ -29,22 +35,25 @@ class EmbedderManager:
@classmethod
def loadEmbedder(
cls, embederType: EnumEmbedderTypes, file: str, isHalf: bool, dev: device
cls, embederType: EnumEmbedderTypes, isHalf: bool, dev: device
) -> Embedder:
if (
embederType == EnumEmbedderTypes.hubert
or embederType == EnumEmbedderTypes.hubert.value
):
file = cls.params.hubert_base
return FairseqHubert().loadModel(file, dev, isHalf)
elif (
embederType == EnumEmbedderTypes.hubert_jp
or embederType == EnumEmbedderTypes.hubert_jp.value
):
file = cls.params.hubert_base_jp
return FairseqHubertJp().loadModel(file, dev, isHalf)
elif (
embederType == EnumEmbedderTypes.contentvec
or embederType == EnumEmbedderTypes.contentvec.value
):
file = cls.params.hubert_base
return FairseqContentvec().loadModel(file, dev, isHalf)
else:
return FairseqHubert().loadModel(file, dev, isHalf)