diff --git a/server/const.py b/server/const.py index bb5ac0a1..52edbbd7 100644 --- a/server/const.py +++ b/server/const.py @@ -66,6 +66,7 @@ class EnumInferenceTypes(Enum): pyTorchRVCv2Nono = "pyTorchRVCv2Nono" pyTorchWebUI = "pyTorchWebUI" pyTorchWebUINono = "pyTorchWebUINono" + pyTorchVoRASbeta = "pyTorchVoRASbeta" onnxRVC = "onnxRVC" onnxRVCNono = "onnxRVCNono" diff --git a/server/voice_changer/RVC/RVCModelSlotGenerator.py b/server/voice_changer/RVC/RVCModelSlotGenerator.py index 662e7d05..28532dc6 100644 --- a/server/voice_changer/RVC/RVCModelSlotGenerator.py +++ b/server/voice_changer/RVC/RVCModelSlotGenerator.py @@ -36,8 +36,31 @@ class RVCModelSlotGenerator(ModelSlotGenerator): def _setInfoByPytorch(cls, slot: ModelSlot): cpt = torch.load(slot.modelFile, map_location="cpu") config_len = len(cpt["config"]) + + print(cpt["version"]) + if cpt["version"] == "voras_beta": + slot.f0 = True if cpt["f0"] == 1 else False + slot.modelType = EnumInferenceTypes.pyTorchVoRASbeta.value + slot.embChannels = 768 + slot.embOutputLayer = ( + cpt["embedder_output_layer"] if "embedder_output_layer" in cpt else 9 + ) + slot.useFinalProj = False - if config_len == 18: + slot.embedder = cpt["embedder_name"] + if slot.embedder.endswith("768"): + slot.embedder = slot.embedder[:-3] + + if slot.embedder == EnumEmbedderTypes.hubert.value: + slot.embedder = EnumEmbedderTypes.hubert.value + elif slot.embedder == EnumEmbedderTypes.contentvec.value: + slot.embedder = EnumEmbedderTypes.contentvec.value + elif slot.embedder == EnumEmbedderTypes.hubert_jp.value: + slot.embedder = EnumEmbedderTypes.hubert_jp.value + else: + raise RuntimeError("[Voice Changer][setInfoByONNX] unknown embedder") + + elif config_len == 18: # Original RVC slot.f0 = True if cpt["f0"] == 1 else False version = cpt.get("version", "v1") diff --git a/server/voice_changer/RVC/inferencer/InferencerManager.py b/server/voice_changer/RVC/inferencer/InferencerManager.py index cd5f0a6e..ef4bea5e 100644 --- a/server/voice_changer/RVC/inferencer/InferencerManager.py +++ b/server/voice_changer/RVC/inferencer/InferencerManager.py @@ -8,7 +8,7 @@ from voice_changer.RVC.inferencer.RVCInferencerv2 import RVCInferencerv2 from voice_changer.RVC.inferencer.RVCInferencerv2Nono import RVCInferencerv2Nono from voice_changer.RVC.inferencer.WebUIInferencer import WebUIInferencer from voice_changer.RVC.inferencer.WebUIInferencerNono import WebUIInferencerNono - +from voice_changer.RVC.inferencer.VorasInferencebeta import VoRASInferencer class InferencerManager: currentInferencer: Inferencer | None = None @@ -37,6 +37,8 @@ class InferencerManager: return RVCInferencerNono().loadModel(file, gpu) elif inferencerType == EnumInferenceTypes.pyTorchRVCv2 or inferencerType == EnumInferenceTypes.pyTorchRVCv2.value: return RVCInferencerv2().loadModel(file, gpu) + elif inferencerType == EnumInferenceTypes.pyTorchVoRASbeta or inferencerType == EnumInferenceTypes.pyTorchVoRASbeta.value: + return VoRASInferencer().loadModel(file, gpu) elif inferencerType == EnumInferenceTypes.pyTorchRVCv2Nono or inferencerType == EnumInferenceTypes.pyTorchRVCv2Nono.value: return RVCInferencerv2Nono().loadModel(file, gpu) elif inferencerType == EnumInferenceTypes.pyTorchWebUI or inferencerType == EnumInferenceTypes.pyTorchWebUI.value: diff --git a/server/voice_changer/RVC/inferencer/VorasInferencebeta.py b/server/voice_changer/RVC/inferencer/VorasInferencebeta.py new file mode 100644 index 00000000..a5b02f40 --- /dev/null +++ b/server/voice_changer/RVC/inferencer/VorasInferencebeta.py @@ -0,0 +1,39 @@ +import torch +from torch import device + +from const import EnumInferenceTypes +from voice_changer.RVC.inferencer.Inferencer import Inferencer +from voice_changer.RVC.deviceManager.DeviceManager import DeviceManager +from .voras_beta.models import Synthesizer + + +class VoRASInferencer(Inferencer): + def loadModel(self, file: str, gpu: device): + super().setProps(EnumInferenceTypes.pyTorchVoRASbeta, file, False, gpu) + + dev = DeviceManager.get_instance().getDevice(gpu) + self.isHalf = False # DeviceManager.get_instance().halfPrecisionAvailable(gpu) + + cpt = torch.load(file, map_location="cpu") + model = Synthesizer(**cpt["params"]) + + model.eval() + model.load_state_dict(cpt["weight"], strict=False) + model.remove_weight_norm() + model.change_speaker(0) + + model = model.to(dev) + + self.model = model + print("load model comprete") + return self + + def infer( + self, + feats: torch.Tensor, + pitch_length: torch.Tensor, + pitch: torch.Tensor, + pitchf: torch.Tensor, + sid: torch.Tensor, + ) -> torch.Tensor: + return self.model.infer(feats, pitch_length, pitch, pitchf, sid) diff --git a/server/voice_changer/RVC/inferencer/voras_beta/commons.py b/server/voice_changer/RVC/inferencer/voras_beta/commons.py new file mode 100644 index 00000000..79731d2b --- /dev/null +++ b/server/voice_changer/RVC/inferencer/voras_beta/commons.py @@ -0,0 +1,165 @@ +import math + +import torch +from torch.nn import functional as F + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += ( + 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) + ) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + r = x[i, :, idx_str:idx_end] + ret[i, :, :r.size(1)] = r + return ret + + +def slice_segments2(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + r = x[i, idx_str:idx_end] + ret[i, :r.size(0)] = r + return ret + + +def rand_slice_segments(x, x_lengths, segment_size=4, ids_str=None): + b, d, t = x.size() + if ids_str is None: + ids_str = torch.zeros([b]).to(device=x.device, dtype=x_lengths.dtype) + ids_str_max = torch.maximum(torch.zeros_like(x_lengths).to(device=x_lengths.device ,dtype=x_lengths.dtype), x_lengths - segment_size + 1 - ids_str) + ids_str += (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( + num_timescales - 1 + ) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm diff --git a/server/voice_changer/RVC/inferencer/voras_beta/config.py b/server/voice_changer/RVC/inferencer/voras_beta/config.py new file mode 100644 index 00000000..ddfb4271 --- /dev/null +++ b/server/voice_changer/RVC/inferencer/voras_beta/config.py @@ -0,0 +1,61 @@ +from typing import * + +from pydantic import BaseModel + + +class TrainConfigTrain(BaseModel): + log_interval: int + seed: int + epochs: int + learning_rate: float + betas: List[float] + eps: float + batch_size: int + fp16_run: bool + lr_decay: float + segment_size: int + init_lr_ratio: int + warmup_epochs: int + c_mel: int + c_kl: float + + +class TrainConfigData(BaseModel): + max_wav_value: float + sampling_rate: int + filter_length: int + hop_length: int + win_length: int + n_mel_channels: int + mel_fmin: float + mel_fmax: Any + + +class TrainConfigModel(BaseModel): + emb_channels: int + inter_channels: int + n_layers: int + upsample_rates: List[int] + use_spectral_norm: bool + gin_channels: int + spk_embed_dim: int + + +class TrainConfig(BaseModel): + version: Literal["voras"] = "voras" + train: TrainConfigTrain + data: TrainConfigData + model: TrainConfigModel + + +class DatasetMetaItem(BaseModel): + gt_wav: str + co256: str + f0: Optional[str] + f0nsf: Optional[str] + speaker_id: int + + +class DatasetMetadata(BaseModel): + files: Dict[str, DatasetMetaItem] + # mute: DatasetMetaItem diff --git a/server/voice_changer/RVC/inferencer/voras_beta/models.py b/server/voice_changer/RVC/inferencer/voras_beta/models.py new file mode 100644 index 00000000..3168e590 --- /dev/null +++ b/server/voice_changer/RVC/inferencer/voras_beta/models.py @@ -0,0 +1,238 @@ +import math +import os +import sys + +import numpy as np +import torch +from torch import nn +from torch.nn import Conv2d +from torch.nn import functional as F +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from . import commons, modules +from .commons import get_padding +from .modules import (ConvNext2d, HarmonicEmbedder, IMDCTSymExpHead, + LoRALinear1d, SnakeFilter, WaveBlock) + +parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(parent_dir) + +sr2sr = { + "24k": 24000, + "32k": 32000, + "40k": 40000, + "48k": 48000, +} + +class GeneratorVoras(torch.nn.Module): + def __init__( + self, + emb_channels, + inter_channels, + gin_channels, + n_layers, + sr, + hop_length, + ): + super(GeneratorVoras, self).__init__() + self.n_layers = n_layers + self.emb_pitch = HarmonicEmbedder(768, inter_channels, gin_channels, 16, 15) # # pitch 256 + self.plinear = LoRALinear1d(inter_channels, inter_channels, gin_channels, r=8) + self.glinear = weight_norm(nn.Conv1d(gin_channels, inter_channels, 1)) + self.resblocks = nn.ModuleList() + self.init_linear = LoRALinear1d(emb_channels, inter_channels, gin_channels, r=4) + for _ in range(self.n_layers): + self.resblocks.append(WaveBlock(inter_channels, gin_channels, [9] * 2, [1] * 2, [1, 9], 2, r=4)) + self.head = IMDCTSymExpHead(inter_channels, gin_channels, hop_length, padding="center", sample_rate=sr) + self.post = SnakeFilter(4, 8, 9, 2, eps=1e-5) + + def forward(self, x, pitchf, x_mask, g): + x = self.init_linear(x, g) + self.plinear(self.emb_pitch(pitchf, g), g) + self.glinear(g) + for i in range(self.n_layers): + x = self.resblocks[i](x, x_mask, g) + x = x * x_mask + x = self.head(x, g) + x = self.post(x) + return torch.tanh(x) + + def remove_weight_norm(self): + self.plinear.remove_weight_norm() + remove_weight_norm(self.glinear) + for l in self.resblocks: + l.remove_weight_norm() + self.init_linear.remove_weight_norm() + self.head.remove_weight_norm() + self.post.remove_weight_norm() + + def fix_speaker(self, g): + self.plinear.fix_speaker(g) + self.init_linear.fix_speaker(g) + for l in self.resblocks: + l.fix_speaker(g) + self.head.fix_speaker(g) + + def unfix_speaker(self, g): + self.plinear.unfix_speaker(g) + self.init_linear.unfix_speaker(g) + for l in self.resblocks: + l.unfix_speaker(g) + self.head.unfix_speaker(g) + + +class Synthesizer(nn.Module): + def __init__( + self, + segment_size, + n_fft, + hop_length, + inter_channels, + n_layers, + spk_embed_dim, + gin_channels, + emb_channels, + sr, + **kwargs + ): + super().__init__() + if type(sr) == type("strr"): + sr = sr2sr[sr] + self.segment_size = segment_size + self.n_fft = n_fft + self.hop_length = hop_length + self.inter_channels = inter_channels + self.n_layers = n_layers + self.spk_embed_dim = spk_embed_dim + self.gin_channels = gin_channels + self.emb_channels = emb_channels + self.sr = sr + + self.dec = GeneratorVoras( + emb_channels, + inter_channels, + gin_channels, + n_layers, + sr, + hop_length + ) + + self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels) + print( + "gin_channels:", + gin_channels, + "self.spk_embed_dim:", + self.spk_embed_dim, + "emb_channels:", + emb_channels, + ) + self.speaker = None + + def remove_weight_norm(self): + self.dec.remove_weight_norm() + + def change_speaker(self, sid: int): + if self.speaker is not None: + g = self.emb_g(torch.from_numpy(np.array(self.speaker))).unsqueeze(-1) + self.dec.unfix_speaker(g) + g = self.emb_g(torch.from_numpy(np.array(sid))).unsqueeze(-1) + self.dec.fix_speaker(g) + self.speaker = sid + + def forward( + self, phone, phone_lengths, pitch, pitchf, ds + ): + g = self.emb_g(ds).unsqueeze(-1) + x = torch.transpose(phone, 1, -1) + x_mask = torch.unsqueeze(commons.sequence_mask(phone_lengths, x.size(2)), 1).to(phone.dtype) + x_slice, ids_slice = commons.rand_slice_segments( + x, phone_lengths, self.segment_size + ) + pitchf_slice = commons.slice_segments2(pitchf, ids_slice, self.segment_size) + mask_slice = commons.slice_segments(x_mask, ids_slice, self.segment_size) + o = self.dec(x_slice, pitchf_slice, mask_slice, g) + return o, ids_slice, x_mask, g + + def infer(self, phone, phone_lengths, pitch, nsff0, sid, max_len=None): + g = self.emb_g(sid).unsqueeze(-1) + x = torch.transpose(phone, 1, -1) + x_mask = torch.unsqueeze(commons.sequence_mask(phone_lengths, x.size(2)), 1).to(phone.dtype) + o = self.dec((x * x_mask)[:, :, :max_len], nsff0, x_mask, g) + return o, x_mask, (None, None, None, None) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, gin_channels, upsample_rates, final_dim=256, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + self.init_kernel_size = upsample_rates[-1] * 3 + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + N = len(upsample_rates) + self.init_conv = norm_f(Conv2d(1, final_dim // (2 ** (N - 1)), (self.init_kernel_size, 1), (upsample_rates[-1], 1))) + self.convs = nn.ModuleList() + for i, u in enumerate(upsample_rates[::-1][1:], start=1): + self.convs.append( + ConvNext2d( + final_dim // (2 ** (N - i)), + final_dim // (2 ** (N - i - 1)), + gin_channels, + (u*3, 1), + (u, 1), + 4, + r=2 + i//2 + ) + ) + self.conv_post = weight_norm(Conv2d(final_dim, 1, (3, 1), (1, 1))) + + def forward(self, x, g): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (n_pad, 0), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + x = torch.flip(x, dims=[2]) + x = F.pad(x, [0, 0, 0, self.init_kernel_size - 1], mode="constant") + x = self.init_conv(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = torch.flip(x, dims=[2]) + fmap.append(x) + + for i, l in enumerate(self.convs): + x = l(x, g) + fmap.append(x) + + x = F.pad(x, [0, 0, 2, 0], mode="constant") + x = self.conv_post(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, upsample_rates, gin_channels, periods=[2, 3, 5, 7, 11, 17], **kwargs): + super(MultiPeriodDiscriminator, self).__init__() + + discs = [ + DiscriminatorP(i, gin_channels, upsample_rates, use_spectral_norm=False) for i in periods + ] + self.ups = np.prod(upsample_rates) + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat, g): + fmap_rs = [] + fmap_gs = [] + y_d_rs = [] + y_d_gs = [] + for d in self.discriminators: + y_d_r, fmap_r = d(y, g) + y_d_g, fmap_g = d(y_hat, g) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs diff --git a/server/voice_changer/RVC/inferencer/voras_beta/modules.py b/server/voice_changer/RVC/inferencer/voras_beta/modules.py new file mode 100644 index 00000000..f6659c68 --- /dev/null +++ b/server/voice_changer/RVC/inferencer/voras_beta/modules.py @@ -0,0 +1,496 @@ +import math + +import numpy as np +import scipy +import torch +from torch import nn +from torch.nn import Conv1d, Conv2d +from torch.nn import functional as F +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm +from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz + +from . import commons, modules +from .commons import get_padding, init_weights +from .transforms import piecewise_rational_quadratic_transform + +LRELU_SLOPE = 0.1 + +class HarmonicEmbedder(nn.Module): + def __init__(self, num_embeddings, embedding_dim, gin_channels, num_head, num_harmonic=0, f0_min=50., f0_max=1100., device="cuda"): + super(HarmonicEmbedder, self).__init__() + self.embedding_dim = embedding_dim + self.num_head = num_head + self.num_harmonic = num_harmonic + + f0_mel_min = np.log(1 + f0_min / 700) + f0_mel_max = np.log(1 + f0_max * (1 + num_harmonic) / 700) + self.sequence = torch.from_numpy(np.linspace(f0_mel_min, f0_mel_max, num_embeddings-2)) + self.emb_layer = torch.nn.Embedding(num_embeddings, embedding_dim) + self.linear_q = Conv1d(gin_channels, num_head * (1 + num_harmonic), 1) + self.weight = None + + def forward(self, x, g): + b, l = x.size() + non_zero = (x != 0.).to(dtype=torch.long).unsqueeze(1) + mel = torch.log(1 + x / 700).unsqueeze(1) + harmonies = torch.arange(1 + self.num_harmonic, device=x.device, dtype=x.dtype).view(1, 1 + self.num_harmonic, 1) + 1. + ix = torch.searchsorted(self.sequence.to(x.device), mel * harmonies).to(x.device) + 1 + ix = ix * non_zero + emb = self.emb_layer(ix).transpose(1, 3).reshape(b, self.num_head, self.embedding_dim // self.num_head, 1 + self.num_harmonic, l) + if self.weight is None: + weight = torch.nn.functional.softmax(self.linear_q(g).reshape(b, self.num_head, 1, 1 + self.num_harmonic, 1), 3) + else: + weight = self.weight + res = torch.sum(emb * weight, dim=3).reshape(b, self.embedding_dim, l) + return res + + def fix_speaker(self, g): + self.weight = torch.nn.functional.softmax(self.linear_q(g).reshape(1, self.num_head, 1, 1 + self.num_harmonic, 1), 3) + + def unfix_speaker(self, g): + self.weight = None + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class DilatedCausalConv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, groups=1, dilation=1, bias=True): + super(DilatedCausalConv1d, self).__init__() + self.kernel_size = kernel_size + self.dilation = dilation + self.stride = stride + self.conv = weight_norm(nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, groups=groups, dilation=dilation, bias=bias)) + + def forward(self, x): + x = torch.flip(x, [2]) + x = F.pad(x, [0, (self.kernel_size - 1) * self.dilation], mode="constant", value=0.) + size = x.shape[2] // self.stride + x = self.conv(x)[:, :, :size] + x = torch.flip(x, [2]) + return x + + def remove_weight_norm(self): + remove_weight_norm(self.conv) + + +class CausalConvTranspose1d(nn.Module): + """ + padding = 0, dilation = 1のとき + + Lout = (Lin - 1) * stride + kernel_rate * stride + output_padding + Lout = Lin * stride + (kernel_rate - 1) * stride + output_padding + output_paddingいらないね + """ + def __init__(self, in_channels, out_channels, kernel_rate=3, stride=1, groups=1): + super(CausalConvTranspose1d, self).__init__() + kernel_size = kernel_rate * stride + self.trim_size = (kernel_rate - 1) * stride + self.conv = weight_norm(nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride=stride, groups=groups)) + + def forward(self, x): + x = self.conv(x) + return x[:, :, :-self.trim_size] + + def remove_weight_norm(self): + remove_weight_norm(self.conv) + + +class LoRALinear1d(nn.Module): + def __init__(self, in_channels, out_channels, info_channels, r): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.info_channels = info_channels + self.r = r + self.main_fc = weight_norm(nn.Conv1d(in_channels, out_channels, 1)) + self.adapter_in = nn.Conv1d(info_channels, in_channels * r, 1) + self.adapter_out = nn.Conv1d(info_channels, out_channels * r, 1) + nn.init.normal_(self.adapter_in.weight.data, 0, 0.01) + nn.init.constant_(self.adapter_out.weight.data, 1e-6) + self.adapter_in = weight_norm(self.adapter_in) + self.adapter_out = weight_norm(self.adapter_out) + self.speaker_fixed = False + + def forward(self, x, g): + x_ = self.main_fc(x) + if not self.speaker_fixed: + a_in = self.adapter_in(g).view(-1, self.in_channels, self.r) + a_out = self.adapter_out(g).view(-1, self.r, self.out_channels) + l = torch.einsum("brl,brc->bcl", torch.einsum("bcl,bcr->brl", x, a_in), a_out) + x_ = x_ + l + return x_ + + def remove_weight_norm(self): + remove_weight_norm(self.main_fc) + remove_weight_norm(self.adapter_in) + remove_weight_norm(self.adapter_out) + + def fix_speaker(self, g): + self.speaker_fixed = True + a_in = self.adapter_in(g).view(-1, self.in_channels, self.r) + a_out = self.adapter_out(g).view(-1, self.r, self.out_channels) + weight = torch.einsum("bir,bro->oi", a_in, a_out).unsqueeze(2) + self.main_fc.weight.data.add_(weight) + + def unfix_speaker(self, g): + self.speaker_fixed = False + a_in = self.adapter_in(g).view(-1, self.in_channels, self.r) + a_out = self.adapter_out(g).view(-1, self.r, self.out_channels) + weight = torch.einsum("bir,bro->oi", a_in, a_out).unsqueeze(2) + self.main_fc.weight.data.sub_(weight) + + +class LoRALinear2d(nn.Module): + def __init__(self, in_channels, out_channels, info_channels, r): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.info_channels = info_channels + self.r = r + self.main_fc = weight_norm(nn.Conv2d(in_channels, out_channels, (1, 1), (1, 1))) + self.adapter_in = nn.Conv1d(info_channels, in_channels * r, 1) + self.adapter_out = nn.Conv1d(info_channels, out_channels * r, 1) + nn.init.normal_(self.adapter_in.weight.data, 0, 0.01) + nn.init.constant_(self.adapter_out.weight.data, 1e-6) + self.adapter_in = weight_norm(self.adapter_in) + self.adapter_out = weight_norm(self.adapter_out) + self.speaker_fixed = False + + def forward(self, x, g): + x_ = self.main_fc(x) + if not self.speaker_fixed: + a_in = self.adapter_in(g).view(-1, self.in_channels, self.r) + a_out = self.adapter_out(g).view(-1, self.r, self.out_channels) + l = torch.einsum("brhw,brc->bchw", torch.einsum("bchw,bcr->brhw", x, a_in), a_out) + x_ = x_ + l + return x_ + + def remove_weight_norm(self): + remove_weight_norm(self.main_fc) + remove_weight_norm(self.adapter_in) + remove_weight_norm(self.adapter_out) + + def fix_speaker(self, g): + a_in = self.adapter_in(g).view(-1, self.in_channels, self.r) + a_out = self.adapter_out(g).view(-1, self.r, self.out_channels) + weight = torch.einsum("bir,bro->oi", a_in, a_out).unsqueeze(2).unsqueeze(3) + self.main_fc.weight.data.add_(weight) + + def unfix_speaker(self, g): + a_in = self.adapter_in(g).view(-1, self.in_channels, self.r) + a_out = self.adapter_out(g).view(-1, self.r, self.out_channels) + weight = torch.einsum("bir,bro->oi", a_in, a_out).unsqueeze(2).unsqueeze(3) + self.main_fc.weight.data.sub_(weight) + + +class MBConv2d(torch.nn.Module): + """ + Causal MBConv2D + """ + def __init__(self, in_channels, out_channels, gin_channels, kernel_size, stride, extend_ratio, r, use_spectral_norm=False): + super(MBConv2d, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + inner_channels = int(in_channels * extend_ratio) + self.kernel_size = kernel_size + self.pwconv1 = LoRALinear2d(in_channels, inner_channels, gin_channels, r=r) + self.dwconv = norm_f(Conv2d(inner_channels, inner_channels, kernel_size, stride, groups=inner_channels)) + self.pwconv2 = LoRALinear2d(inner_channels, out_channels, gin_channels, r=r) + self.pwnorm = LayerNorm(in_channels) + self.dwnorm = LayerNorm(inner_channels) + + def forward(self, x, g): + x = self.pwnorm(x) + x = self.pwconv1(x, g) + x = F.pad(x, [0, 0, self.kernel_size[0] - 1, 0], mode="constant") + x = self.dwnorm(x) + x = self.dwconv(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.pwconv2(x, g) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + return x + +class ConvNext2d(torch.nn.Module): + """ + Causal ConvNext Block + stride = 1 only + """ + def __init__(self, in_channels, out_channels, gin_channels, kernel_size, stride, extend_ratio, r, use_spectral_norm=False): + super(ConvNext2d, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + inner_channels = int(in_channels * extend_ratio) + self.kernel_size = kernel_size + self.dwconv = norm_f(Conv2d(in_channels, in_channels, kernel_size, stride, groups=in_channels)) + self.pwconv1 = LoRALinear2d(in_channels, inner_channels, gin_channels, r=r) + self.pwconv2 = LoRALinear2d(inner_channels, out_channels, gin_channels, r=r) + self.act = nn.GELU() + self.norm = LayerNorm(in_channels) + + def forward(self, x, g): + x = F.pad(x, [0, 0, self.kernel_size[0] - 1, 0], mode="constant") + x = self.dwconv(x) + x = self.norm(x) + x = self.pwconv1(x, g) + x = self.act(x) + x = self.pwconv2(x, g) + x = self.act(x) + return x + + def remove_weight_norm(self): + remove_weight_norm(self.dwconv) + + +class WaveBlock(torch.nn.Module): + def __init__(self, inner_channels, gin_channels, kernel_sizes, strides, dilations, extend_rate, r): + super(WaveBlock, self).__init__() + norm_f = weight_norm + extend_channels = int(inner_channels * extend_rate) + self.dconvs = nn.ModuleList() + self.p1convs = nn.ModuleList() + self.p2convs = nn.ModuleList() + self.norms = nn.ModuleList() + self.act = nn.GELU() + + # self.ses = nn.ModuleList() + # self.norms = [] + for i, (k, s, d) in enumerate(zip(kernel_sizes, strides, dilations)): + self.dconvs.append(DilatedCausalConv1d(inner_channels, inner_channels, k, stride=s, dilation=d, groups=inner_channels)) + self.p1convs.append(LoRALinear1d(inner_channels, extend_channels, gin_channels, r)) + self.p2convs.append(LoRALinear1d(extend_channels, inner_channels, gin_channels, r)) + self.norms.append(LayerNorm(inner_channels)) + + def forward(self, x, x_mask, g): + x *= x_mask + for i in range(len(self.dconvs)): + residual = x.clone() + x = self.dconvs[i](x) + x = self.norms[i](x) + x *= x_mask + x = self.p1convs[i](x, g) + x = self.act(x) + x = self.p2convs[i](x, g) + x = residual + x + return x + + def remove_weight_norm(self): + for c in self.dconvs: + c.remove_weight_norm() + for c in self.p1convs: + c.remove_weight_norm() + for c in self.p2convs: + c.remove_weight_norm() + + def fix_speaker(self, g): + for c in self.p1convs: + c.fix_speaker(g) + for c in self.p2convs: + c.fix_speaker(g) + + def unfix_speaker(self, g): + for c in self.p1convs: + c.unfix_speaker(g) + for c in self.p2convs: + c.unfix_speaker(g) + + +class SnakeFilter(torch.nn.Module): + """ + Adaptive filter using snakebeta + """ + def __init__(self, channels, groups, kernel_size, num_layers, eps=1e-6): + super(SnakeFilter, self).__init__() + self.eps = eps + self.num_layers = num_layers + inner_channels = channels * groups + self.init_conv = DilatedCausalConv1d(1, inner_channels, kernel_size) + self.dconvs = torch.nn.ModuleList() + self.pconvs = torch.nn.ModuleList() + self.post_conv = DilatedCausalConv1d(inner_channels+1, 1, kernel_size, bias=False) + + for i in range(self.num_layers): + self.dconvs.append(DilatedCausalConv1d(inner_channels, inner_channels, kernel_size, stride=1, groups=inner_channels, dilation=kernel_size ** (i + 1))) + self.pconvs.append(weight_norm(Conv1d(inner_channels, inner_channels, 1, groups=groups))) + self.snake_alpha = torch.nn.Parameter(torch.zeros(inner_channels), requires_grad=True) + self.snake_beta = torch.nn.Parameter(torch.zeros(inner_channels), requires_grad=True) + + def forward(self, x): + y = x.clone() + x = self.init_conv(x) + for i in range(self.num_layers): + # snake activation + x = self.dconvs[i](x) + x = self.pconvs[i](x) + x = x + (1.0 / torch.clip(self.snake_beta.unsqueeze(0).unsqueeze(-1), min=self.eps)) * torch.pow(torch.sin(x * self.snake_alpha.unsqueeze(0).unsqueeze(-1)), 2) + x = torch.cat([x, y], 1) + x = self.post_conv(x) + return x + + def remove_weight_norm(self): + self.init_conv.remove_weight_norm() + for c in self.dconvs: + c.remove_weight_norm() + for c in self.pconvs: + remove_weight_norm(c) + self.post_conv.remove_weight_norm() + +""" +https://github.com/charactr-platform/vocos/blob/main/vocos/heads.py +""" +class FourierHead(nn.Module): + """Base class for inverse fourier modules.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class IMDCT(nn.Module): + """ + Inverse Modified Discrete Cosine Transform (IMDCT) module. + + Args: + frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, frame_len: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.frame_len = frame_len * 2 + N = frame_len + n0 = (N + 1) / 2 + window = torch.from_numpy(scipy.signal.cosine(N * 2)).float() + self.register_buffer("window", window) + + pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N) + post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2)) + self.register_buffer("pre_twiddle", torch.view_as_real(pre_twiddle)) + self.register_buffer("post_twiddle", torch.view_as_real(post_twiddle)) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + """ + Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients. + + Args: + X (Tensor): Input MDCT coefficients of shape (B, N, L), where B is the batch size, + L is the number of frames, and N is the number of frequency bins. + + Returns: + Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio. + """ + X = X.transpose(1, 2) + B, L, N = X.shape + Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device) + Y[..., :N] = X + Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,))) + y = torch.fft.ifft(Y * torch.view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1) + y = torch.real(y * torch.view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2) + result = y * self.window.expand(y.shape) + output_size = (1, (L + 1) * N) + audio = torch.nn.functional.fold( + result.transpose(1, 2), + output_size=output_size, + kernel_size=(1, self.frame_len), + stride=(1, self.frame_len // 2), + )[:, 0, 0, :] + + if self.padding == "center": + pad = self.frame_len // 2 + elif self.padding == "same": + pad = self.frame_len // 4 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + audio = audio[:, pad:-pad] + return audio.unsqueeze(1) + + +class IMDCTSymExpHead(FourierHead): + """ + IMDCT Head module for predicting MDCT coefficients with symmetric exponential function + + Args: + dim (int): Hidden dimension of the model. + mdct_frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized + based on perceptual scaling. Defaults to None. + clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. + """ + + def __init__( + self, dim: int, gin_channels: int, mdct_frame_len: int, padding: str = "same", sample_rate: int = 24000, + ): + super().__init__() + out_dim = mdct_frame_len + self.dconv = DilatedCausalConv1d(dim, dim, 5, 1, dim, 1) + self.pconv1 = LoRALinear1d(dim, dim * 2, gin_channels, 2) + self.pconv2 = LoRALinear1d(dim * 2, out_dim, gin_channels, 2) + self.act = torch.nn.GELU() + self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) + + if sample_rate is not None: + # optionally init the last layer following mel-scale + m_max = _hz_to_mel(sample_rate // 2) + m_pts = torch.linspace(0, m_max, out_dim) + f_pts = _mel_to_hz(m_pts) + scale = 1 - (f_pts / f_pts.max()) + + with torch.no_grad(): + self.pconv2.main_fc.weight.mul_(scale.view(-1, 1, 1)) + + def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the IMDCTSymExpHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.dconv(x) + x = self.pconv1(x, g) + x = self.act(x) + x = self.pconv2(x, g) + x = symexp(x) + x = torch.clip(x, min=-1e2, max=1e2) # safeguard to prevent excessively large magnitudes + audio = self.imdct(x) + return audio + + def remove_weight_norm(self): + self.dconv.remove_weight_norm() + self.pconv1.remove_weight_norm() + self.pconv2.remove_weight_norm() + + def fix_speaker(self, g): + self.pconv1.fix_speaker(g) + self.pconv2.fix_speaker(g) + + def unfix_speaker(self, g): + self.pconv1.unfix_speaker(g) + self.pconv2.unfix_speaker(g) + +def symexp(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * (torch.exp(x.abs()) - 1) \ No newline at end of file diff --git a/server/voice_changer/RVC/inferencer/voras_beta/transforms.py b/server/voice_changer/RVC/inferencer/voras_beta/transforms.py new file mode 100644 index 00000000..6f30b717 --- /dev/null +++ b/server/voice_changer/RVC/inferencer/voras_beta/transforms.py @@ -0,0 +1,207 @@ +import numpy as np +import torch +from torch.nn import functional as F + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +def piecewise_rational_quadratic_transform( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = {"tails": tails, "tail_bound": tail_bound} + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs + ) + return outputs, logabsdet + + +def searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 + + +def unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails="linear", + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == "linear": + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError("{} tails are not implemented.".format(tails)) + + ( + outputs[inside_interval_mask], + logabsdet[inside_interval_mask], + ) = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + + return outputs, logabsdet + + +def rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError("Input to a transform is not within its domain") + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError("Minimal bin width too large for the number of bins") + if min_bin_height * num_bins > 1.0: + raise ValueError("Minimal bin height too large for the number of bins") + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + c = -input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * ( + input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta + ) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet diff --git a/server/voice_changer/RVC/inferencer/voras_beta/utils.py b/server/voice_changer/RVC/inferencer/voras_beta/utils.py new file mode 100644 index 00000000..a5152da2 --- /dev/null +++ b/server/voice_changer/RVC/inferencer/voras_beta/utils.py @@ -0,0 +1,286 @@ +import glob +import logging +import os +import shutil +import socket +import sys + +import ffmpeg +import matplotlib +import matplotlib.pylab as plt +import numpy as np +import torch +from scipy.io.wavfile import read +from torch.nn import functional as F + +from modules.shared import ROOT_DIR + +from .config import TrainConfig + +matplotlib.use("Agg") +logging.getLogger("matplotlib").setLevel(logging.WARNING) + +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) +logger = logging + + +class AWP: + """ + Fast AWP + https://www.kaggle.com/code/junkoda/fast-awp + """ + def __init__(self, model, optimizer, *, adv_param='weight', + adv_lr=0.01, adv_eps=0.01): + self.model = model + self.optimizer = optimizer + self.adv_param = adv_param + self.adv_lr = adv_lr + self.adv_eps = adv_eps + self.backup = {} + + def perturb(self): + """ + Perturb model parameters for AWP gradient + Call before loss and loss.backward() + """ + self._save() # save model parameters + self._attack_step() # perturb weights + + def _attack_step(self): + e = 1e-6 + for name, param in self.model.named_parameters(): + if param.requires_grad and param.grad is not None and self.adv_param in name: + grad = self.optimizer.state[param]['exp_avg'] + norm_grad = torch.norm(grad) + norm_data = torch.norm(param.detach()) + + if norm_grad != 0 and not torch.isnan(norm_grad): + # Set lower and upper limit in change + limit_eps = self.adv_eps * param.detach().abs() + param_min = param.data - limit_eps + param_max = param.data + limit_eps + + # Perturb along gradient + # w += (adv_lr * |w| / |grad|) * grad + param.data.add_(grad, alpha=(self.adv_lr * (norm_data + e) / (norm_grad + e))) + + # Apply the limit to the change + param.data.clamp_(param_min, param_max) + + def _save(self): + for name, param in self.model.named_parameters(): + if param.requires_grad and param.grad is not None and self.adv_param in name: + if name not in self.backup: + self.backup[name] = param.clone().detach() + else: + self.backup[name].copy_(param.data) + + def restore(self): + """ + Restore model parameter to correct position; AWP do not perturbe weights, it perturb gradients + Call after loss.backward(), before optimizer.step() + """ + for name, param in self.model.named_parameters(): + if name in self.backup: + param.data.copy_(self.backup[name]) + + +def load_audio(file: str, sr): + try: + # https://github.com/openai/whisper/blob/main/whisper/audio.py#L26 + # This launches a subprocess to decode audio while down-mixing and resampling as necessary. + # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. + file = ( + file.strip(" ").strip('"').strip("\n").strip('"').strip(" ") + ) # Prevent small white copy path head and tail with spaces and " and return + out, _ = ( + ffmpeg.input(file, threads=0) + .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr) + .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) + ) + except Exception as e: + raise RuntimeError(f"Failed to load audio: {e}") + + return np.frombuffer(out, np.float32).flatten() + + +def find_empty_port(): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + s.close() + return port + + +def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1): + assert os.path.isfile(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") + + saved_state_dict = checkpoint_dict["model"] + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + new_state_dict = {} + for k, v in state_dict.items(): # 模型需要的shape + try: + new_state_dict[k] = saved_state_dict[k] + if saved_state_dict[k].shape != state_dict[k].shape: + print( + f"shape-{k}-mismatch|need-{state_dict[k].shape}|get-{saved_state_dict[k].shape}" + ) + if saved_state_dict[k].dim() == 2: # NOTE: check is this ok? + # for embedded input 256 <==> 768 + # this achieves we can continue training from original's pretrained checkpoints when using embedder that 768-th dim output etc. + if saved_state_dict[k].dtype == torch.half: + new_state_dict[k] = ( + F.interpolate( + saved_state_dict[k].float().unsqueeze(0).unsqueeze(0), + size=state_dict[k].shape, + mode="bilinear", + ) + .half() + .squeeze(0) + .squeeze(0) + ) + else: + new_state_dict[k] = ( + F.interpolate( + saved_state_dict[k].unsqueeze(0).unsqueeze(0), + size=state_dict[k].shape, + mode="bilinear", + ) + .squeeze(0) + .squeeze(0) + ) + print( + "interpolated new_state_dict", + k, + "from", + saved_state_dict[k].shape, + "to", + new_state_dict[k].shape, + ) + else: + raise KeyError + except Exception as e: + # print(traceback.format_exc()) + print(f"{k} is not in the checkpoint") + print("error: %s" % e) + new_state_dict[k] = v # 模型自带的随机值 + if hasattr(model, "module"): + model.module.load_state_dict(new_state_dict, strict=False) + else: + model.load_state_dict(new_state_dict, strict=False) + print("Loaded model weights") + + epoch = checkpoint_dict["epoch"] + learning_rate = checkpoint_dict["learning_rate"] + if optimizer is not None and load_opt == 1: + optimizer.load_state_dict(checkpoint_dict["optimizer"]) + print("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, epoch)) + return model, optimizer, learning_rate, epoch + + +def save_state(model, optimizer, learning_rate, epoch, checkpoint_path): + print( + "Saving model and optimizer state at epoch {} to {}".format( + epoch, checkpoint_path + ) + ) + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + torch.save( + { + "model": state_dict, + "epoch": epoch, + "optimizer": optimizer.state_dict(), + "learning_rate": learning_rate, + }, + checkpoint_path, + ) + + +def summarize( + writer, + global_step, + scalars={}, + histograms={}, + images={}, + audios={}, + audio_sampling_rate=22050, +): + for k, v in scalars.items(): + writer.add_scalar(k, v, global_step) + for k, v in histograms.items(): + writer.add_histogram(k, v, global_step) + for k, v in images.items(): + writer.add_image(k, v, global_step, dataformats="HWC") + for k, v in audios.items(): + writer.add_audio(k, v, global_step, audio_sampling_rate) + + +def latest_checkpoint_path(dir_path, regex="G_*.pth"): + filelist = glob.glob(os.path.join(dir_path, regex)) + if len(filelist) == 0: + return None + filelist.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) + filepath = filelist[-1] + return filepath + + +def plot_spectrogram_to_numpy(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def plot_alignment_to_numpy(alignment, info=None): + fig, ax = plt.subplots(figsize=(6, 4)) + im = ax.imshow( + alignment.transpose(), aspect="auto", origin="lower", interpolation="none" + ) + fig.colorbar(im, ax=ax) + xlabel = "Decoder timestep" + if info is not None: + xlabel += "\n\n" + info + plt.xlabel(xlabel) + plt.ylabel("Encoder timestep") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def load_wav_to_torch(full_path): + sampling_rate, data = read(full_path) + return torch.FloatTensor(data.astype(np.float32)), sampling_rate + + +def load_config(training_dir: str, sample_rate: int, emb_channels: int): + if emb_channels == 256: + config_path = os.path.join(ROOT_DIR, "configs", f"{sample_rate}.json") + else: + config_path = os.path.join( + ROOT_DIR, "configs", f"{sample_rate}-{emb_channels}.json" + ) + config_save_path = os.path.join(training_dir, "config.json") + + shutil.copyfile(config_path, config_save_path) + + return TrainConfig.parse_file(config_save_path) diff --git a/server/voice_changer/RVC/pipeline/Pipeline.py b/server/voice_changer/RVC/pipeline/Pipeline.py index 1d8fbd0c..4e24acb1 100644 --- a/server/voice_changer/RVC/pipeline/Pipeline.py +++ b/server/voice_changer/RVC/pipeline/Pipeline.py @@ -3,6 +3,7 @@ from typing import Any import math import torch import torch.nn.functional as F +from torch.cuda.amp import autocast from Exceptions import ( DeviceCannotSupportHalfPrecisionException, DeviceChangingException, @@ -118,10 +119,6 @@ class Pipeline(object): # tensor型調整 feats = audio_pad - if self.isHalf is True: - feats = feats.half() - else: - feats = feats.float() if feats.dim() == 2: # double channels feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() @@ -129,19 +126,20 @@ class Pipeline(object): # embedding padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False) - try: - feats = self.embedder.extractFeatures(feats, embOutputLayer, useFinalProj) - if torch.isnan(feats).all(): - raise DeviceCannotSupportHalfPrecisionException() - except RuntimeError as e: - if "HALF" in e.__str__().upper(): - raise HalfPrecisionChangingException() - elif "same device" in e.__str__(): - raise DeviceChangingException() - else: - raise e - if protect < 0.5 and search_index: - feats0 = feats.clone() + with autocast(enabled=self.isHalf): + try: + feats = self.embedder.extractFeatures(feats, embOutputLayer, useFinalProj) + if torch.isnan(feats).all(): + raise DeviceCannotSupportHalfPrecisionException() + except RuntimeError as e: + if "HALF" in e.__str__().upper(): + raise HalfPrecisionChangingException() + elif "same device" in e.__str__(): + raise DeviceChangingException() + else: + raise e + if protect < 0.5 and search_index: + feats0 = feats.clone() # Index - feature抽出 # if self.index is not None and self.feature is not None and index_rate != 0: @@ -167,10 +165,8 @@ class Pipeline(object): # recover silient font npy = np.concatenate([np.zeros([npyOffset, npy.shape[1]]).astype("float32"), npy]) - if self.isHalf is True: - npy = npy.astype("float16") - feats = torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate + (1 - index_rate) * feats + feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) if protect < 0.5 and search_index: feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) @@ -207,14 +203,15 @@ class Pipeline(object): # 推論実行 try: with torch.no_grad(): - audio1 = ( - torch.clip( - self.inferencer.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0].to(dtype=torch.float32), - -1.0, - 1.0, - ) - * 32767.5 - ).data.to(dtype=torch.int16) + with autocast(enabled=self.isHalf): + audio1 = ( + torch.clip( + self.inferencer.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0].to(dtype=torch.float32), + -1.0, + 1.0, + ) + * 32767.5 + ).data.to(dtype=torch.int16) except RuntimeError as e: if "HALF" in e.__str__().upper(): print("11", e)