voice-changer/server/voice_changer/RVC/inferencer/OnnxRVCInferencer.py

85 lines
3.0 KiB
Python
Raw Normal View History

2023-05-02 14:57:12 +03:00
import torch
import onnxruntime
2023-05-31 08:30:35 +03:00
from const import EnumInferenceTypes
2023-05-29 11:34:35 +03:00
from voice_changer.RVC.deviceManager.DeviceManager import DeviceManager
2023-05-02 14:57:12 +03:00
from voice_changer.RVC.inferencer.Inferencer import Inferencer
import numpy as np
2023-05-03 07:14:00 +03:00
class OnnxRVCInferencer(Inferencer):
2023-09-06 02:04:39 +03:00
def loadModel(self, file: str, gpu: int, inferencerTypeVersion: str | None = None):
2023-05-31 08:30:35 +03:00
self.setProps(EnumInferenceTypes.onnxRVC, file, True, gpu)
2023-05-29 11:34:35 +03:00
(
onnxProviders,
onnxProviderOptions,
) = DeviceManager.get_instance().getOnnxExecutionProvider(gpu)
2023-05-02 14:57:12 +03:00
2023-05-29 11:34:35 +03:00
onnx_session = onnxruntime.InferenceSession(
file, providers=onnxProviders, provider_options=onnxProviderOptions
)
2023-05-02 14:57:12 +03:00
# check half-precision
2023-05-03 07:14:00 +03:00
first_input_type = onnx_session.get_inputs()[0].type
2023-05-02 14:57:12 +03:00
if first_input_type == "tensor(float)":
self.isHalf = False
else:
self.isHalf = True
self.model = onnx_session
2023-09-06 02:04:39 +03:00
self.inferencerTypeVersion = inferencerTypeVersion
2023-05-02 14:57:12 +03:00
return self
def infer(
self,
feats: torch.Tensor,
pitch_length: torch.Tensor,
2023-05-03 07:14:00 +03:00
pitch: torch.Tensor,
pitchf: torch.Tensor,
2023-05-02 14:57:12 +03:00
sid: torch.Tensor,
2023-07-01 10:45:25 +03:00
convert_length: int | None,
2023-05-02 14:57:12 +03:00
) -> torch.Tensor:
if pitch is None or pitchf is None:
raise RuntimeError("[Voice Changer] Pitch or Pitchf is not found.")
2023-05-04 20:20:33 +03:00
# print("INFER1", self.model.get_providers())
# print("INFER2", self.model.get_provider_options())
# print("INFER3", self.model.get_session_options())
2023-05-02 14:57:12 +03:00
if self.isHalf:
audio1 = self.model.run(
["audio"],
{
"feats": feats.cpu().numpy().astype(np.float16),
"p_len": pitch_length.cpu().numpy().astype(np.int64),
"pitch": pitch.cpu().numpy().astype(np.int64),
"pitchf": pitchf.cpu().numpy().astype(np.float32),
"sid": sid.cpu().numpy().astype(np.int64)
2023-05-02 14:57:12 +03:00
},
)
else:
audio1 = self.model.run(
["audio"],
{
"feats": feats.cpu().numpy().astype(np.float32),
"p_len": pitch_length.cpu().numpy().astype(np.int64),
"pitch": pitch.cpu().numpy().astype(np.int64),
"pitchf": pitchf.cpu().numpy().astype(np.float32),
"sid": sid.cpu().numpy().astype(np.int64)
2023-05-02 14:57:12 +03:00
},
)
2023-11-03 04:17:45 +03:00
if self.inferencerTypeVersion == "v2.1" or self.inferencerTypeVersion == "v2.2" or self.inferencerTypeVersion == "v1.1":
2023-09-06 02:04:39 +03:00
res = audio1[0]
else:
res = np.array(audio1)[0][0, 0]
res = np.clip(res, -1.0, 1.0)
return torch.tensor(res)
# return torch.tensor(np.array(audio1))
2023-05-31 08:30:35 +03:00
def getInferencerInfo(self):
inferencer = super().getInferencerInfo()
inferencer["onnxExecutionProvider"] = self.model.get_providers()
return inferencer