mirror of
https://github.com/w-okada/voice-changer.git
synced 2025-02-09 11:42:29 +03:00
209 lines
6.8 KiB
Python
209 lines
6.8 KiB
Python
# from whisper_ppg.model import Whisper, ModelDimensions
|
|
|
|
# from whisper_ppg_custom._LightWhisper import LightWhisper
|
|
# from whisper_ppg_custom.Timer import Timer2
|
|
# from whisper_ppg_custom.whisper_ppg.audio import load_audio, pad_or_trim, log_mel_spectrogram
|
|
# from whisper_ppg_custom.whisper_ppg.model import Whisper, ModelDimensions
|
|
import torch
|
|
|
|
# import numpy as np
|
|
# from easy_vc_dev.utils.whisper.audio import load_audio, pad_or_trim
|
|
from .model import ModelDimensions, Whisper
|
|
|
|
# import onnx
|
|
|
|
# from onnxsim import simplify
|
|
# import json
|
|
|
|
# import onnxruntime
|
|
|
|
|
|
def load_model(path) -> Whisper:
|
|
device = "cpu"
|
|
checkpoint = torch.load(path, map_location=device)
|
|
dims = ModelDimensions(**checkpoint["dims"])
|
|
model = Whisper(dims)
|
|
model.load_state_dict(checkpoint["model_state_dict"])
|
|
model = model.to(device)
|
|
return model
|
|
|
|
|
|
# def pred_ppg(whisper: Whisper, wavPath: str, ppgPath: str):
|
|
# print("pred")
|
|
# # whisper = load_model("base.pt") # "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt"
|
|
# audio = load_audio(wavPath)
|
|
# audln = audio.shape[0]
|
|
# ppgln = audln // 320
|
|
# print("audio.shape1", audio.shape, audio.shape[0] / 16000)
|
|
# audio = pad_or_trim(audio)
|
|
# audio = audio[:400000]
|
|
# print("audio.shape2", audio.shape)
|
|
# print(f"whisper.device {whisper.device}")
|
|
# for i in range(5):
|
|
# with Timer2("mainPorcess timer", True) as t:
|
|
# mel = log_mel_spectrogram(audio).to(whisper.device)
|
|
# with torch.no_grad():
|
|
# ppg = whisper.encoder(mel.unsqueeze(0)).squeeze().data.cpu().float().numpy()
|
|
# print("ppg.shape", ppg.shape)
|
|
# ppg = ppg[:ppgln,]
|
|
# print(ppg.shape)
|
|
# np.save(ppgPath, ppg, allow_pickle=False)
|
|
# t.record("fin")
|
|
# print("res", ppg)
|
|
|
|
|
|
# def pred_ppg_onnx(wavPath, ppgPath):
|
|
# print("pred")
|
|
# # whisper = load_model("base.pt") # "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt"
|
|
# whisper = load_model("tiny.pt")
|
|
# audio = load_audio(wavPath)
|
|
# # audln = audio.shape[0]
|
|
# # ppgln = audln // 320
|
|
# print("audio.shape1", audio.shape, audio.shape[0] / 16000)
|
|
# audio = pad_or_trim(audio)
|
|
# audio = audio[:1000]
|
|
# print("audio.shape2", audio.shape)
|
|
# print(f"whisper.device {whisper.device}")
|
|
# onnx_session = onnxruntime.InferenceSession(
|
|
# "wencoder_sim.onnx",
|
|
# providers=["CPUExecutionProvider"],
|
|
# provider_options=[
|
|
# {
|
|
# "intra_op_num_threads": 8,
|
|
# "execution_mode": onnxruntime.ExecutionMode.ORT_PARALLEL,
|
|
# "inter_op_num_threads": 8,
|
|
# }
|
|
# ],
|
|
# )
|
|
|
|
# for i in range(5):
|
|
# with Timer2("mainPorcess timer", True) as t:
|
|
# mel = log_mel_spectrogram(audio).to(whisper.device).unsqueeze(0)
|
|
# onnx_res = onnx_session.run(
|
|
# ["ppg"],
|
|
# {
|
|
# "mel": mel.cpu().numpy(),
|
|
# },
|
|
# )
|
|
# t.record("fin")
|
|
# print("onnx_res", onnx_res)
|
|
|
|
|
|
# def export_encoder(wavPath, ppgPath):
|
|
# print("pred")
|
|
# # whisper = load_model("base.pt") # "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt"
|
|
# whisper = load_model("tiny.pt")
|
|
# audio = load_audio(wavPath)
|
|
# # audln = audio.shape[0]
|
|
# # ppgln = audln // 320
|
|
# print("audio.shape1", audio.shape, audio.shape[0] / 16000)
|
|
# audio = pad_or_trim(audio)
|
|
# print("audio.shape2", audio.shape)
|
|
# print(f"whisper.device {whisper.device}")
|
|
|
|
# mel = log_mel_spectrogram(audio).to(whisper.device).unsqueeze(0)
|
|
# input_names = ["mel"]
|
|
# output_names = ["ppg"]
|
|
|
|
# torch.onnx.export(
|
|
# whisper.encoder,
|
|
# (mel,),
|
|
# "wencoder.onnx",
|
|
# dynamic_axes={
|
|
# "mel": [2],
|
|
# },
|
|
# do_constant_folding=False,
|
|
# opset_version=17,
|
|
# verbose=False,
|
|
# input_names=input_names,
|
|
# output_names=output_names,
|
|
# )
|
|
|
|
# metadata = {
|
|
# "application": "VC_CLIENT",
|
|
# "version": "2.1",
|
|
# }
|
|
# model_onnx2 = onnx.load("wencoder.onnx")
|
|
# model_simp, check = simplify(model_onnx2)
|
|
# meta = model_simp.metadata_props.add()
|
|
# meta.key = "metadata"
|
|
# meta.value = json.dumps(metadata)
|
|
# onnx.save(model_simp, "wencoder_sim.onnx")
|
|
|
|
|
|
# def pred_ppg_onnx_w(wavPath, ppgPath):
|
|
# print("pred")
|
|
# audio = load_audio(wavPath)
|
|
# print("audio.shape1", audio.shape, audio.shape[0] / 16000)
|
|
# audio = pad_or_trim(audio)
|
|
# print("audio.shape2", audio.shape)
|
|
# onnx_session = onnxruntime.InferenceSession(
|
|
# "wencoder_sim.onnx",
|
|
# providers=["CPUExecutionProvider"],
|
|
# provider_options=[
|
|
# {
|
|
# "intra_op_num_threads": 8,
|
|
# "execution_mode": onnxruntime.ExecutionMode.ORT_PARALLEL,
|
|
# "inter_op_num_threads": 8,
|
|
# }
|
|
# ],
|
|
# )
|
|
|
|
# for i in range(5):
|
|
# with Timer2("mainPorcess timer", True) as t:
|
|
# mel = log_mel_spectrogram(audio).to("cpu").unsqueeze(0)
|
|
# # mel = mel[:, :, 1500:]
|
|
# mel = mel[:, :, 2500:]
|
|
# # mel[0, 79, 1499] = 0.1
|
|
|
|
# print("x.shape", mel.shape)
|
|
# onnx_res = onnx_session.run(
|
|
# ["ppg"],
|
|
# {
|
|
# "mel": mel.cpu().numpy(),
|
|
# },
|
|
# )
|
|
# t.record("fin")
|
|
# print("onnx_res", onnx_res)
|
|
|
|
|
|
# def export_wrapped_encoder(wavPath, ppgPath):
|
|
# print("pred")
|
|
# whisper = LightWhisper("tiny.pt")
|
|
# audio = load_audio(wavPath)
|
|
# # audln = audio.shape[0]
|
|
# # ppgln = audln // 320
|
|
# print("audio.shape1", audio.shape, audio.shape[0] / 16000)
|
|
# audio = pad_or_trim(audio)
|
|
# print("audio.shape2", audio.shape)
|
|
|
|
# mel = log_mel_spectrogram(audio).to("cpu").unsqueeze(0)
|
|
# mel = mel[:, :, 1500:]
|
|
# input_names = ["mel"]
|
|
# output_names = ["ppg"]
|
|
|
|
# torch.onnx.export(
|
|
# whisper,
|
|
# (mel,),
|
|
# "wencoder.onnx",
|
|
# dynamic_axes={
|
|
# "mel": [2],
|
|
# },
|
|
# do_constant_folding=True,
|
|
# opset_version=17,
|
|
# verbose=False,
|
|
# input_names=input_names,
|
|
# output_names=output_names,
|
|
# )
|
|
|
|
# metadata = {
|
|
# "application": "VC_CLIENT",
|
|
# "version": "2.1",
|
|
# }
|
|
# model_onnx2 = onnx.load("wencoder.onnx")
|
|
# model_simp, check = simplify(model_onnx2)
|
|
# meta = model_simp.metadata_props.add()
|
|
# meta.key = "metadata"
|
|
# meta.value = json.dumps(metadata)
|
|
# onnx.save(model_simp, "wencoder_sim.onnx")
|