Refctoring

- embedder type from enum to Literal
This commit is contained in:
w-okada 2023-07-07 03:56:45 +09:00
parent 89a06b65ec
commit 0bd7bfafc5
11 changed files with 42 additions and 66 deletions

View File

@ -90,7 +90,6 @@ voiceChangerParams = VoiceChangerParams(
nsf_hifigan=args.nsf_hifigan,
crepe_onnx_full=args.crepe_onnx_full,
crepe_onnx_tiny=args.crepe_onnx_tiny,
sample_mode=args.sample_mode,
)

View File

@ -54,11 +54,11 @@ def getFrontendPath():
return frontend_path
# "hubert_base", "contentvec", "distilhubert"
class EnumEmbedderTypes(Enum):
hubert = "hubert_base"
contentvec = "contentvec"
hubert_jp = "hubert-base-japanese"
EmbedderType: TypeAlias = Literal[
"hubert_base",
"contentvec",
"hubert-base-japanese"
]
class EnumInferenceTypes(Enum):

View File

@ -1,5 +1,5 @@
from typing import TypeAlias, Union
from const import MAX_SLOT_NUM, EnumInferenceTypes, EnumEmbedderTypes, VoiceChangerType
from const import MAX_SLOT_NUM, EnumInferenceTypes, EmbedderType, VoiceChangerType
from dataclasses import dataclass, asdict, field
@ -34,7 +34,7 @@ class RVCModelSlot(ModelSlot):
embOutputLayer: int = 9
useFinalProj: bool = True
deprecated: bool = False
embedder: str = EnumEmbedderTypes.hubert.value
embedder: EmbedderType = "hubert_base"
sampleId: str = ""
speakers: dict = field(default_factory=lambda: {0: "target"})

View File

@ -166,7 +166,8 @@ def _downloadSamples(samples: list[ModelSamples], sampleModelIds: list[Tuple[str
slotInfo = modelSlotManager.get_slot_info(targetSlotIndex)
if slotInfo.voiceChangerType == "RVC":
if slotInfo.isONNX:
RVCModelSlotGenerator._setInfoByONNX(slotInfo)
slotInfo = RVCModelSlotGenerator._setInfoByONNX(slotInfo)
else:
RVCModelSlotGenerator._setInfoByPytorch(slotInfo)
slotInfo = RVCModelSlotGenerator._setInfoByPytorch(slotInfo)
modelSlotManager.save_model_slot(targetSlotIndex, slotInfo)

View File

@ -1,6 +1,6 @@
import os
from const import EnumEmbedderTypes, EnumInferenceTypes
from const import EnumInferenceTypes
from dataclasses import asdict
import torch
import onnxruntime
import json
@ -38,6 +38,8 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
config_len = len(cpt["config"])
version = cpt.get("version", "v1")
slot = RVCModelSlot(**asdict(slot))
if version == "voras_beta":
slot.f0 = True if cpt["f0"] == 1 else False
slot.modelType = EnumInferenceTypes.pyTorchVoRASbeta.value
@ -49,12 +51,12 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
if slot.embedder.endswith("768"):
slot.embedder = slot.embedder[:-3]
if slot.embedder == EnumEmbedderTypes.hubert.value:
slot.embedder = EnumEmbedderTypes.hubert.value
elif slot.embedder == EnumEmbedderTypes.contentvec.value:
slot.embedder = EnumEmbedderTypes.contentvec.value
elif slot.embedder == EnumEmbedderTypes.hubert_jp.value:
slot.embedder = EnumEmbedderTypes.hubert_jp.value
# if slot.embedder == "hubert":
# slot.embedder = "hubert"
# elif slot.embedder == "contentvec":
# slot.embedder = "contentvec"
# elif slot.embedder == "hubert_jp":
# slot.embedder = "hubert_jp"
else:
raise RuntimeError("[Voice Changer][setInfoByONNX] unknown embedder")
@ -67,14 +69,14 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
slot.embChannels = 256
slot.embOutputLayer = 9
slot.useFinalProj = True
slot.embedder = EnumEmbedderTypes.hubert.value
slot.embedder = "hubert_base"
print("[Voice Changer] Official Model(pyTorch) : v1")
else:
slot.modelType = EnumInferenceTypes.pyTorchRVCv2.value if slot.f0 else EnumInferenceTypes.pyTorchRVCv2Nono.value
slot.embChannels = 768
slot.embOutputLayer = 12
slot.useFinalProj = False
slot.embedder = EnumEmbedderTypes.hubert.value
slot.embedder = "hubert_base"
print("[Voice Changer] Official Model(pyTorch) : v2")
else:
@ -104,24 +106,18 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
for k, v in cpt["speaker_info"].items():
slot.speakers[int(k)] = str(v)
# if slot.embedder == EnumEmbedderTypes.hubert.value:
# slot.embedder = EnumEmbedderTypes.hubert
# elif slot.embedder == EnumEmbedderTypes.contentvec.value:
# slot.embedder = EnumEmbedderTypes.contentvec
# elif slot.embedder == EnumEmbedderTypes.hubert_jp.value:
# slot.embedder = EnumEmbedderTypes.hubert_jp
# else:
# raise RuntimeError("[Voice Changer][setInfoByONNX] unknown embedder")
slot.samplingRate = cpt["config"][-1]
del cpt
return slot
@classmethod
def _setInfoByONNX(cls, slot: ModelSlot):
tmp_onnx_session = onnxruntime.InferenceSession(slot.modelFile, providers=["CPUExecutionProvider"])
modelmeta = tmp_onnx_session.get_modelmeta()
try:
slot = RVCModelSlot(**asdict(slot))
metadata = json.loads(modelmeta.custom_metadata_map["metadata"])
# slot.modelType = metadata["modelType"]
@ -144,17 +140,9 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
print(f"[Voice Changer] ONNX Model: ch:{slot.embChannels}, L:{slot.embOutputLayer}, FP:{slot.useFinalProj}")
if "embedder" not in metadata:
slot.embedder = EnumEmbedderTypes.hubert.value
slot.embedder = "hubert_base"
else:
slot.embedder = metadata["embedder"]
# elif metadata["embedder"] == EnumEmbedderTypes.hubert.value:
# slot.embedder = EnumEmbedderTypes.hubert
# elif metadata["embedder"] == EnumEmbedderTypes.contentvec.value:
# slot.embedder = EnumEmbedderTypes.contentvec
# elif metadata["embedder"] == EnumEmbedderTypes.hubert_jp.value:
# slot.embedder = EnumEmbedderTypes.hubert_jp
# else:
# raise RuntimeError("[Voice Changer][setInfoByONNX] unknown embedder")
slot.f0 = metadata["f0"]
slot.modelType = EnumInferenceTypes.onnxRVC.value if slot.f0 else EnumInferenceTypes.onnxRVCNono.value
@ -164,7 +152,7 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
except Exception as e:
slot.modelType = EnumInferenceTypes.onnxRVC.value
slot.embChannels = 256
slot.embedder = EnumEmbedderTypes.hubert.value
slot.embedder = "hubert_base"
slot.f0 = True
slot.samplingRate = 48000
slot.deprecated = True
@ -175,3 +163,4 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
print("[Voice Changer] ############## !!!! CAUTION !!!! ####################")
del tmp_onnx_session
return slot

View File

@ -3,11 +3,11 @@ from typing import Any, Protocol
import torch
from torch import device
from const import EnumEmbedderTypes
from const import EmbedderType
class Embedder(Protocol):
embedderType: EnumEmbedderTypes = EnumEmbedderTypes.hubert
embedderType: EmbedderType = "hubert_base"
file: str
isHalf: bool = True
dev: device
@ -24,7 +24,7 @@ class Embedder(Protocol):
def getEmbedderInfo(self):
return {
"embedderType": self.embedderType.value,
"embedderType": self.embedderType,
"file": self.file,
"isHalf": self.isHalf,
"devType": self.dev.type,
@ -33,7 +33,7 @@ class Embedder(Protocol):
def setProps(
self,
embedderType: EnumEmbedderTypes,
embedderType: EmbedderType,
file: str,
dev: device,
isHalf: bool = True,
@ -56,7 +56,7 @@ class Embedder(Protocol):
self.model = self.model.to(self.dev)
return self
def matchCondition(self, embedderType: EnumEmbedderTypes) -> bool:
def matchCondition(self, embedderType: EmbedderType) -> bool:
# Check Type
if self.embedderType != embedderType:
print(

View File

@ -1,6 +1,6 @@
from torch import device
from const import EnumEmbedderTypes
from const import EmbedderType
from voice_changer.RVC.embedder.Embedder import Embedder
from voice_changer.RVC.embedder.FairseqContentvec import FairseqContentvec
from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert
@ -18,7 +18,7 @@ class EmbedderManager:
@classmethod
def getEmbedder(
cls, embederType: EnumEmbedderTypes, isHalf: bool, dev: device
cls, embederType: EmbedderType, isHalf: bool, dev: device
) -> Embedder:
if cls.currentEmbedder is None:
print("[Voice Changer] generate new embedder. (no embedder)")
@ -35,24 +35,15 @@ class EmbedderManager:
@classmethod
def loadEmbedder(
cls, embederType: EnumEmbedderTypes, isHalf: bool, dev: device
cls, embederType: EmbedderType, isHalf: bool, dev: device
) -> Embedder:
if (
embederType == EnumEmbedderTypes.hubert
or embederType == EnumEmbedderTypes.hubert.value
):
if embederType == "hubert_base":
file = cls.params.hubert_base
return FairseqHubert().loadModel(file, dev, isHalf)
elif (
embederType == EnumEmbedderTypes.hubert_jp
or embederType == EnumEmbedderTypes.hubert_jp.value
):
elif embederType == "hubert-base-japanese":
file = cls.params.hubert_base_jp
return FairseqHubertJp().loadModel(file, dev, isHalf)
elif (
embederType == EnumEmbedderTypes.contentvec
or embederType == EnumEmbedderTypes.contentvec.value
):
elif embederType == "contentvec":
file = cls.params.hubert_base
return FairseqContentvec().loadModel(file, dev, isHalf)
else:

View File

@ -1,5 +1,4 @@
from torch import device
from const import EnumEmbedderTypes
from voice_changer.RVC.embedder.Embedder import Embedder
from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert
@ -7,5 +6,5 @@ from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert
class FairseqContentvec(FairseqHubert):
def loadModel(self, file: str, dev: device, isHalf: bool = True) -> Embedder:
super().loadModel(file, dev, isHalf)
super().setProps(EnumEmbedderTypes.contentvec, file, dev, isHalf)
super().setProps("contentvec", file, dev, isHalf)
return self

View File

@ -1,13 +1,12 @@
import torch
from torch import device
from const import EnumEmbedderTypes
from voice_changer.RVC.embedder.Embedder import Embedder
from fairseq import checkpoint_utils
class FairseqHubert(Embedder):
def loadModel(self, file: str, dev: device, isHalf: bool = True) -> Embedder:
super().setProps(EnumEmbedderTypes.hubert, file, dev, isHalf)
super().setProps("hubert_base", file, dev, isHalf)
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
[file],

View File

@ -1,5 +1,4 @@
from torch import device
from const import EnumEmbedderTypes
from voice_changer.RVC.embedder.Embedder import Embedder
from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert
@ -7,5 +6,5 @@ from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert
class FairseqHubertJp(FairseqHubert):
def loadModel(self, file: str, dev: device, isHalf: bool = True) -> Embedder:
super().loadModel(file, dev, isHalf)
super().setProps(EnumEmbedderTypes.hubert_jp, file, dev, isHalf)
super().setProps("hubert-base-japanese", file, dev, isHalf)
return self

View File

@ -31,7 +31,6 @@ class InferencerManager:
file: str,
gpu: int,
) -> Inferencer:
print("inferencerTypeinferencerTypeinferencerTypeinferencerType", inferencerType)
if inferencerType == EnumInferenceTypes.pyTorchRVC or inferencerType == EnumInferenceTypes.pyTorchRVC.value:
return RVCInferencer().loadModel(file, gpu)
elif inferencerType == EnumInferenceTypes.pyTorchRVCNono or inferencerType == EnumInferenceTypes.pyTorchRVCNono.value: