voice-changer/server/voice_changer/RVC/embedder/whisper/whisper.py
2024-02-28 23:08:49 +09:00

209 lines
6.8 KiB
Python

# from whisper_ppg.model import Whisper, ModelDimensions
# from whisper_ppg_custom._LightWhisper import LightWhisper
# from whisper_ppg_custom.Timer import Timer2
# from whisper_ppg_custom.whisper_ppg.audio import load_audio, pad_or_trim, log_mel_spectrogram
# from whisper_ppg_custom.whisper_ppg.model import Whisper, ModelDimensions
import torch
# import numpy as np
# from easy_vc_dev.utils.whisper.audio import load_audio, pad_or_trim
from .model import ModelDimensions, Whisper
# import onnx
# from onnxsim import simplify
# import json
# import onnxruntime
def load_model(path) -> Whisper:
device = "cpu"
checkpoint = torch.load(path, map_location=device)
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(device)
return model
# def pred_ppg(whisper: Whisper, wavPath: str, ppgPath: str):
# print("pred")
# # whisper = load_model("base.pt") # "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt"
# audio = load_audio(wavPath)
# audln = audio.shape[0]
# ppgln = audln // 320
# print("audio.shape1", audio.shape, audio.shape[0] / 16000)
# audio = pad_or_trim(audio)
# audio = audio[:400000]
# print("audio.shape2", audio.shape)
# print(f"whisper.device {whisper.device}")
# for i in range(5):
# with Timer2("mainPorcess timer", True) as t:
# mel = log_mel_spectrogram(audio).to(whisper.device)
# with torch.no_grad():
# ppg = whisper.encoder(mel.unsqueeze(0)).squeeze().data.cpu().float().numpy()
# print("ppg.shape", ppg.shape)
# ppg = ppg[:ppgln,]
# print(ppg.shape)
# np.save(ppgPath, ppg, allow_pickle=False)
# t.record("fin")
# print("res", ppg)
# def pred_ppg_onnx(wavPath, ppgPath):
# print("pred")
# # whisper = load_model("base.pt") # "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt"
# whisper = load_model("tiny.pt")
# audio = load_audio(wavPath)
# # audln = audio.shape[0]
# # ppgln = audln // 320
# print("audio.shape1", audio.shape, audio.shape[0] / 16000)
# audio = pad_or_trim(audio)
# audio = audio[:1000]
# print("audio.shape2", audio.shape)
# print(f"whisper.device {whisper.device}")
# onnx_session = onnxruntime.InferenceSession(
# "wencoder_sim.onnx",
# providers=["CPUExecutionProvider"],
# provider_options=[
# {
# "intra_op_num_threads": 8,
# "execution_mode": onnxruntime.ExecutionMode.ORT_PARALLEL,
# "inter_op_num_threads": 8,
# }
# ],
# )
# for i in range(5):
# with Timer2("mainPorcess timer", True) as t:
# mel = log_mel_spectrogram(audio).to(whisper.device).unsqueeze(0)
# onnx_res = onnx_session.run(
# ["ppg"],
# {
# "mel": mel.cpu().numpy(),
# },
# )
# t.record("fin")
# print("onnx_res", onnx_res)
# def export_encoder(wavPath, ppgPath):
# print("pred")
# # whisper = load_model("base.pt") # "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt"
# whisper = load_model("tiny.pt")
# audio = load_audio(wavPath)
# # audln = audio.shape[0]
# # ppgln = audln // 320
# print("audio.shape1", audio.shape, audio.shape[0] / 16000)
# audio = pad_or_trim(audio)
# print("audio.shape2", audio.shape)
# print(f"whisper.device {whisper.device}")
# mel = log_mel_spectrogram(audio).to(whisper.device).unsqueeze(0)
# input_names = ["mel"]
# output_names = ["ppg"]
# torch.onnx.export(
# whisper.encoder,
# (mel,),
# "wencoder.onnx",
# dynamic_axes={
# "mel": [2],
# },
# do_constant_folding=False,
# opset_version=17,
# verbose=False,
# input_names=input_names,
# output_names=output_names,
# )
# metadata = {
# "application": "VC_CLIENT",
# "version": "2.1",
# }
# model_onnx2 = onnx.load("wencoder.onnx")
# model_simp, check = simplify(model_onnx2)
# meta = model_simp.metadata_props.add()
# meta.key = "metadata"
# meta.value = json.dumps(metadata)
# onnx.save(model_simp, "wencoder_sim.onnx")
# def pred_ppg_onnx_w(wavPath, ppgPath):
# print("pred")
# audio = load_audio(wavPath)
# print("audio.shape1", audio.shape, audio.shape[0] / 16000)
# audio = pad_or_trim(audio)
# print("audio.shape2", audio.shape)
# onnx_session = onnxruntime.InferenceSession(
# "wencoder_sim.onnx",
# providers=["CPUExecutionProvider"],
# provider_options=[
# {
# "intra_op_num_threads": 8,
# "execution_mode": onnxruntime.ExecutionMode.ORT_PARALLEL,
# "inter_op_num_threads": 8,
# }
# ],
# )
# for i in range(5):
# with Timer2("mainPorcess timer", True) as t:
# mel = log_mel_spectrogram(audio).to("cpu").unsqueeze(0)
# # mel = mel[:, :, 1500:]
# mel = mel[:, :, 2500:]
# # mel[0, 79, 1499] = 0.1
# print("x.shape", mel.shape)
# onnx_res = onnx_session.run(
# ["ppg"],
# {
# "mel": mel.cpu().numpy(),
# },
# )
# t.record("fin")
# print("onnx_res", onnx_res)
# def export_wrapped_encoder(wavPath, ppgPath):
# print("pred")
# whisper = LightWhisper("tiny.pt")
# audio = load_audio(wavPath)
# # audln = audio.shape[0]
# # ppgln = audln // 320
# print("audio.shape1", audio.shape, audio.shape[0] / 16000)
# audio = pad_or_trim(audio)
# print("audio.shape2", audio.shape)
# mel = log_mel_spectrogram(audio).to("cpu").unsqueeze(0)
# mel = mel[:, :, 1500:]
# input_names = ["mel"]
# output_names = ["ppg"]
# torch.onnx.export(
# whisper,
# (mel,),
# "wencoder.onnx",
# dynamic_axes={
# "mel": [2],
# },
# do_constant_folding=True,
# opset_version=17,
# verbose=False,
# input_names=input_names,
# output_names=output_names,
# )
# metadata = {
# "application": "VC_CLIENT",
# "version": "2.1",
# }
# model_onnx2 = onnx.load("wencoder.onnx")
# model_simp, check = simplify(model_onnx2)
# meta = model_simp.metadata_props.add()
# meta.key = "metadata"
# meta.value = json.dumps(metadata)
# onnx.save(model_simp, "wencoder_sim.onnx")