mirror of
https://github.com/w-okada/voice-changer.git
synced 2025-01-23 13:35:12 +03:00
Refctoring
- embedder type from enum to Literal
This commit is contained in:
parent
89a06b65ec
commit
0bd7bfafc5
@ -90,7 +90,6 @@ voiceChangerParams = VoiceChangerParams(
|
||||
nsf_hifigan=args.nsf_hifigan,
|
||||
crepe_onnx_full=args.crepe_onnx_full,
|
||||
crepe_onnx_tiny=args.crepe_onnx_tiny,
|
||||
|
||||
sample_mode=args.sample_mode,
|
||||
)
|
||||
|
||||
|
@ -54,11 +54,11 @@ def getFrontendPath():
|
||||
return frontend_path
|
||||
|
||||
|
||||
# "hubert_base", "contentvec", "distilhubert"
|
||||
class EnumEmbedderTypes(Enum):
|
||||
hubert = "hubert_base"
|
||||
contentvec = "contentvec"
|
||||
hubert_jp = "hubert-base-japanese"
|
||||
EmbedderType: TypeAlias = Literal[
|
||||
"hubert_base",
|
||||
"contentvec",
|
||||
"hubert-base-japanese"
|
||||
]
|
||||
|
||||
|
||||
class EnumInferenceTypes(Enum):
|
||||
|
@ -1,5 +1,5 @@
|
||||
from typing import TypeAlias, Union
|
||||
from const import MAX_SLOT_NUM, EnumInferenceTypes, EnumEmbedderTypes, VoiceChangerType
|
||||
from const import MAX_SLOT_NUM, EnumInferenceTypes, EmbedderType, VoiceChangerType
|
||||
|
||||
from dataclasses import dataclass, asdict, field
|
||||
|
||||
@ -34,7 +34,7 @@ class RVCModelSlot(ModelSlot):
|
||||
embOutputLayer: int = 9
|
||||
useFinalProj: bool = True
|
||||
deprecated: bool = False
|
||||
embedder: str = EnumEmbedderTypes.hubert.value
|
||||
embedder: EmbedderType = "hubert_base"
|
||||
|
||||
sampleId: str = ""
|
||||
speakers: dict = field(default_factory=lambda: {0: "target"})
|
||||
|
@ -166,7 +166,8 @@ def _downloadSamples(samples: list[ModelSamples], sampleModelIds: list[Tuple[str
|
||||
slotInfo = modelSlotManager.get_slot_info(targetSlotIndex)
|
||||
if slotInfo.voiceChangerType == "RVC":
|
||||
if slotInfo.isONNX:
|
||||
RVCModelSlotGenerator._setInfoByONNX(slotInfo)
|
||||
slotInfo = RVCModelSlotGenerator._setInfoByONNX(slotInfo)
|
||||
else:
|
||||
RVCModelSlotGenerator._setInfoByPytorch(slotInfo)
|
||||
slotInfo = RVCModelSlotGenerator._setInfoByPytorch(slotInfo)
|
||||
|
||||
modelSlotManager.save_model_slot(targetSlotIndex, slotInfo)
|
||||
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from const import EnumEmbedderTypes, EnumInferenceTypes
|
||||
|
||||
from const import EnumInferenceTypes
|
||||
from dataclasses import asdict
|
||||
import torch
|
||||
import onnxruntime
|
||||
import json
|
||||
@ -38,6 +38,8 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
|
||||
config_len = len(cpt["config"])
|
||||
version = cpt.get("version", "v1")
|
||||
|
||||
slot = RVCModelSlot(**asdict(slot))
|
||||
|
||||
if version == "voras_beta":
|
||||
slot.f0 = True if cpt["f0"] == 1 else False
|
||||
slot.modelType = EnumInferenceTypes.pyTorchVoRASbeta.value
|
||||
@ -49,12 +51,12 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
|
||||
if slot.embedder.endswith("768"):
|
||||
slot.embedder = slot.embedder[:-3]
|
||||
|
||||
if slot.embedder == EnumEmbedderTypes.hubert.value:
|
||||
slot.embedder = EnumEmbedderTypes.hubert.value
|
||||
elif slot.embedder == EnumEmbedderTypes.contentvec.value:
|
||||
slot.embedder = EnumEmbedderTypes.contentvec.value
|
||||
elif slot.embedder == EnumEmbedderTypes.hubert_jp.value:
|
||||
slot.embedder = EnumEmbedderTypes.hubert_jp.value
|
||||
# if slot.embedder == "hubert":
|
||||
# slot.embedder = "hubert"
|
||||
# elif slot.embedder == "contentvec":
|
||||
# slot.embedder = "contentvec"
|
||||
# elif slot.embedder == "hubert_jp":
|
||||
# slot.embedder = "hubert_jp"
|
||||
else:
|
||||
raise RuntimeError("[Voice Changer][setInfoByONNX] unknown embedder")
|
||||
|
||||
@ -67,14 +69,14 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
|
||||
slot.embChannels = 256
|
||||
slot.embOutputLayer = 9
|
||||
slot.useFinalProj = True
|
||||
slot.embedder = EnumEmbedderTypes.hubert.value
|
||||
slot.embedder = "hubert_base"
|
||||
print("[Voice Changer] Official Model(pyTorch) : v1")
|
||||
else:
|
||||
slot.modelType = EnumInferenceTypes.pyTorchRVCv2.value if slot.f0 else EnumInferenceTypes.pyTorchRVCv2Nono.value
|
||||
slot.embChannels = 768
|
||||
slot.embOutputLayer = 12
|
||||
slot.useFinalProj = False
|
||||
slot.embedder = EnumEmbedderTypes.hubert.value
|
||||
slot.embedder = "hubert_base"
|
||||
print("[Voice Changer] Official Model(pyTorch) : v2")
|
||||
|
||||
else:
|
||||
@ -104,24 +106,18 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
|
||||
for k, v in cpt["speaker_info"].items():
|
||||
slot.speakers[int(k)] = str(v)
|
||||
|
||||
# if slot.embedder == EnumEmbedderTypes.hubert.value:
|
||||
# slot.embedder = EnumEmbedderTypes.hubert
|
||||
# elif slot.embedder == EnumEmbedderTypes.contentvec.value:
|
||||
# slot.embedder = EnumEmbedderTypes.contentvec
|
||||
# elif slot.embedder == EnumEmbedderTypes.hubert_jp.value:
|
||||
# slot.embedder = EnumEmbedderTypes.hubert_jp
|
||||
# else:
|
||||
# raise RuntimeError("[Voice Changer][setInfoByONNX] unknown embedder")
|
||||
|
||||
slot.samplingRate = cpt["config"][-1]
|
||||
|
||||
del cpt
|
||||
|
||||
return slot
|
||||
|
||||
@classmethod
|
||||
def _setInfoByONNX(cls, slot: ModelSlot):
|
||||
tmp_onnx_session = onnxruntime.InferenceSession(slot.modelFile, providers=["CPUExecutionProvider"])
|
||||
modelmeta = tmp_onnx_session.get_modelmeta()
|
||||
try:
|
||||
slot = RVCModelSlot(**asdict(slot))
|
||||
metadata = json.loads(modelmeta.custom_metadata_map["metadata"])
|
||||
|
||||
# slot.modelType = metadata["modelType"]
|
||||
@ -144,17 +140,9 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
|
||||
print(f"[Voice Changer] ONNX Model: ch:{slot.embChannels}, L:{slot.embOutputLayer}, FP:{slot.useFinalProj}")
|
||||
|
||||
if "embedder" not in metadata:
|
||||
slot.embedder = EnumEmbedderTypes.hubert.value
|
||||
slot.embedder = "hubert_base"
|
||||
else:
|
||||
slot.embedder = metadata["embedder"]
|
||||
# elif metadata["embedder"] == EnumEmbedderTypes.hubert.value:
|
||||
# slot.embedder = EnumEmbedderTypes.hubert
|
||||
# elif metadata["embedder"] == EnumEmbedderTypes.contentvec.value:
|
||||
# slot.embedder = EnumEmbedderTypes.contentvec
|
||||
# elif metadata["embedder"] == EnumEmbedderTypes.hubert_jp.value:
|
||||
# slot.embedder = EnumEmbedderTypes.hubert_jp
|
||||
# else:
|
||||
# raise RuntimeError("[Voice Changer][setInfoByONNX] unknown embedder")
|
||||
|
||||
slot.f0 = metadata["f0"]
|
||||
slot.modelType = EnumInferenceTypes.onnxRVC.value if slot.f0 else EnumInferenceTypes.onnxRVCNono.value
|
||||
@ -164,7 +152,7 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
|
||||
except Exception as e:
|
||||
slot.modelType = EnumInferenceTypes.onnxRVC.value
|
||||
slot.embChannels = 256
|
||||
slot.embedder = EnumEmbedderTypes.hubert.value
|
||||
slot.embedder = "hubert_base"
|
||||
slot.f0 = True
|
||||
slot.samplingRate = 48000
|
||||
slot.deprecated = True
|
||||
@ -175,3 +163,4 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
|
||||
print("[Voice Changer] ############## !!!! CAUTION !!!! ####################")
|
||||
|
||||
del tmp_onnx_session
|
||||
return slot
|
||||
|
@ -3,11 +3,11 @@ from typing import Any, Protocol
|
||||
import torch
|
||||
from torch import device
|
||||
|
||||
from const import EnumEmbedderTypes
|
||||
from const import EmbedderType
|
||||
|
||||
|
||||
class Embedder(Protocol):
|
||||
embedderType: EnumEmbedderTypes = EnumEmbedderTypes.hubert
|
||||
embedderType: EmbedderType = "hubert_base"
|
||||
file: str
|
||||
isHalf: bool = True
|
||||
dev: device
|
||||
@ -24,7 +24,7 @@ class Embedder(Protocol):
|
||||
|
||||
def getEmbedderInfo(self):
|
||||
return {
|
||||
"embedderType": self.embedderType.value,
|
||||
"embedderType": self.embedderType,
|
||||
"file": self.file,
|
||||
"isHalf": self.isHalf,
|
||||
"devType": self.dev.type,
|
||||
@ -33,7 +33,7 @@ class Embedder(Protocol):
|
||||
|
||||
def setProps(
|
||||
self,
|
||||
embedderType: EnumEmbedderTypes,
|
||||
embedderType: EmbedderType,
|
||||
file: str,
|
||||
dev: device,
|
||||
isHalf: bool = True,
|
||||
@ -56,7 +56,7 @@ class Embedder(Protocol):
|
||||
self.model = self.model.to(self.dev)
|
||||
return self
|
||||
|
||||
def matchCondition(self, embedderType: EnumEmbedderTypes) -> bool:
|
||||
def matchCondition(self, embedderType: EmbedderType) -> bool:
|
||||
# Check Type
|
||||
if self.embedderType != embedderType:
|
||||
print(
|
||||
|
@ -1,6 +1,6 @@
|
||||
from torch import device
|
||||
|
||||
from const import EnumEmbedderTypes
|
||||
from const import EmbedderType
|
||||
from voice_changer.RVC.embedder.Embedder import Embedder
|
||||
from voice_changer.RVC.embedder.FairseqContentvec import FairseqContentvec
|
||||
from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert
|
||||
@ -18,7 +18,7 @@ class EmbedderManager:
|
||||
|
||||
@classmethod
|
||||
def getEmbedder(
|
||||
cls, embederType: EnumEmbedderTypes, isHalf: bool, dev: device
|
||||
cls, embederType: EmbedderType, isHalf: bool, dev: device
|
||||
) -> Embedder:
|
||||
if cls.currentEmbedder is None:
|
||||
print("[Voice Changer] generate new embedder. (no embedder)")
|
||||
@ -35,24 +35,15 @@ class EmbedderManager:
|
||||
|
||||
@classmethod
|
||||
def loadEmbedder(
|
||||
cls, embederType: EnumEmbedderTypes, isHalf: bool, dev: device
|
||||
cls, embederType: EmbedderType, isHalf: bool, dev: device
|
||||
) -> Embedder:
|
||||
if (
|
||||
embederType == EnumEmbedderTypes.hubert
|
||||
or embederType == EnumEmbedderTypes.hubert.value
|
||||
):
|
||||
if embederType == "hubert_base":
|
||||
file = cls.params.hubert_base
|
||||
return FairseqHubert().loadModel(file, dev, isHalf)
|
||||
elif (
|
||||
embederType == EnumEmbedderTypes.hubert_jp
|
||||
or embederType == EnumEmbedderTypes.hubert_jp.value
|
||||
):
|
||||
elif embederType == "hubert-base-japanese":
|
||||
file = cls.params.hubert_base_jp
|
||||
return FairseqHubertJp().loadModel(file, dev, isHalf)
|
||||
elif (
|
||||
embederType == EnumEmbedderTypes.contentvec
|
||||
or embederType == EnumEmbedderTypes.contentvec.value
|
||||
):
|
||||
elif embederType == "contentvec":
|
||||
file = cls.params.hubert_base
|
||||
return FairseqContentvec().loadModel(file, dev, isHalf)
|
||||
else:
|
||||
|
@ -1,5 +1,4 @@
|
||||
from torch import device
|
||||
from const import EnumEmbedderTypes
|
||||
from voice_changer.RVC.embedder.Embedder import Embedder
|
||||
from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert
|
||||
|
||||
@ -7,5 +6,5 @@ from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert
|
||||
class FairseqContentvec(FairseqHubert):
|
||||
def loadModel(self, file: str, dev: device, isHalf: bool = True) -> Embedder:
|
||||
super().loadModel(file, dev, isHalf)
|
||||
super().setProps(EnumEmbedderTypes.contentvec, file, dev, isHalf)
|
||||
super().setProps("contentvec", file, dev, isHalf)
|
||||
return self
|
||||
|
@ -1,13 +1,12 @@
|
||||
import torch
|
||||
from torch import device
|
||||
from const import EnumEmbedderTypes
|
||||
from voice_changer.RVC.embedder.Embedder import Embedder
|
||||
from fairseq import checkpoint_utils
|
||||
|
||||
|
||||
class FairseqHubert(Embedder):
|
||||
def loadModel(self, file: str, dev: device, isHalf: bool = True) -> Embedder:
|
||||
super().setProps(EnumEmbedderTypes.hubert, file, dev, isHalf)
|
||||
super().setProps("hubert_base", file, dev, isHalf)
|
||||
|
||||
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
||||
[file],
|
||||
|
@ -1,5 +1,4 @@
|
||||
from torch import device
|
||||
from const import EnumEmbedderTypes
|
||||
from voice_changer.RVC.embedder.Embedder import Embedder
|
||||
from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert
|
||||
|
||||
@ -7,5 +6,5 @@ from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert
|
||||
class FairseqHubertJp(FairseqHubert):
|
||||
def loadModel(self, file: str, dev: device, isHalf: bool = True) -> Embedder:
|
||||
super().loadModel(file, dev, isHalf)
|
||||
super().setProps(EnumEmbedderTypes.hubert_jp, file, dev, isHalf)
|
||||
super().setProps("hubert-base-japanese", file, dev, isHalf)
|
||||
return self
|
||||
|
@ -31,7 +31,6 @@ class InferencerManager:
|
||||
file: str,
|
||||
gpu: int,
|
||||
) -> Inferencer:
|
||||
print("inferencerTypeinferencerTypeinferencerTypeinferencerType", inferencerType)
|
||||
if inferencerType == EnumInferenceTypes.pyTorchRVC or inferencerType == EnumInferenceTypes.pyTorchRVC.value:
|
||||
return RVCInferencer().loadModel(file, gpu)
|
||||
elif inferencerType == EnumInferenceTypes.pyTorchRVCNono or inferencerType == EnumInferenceTypes.pyTorchRVCNono.value:
|
||||
|
Loading…
Reference in New Issue
Block a user