mirror of
https://github.com/w-okada/voice-changer.git
synced 2025-01-24 22:15:02 +03:00
58 lines
1.8 KiB
Python
58 lines
1.8 KiB
Python
|
import numpy as np
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torchaudio
|
||
|
from torch.nn import functional as F
|
||
|
from .core import upsample
|
||
|
|
||
|
class SSSLoss(nn.Module):
|
||
|
"""
|
||
|
Single-scale Spectral Loss.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, n_fft=111, alpha=1.0, overlap=0, eps=1e-7):
|
||
|
super().__init__()
|
||
|
self.n_fft = n_fft
|
||
|
self.alpha = alpha
|
||
|
self.eps = eps
|
||
|
self.hop_length = int(n_fft * (1 - overlap)) # 25% of the length
|
||
|
self.spec = torchaudio.transforms.Spectrogram(n_fft=self.n_fft, hop_length=self.hop_length, power=1, normalized=True, center=False)
|
||
|
|
||
|
def forward(self, x_true, x_pred):
|
||
|
S_true = self.spec(x_true) + self.eps
|
||
|
S_pred = self.spec(x_pred) + self.eps
|
||
|
|
||
|
converge_term = torch.mean(torch.linalg.norm(S_true - S_pred, dim = (1, 2)) / torch.linalg.norm(S_true + S_pred, dim = (1, 2)))
|
||
|
|
||
|
log_term = F.l1_loss(S_true.log(), S_pred.log())
|
||
|
|
||
|
loss = converge_term + self.alpha * log_term
|
||
|
return loss
|
||
|
|
||
|
|
||
|
class RSSLoss(nn.Module):
|
||
|
'''
|
||
|
Random-scale Spectral Loss.
|
||
|
'''
|
||
|
|
||
|
def __init__(self, fft_min, fft_max, n_scale, alpha=1.0, overlap=0, eps=1e-7, device='cuda'):
|
||
|
super().__init__()
|
||
|
self.fft_min = fft_min
|
||
|
self.fft_max = fft_max
|
||
|
self.n_scale = n_scale
|
||
|
self.lossdict = {}
|
||
|
for n_fft in range(fft_min, fft_max):
|
||
|
self.lossdict[n_fft] = SSSLoss(n_fft, alpha, overlap, eps).to(device)
|
||
|
|
||
|
def forward(self, x_pred, x_true):
|
||
|
value = 0.
|
||
|
n_ffts = torch.randint(self.fft_min, self.fft_max, (self.n_scale,))
|
||
|
for n_fft in n_ffts:
|
||
|
loss_func = self.lossdict[int(n_fft)]
|
||
|
value += loss_func(x_true, x_pred)
|
||
|
return value / self.n_scale
|
||
|
|
||
|
|
||
|
|