mirror of
https://github.com/w-okada/voice-changer.git
synced 2025-02-03 16:53:55 +03:00
88 lines
3.1 KiB
Python
88 lines
3.1 KiB
Python
import torch
|
|
from torchaudio.transforms import Resample
|
|
from ..nsf_hifigan.nvSTFT import STFT # type: ignore
|
|
from ..nsf_hifigan.models import load_model, load_config # type: ignore
|
|
|
|
|
|
class Vocoder:
|
|
def __init__(self, vocoder_type, vocoder_ckpt, device=None):
|
|
if device is None:
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
self.device = device
|
|
|
|
if vocoder_type == "nsf-hifigan":
|
|
self.vocoder = NsfHifiGAN(vocoder_ckpt, device=device)
|
|
elif vocoder_type == "nsf-hifigan-log10":
|
|
self.vocoder = NsfHifiGANLog10(vocoder_ckpt, device=device)
|
|
else:
|
|
raise ValueError(f" [x] Unknown vocoder: {vocoder_type}")
|
|
|
|
self.resample_kernel = {}
|
|
self.vocoder_sample_rate = self.vocoder.sample_rate()
|
|
self.vocoder_hop_size = self.vocoder.hop_size()
|
|
self.dimension = self.vocoder.dimension()
|
|
|
|
def extract(self, audio, sample_rate, keyshift=0):
|
|
# resample
|
|
if sample_rate == self.vocoder_sample_rate:
|
|
audio_res = audio
|
|
else:
|
|
key_str = str(sample_rate)
|
|
if key_str not in self.resample_kernel:
|
|
self.resample_kernel[key_str] = Resample(sample_rate, self.vocoder_sample_rate, lowpass_filter_width=128).to(self.device)
|
|
audio_res = self.resample_kernel[key_str](audio)
|
|
|
|
# extract
|
|
mel = self.vocoder.extract(audio_res, keyshift=keyshift) # B, n_frames, bins
|
|
return mel
|
|
|
|
def infer(self, mel, f0):
|
|
f0 = f0[:, : mel.size(1), 0] # B, n_frames
|
|
audio = self.vocoder(mel, f0)
|
|
return audio
|
|
|
|
|
|
class NsfHifiGAN(torch.nn.Module):
|
|
def __init__(self, model_path, device=None):
|
|
super().__init__()
|
|
if device is None:
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
self.device = device
|
|
self.model_path = model_path
|
|
self.model = None
|
|
self.h = load_config(model_path)
|
|
self.stft = STFT(self.h.sampling_rate, self.h.num_mels, self.h.n_fft, self.h.win_size, self.h.hop_size, self.h.fmin, self.h.fmax)
|
|
|
|
def sample_rate(self):
|
|
return self.h.sampling_rate
|
|
|
|
def hop_size(self):
|
|
return self.h.hop_size
|
|
|
|
def dimension(self):
|
|
return self.h.num_mels
|
|
|
|
def extract(self, audio, keyshift=0):
|
|
mel = self.stft.get_mel(audio, keyshift=keyshift).transpose(1, 2) # B, n_frames, bins
|
|
return mel
|
|
|
|
def forward(self, mel, f0):
|
|
if self.model is None:
|
|
print("| Load HifiGAN: ", self.model_path)
|
|
self.model, self.h = load_model(self.model_path, device=self.device)
|
|
with torch.no_grad():
|
|
c = mel.transpose(1, 2)
|
|
audio = self.model(c, f0)
|
|
return audio
|
|
|
|
|
|
class NsfHifiGANLog10(NsfHifiGAN):
|
|
def forward(self, mel, f0):
|
|
if self.model is None:
|
|
print("| Load HifiGAN: ", self.model_path)
|
|
self.model, self.h = load_model(self.model_path, device=self.device)
|
|
with torch.no_grad():
|
|
c = 0.434294 * mel.transpose(1, 2)
|
|
audio = self.model(c, f0)
|
|
return audio
|