voice-changer/server/voice_changer/RVC/ModelWrapper.py

95 lines
3.5 KiB
Python
Raw Normal View History

2023-04-07 21:11:37 +03:00
import onnxruntime
import torch
import numpy as np
import json
2023-04-07 21:11:37 +03:00
# providers = ['OpenVINOExecutionProvider', "CUDAExecutionProvider", "DmlExecutionProvider", "CPUExecutionProvider"]
providers = ["CPUExecutionProvider"]
class ModelWrapper:
2023-04-14 10:38:08 +03:00
def __init__(self, onnx_model):
2023-04-07 21:11:37 +03:00
self.onnx_model = onnx_model
# ort_options = onnxruntime.SessionOptions()
# ort_options.intra_op_num_threads = 8
self.onnx_session = onnxruntime.InferenceSession(
self.onnx_model,
providers=providers
)
# input_info = s
2023-04-14 10:38:08 +03:00
first_input_type = self.onnx_session.get_inputs()[0].type
if first_input_type == "tensor(float)":
self.is_half = False
else:
self.is_half = True
modelmeta = self.onnx_session.get_modelmeta()
try:
metadata = json.loads(modelmeta.custom_metadata_map["metadata"])
self.samplingRate = metadata["samplingRate"]
self.f0 = metadata["f0"]
2023-04-24 11:39:31 +03:00
self.embChannels = metadata["embChannels"]
print(f"[Voice Changer] Onnx metadata: sr:{self.samplingRate}, f0:{self.f0}")
except:
self.samplingRate = -1
self.f0 = True
print(f"[Voice Changer] Onnx version is old. Please regenerate onnxfile. Fallback to default")
print(f"[Voice Changer] Onnx metadata: sr:{self.samplingRate}, f0:{self.f0}")
def getSamplingRate(self):
return self.samplingRate
def getF0(self):
return self.f0
2023-04-07 21:11:37 +03:00
2023-04-24 11:39:31 +03:00
def getEmbChannels(self):
return self.embChannels
2023-04-07 21:11:37 +03:00
def set_providers(self, providers, provider_options=[{}]):
self.onnx_session.set_providers(providers=providers, provider_options=provider_options)
def get_providers(self):
return self.onnx_session.get_providers()
def infer_pitchless(self, feats, p_len, sid):
2023-04-07 21:11:37 +03:00
if self.is_half:
audio1 = self.onnx_session.run(
["audio"],
{
"feats": feats.cpu().numpy().astype(np.float16),
"p_len": p_len.cpu().numpy().astype(np.int64),
"sid": sid.cpu().numpy().astype(np.int64),
})
else:
audio1 = self.onnx_session.run(
["audio"],
{
"feats": feats.cpu().numpy().astype(np.float32),
"p_len": p_len.cpu().numpy().astype(np.int64),
"sid": sid.cpu().numpy().astype(np.int64),
})
return torch.tensor(np.array(audio1))
2023-04-07 21:11:37 +03:00
def infer(self, feats, p_len, pitch, pitchf, sid):
if self.is_half:
2023-04-07 21:11:37 +03:00
audio1 = self.onnx_session.run(
["audio"],
{
"feats": feats.cpu().numpy().astype(np.float16),
"p_len": p_len.cpu().numpy().astype(np.int64),
"pitch": pitch.cpu().numpy().astype(np.int64),
"pitchf": pitchf.cpu().numpy().astype(np.float32),
"sid": sid.cpu().numpy().astype(np.int64),
})
else:
audio1 = self.onnx_session.run(
["audio"],
{
2023-04-14 10:38:08 +03:00
"feats": feats.cpu().numpy().astype(np.float32),
"p_len": p_len.cpu().numpy().astype(np.int64),
"pitch": pitch.cpu().numpy().astype(np.int64),
"pitchf": pitchf.cpu().numpy().astype(np.float32),
"sid": sid.cpu().numpy().astype(np.int64),
2023-04-07 21:11:37 +03:00
})
return torch.tensor(np.array(audio1))