v3モデルのデモンストレーション

This commit is contained in:
nadare 2023-06-18 22:23:50 +09:00
parent 3bed1af628
commit 0936b02287
11 changed files with 2291 additions and 2 deletions

View File

@ -65,6 +65,7 @@ class EnumInferenceTypes(Enum):
pyTorchRVCNono = "pyTorchRVCNono"
pyTorchRVCv2 = "pyTorchRVCv2"
pyTorchRVCv2Nono = "pyTorchRVCv2Nono"
pyTorchRVCv3 = "pyTorchRVCv3"
pyTorchWebUI = "pyTorchWebUI"
pyTorchWebUINono = "pyTorchWebUINono"
onnxRVC = "onnxRVC"

View File

@ -11,7 +11,33 @@ def _setInfoByPytorch(slot: ModelSlot):
cpt = torch.load(slot.modelFile, map_location="cpu")
config_len = len(cpt["config"])
if config_len == 18:
if cpt["version"] == "v3":
slot.f0 = True if cpt["f0"] == 1 else False
slot.modelType = EnumInferenceTypes.pyTorchRVCv3.value
slot.embChannels = cpt["config"][17]
slot.embOutputLayer = (
cpt["embedder_output_layer"] if "embedder_output_layer" in cpt else 9
)
if slot.embChannels == 256:
slot.useFinalProj = True
else:
slot.useFinalProj = False
slot.embedder = cpt["embedder_name"]
if slot.embedder.endswith("768"):
slot.embedder = slot.embedder[:-3]
if slot.embedder == EnumEmbedderTypes.hubert.value:
slot.embedder = EnumEmbedderTypes.hubert.value
elif slot.embedder == EnumEmbedderTypes.contentvec.value:
slot.embedder = EnumEmbedderTypes.contentvec.value
elif slot.embedder == EnumEmbedderTypes.hubert_jp.value:
slot.embedder = EnumEmbedderTypes.hubert_jp.value
print("nadare v3 loaded")
else:
raise RuntimeError("[Voice Changer][setInfoByONNX] unknown embedder")
elif config_len == 18:
# Original RVC
slot.f0 = True if cpt["f0"] == 1 else False
version = cpt.get("version", "v1")

View File

@ -8,7 +8,7 @@ from voice_changer.RVC.inferencer.RVCInferencerv2 import RVCInferencerv2
from voice_changer.RVC.inferencer.RVCInferencerv2Nono import RVCInferencerv2Nono
from voice_changer.RVC.inferencer.WebUIInferencer import WebUIInferencer
from voice_changer.RVC.inferencer.WebUIInferencerNono import WebUIInferencerNono
from voice_changer.RVC.inferencer.RVCInferencerv3 import RVCInferencerv3
class InferencerManager:
currentInferencer: Inferencer | None = None
@ -37,6 +37,8 @@ class InferencerManager:
return RVCInferencerNono().loadModel(file, gpu)
elif inferencerType == EnumInferenceTypes.pyTorchRVCv2 or inferencerType == EnumInferenceTypes.pyTorchRVCv2.value:
return RVCInferencerv2().loadModel(file, gpu)
elif inferencerType == EnumInferenceTypes.pyTorchRVCv3 or inferencerType == EnumInferenceTypes.pyTorchRVCv3.value:
return RVCInferencerv3().loadModel(file, gpu)
elif inferencerType == EnumInferenceTypes.pyTorchRVCv2Nono or inferencerType == EnumInferenceTypes.pyTorchRVCv2Nono.value:
return RVCInferencerv2Nono().loadModel(file, gpu)
elif inferencerType == EnumInferenceTypes.pyTorchWebUI or inferencerType == EnumInferenceTypes.pyTorchWebUI.value:

View File

@ -0,0 +1,40 @@
import torch
from torch import device
from const import EnumInferenceTypes
from voice_changer.RVC.inferencer.Inferencer import Inferencer
from voice_changer.RVC.deviceManager.DeviceManager import DeviceManager
from .model_v3.models import SynthesizerTrnMs256NSFSid
class RVCInferencerv3(Inferencer):
def loadModel(self, file: str, gpu: device):
print("nadare v3 load start")
super().setProps(EnumInferenceTypes.pyTorchRVCv3, file, True, gpu)
dev = DeviceManager.get_instance().getDevice(gpu)
isHalf = False # DeviceManager.get_instance().halfPrecisionAvailable(gpu)
cpt = torch.load(file, map_location="cpu")
model = SynthesizerTrnMs256NSFSid(**cpt["params"])
model.eval()
model.load_state_dict(cpt["weight"], strict=False)
model = model.to(dev)
if isHalf:
model = model.half()
self.model = model
print("load model comprete")
return self
def infer(
self,
feats: torch.Tensor,
pitch_length: torch.Tensor,
pitch: torch.Tensor,
pitchf: torch.Tensor,
sid: torch.Tensor,
) -> torch.Tensor:
return self.model.infer(feats, pitch_length, pitch, pitchf, sid)

View File

@ -0,0 +1,343 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from . import commons
from .modules import LayerNorm, LoRALinear1d
class Encoder(nn.Module):
def __init__(
self,
hidden_channels,
filter_channels,
gin_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
window_size=25,
**kwargs
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size
self.drop = nn.Dropout(p_dropout)
self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
gin_channels,
n_heads,
p_dropout=p_dropout,
window_size=window_size,
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
gin_channels,
kernel_size,
p_dropout=p_dropout,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, g):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
y = self.attn_layers[i](x, x, g, attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask, g)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
def remove_weight_norm(self):
for l in self.attn_layers:
l.remove_weight_norm()
for l in self.ffn_layers:
l.remove_weight_norm()
class MultiHeadAttention(nn.Module):
def __init__(
self,
channels,
out_channels,
gin_channels,
n_heads,
p_dropout=0.0,
window_size=None,
heads_share=False,
block_length=None,
proximal_bias=False,
proximal_init=False,
):
super().__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.p_dropout = p_dropout
self.window_size = window_size
self.heads_share = heads_share
self.block_length = block_length
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.attn = None
self.k_channels = channels // n_heads
self.conv_q = LoRALinear1d(channels, channels, gin_channels, 2)
self.conv_k = LoRALinear1d(channels, channels, gin_channels, 2)
self.conv_v = LoRALinear1d(channels, channels, gin_channels, 2)
self.conv_qkw = weight_norm(nn.Conv1d(channels, channels, 5, 1, groups=channels, padding=2))
self.conv_vw = weight_norm(nn.Conv1d(channels, channels, 5, 1, groups=channels, padding=2))
self.conv_o = LoRALinear1d(channels, out_channels, gin_channels, 2)
self.drop = nn.Dropout(p_dropout)
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
self.emb_rel_v = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
def forward(self, x, c, g, attn_mask=None):
q = self.conv_qkw(self.conv_q(x, g))
k = self.conv_qkw(self.conv_k(c, g))
v = self.conv_vw(self.conv_v(c, g))
x, self.attn = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x, g)
return x
def attention(self, query, key, value, mask=None):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, t_t = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None:
assert (
t_s == t_t
), "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(
query / math.sqrt(self.k_channels), key_relative_embeddings
)
scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(
device=scores.device, dtype=scores.dtype
)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None:
assert (
t_s == t_t
), "Local attention is only available for self-attention."
block_mask = (
torch.ones_like(scores)
.triu(-self.block_length)
.tril(self.block_length)
)
scores = scores.masked_fill(block_mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(
self.emb_rel_v, t_s
)
output = output + self._matmul_with_relative_values(
relative_weights, value_relative_embeddings
)
output = (
output.transpose(2, 3).contiguous().view(b, d, t_t)
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn
def _matmul_with_relative_values(self, x, y):
"""
x: [b, h, l, m]
y: [h or 1, m, d]
ret: [b, h, l, d]
"""
ret = torch.matmul(x, y.unsqueeze(0))
return ret
def _matmul_with_relative_keys(self, x, y):
"""
x: [b, h, l, d]
y: [h or 1, m, d]
ret: [b, h, l, m]
"""
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
return ret
def _get_relative_embeddings(self, relative_embeddings, length):
max_relative_position = 2 * self.window_size + 1
# Pad first before slice to avoid using cond ops.
pad_length = max(length - (self.window_size + 1), 0)
slice_start_position = max((self.window_size + 1) - length, 0)
slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0:
padded_relative_embeddings = F.pad(
relative_embeddings,
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
)
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[
:, slice_start_position:slice_end_position
]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
"""
x: [b, h, l, 2*l-1]
ret: [b, h, l, l]
"""
batch, heads, length, _ = x.size()
# Concat columns of pad to shift from relative to absolute indexing.
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
)
# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
:, :, :length, length - 1 :
]
return x_final
def _absolute_position_to_relative_position(self, x):
"""
x: [b, h, l, l]
ret: [b, h, l, 2*l-1]
"""
batch, heads, length, _ = x.size()
# padd along column
x = F.pad(
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
)
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final
def _attention_bias_proximal(self, length):
"""Bias for self-attention to encourage attention to close positions.
Args:
length: an integer scalar.
Returns:
a Tensor with shape [1, 1, length, length]
"""
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
def remove_weight_norm(self):
self.conv_q.remove_weight_norm()
self.conv_k.remove_weight_norm()
self.conv_v.remove_weight_norm()
self.conv_o.remove_weight_norm()
remove_weight_norm(self.conv_qkw)
remove_weight_norm(self.conv_vw)
class FFN(nn.Module):
def __init__(
self,
in_channels,
out_channels,
filter_channels,
gin_channels,
kernel_size,
p_dropout=0.0,
activation=None,
causal=False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.activation = activation
self.causal = causal
self.conv_1 = LoRALinear1d(in_channels, filter_channels, gin_channels, 2)
self.conv_2 = LoRALinear1d(filter_channels, out_channels, gin_channels, 2)
self.drop = nn.Dropout(p_dropout)
def forward(self, x, x_mask, g):
x = self.conv_1(x * x_mask, g)
if self.activation == "gelu":
x = x * torch.sigmoid(1.702 * x)
else:
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(x * x_mask, g)
return x * x_mask
def _causal_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = self.kernel_size - 1
pad_r = 0
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x
def _same_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = (self.kernel_size - 1) // 2
pad_r = self.kernel_size // 2
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x
def remove_weight_norm(self):
self.conv_1.remove_weight_norm()
self.conv_2.remove_weight_norm()

View File

@ -0,0 +1,165 @@
import math
import torch
from torch.nn import functional as F
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += (
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
)
return kl
def rand_gumbel(shape):
"""Sample from the Gumbel distribution, protect from overflows."""
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
return -torch.log(-torch.log(uniform_samples))
def rand_gumbel_like(x):
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
return g
def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
r = x[i, :, idx_str:idx_end]
ret[i, :, :r.size(1)] = r
return ret
def slice_segments2(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
r = x[i, idx_str:idx_end]
ret[i, :r.size(0)] = r
return ret
def rand_slice_segments(x, x_lengths, segment_size=4, ids_str=None):
b, d, t = x.size()
if ids_str is None:
ids_str = torch.zeros([b]).to(device=x.device, dtype=x_lengths.dtype)
ids_str_max = torch.maximum(torch.zeros_like(x_lengths).to(device=x_lengths.device ,dtype=x_lengths.dtype), x_lengths - segment_size + 1 - ids_str)
ids_str += (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
num_timescales - 1
)
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length)
return signal
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal.to(dtype=x.dtype, device=x.device)
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
def subsequent_mask(length):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2, 3) * mask
return path
def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if clip_value is not None:
clip_value = float(clip_value)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1.0 / norm_type)
return total_norm

View File

@ -0,0 +1,71 @@
from typing import *
from pydantic import BaseModel
class TrainConfigTrain(BaseModel):
log_interval: int
seed: int
epochs: int
learning_rate: float
betas: List[float]
eps: float
batch_size: int
fp16_run: bool
lr_decay: float
segment_size: int
init_lr_ratio: int
warmup_epochs: int
c_mel: int
c_kl: float
class TrainConfigData(BaseModel):
max_wav_value: float
sampling_rate: int
filter_length: int
hop_length: int
win_length: int
n_mel_channels: int
mel_fmin: float
mel_fmax: Any
class TrainConfigModel(BaseModel):
inter_channels: int
hidden_channels: int
filter_channels: int
n_heads: int
n_layers: int
kernel_size: int
p_dropout: int
resblock: str
resblock_kernel_sizes: List[int]
resblock_dilation_sizes: List[List[int]]
upsample_rates: List[int]
upsample_initial_channel: int
upsample_kernel_sizes: List[int]
use_spectral_norm: bool
gin_channels: int
emb_channels: int
spk_embed_dim: int
class TrainConfig(BaseModel):
version: Literal["v1", "v2"] = "v2"
train: TrainConfigTrain
data: TrainConfigData
model: TrainConfigModel
class DatasetMetaItem(BaseModel):
gt_wav: str
co256: str
f0: Optional[str]
f0nsf: Optional[str]
speaker_id: int
class DatasetMetadata(BaseModel):
files: Dict[str, DatasetMetaItem]
# mute: DatasetMetaItem

View File

@ -0,0 +1,522 @@
import math
import os
import sys
import numpy as np
import torch
from torch import nn
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from . import attentions, commons, modules
from .commons import get_padding, init_weights
from .modules import (CausalConvTranspose1d, ConvNext2d, DilatedCausalConv1d,
LoRALinear1d, ResBlock1, WaveConv1D)
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
class TextEncoder(nn.Module):
def __init__(
self,
out_channels: int,
hidden_channels: int,
filter_channels: int,
emb_channels: int,
gin_channels: int,
n_heads: int,
n_layers: int,
kernel_size: int,
p_dropout: int,
f0: bool = True,
):
super().__init__()
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.emb_channels = emb_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.emb_phone = nn.Linear(emb_channels, hidden_channels)
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
if f0 == True:
self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
self.emb_g = nn.Conv1d(gin_channels, hidden_channels, 1)
self.encoder = attentions.Encoder(
hidden_channels, filter_channels, gin_channels, n_heads, n_layers, kernel_size, p_dropout
)
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
def forward(self, phone, pitch, lengths, g):
if pitch == None:
x = self.emb_phone(phone)
else:
x = self.emb_phone(phone) + self.emb_pitch(pitch)
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
x.dtype
)
x = self.encoder(x * x_mask, x_mask, g)
x = self.proj(x)
return x, None, x_mask
class SineGen(torch.nn.Module):
"""Definition of sine generator
SineGen(samp_rate, harmonic_num = 0,
sine_amp = 0.1, noise_std = 0.003,
voiced_threshold = 0,
flag_for_pulse=False)
samp_rate: sampling rate in Hz
harmonic_num: number of harmonic overtones (default 0)
sine_amp: amplitude of sine-wavefrom (default 0.1)
noise_std: std of Gaussian noise (default 0.003)
voiced_thoreshold: F0 threshold for U/V classification (default 0)
flag_for_pulse: this SinGen is used inside PulseGen (default False)
Note: when flag_for_pulse is True, the first time step of a voiced
segment is always sin(np.pi) or cos(0)
"""
def __init__(
self,
samp_rate,
harmonic_num=0,
sine_amp=0.1,
noise_std=0.003,
voiced_threshold=0,
flag_for_pulse=False,
):
super(SineGen, self).__init__()
self.sine_amp = sine_amp
self.noise_std = noise_std
self.harmonic_num = harmonic_num
self.dim = self.harmonic_num + 1
self.sampling_rate = samp_rate
self.voiced_threshold = voiced_threshold
def _f02uv(self, f0):
# generate uv signal
uv = torch.ones_like(f0)
uv = uv * (f0 > self.voiced_threshold)
return uv
def forward(self, f0, upp):
"""sine_tensor, uv = forward(f0)
input F0: tensor(batchsize=1, length, dim=1)
f0 for unvoiced steps should be 0
output sine_tensor: tensor(batchsize=1, length, dim)
output uv: tensor(batchsize=1, length, 1)
"""
with torch.no_grad():
f0 = f0[:, None].transpose(1, 2)
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
# fundamental component
f0_buf[:, :, 0] = f0[:, :, 0]
for idx in np.arange(self.harmonic_num):
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
idx + 2
) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
rand_ini = torch.rand(
f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
)
rand_ini[:, 0] = 0
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化
tmp_over_one *= upp
tmp_over_one = F.interpolate(
tmp_over_one.transpose(2, 1),
scale_factor=upp,
mode="linear",
align_corners=True,
).transpose(2, 1)
rad_values = F.interpolate(
rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
).transpose(
2, 1
) #######
tmp_over_one %= 1
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
cumsum_shift = torch.zeros_like(rad_values)
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
sine_waves = torch.sin(
torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
)
sine_waves = sine_waves * self.sine_amp
uv = self._f02uv(f0)
uv = F.interpolate(
uv.transpose(2, 1), scale_factor=upp, mode="nearest"
).transpose(2, 1)
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)
sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise
class SourceModuleHnNSF(torch.nn.Module):
"""SourceModule for hn-nsf
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0)
sampling_rate: sampling_rate in Hz
harmonic_num: number of harmonic above F0 (default: 0)
sine_amp: amplitude of sine source signal (default: 0.1)
add_noise_std: std of additive Gaussian noise (default: 0.003)
note that amplitude of noise in unvoiced is decided
by sine_amp
voiced_threshold: threhold to set U/V given F0 (default: 0)
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
uv (batchsize, length, 1)
"""
def __init__(
self,
sampling_rate,
gin_channels,
harmonic_num=0,
sine_amp=0.1,
add_noise_std=0.003,
voiced_threshod=0,
is_half=True,
):
super(SourceModuleHnNSF, self).__init__()
self.sine_amp = sine_amp
self.noise_std = add_noise_std
self.is_half = is_half
# to produce sine waveforms
self.l_sin_gen = SineGen(
sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
)
# to merge source harmonics into a single excitation
self.l_linear = torch.nn.Conv1d(gin_channels, harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh()
def forward(self, x, upp=None):
sine_wavs, uv, _ = self.l_sin_gen(x, upp)
sine_raw = torch.transpose(sine_wavs, 1, 2).to(device=x.device, dtype=x.dtype)
return sine_raw, None, None # noise, uv
class GeneratorNSF(torch.nn.Module):
def __init__(
self,
initial_channel,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels,
sr,
harmonic_num=16,
is_half=False,
):
super(GeneratorNSF, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.upsample_rates = upsample_rates
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
self.m_source = SourceModuleHnNSF(
sampling_rate=sr, gin_channels=gin_channels, harmonic_num=harmonic_num, is_half=is_half
)
self.gpre = Conv1d(gin_channels, initial_channel, 1)
self.conv_pre = ResBlock1(initial_channel, upsample_initial_channel, gin_channels, [7] * 5, [1] * 5, [1, 2, 4, 8, 1], 1, 2)
self.ups = nn.ModuleList()
self.resblocks = nn.ModuleList()
c_cur = upsample_initial_channel
for i, u in enumerate(upsample_rates):
c_pre = c_cur
c_cur = c_pre // 2
self.ups.append(
CausalConvTranspose1d(
c_pre,
c_pre,
kernel_rate=3,
stride=u,
groups=c_pre,
)
)
self.resblocks.append(ResBlock1(c_pre, c_cur, gin_channels, [11] * 5, [1] * 5, [1, 2, 4, 8, 1], 1, r=2))
self.conv_post = DilatedCausalConv1d(c_cur, 1, 5, stride=1, groups=1, dilation=1, bias=False)
self.noise_convs = nn.ModuleList()
self.noise_pre = LoRALinear1d(1 + harmonic_num, c_pre, gin_channels, r=2+harmonic_num)
for i, u in enumerate(upsample_rates[::-1]):
c_pre = c_pre * 2
c_cur = c_cur * 2
if i + 1 < len(upsample_rates):
self.noise_convs.append(DilatedCausalConv1d(c_cur, c_pre, kernel_size=u*3, stride=u, groups=c_cur, dilation=1))
else:
self.noise_convs.append(DilatedCausalConv1d(c_cur, initial_channel, kernel_size=u*3, stride=u, groups=math.gcd(c_cur, initial_channel), dilation=1))
self.upp = np.prod(upsample_rates)
def forward(self, x, x_mask, f0f, g):
har_source, noi_source, uv = self.m_source(f0f, self.upp)
har_source = self.noise_pre(har_source, g)
x_sources = [har_source]
for c in self.noise_convs:
har_source = c(har_source)
x_sources.append(har_source)
x = x + x_sources[-1]
x = x + self.gpre(g)
x = self.conv_pre(x, x_mask, g)
for i, u in enumerate(self.upsample_rates):
x_mask = torch.repeat_interleave(x_mask, u, 2)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = self.ups[i](x)
x = self.resblocks[i](x + x_sources[-i-2], x_mask, g)
x = F.leaky_relu(x)
x = self.conv_post(x)
if x_mask is not None:
x *= x_mask
x = torch.tanh(x)
return x
def remove_weight_norm(self):
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.noise_pre)
remove_weight_norm(self.conv_post)
sr2sr = {
"32k": 32000,
"40k": 40000,
"48k": 48000,
}
class SynthesizerTrnMs256NSFSid(nn.Module):
def __init__(
self,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
spk_embed_dim,
gin_channels,
emb_channels,
sr,
**kwargs
):
super().__init__()
if type(sr) == type("strr"):
sr = sr2sr[sr]
self.spec_channels = spec_channels
self.inter_channels = inter_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.resblock = resblock
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
self.upsample_rates = upsample_rates
self.upsample_initial_channel = upsample_initial_channel
self.upsample_kernel_sizes = upsample_kernel_sizes
self.segment_size = segment_size
self.gin_channels = gin_channels
self.emb_channels = emb_channels
self.sr = sr
# self.hop_length = hop_length#
self.spk_embed_dim = spk_embed_dim
self.emb_pitch = nn.Embedding(256, emb_channels) # pitch 256
self.dec = GeneratorNSF(
emb_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
sr=sr,
)
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
print(
"gin_channels:",
gin_channels,
"self.spk_embed_dim:",
self.spk_embed_dim,
"emb_channels:",
emb_channels,
)
def remove_weight_norm(self):
self.dec.remove_weight_norm()
def forward(
self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
): # 这里ds是id[bs,1]
# print(1,pitch.shape)#[bs,t]
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t广播的
# m_p, _, x_mask = self.enc_p(phone, pitch, phone_lengths, g)
# z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
# z_p = self.flow(m_p * x_mask, x_mask, g=g)
x = phone + self.emb_pitch(pitch)
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(commons.sequence_mask(phone_lengths, x.size(2)), 1).to(
phone.dtype
)
m_p_slice, ids_slice = commons.rand_slice_segments(
x, phone_lengths, self.segment_size
)
# print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
mask_slice = commons.slice_segments(x_mask, ids_slice, self.segment_size)
# print(-2,pitchf.shape,z_slice.shape)
o = self.dec(m_p_slice, mask_slice, pitchf, g)
return o, ids_slice, x_mask, g
def infer(self, phone, phone_lengths, pitch, nsff0, sid, max_len=None):
g = self.emb_g(sid).unsqueeze(-1)
x = phone + self.emb_pitch(pitch)
x = torch.transpose(x, 1, -1)
x_mask = torch.unsqueeze(commons.sequence_mask(phone_lengths, x.size(2)), 1).to(
phone.dtype
)
o = self.dec((x * x_mask)[:, :, :max_len], x_mask, nsff0, g)
return o, x_mask, (None, None, None, None)
class DiscriminatorS(torch.nn.Module):
def __init__(
self,
hidden_channels: int,
filter_channels: int,
gin_channels: int,
n_heads: int,
n_layers: int,
kernel_size: int,
p_dropout: int,
):
super(DiscriminatorS, self).__init__()
self.convs = WaveConv1D(2, hidden_channels, gin_channels, [10, 7, 7, 7, 5, 3, 3], [5, 4, 4, 4, 3, 2, 2], [1] * 7, hidden_channels // 2, False)
self.encoder = attentions.Encoder(
hidden_channels, filter_channels, gin_channels, n_heads, n_layers//2, kernel_size, p_dropout
)
self.cross = weight_norm(torch.nn.Conv1d(gin_channels, hidden_channels, 1, 1))
self.conv_post = weight_norm(torch.nn.Conv1d(hidden_channels, 1, 3, 1, padding=get_padding(5, 1)))
def forward(self, x, g):
x = self.convs(x)
x_mask = torch.ones([x.shape[0], 1, x.shape[2]], device=x.device, dtype=x.dtype)
x = self.encoder(x, x_mask, g)
fmap = [x]
x = x + x * self.cross(g)
y = self.conv_post(x)
return y, fmap
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, gin_channels, upsample_rates, final_dim=256, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
self.use_spectral_norm = use_spectral_norm
self.init_kernel_size = upsample_rates[-1] * 3
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
N = len(upsample_rates)
self.init_conv = norm_f(Conv2d(1, final_dim // (2 ** (N - 1)), (self.init_kernel_size, 1), (upsample_rates[-1], 1)))
self.convs = nn.ModuleList()
for i, u in enumerate(upsample_rates[::-1][1:], start=1):
self.convs.append(
ConvNext2d(
final_dim // (2 ** (N - i)),
final_dim // (2 ** (N - i - 1)),
gin_channels,
(u*3, 1),
(u, 1),
4,
r=2
)
)
self.conv_post = weight_norm(Conv2d(final_dim, 1, (3, 1), (1, 1)))
def forward(self, x, g):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (n_pad, 0), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
x = torch.flip(x, dims=[2])
x = F.pad(x, [0, 0, 0, self.init_kernel_size - 1], mode="constant")
x = self.init_conv(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = torch.flip(x, dims=[2])
for i, l in enumerate(self.convs):
x = l(x, g)
if i >= 1:
fmap.append(x)
x = F.pad(x, [0, 0, 2, 0], mode="constant")
x = self.conv_post(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, upsample_rates, gin_channels, periods=[2, 3, 5, 7, 11, 17], **kwargs):
super(MultiPeriodDiscriminator, self).__init__()
# discs = [DiscriminatorS(hidden_channels, filter_channels, gin_channels, n_heads, n_layers, kernel_size, p_dropout)]
discs = [
DiscriminatorP(i, gin_channels, upsample_rates, use_spectral_norm=False) for i in periods
]
self.ups = np.prod(upsample_rates)
self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat, g):
fmap_rs = []
fmap_gs = []
y_d_rs = []
y_d_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y, g)
y_d_g, fmap_g = d(y_hat, g)
# for j in range(len(fmap_r)):
# print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
y_d_rs.append(y_d_r)
y_d_gs.append(y_d_g)
fmap_rs.append(fmap_r)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs

View File

@ -0,0 +1,626 @@
import math
import torch
from torch import nn
from torch.nn import Conv1d, Conv2d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from . import commons, modules
from .commons import get_padding, init_weights
from .transforms import piecewise_rational_quadratic_transform
LRELU_SLOPE = 0.1
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class ConvReluNorm(nn.Module):
def __init__(
self,
in_channels,
hidden_channels,
out_channels,
kernel_size,
n_layers,
p_dropout,
):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
assert n_layers > 1, "Number of layers should be larger than 0."
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.conv_layers.append(
nn.Conv1d(
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
for _ in range(n_layers - 1):
self.conv_layers.append(
nn.Conv1d(
hidden_channels,
hidden_channels,
kernel_size,
padding=kernel_size // 2,
)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask):
x_org = x
for i in range(self.n_layers):
x = self.conv_layers[i](x * x_mask)
x = self.norm_layers[i](x)
x = self.relu_drop(x)
x = x_org + self.proj(x)
return x * x_mask
class DDSConv(nn.Module):
"""
Dialted and Depth-Separable Convolution
"""
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
self.drop = nn.Dropout(p_dropout)
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(n_layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(
nn.Conv1d(
channels,
channels,
kernel_size,
groups=channels,
dilation=dilation,
padding=padding,
)
)
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm(channels))
self.norms_2.append(LayerNorm(channels))
def forward(self, x, x_mask, g=None):
if g is not None:
x = x + g
for i in range(self.n_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.drop(y)
x = x + y
return x * x_mask
class WN(torch.nn.Module):
def __init__(
self,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0,
p_dropout=0,
):
super(WN, self).__init__()
assert kernel_size % 2 == 1
self.hidden_channels = hidden_channels
self.kernel_size = (kernel_size,)
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.p_dropout = p_dropout
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
self.drop = nn.Dropout(p_dropout)
if gin_channels != 0:
cond_layer = torch.nn.Conv1d(
gin_channels, 2 * hidden_channels * n_layers, 1
)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
for i in range(n_layers):
dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d(
hidden_channels,
2 * hidden_channels,
kernel_size,
dilation=dilation,
padding=padding,
)
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer)
# last one is not necessary
if i < n_layers - 1:
res_skip_channels = 2 * hidden_channels
else:
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
def forward(self, x, x_mask, g=None, **kwargs):
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
if g is not None:
g = self.cond_layer(g)
for i in range(self.n_layers):
x_in = self.in_layers[i](x)
if g is not None:
cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
else:
g_l = torch.zeros_like(x_in)
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
acts = self.drop(acts)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
res_acts = res_skip_acts[:, : self.hidden_channels, :]
x = (x + res_acts) * x_mask
output = output + res_skip_acts[:, self.hidden_channels :, :]
else:
output = output + res_skip_acts
return output * x_mask
def remove_weight_norm(self):
if self.gin_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
class DilatedCausalConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, groups=1, dilation=1, bias=True):
super(DilatedCausalConv1d, self).__init__()
self.kernel_size = kernel_size
self.dilation = dilation
self.stride = stride
self.conv = weight_norm(nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, groups=groups, dilation=dilation, bias=bias))
init_weights(self.conv)
def forward(self, x):
x = torch.flip(x, [2])
x = F.pad(x, [0, (self.kernel_size - 1) * self.dilation], mode="constant", value=0.)
size = x.shape[2] // self.stride
x = self.conv(x)[:, :, :size]
x = torch.flip(x, [2])
return x
def remove_weight_norm(self):
remove_weight_norm(self.conv)
class CausalConvTranspose1d(nn.Module):
"""
padding = 0, dilation = 1のとき
Lout = (Lin - 1) * stride + kernel_rate * stride + output_padding
Lout = Lin * stride + (kernel_rate - 1) * stride + output_padding
output_paddingいらないね
"""
def __init__(self, in_channels, out_channels, kernel_rate=3, stride=1, groups=1):
super(CausalConvTranspose1d, self).__init__()
kernel_size = kernel_rate * stride
self.trim_size = (kernel_rate - 1) * stride
self.conv = weight_norm(nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride=stride, groups=groups))
def forward(self, x):
x = self.conv(x)
return x[:, :, :-self.trim_size]
def remove_weight_norm(self):
remove_weight_norm(self.conv)
class LoRALinear1d(nn.Module):
def __init__(self, in_channels, out_channels, info_channels, r):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.info_channels = info_channels
self.r = r
self.main_fc = weight_norm(nn.Conv1d(in_channels, out_channels, 1))
self.adapter_in = nn.Conv1d(info_channels, in_channels * r, 1)
self.adapter_out = nn.Conv1d(info_channels, out_channels * r, 1)
nn.init.normal_(self.adapter_in.weight.data, 0, 0.01)
nn.init.constant_(self.adapter_out.weight.data, 1e-6)
init_weights(self.main_fc)
self.adapter_in = weight_norm(self.adapter_in)
self.adapter_out = weight_norm(self.adapter_out)
def forward(self, x, g):
a_in = self.adapter_in(g).view(-1, self.in_channels, self.r)
a_out = self.adapter_out(g).view(-1, self.r, self.out_channels)
x = self.main_fc(x) + torch.einsum("brl,brc->bcl", torch.einsum("bcl,bcr->brl", x, a_in), a_out)
return x
def remove_weight_norm(self):
remove_weight_norm(self.main_fc)
remove_weight_norm(self.adapter_in)
remove_weight_norm(self.adapter_out)
class LoRALinear2d(nn.Module):
def __init__(self, in_channels, out_channels, info_channels, r):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.info_channels = info_channels
self.r = r
self.main_fc = weight_norm(nn.Conv2d(in_channels, out_channels, (1, 1), (1, 1)))
self.adapter_in = nn.Conv1d(info_channels, in_channels * r, 1)
self.adapter_out = nn.Conv1d(info_channels, out_channels * r, 1)
nn.init.normal_(self.adapter_in.weight.data, 0, 0.01)
nn.init.constant_(self.adapter_out.weight.data, 1e-6)
self.adapter_in = weight_norm(self.adapter_in)
self.adapter_out = weight_norm(self.adapter_out)
def forward(self, x, g):
a_in = self.adapter_in(g).view(-1, self.in_channels, self.r)
a_out = self.adapter_out(g).view(-1, self.r, self.out_channels)
x = self.main_fc(x) + torch.einsum("brhw,brc->bchw", torch.einsum("bchw,bcr->brhw", x, a_in), a_out)
return x
def remove_weight_norm(self):
remove_weight_norm(self.main_fc)
remove_weight_norm(self.adapter_in)
remove_weight_norm(self.adapter_out)
class WaveConv1D(torch.nn.Module):
def __init__(self, in_channels, out_channels, gin_channels, kernel_sizes, strides, dilations, extend_ratio, r, use_spectral_norm=False):
super(WaveConv1D, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
inner_channels = int(in_channels * extend_ratio)
self.convs = []
# self.norms = []
self.convs.append(LoRALinear1d(in_channels, inner_channels, gin_channels, r))
for i, (k, s, d) in enumerate(zip(kernel_sizes, strides, dilations), start=1):
self.convs.append(norm_f(Conv1d(inner_channels, inner_channels, k, s, dilation=d, groups=inner_channels, padding=get_padding(k, d))))
if i < len(kernel_sizes):
self.convs.append(norm_f(Conv1d(inner_channels, inner_channels, 1, 1)))
else:
self.convs.append(norm_f(Conv1d(inner_channels, out_channels, 1, 1)))
self.convs = nn.ModuleList(self.convs)
def forward(self, x, g, x_mask=None):
for i, l in enumerate(self.convs):
if i % 2:
x_ = l(x)
else:
x_ = l(x, g)
x = F.leaky_relu(x_, modules.LRELU_SLOPE)
if x_mask is not None:
x *= x_mask
return x
def remove_weight_norm(self):
for i, c in enumerate(self.convs):
if i % 2:
remove_weight_norm(c)
else:
c.remove_weight_norm()
class MBConv2d(torch.nn.Module):
"""
Causal MBConv2D
"""
def __init__(self, in_channels, out_channels, gin_channels, kernel_size, stride, extend_ratio, r, use_spectral_norm=False):
super(MBConv2d, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
inner_channels = int(in_channels * extend_ratio)
self.kernel_size = kernel_size
self.pre_pointwise = LoRALinear2d(in_channels, inner_channels, gin_channels, r=r)
self.depthwise = norm_f(Conv2d(inner_channels, inner_channels, kernel_size, stride, groups=inner_channels))
self.post_pointwise = LoRALinear2d(inner_channels, out_channels, gin_channels, r=r)
def forward(self, x, g):
x = self.pre_pointwise(x, g)
x = F.pad(x, [0, 0, self.kernel_size[0] - 1, 0], mode="constant")
x = self.depthwise(x)
x = self.post_pointwise(x, g)
return x
class ConvNext2d(torch.nn.Module):
"""
Causal ConvNext Block
stride = 1 only
"""
def __init__(self, in_channels, out_channels, gin_channels, kernel_size, stride, extend_ratio, r, use_spectral_norm=False):
super(ConvNext2d, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
inner_channels = int(in_channels * extend_ratio)
self.kernel_size = kernel_size
self.dwconv = norm_f(Conv2d(in_channels, in_channels, kernel_size, stride, groups=in_channels))
self.pwconv1 = LoRALinear2d(in_channels, inner_channels, gin_channels, r=r)
self.pwconv2 = LoRALinear2d(inner_channels, out_channels, gin_channels, r=r)
self.act = nn.GELU()
self.norm = LayerNorm(in_channels)
def forward(self, x, g):
x = F.pad(x, [0, 0, self.kernel_size[0] - 1, 0], mode="constant")
x = self.dwconv(x)
x = self.norm(x)
x = self.pwconv1(x, g)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = self.pwconv2(x, g)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
return x
def remove_weight_norm(self):
remove_weight_norm(self.dwconv)
class SqueezeExcitation1D(torch.nn.Module):
def __init__(self, input_channels, squeeze_channels, gin_channels, use_spectral_norm=False):
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
super(SqueezeExcitation1D, self).__init__()
self.fc1 = LoRALinear1d(input_channels, squeeze_channels, gin_channels, 2)
self.fc2 = LoRALinear1d(squeeze_channels, input_channels, gin_channels, 2)
def _scale(self, x, x_mask, g):
x_length = torch.sum(x_mask, dim=2, keepdim=True)
x_length = torch.maximum(x_length, torch.ones_like(x_length))
scale = torch.sum(x * x_mask, dim=2, keepdim=True) / x_length
scale = self.fc1(scale, g)
scale = F.leaky_relu(scale, modules.LRELU_SLOPE)
scale = self.fc2(scale, g)
return torch.sigmoid(scale)
def forward(self, x, x_mask, g):
scale = self._scale(x, x_mask, g)
return scale * x
def remove_weight_norm(self):
self.fc1.remove_weight_norm()
self.fc2.remove_weight_norm()
class ResBlock1(torch.nn.Module):
def __init__(self, in_channels, out_channels, gin_channels, kernel_sizes, strides, dilations, extend_ratio, r):
super(ResBlock1, self).__init__()
norm_f = weight_norm
inner_channels = int(in_channels * extend_ratio)
self.dconvs = nn.ModuleList()
self.pconvs = nn.ModuleList()
# self.ses = nn.ModuleList()
self.norms = nn.ModuleList()
self.init_conv = LoRALinear1d(in_channels, inner_channels, gin_channels, r)
for i, (k, s, d) in enumerate(zip(kernel_sizes, strides, dilations)):
self.norms.append(LayerNorm(inner_channels))
self.dconvs.append(DilatedCausalConv1d(inner_channels, inner_channels, k, stride=s, dilation=d, groups=inner_channels))
if i < len(kernel_sizes) - 1:
self.pconvs.append(LoRALinear1d(inner_channels, inner_channels, gin_channels, r))
self.out_conv = LoRALinear1d(inner_channels, out_channels, gin_channels, r)
init_weights(self.init_conv)
init_weights(self.out_conv)
def forward(self, x, x_mask, g):
x *= x_mask
x = self.init_conv(x, g)
for i in range(len(self.dconvs)):
x *= x_mask
x = self.norms[i](x)
x_ = self.dconvs[i](x)
x_ = F.leaky_relu(x_, modules.LRELU_SLOPE)
if i < len(self.dconvs) - 1:
x = x + self.pconvs[i](x_, g)
x = self.out_conv(x_, g)
return x
def remove_weight_norm(self):
for c in self.dconvs:
c.remove_weight_norm()
for c in self.pconvs:
c.remove_weight_norm()
self.init_conv.remove_weight_norm()
self.out_conv.remove_weight_norm()
class Log(nn.Module):
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
logdet = torch.sum(-y, [1, 2])
return y, logdet
else:
x = torch.exp(x) * x_mask
return x
class Flip(nn.Module):
def forward(self, x, *args, reverse=False, **kwargs):
x = torch.flip(x, [1])
if not reverse:
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
return x, logdet
else:
return x
class ElementwiseAffine(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.m = nn.Parameter(torch.zeros(channels, 1))
self.logs = nn.Parameter(torch.zeros(channels, 1))
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = self.m + torch.exp(self.logs) * x
y = y * x_mask
logdet = torch.sum(self.logs * x_mask, [1, 2])
return y, logdet
else:
x = (x - self.m) * torch.exp(-self.logs) * x_mask
return x
class ResidualCouplingLayer(nn.Module):
def __init__(
self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
p_dropout=0,
gin_channels=0,
mean_only=False,
):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.half_channels = channels // 2
self.mean_only = mean_only
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
p_dropout=p_dropout,
gin_channels=gin_channels,
)
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
else:
m = stats
logs = torch.zeros_like(m)
if not reverse:
x1 = m + x1 * torch.exp(logs) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(logs, [1, 2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-logs) * x_mask
x = torch.cat([x0, x1], 1)
return x
def remove_weight_norm(self):
self.enc.remove_weight_norm()
class ConvFlow(nn.Module):
def __init__(
self,
in_channels,
filter_channels,
kernel_size,
n_layers,
num_bins=10,
tail_bound=5.0,
):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.num_bins = num_bins
self.tail_bound = tail_bound
self.half_channels = in_channels // 2
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
self.proj = nn.Conv1d(
filter_channels, self.half_channels * (num_bins * 3 - 1), 1
)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0)
h = self.convs(h, x_mask, g=g)
h = self.proj(h) * x_mask
b, c, t = x0.shape
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
self.filter_channels
)
unnormalized_derivatives = h[..., 2 * self.num_bins :]
x1, logabsdet = piecewise_rational_quadratic_transform(
x1,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=reverse,
tails="linear",
tail_bound=self.tail_bound,
)
x = torch.cat([x0, x1], 1) * x_mask
logdet = torch.sum(logabsdet * x_mask, [1, 2])
if not reverse:
return x, logdet
else:
return x

View File

@ -0,0 +1,207 @@
import numpy as np
import torch
from torch.nn import functional as F
DEFAULT_MIN_BIN_WIDTH = 1e-3
DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3
def piecewise_rational_quadratic_transform(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails=None,
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
else:
spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
outputs, logabsdet = spline_fn(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs
)
return outputs, logabsdet
def searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
def unconstrained_rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails="linear",
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(inputs)
logabsdet = torch.zeros_like(inputs)
if tails == "linear":
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError("{} tails are not implemented.".format(tails))
(
outputs[inside_interval_mask],
logabsdet[inside_interval_mask],
) = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
)
return outputs, logabsdet
def rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
left=0.0,
right=1.0,
bottom=0.0,
top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if torch.min(inputs) < left or torch.max(inputs) > right:
raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0:
raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, dim=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
heights = F.softmax(unnormalized_heights, dim=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = torch.cumsum(heights, dim=-1)
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top
heights = cumheights[..., 1:] - cumheights[..., :-1]
if inverse:
bin_idx = searchsorted(cumheights, inputs)[..., None]
else:
bin_idx = searchsorted(cumwidths, inputs)[..., None]
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
delta = heights / widths
input_delta = delta.gather(-1, bin_idx)[..., 0]
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
input_heights = heights.gather(-1, bin_idx)[..., 0]
if inverse:
a = (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
) + input_heights * (input_delta - input_derivatives)
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
)
c = -input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()
root = (2 * c) / (-b - torch.sqrt(discriminant))
outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, -logabsdet
else:
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * theta.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, logabsdet

View File

@ -0,0 +1,286 @@
import glob
import logging
import os
import shutil
import socket
import sys
import ffmpeg
import matplotlib
import matplotlib.pylab as plt
import numpy as np
import torch
from scipy.io.wavfile import read
from torch.nn import functional as F
from modules.shared import ROOT_DIR
from .config import TrainConfig
matplotlib.use("Agg")
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging
class AWP:
"""
Fast AWP
https://www.kaggle.com/code/junkoda/fast-awp
"""
def __init__(self, model, optimizer, *, adv_param='weight',
adv_lr=0.01, adv_eps=0.01):
self.model = model
self.optimizer = optimizer
self.adv_param = adv_param
self.adv_lr = adv_lr
self.adv_eps = adv_eps
self.backup = {}
def perturb(self):
"""
Perturb model parameters for AWP gradient
Call before loss and loss.backward()
"""
self._save() # save model parameters
self._attack_step() # perturb weights
def _attack_step(self):
e = 1e-6
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None and self.adv_param in name:
grad = self.optimizer.state[param]['exp_avg']
norm_grad = torch.norm(grad)
norm_data = torch.norm(param.detach())
if norm_grad != 0 and not torch.isnan(norm_grad):
# Set lower and upper limit in change
limit_eps = self.adv_eps * param.detach().abs()
param_min = param.data - limit_eps
param_max = param.data + limit_eps
# Perturb along gradient
# w += (adv_lr * |w| / |grad|) * grad
param.data.add_(grad, alpha=(self.adv_lr * (norm_data + e) / (norm_grad + e)))
# Apply the limit to the change
param.data.clamp_(param_min, param_max)
def _save(self):
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None and self.adv_param in name:
if name not in self.backup:
self.backup[name] = param.clone().detach()
else:
self.backup[name].copy_(param.data)
def restore(self):
"""
Restore model parameter to correct position; AWP do not perturbe weights, it perturb gradients
Call after loss.backward(), before optimizer.step()
"""
for name, param in self.model.named_parameters():
if name in self.backup:
param.data.copy_(self.backup[name])
def load_audio(file: str, sr):
try:
# https://github.com/openai/whisper/blob/main/whisper/audio.py#L26
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
file = (
file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
) # Prevent small white copy path head and tail with spaces and " and return
out, _ = (
ffmpeg.input(file, threads=0)
.output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
)
except Exception as e:
raise RuntimeError(f"Failed to load audio: {e}")
return np.frombuffer(out, np.float32).flatten()
def find_empty_port():
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
s.close()
return port
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
saved_state_dict = checkpoint_dict["model"]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items(): # 模型需要的shape
try:
new_state_dict[k] = saved_state_dict[k]
if saved_state_dict[k].shape != state_dict[k].shape:
print(
f"shape-{k}-mismatch|need-{state_dict[k].shape}|get-{saved_state_dict[k].shape}"
)
if saved_state_dict[k].dim() == 2: # NOTE: check is this ok?
# for embedded input 256 <==> 768
# this achieves we can continue training from original's pretrained checkpoints when using embedder that 768-th dim output etc.
if saved_state_dict[k].dtype == torch.half:
new_state_dict[k] = (
F.interpolate(
saved_state_dict[k].float().unsqueeze(0).unsqueeze(0),
size=state_dict[k].shape,
mode="bilinear",
)
.half()
.squeeze(0)
.squeeze(0)
)
else:
new_state_dict[k] = (
F.interpolate(
saved_state_dict[k].unsqueeze(0).unsqueeze(0),
size=state_dict[k].shape,
mode="bilinear",
)
.squeeze(0)
.squeeze(0)
)
print(
"interpolated new_state_dict",
k,
"from",
saved_state_dict[k].shape,
"to",
new_state_dict[k].shape,
)
else:
raise KeyError
except Exception as e:
# print(traceback.format_exc())
print(f"{k} is not in the checkpoint")
print("error: %s" % e)
new_state_dict[k] = v # 模型自带的随机值
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict, strict=False)
else:
model.load_state_dict(new_state_dict, strict=False)
print("Loaded model weights")
epoch = checkpoint_dict["epoch"]
learning_rate = checkpoint_dict["learning_rate"]
if optimizer is not None and load_opt == 1:
optimizer.load_state_dict(checkpoint_dict["optimizer"])
print("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, epoch))
return model, optimizer, learning_rate, epoch
def save_state(model, optimizer, learning_rate, epoch, checkpoint_path):
print(
"Saving model and optimizer state at epoch {} to {}".format(
epoch, checkpoint_path
)
)
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(
{
"model": state_dict,
"epoch": epoch,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
def summarize(
writer,
global_step,
scalars={},
histograms={},
images={},
audios={},
audio_sampling_rate=22050,
):
for k, v in scalars.items():
writer.add_scalar(k, v, global_step)
for k, v in histograms.items():
writer.add_histogram(k, v, global_step)
for k, v in images.items():
writer.add_image(k, v, global_step, dataformats="HWC")
for k, v in audios.items():
writer.add_audio(k, v, global_step, audio_sampling_rate)
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
filelist = glob.glob(os.path.join(dir_path, regex))
if len(filelist) == 0:
return None
filelist.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
filepath = filelist[-1]
return filepath
def plot_spectrogram_to_numpy(spectrogram):
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def plot_alignment_to_numpy(alignment, info=None):
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
)
fig.colorbar(im, ax=ax)
xlabel = "Decoder timestep"
if info is not None:
xlabel += "\n\n" + info
plt.xlabel(xlabel)
plt.ylabel("Encoder timestep")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def load_wav_to_torch(full_path):
sampling_rate, data = read(full_path)
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
def load_config(training_dir: str, sample_rate: int, emb_channels: int):
if emb_channels == 256:
config_path = os.path.join(ROOT_DIR, "configs", f"{sample_rate}.json")
else:
config_path = os.path.join(
ROOT_DIR, "configs", f"{sample_rate}-{emb_channels}.json"
)
config_save_path = os.path.join(training_dir, "config.json")
shutil.copyfile(config_path, config_save_path)
return TrainConfig.parse_file(config_save_path)