merge master

This commit is contained in:
wataru 2023-06-26 01:06:23 +09:00
commit 21b9a9fe24
11 changed files with 1545 additions and 30 deletions

View File

@ -66,6 +66,7 @@ class EnumInferenceTypes(Enum):
pyTorchRVCv2Nono = "pyTorchRVCv2Nono"
pyTorchWebUI = "pyTorchWebUI"
pyTorchWebUINono = "pyTorchWebUINono"
pyTorchVoRASbeta = "pyTorchVoRASbeta"
onnxRVC = "onnxRVC"
onnxRVCNono = "onnxRVCNono"

View File

@ -36,8 +36,31 @@ class RVCModelSlotGenerator(ModelSlotGenerator):
def _setInfoByPytorch(cls, slot: ModelSlot):
cpt = torch.load(slot.modelFile, map_location="cpu")
config_len = len(cpt["config"])
print(cpt["version"])
if cpt["version"] == "voras_beta":
slot.f0 = True if cpt["f0"] == 1 else False
slot.modelType = EnumInferenceTypes.pyTorchVoRASbeta.value
slot.embChannels = 768
slot.embOutputLayer = (
cpt["embedder_output_layer"] if "embedder_output_layer" in cpt else 9
)
slot.useFinalProj = False
if config_len == 18:
slot.embedder = cpt["embedder_name"]
if slot.embedder.endswith("768"):
slot.embedder = slot.embedder[:-3]
if slot.embedder == EnumEmbedderTypes.hubert.value:
slot.embedder = EnumEmbedderTypes.hubert.value
elif slot.embedder == EnumEmbedderTypes.contentvec.value:
slot.embedder = EnumEmbedderTypes.contentvec.value
elif slot.embedder == EnumEmbedderTypes.hubert_jp.value:
slot.embedder = EnumEmbedderTypes.hubert_jp.value
else:
raise RuntimeError("[Voice Changer][setInfoByONNX] unknown embedder")
elif config_len == 18:
# Original RVC
slot.f0 = True if cpt["f0"] == 1 else False
version = cpt.get("version", "v1")

View File

@ -8,7 +8,7 @@ from voice_changer.RVC.inferencer.RVCInferencerv2 import RVCInferencerv2
from voice_changer.RVC.inferencer.RVCInferencerv2Nono import RVCInferencerv2Nono
from voice_changer.RVC.inferencer.WebUIInferencer import WebUIInferencer
from voice_changer.RVC.inferencer.WebUIInferencerNono import WebUIInferencerNono
from voice_changer.RVC.inferencer.VorasInferencebeta import VoRASInferencer
class InferencerManager:
currentInferencer: Inferencer | None = None
@ -37,6 +37,8 @@ class InferencerManager:
return RVCInferencerNono().loadModel(file, gpu)
elif inferencerType == EnumInferenceTypes.pyTorchRVCv2 or inferencerType == EnumInferenceTypes.pyTorchRVCv2.value:
return RVCInferencerv2().loadModel(file, gpu)
elif inferencerType == EnumInferenceTypes.pyTorchVoRASbeta or inferencerType == EnumInferenceTypes.pyTorchVoRASbeta.value:
return VoRASInferencer().loadModel(file, gpu)
elif inferencerType == EnumInferenceTypes.pyTorchRVCv2Nono or inferencerType == EnumInferenceTypes.pyTorchRVCv2Nono.value:
return RVCInferencerv2Nono().loadModel(file, gpu)
elif inferencerType == EnumInferenceTypes.pyTorchWebUI or inferencerType == EnumInferenceTypes.pyTorchWebUI.value:

View File

@ -0,0 +1,39 @@
import torch
from torch import device
from const import EnumInferenceTypes
from voice_changer.RVC.inferencer.Inferencer import Inferencer
from voice_changer.RVC.deviceManager.DeviceManager import DeviceManager
from .voras_beta.models import Synthesizer
class VoRASInferencer(Inferencer):
def loadModel(self, file: str, gpu: device):
super().setProps(EnumInferenceTypes.pyTorchVoRASbeta, file, False, gpu)
dev = DeviceManager.get_instance().getDevice(gpu)
self.isHalf = False # DeviceManager.get_instance().halfPrecisionAvailable(gpu)
cpt = torch.load(file, map_location="cpu")
model = Synthesizer(**cpt["params"])
model.eval()
model.load_state_dict(cpt["weight"], strict=False)
model.remove_weight_norm()
model.change_speaker(0)
model = model.to(dev)
self.model = model
print("load model comprete")
return self
def infer(
self,
feats: torch.Tensor,
pitch_length: torch.Tensor,
pitch: torch.Tensor,
pitchf: torch.Tensor,
sid: torch.Tensor,
) -> torch.Tensor:
return self.model.infer(feats, pitch_length, pitch, pitchf, sid)

View File

@ -0,0 +1,165 @@
import math
import torch
from torch.nn import functional as F
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += (
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
)
return kl
def rand_gumbel(shape):
"""Sample from the Gumbel distribution, protect from overflows."""
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
return -torch.log(-torch.log(uniform_samples))
def rand_gumbel_like(x):
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
return g
def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
r = x[i, :, idx_str:idx_end]
ret[i, :, :r.size(1)] = r
return ret
def slice_segments2(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
r = x[i, idx_str:idx_end]
ret[i, :r.size(0)] = r
return ret
def rand_slice_segments(x, x_lengths, segment_size=4, ids_str=None):
b, d, t = x.size()
if ids_str is None:
ids_str = torch.zeros([b]).to(device=x.device, dtype=x_lengths.dtype)
ids_str_max = torch.maximum(torch.zeros_like(x_lengths).to(device=x_lengths.device ,dtype=x_lengths.dtype), x_lengths - segment_size + 1 - ids_str)
ids_str += (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
num_timescales - 1
)
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length)
return signal
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal.to(dtype=x.dtype, device=x.device)
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
def subsequent_mask(length):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2, 3) * mask
return path
def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if clip_value is not None:
clip_value = float(clip_value)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1.0 / norm_type)
return total_norm

View File

@ -0,0 +1,61 @@
from typing import *
from pydantic import BaseModel
class TrainConfigTrain(BaseModel):
log_interval: int
seed: int
epochs: int
learning_rate: float
betas: List[float]
eps: float
batch_size: int
fp16_run: bool
lr_decay: float
segment_size: int
init_lr_ratio: int
warmup_epochs: int
c_mel: int
c_kl: float
class TrainConfigData(BaseModel):
max_wav_value: float
sampling_rate: int
filter_length: int
hop_length: int
win_length: int
n_mel_channels: int
mel_fmin: float
mel_fmax: Any
class TrainConfigModel(BaseModel):
emb_channels: int
inter_channels: int
n_layers: int
upsample_rates: List[int]
use_spectral_norm: bool
gin_channels: int
spk_embed_dim: int
class TrainConfig(BaseModel):
version: Literal["voras"] = "voras"
train: TrainConfigTrain
data: TrainConfigData
model: TrainConfigModel
class DatasetMetaItem(BaseModel):
gt_wav: str
co256: str
f0: Optional[str]
f0nsf: Optional[str]
speaker_id: int
class DatasetMetadata(BaseModel):
files: Dict[str, DatasetMetaItem]
# mute: DatasetMetaItem

View File

@ -0,0 +1,238 @@
import math
import os
import sys
import numpy as np
import torch
from torch import nn
from torch.nn import Conv2d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from . import commons, modules
from .commons import get_padding
from .modules import (ConvNext2d, HarmonicEmbedder, IMDCTSymExpHead,
LoRALinear1d, SnakeFilter, WaveBlock)
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
sr2sr = {
"24k": 24000,
"32k": 32000,
"40k": 40000,
"48k": 48000,
}
class GeneratorVoras(torch.nn.Module):
def __init__(
self,
emb_channels,
inter_channels,
gin_channels,
n_layers,
sr,
hop_length,
):
super(GeneratorVoras, self).__init__()
self.n_layers = n_layers
self.emb_pitch = HarmonicEmbedder(768, inter_channels, gin_channels, 16, 15) # # pitch 256
self.plinear = LoRALinear1d(inter_channels, inter_channels, gin_channels, r=8)
self.glinear = weight_norm(nn.Conv1d(gin_channels, inter_channels, 1))
self.resblocks = nn.ModuleList()
self.init_linear = LoRALinear1d(emb_channels, inter_channels, gin_channels, r=4)
for _ in range(self.n_layers):
self.resblocks.append(WaveBlock(inter_channels, gin_channels, [9] * 2, [1] * 2, [1, 9], 2, r=4))
self.head = IMDCTSymExpHead(inter_channels, gin_channels, hop_length, padding="center", sample_rate=sr)
self.post = SnakeFilter(4, 8, 9, 2, eps=1e-5)
def forward(self, x, pitchf, x_mask, g):
x = self.init_linear(x, g) + self.plinear(self.emb_pitch(pitchf, g), g) + self.glinear(g)
for i in range(self.n_layers):
x = self.resblocks[i](x, x_mask, g)
x = x * x_mask
x = self.head(x, g)
x = self.post(x)
return torch.tanh(x)
def remove_weight_norm(self):
self.plinear.remove_weight_norm()
remove_weight_norm(self.glinear)
for l in self.resblocks:
l.remove_weight_norm()
self.init_linear.remove_weight_norm()
self.head.remove_weight_norm()
self.post.remove_weight_norm()
def fix_speaker(self, g):
self.plinear.fix_speaker(g)
self.init_linear.fix_speaker(g)
for l in self.resblocks:
l.fix_speaker(g)
self.head.fix_speaker(g)
def unfix_speaker(self, g):
self.plinear.unfix_speaker(g)
self.init_linear.unfix_speaker(g)
for l in self.resblocks:
l.unfix_speaker(g)
self.head.unfix_speaker(g)
class Synthesizer(nn.Module):
def __init__(
self,
segment_size,
n_fft,
hop_length,
inter_channels,
n_layers,
spk_embed_dim,
gin_channels,
emb_channels,
sr,
**kwargs
):
super().__init__()
if type(sr) == type("strr"):
sr = sr2sr[sr]
self.segment_size = segment_size
self.n_fft = n_fft
self.hop_length = hop_length
self.inter_channels = inter_channels
self.n_layers = n_layers
self.spk_embed_dim = spk_embed_dim
self.gin_channels = gin_channels
self.emb_channels = emb_channels
self.sr = sr
self.dec = GeneratorVoras(
emb_channels,
inter_channels,
gin_channels,
n_layers,
sr,
hop_length
)
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
print(
"gin_channels:",
gin_channels,
"self.spk_embed_dim:",
self.spk_embed_dim,
"emb_channels:",
emb_channels,
)
self.speaker = None
def remove_weight_norm(self):
self.dec.remove_weight_norm()
def change_speaker(self, sid: int):
if self.speaker is not None:
g = self.emb_g(torch.from_numpy(np.array(self.speaker))).unsqueeze(-1)
self.dec.unfix_speaker(g)
g = self.emb_g(torch.from_numpy(np.array(sid))).unsqueeze(-1)
self.dec.fix_speaker(g)
self.speaker = sid
def forward(
self, phone, phone_lengths, pitch, pitchf, ds
):
g = self.emb_g(ds).unsqueeze(-1)
x = torch.transpose(phone, 1, -1)
x_mask = torch.unsqueeze(commons.sequence_mask(phone_lengths, x.size(2)), 1).to(phone.dtype)
x_slice, ids_slice = commons.rand_slice_segments(
x, phone_lengths, self.segment_size
)
pitchf_slice = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
mask_slice = commons.slice_segments(x_mask, ids_slice, self.segment_size)
o = self.dec(x_slice, pitchf_slice, mask_slice, g)
return o, ids_slice, x_mask, g
def infer(self, phone, phone_lengths, pitch, nsff0, sid, max_len=None):
g = self.emb_g(sid).unsqueeze(-1)
x = torch.transpose(phone, 1, -1)
x_mask = torch.unsqueeze(commons.sequence_mask(phone_lengths, x.size(2)), 1).to(phone.dtype)
o = self.dec((x * x_mask)[:, :, :max_len], nsff0, x_mask, g)
return o, x_mask, (None, None, None, None)
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, gin_channels, upsample_rates, final_dim=256, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
self.use_spectral_norm = use_spectral_norm
self.init_kernel_size = upsample_rates[-1] * 3
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
N = len(upsample_rates)
self.init_conv = norm_f(Conv2d(1, final_dim // (2 ** (N - 1)), (self.init_kernel_size, 1), (upsample_rates[-1], 1)))
self.convs = nn.ModuleList()
for i, u in enumerate(upsample_rates[::-1][1:], start=1):
self.convs.append(
ConvNext2d(
final_dim // (2 ** (N - i)),
final_dim // (2 ** (N - i - 1)),
gin_channels,
(u*3, 1),
(u, 1),
4,
r=2 + i//2
)
)
self.conv_post = weight_norm(Conv2d(final_dim, 1, (3, 1), (1, 1)))
def forward(self, x, g):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (n_pad, 0), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
x = torch.flip(x, dims=[2])
x = F.pad(x, [0, 0, 0, self.init_kernel_size - 1], mode="constant")
x = self.init_conv(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = torch.flip(x, dims=[2])
fmap.append(x)
for i, l in enumerate(self.convs):
x = l(x, g)
fmap.append(x)
x = F.pad(x, [0, 0, 2, 0], mode="constant")
x = self.conv_post(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, upsample_rates, gin_channels, periods=[2, 3, 5, 7, 11, 17], **kwargs):
super(MultiPeriodDiscriminator, self).__init__()
discs = [
DiscriminatorP(i, gin_channels, upsample_rates, use_spectral_norm=False) for i in periods
]
self.ups = np.prod(upsample_rates)
self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat, g):
fmap_rs = []
fmap_gs = []
y_d_rs = []
y_d_gs = []
for d in self.discriminators:
y_d_r, fmap_r = d(y, g)
y_d_g, fmap_g = d(y_hat, g)
y_d_rs.append(y_d_r)
y_d_gs.append(y_d_g)
fmap_rs.append(fmap_r)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs

View File

@ -0,0 +1,496 @@
import math
import numpy as np
import scipy
import torch
from torch import nn
from torch.nn import Conv1d, Conv2d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
from . import commons, modules
from .commons import get_padding, init_weights
from .transforms import piecewise_rational_quadratic_transform
LRELU_SLOPE = 0.1
class HarmonicEmbedder(nn.Module):
def __init__(self, num_embeddings, embedding_dim, gin_channels, num_head, num_harmonic=0, f0_min=50., f0_max=1100., device="cuda"):
super(HarmonicEmbedder, self).__init__()
self.embedding_dim = embedding_dim
self.num_head = num_head
self.num_harmonic = num_harmonic
f0_mel_min = np.log(1 + f0_min / 700)
f0_mel_max = np.log(1 + f0_max * (1 + num_harmonic) / 700)
self.sequence = torch.from_numpy(np.linspace(f0_mel_min, f0_mel_max, num_embeddings-2))
self.emb_layer = torch.nn.Embedding(num_embeddings, embedding_dim)
self.linear_q = Conv1d(gin_channels, num_head * (1 + num_harmonic), 1)
self.weight = None
def forward(self, x, g):
b, l = x.size()
non_zero = (x != 0.).to(dtype=torch.long).unsqueeze(1)
mel = torch.log(1 + x / 700).unsqueeze(1)
harmonies = torch.arange(1 + self.num_harmonic, device=x.device, dtype=x.dtype).view(1, 1 + self.num_harmonic, 1) + 1.
ix = torch.searchsorted(self.sequence.to(x.device), mel * harmonies).to(x.device) + 1
ix = ix * non_zero
emb = self.emb_layer(ix).transpose(1, 3).reshape(b, self.num_head, self.embedding_dim // self.num_head, 1 + self.num_harmonic, l)
if self.weight is None:
weight = torch.nn.functional.softmax(self.linear_q(g).reshape(b, self.num_head, 1, 1 + self.num_harmonic, 1), 3)
else:
weight = self.weight
res = torch.sum(emb * weight, dim=3).reshape(b, self.embedding_dim, l)
return res
def fix_speaker(self, g):
self.weight = torch.nn.functional.softmax(self.linear_q(g).reshape(1, self.num_head, 1, 1 + self.num_harmonic, 1), 3)
def unfix_speaker(self, g):
self.weight = None
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class DilatedCausalConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, groups=1, dilation=1, bias=True):
super(DilatedCausalConv1d, self).__init__()
self.kernel_size = kernel_size
self.dilation = dilation
self.stride = stride
self.conv = weight_norm(nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, groups=groups, dilation=dilation, bias=bias))
def forward(self, x):
x = torch.flip(x, [2])
x = F.pad(x, [0, (self.kernel_size - 1) * self.dilation], mode="constant", value=0.)
size = x.shape[2] // self.stride
x = self.conv(x)[:, :, :size]
x = torch.flip(x, [2])
return x
def remove_weight_norm(self):
remove_weight_norm(self.conv)
class CausalConvTranspose1d(nn.Module):
"""
padding = 0, dilation = 1のとき
Lout = (Lin - 1) * stride + kernel_rate * stride + output_padding
Lout = Lin * stride + (kernel_rate - 1) * stride + output_padding
output_paddingいらないね
"""
def __init__(self, in_channels, out_channels, kernel_rate=3, stride=1, groups=1):
super(CausalConvTranspose1d, self).__init__()
kernel_size = kernel_rate * stride
self.trim_size = (kernel_rate - 1) * stride
self.conv = weight_norm(nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride=stride, groups=groups))
def forward(self, x):
x = self.conv(x)
return x[:, :, :-self.trim_size]
def remove_weight_norm(self):
remove_weight_norm(self.conv)
class LoRALinear1d(nn.Module):
def __init__(self, in_channels, out_channels, info_channels, r):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.info_channels = info_channels
self.r = r
self.main_fc = weight_norm(nn.Conv1d(in_channels, out_channels, 1))
self.adapter_in = nn.Conv1d(info_channels, in_channels * r, 1)
self.adapter_out = nn.Conv1d(info_channels, out_channels * r, 1)
nn.init.normal_(self.adapter_in.weight.data, 0, 0.01)
nn.init.constant_(self.adapter_out.weight.data, 1e-6)
self.adapter_in = weight_norm(self.adapter_in)
self.adapter_out = weight_norm(self.adapter_out)
self.speaker_fixed = False
def forward(self, x, g):
x_ = self.main_fc(x)
if not self.speaker_fixed:
a_in = self.adapter_in(g).view(-1, self.in_channels, self.r)
a_out = self.adapter_out(g).view(-1, self.r, self.out_channels)
l = torch.einsum("brl,brc->bcl", torch.einsum("bcl,bcr->brl", x, a_in), a_out)
x_ = x_ + l
return x_
def remove_weight_norm(self):
remove_weight_norm(self.main_fc)
remove_weight_norm(self.adapter_in)
remove_weight_norm(self.adapter_out)
def fix_speaker(self, g):
self.speaker_fixed = True
a_in = self.adapter_in(g).view(-1, self.in_channels, self.r)
a_out = self.adapter_out(g).view(-1, self.r, self.out_channels)
weight = torch.einsum("bir,bro->oi", a_in, a_out).unsqueeze(2)
self.main_fc.weight.data.add_(weight)
def unfix_speaker(self, g):
self.speaker_fixed = False
a_in = self.adapter_in(g).view(-1, self.in_channels, self.r)
a_out = self.adapter_out(g).view(-1, self.r, self.out_channels)
weight = torch.einsum("bir,bro->oi", a_in, a_out).unsqueeze(2)
self.main_fc.weight.data.sub_(weight)
class LoRALinear2d(nn.Module):
def __init__(self, in_channels, out_channels, info_channels, r):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.info_channels = info_channels
self.r = r
self.main_fc = weight_norm(nn.Conv2d(in_channels, out_channels, (1, 1), (1, 1)))
self.adapter_in = nn.Conv1d(info_channels, in_channels * r, 1)
self.adapter_out = nn.Conv1d(info_channels, out_channels * r, 1)
nn.init.normal_(self.adapter_in.weight.data, 0, 0.01)
nn.init.constant_(self.adapter_out.weight.data, 1e-6)
self.adapter_in = weight_norm(self.adapter_in)
self.adapter_out = weight_norm(self.adapter_out)
self.speaker_fixed = False
def forward(self, x, g):
x_ = self.main_fc(x)
if not self.speaker_fixed:
a_in = self.adapter_in(g).view(-1, self.in_channels, self.r)
a_out = self.adapter_out(g).view(-1, self.r, self.out_channels)
l = torch.einsum("brhw,brc->bchw", torch.einsum("bchw,bcr->brhw", x, a_in), a_out)
x_ = x_ + l
return x_
def remove_weight_norm(self):
remove_weight_norm(self.main_fc)
remove_weight_norm(self.adapter_in)
remove_weight_norm(self.adapter_out)
def fix_speaker(self, g):
a_in = self.adapter_in(g).view(-1, self.in_channels, self.r)
a_out = self.adapter_out(g).view(-1, self.r, self.out_channels)
weight = torch.einsum("bir,bro->oi", a_in, a_out).unsqueeze(2).unsqueeze(3)
self.main_fc.weight.data.add_(weight)
def unfix_speaker(self, g):
a_in = self.adapter_in(g).view(-1, self.in_channels, self.r)
a_out = self.adapter_out(g).view(-1, self.r, self.out_channels)
weight = torch.einsum("bir,bro->oi", a_in, a_out).unsqueeze(2).unsqueeze(3)
self.main_fc.weight.data.sub_(weight)
class MBConv2d(torch.nn.Module):
"""
Causal MBConv2D
"""
def __init__(self, in_channels, out_channels, gin_channels, kernel_size, stride, extend_ratio, r, use_spectral_norm=False):
super(MBConv2d, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
inner_channels = int(in_channels * extend_ratio)
self.kernel_size = kernel_size
self.pwconv1 = LoRALinear2d(in_channels, inner_channels, gin_channels, r=r)
self.dwconv = norm_f(Conv2d(inner_channels, inner_channels, kernel_size, stride, groups=inner_channels))
self.pwconv2 = LoRALinear2d(inner_channels, out_channels, gin_channels, r=r)
self.pwnorm = LayerNorm(in_channels)
self.dwnorm = LayerNorm(inner_channels)
def forward(self, x, g):
x = self.pwnorm(x)
x = self.pwconv1(x, g)
x = F.pad(x, [0, 0, self.kernel_size[0] - 1, 0], mode="constant")
x = self.dwnorm(x)
x = self.dwconv(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = self.pwconv2(x, g)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
return x
class ConvNext2d(torch.nn.Module):
"""
Causal ConvNext Block
stride = 1 only
"""
def __init__(self, in_channels, out_channels, gin_channels, kernel_size, stride, extend_ratio, r, use_spectral_norm=False):
super(ConvNext2d, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
inner_channels = int(in_channels * extend_ratio)
self.kernel_size = kernel_size
self.dwconv = norm_f(Conv2d(in_channels, in_channels, kernel_size, stride, groups=in_channels))
self.pwconv1 = LoRALinear2d(in_channels, inner_channels, gin_channels, r=r)
self.pwconv2 = LoRALinear2d(inner_channels, out_channels, gin_channels, r=r)
self.act = nn.GELU()
self.norm = LayerNorm(in_channels)
def forward(self, x, g):
x = F.pad(x, [0, 0, self.kernel_size[0] - 1, 0], mode="constant")
x = self.dwconv(x)
x = self.norm(x)
x = self.pwconv1(x, g)
x = self.act(x)
x = self.pwconv2(x, g)
x = self.act(x)
return x
def remove_weight_norm(self):
remove_weight_norm(self.dwconv)
class WaveBlock(torch.nn.Module):
def __init__(self, inner_channels, gin_channels, kernel_sizes, strides, dilations, extend_rate, r):
super(WaveBlock, self).__init__()
norm_f = weight_norm
extend_channels = int(inner_channels * extend_rate)
self.dconvs = nn.ModuleList()
self.p1convs = nn.ModuleList()
self.p2convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.act = nn.GELU()
# self.ses = nn.ModuleList()
# self.norms = []
for i, (k, s, d) in enumerate(zip(kernel_sizes, strides, dilations)):
self.dconvs.append(DilatedCausalConv1d(inner_channels, inner_channels, k, stride=s, dilation=d, groups=inner_channels))
self.p1convs.append(LoRALinear1d(inner_channels, extend_channels, gin_channels, r))
self.p2convs.append(LoRALinear1d(extend_channels, inner_channels, gin_channels, r))
self.norms.append(LayerNorm(inner_channels))
def forward(self, x, x_mask, g):
x *= x_mask
for i in range(len(self.dconvs)):
residual = x.clone()
x = self.dconvs[i](x)
x = self.norms[i](x)
x *= x_mask
x = self.p1convs[i](x, g)
x = self.act(x)
x = self.p2convs[i](x, g)
x = residual + x
return x
def remove_weight_norm(self):
for c in self.dconvs:
c.remove_weight_norm()
for c in self.p1convs:
c.remove_weight_norm()
for c in self.p2convs:
c.remove_weight_norm()
def fix_speaker(self, g):
for c in self.p1convs:
c.fix_speaker(g)
for c in self.p2convs:
c.fix_speaker(g)
def unfix_speaker(self, g):
for c in self.p1convs:
c.unfix_speaker(g)
for c in self.p2convs:
c.unfix_speaker(g)
class SnakeFilter(torch.nn.Module):
"""
Adaptive filter using snakebeta
"""
def __init__(self, channels, groups, kernel_size, num_layers, eps=1e-6):
super(SnakeFilter, self).__init__()
self.eps = eps
self.num_layers = num_layers
inner_channels = channels * groups
self.init_conv = DilatedCausalConv1d(1, inner_channels, kernel_size)
self.dconvs = torch.nn.ModuleList()
self.pconvs = torch.nn.ModuleList()
self.post_conv = DilatedCausalConv1d(inner_channels+1, 1, kernel_size, bias=False)
for i in range(self.num_layers):
self.dconvs.append(DilatedCausalConv1d(inner_channels, inner_channels, kernel_size, stride=1, groups=inner_channels, dilation=kernel_size ** (i + 1)))
self.pconvs.append(weight_norm(Conv1d(inner_channels, inner_channels, 1, groups=groups)))
self.snake_alpha = torch.nn.Parameter(torch.zeros(inner_channels), requires_grad=True)
self.snake_beta = torch.nn.Parameter(torch.zeros(inner_channels), requires_grad=True)
def forward(self, x):
y = x.clone()
x = self.init_conv(x)
for i in range(self.num_layers):
# snake activation
x = self.dconvs[i](x)
x = self.pconvs[i](x)
x = x + (1.0 / torch.clip(self.snake_beta.unsqueeze(0).unsqueeze(-1), min=self.eps)) * torch.pow(torch.sin(x * self.snake_alpha.unsqueeze(0).unsqueeze(-1)), 2)
x = torch.cat([x, y], 1)
x = self.post_conv(x)
return x
def remove_weight_norm(self):
self.init_conv.remove_weight_norm()
for c in self.dconvs:
c.remove_weight_norm()
for c in self.pconvs:
remove_weight_norm(c)
self.post_conv.remove_weight_norm()
"""
https://github.com/charactr-platform/vocos/blob/main/vocos/heads.py
"""
class FourierHead(nn.Module):
"""Base class for inverse fourier modules."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
raise NotImplementedError("Subclasses must implement the forward method.")
class IMDCT(nn.Module):
"""
Inverse Modified Discrete Cosine Transform (IMDCT) module.
Args:
frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, frame_len: int, padding: str = "same"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.frame_len = frame_len * 2
N = frame_len
n0 = (N + 1) / 2
window = torch.from_numpy(scipy.signal.cosine(N * 2)).float()
self.register_buffer("window", window)
pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
self.register_buffer("pre_twiddle", torch.view_as_real(pre_twiddle))
self.register_buffer("post_twiddle", torch.view_as_real(post_twiddle))
def forward(self, X: torch.Tensor) -> torch.Tensor:
"""
Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
Args:
X (Tensor): Input MDCT coefficients of shape (B, N, L), where B is the batch size,
L is the number of frames, and N is the number of frequency bins.
Returns:
Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
"""
X = X.transpose(1, 2)
B, L, N = X.shape
Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
Y[..., :N] = X
Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
y = torch.fft.ifft(Y * torch.view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1)
y = torch.real(y * torch.view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2)
result = y * self.window.expand(y.shape)
output_size = (1, (L + 1) * N)
audio = torch.nn.functional.fold(
result.transpose(1, 2),
output_size=output_size,
kernel_size=(1, self.frame_len),
stride=(1, self.frame_len // 2),
)[:, 0, 0, :]
if self.padding == "center":
pad = self.frame_len // 2
elif self.padding == "same":
pad = self.frame_len // 4
else:
raise ValueError("Padding must be 'center' or 'same'.")
audio = audio[:, pad:-pad]
return audio.unsqueeze(1)
class IMDCTSymExpHead(FourierHead):
"""
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
Args:
dim (int): Hidden dimension of the model.
mdct_frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
based on perceptual scaling. Defaults to None.
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
"""
def __init__(
self, dim: int, gin_channels: int, mdct_frame_len: int, padding: str = "same", sample_rate: int = 24000,
):
super().__init__()
out_dim = mdct_frame_len
self.dconv = DilatedCausalConv1d(dim, dim, 5, 1, dim, 1)
self.pconv1 = LoRALinear1d(dim, dim * 2, gin_channels, 2)
self.pconv2 = LoRALinear1d(dim * 2, out_dim, gin_channels, 2)
self.act = torch.nn.GELU()
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
if sample_rate is not None:
# optionally init the last layer following mel-scale
m_max = _hz_to_mel(sample_rate // 2)
m_pts = torch.linspace(0, m_max, out_dim)
f_pts = _mel_to_hz(m_pts)
scale = 1 - (f_pts / f_pts.max())
with torch.no_grad():
self.pconv2.main_fc.weight.mul_(scale.view(-1, 1, 1))
def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the IMDCTSymExpHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x = self.dconv(x)
x = self.pconv1(x, g)
x = self.act(x)
x = self.pconv2(x, g)
x = symexp(x)
x = torch.clip(x, min=-1e2, max=1e2) # safeguard to prevent excessively large magnitudes
audio = self.imdct(x)
return audio
def remove_weight_norm(self):
self.dconv.remove_weight_norm()
self.pconv1.remove_weight_norm()
self.pconv2.remove_weight_norm()
def fix_speaker(self, g):
self.pconv1.fix_speaker(g)
self.pconv2.fix_speaker(g)
def unfix_speaker(self, g):
self.pconv1.unfix_speaker(g)
self.pconv2.unfix_speaker(g)
def symexp(x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * (torch.exp(x.abs()) - 1)

View File

@ -0,0 +1,207 @@
import numpy as np
import torch
from torch.nn import functional as F
DEFAULT_MIN_BIN_WIDTH = 1e-3
DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3
def piecewise_rational_quadratic_transform(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails=None,
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
else:
spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
outputs, logabsdet = spline_fn(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs
)
return outputs, logabsdet
def searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
def unconstrained_rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails="linear",
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(inputs)
logabsdet = torch.zeros_like(inputs)
if tails == "linear":
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError("{} tails are not implemented.".format(tails))
(
outputs[inside_interval_mask],
logabsdet[inside_interval_mask],
) = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
)
return outputs, logabsdet
def rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
left=0.0,
right=1.0,
bottom=0.0,
top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if torch.min(inputs) < left or torch.max(inputs) > right:
raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0:
raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, dim=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
heights = F.softmax(unnormalized_heights, dim=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = torch.cumsum(heights, dim=-1)
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top
heights = cumheights[..., 1:] - cumheights[..., :-1]
if inverse:
bin_idx = searchsorted(cumheights, inputs)[..., None]
else:
bin_idx = searchsorted(cumwidths, inputs)[..., None]
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
delta = heights / widths
input_delta = delta.gather(-1, bin_idx)[..., 0]
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
input_heights = heights.gather(-1, bin_idx)[..., 0]
if inverse:
a = (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
) + input_heights * (input_delta - input_derivatives)
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
)
c = -input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()
root = (2 * c) / (-b - torch.sqrt(discriminant))
outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, -logabsdet
else:
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * theta.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, logabsdet

View File

@ -0,0 +1,286 @@
import glob
import logging
import os
import shutil
import socket
import sys
import ffmpeg
import matplotlib
import matplotlib.pylab as plt
import numpy as np
import torch
from scipy.io.wavfile import read
from torch.nn import functional as F
from modules.shared import ROOT_DIR
from .config import TrainConfig
matplotlib.use("Agg")
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging
class AWP:
"""
Fast AWP
https://www.kaggle.com/code/junkoda/fast-awp
"""
def __init__(self, model, optimizer, *, adv_param='weight',
adv_lr=0.01, adv_eps=0.01):
self.model = model
self.optimizer = optimizer
self.adv_param = adv_param
self.adv_lr = adv_lr
self.adv_eps = adv_eps
self.backup = {}
def perturb(self):
"""
Perturb model parameters for AWP gradient
Call before loss and loss.backward()
"""
self._save() # save model parameters
self._attack_step() # perturb weights
def _attack_step(self):
e = 1e-6
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None and self.adv_param in name:
grad = self.optimizer.state[param]['exp_avg']
norm_grad = torch.norm(grad)
norm_data = torch.norm(param.detach())
if norm_grad != 0 and not torch.isnan(norm_grad):
# Set lower and upper limit in change
limit_eps = self.adv_eps * param.detach().abs()
param_min = param.data - limit_eps
param_max = param.data + limit_eps
# Perturb along gradient
# w += (adv_lr * |w| / |grad|) * grad
param.data.add_(grad, alpha=(self.adv_lr * (norm_data + e) / (norm_grad + e)))
# Apply the limit to the change
param.data.clamp_(param_min, param_max)
def _save(self):
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None and self.adv_param in name:
if name not in self.backup:
self.backup[name] = param.clone().detach()
else:
self.backup[name].copy_(param.data)
def restore(self):
"""
Restore model parameter to correct position; AWP do not perturbe weights, it perturb gradients
Call after loss.backward(), before optimizer.step()
"""
for name, param in self.model.named_parameters():
if name in self.backup:
param.data.copy_(self.backup[name])
def load_audio(file: str, sr):
try:
# https://github.com/openai/whisper/blob/main/whisper/audio.py#L26
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
file = (
file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
) # Prevent small white copy path head and tail with spaces and " and return
out, _ = (
ffmpeg.input(file, threads=0)
.output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
)
except Exception as e:
raise RuntimeError(f"Failed to load audio: {e}")
return np.frombuffer(out, np.float32).flatten()
def find_empty_port():
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
s.close()
return port
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
saved_state_dict = checkpoint_dict["model"]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items(): # 模型需要的shape
try:
new_state_dict[k] = saved_state_dict[k]
if saved_state_dict[k].shape != state_dict[k].shape:
print(
f"shape-{k}-mismatch|need-{state_dict[k].shape}|get-{saved_state_dict[k].shape}"
)
if saved_state_dict[k].dim() == 2: # NOTE: check is this ok?
# for embedded input 256 <==> 768
# this achieves we can continue training from original's pretrained checkpoints when using embedder that 768-th dim output etc.
if saved_state_dict[k].dtype == torch.half:
new_state_dict[k] = (
F.interpolate(
saved_state_dict[k].float().unsqueeze(0).unsqueeze(0),
size=state_dict[k].shape,
mode="bilinear",
)
.half()
.squeeze(0)
.squeeze(0)
)
else:
new_state_dict[k] = (
F.interpolate(
saved_state_dict[k].unsqueeze(0).unsqueeze(0),
size=state_dict[k].shape,
mode="bilinear",
)
.squeeze(0)
.squeeze(0)
)
print(
"interpolated new_state_dict",
k,
"from",
saved_state_dict[k].shape,
"to",
new_state_dict[k].shape,
)
else:
raise KeyError
except Exception as e:
# print(traceback.format_exc())
print(f"{k} is not in the checkpoint")
print("error: %s" % e)
new_state_dict[k] = v # 模型自带的随机值
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict, strict=False)
else:
model.load_state_dict(new_state_dict, strict=False)
print("Loaded model weights")
epoch = checkpoint_dict["epoch"]
learning_rate = checkpoint_dict["learning_rate"]
if optimizer is not None and load_opt == 1:
optimizer.load_state_dict(checkpoint_dict["optimizer"])
print("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, epoch))
return model, optimizer, learning_rate, epoch
def save_state(model, optimizer, learning_rate, epoch, checkpoint_path):
print(
"Saving model and optimizer state at epoch {} to {}".format(
epoch, checkpoint_path
)
)
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(
{
"model": state_dict,
"epoch": epoch,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
def summarize(
writer,
global_step,
scalars={},
histograms={},
images={},
audios={},
audio_sampling_rate=22050,
):
for k, v in scalars.items():
writer.add_scalar(k, v, global_step)
for k, v in histograms.items():
writer.add_histogram(k, v, global_step)
for k, v in images.items():
writer.add_image(k, v, global_step, dataformats="HWC")
for k, v in audios.items():
writer.add_audio(k, v, global_step, audio_sampling_rate)
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
filelist = glob.glob(os.path.join(dir_path, regex))
if len(filelist) == 0:
return None
filelist.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
filepath = filelist[-1]
return filepath
def plot_spectrogram_to_numpy(spectrogram):
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def plot_alignment_to_numpy(alignment, info=None):
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
)
fig.colorbar(im, ax=ax)
xlabel = "Decoder timestep"
if info is not None:
xlabel += "\n\n" + info
plt.xlabel(xlabel)
plt.ylabel("Encoder timestep")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def load_wav_to_torch(full_path):
sampling_rate, data = read(full_path)
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
def load_config(training_dir: str, sample_rate: int, emb_channels: int):
if emb_channels == 256:
config_path = os.path.join(ROOT_DIR, "configs", f"{sample_rate}.json")
else:
config_path = os.path.join(
ROOT_DIR, "configs", f"{sample_rate}-{emb_channels}.json"
)
config_save_path = os.path.join(training_dir, "config.json")
shutil.copyfile(config_path, config_save_path)
return TrainConfig.parse_file(config_save_path)

View File

@ -3,6 +3,7 @@ from typing import Any
import math
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
from Exceptions import (
DeviceCannotSupportHalfPrecisionException,
DeviceChangingException,
@ -118,10 +119,6 @@ class Pipeline(object):
# tensor型調整
feats = audio_pad
if self.isHalf is True:
feats = feats.half()
else:
feats = feats.float()
if feats.dim() == 2: # double channels
feats = feats.mean(-1)
assert feats.dim() == 1, feats.dim()
@ -129,19 +126,20 @@ class Pipeline(object):
# embedding
padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False)
try:
feats = self.embedder.extractFeatures(feats, embOutputLayer, useFinalProj)
if torch.isnan(feats).all():
raise DeviceCannotSupportHalfPrecisionException()
except RuntimeError as e:
if "HALF" in e.__str__().upper():
raise HalfPrecisionChangingException()
elif "same device" in e.__str__():
raise DeviceChangingException()
else:
raise e
if protect < 0.5 and search_index:
feats0 = feats.clone()
with autocast(enabled=self.isHalf):
try:
feats = self.embedder.extractFeatures(feats, embOutputLayer, useFinalProj)
if torch.isnan(feats).all():
raise DeviceCannotSupportHalfPrecisionException()
except RuntimeError as e:
if "HALF" in e.__str__().upper():
raise HalfPrecisionChangingException()
elif "same device" in e.__str__():
raise DeviceChangingException()
else:
raise e
if protect < 0.5 and search_index:
feats0 = feats.clone()
# Index - feature抽出
# if self.index is not None and self.feature is not None and index_rate != 0:
@ -167,10 +165,8 @@ class Pipeline(object):
# recover silient font
npy = np.concatenate([np.zeros([npyOffset, npy.shape[1]]).astype("float32"), npy])
if self.isHalf is True:
npy = npy.astype("float16")
feats = torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate + (1 - index_rate) * feats
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
if protect < 0.5 and search_index:
feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
@ -207,14 +203,15 @@ class Pipeline(object):
# 推論実行
try:
with torch.no_grad():
audio1 = (
torch.clip(
self.inferencer.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0].to(dtype=torch.float32),
-1.0,
1.0,
)
* 32767.5
).data.to(dtype=torch.int16)
with autocast(enabled=self.isHalf):
audio1 = (
torch.clip(
self.inferencer.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0].to(dtype=torch.float32),
-1.0,
1.0,
)
* 32767.5
).data.to(dtype=torch.int16)
except RuntimeError as e:
if "HALF" in e.__str__().upper():
print("11", e)