Refctoring

- embedder type from enum to Literal
This commit is contained in:
w-okada 2023-07-07 03:56:45 +09:00
parent 89a06b65ec
commit 0bd7bfafc5
11 changed files with 42 additions and 66 deletions

View File

@ -90,7 +90,6 @@ voiceChangerParams = VoiceChangerParams(
nsf_hifigan=args.nsf_hifigan, nsf_hifigan=args.nsf_hifigan,
crepe_onnx_full=args.crepe_onnx_full, crepe_onnx_full=args.crepe_onnx_full,
crepe_onnx_tiny=args.crepe_onnx_tiny, crepe_onnx_tiny=args.crepe_onnx_tiny,
sample_mode=args.sample_mode, sample_mode=args.sample_mode,
) )

View File

@ -54,11 +54,11 @@ def getFrontendPath():
return frontend_path return frontend_path
# "hubert_base", "contentvec", "distilhubert" EmbedderType: TypeAlias = Literal[
class EnumEmbedderTypes(Enum): "hubert_base",
hubert = "hubert_base" "contentvec",
contentvec = "contentvec" "hubert-base-japanese"
hubert_jp = "hubert-base-japanese" ]
class EnumInferenceTypes(Enum): class EnumInferenceTypes(Enum):

View File

@ -1,5 +1,5 @@
from typing import TypeAlias, Union from typing import TypeAlias, Union
from const import MAX_SLOT_NUM, EnumInferenceTypes, EnumEmbedderTypes, VoiceChangerType from const import MAX_SLOT_NUM, EnumInferenceTypes, EmbedderType, VoiceChangerType
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
@ -34,7 +34,7 @@ class RVCModelSlot(ModelSlot):
embOutputLayer: int = 9 embOutputLayer: int = 9
useFinalProj: bool = True useFinalProj: bool = True
deprecated: bool = False deprecated: bool = False
embedder: str = EnumEmbedderTypes.hubert.value embedder: EmbedderType = "hubert_base"
sampleId: str = "" sampleId: str = ""
speakers: dict = field(default_factory=lambda: {0: "target"}) speakers: dict = field(default_factory=lambda: {0: "target"})

View File

@ -166,7 +166,8 @@ def _downloadSamples(samples: list[ModelSamples], sampleModelIds: list[Tuple[str
slotInfo = modelSlotManager.get_slot_info(targetSlotIndex) slotInfo = modelSlotManager.get_slot_info(targetSlotIndex)
if slotInfo.voiceChangerType == "RVC": if slotInfo.voiceChangerType == "RVC":
if slotInfo.isONNX: if slotInfo.isONNX:
RVCModelSlotGenerator._setInfoByONNX(slotInfo) slotInfo = RVCModelSlotGenerator._setInfoByONNX(slotInfo)
else: else:
RVCModelSlotGenerator._setInfoByPytorch(slotInfo) slotInfo = RVCModelSlotGenerator._setInfoByPytorch(slotInfo)
modelSlotManager.save_model_slot(targetSlotIndex, slotInfo) modelSlotManager.save_model_slot(targetSlotIndex, slotInfo)

View File

@ -1,6 +1,6 @@
import os import os
from const import EnumEmbedderTypes, EnumInferenceTypes from const import EnumInferenceTypes
from dataclasses import asdict
import torch import torch
import onnxruntime import onnxruntime
import json import json
@ -38,6 +38,8 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
config_len = len(cpt["config"]) config_len = len(cpt["config"])
version = cpt.get("version", "v1") version = cpt.get("version", "v1")
slot = RVCModelSlot(**asdict(slot))
if version == "voras_beta": if version == "voras_beta":
slot.f0 = True if cpt["f0"] == 1 else False slot.f0 = True if cpt["f0"] == 1 else False
slot.modelType = EnumInferenceTypes.pyTorchVoRASbeta.value slot.modelType = EnumInferenceTypes.pyTorchVoRASbeta.value
@ -49,12 +51,12 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
if slot.embedder.endswith("768"): if slot.embedder.endswith("768"):
slot.embedder = slot.embedder[:-3] slot.embedder = slot.embedder[:-3]
if slot.embedder == EnumEmbedderTypes.hubert.value: # if slot.embedder == "hubert":
slot.embedder = EnumEmbedderTypes.hubert.value # slot.embedder = "hubert"
elif slot.embedder == EnumEmbedderTypes.contentvec.value: # elif slot.embedder == "contentvec":
slot.embedder = EnumEmbedderTypes.contentvec.value # slot.embedder = "contentvec"
elif slot.embedder == EnumEmbedderTypes.hubert_jp.value: # elif slot.embedder == "hubert_jp":
slot.embedder = EnumEmbedderTypes.hubert_jp.value # slot.embedder = "hubert_jp"
else: else:
raise RuntimeError("[Voice Changer][setInfoByONNX] unknown embedder") raise RuntimeError("[Voice Changer][setInfoByONNX] unknown embedder")
@ -67,14 +69,14 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
slot.embChannels = 256 slot.embChannels = 256
slot.embOutputLayer = 9 slot.embOutputLayer = 9
slot.useFinalProj = True slot.useFinalProj = True
slot.embedder = EnumEmbedderTypes.hubert.value slot.embedder = "hubert_base"
print("[Voice Changer] Official Model(pyTorch) : v1") print("[Voice Changer] Official Model(pyTorch) : v1")
else: else:
slot.modelType = EnumInferenceTypes.pyTorchRVCv2.value if slot.f0 else EnumInferenceTypes.pyTorchRVCv2Nono.value slot.modelType = EnumInferenceTypes.pyTorchRVCv2.value if slot.f0 else EnumInferenceTypes.pyTorchRVCv2Nono.value
slot.embChannels = 768 slot.embChannels = 768
slot.embOutputLayer = 12 slot.embOutputLayer = 12
slot.useFinalProj = False slot.useFinalProj = False
slot.embedder = EnumEmbedderTypes.hubert.value slot.embedder = "hubert_base"
print("[Voice Changer] Official Model(pyTorch) : v2") print("[Voice Changer] Official Model(pyTorch) : v2")
else: else:
@ -104,24 +106,18 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
for k, v in cpt["speaker_info"].items(): for k, v in cpt["speaker_info"].items():
slot.speakers[int(k)] = str(v) slot.speakers[int(k)] = str(v)
# if slot.embedder == EnumEmbedderTypes.hubert.value:
# slot.embedder = EnumEmbedderTypes.hubert
# elif slot.embedder == EnumEmbedderTypes.contentvec.value:
# slot.embedder = EnumEmbedderTypes.contentvec
# elif slot.embedder == EnumEmbedderTypes.hubert_jp.value:
# slot.embedder = EnumEmbedderTypes.hubert_jp
# else:
# raise RuntimeError("[Voice Changer][setInfoByONNX] unknown embedder")
slot.samplingRate = cpt["config"][-1] slot.samplingRate = cpt["config"][-1]
del cpt del cpt
return slot
@classmethod @classmethod
def _setInfoByONNX(cls, slot: ModelSlot): def _setInfoByONNX(cls, slot: ModelSlot):
tmp_onnx_session = onnxruntime.InferenceSession(slot.modelFile, providers=["CPUExecutionProvider"]) tmp_onnx_session = onnxruntime.InferenceSession(slot.modelFile, providers=["CPUExecutionProvider"])
modelmeta = tmp_onnx_session.get_modelmeta() modelmeta = tmp_onnx_session.get_modelmeta()
try: try:
slot = RVCModelSlot(**asdict(slot))
metadata = json.loads(modelmeta.custom_metadata_map["metadata"]) metadata = json.loads(modelmeta.custom_metadata_map["metadata"])
# slot.modelType = metadata["modelType"] # slot.modelType = metadata["modelType"]
@ -144,17 +140,9 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
print(f"[Voice Changer] ONNX Model: ch:{slot.embChannels}, L:{slot.embOutputLayer}, FP:{slot.useFinalProj}") print(f"[Voice Changer] ONNX Model: ch:{slot.embChannels}, L:{slot.embOutputLayer}, FP:{slot.useFinalProj}")
if "embedder" not in metadata: if "embedder" not in metadata:
slot.embedder = EnumEmbedderTypes.hubert.value slot.embedder = "hubert_base"
else: else:
slot.embedder = metadata["embedder"] slot.embedder = metadata["embedder"]
# elif metadata["embedder"] == EnumEmbedderTypes.hubert.value:
# slot.embedder = EnumEmbedderTypes.hubert
# elif metadata["embedder"] == EnumEmbedderTypes.contentvec.value:
# slot.embedder = EnumEmbedderTypes.contentvec
# elif metadata["embedder"] == EnumEmbedderTypes.hubert_jp.value:
# slot.embedder = EnumEmbedderTypes.hubert_jp
# else:
# raise RuntimeError("[Voice Changer][setInfoByONNX] unknown embedder")
slot.f0 = metadata["f0"] slot.f0 = metadata["f0"]
slot.modelType = EnumInferenceTypes.onnxRVC.value if slot.f0 else EnumInferenceTypes.onnxRVCNono.value slot.modelType = EnumInferenceTypes.onnxRVC.value if slot.f0 else EnumInferenceTypes.onnxRVCNono.value
@ -164,7 +152,7 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
except Exception as e: except Exception as e:
slot.modelType = EnumInferenceTypes.onnxRVC.value slot.modelType = EnumInferenceTypes.onnxRVC.value
slot.embChannels = 256 slot.embChannels = 256
slot.embedder = EnumEmbedderTypes.hubert.value slot.embedder = "hubert_base"
slot.f0 = True slot.f0 = True
slot.samplingRate = 48000 slot.samplingRate = 48000
slot.deprecated = True slot.deprecated = True
@ -175,3 +163,4 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
print("[Voice Changer] ############## !!!! CAUTION !!!! ####################") print("[Voice Changer] ############## !!!! CAUTION !!!! ####################")
del tmp_onnx_session del tmp_onnx_session
return slot

View File

@ -3,11 +3,11 @@ from typing import Any, Protocol
import torch import torch
from torch import device from torch import device
from const import EnumEmbedderTypes from const import EmbedderType
class Embedder(Protocol): class Embedder(Protocol):
embedderType: EnumEmbedderTypes = EnumEmbedderTypes.hubert embedderType: EmbedderType = "hubert_base"
file: str file: str
isHalf: bool = True isHalf: bool = True
dev: device dev: device
@ -24,7 +24,7 @@ class Embedder(Protocol):
def getEmbedderInfo(self): def getEmbedderInfo(self):
return { return {
"embedderType": self.embedderType.value, "embedderType": self.embedderType,
"file": self.file, "file": self.file,
"isHalf": self.isHalf, "isHalf": self.isHalf,
"devType": self.dev.type, "devType": self.dev.type,
@ -33,7 +33,7 @@ class Embedder(Protocol):
def setProps( def setProps(
self, self,
embedderType: EnumEmbedderTypes, embedderType: EmbedderType,
file: str, file: str,
dev: device, dev: device,
isHalf: bool = True, isHalf: bool = True,
@ -56,7 +56,7 @@ class Embedder(Protocol):
self.model = self.model.to(self.dev) self.model = self.model.to(self.dev)
return self return self
def matchCondition(self, embedderType: EnumEmbedderTypes) -> bool: def matchCondition(self, embedderType: EmbedderType) -> bool:
# Check Type # Check Type
if self.embedderType != embedderType: if self.embedderType != embedderType:
print( print(

View File

@ -1,6 +1,6 @@
from torch import device from torch import device
from const import EnumEmbedderTypes from const import EmbedderType
from voice_changer.RVC.embedder.Embedder import Embedder from voice_changer.RVC.embedder.Embedder import Embedder
from voice_changer.RVC.embedder.FairseqContentvec import FairseqContentvec from voice_changer.RVC.embedder.FairseqContentvec import FairseqContentvec
from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert
@ -18,7 +18,7 @@ class EmbedderManager:
@classmethod @classmethod
def getEmbedder( def getEmbedder(
cls, embederType: EnumEmbedderTypes, isHalf: bool, dev: device cls, embederType: EmbedderType, isHalf: bool, dev: device
) -> Embedder: ) -> Embedder:
if cls.currentEmbedder is None: if cls.currentEmbedder is None:
print("[Voice Changer] generate new embedder. (no embedder)") print("[Voice Changer] generate new embedder. (no embedder)")
@ -35,24 +35,15 @@ class EmbedderManager:
@classmethod @classmethod
def loadEmbedder( def loadEmbedder(
cls, embederType: EnumEmbedderTypes, isHalf: bool, dev: device cls, embederType: EmbedderType, isHalf: bool, dev: device
) -> Embedder: ) -> Embedder:
if ( if embederType == "hubert_base":
embederType == EnumEmbedderTypes.hubert
or embederType == EnumEmbedderTypes.hubert.value
):
file = cls.params.hubert_base file = cls.params.hubert_base
return FairseqHubert().loadModel(file, dev, isHalf) return FairseqHubert().loadModel(file, dev, isHalf)
elif ( elif embederType == "hubert-base-japanese":
embederType == EnumEmbedderTypes.hubert_jp
or embederType == EnumEmbedderTypes.hubert_jp.value
):
file = cls.params.hubert_base_jp file = cls.params.hubert_base_jp
return FairseqHubertJp().loadModel(file, dev, isHalf) return FairseqHubertJp().loadModel(file, dev, isHalf)
elif ( elif embederType == "contentvec":
embederType == EnumEmbedderTypes.contentvec
or embederType == EnumEmbedderTypes.contentvec.value
):
file = cls.params.hubert_base file = cls.params.hubert_base
return FairseqContentvec().loadModel(file, dev, isHalf) return FairseqContentvec().loadModel(file, dev, isHalf)
else: else:

View File

@ -1,5 +1,4 @@
from torch import device from torch import device
from const import EnumEmbedderTypes
from voice_changer.RVC.embedder.Embedder import Embedder from voice_changer.RVC.embedder.Embedder import Embedder
from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert
@ -7,5 +6,5 @@ from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert
class FairseqContentvec(FairseqHubert): class FairseqContentvec(FairseqHubert):
def loadModel(self, file: str, dev: device, isHalf: bool = True) -> Embedder: def loadModel(self, file: str, dev: device, isHalf: bool = True) -> Embedder:
super().loadModel(file, dev, isHalf) super().loadModel(file, dev, isHalf)
super().setProps(EnumEmbedderTypes.contentvec, file, dev, isHalf) super().setProps("contentvec", file, dev, isHalf)
return self return self

View File

@ -1,13 +1,12 @@
import torch import torch
from torch import device from torch import device
from const import EnumEmbedderTypes
from voice_changer.RVC.embedder.Embedder import Embedder from voice_changer.RVC.embedder.Embedder import Embedder
from fairseq import checkpoint_utils from fairseq import checkpoint_utils
class FairseqHubert(Embedder): class FairseqHubert(Embedder):
def loadModel(self, file: str, dev: device, isHalf: bool = True) -> Embedder: def loadModel(self, file: str, dev: device, isHalf: bool = True) -> Embedder:
super().setProps(EnumEmbedderTypes.hubert, file, dev, isHalf) super().setProps("hubert_base", file, dev, isHalf)
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
[file], [file],

View File

@ -1,5 +1,4 @@
from torch import device from torch import device
from const import EnumEmbedderTypes
from voice_changer.RVC.embedder.Embedder import Embedder from voice_changer.RVC.embedder.Embedder import Embedder
from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert
@ -7,5 +6,5 @@ from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert
class FairseqHubertJp(FairseqHubert): class FairseqHubertJp(FairseqHubert):
def loadModel(self, file: str, dev: device, isHalf: bool = True) -> Embedder: def loadModel(self, file: str, dev: device, isHalf: bool = True) -> Embedder:
super().loadModel(file, dev, isHalf) super().loadModel(file, dev, isHalf)
super().setProps(EnumEmbedderTypes.hubert_jp, file, dev, isHalf) super().setProps("hubert-base-japanese", file, dev, isHalf)
return self return self

View File

@ -31,7 +31,6 @@ class InferencerManager:
file: str, file: str,
gpu: int, gpu: int,
) -> Inferencer: ) -> Inferencer:
print("inferencerTypeinferencerTypeinferencerTypeinferencerType", inferencerType)
if inferencerType == EnumInferenceTypes.pyTorchRVC or inferencerType == EnumInferenceTypes.pyTorchRVC.value: if inferencerType == EnumInferenceTypes.pyTorchRVC or inferencerType == EnumInferenceTypes.pyTorchRVC.value:
return RVCInferencer().loadModel(file, gpu) return RVCInferencer().loadModel(file, gpu)
elif inferencerType == EnumInferenceTypes.pyTorchRVCNono or inferencerType == EnumInferenceTypes.pyTorchRVCNono.value: elif inferencerType == EnumInferenceTypes.pyTorchRVCNono or inferencerType == EnumInferenceTypes.pyTorchRVCNono.value: