voice-changer/server/voice_changer/RVC/ModelSlotGenerator.py

148 lines
4.8 KiB
Python
Raw Normal View History

2023-05-02 14:57:12 +03:00
from const import EnumEmbedderTypes, EnumInferenceTypes
from voice_changer.RVC.ModelSlot import ModelSlot
import torch
import onnxruntime
import json
2023-05-14 22:24:58 +03:00
import os
2023-05-02 14:57:12 +03:00
2023-05-14 22:24:58 +03:00
def generateModelSlot_(params):
2023-05-02 14:57:12 +03:00
modelSlot = ModelSlot()
2023-05-08 19:01:20 +03:00
modelSlot.modelFile = params["files"]["rvcModel"]
modelSlot.featureFile = (
params["files"]["rvcFeature"] if "rvcFeature" in params["files"] else None
)
modelSlot.indexFile = (
params["files"]["rvcIndex"] if "rvcIndex" in params["files"] else None
)
2023-05-02 14:57:12 +03:00
modelSlot.defaultTrans = params["trans"] if "trans" in params else 0
2023-05-08 19:01:20 +03:00
modelSlot.isONNX = modelSlot.modelFile.endswith(".onnx")
2023-05-02 14:57:12 +03:00
if modelSlot.isONNX:
2023-05-08 19:01:20 +03:00
_setInfoByONNX(modelSlot)
2023-05-02 14:57:12 +03:00
else:
2023-05-08 19:01:20 +03:00
_setInfoByPytorch(modelSlot)
2023-05-02 14:57:12 +03:00
return modelSlot
2023-05-14 22:24:58 +03:00
def generateModelSlot(slotDir: str):
modelSlot = ModelSlot()
if os.path.exists(slotDir) == False:
return modelSlot
paramFile = os.path.join(slotDir, "params.json")
with open(paramFile, "r") as f:
params = json.load(f)
modelSlot.modelFile = os.path.join(
slotDir, os.path.basename(params["files"]["rvcModel"])
)
if "rvcFeature" in params["files"]:
modelSlot.featureFile = os.path.join(
slotDir, os.path.basename(params["files"]["rvcFeature"])
)
else:
modelSlot.featureFile = None
if "rvcIndex" in params["files"]:
modelSlot.indexFile = os.path.join(
slotDir, os.path.basename(params["files"]["rvcIndex"])
)
else:
modelSlot.indexFile = None
modelSlot.defaultTrans = params["trans"] if "trans" in params else 0
modelSlot.isONNX = modelSlot.modelFile.endswith(".onnx")
if modelSlot.isONNX:
_setInfoByONNX(modelSlot)
else:
_setInfoByPytorch(modelSlot)
return modelSlot
2023-05-08 19:01:20 +03:00
def _setInfoByPytorch(slot: ModelSlot):
cpt = torch.load(slot.modelFile, map_location="cpu")
2023-05-02 14:57:12 +03:00
config_len = len(cpt["config"])
if config_len == 18:
slot.f0 = True if cpt["f0"] == 1 else False
slot.modelType = (
EnumInferenceTypes.pyTorchRVC
if slot.f0
else EnumInferenceTypes.pyTorchRVCNono
)
slot.embChannels = 256
slot.embedder = EnumEmbedderTypes.hubert
else:
slot.f0 = True if cpt["f0"] == 1 else False
slot.modelType = (
EnumInferenceTypes.pyTorchWebUI
if slot.f0
else EnumInferenceTypes.pyTorchWebUINono
)
slot.embChannels = cpt["config"][17]
slot.embedder = cpt["embedder_name"]
if slot.embedder.endswith("768"):
slot.embedder = slot.embedder[:-3]
2023-05-03 07:14:00 +03:00
if slot.embedder == EnumEmbedderTypes.hubert.value:
slot.embedder = EnumEmbedderTypes.hubert
elif slot.embedder == EnumEmbedderTypes.contentvec.value:
slot.embedder = EnumEmbedderTypes.contentvec
elif slot.embedder == EnumEmbedderTypes.hubert_jp.value:
slot.embedder = EnumEmbedderTypes.hubert_jp
else:
raise RuntimeError("[Voice Changer][setInfoByONNX] unknown embedder")
2023-05-02 14:57:12 +03:00
slot.samplingRate = cpt["config"][-1]
del cpt
2023-05-08 19:01:20 +03:00
def _setInfoByONNX(slot: ModelSlot):
2023-05-02 14:57:12 +03:00
tmp_onnx_session = onnxruntime.InferenceSession(
2023-05-08 19:01:20 +03:00
slot.modelFile, providers=["CPUExecutionProvider"]
2023-05-02 14:57:12 +03:00
)
modelmeta = tmp_onnx_session.get_modelmeta()
try:
metadata = json.loads(modelmeta.custom_metadata_map["metadata"])
2023-05-03 11:12:40 +03:00
# slot.modelType = metadata["modelType"]
2023-05-02 14:57:12 +03:00
slot.embChannels = metadata["embChannels"]
2023-05-03 07:14:00 +03:00
if "embedder" not in metadata:
slot.embedder = EnumEmbedderTypes.hubert
elif metadata["embedder"] == EnumEmbedderTypes.hubert.value:
slot.embedder = EnumEmbedderTypes.hubert
elif metadata["embedder"] == EnumEmbedderTypes.contentvec.value:
slot.embedder = EnumEmbedderTypes.contentvec
elif metadata["embedder"] == EnumEmbedderTypes.hubert_jp.value:
slot.embedder = EnumEmbedderTypes.hubert_jp
else:
raise RuntimeError("[Voice Changer][setInfoByONNX] unknown embedder")
2023-05-02 14:57:12 +03:00
slot.f0 = metadata["f0"]
slot.modelType = (
EnumInferenceTypes.onnxRVC if slot.f0 else EnumInferenceTypes.onnxRVCNono
)
slot.samplingRate = metadata["samplingRate"]
slot.deprecated = False
2023-05-03 07:14:00 +03:00
except Exception as e:
2023-05-02 14:57:12 +03:00
slot.modelType = EnumInferenceTypes.onnxRVC
slot.embChannels = 256
slot.embedder = EnumEmbedderTypes.hubert
slot.f0 = True
slot.samplingRate = 48000
slot.deprecated = True
2023-05-03 07:14:00 +03:00
print("[Voice Changer] setInfoByONNX", e)
2023-05-02 14:57:12 +03:00
print("[Voice Changer] ############## !!!! CAUTION !!!! ####################")
print("[Voice Changer] This onnxfie is depricated. Please regenerate onnxfile.")
print("[Voice Changer] ############## !!!! CAUTION !!!! ####################")
del tmp_onnx_session