voice-changer/server/voice_changer/RVC/inferencer/OnnxRVCInferencer.py

102 lines
3.4 KiB
Python
Raw Normal View History

2023-05-02 14:57:12 +03:00
import torch
from torch import device
import onnxruntime
from const import EnumInferenceTypes
from voice_changer.RVC.inferencer.Inferencer import Inferencer
import numpy as np
providers = ["CPUExecutionProvider"]
2023-05-03 07:14:00 +03:00
class OnnxRVCInferencer(Inferencer):
2023-05-02 14:57:12 +03:00
def loadModel(self, file: str, dev: device, isHalf: bool = True):
super().setProps(EnumInferenceTypes.onnxRVC, file, dev, isHalf)
# ort_options = onnxruntime.SessionOptions()
# ort_options.intra_op_num_threads = 8
2023-05-03 07:14:00 +03:00
onnx_session = onnxruntime.InferenceSession(file, providers=providers)
2023-05-02 14:57:12 +03:00
# check half-precision
2023-05-03 07:14:00 +03:00
first_input_type = onnx_session.get_inputs()[0].type
2023-05-02 14:57:12 +03:00
if first_input_type == "tensor(float)":
self.isHalf = False
else:
self.isHalf = True
self.model = onnx_session
2023-05-03 11:12:40 +03:00
self.setDevice(dev)
2023-05-02 14:57:12 +03:00
return self
def infer(
self,
feats: torch.Tensor,
pitch_length: torch.Tensor,
2023-05-03 07:14:00 +03:00
pitch: torch.Tensor,
pitchf: torch.Tensor,
2023-05-02 14:57:12 +03:00
sid: torch.Tensor,
) -> torch.Tensor:
if pitch is None or pitchf is None:
raise RuntimeError("[Voice Changer] Pitch or Pitchf is not found.")
2023-05-04 20:20:33 +03:00
# print("INFER1", self.model.get_providers())
# print("INFER2", self.model.get_provider_options())
# print("INFER3", self.model.get_session_options())
2023-05-02 14:57:12 +03:00
if self.isHalf:
audio1 = self.model.run(
["audio"],
{
"feats": feats.cpu().numpy().astype(np.float16),
"p_len": pitch_length.cpu().numpy().astype(np.int64),
"pitch": pitch.cpu().numpy().astype(np.int64),
"pitchf": pitchf.cpu().numpy().astype(np.float32),
"sid": sid.cpu().numpy().astype(np.int64),
},
)
else:
audio1 = self.model.run(
["audio"],
{
"feats": feats.cpu().numpy().astype(np.float32),
"p_len": pitch_length.cpu().numpy().astype(np.int64),
"pitch": pitch.cpu().numpy().astype(np.int64),
"pitchf": pitchf.cpu().numpy().astype(np.float32),
"sid": sid.cpu().numpy().astype(np.int64),
},
)
return torch.tensor(np.array(audio1))
def setHalf(self, isHalf: bool):
2023-05-03 07:14:00 +03:00
self.isHalf = isHalf
pass
# raise RuntimeError("half-precision is not changable.", self.isHalf)
2023-05-02 14:57:12 +03:00
def setDevice(self, dev: device):
2023-05-03 07:14:00 +03:00
index = dev.index
type = dev.type
if type == "cpu":
self.model.set_providers(providers=["CPUExecutionProvider"])
elif type == "cuda":
provider_options = [{"device_id": index}]
self.model.set_providers(
providers=["CUDAExecutionProvider"],
provider_options=provider_options,
)
else:
self.model.set_providers(providers=["CPUExecutionProvider"])
2023-05-02 14:57:12 +03:00
return self
2023-05-04 17:50:52 +03:00
def setDirectMLEnable(self, enable: bool):
if "DmlExecutionProvider" not in onnxruntime.get_available_providers():
print("[Voice Changer] DML is not available.")
return
if enable:
self.model.set_providers(
providers=["DmlExecutionProvider", "CPUExecutionProvider"]
)
else:
self.model.set_providers(providers=["CPUExecutionProvider"])