bugfxi: onnx for cpu

This commit is contained in:
wataru 2023-04-14 16:38:08 +09:00
parent 1c8d891a7d
commit b41de02624
2 changed files with 12 additions and 8 deletions

View File

@ -6,9 +6,8 @@ providers = ["CPUExecutionProvider"]
class ModelWrapper:
def __init__(self, onnx_model, is_half):
def __init__(self, onnx_model):
self.onnx_model = onnx_model
self.is_half = is_half
# ort_options = onnxruntime.SessionOptions()
# ort_options.intra_op_num_threads = 8
@ -17,6 +16,11 @@ class ModelWrapper:
providers=providers
)
# input_info = s
first_input_type = self.onnx_session.get_inputs()[0].type
if first_input_type == "tensor(float)":
self.is_half = False
else:
self.is_half = True
def set_providers(self, providers, provider_options=[{}]):
self.onnx_session.set_providers(providers=providers, provider_options=provider_options)
@ -45,11 +49,11 @@ class ModelWrapper:
audio1 = self.onnx_session.run(
["audio"],
{
"feats": feats.cpu().numpy().astype(np.float16),
"p_len": p_len.cpu().numpy(),
"pitch": pitch.cpu().numpy(),
"pitchf": pitchf.cpu().numpy(),
"sid": sid.cpu().numpy(),
"feats": feats.cpu().numpy().astype(np.float32),
"p_len": p_len.cpu().numpy().astype(np.int64),
"pitch": pitch.cpu().numpy().astype(np.int64),
"pitchf": pitchf.cpu().numpy().astype(np.float32),
"sid": sid.cpu().numpy().astype(np.int64),
})
return torch.tensor(np.array(audio1))

View File

@ -113,7 +113,7 @@ class RVC:
# ONNXモデル生成
if onnx_model_file != None:
self.onnx_session = ModelWrapper(onnx_model_file, is_half=self.is_half)
self.onnx_session = ModelWrapper(onnx_model_file)
return self.get_info()
def update_settings(self, key: str, val: any):