voice-changer/server/voice_changer/RVC/export2onnx.py

137 lines
4.5 KiB
Python
Raw Normal View History

2023-05-03 11:12:40 +03:00
import os
import json
2023-04-13 02:00:28 +03:00
import torch
from onnxsim import simplify
import onnx
2023-05-03 11:12:40 +03:00
from const import TMP_DIR, EnumInferenceTypes
from voice_changer.RVC.ModelSlot import ModelSlot
2023-05-03 12:47:14 +03:00
from voice_changer.RVC.deviceManager.DeviceManager import DeviceManager
2023-04-13 02:00:28 +03:00
2023-04-28 02:46:34 +03:00
from voice_changer.RVC.onnx.SynthesizerTrnMs256NSFsid_ONNX import (
SynthesizerTrnMs256NSFsid_ONNX,
)
from voice_changer.RVC.onnx.SynthesizerTrnMs256NSFsid_nono_ONNX import (
SynthesizerTrnMs256NSFsid_nono_ONNX,
)
from voice_changer.RVC.onnx.SynthesizerTrnMsNSFsidNono_webui_ONNX import (
SynthesizerTrnMsNSFsidNono_webui_ONNX,
)
from voice_changer.RVC.onnx.SynthesizerTrnMsNSFsid_webui_ONNX import (
SynthesizerTrnMsNSFsid_webui_ONNX,
)
2023-04-13 02:00:28 +03:00
2023-05-03 12:47:14 +03:00
def export2onnx(gpu: int, modelSlot: ModelSlot):
2023-05-03 11:12:40 +03:00
pyTorchModelFile = modelSlot.pyTorchModelFile
output_file = os.path.splitext(os.path.basename(pyTorchModelFile))[0] + ".onnx"
output_file_simple = (
os.path.splitext(os.path.basename(pyTorchModelFile))[0] + "_simple.onnx"
)
output_path = os.path.join(TMP_DIR, output_file)
output_path_simple = os.path.join(TMP_DIR, output_file_simple)
metadata = {
"application": "VC_CLIENT",
"version": "2",
# ↓EnumInferenceTypesのままだとシリアライズできないのでテキスト化
"modelType": modelSlot.modelType.value,
"samplingRate": modelSlot.samplingRate,
"f0": modelSlot.f0,
"embChannels": modelSlot.embChannels,
"embedder": modelSlot.embedder.value,
}
2023-05-03 12:47:14 +03:00
gpuMomory = DeviceManager.get_instance().getDeviceMemory(gpu)
print(f"[Voice Changer] exporting onnx... gpu_id:{gpu} gpu_mem:{gpuMomory}")
2023-05-03 11:12:40 +03:00
2023-05-03 12:47:14 +03:00
if gpuMomory > 0:
2023-05-03 11:12:40 +03:00
_export2onnx(pyTorchModelFile, output_path, output_path_simple, True, metadata)
else:
print(
"[Voice Changer] Warning!!! onnx export with float32. maybe size is doubled."
)
_export2onnx(pyTorchModelFile, output_path, output_path_simple, False, metadata)
return output_file_simple
def _export2onnx(input_model, output_model, output_model_simple, is_half, metadata):
2023-04-13 02:00:28 +03:00
cpt = torch.load(input_model, map_location="cpu")
if is_half:
dev = torch.device("cuda", index=0)
else:
dev = torch.device("cpu")
2023-05-03 11:12:40 +03:00
# EnumInferenceTypesのままだとシリアライズできないのでテキスト化
if metadata["modelType"] == EnumInferenceTypes.pyTorchRVC.value:
net_g_onnx = SynthesizerTrnMs256NSFsid_ONNX(*cpt["config"], is_half=is_half)
2023-05-03 11:12:40 +03:00
elif metadata["modelType"] == EnumInferenceTypes.pyTorchWebUI.value:
2023-04-24 11:39:31 +03:00
net_g_onnx = SynthesizerTrnMsNSFsid_webui_ONNX(**cpt["params"], is_half=is_half)
2023-05-03 11:12:40 +03:00
elif metadata["modelType"] == EnumInferenceTypes.pyTorchRVCNono.value:
net_g_onnx = SynthesizerTrnMs256NSFsid_nono_ONNX(*cpt["config"])
2023-05-03 11:12:40 +03:00
elif metadata["modelType"] == EnumInferenceTypes.pyTorchWebUINono.value:
2023-04-24 11:39:31 +03:00
net_g_onnx = SynthesizerTrnMsNSFsidNono_webui_ONNX(**cpt["params"])
2023-05-03 11:12:40 +03:00
else:
print(
"unknwon::::: ",
metadata["modelType"],
EnumInferenceTypes.pyTorchWebUI.value,
)
2023-04-14 09:25:52 +03:00
net_g_onnx.eval().to(dev)
2023-04-13 02:00:28 +03:00
net_g_onnx.load_state_dict(cpt["weight"], strict=False)
if is_half:
net_g_onnx = net_g_onnx.half()
if is_half:
2023-04-24 11:39:31 +03:00
feats = torch.HalfTensor(1, 2192, metadata["embChannels"]).to(dev)
2023-04-13 02:00:28 +03:00
else:
2023-04-24 11:39:31 +03:00
feats = torch.FloatTensor(1, 2192, metadata["embChannels"]).to(dev)
2023-04-13 02:00:28 +03:00
p_len = torch.LongTensor([2192]).to(dev)
sid = torch.LongTensor([0]).to(dev)
2023-04-28 02:46:34 +03:00
if metadata["f0"] is True:
pitch = torch.zeros(1, 2192, dtype=torch.int64).to(dev)
pitchf = torch.FloatTensor(1, 2192).to(dev)
input_names = ["feats", "p_len", "pitch", "pitchf", "sid"]
2023-04-28 02:46:34 +03:00
inputs = (
feats,
p_len,
pitch,
pitchf,
sid,
)
else:
input_names = ["feats", "p_len", "sid"]
2023-04-28 02:46:34 +03:00
inputs = (
feats,
p_len,
sid,
)
output_names = [
"audio",
]
torch.onnx.export(
net_g_onnx,
inputs,
output_model,
dynamic_axes={
"feats": [1],
"pitch": [1],
"pitchf": [1],
},
do_constant_folding=False,
opset_version=17,
verbose=False,
input_names=input_names,
output_names=output_names,
)
2023-04-13 02:00:28 +03:00
model_onnx2 = onnx.load(output_model)
model_simp, check = simplify(model_onnx2)
meta = model_simp.metadata_props.add()
meta.key = "metadata"
meta.value = json.dumps(metadata)
2023-04-13 02:00:28 +03:00
onnx.save(model_simp, output_model_simple)