From 0201b2e44cbd1b2db22428930da2255980dd0f2a Mon Sep 17 00:00:00 2001 From: wataru Date: Thu, 22 Jun 2023 10:46:12 +0900 Subject: [PATCH] WIP: integrate vcs to new gui 4 --- server/voice_changer/DDSP_SVC/DDSP_SVC.py | 88 +- .../DDSP_SVC/DDSP_SVCModelSlotGenerator.py | 22 + .../voice_changer/DDSP_SVC/DDSP_SVCSetting.py | 11 +- server/voice_changer/DDSP_SVC/ModelSlot.py | 8 - server/voice_changer/DDSP_SVC/SvcDDSP.py | 45 +- .../DDSP_SVC/models/ddsp/__init__.py | 0 .../DDSP_SVC/models/ddsp/core.py | 281 ++++ .../DDSP_SVC/models/ddsp/loss.py | 57 + .../DDSP_SVC/models/ddsp/pcmer.py | 380 +++++ .../DDSP_SVC/models/ddsp/unit2control.py | 86 ++ .../DDSP_SVC/models/ddsp/vocoder.py | 639 ++++++++ .../DDSP_SVC/models/diffusion/data_loaders.py | 215 +++ .../DDSP_SVC/models/diffusion/diffusion.py | 342 +++++ .../models/diffusion/diffusion_onnx.py | 526 +++++++ .../models/diffusion/dpm_solver_pytorch.py | 1284 +++++++++++++++++ .../DDSP_SVC/models/diffusion/infer_gt_mel.py | 59 + .../DDSP_SVC/models/diffusion/onnx_export.py | 226 +++ .../DDSP_SVC/models/diffusion/solver.py | 194 +++ .../DDSP_SVC/models/diffusion/uni_pc.py | 731 ++++++++++ .../DDSP_SVC/models/diffusion/unit2mel.py | 77 + .../DDSP_SVC/models/diffusion/vocoder.py | 87 ++ .../DDSP_SVC/models/diffusion/wavenet.py | 91 ++ .../DDSP_SVC/models/encoder/hubert/model.py | 293 ++++ .../voice_changer/DDSP_SVC/models/enhancer.py | 102 ++ .../DDSP_SVC/models/nsf_hifigan/env.py | 15 + .../DDSP_SVC/models/nsf_hifigan/models.py | 434 ++++++ .../DDSP_SVC/models/nsf_hifigan/nvSTFT.py | 129 ++ .../DDSP_SVC/models/nsf_hifigan/utils.py | 68 + .../SoVitsSvc40/models/readme.txt | 2 + .../SoVitsSvc40v2/SoVitsSvc40v2.py | 466 ------ server/voice_changer/VoiceChangerManager.py | 11 +- 31 files changed, 6380 insertions(+), 589 deletions(-) create mode 100644 server/voice_changer/DDSP_SVC/DDSP_SVCModelSlotGenerator.py delete mode 100644 server/voice_changer/DDSP_SVC/ModelSlot.py create mode 100644 server/voice_changer/DDSP_SVC/models/ddsp/__init__.py create mode 100644 server/voice_changer/DDSP_SVC/models/ddsp/core.py create mode 100644 server/voice_changer/DDSP_SVC/models/ddsp/loss.py create mode 100644 server/voice_changer/DDSP_SVC/models/ddsp/pcmer.py create mode 100644 server/voice_changer/DDSP_SVC/models/ddsp/unit2control.py create mode 100644 server/voice_changer/DDSP_SVC/models/ddsp/vocoder.py create mode 100644 server/voice_changer/DDSP_SVC/models/diffusion/data_loaders.py create mode 100644 server/voice_changer/DDSP_SVC/models/diffusion/diffusion.py create mode 100644 server/voice_changer/DDSP_SVC/models/diffusion/diffusion_onnx.py create mode 100644 server/voice_changer/DDSP_SVC/models/diffusion/dpm_solver_pytorch.py create mode 100644 server/voice_changer/DDSP_SVC/models/diffusion/infer_gt_mel.py create mode 100644 server/voice_changer/DDSP_SVC/models/diffusion/onnx_export.py create mode 100644 server/voice_changer/DDSP_SVC/models/diffusion/solver.py create mode 100644 server/voice_changer/DDSP_SVC/models/diffusion/uni_pc.py create mode 100644 server/voice_changer/DDSP_SVC/models/diffusion/unit2mel.py create mode 100644 server/voice_changer/DDSP_SVC/models/diffusion/vocoder.py create mode 100644 server/voice_changer/DDSP_SVC/models/diffusion/wavenet.py create mode 100644 server/voice_changer/DDSP_SVC/models/encoder/hubert/model.py create mode 100644 server/voice_changer/DDSP_SVC/models/enhancer.py create mode 100644 server/voice_changer/DDSP_SVC/models/nsf_hifigan/env.py create mode 100644 server/voice_changer/DDSP_SVC/models/nsf_hifigan/models.py create mode 100644 server/voice_changer/DDSP_SVC/models/nsf_hifigan/nvSTFT.py create mode 100644 server/voice_changer/DDSP_SVC/models/nsf_hifigan/utils.py create mode 100644 server/voice_changer/SoVitsSvc40/models/readme.txt delete mode 100644 server/voice_changer/SoVitsSvc40v2/SoVitsSvc40v2.py diff --git a/server/voice_changer/DDSP_SVC/DDSP_SVC.py b/server/voice_changer/DDSP_SVC/DDSP_SVC.py index 2c306ed9..5ac4a5fc 100644 --- a/server/voice_changer/DDSP_SVC/DDSP_SVC.py +++ b/server/voice_changer/DDSP_SVC/DDSP_SVC.py @@ -4,7 +4,6 @@ from dataclasses import asdict import numpy as np import torch from data.ModelSlot import DDSPSVCModelSlot -from voice_changer.DDSP_SVC.ModelSlot import ModelSlot from voice_changer.DDSP_SVC.deviceManager.DeviceManager import DeviceManager @@ -18,11 +17,10 @@ if sys.platform.startswith("darwin"): else: sys.path.append("DDSP-SVC") -from diffusion.infer_gt_mel import DiffGtMel # type: ignore +from .models.diffusion.infer_gt_mel import DiffGtMel from voice_changer.utils.VoiceChangerModel import AudioInOut from voice_changer.utils.VoiceChangerParams import VoiceChangerParams -from voice_changer.utils.LoadModelParams import LoadModelParams, LoadModelParams2 from voice_changer.DDSP_SVC.DDSP_SVCSetting import DDSP_SVCSettings from voice_changer.RVC.embedder.EmbedderManager import EmbedderManager @@ -51,68 +49,39 @@ def phase_vocoder(a, b, fade_out, fade_in): class DDSP_SVC: initialLoad: bool = True - settings: DDSP_SVCSettings = DDSP_SVCSettings() - diff_model: DiffGtMel = DiffGtMel() - svc_model: SvcDDSP = SvcDDSP() - deviceManager = DeviceManager.get_instance() - # diff_model: DiffGtMel = DiffGtMel() - - audio_buffer: AudioInOut | None = None - prevVol: float = 0 - # resample_kernel = {} - - def __init__(self, params: VoiceChangerParams): + def __init__(self, params: VoiceChangerParams, slotInfo: DDSPSVCModelSlot): + print("[Voice Changer] [DDSP-SVC] Creating instance ") + self.deviceManager = DeviceManager.get_instance() self.gpu_num = torch.cuda.device_count() self.params = params + self.settings = DDSP_SVCSettings() + self.svc_model: SvcDDSP = SvcDDSP() + self.diff_model: DiffGtMel = DiffGtMel() + self.svc_model.setVCParams(params) EmbedderManager.initialize(params) - print("[Voice Changer] DDSP-SVC initialization:", params) - def loadModel(self, props: LoadModelParams): - target_slot_idx = props.slot - params = props.params + self.audio_buffer: AudioInOut | None = None + self.prevVol = 0.0 + self.slotInfo = slotInfo + self.initialize() - modelFile = params["files"]["ddspSvcModel"] - diffusionFile = params["files"]["ddspSvcDiffusion"] - modelSlot = ModelSlot( - modelFile=modelFile, - diffusionFile=diffusionFile, - defaultTrans=params["trans"] if "trans" in params else 0, - ) - self.settings.modelSlots[target_slot_idx] = modelSlot - - # 初回のみロード - # if self.initialLoad: - # self.prepareModel(target_slot_idx) - # self.settings.modelSlotIndex = target_slot_idx - # self.switchModel() - # self.initialLoad = False - # elif target_slot_idx == self.currentSlot: - # self.prepareModel(target_slot_idx) - self.settings.modelSlotIndex = target_slot_idx - self.reloadModel() - - print("params:", params) - return self.get_info() - - def reloadModel(self): + def initialize(self): self.device = self.deviceManager.getDevice(self.settings.gpu) - modelFile = self.settings.modelSlots[self.settings.modelSlotIndex].modelFile - diffusionFile = self.settings.modelSlots[self.settings.modelSlotIndex].diffusionFile self.svc_model = SvcDDSP() self.svc_model.setVCParams(self.params) - self.svc_model.update_model(modelFile, self.device) + self.svc_model.update_model(self.slotInfo.modelFile, self.device) self.diff_model = DiffGtMel(device=self.device) - self.diff_model.flush_model(diffusionFile, ddsp_config=self.svc_model.args) + self.diff_model.flush_model(self.slotInfo.diffModelFile, ddsp_config=self.svc_model.args) def update_settings(self, key: str, val: int | float | str): if key in self.settings.intData: val = int(val) setattr(self.settings, key, val) if key == "gpu": - self.reloadModel() + self.initialize() elif key in self.settings.floatData: setattr(self.settings, key, float(val)) elif key in self.settings.strData: @@ -160,10 +129,6 @@ class DDSP_SVC: # raise NoModeLoadedException("ONNX") def _pyTorch_inference(self, data): - # if hasattr(self, "model") is False or self.model is None: - # print("[Voice Changer] No pyTorch session.") - # raise NoModeLoadedException("pytorch") - input_wav = data[0] _audio, _model_sr = self.svc_model.infer( input_wav, @@ -192,32 +157,13 @@ class DDSP_SVC: return _audio.cpu().numpy() * 32768.0 def inference(self, data): - if self.settings.framework == "ONNX": + if self.slotInfo.isONNX: audio = self._onnx_inference(data) else: audio = self._pyTorch_inference(data) return audio - @classmethod - def loadModel2(cls, props: LoadModelParams2): - slotInfo: DDSPSVCModelSlot = DDSPSVCModelSlot() - for file in props.files: - if file.kind == "ddspSvcModelConfig": - slotInfo.configFile = file.name - elif file.kind == "ddspSvcModel": - slotInfo.modelFile = file.name - elif file.kind == "ddspSvcDiffusionConfig": - slotInfo.diffConfigFile = file.name - elif file.kind == "ddspSvcDiffusion": - slotInfo.diffModelFile = file.name - slotInfo.isONNX = slotInfo.modelFile.endswith(".onnx") - slotInfo.name = os.path.splitext(os.path.basename(slotInfo.modelFile))[0] - return slotInfo - def __del__(self): - del self.net_g - del self.onnx_session - remove_path = os.path.join("DDSP-SVC") sys.path = [x for x in sys.path if x.endswith(remove_path) is False] diff --git a/server/voice_changer/DDSP_SVC/DDSP_SVCModelSlotGenerator.py b/server/voice_changer/DDSP_SVC/DDSP_SVCModelSlotGenerator.py new file mode 100644 index 00000000..4b59ce07 --- /dev/null +++ b/server/voice_changer/DDSP_SVC/DDSP_SVCModelSlotGenerator.py @@ -0,0 +1,22 @@ +import os +from data.ModelSlot import DDSPSVCModelSlot +from voice_changer.utils.LoadModelParams import LoadModelParams +from voice_changer.utils.ModelSlotGenerator import ModelSlotGenerator + + +class DDSP_SVCModelSlotGenerator(ModelSlotGenerator): + @classmethod + def loadModel(cls, props: LoadModelParams): + slotInfo: DDSPSVCModelSlot = DDSPSVCModelSlot() + for file in props.files: + if file.kind == "ddspSvcModelConfig": + slotInfo.configFile = file.name + elif file.kind == "ddspSvcModel": + slotInfo.modelFile = file.name + elif file.kind == "ddspSvcDiffusionConfig": + slotInfo.diffConfigFile = file.name + elif file.kind == "ddspSvcDiffusion": + slotInfo.diffModelFile = file.name + slotInfo.isONNX = slotInfo.modelFile.endswith(".onnx") + slotInfo.name = os.path.splitext(os.path.basename(slotInfo.modelFile))[0] + return slotInfo diff --git a/server/voice_changer/DDSP_SVC/DDSP_SVCSetting.py b/server/voice_changer/DDSP_SVC/DDSP_SVCSetting.py index bf3e5957..d97dd262 100644 --- a/server/voice_changer/DDSP_SVC/DDSP_SVCSetting.py +++ b/server/voice_changer/DDSP_SVC/DDSP_SVCSetting.py @@ -1,7 +1,5 @@ from dataclasses import dataclass, field -from voice_changer.DDSP_SVC.ModelSlot import ModelSlot - @dataclass class DDSP_SVCSettings: @@ -23,14 +21,7 @@ class DDSP_SVCSettings: kStep: int = 120 threshold: int = -45 - framework: str = "PyTorch" # PyTorch or ONNX - pyTorchModelFile: str = "" - onnxModelFile: str = "" - configFile: str = "" - speakers: dict[str, int] = field(default_factory=lambda: {}) - modelSlotIndex: int = -1 - modelSlots: list[ModelSlot] = field(default_factory=lambda: [ModelSlot()]) # ↓mutableな物だけ列挙 intData = [ "gpu", @@ -46,4 +37,4 @@ class DDSP_SVCSettings: "kStep", ] floatData = ["silentThreshold"] - strData = ["framework", "f0Detector", "diffMethod"] + strData = ["f0Detector", "diffMethod"] diff --git a/server/voice_changer/DDSP_SVC/ModelSlot.py b/server/voice_changer/DDSP_SVC/ModelSlot.py deleted file mode 100644 index 16566a7e..00000000 --- a/server/voice_changer/DDSP_SVC/ModelSlot.py +++ /dev/null @@ -1,8 +0,0 @@ -from dataclasses import dataclass - - -@dataclass -class ModelSlot: - modelFile: str = "" - diffusionFile: str = "" - defaultTrans: int = 0 diff --git a/server/voice_changer/DDSP_SVC/SvcDDSP.py b/server/voice_changer/DDSP_SVC/SvcDDSP.py index 65cd424e..e9e982b4 100644 --- a/server/voice_changer/DDSP_SVC/SvcDDSP.py +++ b/server/voice_changer/DDSP_SVC/SvcDDSP.py @@ -3,13 +3,13 @@ import torch try: - from ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder # type: ignore + from .models.ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder except Exception as e: print(e) - from ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder # type: ignore + from .models.ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder -from ddsp.core import upsample # type: ignore -from enhancer import Enhancer # type: ignore +from .models.ddsp.core import upsample +from .models.enhancer import Enhancer from voice_changer.utils.VoiceChangerParams import VoiceChangerParams import numpy as np @@ -38,11 +38,7 @@ class SvcDDSP: print("ARGS:", self.args) # load units encoder - if ( - self.units_encoder is None - or self.args.data.encoder != self.encoder_type - or self.args.data.encoder_ckpt != self.encoder_ckpt - ): + if self.units_encoder is None or self.args.data.encoder != self.encoder_type or self.args.data.encoder_ckpt != self.encoder_ckpt: if self.args.data.encoder == "cnhubertsoftfish": cnhubertsoft_gate = self.args.data.cnhubertsoft_gate else: @@ -77,15 +73,9 @@ class SvcDDSP: self.encoder_ckpt = encoderPath # load enhancer - if ( - self.enhancer is None - or self.args.enhancer.type != self.enhancer_type - or self.args.enhancer.ckpt != self.enhancer_ckpt - ): + if self.enhancer is None or self.args.enhancer.type != self.enhancer_type or self.args.enhancer.ckpt != self.enhancer_ckpt: enhancerPath = self.params.nsf_hifigan - self.enhancer = Enhancer( - self.args.enhancer.type, enhancerPath, device=self.device - ) + self.enhancer = Enhancer(self.args.enhancer.type, enhancerPath, device=self.device) self.enhancer_type = self.args.enhancer.type self.enhancer_ckpt = enhancerPath @@ -118,9 +108,7 @@ class SvcDDSP: # print("audio", audio) # load input # audio, sample_rate = librosa.load(input_wav, sr=None, mono=True) - hop_size = ( - self.args.data.block_size * sample_rate / self.args.data.sampling_rate - ) + hop_size = self.args.data.block_size * sample_rate / self.args.data.sampling_rate if audio_alignment: audio_length = len(audio) # safe front silence @@ -131,12 +119,9 @@ class SvcDDSP: audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(self.device) # extract f0 - pitch_extractor = F0_Extractor( - pitch_extractor_type, sample_rate, hop_size, float(f0_min), float(f0_max) - ) - f0 = pitch_extractor.extract( - audio, uv_interp=True, device=self.device, silence_front=silence_front - ) + print("pitch_extractor_type", pitch_extractor_type) + pitch_extractor = F0_Extractor(pitch_extractor_type, sample_rate, hop_size, float(f0_min), float(f0_max)) + f0 = pitch_extractor.extract(audio, uv_interp=True, device=self.device, silence_front=silence_front) f0 = torch.from_numpy(f0).float().to(self.device).unsqueeze(-1).unsqueeze(0) f0 = f0 * 2 ** (float(pitch_adjust) / 12) @@ -148,9 +133,7 @@ class SvcDDSP: mask = np.array([np.max(mask[n : n + 9]) for n in range(len(mask) - 8)]) # type: ignore mask = torch.from_numpy(mask).float().to(self.device).unsqueeze(-1).unsqueeze(0) mask = upsample(mask, self.args.data.block_size).squeeze(-1) - volume = ( - torch.from_numpy(volume).float().to(self.device).unsqueeze(-1).unsqueeze(0) - ) + volume = torch.from_numpy(volume).float().to(self.device).unsqueeze(-1).unsqueeze(0) # extract units units = self.units_encoder.encode(audio_t, sample_rate, hop_size) @@ -165,9 +148,7 @@ class SvcDDSP: # forward and return the output with torch.no_grad(): - output, _, (s_h, s_n) = self.model( - units, f0, volume, spk_id=spk_id, spk_mix_dict=dictionary - ) + output, _, (s_h, s_n) = self.model(units, f0, volume, spk_id=spk_id, spk_mix_dict=dictionary) if diff_use and diff_model is not None: output = diff_model.infer( diff --git a/server/voice_changer/DDSP_SVC/models/ddsp/__init__.py b/server/voice_changer/DDSP_SVC/models/ddsp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/voice_changer/DDSP_SVC/models/ddsp/core.py b/server/voice_changer/DDSP_SVC/models/ddsp/core.py new file mode 100644 index 00000000..ae17c351 --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/ddsp/core.py @@ -0,0 +1,281 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F + +import math +import numpy as np + +def MaskedAvgPool1d(x, kernel_size): + x = x.unsqueeze(1) + x = F.pad(x, ((kernel_size - 1) // 2, kernel_size // 2), mode="reflect") + mask = ~torch.isnan(x) + masked_x = torch.where(mask, x, torch.zeros_like(x)) + ones_kernel = torch.ones(x.size(1), 1, kernel_size, device=x.device) + + # Perform sum pooling + sum_pooled = F.conv1d( + masked_x, + ones_kernel, + stride=1, + padding=0, + groups=x.size(1), + ) + + # Count the non-masked (valid) elements in each pooling window + valid_count = F.conv1d( + mask.float(), + ones_kernel, + stride=1, + padding=0, + groups=x.size(1), + ) + valid_count = valid_count.clamp(min=1) # Avoid division by zero + + # Perform masked average pooling + avg_pooled = sum_pooled / valid_count + + return avg_pooled.squeeze(1) + +def MedianPool1d(x, kernel_size): + x = x.unsqueeze(1) + x = F.pad(x, ((kernel_size - 1) // 2, kernel_size // 2), mode="reflect") + x = x.squeeze(1) + x = x.unfold(1, kernel_size, 1) + x, _ = torch.sort(x, dim=-1) + return x[:, :, (kernel_size - 1) // 2] + +def get_fft_size(frame_size: int, ir_size: int, power_of_2: bool = True): + """Calculate final size for efficient FFT. + Args: + frame_size: Size of the audio frame. + ir_size: Size of the convolving impulse response. + power_of_2: Constrain to be a power of 2. If False, allow other 5-smooth + numbers. TPU requires power of 2, while GPU is more flexible. + Returns: + fft_size: Size for efficient FFT. + """ + convolved_frame_size = ir_size + frame_size - 1 + if power_of_2: + # Next power of 2. + fft_size = int(2**np.ceil(np.log2(convolved_frame_size))) + else: + fft_size = convolved_frame_size + return fft_size + + +def upsample(signal, factor): + signal = signal.permute(0, 2, 1) + signal = nn.functional.interpolate(torch.cat((signal,signal[:,:,-1:]),2), size=signal.shape[-1] * factor + 1, mode='linear', align_corners=True) + signal = signal[:,:,:-1] + return signal.permute(0, 2, 1) + + +def remove_above_fmax(amplitudes, pitch, fmax, level_start=1): + n_harm = amplitudes.shape[-1] + pitches = pitch * torch.arange(level_start, n_harm + level_start).to(pitch) + aa = (pitches < fmax).float() + 1e-7 + return amplitudes * aa + + +def crop_and_compensate_delay(audio, audio_size, ir_size, + padding = 'same', + delay_compensation = -1): + """Crop audio output from convolution to compensate for group delay. + Args: + audio: Audio after convolution. Tensor of shape [batch, time_steps]. + audio_size: Initial size of the audio before convolution. + ir_size: Size of the convolving impulse response. + padding: Either 'valid' or 'same'. For 'same' the final output to be the + same size as the input audio (audio_timesteps). For 'valid' the audio is + extended to include the tail of the impulse response (audio_timesteps + + ir_timesteps - 1). + delay_compensation: Samples to crop from start of output audio to compensate + for group delay of the impulse response. If delay_compensation < 0 it + defaults to automatically calculating a constant group delay of the + windowed linear phase filter from frequency_impulse_response(). + Returns: + Tensor of cropped and shifted audio. + Raises: + ValueError: If padding is not either 'valid' or 'same'. + """ + # Crop the output. + if padding == 'valid': + crop_size = ir_size + audio_size - 1 + elif padding == 'same': + crop_size = audio_size + else: + raise ValueError('Padding must be \'valid\' or \'same\', instead ' + 'of {}.'.format(padding)) + + # Compensate for the group delay of the filter by trimming the front. + # For an impulse response produced by frequency_impulse_response(), + # the group delay is constant because the filter is linear phase. + total_size = int(audio.shape[-1]) + crop = total_size - crop_size + start = (ir_size // 2 if delay_compensation < 0 else delay_compensation) + end = crop - start + return audio[:, start:-end] + + +def fft_convolve(audio, + impulse_response): # B, n_frames, 2*(n_mags-1) + """Filter audio with frames of time-varying impulse responses. + Time-varying filter. Given audio [batch, n_samples], and a series of impulse + responses [batch, n_frames, n_impulse_response], splits the audio into frames, + applies filters, and then overlap-and-adds audio back together. + Applies non-windowed non-overlapping STFT/ISTFT to efficiently compute + convolution for large impulse response sizes. + Args: + audio: Input audio. Tensor of shape [batch, audio_timesteps]. + impulse_response: Finite impulse response to convolve. Can either be a 2-D + Tensor of shape [batch, ir_size], or a 3-D Tensor of shape [batch, + ir_frames, ir_size]. A 2-D tensor will apply a single linear + time-invariant filter to the audio. A 3-D Tensor will apply a linear + time-varying filter. Automatically chops the audio into equally shaped + blocks to match ir_frames. + Returns: + audio_out: Convolved audio. Tensor of shape + [batch, audio_timesteps]. + """ + # Add a frame dimension to impulse response if it doesn't have one. + ir_shape = impulse_response.size() + if len(ir_shape) == 2: + impulse_response = impulse_response.unsqueeze(1) + ir_shape = impulse_response.size() + + # Get shapes of audio and impulse response. + batch_size_ir, n_ir_frames, ir_size = ir_shape + batch_size, audio_size = audio.size() # B, T + + # Validate that batch sizes match. + if batch_size != batch_size_ir: + raise ValueError('Batch size of audio ({}) and impulse response ({}) must ' + 'be the same.'.format(batch_size, batch_size_ir)) + + # Cut audio into 50% overlapped frames (center padding). + hop_size = int(audio_size / n_ir_frames) + frame_size = 2 * hop_size + audio_frames = F.pad(audio, (hop_size, hop_size)).unfold(1, frame_size, hop_size) + + # Apply Bartlett (triangular) window + window = torch.bartlett_window(frame_size).to(audio_frames) + audio_frames = audio_frames * window + + # Pad and FFT the audio and impulse responses. + fft_size = get_fft_size(frame_size, ir_size, power_of_2=False) + audio_fft = torch.fft.rfft(audio_frames, fft_size) + ir_fft = torch.fft.rfft(torch.cat((impulse_response,impulse_response[:,-1:,:]),1), fft_size) + + # Multiply the FFTs (same as convolution in time). + audio_ir_fft = torch.multiply(audio_fft, ir_fft) + + # Take the IFFT to resynthesize audio. + audio_frames_out = torch.fft.irfft(audio_ir_fft, fft_size) + + # Overlap Add + batch_size, n_audio_frames, frame_size = audio_frames_out.size() # # B, n_frames+1, 2*(hop_size+n_mags-1)-1 + fold = torch.nn.Fold(output_size=(1, (n_audio_frames - 1) * hop_size + frame_size),kernel_size=(1, frame_size),stride=(1, hop_size)) + output_signal = fold(audio_frames_out.transpose(1, 2)).squeeze(1).squeeze(1) + + # Crop and shift the output audio. + output_signal = crop_and_compensate_delay(output_signal[:,hop_size:], audio_size, ir_size) + return output_signal + + +def apply_window_to_impulse_response(impulse_response, # B, n_frames, 2*(n_mag-1) + window_size: int = 0, + causal: bool = False): + """Apply a window to an impulse response and put in causal form. + Args: + impulse_response: A series of impulse responses frames to window, of shape + [batch, n_frames, ir_size]. ---------> ir_size means size of filter_bank ?????? + + window_size: Size of the window to apply in the time domain. If window_size + is less than 1, it defaults to the impulse_response size. + causal: Impulse response input is in causal form (peak in the middle). + Returns: + impulse_response: Windowed impulse response in causal form, with last + dimension cropped to window_size if window_size is greater than 0 and less + than ir_size. + """ + + # If IR is in causal form, put it in zero-phase form. + if causal: + impulse_response = torch.fftshift(impulse_response, axes=-1) + + # Get a window for better time/frequency resolution than rectangular. + # Window defaults to IR size, cannot be bigger. + ir_size = int(impulse_response.size(-1)) + if (window_size <= 0) or (window_size > ir_size): + window_size = ir_size + window = nn.Parameter(torch.hann_window(window_size), requires_grad = False).to(impulse_response) + + # Zero pad the window and put in in zero-phase form. + padding = ir_size - window_size + if padding > 0: + half_idx = (window_size + 1) // 2 + window = torch.cat([window[half_idx:], + torch.zeros([padding]), + window[:half_idx]], axis=0) + else: + window = window.roll(window.size(-1)//2, -1) + + # Apply the window, to get new IR (both in zero-phase form). + window = window.unsqueeze(0) + impulse_response = impulse_response * window + + # Put IR in causal form and trim zero padding. + if padding > 0: + first_half_start = (ir_size - (half_idx - 1)) + 1 + second_half_end = half_idx + 1 + impulse_response = torch.cat([impulse_response[..., first_half_start:], + impulse_response[..., :second_half_end]], + dim=-1) + else: + impulse_response = impulse_response.roll(impulse_response.size(-1)//2, -1) + + return impulse_response + + +def apply_dynamic_window_to_impulse_response(impulse_response, # B, n_frames, 2*(n_mag-1) or 2*n_mag-1 + half_width_frames): # B,n_frames, 1 + ir_size = int(impulse_response.size(-1)) # 2*(n_mag -1) or 2*n_mag-1 + + window = torch.arange(-(ir_size // 2), (ir_size + 1) // 2).to(impulse_response) / half_width_frames + window[window > 1] = 0 + window = (1 + torch.cos(np.pi * window)) / 2 # B, n_frames, 2*(n_mag -1) or 2*n_mag-1 + + impulse_response = impulse_response.roll(ir_size // 2, -1) + impulse_response = impulse_response * window + + return impulse_response + + +def frequency_impulse_response(magnitudes, + hann_window = True, + half_width_frames = None): + + # Get the IR + impulse_response = torch.fft.irfft(magnitudes) # B, n_frames, 2*(n_mags-1) + + # Window and put in causal form. + if hann_window: + if half_width_frames is None: + impulse_response = apply_window_to_impulse_response(impulse_response) + else: + impulse_response = apply_dynamic_window_to_impulse_response(impulse_response, half_width_frames) + else: + impulse_response = impulse_response.roll(impulse_response.size(-1) // 2, -1) + + return impulse_response + + +def frequency_filter(audio, + magnitudes, + hann_window=True, + half_width_frames=None): + + impulse_response = frequency_impulse_response(magnitudes, hann_window, half_width_frames) + + return fft_convolve(audio, impulse_response) + \ No newline at end of file diff --git a/server/voice_changer/DDSP_SVC/models/ddsp/loss.py b/server/voice_changer/DDSP_SVC/models/ddsp/loss.py new file mode 100644 index 00000000..47a9be63 --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/ddsp/loss.py @@ -0,0 +1,57 @@ +import numpy as np + +import torch +import torch.nn as nn +import torchaudio +from torch.nn import functional as F +from .core import upsample + +class SSSLoss(nn.Module): + """ + Single-scale Spectral Loss. + """ + + def __init__(self, n_fft=111, alpha=1.0, overlap=0, eps=1e-7): + super().__init__() + self.n_fft = n_fft + self.alpha = alpha + self.eps = eps + self.hop_length = int(n_fft * (1 - overlap)) # 25% of the length + self.spec = torchaudio.transforms.Spectrogram(n_fft=self.n_fft, hop_length=self.hop_length, power=1, normalized=True, center=False) + + def forward(self, x_true, x_pred): + S_true = self.spec(x_true) + self.eps + S_pred = self.spec(x_pred) + self.eps + + converge_term = torch.mean(torch.linalg.norm(S_true - S_pred, dim = (1, 2)) / torch.linalg.norm(S_true + S_pred, dim = (1, 2))) + + log_term = F.l1_loss(S_true.log(), S_pred.log()) + + loss = converge_term + self.alpha * log_term + return loss + + +class RSSLoss(nn.Module): + ''' + Random-scale Spectral Loss. + ''' + + def __init__(self, fft_min, fft_max, n_scale, alpha=1.0, overlap=0, eps=1e-7, device='cuda'): + super().__init__() + self.fft_min = fft_min + self.fft_max = fft_max + self.n_scale = n_scale + self.lossdict = {} + for n_fft in range(fft_min, fft_max): + self.lossdict[n_fft] = SSSLoss(n_fft, alpha, overlap, eps).to(device) + + def forward(self, x_pred, x_true): + value = 0. + n_ffts = torch.randint(self.fft_min, self.fft_max, (self.n_scale,)) + for n_fft in n_ffts: + loss_func = self.lossdict[int(n_fft)] + value += loss_func(x_true, x_pred) + return value / self.n_scale + + + diff --git a/server/voice_changer/DDSP_SVC/models/ddsp/pcmer.py b/server/voice_changer/DDSP_SVC/models/ddsp/pcmer.py new file mode 100644 index 00000000..f0eb32a5 --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/ddsp/pcmer.py @@ -0,0 +1,380 @@ +import torch + +from torch import nn +import math +from functools import partial +from einops import rearrange, repeat + +from local_attention import LocalAttention +import torch.nn.functional as F +#import fast_transformers.causal_product.causal_product_cuda + +def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None): + b, h, *_ = data.shape + # (batch size, head, length, model_dim) + + # normalize model dim + data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. + + # what is ration?, projection_matrix.shape[0] --> 266 + + ratio = (projection_matrix.shape[0] ** -0.5) + + projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h) + projection = projection.type_as(data) + + #data_dash = w^T x + data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection) + + + # diag_data = D**2 + diag_data = data ** 2 + diag_data = torch.sum(diag_data, dim=-1) + diag_data = (diag_data / 2.0) * (data_normalizer ** 2) + diag_data = diag_data.unsqueeze(dim=-1) + + #print () + if is_query: + data_dash = ratio * ( + torch.exp(data_dash - diag_data - + torch.max(data_dash, dim=-1, keepdim=True).values) + eps) + else: + data_dash = ratio * ( + torch.exp(data_dash - diag_data + eps))#- torch.max(data_dash)) + eps) + + return data_dash.type_as(data) + +def orthogonal_matrix_chunk(cols, qr_uniform_q = False, device = None): + unstructured_block = torch.randn((cols, cols), device = device) + q, r = torch.linalg.qr(unstructured_block.cpu(), mode='reduced') + q, r = map(lambda t: t.to(device), (q, r)) + + # proposed by @Parskatt + # to make sure Q is uniform https://arxiv.org/pdf/math-ph/0609050.pdf + if qr_uniform_q: + d = torch.diag(r, 0) + q *= d.sign() + return q.t() +def exists(val): + return val is not None + +def empty(tensor): + return tensor.numel() == 0 + +def default(val, d): + return val if exists(val) else d + +def cast_tuple(val): + return (val,) if not isinstance(val, tuple) else val + +class PCmer(nn.Module): + """The encoder that is used in the Transformer model.""" + + def __init__(self, + num_layers, + num_heads, + dim_model, + dim_keys, + dim_values, + residual_dropout, + attention_dropout): + super().__init__() + self.num_layers = num_layers + self.num_heads = num_heads + self.dim_model = dim_model + self.dim_values = dim_values + self.dim_keys = dim_keys + self.residual_dropout = residual_dropout + self.attention_dropout = attention_dropout + + self._layers = nn.ModuleList([_EncoderLayer(self) for _ in range(num_layers)]) + + # METHODS ######################################################################################################## + + def forward(self, phone, mask=None): + + # apply all layers to the input + for (i, layer) in enumerate(self._layers): + phone = layer(phone, mask) + # provide the final sequence + return phone + + +# ==================================================================================================================== # +# CLASS _ E N C O D E R L A Y E R # +# ==================================================================================================================== # + + +class _EncoderLayer(nn.Module): + """One layer of the encoder. + + Attributes: + attn: (:class:`mha.MultiHeadAttention`): The attention mechanism that is used to read the input sequence. + feed_forward (:class:`ffl.FeedForwardLayer`): The feed-forward layer on top of the attention mechanism. + """ + + def __init__(self, parent: PCmer): + """Creates a new instance of ``_EncoderLayer``. + + Args: + parent (Encoder): The encoder that the layers is created for. + """ + super().__init__() + + + self.conformer = ConformerConvModule(parent.dim_model) + self.norm = nn.LayerNorm(parent.dim_model) + self.dropout = nn.Dropout(parent.residual_dropout) + + # selfatt -> fastatt: performer! + self.attn = SelfAttention(dim = parent.dim_model, + heads = parent.num_heads, + causal = False) + + # METHODS ######################################################################################################## + + def forward(self, phone, mask=None): + + # compute attention sub-layer + phone = phone + (self.attn(self.norm(phone), mask=mask)) + + phone = phone + (self.conformer(phone)) + + return phone + +def calc_same_padding(kernel_size): + pad = kernel_size // 2 + return (pad, pad - (kernel_size + 1) % 2) + +# helper classes + +class Swish(nn.Module): + def forward(self, x): + return x * x.sigmoid() + +class Transpose(nn.Module): + def __init__(self, dims): + super().__init__() + assert len(dims) == 2, 'dims must be a tuple of two dimensions' + self.dims = dims + + def forward(self, x): + return x.transpose(*self.dims) + +class GLU(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + out, gate = x.chunk(2, dim=self.dim) + return out * gate.sigmoid() + +class DepthWiseConv1d(nn.Module): + def __init__(self, chan_in, chan_out, kernel_size, padding): + super().__init__() + self.padding = padding + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in) + + def forward(self, x): + x = F.pad(x, self.padding) + return self.conv(x) + +class ConformerConvModule(nn.Module): + def __init__( + self, + dim, + causal = False, + expansion_factor = 2, + kernel_size = 31, + dropout = 0.): + super().__init__() + + inner_dim = dim * expansion_factor + padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0) + + self.net = nn.Sequential( + nn.LayerNorm(dim), + Transpose((1, 2)), + nn.Conv1d(dim, inner_dim * 2, 1), + GLU(dim=1), + DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding), + #nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(), + Swish(), + nn.Conv1d(inner_dim, dim, 1), + Transpose((1, 2)), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + +def linear_attention(q, k, v): + if v is None: + #print (k.size(), q.size()) + out = torch.einsum('...ed,...nd->...ne', k, q) + return out + + else: + k_cumsum = k.sum(dim = -2) + #k_cumsum = k.sum(dim = -2) + D_inv = 1. / (torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q)) + 1e-8) + + context = torch.einsum('...nd,...ne->...de', k, v) + #print ("TRUEEE: ", context.size(), q.size(), D_inv.size()) + out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv) + return out + +def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, qr_uniform_q = False, device = None): + nb_full_blocks = int(nb_rows / nb_columns) + #print (nb_full_blocks) + block_list = [] + + for _ in range(nb_full_blocks): + q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device) + block_list.append(q) + # block_list[n] is a orthogonal matrix ... (model_dim * model_dim) + #print (block_list[0].size(), torch.einsum('...nd,...nd->...n', block_list[0], torch.roll(block_list[0],1,1))) + #print (nb_rows, nb_full_blocks, nb_columns) + remaining_rows = nb_rows - nb_full_blocks * nb_columns + #print (remaining_rows) + if remaining_rows > 0: + q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device) + #print (q[:remaining_rows].size()) + block_list.append(q[:remaining_rows]) + + final_matrix = torch.cat(block_list) + + if scaling == 0: + multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1) + elif scaling == 1: + multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device) + else: + raise ValueError(f'Invalid scaling {scaling}') + + return torch.diag(multiplier) @ final_matrix + +class FastAttention(nn.Module): + def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, no_projection = False): + super().__init__() + nb_features = default(nb_features, int(dim_heads * math.log(dim_heads))) + + self.dim_heads = dim_heads + self.nb_features = nb_features + self.ortho_scaling = ortho_scaling + + self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling, qr_uniform_q = qr_uniform_q) + projection_matrix = self.create_projection() + self.register_buffer('projection_matrix', projection_matrix) + + self.generalized_attention = generalized_attention + self.kernel_fn = kernel_fn + + # if this is turned on, no projection will be used + # queries and keys will be softmax-ed as in the original efficient attention paper + self.no_projection = no_projection + + self.causal = causal + if causal: + try: + import fast_transformers.causal_product.causal_product_cuda + self.causal_linear_fn = partial(causal_linear_attention) + except ImportError: + print('unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version') + self.causal_linear_fn = causal_linear_attention_noncuda + @torch.no_grad() + def redraw_projection_matrix(self): + projections = self.create_projection() + self.projection_matrix.copy_(projections) + del projections + + def forward(self, q, k, v): + device = q.device + + if self.no_projection: + q = q.softmax(dim = -1) + k = torch.exp(k) if self.causal else k.softmax(dim = -2) + + elif self.generalized_attention: + create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device) + q, k = map(create_kernel, (q, k)) + + else: + create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device) + + q = create_kernel(q, is_query = True) + k = create_kernel(k, is_query = False) + + attn_fn = linear_attention if not self.causal else self.causal_linear_fn + if v is None: + out = attn_fn(q, k, None) + return out + else: + out = attn_fn(q, k, v) + return out +class SelfAttention(nn.Module): + def __init__(self, dim, causal = False, heads = 8, dim_head = 64, local_heads = 0, local_window_size = 256, nb_features = None, feature_redraw_interval = 1000, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, dropout = 0., no_projection = False): + super().__init__() + assert dim % heads == 0, 'dimension must be divisible by number of heads' + dim_head = default(dim_head, dim // heads) + inner_dim = dim_head * heads + self.fast_attention = FastAttention(dim_head, nb_features, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, qr_uniform_q = qr_uniform_q, no_projection = no_projection) + + self.heads = heads + self.global_heads = heads - local_heads + self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None + + #print (heads, nb_features, dim_head) + #name_embedding = torch.zeros(110, heads, dim_head, dim_head) + #self.name_embedding = nn.Parameter(name_embedding, requires_grad=True) + + + self.to_q = nn.Linear(dim, inner_dim) + self.to_k = nn.Linear(dim, inner_dim) + self.to_v = nn.Linear(dim, inner_dim) + self.to_out = nn.Linear(inner_dim, dim) + self.dropout = nn.Dropout(dropout) + + @torch.no_grad() + def redraw_projection_matrix(self): + self.fast_attention.redraw_projection_matrix() + #torch.nn.init.zeros_(self.name_embedding) + #print (torch.sum(self.name_embedding)) + def forward(self, x, context = None, mask = None, context_mask = None, name=None, inference=False, **kwargs): + b, n, _, h, gh = *x.shape, self.heads, self.global_heads + + cross_attend = exists(context) + + context = default(context, x) + context_mask = default(context_mask, mask) if not cross_attend else context_mask + #print (torch.sum(self.name_embedding)) + q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v)) + + attn_outs = [] + #print (name) + #print (self.name_embedding[name].size()) + if not empty(q): + if exists(context_mask): + global_mask = context_mask[:, None, :, None] + v.masked_fill_(~global_mask, 0.) + if cross_attend: + pass + #print (torch.sum(self.name_embedding)) + #out = self.fast_attention(q,self.name_embedding[name],None) + #print (torch.sum(self.name_embedding[...,-1:])) + else: + out = self.fast_attention(q, k, v) + attn_outs.append(out) + + if not empty(lq): + assert not cross_attend, 'local attention is not compatible with cross attention' + out = self.local_attn(lq, lk, lv, input_mask = mask) + attn_outs.append(out) + + out = torch.cat(attn_outs, dim = 1) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + return self.dropout(out) \ No newline at end of file diff --git a/server/voice_changer/DDSP_SVC/models/ddsp/unit2control.py b/server/voice_changer/DDSP_SVC/models/ddsp/unit2control.py new file mode 100644 index 00000000..838ba892 --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/ddsp/unit2control.py @@ -0,0 +1,86 @@ +import gin + +import numpy as np +import torch +import torch.nn as nn +from torch.nn.utils import weight_norm + +from .pcmer import PCmer + + +def split_to_dict(tensor, tensor_splits): + """Split a tensor into a dictionary of multiple tensors.""" + labels = [] + sizes = [] + + for k, v in tensor_splits.items(): + labels.append(k) + sizes.append(v) + + tensors = torch.split(tensor, sizes, dim=-1) + return dict(zip(labels, tensors)) + + +class Unit2Control(nn.Module): + def __init__( + self, + input_channel, + n_spk, + output_splits): + super().__init__() + self.output_splits = output_splits + self.f0_embed = nn.Linear(1, 256) + self.phase_embed = nn.Linear(1, 256) + self.volume_embed = nn.Linear(1, 256) + self.n_spk = n_spk + if n_spk is not None and n_spk > 1: + self.spk_embed = nn.Embedding(n_spk, 256) + + # conv in stack + self.stack = nn.Sequential( + nn.Conv1d(input_channel, 256, 3, 1, 1), + nn.GroupNorm(4, 256), + nn.LeakyReLU(), + nn.Conv1d(256, 256, 3, 1, 1)) + + # transformer + self.decoder = PCmer( + num_layers=3, + num_heads=8, + dim_model=256, + dim_keys=256, + dim_values=256, + residual_dropout=0.1, + attention_dropout=0.1) + self.norm = nn.LayerNorm(256) + + # out + self.n_out = sum([v for k, v in output_splits.items()]) + self.dense_out = weight_norm( + nn.Linear(256, self.n_out)) + + def forward(self, units, f0, phase, volume, spk_id = None, spk_mix_dict = None): + + ''' + input: + B x n_frames x n_unit + return: + dict of B x n_frames x feat + ''' + + x = self.stack(units.transpose(1,2)).transpose(1,2) + x = x + self.f0_embed((1+ f0 / 700).log()) + self.phase_embed(phase / np.pi) + self.volume_embed(volume) + if self.n_spk is not None and self.n_spk > 1: + if spk_mix_dict is not None: + for k, v in spk_mix_dict.items(): + spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device) + x = x + v * self.spk_embed(spk_id_torch - 1) + else: + x = x + self.spk_embed(spk_id - 1) + x = self.decoder(x) + x = self.norm(x) + e = self.dense_out(x) + controls = split_to_dict(e, self.output_splits) + + return controls + diff --git a/server/voice_changer/DDSP_SVC/models/ddsp/vocoder.py b/server/voice_changer/DDSP_SVC/models/ddsp/vocoder.py new file mode 100644 index 00000000..ffa3db0c --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/ddsp/vocoder.py @@ -0,0 +1,639 @@ +import os +import numpy as np +import yaml +import torch +import torch.nn.functional as F +import pyworld as pw +import parselmouth +import torchcrepe +from transformers import HubertModel, Wav2Vec2FeatureExtractor +from fairseq import checkpoint_utils +from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present +from torchaudio.transforms import Resample +from .unit2control import Unit2Control +from .core import frequency_filter, upsample, remove_above_fmax, MaskedAvgPool1d, MedianPool1d +from ..encoder.hubert.model import HubertSoft + +CREPE_RESAMPLE_KERNEL = {} + + +class F0_Extractor: + def __init__(self, f0_extractor, sample_rate=44100, hop_size=512, f0_min=65, f0_max=800): + self.f0_extractor = f0_extractor + self.sample_rate = sample_rate + self.hop_size = hop_size + self.f0_min = f0_min + self.f0_max = f0_max + if f0_extractor == "crepe": + key_str = str(sample_rate) + if key_str not in CREPE_RESAMPLE_KERNEL: + CREPE_RESAMPLE_KERNEL[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128) + self.resample_kernel = CREPE_RESAMPLE_KERNEL[key_str] + + def extract(self, audio, uv_interp=False, device=None, silence_front=0): # audio: 1d numpy array + # extractor start time + n_frames = int(len(audio) // self.hop_size) + 1 + + start_frame = int(silence_front * self.sample_rate / self.hop_size) + real_silence_front = start_frame * self.hop_size / self.sample_rate + audio = audio[int(np.round(real_silence_front * self.sample_rate)) :] + + # extract f0 using parselmouth + if self.f0_extractor == "parselmouth": + f0 = parselmouth.Sound(audio, self.sample_rate).to_pitch_ac(time_step=self.hop_size / self.sample_rate, voicing_threshold=0.6, pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array["frequency"] + pad_size = start_frame + (int(len(audio) // self.hop_size) - len(f0) + 1) // 2 + f0 = np.pad(f0, (pad_size, n_frames - len(f0) - pad_size)) + + # extract f0 using dio + elif self.f0_extractor == "dio": + _f0, t = pw.dio(audio.astype("double"), self.sample_rate, f0_floor=self.f0_min, f0_ceil=self.f0_max, channels_in_octave=2, frame_period=(1000 * self.hop_size / self.sample_rate)) + f0 = pw.stonemask(audio.astype("double"), _f0, t, self.sample_rate) + f0 = np.pad(f0.astype("float"), (start_frame, n_frames - len(f0) - start_frame)) + + # extract f0 using harvest + elif self.f0_extractor == "harvest": + f0, _ = pw.harvest(audio.astype("double"), self.sample_rate, f0_floor=self.f0_min, f0_ceil=self.f0_max, frame_period=(1000 * self.hop_size / self.sample_rate)) + f0 = np.pad(f0.astype("float"), (start_frame, n_frames - len(f0) - start_frame)) + + # extract f0 using crepe + elif self.f0_extractor == "crepe": + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + resample_kernel = self.resample_kernel.to(device) + wav16k_torch = resample_kernel(torch.FloatTensor(audio).unsqueeze(0).to(device)) + + f0, pd = torchcrepe.predict(wav16k_torch, 16000, 80, self.f0_min, self.f0_max, pad=True, model="full", batch_size=512, device=device, return_periodicity=True) + pd = MedianPool1d(pd, 4) + f0 = torchcrepe.threshold.At(0.05)(f0, pd) + f0 = MaskedAvgPool1d(f0, 4) + + f0 = f0.squeeze(0).cpu().numpy() + f0 = np.array([f0[int(min(int(np.round(n * self.hop_size / self.sample_rate / 0.005)), len(f0) - 1))] for n in range(n_frames - start_frame)]) + f0 = np.pad(f0, (start_frame, 0)) + + else: + raise ValueError(f" [x] Unknown f0 extractor: {f0_extractor}") + + # interpolate the unvoiced f0 + if uv_interp: + uv = f0 == 0 + if len(f0[~uv]) > 0: + f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) + f0[f0 < self.f0_min] = self.f0_min + return f0 + + +class Volume_Extractor: + def __init__(self, hop_size=512): + self.hop_size = hop_size + + def extract(self, audio): # audio: 1d numpy array + n_frames = int(len(audio) // self.hop_size) + 1 + audio2 = audio**2 + audio2 = np.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)), mode="reflect") + volume = np.array([np.mean(audio2[int(n * self.hop_size) : int((n + 1) * self.hop_size)]) for n in range(n_frames)]) + volume = np.sqrt(volume) + return volume + + +class Units_Encoder: + def __init__(self, encoder, encoder_ckpt, encoder_sample_rate=16000, encoder_hop_size=320, device=None, cnhubertsoft_gate=10): + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = device + + is_loaded_encoder = False + if encoder == "hubertsoft": + self.model = Audio2HubertSoft(encoder_ckpt).to(device) + is_loaded_encoder = True + if encoder == "hubertbase": + self.model = Audio2HubertBase(encoder_ckpt, device=device) + is_loaded_encoder = True + if encoder == "hubertbase768": + self.model = Audio2HubertBase768(encoder_ckpt, device=device) + is_loaded_encoder = True + if encoder == "hubertbase768l12": + self.model = Audio2HubertBase768L12(encoder_ckpt, device=device) + is_loaded_encoder = True + if encoder == "hubertlarge1024l24": + self.model = Audio2HubertLarge1024L24(encoder_ckpt, device=device) + is_loaded_encoder = True + if encoder == "contentvec": + self.model = Audio2ContentVec(encoder_ckpt, device=device) + is_loaded_encoder = True + if encoder == "contentvec768": + self.model = Audio2ContentVec768(encoder_ckpt, device=device) + is_loaded_encoder = True + if encoder == "contentvec768l12": + self.model = Audio2ContentVec768L12(encoder_ckpt, device=device) + is_loaded_encoder = True + if encoder == "cnhubertsoftfish": + self.model = CNHubertSoftFish(encoder_ckpt, device=device, gate_size=cnhubertsoft_gate) + is_loaded_encoder = True + if not is_loaded_encoder: + raise ValueError(f" [x] Unknown units encoder: {encoder}") + + self.resample_kernel = {} + self.encoder_sample_rate = encoder_sample_rate + self.encoder_hop_size = encoder_hop_size + + def encode(self, audio, sample_rate, hop_size): # B, T + # resample + if sample_rate == self.encoder_sample_rate: + audio_res = audio + else: + key_str = str(sample_rate) + if key_str not in self.resample_kernel: + self.resample_kernel[key_str] = Resample(sample_rate, self.encoder_sample_rate, lowpass_filter_width=128).to(self.device) + audio_res = self.resample_kernel[key_str](audio) + + # encode + if audio_res.size(-1) < 400: + audio_res = torch.nn.functional.pad(audio, (0, 400 - audio_res.size(-1))) + units = self.model(audio_res) + + # alignment + n_frames = audio.size(-1) // hop_size + 1 + ratio = (hop_size / sample_rate) / (self.encoder_hop_size / self.encoder_sample_rate) + index = torch.clamp(torch.round(ratio * torch.arange(n_frames).to(self.device)).long(), max=units.size(1) - 1) + units_aligned = torch.gather(units, 1, index.unsqueeze(0).unsqueeze(-1).repeat([1, 1, units.size(-1)])) + return units_aligned + + +class Audio2HubertSoft(torch.nn.Module): + def __init__(self, path, h_sample_rate=16000, h_hop_size=320): + super().__init__() + print(" [Encoder Model] HuBERT Soft") + self.hubert = HubertSoft() + print(" [Loading] " + path) + checkpoint = torch.load(path) + consume_prefix_in_state_dict_if_present(checkpoint, "module.") + self.hubert.load_state_dict(checkpoint) + self.hubert.eval() + + def forward(self, audio): # B, T + with torch.inference_mode(): + units = self.hubert.units(audio.unsqueeze(1)) + return units + + +class Audio2ContentVec: + def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"): + self.device = device + print(" [Encoder Model] Content Vec") + print(" [Loading] " + path) + self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task( + [path], + suffix="", + ) + self.hubert = self.models[0] + self.hubert = self.hubert.to(self.device) + self.hubert.eval() + + def __call__(self, audio): # B, T + # wav_tensor = torch.from_numpy(audio).to(self.device) + wav_tensor = audio + feats = wav_tensor.view(1, -1) + padding_mask = torch.BoolTensor(feats.shape).fill_(False) + inputs = { + "source": feats.to(wav_tensor.device), + "padding_mask": padding_mask.to(wav_tensor.device), + "output_layer": 9, # layer 9 + } + with torch.no_grad(): + logits = self.hubert.extract_features(**inputs) + feats = self.hubert.final_proj(logits[0]) + units = feats # .transpose(2, 1) + return units + + +class Audio2ContentVec768: + def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"): + self.device = device + print(" [Encoder Model] Content Vec") + print(" [Loading] " + path) + self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task( + [path], + suffix="", + ) + self.hubert = self.models[0] + self.hubert = self.hubert.to(self.device) + self.hubert.eval() + + def __call__(self, audio): # B, T + # wav_tensor = torch.from_numpy(audio).to(self.device) + wav_tensor = audio + feats = wav_tensor.view(1, -1) + padding_mask = torch.BoolTensor(feats.shape).fill_(False) + inputs = { + "source": feats.to(wav_tensor.device), + "padding_mask": padding_mask.to(wav_tensor.device), + "output_layer": 9, # layer 9 + } + with torch.no_grad(): + logits = self.hubert.extract_features(**inputs) + feats = logits[0] + units = feats # .transpose(2, 1) + return units + + +class Audio2ContentVec768L12: + def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"): + self.device = device + print(" [Encoder Model] Content Vec") + print(" [Loading] " + path) + self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task( + [path], + suffix="", + ) + self.hubert = self.models[0] + self.hubert = self.hubert.to(self.device) + self.hubert.eval() + + def __call__(self, audio): # B, T + # wav_tensor = torch.from_numpy(audio).to(self.device) + wav_tensor = audio + feats = wav_tensor.view(1, -1) + padding_mask = torch.BoolTensor(feats.shape).fill_(False) + inputs = { + "source": feats.to(wav_tensor.device), + "padding_mask": padding_mask.to(wav_tensor.device), + "output_layer": 12, # layer 12 + } + with torch.no_grad(): + logits = self.hubert.extract_features(**inputs) + feats = logits[0] + units = feats # .transpose(2, 1) + return units + + +class CNHubertSoftFish(torch.nn.Module): + def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu", gate_size=10): + super().__init__() + self.device = device + self.gate_size = gate_size + + self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("./pretrain/TencentGameMate/chinese-hubert-base") + self.model = HubertModel.from_pretrained("./pretrain/TencentGameMate/chinese-hubert-base") + self.proj = torch.nn.Sequential(torch.nn.Dropout(0.1), torch.nn.Linear(768, 256)) + # self.label_embedding = nn.Embedding(128, 256) + + state_dict = torch.load(path, map_location=device) + self.load_state_dict(state_dict) + + @torch.no_grad() + def forward(self, audio): + input_values = self.feature_extractor(audio, sampling_rate=16000, return_tensors="pt").input_values + input_values = input_values.to(self.model.device) + + return self._forward(input_values[0]) + + @torch.no_grad() + def _forward(self, input_values): + features = self.model(input_values) + features = self.proj(features.last_hidden_state) + + # Top-k gating + topk, indices = torch.topk(features, self.gate_size, dim=2) + features = torch.zeros_like(features).scatter(2, indices, topk) + features = features / features.sum(2, keepdim=True) + + return features.to(self.device) # .transpose(1, 2) + + +class Audio2HubertBase: + def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"): + self.device = device + print(" [Encoder Model] HuBERT Base") + print(" [Loading] " + path) + self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task( + [path], + suffix="", + ) + self.hubert = self.models[0] + self.hubert = self.hubert.to(self.device) + self.hubert = self.hubert.float() + self.hubert.eval() + + def __call__(self, audio): # B, T + with torch.no_grad(): + padding_mask = torch.BoolTensor(audio.shape).fill_(False) + inputs = { + "source": audio.to(self.device), + "padding_mask": padding_mask.to(self.device), + "output_layer": 9, # layer 9 + } + logits = self.hubert.extract_features(**inputs) + units = self.hubert.final_proj(logits[0]) + return units + + +class Audio2HubertBase768: + def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"): + self.device = device + print(" [Encoder Model] HuBERT Base") + print(" [Loading] " + path) + self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task( + [path], + suffix="", + ) + self.hubert = self.models[0] + self.hubert = self.hubert.to(self.device) + self.hubert = self.hubert.float() + self.hubert.eval() + + def __call__(self, audio): # B, T + with torch.no_grad(): + padding_mask = torch.BoolTensor(audio.shape).fill_(False) + inputs = { + "source": audio.to(self.device), + "padding_mask": padding_mask.to(self.device), + "output_layer": 9, # layer 9 + } + logits = self.hubert.extract_features(**inputs) + units = logits[0] + return units + + +class Audio2HubertBase768L12: + def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"): + self.device = device + print(" [Encoder Model] HuBERT Base") + print(" [Loading] " + path) + self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task( + [path], + suffix="", + ) + self.hubert = self.models[0] + self.hubert = self.hubert.to(self.device) + self.hubert = self.hubert.float() + self.hubert.eval() + + def __call__(self, audio): # B, T + with torch.no_grad(): + padding_mask = torch.BoolTensor(audio.shape).fill_(False) + inputs = { + "source": audio.to(self.device), + "padding_mask": padding_mask.to(self.device), + "output_layer": 12, # layer 12 + } + logits = self.hubert.extract_features(**inputs) + units = logits[0] + return units + + +class Audio2HubertLarge1024L24: + def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"): + self.device = device + print(" [Encoder Model] HuBERT Base") + print(" [Loading] " + path) + self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task( + [path], + suffix="", + ) + self.hubert = self.models[0] + self.hubert = self.hubert.to(self.device) + self.hubert = self.hubert.float() + self.hubert.eval() + + def __call__(self, audio): # B, T + with torch.no_grad(): + padding_mask = torch.BoolTensor(audio.shape).fill_(False) + inputs = { + "source": audio.to(self.device), + "padding_mask": padding_mask.to(self.device), + "output_layer": 24, # layer 24 + } + logits = self.hubert.extract_features(**inputs) + units = logits[0] + return units + + +class DotDict(dict): + def __getattr__(*args): + val = dict.get(*args) + return DotDict(val) if type(val) is dict else val + + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +def load_model(model_path, device="cpu"): + config_file = os.path.join(os.path.split(model_path)[0], "config.yaml") + with open(config_file, "r") as config: + args = yaml.safe_load(config) + args = DotDict(args) + + # load model + model = None + + if args.model.type == "Sins": + model = Sins(sampling_rate=args.data.sampling_rate, block_size=args.data.block_size, n_harmonics=args.model.n_harmonics, n_mag_allpass=args.model.n_mag_allpass, n_mag_noise=args.model.n_mag_noise, n_unit=args.data.encoder_out_channels, n_spk=args.model.n_spk) + + elif args.model.type == "CombSub": + model = CombSub(sampling_rate=args.data.sampling_rate, block_size=args.data.block_size, n_mag_allpass=args.model.n_mag_allpass, n_mag_harmonic=args.model.n_mag_harmonic, n_mag_noise=args.model.n_mag_noise, n_unit=args.data.encoder_out_channels, n_spk=args.model.n_spk) + + elif args.model.type == "CombSubFast": + model = CombSubFast(sampling_rate=args.data.sampling_rate, block_size=args.data.block_size, n_unit=args.data.encoder_out_channels, n_spk=args.model.n_spk) + + else: + raise ValueError(f" [x] Unknown Model: {args.model.type}") + + print(" [Loading] " + model_path) + ckpt = torch.load(model_path, map_location=torch.device(device)) + model.to(device) + model.load_state_dict(ckpt["model"]) + model.eval() + return model, args + + +class Sins(torch.nn.Module): + def __init__(self, sampling_rate, block_size, n_harmonics, n_mag_allpass, n_mag_noise, n_unit=256, n_spk=1): + super().__init__() + + print(" [DDSP Model] Sinusoids Additive Synthesiser") + + # params + self.register_buffer("sampling_rate", torch.tensor(sampling_rate)) + self.register_buffer("block_size", torch.tensor(block_size)) + # Unit2Control + split_map = { + "amplitudes": n_harmonics, + "group_delay": n_mag_allpass, + "noise_magnitude": n_mag_noise, + } + self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map) + + def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, initial_phase=None, infer=True, max_upsample_dim=32): + """ + units_frames: B x n_frames x n_unit + f0_frames: B x n_frames x 1 + volume_frames: B x n_frames x 1 + spk_id: B x 1 + """ + # exciter phase + f0 = upsample(f0_frames, self.block_size) + if infer: + x = torch.cumsum(f0.double() / self.sampling_rate, axis=1) + else: + x = torch.cumsum(f0 / self.sampling_rate, axis=1) + if initial_phase is not None: + x += initial_phase.to(x) / 2 / np.pi + x = x - torch.round(x) + x = x.to(f0) + + phase = 2 * np.pi * x + phase_frames = phase[:, :: self.block_size, :] + + # parameter prediction + ctrls = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict) + + amplitudes_frames = torch.exp(ctrls["amplitudes"]) / 128 + group_delay = np.pi * torch.tanh(ctrls["group_delay"]) + noise_param = torch.exp(ctrls["noise_magnitude"]) / 128 + + # sinusoids exciter signal + amplitudes_frames = remove_above_fmax(amplitudes_frames, f0_frames, self.sampling_rate / 2, level_start=1) + n_harmonic = amplitudes_frames.shape[-1] + level_harmonic = torch.arange(1, n_harmonic + 1).to(phase) + sinusoids = 0.0 + for n in range((n_harmonic - 1) // max_upsample_dim + 1): + start = n * max_upsample_dim + end = (n + 1) * max_upsample_dim + phases = phase * level_harmonic[start:end] + amplitudes = upsample(amplitudes_frames[:, :, start:end], self.block_size) + sinusoids += (torch.sin(phases) * amplitudes).sum(-1) + + # harmonic part filter (apply group-delay) + harmonic = frequency_filter(sinusoids, torch.exp(1.0j * torch.cumsum(group_delay, axis=-1)), hann_window=False) + + # noise part filter + noise = torch.rand_like(harmonic) * 2 - 1 + noise = frequency_filter(noise, torch.complex(noise_param, torch.zeros_like(noise_param)), hann_window=True) + + signal = harmonic + noise + + return signal, phase, (harmonic, noise) # , (noise_param, noise_param) + + +class CombSubFast(torch.nn.Module): + def __init__(self, sampling_rate, block_size, n_unit=256, n_spk=1): + super().__init__() + + print(" [DDSP Model] Combtooth Subtractive Synthesiser") + # params + self.register_buffer("sampling_rate", torch.tensor(sampling_rate)) + self.register_buffer("block_size", torch.tensor(block_size)) + self.register_buffer("window", torch.sqrt(torch.hann_window(2 * block_size))) + # Unit2Control + split_map = {"harmonic_magnitude": block_size + 1, "harmonic_phase": block_size + 1, "noise_magnitude": block_size + 1} + self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map) + + def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, initial_phase=None, infer=True, **kwargs): + """ + units_frames: B x n_frames x n_unit + f0_frames: B x n_frames x 1 + volume_frames: B x n_frames x 1 + spk_id: B x 1 + """ + # exciter phase + f0 = upsample(f0_frames, self.block_size) + if infer: + x = torch.cumsum(f0.double() / self.sampling_rate, axis=1) + else: + x = torch.cumsum(f0 / self.sampling_rate, axis=1) + if initial_phase is not None: + x += initial_phase.to(x) / 2 / np.pi + x = x - torch.round(x) + x = x.to(f0) + + phase_frames = 2 * np.pi * x[:, :: self.block_size, :] + + # parameter prediction + ctrls = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict) + + src_filter = torch.exp(ctrls["harmonic_magnitude"] + 1.0j * np.pi * ctrls["harmonic_phase"]) + src_filter = torch.cat((src_filter, src_filter[:, -1:, :]), 1) + noise_filter = torch.exp(ctrls["noise_magnitude"]) / 128 + noise_filter = torch.cat((noise_filter, noise_filter[:, -1:, :]), 1) + + # combtooth exciter signal + combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3)) + combtooth = combtooth.squeeze(-1) + combtooth_frames = F.pad(combtooth, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size) + combtooth_frames = combtooth_frames * self.window + combtooth_fft = torch.fft.rfft(combtooth_frames, 2 * self.block_size) + + # noise exciter signal + noise = torch.rand_like(combtooth) * 2 - 1 + noise_frames = F.pad(noise, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size) + noise_frames = noise_frames * self.window + noise_fft = torch.fft.rfft(noise_frames, 2 * self.block_size) + + # apply the filters + signal_fft = combtooth_fft * src_filter + noise_fft * noise_filter + + # take the ifft to resynthesize audio. + signal_frames_out = torch.fft.irfft(signal_fft, 2 * self.block_size) * self.window + + # overlap add + fold = torch.nn.Fold(output_size=(1, (signal_frames_out.size(1) + 1) * self.block_size), kernel_size=(1, 2 * self.block_size), stride=(1, self.block_size)) + signal = fold(signal_frames_out.transpose(1, 2))[:, 0, 0, self.block_size : -self.block_size] + + return signal, phase_frames, (signal, signal) + + +class CombSub(torch.nn.Module): + def __init__(self, sampling_rate, block_size, n_mag_allpass, n_mag_harmonic, n_mag_noise, n_unit=256, n_spk=1): + super().__init__() + + print(" [DDSP Model] Combtooth Subtractive Synthesiser (Old Version)") + # params + self.register_buffer("sampling_rate", torch.tensor(sampling_rate)) + self.register_buffer("block_size", torch.tensor(block_size)) + # Unit2Control + split_map = {"group_delay": n_mag_allpass, "harmonic_magnitude": n_mag_harmonic, "noise_magnitude": n_mag_noise} + self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map) + + def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, initial_phase=None, infer=True, **kwargs): + """ + units_frames: B x n_frames x n_unit + f0_frames: B x n_frames x 1 + volume_frames: B x n_frames x 1 + spk_id: B x 1 + """ + # exciter phase + f0 = upsample(f0_frames, self.block_size) + if infer: + x = torch.cumsum(f0.double() / self.sampling_rate, axis=1) + else: + x = torch.cumsum(f0 / self.sampling_rate, axis=1) + if initial_phase is not None: + x += initial_phase.to(x) / 2 / np.pi + x = x - torch.round(x) + x = x.to(f0) + + phase_frames = 2 * np.pi * x[:, :: self.block_size, :] + + # parameter prediction + ctrls = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict) + + group_delay = np.pi * torch.tanh(ctrls["group_delay"]) + src_param = torch.exp(ctrls["harmonic_magnitude"]) + noise_param = torch.exp(ctrls["noise_magnitude"]) / 128 + + # combtooth exciter signal + combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3)) + combtooth = combtooth.squeeze(-1) + + # harmonic part filter (using dynamic-windowed LTV-FIR, with group-delay prediction) + harmonic = frequency_filter(combtooth, torch.exp(1.0j * torch.cumsum(group_delay, axis=-1)), hann_window=False) + harmonic = frequency_filter(harmonic, torch.complex(src_param, torch.zeros_like(src_param)), hann_window=True, half_width_frames=1.5 * self.sampling_rate / (f0_frames + 1e-3)) + + # noise part filter (using constant-windowed LTV-FIR, without group-delay) + noise = torch.rand_like(harmonic) * 2 - 1 + noise = frequency_filter(noise, torch.complex(noise_param, torch.zeros_like(noise_param)), hann_window=True) + + signal = harmonic + noise + + return signal, phase_frames, (harmonic, noise) diff --git a/server/voice_changer/DDSP_SVC/models/diffusion/data_loaders.py b/server/voice_changer/DDSP_SVC/models/diffusion/data_loaders.py new file mode 100644 index 00000000..3919f317 --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/diffusion/data_loaders.py @@ -0,0 +1,215 @@ +import os +import random +import re +import numpy as np +import librosa +import torch +from tqdm import tqdm +from torch.utils.data import Dataset + + +def traverse_dir(root_dir, extensions, amount=None, str_include=None, str_exclude=None, is_pure=False, is_sort=False, is_ext=True): + file_list = [] + cnt = 0 + for root, _, files in os.walk(root_dir): + for file in files: + if any([file.endswith(f".{ext}") for ext in extensions]): + # path + mix_path = os.path.join(root, file) + pure_path = mix_path[len(root_dir) + 1 :] if is_pure else mix_path + + # amount + if (amount is not None) and (cnt == amount): + if is_sort: + file_list.sort() + return file_list + + # check string + if (str_include is not None) and (str_include not in pure_path): + continue + if (str_exclude is not None) and (str_exclude in pure_path): + continue + + if not is_ext: + ext = pure_path.split(".")[-1] + pure_path = pure_path[: -(len(ext) + 1)] + file_list.append(pure_path) + cnt += 1 + if is_sort: + file_list.sort() + return file_list + + +def get_data_loaders(args, whole_audio=False): + data_train = AudioDataset(args.data.train_path, waveform_sec=args.data.duration, hop_size=args.data.block_size, sample_rate=args.data.sampling_rate, load_all_data=args.train.cache_all_data, whole_audio=whole_audio, extensions=args.data.extensions, n_spk=args.model.n_spk, device=args.train.cache_device, fp16=args.train.cache_fp16, use_aug=True) + loader_train = torch.utils.data.DataLoader(data_train, batch_size=args.train.batch_size if not whole_audio else 1, shuffle=True, num_workers=args.train.num_workers if args.train.cache_device == "cpu" else 0, persistent_workers=(args.train.num_workers > 0) if args.train.cache_device == "cpu" else False, pin_memory=True if args.train.cache_device == "cpu" else False) + data_valid = AudioDataset(args.data.valid_path, waveform_sec=args.data.duration, hop_size=args.data.block_size, sample_rate=args.data.sampling_rate, load_all_data=args.train.cache_all_data, whole_audio=True, extensions=args.data.extensions, n_spk=args.model.n_spk) + loader_valid = torch.utils.data.DataLoader(data_valid, batch_size=1, shuffle=False, num_workers=0, pin_memory=True) + return loader_train, loader_valid + + +class AudioDataset(Dataset): + def __init__( + self, + path_root, + waveform_sec, + hop_size, + sample_rate, + load_all_data=True, + whole_audio=False, + extensions=["wav"], + n_spk=1, + device="cpu", + fp16=False, + use_aug=False, + ): + super().__init__() + + self.waveform_sec = waveform_sec + self.sample_rate = sample_rate + self.hop_size = hop_size + self.path_root = path_root + self.paths = traverse_dir(os.path.join(path_root, "audio"), extensions=extensions, is_pure=True, is_sort=True, is_ext=True) + self.whole_audio = whole_audio + self.use_aug = use_aug + self.data_buffer = {} + self.pitch_aug_dict = np.load(os.path.join(self.path_root, "pitch_aug_dict.npy"), allow_pickle=True).item() + if load_all_data: + print("Load all the data from :", path_root) + else: + print("Load the f0, volume data from :", path_root) + for name_ext in tqdm(self.paths, total=len(self.paths)): + name = os.path.splitext(name_ext)[0] # NOQA + path_audio = os.path.join(self.path_root, "audio", name_ext) + duration = librosa.get_duration(filename=path_audio, sr=self.sample_rate) + + path_f0 = os.path.join(self.path_root, "f0", name_ext) + ".npy" + f0 = np.load(path_f0) + f0 = torch.from_numpy(f0).float().unsqueeze(-1).to(device) + + path_volume = os.path.join(self.path_root, "volume", name_ext) + ".npy" + volume = np.load(path_volume) + volume = torch.from_numpy(volume).float().unsqueeze(-1).to(device) + + path_augvol = os.path.join(self.path_root, "aug_vol", name_ext) + ".npy" + aug_vol = np.load(path_augvol) + aug_vol = torch.from_numpy(aug_vol).float().unsqueeze(-1).to(device) + + if n_spk is not None and n_spk > 1: + dirname_split = re.split(r"_|\-", os.path.dirname(name_ext), 2)[0] + spk_id = int(dirname_split) if str.isdigit(dirname_split) else 0 + if spk_id < 1 or spk_id > n_spk: + raise ValueError(" [x] Muiti-speaker traing error : spk_id must be a positive integer from 1 to n_spk ") + else: + spk_id = 1 + spk_id = torch.LongTensor(np.array([spk_id])).to(device) + + if load_all_data: + """ + audio, sr = librosa.load(path_audio, sr=self.sample_rate) + if len(audio.shape) > 1: + audio = librosa.to_mono(audio) + audio = torch.from_numpy(audio).to(device) + """ + path_mel = os.path.join(self.path_root, "mel", name_ext) + ".npy" + mel = np.load(path_mel) + mel = torch.from_numpy(mel).to(device) + + path_augmel = os.path.join(self.path_root, "aug_mel", name_ext) + ".npy" + aug_mel = np.load(path_augmel) + aug_mel = torch.from_numpy(aug_mel).to(device) + + path_units = os.path.join(self.path_root, "units", name_ext) + ".npy" + units = np.load(path_units) + units = torch.from_numpy(units).to(device) + + if fp16: + mel = mel.half() + aug_mel = aug_mel.half() + units = units.half() + + self.data_buffer[name_ext] = {"duration": duration, "mel": mel, "aug_mel": aug_mel, "units": units, "f0": f0, "volume": volume, "aug_vol": aug_vol, "spk_id": spk_id} + else: + self.data_buffer[name_ext] = {"duration": duration, "f0": f0, "volume": volume, "aug_vol": aug_vol, "spk_id": spk_id} + + def __getitem__(self, file_idx): + name_ext = self.paths[file_idx] + data_buffer = self.data_buffer[name_ext] + # check duration. if too short, then skip + if data_buffer["duration"] < (self.waveform_sec + 0.1): + return self.__getitem__((file_idx + 1) % len(self.paths)) + + # get item + return self.get_data(name_ext, data_buffer) + + def get_data(self, name_ext, data_buffer): + name = os.path.splitext(name_ext)[0] + frame_resolution = self.hop_size / self.sample_rate + duration = data_buffer["duration"] + waveform_sec = duration if self.whole_audio else self.waveform_sec + + # load audio + idx_from = 0 if self.whole_audio else random.uniform(0, duration - waveform_sec - 0.1) + start_frame = int(idx_from / frame_resolution) + units_frame_len = int(waveform_sec / frame_resolution) + aug_flag = random.choice([True, False]) and self.use_aug + """ + audio = data_buffer.get('audio') + if audio is None: + path_audio = os.path.join(self.path_root, 'audio', name) + '.wav' + audio, sr = librosa.load( + path_audio, + sr = self.sample_rate, + offset = start_frame * frame_resolution, + duration = waveform_sec) + if len(audio.shape) > 1: + audio = librosa.to_mono(audio) + # clip audio into N seconds + audio = audio[ : audio.shape[-1] // self.hop_size * self.hop_size] + audio = torch.from_numpy(audio).float() + else: + audio = audio[start_frame * self.hop_size : (start_frame + units_frame_len) * self.hop_size] + """ + # load mel + mel_key = "aug_mel" if aug_flag else "mel" + mel = data_buffer.get(mel_key) + if mel is None: + mel = os.path.join(self.path_root, mel_key, name_ext) + ".npy" + mel = np.load(mel) + mel = mel[start_frame : start_frame + units_frame_len] + mel = torch.from_numpy(mel).float() + else: + mel = mel[start_frame : start_frame + units_frame_len] + + # load units + units = data_buffer.get("units") + if units is None: + units = os.path.join(self.path_root, "units", name_ext) + ".npy" + units = np.load(units) + units = units[start_frame : start_frame + units_frame_len] + units = torch.from_numpy(units).float() + else: + units = units[start_frame : start_frame + units_frame_len] + + # load f0 + f0 = data_buffer.get("f0") + aug_shift = 0 + if aug_flag: + aug_shift = self.pitch_aug_dict[name_ext] + f0_frames = 2 ** (aug_shift / 12) * f0[start_frame : start_frame + units_frame_len] + + # load volume + vol_key = "aug_vol" if aug_flag else "volume" + volume = data_buffer.get(vol_key) + volume_frames = volume[start_frame : start_frame + units_frame_len] + + # load spk_id + spk_id = data_buffer.get("spk_id") + + # load shift + aug_shift = torch.from_numpy(np.array([[aug_shift]])).float() + + return dict(mel=mel, f0=f0_frames, volume=volume_frames, units=units, spk_id=spk_id, aug_shift=aug_shift, name=name, name_ext=name_ext) + + def __len__(self): + return len(self.paths) diff --git a/server/voice_changer/DDSP_SVC/models/diffusion/diffusion.py b/server/voice_changer/DDSP_SVC/models/diffusion/diffusion.py new file mode 100644 index 00000000..0d3b2fe1 --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/diffusion/diffusion.py @@ -0,0 +1,342 @@ +from collections import deque +from functools import partial +from inspect import isfunction +import torch.nn.functional as F +import numpy as np +import torch +from torch import nn +from tqdm import tqdm + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) # NOQA + noise = lambda: torch.randn(shape, device=device) # NOQA + return repeat_noise() if repeat else noise() + + +def linear_beta_schedule(timesteps, max_beta=0.02): + """ + linear schedule + """ + betas = np.linspace(1e-4, max_beta, timesteps) + return betas + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = np.linspace(0, steps, steps) + alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return np.clip(betas, a_min=0, a_max=0.999) + + +beta_schedule = { + "cosine": cosine_beta_schedule, + "linear": linear_beta_schedule, +} + + +class GaussianDiffusion(nn.Module): + def __init__(self, denoise_fn, out_dims=128, timesteps=1000, k_step=1000, max_beta=0.02, spec_min=-12, spec_max=2): + super().__init__() + self.denoise_fn = denoise_fn + self.out_dims = out_dims + betas = beta_schedule["linear"](timesteps, max_beta=max_beta) + + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.k_step = k_step + + self.noise_list = deque(maxlen=4) + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))) + self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))) + self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))) + self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer("posterior_log_variance_clipped", to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer("posterior_mean_coef1", to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))) + self.register_buffer("posterior_mean_coef2", to_torch((1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod))) + + self.register_buffer("spec_min", torch.FloatTensor([spec_min])[None, None, :out_dims]) + self.register_buffer("spec_max", torch.FloatTensor([spec_max])[None, None, :out_dims]) + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + + def q_posterior(self, x_start, x_t, t): + posterior_mean = extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, cond): + noise_pred = self.denoise_fn(x, t, cond=cond) + x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred) + + x_recon.clamp_(-1.0, 1.0) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device # type:ignore + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_ddim(self, x, t, interval, cond): + a_t = extract(self.alphas_cumprod, t, x.shape) + a_prev = extract(self.alphas_cumprod, torch.max(t - interval, torch.zeros_like(t)), x.shape) + + noise_pred = self.denoise_fn(x, t, cond=cond) + x_prev = a_prev.sqrt() * (x / a_t.sqrt() + (((1 - a_prev) / a_prev).sqrt() - ((1 - a_t) / a_t).sqrt()) * noise_pred) + return x_prev + + @torch.no_grad() + def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False): + """ + Use the PLMS method from + [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778). + """ + + def get_x_pred(x, noise_t, t): + a_t = extract(self.alphas_cumprod, t, x.shape) + a_prev = extract(self.alphas_cumprod, torch.max(t - interval, torch.zeros_like(t)), x.shape) + a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt() + + x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t) + x_pred = x + x_delta + + return x_pred + + noise_list = self.noise_list + noise_pred = self.denoise_fn(x, t, cond=cond) + + if len(noise_list) == 0: + x_pred = get_x_pred(x, noise_pred, t) + noise_pred_prev = self.denoise_fn(x_pred, max(t - interval, 0), cond=cond) + noise_pred_prime = (noise_pred + noise_pred_prev) / 2 + elif len(noise_list) == 1: + noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2 + elif len(noise_list) == 2: + noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12 + else: + noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24 + + x_prev = get_x_pred(x, noise_pred_prime, t) + noise_list.append(noise_pred) + + return x_prev + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + + def p_losses(self, x_start, t, cond, noise=None, loss_type="l2"): + noise = default(noise, lambda: torch.randn_like(x_start)) + + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + x_recon = self.denoise_fn(x_noisy, t, cond) + + if loss_type == "l1": + loss = (noise - x_recon).abs().mean() + elif loss_type == "l2": + loss = F.mse_loss(noise, x_recon) + else: + raise NotImplementedError() + + return loss + + def forward(self, condition, gt_spec=None, infer=True, infer_speedup=10, method="dpm-solver", k_step=None, use_tqdm=True): + """ + conditioning diffusion, use fastspeech2 encoder output as the condition + """ + cond = condition.transpose(1, 2) + b, device = condition.shape[0], condition.device + + if not infer: + spec = self.norm_spec(gt_spec) + if k_step is None: + t_max = self.k_step + else: + t_max = k_step + t = torch.randint(0, t_max, (b,), device=device).long() + norm_spec = spec.transpose(1, 2)[:, None, :, :] # [B, 1, M, T] + return self.p_losses(norm_spec, t, cond=cond) + else: + shape = (cond.shape[0], 1, self.out_dims, cond.shape[2]) + + if gt_spec is None or k_step is None: + t = self.k_step + x = torch.randn(shape, device=device) + else: + t = k_step + norm_spec = self.norm_spec(gt_spec) + norm_spec = norm_spec.transpose(1, 2)[:, None, :, :] + x = self.q_sample(x_start=norm_spec, t=torch.tensor([t - 1], device=device).long()) + + if method is not None and infer_speedup > 1: + if method == "dpm-solver": + from .dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver + + # 1. Define the noise schedule. + noise_schedule = NoiseScheduleVP(schedule="discrete", betas=self.betas[:t]) + + # 2. Convert your discrete-time `model` to the continuous-time + # noise prediction model. Here is an example for a diffusion model + # `model` with the noise prediction type ("noise") . + def my_wrapper(fn): + def wrapped(x, t, **kwargs): + ret = fn(x, t, **kwargs) + if use_tqdm: + self.bar.update(1) + return ret + + return wrapped + + model_fn = model_wrapper(my_wrapper(self.denoise_fn), noise_schedule, model_type="noise", model_kwargs={"cond": cond}) # or "x_start" or "v" or "score" + + # 3. Define dpm-solver and sample by singlestep DPM-Solver. + # (We recommend singlestep DPM-Solver for unconditional sampling) + # You can adjust the `steps` to balance the computation + # costs and the sample quality. + dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + + steps = t // infer_speedup + if use_tqdm: + self.bar = tqdm(desc="sample time step", total=steps) + x = dpm_solver.sample( + x, + steps=steps, + order=2, + skip_type="time_uniform", + method="multistep", + ) + if use_tqdm: + self.bar.close() + elif method == "unipc": + from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC + + # 1. Define the noise schedule. + noise_schedule = NoiseScheduleVP(schedule="discrete", betas=self.betas[:t]) + + # 2. Convert your discrete-time `model` to the continuous-time + # noise prediction model. Here is an example for a diffusion model + # `model` with the noise prediction type ("noise") . + def my_wrapper(fn): + def wrapped(x, t, **kwargs): + ret = fn(x, t, **kwargs) + if use_tqdm: + self.bar.update(1) + return ret + + return wrapped + + model_fn = model_wrapper(my_wrapper(self.denoise_fn), noise_schedule, model_type="noise", model_kwargs={"cond": cond}) # or "x_start" or "v" or "score" + + # 3. Define uni_pc and sample by multistep UniPC. + # You can adjust the `steps` to balance the computation + # costs and the sample quality. + uni_pc = UniPC(model_fn, noise_schedule, variant="bh2") + + steps = t // infer_speedup + if use_tqdm: + self.bar = tqdm(desc="sample time step", total=steps) + x = uni_pc.sample( + x, + steps=steps, + order=2, + skip_type="time_uniform", + method="multistep", + ) + if use_tqdm: + self.bar.close() + elif method == "pndm": + self.noise_list = deque(maxlen=4) + if use_tqdm: + for i in tqdm( + reversed(range(0, t, infer_speedup)), + desc="sample time step", + total=t // infer_speedup, + ): + x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond) + else: + for i in reversed(range(0, t, infer_speedup)): + x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond) + elif method == "ddim": + if use_tqdm: + for i in tqdm( + reversed(range(0, t, infer_speedup)), + desc="sample time step", + total=t // infer_speedup, + ): + x = self.p_sample_ddim(x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond) + else: + for i in reversed(range(0, t, infer_speedup)): + x = self.p_sample_ddim(x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond) + else: + raise NotImplementedError(method) + else: + if use_tqdm: + for i in tqdm(reversed(range(0, t)), desc="sample time step", total=t): + x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) + else: + for i in reversed(range(0, t)): + x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) + x = x.squeeze(1).transpose(1, 2) # [B, T, M] + return self.denorm_spec(x) + + def norm_spec(self, x): + return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 + + def denorm_spec(self, x): + return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min diff --git a/server/voice_changer/DDSP_SVC/models/diffusion/diffusion_onnx.py b/server/voice_changer/DDSP_SVC/models/diffusion/diffusion_onnx.py new file mode 100644 index 00000000..042ad35e --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/diffusion/diffusion_onnx.py @@ -0,0 +1,526 @@ +from collections import deque +from functools import partial +from inspect import isfunction +import torch.nn.functional as F +import numpy as np +from torch.nn import Conv1d +from torch.nn import Mish +import torch +from torch import nn +from tqdm import tqdm +import math + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def extract(a, t): + return a[t].reshape((1, 1, 1, 1)) + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) # NOQA + noise = lambda: torch.randn(shape, device=device) # NOQA + return repeat_noise() if repeat else noise() + + +def linear_beta_schedule(timesteps, max_beta=0.02): + """ + linear schedule + """ + betas = np.linspace(1e-4, max_beta, timesteps) + return betas + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = np.linspace(0, steps, steps) + alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return np.clip(betas, a_min=0, a_max=0.999) + + +beta_schedule = { + "cosine": cosine_beta_schedule, + "linear": linear_beta_schedule, +} + + +def extract_1(a, t): + return a[t].reshape((1, 1, 1, 1)) + + +def predict_stage0(noise_pred, noise_pred_prev): + return (noise_pred + noise_pred_prev) / 2 + + +def predict_stage1(noise_pred, noise_list): + return (noise_pred * 3 - noise_list[-1]) / 2 + + +def predict_stage2(noise_pred, noise_list): + return (noise_pred * 23 - noise_list[-1] * 16 + noise_list[-2] * 5) / 12 + + +def predict_stage3(noise_pred, noise_list): + return (noise_pred * 55 - noise_list[-1] * 59 + noise_list[-2] * 37 - noise_list[-3] * 9) / 24 + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + self.half_dim = dim // 2 + self.emb = 9.21034037 / (self.half_dim - 1) + self.emb = torch.exp(torch.arange(self.half_dim) * torch.tensor(-self.emb)).unsqueeze(0) + self.emb = self.emb.cpu() + + def forward(self, x): + emb = self.emb * x + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class ResidualBlock(nn.Module): + def __init__(self, encoder_hidden, residual_channels, dilation): + super().__init__() + self.residual_channels = residual_channels + self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation) + self.diffusion_projection = nn.Linear(residual_channels, residual_channels) + self.conditioner_projection = Conv1d(encoder_hidden, 2 * residual_channels, 1) + self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1) + + def forward(self, x, conditioner, diffusion_step): + diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) + conditioner = self.conditioner_projection(conditioner) + y = x + diffusion_step + y = self.dilated_conv(y) + conditioner + + gate, filter_1 = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) + + y = torch.sigmoid(gate) * torch.tanh(filter_1) + y = self.output_projection(y) + + residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) + + return (x + residual) / 1.41421356, skip + + +class DiffNet(nn.Module): + def __init__(self, in_dims, n_layers, n_chans, n_hidden): + super().__init__() + self.encoder_hidden = n_hidden + self.residual_layers = n_layers + self.residual_channels = n_chans + self.input_projection = Conv1d(in_dims, self.residual_channels, 1) + self.diffusion_embedding = SinusoidalPosEmb(self.residual_channels) + dim = self.residual_channels + self.mlp = nn.Sequential(nn.Linear(dim, dim * 4), Mish(), nn.Linear(dim * 4, dim)) + self.residual_layers = nn.ModuleList([ResidualBlock(self.encoder_hidden, self.residual_channels, 1) for i in range(self.residual_layers)]) + self.skip_projection = Conv1d(self.residual_channels, self.residual_channels, 1) + self.output_projection = Conv1d(self.residual_channels, in_dims, 1) + nn.init.zeros_(self.output_projection.weight) + + def forward(self, spec, diffusion_step, cond): + x = spec.squeeze(0) + x = self.input_projection(x) # x [B, residual_channel, T] + x = F.relu(x) + # skip = torch.randn_like(x) + diffusion_step = diffusion_step.float() + diffusion_step = self.diffusion_embedding(diffusion_step) + diffusion_step = self.mlp(diffusion_step) + + x, skip = self.residual_layers[0](x, cond, diffusion_step) + # noinspection PyTypeChecker + for layer in self.residual_layers[1:]: + x, skip_connection = layer.forward(x, cond, diffusion_step) + skip = skip + skip_connection + x = skip / math.sqrt(len(self.residual_layers)) + x = self.skip_projection(x) + x = F.relu(x) + x = self.output_projection(x) # [B, 80, T] + return x.unsqueeze(1) + + +class AfterDiffusion(nn.Module): + def __init__(self, spec_max, spec_min, v_type="a"): + super().__init__() + self.spec_max = spec_max + self.spec_min = spec_min + self.type = v_type + + def forward(self, x): + x = x.squeeze(1).permute(0, 2, 1) + mel_out = (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min + if self.type == "nsf-hifigan-log10": + mel_out = mel_out * 0.434294 + return mel_out.transpose(2, 1) + + +class Pred(nn.Module): + def __init__(self, alphas_cumprod): + super().__init__() + self.alphas_cumprod = alphas_cumprod + + def forward(self, x_1, noise_t, t_1, t_prev): + a_t = extract(self.alphas_cumprod, t_1).cpu() + a_prev = extract(self.alphas_cumprod, t_prev).cpu() + a_t_sq, a_prev_sq = a_t.sqrt().cpu(), a_prev.sqrt().cpu() + x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x_1 - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t) + x_pred = x_1 + x_delta.cpu() + + return x_pred + + +class GaussianDiffusion(nn.Module): + def __init__(self, out_dims=128, n_layers=20, n_chans=384, n_hidden=256, timesteps=1000, k_step=1000, max_beta=0.02, spec_min=-12, spec_max=2): + super().__init__() + self.denoise_fn = DiffNet(out_dims, n_layers, n_chans, n_hidden) + self.out_dims = out_dims + self.mel_bins = out_dims + self.n_hidden = n_hidden + betas = beta_schedule["linear"](timesteps, max_beta=max_beta) + + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.k_step = k_step + + self.noise_list = deque(maxlen=4) + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))) + self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))) + self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))) + self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer("posterior_log_variance_clipped", to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer("posterior_mean_coef1", to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))) + self.register_buffer("posterior_mean_coef2", to_torch((1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod))) + + self.register_buffer("spec_min", torch.FloatTensor([spec_min])[None, None, :out_dims]) + self.register_buffer("spec_max", torch.FloatTensor([spec_max])[None, None, :out_dims]) + self.ad = AfterDiffusion(self.spec_max, self.spec_min) + self.xp = Pred(self.alphas_cumprod) + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + + def q_posterior(self, x_start, x_t, t): + posterior_mean = extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, cond): + noise_pred = self.denoise_fn(x, t, cond=cond) + x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred) + + x_recon.clamp_(-1.0, 1.0) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device # type: ignore + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False): + """ + Use the PLMS method from + [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778). + """ + + def get_x_pred(x, noise_t, t): + a_t = extract(self.alphas_cumprod, t) + a_prev = extract(self.alphas_cumprod, torch.max(t - interval, torch.zeros_like(t))) + a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt() + + x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t) + x_pred = x + x_delta + + return x_pred + + noise_list = self.noise_list + noise_pred = self.denoise_fn(x, t, cond=cond) + + if len(noise_list) == 0: + x_pred = get_x_pred(x, noise_pred, t) + noise_pred_prev = self.denoise_fn(x_pred, max(t - interval, 0), cond=cond) + noise_pred_prime = (noise_pred + noise_pred_prev) / 2 + elif len(noise_list) == 1: + noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2 + elif len(noise_list) == 2: + noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12 + else: + noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24 + + x_prev = get_x_pred(x, noise_pred_prime, t) + noise_list.append(noise_pred) + + return x_prev + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + + def p_losses(self, x_start, t, cond, noise=None, loss_type="l2"): + noise = default(noise, lambda: torch.randn_like(x_start)) + + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + x_recon = self.denoise_fn(x_noisy, t, cond) + + if loss_type == "l1": + loss = (noise - x_recon).abs().mean() + elif loss_type == "l2": + loss = F.mse_loss(noise, x_recon) + else: + raise NotImplementedError() + + return loss + + def org_forward(self, condition, init_noise=None, gt_spec=None, infer=True, infer_speedup=100, method="pndm", k_step=1000, use_tqdm=True): + """ + conditioning diffusion, use fastspeech2 encoder output as the condition + """ + cond = condition + b, device = condition.shape[0], condition.device + if not infer: + spec = self.norm_spec(gt_spec) + t = torch.randint(0, self.k_step, (b,), device=device).long() + norm_spec = spec.transpose(1, 2)[:, None, :, :] # [B, 1, M, T] + return self.p_losses(norm_spec, t, cond=cond) + else: + shape = (cond.shape[0], 1, self.out_dims, cond.shape[2]) + + if gt_spec is None: + t = self.k_step + if init_noise is None: + x = torch.randn(shape, device=device) + else: + x = init_noise + else: + t = k_step + norm_spec = self.norm_spec(gt_spec) + norm_spec = norm_spec.transpose(1, 2)[:, None, :, :] + x = self.q_sample(x_start=norm_spec, t=torch.tensor([t - 1], device=device).long()) + + if method is not None and infer_speedup > 1: + if method == "dpm-solver": + from .dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver + + # 1. Define the noise schedule. + noise_schedule = NoiseScheduleVP(schedule="discrete", betas=self.betas[:t]) + + # 2. Convert your discrete-time `model` to the continuous-time + # noise prediction model. Here is an example for a diffusion model + # `model` with the noise prediction type ("noise") . + def my_wrapper(fn): + def wrapped(x, t, **kwargs): + ret = fn(x, t, **kwargs) + if use_tqdm: + self.bar.update(1) + return ret + + return wrapped + + model_fn = model_wrapper(my_wrapper(self.denoise_fn), noise_schedule, model_type="noise", model_kwargs={"cond": cond}) # or "x_start" or "v" or "score" + + # 3. Define dpm-solver and sample by singlestep DPM-Solver. + # (We recommend singlestep DPM-Solver for unconditional sampling) + # You can adjust the `steps` to balance the computation + # costs and the sample quality. + dpm_solver = DPM_Solver(model_fn, noise_schedule) + + steps = t // infer_speedup + if use_tqdm: + self.bar = tqdm(desc="sample time step", total=steps) + x = dpm_solver.sample( + x, + steps=steps, + order=3, + skip_type="time_uniform", + method="singlestep", + ) + if use_tqdm: + self.bar.close() + elif method == "pndm": + self.noise_list = deque(maxlen=4) + if use_tqdm: + for i in tqdm( + reversed(range(0, t, infer_speedup)), + desc="sample time step", + total=t // infer_speedup, + ): + x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond) + else: + for i in reversed(range(0, t, infer_speedup)): + x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond) + else: + raise NotImplementedError(method) + else: + if use_tqdm: + for i in tqdm(reversed(range(0, t)), desc="sample time step", total=t): + x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) + else: + for i in reversed(range(0, t)): + x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) + x = x.squeeze(1).transpose(1, 2) # [B, T, M] + return self.denorm_spec(x).transpose(2, 1) + + def norm_spec(self, x): + return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 + + def denorm_spec(self, x): + return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min + + def get_x_pred(self, x_1, noise_t, t_1, t_prev): + a_t = extract(self.alphas_cumprod, t_1) + a_prev = extract(self.alphas_cumprod, t_prev) + a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt() + x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x_1 - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t) + x_pred = x_1 + x_delta + return x_pred + + def OnnxExport(self, project_name=None, init_noise=None, hidden_channels=256, export_denoise=True, export_pred=True, export_after=True): + cond = torch.randn([1, self.n_hidden, 10]).cpu() + if init_noise is None: + x = torch.randn((1, 1, self.mel_bins, cond.shape[2]), dtype=torch.float32).cpu() + else: + x = init_noise + pndms = 100 + + org_y_x = self.org_forward(cond, init_noise=x) + + device = cond.device + n_frames = cond.shape[2] + step_range = torch.arange(0, self.k_step, pndms, dtype=torch.long, device=device).flip(0) + plms_noise_stage = torch.tensor(0, dtype=torch.long, device=device) + noise_list = torch.zeros((0, 1, 1, self.mel_bins, n_frames), device=device) + + ot = step_range[0] + ot_1 = torch.full((1,), ot, device=device, dtype=torch.long) + if export_denoise: + torch.onnx.export(self.denoise_fn, (x.cpu(), ot_1.cpu(), cond.cpu()), f"{project_name}_denoise.onnx", input_names=["noise", "time", "condition"], output_names=["noise_pred"], dynamic_axes={"noise": [3], "condition": [2]}, opset_version=16) + + for t in step_range: + t_1 = torch.full((1,), t, device=device, dtype=torch.long) + noise_pred = self.denoise_fn(x, t_1, cond) + t_prev = t_1 - pndms + t_prev = t_prev * (t_prev > 0) + if plms_noise_stage == 0: + if export_pred: + torch.onnx.export(self.xp, (x.cpu(), noise_pred.cpu(), t_1.cpu(), t_prev.cpu()), f"{project_name}_pred.onnx", input_names=["noise", "noise_pred", "time", "time_prev"], output_names=["noise_pred_o"], dynamic_axes={"noise": [3], "noise_pred": [3]}, opset_version=16) + + x_pred = self.get_x_pred(x, noise_pred, t_1, t_prev) + noise_pred_prev = self.denoise_fn(x_pred, t_prev, cond=cond) + noise_pred_prime = predict_stage0(noise_pred, noise_pred_prev) + + elif plms_noise_stage == 1: + noise_pred_prime = predict_stage1(noise_pred, noise_list) + + elif plms_noise_stage == 2: + noise_pred_prime = predict_stage2(noise_pred, noise_list) + + else: + noise_pred_prime = predict_stage3(noise_pred, noise_list) + + noise_pred = noise_pred.unsqueeze(0) + + if plms_noise_stage < 3: + noise_list = torch.cat((noise_list, noise_pred), dim=0) + plms_noise_stage = plms_noise_stage + 1 + + else: + noise_list = torch.cat((noise_list[-2:], noise_pred), dim=0) + + x = self.get_x_pred(x, noise_pred_prime, t_1, t_prev) + if export_after: + torch.onnx.export(self.ad, x.cpu(), f"{project_name}_after.onnx", input_names=["x"], output_names=["mel_out"], dynamic_axes={"x": [3]}, opset_version=16) + x = self.ad(x) + + print((x == org_y_x).all()) + return x + + def forward(self, condition=None, init_noise=None, pndms=None, k_step=None): + cond = condition + x = init_noise + + device = cond.device + n_frames = cond.shape[2] + step_range = torch.arange(0, k_step.item(), pndms.item(), dtype=torch.long, device=device).flip(0) + plms_noise_stage = torch.tensor(0, dtype=torch.long, device=device) + noise_list = torch.zeros((0, 1, 1, self.mel_bins, n_frames), device=device) + + ot = step_range[0] + ot_1 = torch.full((1,), ot, device=device, dtype=torch.long) # NOQA + + for t in step_range: + t_1 = torch.full((1,), t, device=device, dtype=torch.long) + noise_pred = self.denoise_fn(x, t_1, cond) + t_prev = t_1 - pndms + t_prev = t_prev * (t_prev > 0) + if plms_noise_stage == 0: + x_pred = self.get_x_pred(x, noise_pred, t_1, t_prev) + noise_pred_prev = self.denoise_fn(x_pred, t_prev, cond=cond) + noise_pred_prime = predict_stage0(noise_pred, noise_pred_prev) + + elif plms_noise_stage == 1: + noise_pred_prime = predict_stage1(noise_pred, noise_list) + + elif plms_noise_stage == 2: + noise_pred_prime = predict_stage2(noise_pred, noise_list) + + else: + noise_pred_prime = predict_stage3(noise_pred, noise_list) + + noise_pred = noise_pred.unsqueeze(0) + + if plms_noise_stage < 3: + noise_list = torch.cat((noise_list, noise_pred), dim=0) + plms_noise_stage = plms_noise_stage + 1 + + else: + noise_list = torch.cat((noise_list[-2:], noise_pred), dim=0) + + x = self.get_x_pred(x, noise_pred_prime, t_1, t_prev) + x = self.ad(x) + return x diff --git a/server/voice_changer/DDSP_SVC/models/diffusion/dpm_solver_pytorch.py b/server/voice_changer/DDSP_SVC/models/diffusion/dpm_solver_pytorch.py new file mode 100644 index 00000000..3fade3bc --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/diffusion/dpm_solver_pytorch.py @@ -0,0 +1,1284 @@ +import torch + + +class NoiseScheduleVP: + def __init__( + self, + schedule="discrete", + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20.0, + dtype=torch.float32, + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise + schedule are the default settings in Yang Song's ScoreSDE: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ["discrete", "linear"]: + raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear'".format(schedule)) + + self.schedule = schedule + if schedule == "discrete": + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.T = 1.0 + self.log_alpha_array = ( + self.numerical_clip_alpha(log_alphas) + .reshape( + ( + 1, + -1, + ) + ) + .to(dtype=dtype) + ) + self.total_N = self.log_alpha_array.shape[1] + self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) + else: + self.T = 1.0 + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + + def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1): + """ + For some beta schedules such as cosine schedule, the log-SNR has numerical isssues. + We clip the log-SNR near t=T within -5.1 to ensure the stability. + Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE. + """ + log_sigmas = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_alphas)) + lambs = log_alphas - log_sigmas + idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda) + if idx > 0: + log_alphas = log_alphas[:-idx] + return log_alphas + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == "discrete": + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == "linear": + return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == "linear": + tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == "discrete": + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1.0, + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == "discrete": + return (t_continuous - 1.0 / noise_schedule.total_N) * noise_schedule.total_N + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim()) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + return -expand_dims(sigma_t, x.dim()) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1.0 or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v", "score"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__( + self, + model_fn, + noise_schedule, + algorithm_type="dpmsolver++", + correcting_x0_fn=None, + correcting_xt_fn=None, + thresholding_max_val=1.0, + dynamic_thresholding_ratio=0.995, + ): + """Construct a DPM-Solver. + + We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`). + + We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you + can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the + dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space + DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space + DPMs (such as stable-diffusion). + + To support advanced algorithms in image-to-image applications, we also support corrector functions for + both x0 and xt. + + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++". + correcting_x0_fn: A `str` or a function with the following format: + ``` + def correcting_x0_fn(x0, t): + x0_new = ... + return x0_new + ``` + This function is to correct the outputs of the data prediction model at each sampling step. e.g., + ``` + x0_pred = data_pred_model(xt, t) + if correcting_x0_fn is not None: + x0_pred = correcting_x0_fn(x0_pred, t) + xt_1 = update(x0_pred, xt, t) + ``` + If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1]. + correcting_xt_fn: A function with the following format: + ``` + def correcting_xt_fn(xt, t, step): + x_new = ... + return x_new + ``` + This function is to correct the intermediate samples xt at each sampling step. e.g., + ``` + xt = ... + xt = correcting_xt_fn(xt, t, step) + ``` + thresholding_max_val: A `float`. The max value for thresholding. + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details). + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, + Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models + with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) + self.noise_schedule = noise_schedule + assert algorithm_type in ["dpmsolver", "dpmsolver++"] + self.algorithm_type = algorithm_type + if correcting_x0_fn == "dynamic_thresholding": + self.correcting_x0_fn = self.dynamic_thresholding_fn + else: + self.correcting_x0_fn = correcting_x0_fn + self.correcting_xt_fn = correcting_xt_fn + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.thresholding_max_val = thresholding_max_val + + def dynamic_thresholding_fn(self, x0, t): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with corrector). + """ + noise = self.noise_prediction_fn(x, t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - sigma_t * noise) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0, t) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.algorithm_type == "dpmsolver++": + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == "logSNR": + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == "time_uniform": + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == "time_quadratic": + t_order = 2 + t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [ + 3, + ] * ( + K - 2 + ) + [2, 1] + elif steps % 3 == 1: + orders = [ + 3, + ] * ( + K - 1 + ) + [1] + else: + orders = [ + 3, + ] * ( + K - 1 + ) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [ + 2, + ] * K + else: + K = steps // 2 + 1 + orders = [ + 2, + ] * ( + K - 1 + ) + [1] + elif order == 1: + K = 1 + orders = [ + 1, + ] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == "logSNR": + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ + torch.cumsum( + torch.tensor( + [ + 0, + ] + + orders + ), + 0, + ).to(device) + ] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() # NOQA + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s + if return_intermediate: + return x_t, {"model_s": model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = torch.exp(log_alpha_t - log_alpha_s) * x - (sigma_t * phi_1) * model_s + if return_intermediate: + return x_t, {"model_s": model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type="dpmsolver"): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + if solver_type == "dpmsolver": + x_t = (sigma_t / sigma_s) * x - (alpha_t * phi_1) * model_s - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s) + elif solver_type == "taylor": + x_t = (sigma_t / sigma_s) * x - (alpha_t * phi_1) * model_s + (1.0 / r1) * (alpha_t * (phi_1 / h + 1.0)) * (model_s1 - model_s) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + if solver_type == "dpmsolver": + x_t = torch.exp(log_alpha_t - log_alpha_s) * x - (sigma_t * phi_1) * model_s - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s) + elif solver_type == "taylor": + x_t = torch.exp(log_alpha_t - log_alpha_s) * x - (sigma_t * phi_1) * model_s - (1.0 / r1) * (sigma_t * (phi_1 / h - 1.0)) * (model_s1 - model_s) + if return_intermediate: + return x_t, {"model_s": model_s, "model_s1": model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update(self, x, s, t, r1=1.0 / 3.0, r2=2.0 / 3.0, model_s=None, model_s1=None, return_intermediate=False, solver_type="dpmsolver"): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1.0 / 3.0 + if r2 is None: + r2 = 2.0 / 3.0 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0 + phi_2 = phi_1 / h + 1.0 + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + x_s2 = (sigma_s2 / sigma_s) * x - (alpha_s2 * phi_12) * model_s + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == "dpmsolver": + x_t = (sigma_t / sigma_s) * x - (alpha_t * phi_1) * model_s + (1.0 / r2) * (alpha_t * phi_2) * (model_s2 - model_s) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) + x_t = (sigma_t / sigma_s) * x - (alpha_t * phi_1) * model_s + (alpha_t * phi_2) * D1 - (alpha_t * phi_3) * D2 + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0 + phi_2 = phi_1 / h - 1.0 + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + x_s2 = (torch.exp(log_alpha_s2 - log_alpha_s)) * x - (sigma_s2 * phi_12) * model_s - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == "dpmsolver": + x_t = (torch.exp(log_alpha_t - log_alpha_s)) * x - (sigma_t * phi_1) * model_s - (1.0 / r2) * (sigma_t * phi_2) * (model_s2 - model_s) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) + x_t = (torch.exp(log_alpha_t - log_alpha_s)) * x - (sigma_t * phi_1) * model_s - (sigma_t * phi_2) * D1 - (sigma_t * phi_3) * D2 + + if return_intermediate: + return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] + t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if solver_type == "dpmsolver": + x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0 + elif solver_type == "taylor": + x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 + (alpha_t * (phi_1 / h + 1.0)) * D1_0 + else: + phi_1 = torch.expm1(h) + if solver_type == "dpmsolver": + x_t = (torch.exp(log_alpha_t - log_alpha_prev_0)) * x - (sigma_t * phi_1) * model_prev_0 - 0.5 * (sigma_t * phi_1) * D1_0 + elif solver_type == "taylor": + x_t = (torch.exp(log_alpha_t - log_alpha_prev_0)) * x - (sigma_t * phi_1) * model_prev_0 - (sigma_t * (phi_1 / h - 1.0)) * D1_0 + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) + D1_1 = (1.0 / r1) * (model_prev_1 - model_prev_2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + phi_2 = phi_1 / h + 1.0 + phi_3 = phi_2 / h - 0.5 + x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 + (alpha_t * phi_2) * D1 - (alpha_t * phi_3) * D2 + else: + phi_1 = torch.expm1(h) + phi_2 = phi_1 / h - 1.0 + phi_3 = phi_2 / h - 0.5 + x_t = (torch.exp(log_alpha_t - log_alpha_prev_0)) * x - (sigma_t * phi_1) * model_prev_0 - (sigma_t * phi_2) * D1 - (sigma_t * phi_3) * D2 + return x_t + + def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type="dpmsolver", r1=None, r2=None): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1) + elif order == 3: + return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type="dpmsolver"): + """ + The adaptive step size solver based on singlestep DPM-Solver. + + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((1,)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) # NOQA + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) # NOQA + elif order == 3: + r1, r2 = 1.0 / 3.0, 2.0 / 3.0 + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type) # NOQA + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) # NOQA + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) # NOQA + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.0): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1.0 / order).float(), lambda_0 - lambda_s) + nfe += order + print("adaptive solver nfe", nfe) + return x + + def add_noise(self, x, t, noise=None): + """ + Compute the noised input xt = alpha_t * x + sigma_t * noise. + + Args: + x: A `torch.Tensor` with shape `(batch_size, *shape)`. + t: A `torch.Tensor` with shape `(t_size,)`. + Returns: + xt with shape `(t_size, batch_size, *shape)`. + """ + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + if noise is None: + noise = torch.randn((t.shape[0], *x.shape), device=x.device) + x = x.reshape((-1, *x.shape)) + xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise + if t.shape[0] == 1: + return xt.squeeze(0) + else: + return xt + + def inverse( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=2, + skip_type="time_uniform", + method="multistep", + lower_order_final=True, + denoise_to_zero=False, + solver_type="dpmsolver", + atol=0.0078, + rtol=0.05, + return_intermediate=False, + ): + """ + Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver. + For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training. + """ + t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start + t_T = self.noise_schedule.T if t_end is None else t_end + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + return self.sample(x, steps=steps, t_start=t_0, t_end=t_T, order=order, skip_type=skip_type, method=method, lower_order_final=lower_order_final, denoise_to_zero=denoise_to_zero, solver_type=solver_type, atol=atol, rtol=rtol, return_intermediate=return_intermediate) + + def sample( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=2, + skip_type="time_uniform", + method="multistep", + lower_order_final=True, + denoise_to_zero=False, + solver_type="dpmsolver", + atol=0.0078, + rtol=0.05, + return_intermediate=False, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + + ===================================================== + + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + + ===================================================== + + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g., DPM-Solver: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + e.g., DPM-Solver++: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + return_intermediate: A `bool`. Whether to save the xt at each step. + When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + + """ + t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + if return_intermediate: + assert method in ["multistep", "singlestep", "singlestep_fixed"], "Cannot use adaptive solver when saving intermediate values" + if self.correcting_xt_fn is not None: + assert method in ["multistep", "singlestep", "singlestep_fixed"], "Cannot use adaptive solver when correcting_xt_fn is not None" + device = x.device + intermediates = [] + with torch.no_grad(): + if method == "adaptive": + x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type) + elif method == "multistep": + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + t_prev_list = [t] + model_prev_list = [self.model_fn(x, t)] + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + # Init the first `order` values by lower order multistep DPM-Solver. + for step in range(1, order): + t = timesteps[step] + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step, solver_type=solver_type) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + t_prev_list.append(t) + model_prev_list.append(self.model_fn(x, t)) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in range(order, steps + 1): + t = timesteps[step] + # We only use lower order for steps < 10 + if lower_order_final and steps < 10: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, t) + elif method in ["singlestep", "singlestep_fixed"]: + if method == "singlestep": + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device) + elif method == "singlestep_fixed": + K = steps // order + orders = [ + order, + ] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for step, order in enumerate(orders): + s, t = timesteps_outer[step], timesteps_outer[step + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + else: + raise ValueError("Got wrong method {}".format(method)) + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + if return_intermediate: + return x, intermediates + else: + return x + + +############################################################# +# other utility functions +############################################################# + + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,) * (dims - 1)] diff --git a/server/voice_changer/DDSP_SVC/models/diffusion/infer_gt_mel.py b/server/voice_changer/DDSP_SVC/models/diffusion/infer_gt_mel.py new file mode 100644 index 00000000..a9b96322 --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/diffusion/infer_gt_mel.py @@ -0,0 +1,59 @@ +import torch +import torch.nn.functional as F +from .unit2mel import load_model_vocoder + + +class DiffGtMel: + def __init__(self, project_path=None, device=None): + self.project_path = project_path + if device is not None: + self.device = device + else: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model = None + self.vocoder = None + self.args = None + + def flush_model(self, project_path, ddsp_config=None): + if (self.model is None) or (project_path != self.project_path): + model, vocoder, args = load_model_vocoder(project_path, device=self.device) + if self.check_args(ddsp_config, args): + self.model = model + self.vocoder = vocoder + self.args = args + + def check_args(self, args1, args2): + if args1.data.block_size != args2.data.block_size: + raise ValueError("DDSP与DIFF模型的block_size不一致") + if args1.data.sampling_rate != args2.data.sampling_rate: + raise ValueError("DDSP与DIFF模型的sampling_rate不一致") + if args1.data.encoder != args2.data.encoder: + raise ValueError("DDSP与DIFF模型的encoder不一致") + return True + + def __call__(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, method="pndm", spk_mix_dict=None, start_frame=0): + input_mel = self.vocoder.extract(audio, self.args.data.sampling_rate) + out_mel = self.model(hubert, f0, volume, spk_id=spk_id, spk_mix_dict=spk_mix_dict, gt_spec=input_mel, infer=True, infer_speedup=acc, method=method, k_step=k_step, use_tqdm=False) + if start_frame > 0: + out_mel = out_mel[:, start_frame:, :] + f0 = f0[:, start_frame:, :] + output = self.vocoder.infer(out_mel, f0) + if start_frame > 0: + output = F.pad(output, (start_frame * self.vocoder.vocoder_hop_size, 0)) + return output + + def infer(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, method="pndm", silence_front=0, use_silence=False, spk_mix_dict=None): + start_frame = int(silence_front * self.vocoder.vocoder_sample_rate / self.vocoder.vocoder_hop_size) + if use_silence: + audio = audio[:, start_frame * self.vocoder.vocoder_hop_size :] + f0 = f0[:, start_frame:, :] + hubert = hubert[:, start_frame:, :] + volume = volume[:, start_frame:, :] + _start_frame = 0 + else: + _start_frame = start_frame + audio = self.__call__(audio, f0, hubert, volume, acc=acc, spk_id=spk_id, k_step=k_step, method=method, spk_mix_dict=spk_mix_dict, start_frame=_start_frame) + if use_silence: + if start_frame > 0: + audio = F.pad(audio, (start_frame * self.vocoder.vocoder_hop_size, 0)) + return audio diff --git a/server/voice_changer/DDSP_SVC/models/diffusion/onnx_export.py b/server/voice_changer/DDSP_SVC/models/diffusion/onnx_export.py new file mode 100644 index 00000000..41508bac --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/diffusion/onnx_export.py @@ -0,0 +1,226 @@ +from diffusion_onnx import GaussianDiffusion +import os +import yaml +import torch +import torch.nn as nn +import numpy as np +from wavenet import WaveNet +import torch.nn.functional as F +import diffusion + +class DotDict(dict): + def __getattr__(*args): + val = dict.get(*args) + return DotDict(val) if type(val) is dict else val + + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +def load_model_vocoder( + model_path, + device='cpu'): + config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml') + with open(config_file, "r") as config: + args = yaml.safe_load(config) + args = DotDict(args) + + # load model + model = Unit2Mel( + args.data.encoder_out_channels, + args.model.n_spk, + args.model.use_pitch_aug, + 128, + args.model.n_layers, + args.model.n_chans, + args.model.n_hidden) + + print(' [Loading] ' + model_path) + ckpt = torch.load(model_path, map_location=torch.device(device)) + model.to(device) + model.load_state_dict(ckpt['model']) + model.eval() + return model, args + + +class Unit2Mel(nn.Module): + def __init__( + self, + input_channel, + n_spk, + use_pitch_aug=False, + out_dims=128, + n_layers=20, + n_chans=384, + n_hidden=256): + super().__init__() + self.unit_embed = nn.Linear(input_channel, n_hidden) + self.f0_embed = nn.Linear(1, n_hidden) + self.volume_embed = nn.Linear(1, n_hidden) + if use_pitch_aug: + self.aug_shift_embed = nn.Linear(1, n_hidden, bias=False) + else: + self.aug_shift_embed = None + self.n_spk = n_spk + if n_spk is not None and n_spk > 1: + self.spk_embed = nn.Embedding(n_spk, n_hidden) + + # diffusion + self.decoder = GaussianDiffusion(out_dims, n_layers, n_chans, n_hidden) + self.hidden_size = n_hidden + self.speaker_map = torch.zeros((self.n_spk,1,1,n_hidden)) + + + + def forward(self, units, mel2ph, f0, volume, g = None): + + ''' + input: + B x n_frames x n_unit + return: + dict of B x n_frames x feat + ''' + + decoder_inp = F.pad(units, [0, 0, 1, 0]) + mel2ph_ = mel2ph.unsqueeze(2).repeat([1, 1, units.shape[-1]]) + units = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H] + + x = self.unit_embed(units) + self.f0_embed((1 + f0.unsqueeze(-1) / 700).log()) + self.volume_embed(volume.unsqueeze(-1)) + + if self.n_spk is not None and self.n_spk > 1: # [N, S] * [S, B, 1, H] + g = g.reshape((g.shape[0], g.shape[1], 1, 1, 1)) # [N, S, B, 1, 1] + g = g * self.speaker_map # [N, S, B, 1, H] + g = torch.sum(g, dim=1) # [N, 1, B, 1, H] + g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N] + x = x.transpose(1, 2) + g + return x + else: + return x.transpose(1, 2) + + + def init_spkembed(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None, + gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=300, use_tqdm=True): + + ''' + input: + B x n_frames x n_unit + return: + dict of B x n_frames x feat + ''' + x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume) + if self.n_spk is not None and self.n_spk > 1: + if spk_mix_dict is not None: + spk_embed_mix = torch.zeros((1,1,self.hidden_size)) + for k, v in spk_mix_dict.items(): + spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device) + spk_embeddd = self.spk_embed(spk_id_torch) + self.speaker_map[k] = spk_embeddd + spk_embed_mix = spk_embed_mix + v * spk_embeddd + x = x + spk_embed_mix + else: + x = x + self.spk_embed(spk_id - 1) + self.speaker_map = self.speaker_map.unsqueeze(0) + self.speaker_map = self.speaker_map.detach() + return x.transpose(1, 2) + + def OnnxExport(self, project_name=None, init_noise=None, export_encoder=True, export_denoise=True, export_pred=True, export_after=True): + hubert_hidden_size = 768 + n_frames = 100 + hubert = torch.randn((1, n_frames, hubert_hidden_size)) + mel2ph = torch.arange(end=n_frames).unsqueeze(0).long() + f0 = torch.randn((1, n_frames)) + volume = torch.randn((1, n_frames)) + spk_mix = [] + spks = {} + if self.n_spk is not None and self.n_spk > 1: + for i in range(self.n_spk): + spk_mix.append(1.0/float(self.n_spk)) + spks.update({i:1.0/float(self.n_spk)}) + spk_mix = torch.tensor(spk_mix) + spk_mix = spk_mix.repeat(n_frames, 1) + orgouttt = self.init_spkembed(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks) + outtt = self.forward(hubert, mel2ph, f0, volume, spk_mix) + if export_encoder: + torch.onnx.export( + self, + (hubert, mel2ph, f0, volume, spk_mix), + f"{project_name}_encoder.onnx", + input_names=["hubert", "mel2ph", "f0", "volume", "spk_mix"], + output_names=["mel_pred"], + dynamic_axes={ + "hubert": [1], + "f0": [1], + "volume": [1], + "mel2ph": [1], + "spk_mix": [0], + }, + opset_version=16 + ) + + self.decoder.OnnxExport(project_name, init_noise=init_noise, export_denoise=export_denoise, export_pred=export_pred, export_after=export_after) + + def ExportOnnx(self, project_name=None): + hubert_hidden_size = 768 + n_frames = 100 + hubert = torch.randn((1, n_frames, hubert_hidden_size)) + mel2ph = torch.arange(end=n_frames).unsqueeze(0).long() + f0 = torch.randn((1, n_frames)) + volume = torch.randn((1, n_frames)) + spk_mix = [] + spks = {} + if self.n_spk is not None and self.n_spk > 1: + for i in range(self.n_spk): + spk_mix.append(1.0/float(self.n_spk)) + spks.update({i:1.0/float(self.n_spk)}) + spk_mix = torch.tensor(spk_mix) + orgouttt = self.orgforward(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks) + outtt = self.forward(hubert, mel2ph, f0, volume, spk_mix) + + torch.onnx.export( + self, + (hubert, mel2ph, f0, volume, spk_mix), + f"{project_name}_encoder.onnx", + input_names=["hubert", "mel2ph", "f0", "volume", "spk_mix"], + output_names=["mel_pred"], + dynamic_axes={ + "hubert": [1], + "f0": [1], + "volume": [1], + "mel2ph": [1] + }, + opset_version=16 + ) + + condition = torch.randn(1,self.decoder.n_hidden,n_frames) + noise = torch.randn((1, 1, self.decoder.mel_bins, condition.shape[2]), dtype=torch.float32) + pndm_speedup = torch.LongTensor([100]) + K_steps = torch.LongTensor([1000]) + self.decoder = torch.jit.script(self.decoder) + self.decoder(condition, noise, pndm_speedup, K_steps) + + torch.onnx.export( + self.decoder, + (condition, noise, pndm_speedup, K_steps), + f"{project_name}_diffusion.onnx", + input_names=["condition", "noise", "pndm_speedup", "K_steps"], + output_names=["mel"], + dynamic_axes={ + "condition": [2], + "noise": [3], + }, + opset_version=16 + ) + + +if __name__ == "__main__": + project_name = "dddsp" + model_path = f'{project_name}/model_500000.pt' + + model, _ = load_model_vocoder(model_path) + + # 分开Diffusion导出(需要使用MoeSS/MoeVoiceStudio或者自己编写Pndm/Dpm采样) + model.OnnxExport(project_name, export_encoder=True, export_denoise=True, export_pred=True, export_after=True) + + # 合并Diffusion导出(Encoder和Diffusion分开,直接将Encoder的结果和初始噪声输入Diffusion即可) + # model.ExportOnnx(project_name) + diff --git a/server/voice_changer/DDSP_SVC/models/diffusion/solver.py b/server/voice_changer/DDSP_SVC/models/diffusion/solver.py new file mode 100644 index 00000000..7df93bab --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/diffusion/solver.py @@ -0,0 +1,194 @@ +import os +import time +import numpy as np +import torch +import librosa +from logger.saver import Saver +from logger import utils +from torch import autocast +from torch.cuda.amp import GradScaler + +def test(args, model, vocoder, loader_test, saver): + print(' [*] testing...') + model.eval() + + # losses + test_loss = 0. + + # intialization + num_batches = len(loader_test) + rtf_all = [] + + # run + with torch.no_grad(): + for bidx, data in enumerate(loader_test): + fn = data['name'][0] + print('--------') + print('{}/{} - {}'.format(bidx, num_batches, fn)) + + # unpack data + for k in data.keys(): + if not k.startswith('name'): + data[k] = data[k].to(args.device) + print('>>', data['name'][0]) + + # forward + st_time = time.time() + mel = model( + data['units'], + data['f0'], + data['volume'], + data['spk_id'], + gt_spec=None, + infer=True, + infer_speedup=args.infer.speedup, + method=args.infer.method) + signal = vocoder.infer(mel, data['f0']) + ed_time = time.time() + + # RTF + run_time = ed_time - st_time + song_time = signal.shape[-1] / args.data.sampling_rate + rtf = run_time / song_time + print('RTF: {} | {} / {}'.format(rtf, run_time, song_time)) + rtf_all.append(rtf) + + # loss + for i in range(args.train.batch_size): + loss = model( + data['units'], + data['f0'], + data['volume'], + data['spk_id'], + gt_spec=data['mel'], + infer=False) + test_loss += loss.item() + + # log mel + saver.log_spec(data['name'][0], data['mel'], mel) + + # log audio + path_audio = os.path.join(args.data.valid_path, 'audio', data['name_ext'][0]) + audio, sr = librosa.load(path_audio, sr=args.data.sampling_rate) + if len(audio.shape) > 1: + audio = librosa.to_mono(audio) + audio = torch.from_numpy(audio).unsqueeze(0).to(signal) + saver.log_audio({fn+'/gt.wav': audio, fn+'/pred.wav': signal}) + + # report + test_loss /= args.train.batch_size + test_loss /= num_batches + + # check + print(' [test_loss] test_loss:', test_loss) + print(' Real Time Factor', np.mean(rtf_all)) + return test_loss + + +def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_test): + # saver + saver = Saver(args, initial_global_step=initial_global_step) + + # model size + params_count = utils.get_network_paras_amount({'model': model}) + saver.log_info('--- model size ---') + saver.log_info(params_count) + + # run + num_batches = len(loader_train) + model.train() + saver.log_info('======= start training =======') + scaler = GradScaler() + if args.train.amp_dtype == 'fp32': + dtype = torch.float32 + elif args.train.amp_dtype == 'fp16': + dtype = torch.float16 + elif args.train.amp_dtype == 'bf16': + dtype = torch.bfloat16 + else: + raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype) + for epoch in range(args.train.epochs): + for batch_idx, data in enumerate(loader_train): + saver.global_step_increment() + optimizer.zero_grad() + + # unpack data + for k in data.keys(): + if not k.startswith('name'): + data[k] = data[k].to(args.device) + + # forward + if dtype == torch.float32: + loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'], + aug_shift = data['aug_shift'], gt_spec=data['mel'].float(), infer=False) + else: + with autocast(device_type=args.device, dtype=dtype): + loss = model(data['units'], data['f0'], data['volume'], data['spk_id'], + aug_shift = data['aug_shift'], gt_spec=data['mel'], infer=False) + + # handle nan loss + if torch.isnan(loss): + raise ValueError(' [x] nan loss ') + else: + # backpropagate + if dtype == torch.float32: + loss.backward() + optimizer.step() + else: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + scheduler.step() + + # log loss + if saver.global_step % args.train.interval_log == 0: + current_lr = optimizer.param_groups[0]['lr'] + saver.log_info( + 'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | lr: {:.6} | loss: {:.3f} | time: {} | step: {}'.format( + epoch, + batch_idx, + num_batches, + args.env.expdir, + args.train.interval_log/saver.get_interval_time(), + current_lr, + loss.item(), + saver.get_total_time(), + saver.global_step + ) + ) + + saver.log_value({ + 'train/loss': loss.item() + }) + + saver.log_value({ + 'train/lr': current_lr + }) + + # validation + if saver.global_step % args.train.interval_val == 0: + optimizer_save = optimizer if args.train.save_opt else None + + # save latest + saver.save_model(model, optimizer_save, postfix=f'{saver.global_step}') + last_val_step = saver.global_step - args.train.interval_val + if last_val_step % args.train.interval_force_save != 0: + saver.delete_model(postfix=f'{last_val_step}') + + # run testing set + test_loss = test(args, model, vocoder, loader_test, saver) + + # log loss + saver.log_info( + ' --- --- \nloss: {:.3f}. '.format( + test_loss, + ) + ) + + saver.log_value({ + 'validation/loss': test_loss + }) + + model.train() + + diff --git a/server/voice_changer/DDSP_SVC/models/diffusion/uni_pc.py b/server/voice_changer/DDSP_SVC/models/diffusion/uni_pc.py new file mode 100644 index 00000000..4226570c --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/diffusion/uni_pc.py @@ -0,0 +1,731 @@ +import torch +import torch.nn.functional as F +import math + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + dtype=torch.float32, + ): + """Create a wrapper class for the forward SDE (VP type). + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + t = self.inverse_lambda(lambda_t) + =============================================================== + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + 1. For discrete-time DPMs: + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + 2. For continuous-time DPMs: + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + =============================================================== + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + Example: + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) + self.log_alpha_array = log_alphas.reshape((1, -1,)).to(dtype=dtype) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * noise_schedule.total_N + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return (x - alpha_t * output) / sigma_t + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return alpha_t * output + sigma_t * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + return -sigma_t * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * sigma_t * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class UniPC: + def __init__( + self, + model_fn, + noise_schedule, + algorithm_type="data_prediction", + correcting_x0_fn=None, + correcting_xt_fn=None, + thresholding_max_val=1., + dynamic_thresholding_ratio=0.995, + variant='bh1' + ): + """Construct a UniPC. + + We support both data_prediction and noise_prediction. + """ + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) + self.noise_schedule = noise_schedule + assert algorithm_type in ["data_prediction", "noise_prediction"] + + if correcting_x0_fn == "dynamic_thresholding": + self.correcting_x0_fn = self.dynamic_thresholding_fn + else: + self.correcting_x0_fn = correcting_x0_fn + + self.correcting_xt_fn = correcting_xt_fn + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.thresholding_max_val = thresholding_max_val + + self.variant = variant + self.predict_x0 = algorithm_type == "data_prediction" + + def dynamic_thresholding_fn(self, x0, t=None): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with corrector). + """ + noise = self.noise_prediction_fn(x, t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - sigma_t * noise) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3,] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3,] * (K - 1) + [1] + else: + orders = [3,] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2,] * K + else: + K = steps // 2 + 1 + orders = [2,] * (K - 1) + [1] + elif order == 1: + K = steps + orders = [1,] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs): + if len(t.shape) == 0: + t = t.view(-1) + if 'bh' in self.variant: + return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs) + else: + assert self.variant == 'vary_coeff' + return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs) + + def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True): + #print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)') + ns = self.noise_schedule + assert order <= len(model_prev_list) + + # first compute rks + t_prev_0 = t_prev_list[-1] + lambda_prev_0 = ns.marginal_lambda(t_prev_0) + lambda_t = ns.marginal_lambda(t) + model_prev_0 = model_prev_list[-1] + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + log_alpha_t = ns.marginal_log_mean_coeff(t) + alpha_t = torch.exp(log_alpha_t) + + h = lambda_t - lambda_prev_0 + + rks = [] + D1s = [] + for i in range(1, order): + t_prev_i = t_prev_list[-(i + 1)] + model_prev_i = model_prev_list[-(i + 1)] + lambda_prev_i = ns.marginal_lambda(t_prev_i) + rk = (lambda_prev_i - lambda_prev_0) / h + rks.append(rk) + D1s.append((model_prev_i - model_prev_0) / rk) + + rks.append(1.) + rks = torch.tensor(rks, device=x.device) + + K = len(rks) + # build C matrix + C = [] + + col = torch.ones_like(rks) + for k in range(1, K + 1): + C.append(col) + col = col * rks / (k + 1) + C = torch.stack(C, dim=1) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + C_inv_p = torch.linalg.inv(C[:-1, :-1]) + A_p = C_inv_p + + if use_corrector: + #print('using corrector') + C_inv = torch.linalg.inv(C) + A_c = C_inv + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) + h_phi_ks = [] + factorial_k = 1 + h_phi_k = h_phi_1 + for k in range(1, K + 2): + h_phi_ks.append(h_phi_k) + h_phi_k = h_phi_k / hh - 1 / factorial_k + factorial_k *= (k + 1) + + model_t = None + if self.predict_x0: + x_t_ = ( + sigma_t / sigma_prev_0 * x + - alpha_t * h_phi_1 * model_prev_0 + ) + # now predictor + x_t = x_t_ + if len(D1s) > 0: + # compute the residuals for predictor + for k in range(K - 1): + x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k]) + # now corrector + if use_corrector: + model_t = self.model_fn(x_t, t) + D1_t = (model_t - model_prev_0) + x_t = x_t_ + k = 0 + for k in range(K - 1): + x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1]) + x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1]) + else: + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + x_t_ = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * h_phi_1) * model_prev_0 + ) + # now predictor + x_t = x_t_ + if len(D1s) > 0: + # compute the residuals for predictor + for k in range(K - 1): + x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k]) + # now corrector + if use_corrector: + model_t = self.model_fn(x_t, t) + D1_t = (model_t - model_prev_0) + x_t = x_t_ + k = 0 + for k in range(K - 1): + x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1]) + x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1]) + return x_t, model_t + + def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True): + #print(f'using unified predictor-corrector with order {order} (solver type: B(h))') + ns = self.noise_schedule + assert order <= len(model_prev_list) + + # first compute rks + t_prev_0 = t_prev_list[-1] + lambda_prev_0 = ns.marginal_lambda(t_prev_0) + lambda_t = ns.marginal_lambda(t) + model_prev_0 = model_prev_list[-1] + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + alpha_t = torch.exp(log_alpha_t) + + h = lambda_t - lambda_prev_0 + + rks = [] + D1s = [] + for i in range(1, order): + t_prev_i = t_prev_list[-(i + 1)] + model_prev_i = model_prev_list[-(i + 1)] + lambda_prev_i = ns.marginal_lambda(t_prev_i) + rk = (lambda_prev_i - lambda_prev_0) / h + rks.append(rk) + D1s.append((model_prev_i - model_prev_0) / rk) + + rks.append(1.) + rks = torch.tensor(rks, device=x.device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.variant == 'bh1': + B_h = hh + elif self.variant == 'bh2': + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= (i + 1) + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.cat(b) + + # now predictor + use_predictor = len(D1s) > 0 and x_t is None + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + if x_t is None: + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], device=b.device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) + else: + D1s = None + + if use_corrector: + #print('using corrector') + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], device=b.device) + else: + rhos_c = torch.linalg.solve(R, b) + + model_t = None + if self.predict_x0: + x_t_ = ( + sigma_t / sigma_prev_0 * x + - alpha_t * h_phi_1 * model_prev_0 + ) + + if x_t is None: + if use_predictor: + pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + + if use_corrector: + model_t = self.model_fn(x_t, t) + if D1s is not None: + corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = (model_t - model_prev_0) + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = ( + torch.exp(log_alpha_t - log_alpha_prev_0) * x + - sigma_t * h_phi_1 * model_prev_0 + ) + if x_t is None: + if use_predictor: + pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + if use_corrector: + model_t = self.model_fn(x_t, t) + if D1s is not None: + corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = (model_t - model_prev_0) + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + return x_t, model_t + + def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', + method='multistep', lower_order_final=True, denoise_to_zero=False, atol=0.0078, rtol=0.05, return_intermediate=False, + ): + """ + Compute the sample at time `t_end` by UniPC, given the initial `x` at time `t_start`. + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + if return_intermediate: + assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values" + if self.correcting_xt_fn is not None: + assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None" + device = x.device + intermediates = [] + with torch.no_grad(): + if method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + t_prev_list = [t] + model_prev_list = [self.model_fn(x, t)] + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + + # Init the first `order` values by lower order multistep UniPC. + for step in range(1, order): + t = timesteps[step] + x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, t, step, use_corrector=True) + if model_x is None: + model_x = self.model_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + t_prev_list.append(t) + model_prev_list.append(model_x) + + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in range(order, steps + 1): + t = timesteps[step] + if lower_order_final: + step_order = min(order, steps + 1 - step) + else: + step_order = order + if step == steps: + #print('do not run corrector at the last step') + use_corrector = False + else: + use_corrector = True + x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, t, step_order, use_corrector=use_corrector) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = t + # We do not need to evaluate the final model value. + if step < steps: + if model_x is None: + model_x = self.model_fn(x, t) + model_prev_list[-1] = model_x + else: + raise ValueError("Got wrong method {}".format(method)) + + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + if return_intermediate: + return x, intermediates + else: + return x + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,)*(dims - 1)] \ No newline at end of file diff --git a/server/voice_changer/DDSP_SVC/models/diffusion/unit2mel.py b/server/voice_changer/DDSP_SVC/models/diffusion/unit2mel.py new file mode 100644 index 00000000..ea0c2598 --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/diffusion/unit2mel.py @@ -0,0 +1,77 @@ +import os +import yaml +import torch +import torch.nn as nn +import numpy as np +from .diffusion import GaussianDiffusion +from .wavenet import WaveNet +from .vocoder import Vocoder + + +class DotDict(dict): + def __getattr__(*args): + val = dict.get(*args) + return DotDict(val) if type(val) is dict else val + + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +def load_model_vocoder(model_path, device="cpu"): + config_file = os.path.join(os.path.split(model_path)[0], "config.yaml") + with open(config_file, "r") as config: + args = yaml.safe_load(config) + args = DotDict(args) + + # load vocoder + vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=device) + + # load model + model = Unit2Mel(args.data.encoder_out_channels, args.model.n_spk, args.model.use_pitch_aug, vocoder.dimension, args.model.n_layers, args.model.n_chans, args.model.n_hidden) + + print(" [Loading] " + model_path) + ckpt = torch.load(model_path, map_location=torch.device(device)) + model.to(device) + model.load_state_dict(ckpt["model"]) + model.eval() + return model, vocoder, args + + +class Unit2Mel(nn.Module): + def __init__(self, input_channel, n_spk, use_pitch_aug=False, out_dims=128, n_layers=20, n_chans=384, n_hidden=256): + super().__init__() + self.unit_embed = nn.Linear(input_channel, n_hidden) + self.f0_embed = nn.Linear(1, n_hidden) + self.volume_embed = nn.Linear(1, n_hidden) + if use_pitch_aug: + self.aug_shift_embed = nn.Linear(1, n_hidden, bias=False) + else: + self.aug_shift_embed = None + self.n_spk = n_spk + if n_spk is not None and n_spk > 1: + self.spk_embed = nn.Embedding(n_spk, n_hidden) + + # diffusion + self.decoder = GaussianDiffusion(WaveNet(out_dims, n_layers, n_chans, n_hidden), out_dims=out_dims) + + def forward(self, units, f0, volume, spk_id=None, spk_mix_dict=None, aug_shift=None, gt_spec=None, infer=True, infer_speedup=10, method="dpm-solver", k_step=300, use_tqdm=True): + """ + input: + B x n_frames x n_unit + return: + dict of B x n_frames x feat + """ + + x = self.unit_embed(units) + self.f0_embed((1 + f0 / 700).log()) + self.volume_embed(volume) + if self.n_spk is not None and self.n_spk > 1: + if spk_mix_dict is not None: + for k, v in spk_mix_dict.items(): + spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device) + x = x + v * self.spk_embed(spk_id_torch - 1) + else: + x = x + self.spk_embed(spk_id - 1) + if self.aug_shift_embed is not None and aug_shift is not None: + x = x + self.aug_shift_embed(aug_shift / 5) + x = self.decoder(x, gt_spec=gt_spec, infer=infer, infer_speedup=infer_speedup, method=method, k_step=k_step, use_tqdm=use_tqdm) + + return x diff --git a/server/voice_changer/DDSP_SVC/models/diffusion/vocoder.py b/server/voice_changer/DDSP_SVC/models/diffusion/vocoder.py new file mode 100644 index 00000000..7a53099f --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/diffusion/vocoder.py @@ -0,0 +1,87 @@ +import torch +from torchaudio.transforms import Resample +from ..nsf_hifigan.nvSTFT import STFT +from ..nsf_hifigan.models import load_model, load_config + + +class Vocoder: + def __init__(self, vocoder_type, vocoder_ckpt, device=None): + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = device + + if vocoder_type == "nsf-hifigan": + self.vocoder = NsfHifiGAN(vocoder_ckpt, device=device) + elif vocoder_type == "nsf-hifigan-log10": + self.vocoder = NsfHifiGANLog10(vocoder_ckpt, device=device) + else: + raise ValueError(f" [x] Unknown vocoder: {vocoder_type}") + + self.resample_kernel = {} + self.vocoder_sample_rate = self.vocoder.sample_rate() + self.vocoder_hop_size = self.vocoder.hop_size() + self.dimension = self.vocoder.dimension() + + def extract(self, audio, sample_rate, keyshift=0): + # resample + if sample_rate == self.vocoder_sample_rate: + audio_res = audio + else: + key_str = str(sample_rate) + if key_str not in self.resample_kernel: + self.resample_kernel[key_str] = Resample(sample_rate, self.vocoder_sample_rate, lowpass_filter_width=128).to(self.device) + audio_res = self.resample_kernel[key_str](audio) + + # extract + mel = self.vocoder.extract(audio_res, keyshift=keyshift) # B, n_frames, bins + return mel + + def infer(self, mel, f0): + f0 = f0[:, : mel.size(1), 0] # B, n_frames + audio = self.vocoder(mel, f0) + return audio + + +class NsfHifiGAN(torch.nn.Module): + def __init__(self, model_path, device=None): + super().__init__() + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = device + self.model_path = model_path + self.model = None + self.h = load_config(model_path) + self.stft = STFT(self.h.sampling_rate, self.h.num_mels, self.h.n_fft, self.h.win_size, self.h.hop_size, self.h.fmin, self.h.fmax) + + def sample_rate(self): + return self.h.sampling_rate + + def hop_size(self): + return self.h.hop_size + + def dimension(self): + return self.h.num_mels + + def extract(self, audio, keyshift=0): + mel = self.stft.get_mel(audio, keyshift=keyshift).transpose(1, 2) # B, n_frames, bins + return mel + + def forward(self, mel, f0): + if self.model is None: + print("| Load HifiGAN: ", self.model_path) + self.model, self.h = load_model(self.model_path, device=self.device) + with torch.no_grad(): + c = mel.transpose(1, 2) + audio = self.model(c, f0) + return audio + + +class NsfHifiGANLog10(NsfHifiGAN): + def forward(self, mel, f0): + if self.model is None: + print("| Load HifiGAN: ", self.model_path) + self.model, self.h = load_model(self.model_path, device=self.device) + with torch.no_grad(): + c = 0.434294 * mel.transpose(1, 2) + audio = self.model(c, f0) + return audio diff --git a/server/voice_changer/DDSP_SVC/models/diffusion/wavenet.py b/server/voice_changer/DDSP_SVC/models/diffusion/wavenet.py new file mode 100644 index 00000000..5a1c7a35 --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/diffusion/wavenet.py @@ -0,0 +1,91 @@ +import math +from math import sqrt + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Mish + + +class Conv1d(torch.nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + nn.init.kaiming_normal_(self.weight) + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class ResidualBlock(nn.Module): + def __init__(self, encoder_hidden, residual_channels, dilation): + super().__init__() + self.residual_channels = residual_channels + self.dilated_conv = nn.Conv1d(residual_channels, 2 * residual_channels, kernel_size=3, padding=dilation, dilation=dilation) + self.diffusion_projection = nn.Linear(residual_channels, residual_channels) + self.conditioner_projection = nn.Conv1d(encoder_hidden, 2 * residual_channels, 1) + self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1) + + def forward(self, x, conditioner, diffusion_step): + diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) + conditioner = self.conditioner_projection(conditioner) + y = x + diffusion_step + + y = self.dilated_conv(y) + conditioner + + # Using torch.split instead of torch.chunk to avoid using onnx::Slice + gate, filter = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) + y = torch.sigmoid(gate) * torch.tanh(filter) + + y = self.output_projection(y) + + # Using torch.split instead of torch.chunk to avoid using onnx::Slice + residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) + return (x + residual) / math.sqrt(2.0), skip + + +class WaveNet(nn.Module): + def __init__(self, in_dims=128, n_layers=20, n_chans=384, n_hidden=256): + super().__init__() + self.input_projection = Conv1d(in_dims, n_chans, 1) + self.diffusion_embedding = SinusoidalPosEmb(n_chans) + self.mlp = nn.Sequential(nn.Linear(n_chans, n_chans * 4), Mish(), nn.Linear(n_chans * 4, n_chans)) + self.residual_layers = nn.ModuleList([ResidualBlock(encoder_hidden=n_hidden, residual_channels=n_chans, dilation=1) for i in range(n_layers)]) + self.skip_projection = Conv1d(n_chans, n_chans, 1) + self.output_projection = Conv1d(n_chans, in_dims, 1) + nn.init.zeros_(self.output_projection.weight) + + def forward(self, spec, diffusion_step, cond): + """ + :param spec: [B, 1, M, T] + :param diffusion_step: [B, 1] + :param cond: [B, M, T] + :return: + """ + x = spec.squeeze(1) + x = self.input_projection(x) # [B, residual_channel, T] + + x = F.relu(x) + diffusion_step = self.diffusion_embedding(diffusion_step) + diffusion_step = self.mlp(diffusion_step) + skip = [] + for layer in self.residual_layers: + x, skip_connection = layer(x, cond, diffusion_step) + skip.append(skip_connection) + + x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers)) + x = self.skip_projection(x) + x = F.relu(x) + x = self.output_projection(x) # [B, mel_bins, T] + return x[:, None, :, :] diff --git a/server/voice_changer/DDSP_SVC/models/encoder/hubert/model.py b/server/voice_changer/DDSP_SVC/models/encoder/hubert/model.py new file mode 100644 index 00000000..3fd58192 --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/encoder/hubert/model.py @@ -0,0 +1,293 @@ +import copy +from typing import Optional, Tuple +import random + +from sklearn.cluster import KMeans + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present + +URLS = { + "hubert-discrete": "https://github.com/bshall/hubert/releases/download/v0.1/hubert-discrete-e9416457.pt", + "hubert-soft": "https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt", + "kmeans100": "https://github.com/bshall/hubert/releases/download/v0.1/kmeans100-50f36a95.pt", +} + + +class Hubert(nn.Module): + def __init__(self, num_label_embeddings: int = 100, mask: bool = True): + super().__init__() + self._mask = mask + self.feature_extractor = FeatureExtractor() + self.feature_projection = FeatureProjection() + self.positional_embedding = PositionalConvEmbedding() + self.norm = nn.LayerNorm(768) + self.dropout = nn.Dropout(0.1) + self.encoder = TransformerEncoder( + nn.TransformerEncoderLayer( + 768, 12, 3072, activation="gelu", batch_first=True + ), + 12, + ) + self.proj = nn.Linear(768, 256) + + self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_()) + self.label_embedding = nn.Embedding(num_label_embeddings, 256) + + def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + mask = None + if self.training and self._mask: + mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2) + x[mask] = self.masked_spec_embed.to(x.dtype) + return x, mask + + def encode( + self, x: torch.Tensor, layer: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = self.feature_extractor(x) + x = self.feature_projection(x.transpose(1, 2)) + x, mask = self.mask(x) + x = x + self.positional_embedding(x) + x = self.dropout(self.norm(x)) + x = self.encoder(x, output_layer=layer) + return x, mask + + def logits(self, x: torch.Tensor) -> torch.Tensor: + logits = torch.cosine_similarity( + x.unsqueeze(2), + self.label_embedding.weight.unsqueeze(0).unsqueeze(0), + dim=-1, + ) + return logits / 0.1 + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x, mask = self.encode(x) + x = self.proj(x) + logits = self.logits(x) + return logits, mask + + +class HubertSoft(Hubert): + def __init__(self): + super().__init__() + + @torch.inference_mode() + def units(self, wav: torch.Tensor) -> torch.Tensor: + wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2)) + x, _ = self.encode(wav) + return self.proj(x) + + +class HubertDiscrete(Hubert): + def __init__(self, kmeans): + super().__init__(504) + self.kmeans = kmeans + + @torch.inference_mode() + def units(self, wav: torch.Tensor) -> torch.LongTensor: + wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2)) + x, _ = self.encode(wav, layer=7) + x = self.kmeans.predict(x.squeeze().cpu().numpy()) + return torch.tensor(x, dtype=torch.long, device=wav.device) + + +class FeatureExtractor(nn.Module): + def __init__(self): + super().__init__() + self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False) + self.norm0 = nn.GroupNorm(512, 512) + self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False) + self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.gelu(self.norm0(self.conv0(x))) + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = F.gelu(self.conv3(x)) + x = F.gelu(self.conv4(x)) + x = F.gelu(self.conv5(x)) + x = F.gelu(self.conv6(x)) + return x + + +class FeatureProjection(nn.Module): + def __init__(self): + super().__init__() + self.norm = nn.LayerNorm(512) + self.projection = nn.Linear(512, 768) + self.dropout = nn.Dropout(0.1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x) + x = self.projection(x) + x = self.dropout(x) + return x + + +class PositionalConvEmbedding(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv1d( + 768, + 768, + kernel_size=128, + padding=128 // 2, + groups=16, + ) + self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x.transpose(1, 2)) + x = F.gelu(x[:, :, :-1]) + return x.transpose(1, 2) + + +class TransformerEncoder(nn.Module): + def __init__( + self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int + ) -> None: + super(TransformerEncoder, self).__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for _ in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: torch.Tensor, + mask: torch.Tensor = None, + src_key_padding_mask: torch.Tensor = None, + output_layer: Optional[int] = None, + ) -> torch.Tensor: + output = src + for layer in self.layers[:output_layer]: + output = layer( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask + ) + return output + + +def _compute_mask( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + device: torch.device, + min_masks: int = 0, +) -> torch.Tensor: + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`" + ) + + # compute number of masked spans in batch + num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random()) + num_masked_spans = max(num_masked_spans, min_masks) + + # make sure num masked indices <= sequence_length + if num_masked_spans * mask_length > sequence_length: + num_masked_spans = sequence_length // mask_length + + # SpecAugment mask to fill + mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool) + + # uniform distribution to sample from, make sure that offset samples are < sequence_length + uniform_dist = torch.ones( + (batch_size, sequence_length - (mask_length - 1)), device=device + ) + + # get random indices to mask + mask_indices = torch.multinomial(uniform_dist, num_masked_spans) + + # expand masked indices to masked spans + mask_indices = ( + mask_indices.unsqueeze(dim=-1) + .expand((batch_size, num_masked_spans, mask_length)) + .reshape(batch_size, num_masked_spans * mask_length) + ) + offsets = ( + torch.arange(mask_length, device=device)[None, None, :] + .expand((batch_size, num_masked_spans, mask_length)) + .reshape(batch_size, num_masked_spans * mask_length) + ) + mask_idxs = mask_indices + offsets + + # scatter indices to mask + mask = mask.scatter(1, mask_idxs, True) + + return mask + + +def hubert_discrete( + pretrained: bool = True, + progress: bool = True, +) -> HubertDiscrete: + r"""HuBERT-Discrete from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. + Args: + pretrained (bool): load pretrained weights into the model + progress (bool): show progress bar when downloading model + """ + kmeans = kmeans100(pretrained=pretrained, progress=progress) + hubert = HubertDiscrete(kmeans) + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + URLS["hubert-discrete"], progress=progress + ) + consume_prefix_in_state_dict_if_present(checkpoint, "module.") + hubert.load_state_dict(checkpoint) + hubert.eval() + return hubert + + +def hubert_soft( + pretrained: bool = True, + progress: bool = True, +) -> HubertSoft: + r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. + Args: + pretrained (bool): load pretrained weights into the model + progress (bool): show progress bar when downloading model + """ + hubert = HubertSoft() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + URLS["hubert-soft"], progress=progress + ) + consume_prefix_in_state_dict_if_present(checkpoint, "module.") + hubert.load_state_dict(checkpoint) + hubert.eval() + return hubert + + +def _kmeans( + num_clusters: int, pretrained: bool = True, progress: bool = True +) -> KMeans: + kmeans = KMeans(num_clusters) + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + URLS[f"kmeans{num_clusters}"], progress=progress + ) + kmeans.__dict__["n_features_in_"] = checkpoint["n_features_in_"] + kmeans.__dict__["_n_threads"] = checkpoint["_n_threads"] + kmeans.__dict__["cluster_centers_"] = checkpoint["cluster_centers_"].numpy() + return kmeans + + +def kmeans100(pretrained: bool = True, progress: bool = True) -> KMeans: + r""" + k-means checkpoint for HuBERT-Discrete with 100 clusters. + Args: + pretrained (bool): load pretrained weights into the model + progress (bool): show progress bar when downloading model + """ + return _kmeans(100, pretrained, progress) \ No newline at end of file diff --git a/server/voice_changer/DDSP_SVC/models/enhancer.py b/server/voice_changer/DDSP_SVC/models/enhancer.py new file mode 100644 index 00000000..c85aa848 --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/enhancer.py @@ -0,0 +1,102 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torchaudio.transforms import Resample +from .nsf_hifigan.nvSTFT import STFT +from .nsf_hifigan.models import load_model + + +class Enhancer: + def __init__(self, enhancer_type, enhancer_ckpt, device=None): + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = device + + if enhancer_type == "nsf-hifigan": + self.enhancer = NsfHifiGAN(enhancer_ckpt, device=self.device) + else: + raise ValueError(f" [x] Unknown enhancer: {enhancer_type}") + + self.resample_kernel = {} + self.enhancer_sample_rate = self.enhancer.sample_rate() + self.enhancer_hop_size = self.enhancer.hop_size() + + def enhance(self, audio, sample_rate, f0, hop_size, adaptive_key=0, silence_front=0): # 1, T # 1, n_frames, 1 + # enhancer start time + start_frame = int(silence_front * sample_rate / hop_size) + real_silence_front = start_frame * hop_size / sample_rate + audio = audio[:, int(np.round(real_silence_front * sample_rate)) :] + f0 = f0[:, start_frame:, :] + + # adaptive parameters + if adaptive_key == "auto": + adaptive_key = 12 * np.log2(float(torch.max(f0) / 760)) + adaptive_key = max(0, np.ceil(adaptive_key)) + print("auto_adaptive_key: " + str(int(adaptive_key))) + else: + adaptive_key = float(adaptive_key) + + adaptive_factor = 2 ** (-adaptive_key / 12) + adaptive_sample_rate = 100 * int(np.round(self.enhancer_sample_rate / adaptive_factor / 100)) + real_factor = self.enhancer_sample_rate / adaptive_sample_rate + + # resample the ddsp output + if sample_rate == adaptive_sample_rate: + audio_res = audio + else: + key_str = str(sample_rate) + str(adaptive_sample_rate) + if key_str not in self.resample_kernel: + self.resample_kernel[key_str] = Resample(sample_rate, adaptive_sample_rate, lowpass_filter_width=128).to(self.device) + audio_res = self.resample_kernel[key_str](audio) + + n_frames = int(audio_res.size(-1) // self.enhancer_hop_size + 1) + + # resample f0 + if hop_size == self.enhancer_hop_size and sample_rate == self.enhancer_sample_rate and sample_rate == adaptive_sample_rate: + f0_res = f0.squeeze(-1) # 1, n_frames + else: + f0_np = f0.squeeze(0).squeeze(-1).cpu().numpy() + f0_np *= real_factor + time_org = (hop_size / sample_rate) * np.arange(len(f0_np)) / real_factor + time_frame = (self.enhancer_hop_size / self.enhancer_sample_rate) * np.arange(n_frames) + f0_res = np.interp(time_frame, time_org, f0_np, left=f0_np[0], right=f0_np[-1]) + f0_res = torch.from_numpy(f0_res).unsqueeze(0).float().to(self.device) # 1, n_frames + + # enhance + enhanced_audio, enhancer_sample_rate = self.enhancer(audio_res, f0_res) + + # resample the enhanced output + if adaptive_sample_rate != enhancer_sample_rate: + key_str = str(adaptive_sample_rate) + str(enhancer_sample_rate) + if key_str not in self.resample_kernel: + self.resample_kernel[key_str] = Resample(adaptive_sample_rate, enhancer_sample_rate, lowpass_filter_width=128).to(self.device) + enhanced_audio = self.resample_kernel[key_str](enhanced_audio) + + # pad the silence frames + if start_frame > 0: + enhanced_audio = F.pad(enhanced_audio, (int(np.round(enhancer_sample_rate * real_silence_front)), 0)) + + return enhanced_audio, enhancer_sample_rate + + +class NsfHifiGAN(torch.nn.Module): + def __init__(self, model_path, device=None): + super().__init__() + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = device + print("| Load HifiGAN: ", model_path) + self.model, self.h = load_model(model_path, device=self.device) + self.stft = STFT(self.h.sampling_rate, self.h.num_mels, self.h.n_fft, self.h.win_size, self.h.hop_size, self.h.fmin, self.h.fmax) + + def sample_rate(self): + return self.h.sampling_rate + + def hop_size(self): + return self.h.hop_size + + def forward(self, audio, f0): + with torch.no_grad(): + mel = self.stft.get_mel(audio) + enhanced_audio = self.model(mel, f0[:, : mel.size(-1)]) + return enhanced_audio, self.h.sampling_rate diff --git a/server/voice_changer/DDSP_SVC/models/nsf_hifigan/env.py b/server/voice_changer/DDSP_SVC/models/nsf_hifigan/env.py new file mode 100644 index 00000000..2bdbc95d --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/nsf_hifigan/env.py @@ -0,0 +1,15 @@ +import os +import shutil + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) diff --git a/server/voice_changer/DDSP_SVC/models/nsf_hifigan/models.py b/server/voice_changer/DDSP_SVC/models/nsf_hifigan/models.py new file mode 100644 index 00000000..f53ff392 --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/nsf_hifigan/models.py @@ -0,0 +1,434 @@ +import os +import json +from .env import AttrDict +import numpy as np +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from .utils import init_weights, get_padding + +LRELU_SLOPE = 0.1 + + +def load_model(model_path, device='cuda'): + h = load_config(model_path) + + generator = Generator(h).to(device) + + cp_dict = torch.load(model_path, map_location=device) + generator.load_state_dict(cp_dict['generator']) + generator.eval() + generator.remove_weight_norm() + del cp_dict + return generator, h + +def load_config(model_path): + config_file = os.path.join(os.path.split(model_path)[0], 'config.json') + with open(config_file) as f: + data = f.read() + + json_config = json.loads(data) + h = AttrDict(json_config) + return h + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class SineGen(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + + def _f02uv(self, f0): + # generate uv signal + uv = torch.ones_like(f0) + uv = uv * (f0 > self.voiced_threshold) + return uv + + @torch.no_grad() + def forward(self, f0, upp): + """ sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + f0 = f0.unsqueeze(-1) + fn = torch.multiply(f0, torch.arange(1, self.dim + 1, device=f0.device).reshape((1, 1, -1))) + rad_values = (fn / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化 + rand_ini = torch.rand(fn.shape[0], fn.shape[2], device=fn.device) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + is_half = rad_values.dtype is not torch.float32 + tmp_over_one = torch.cumsum(rad_values.double(), 1) # % 1 #####%1意味着后面的cumsum无法再优化 + if is_half: + tmp_over_one = tmp_over_one.half() + else: + tmp_over_one = tmp_over_one.float() + tmp_over_one *= upp + tmp_over_one = F.interpolate( + tmp_over_one.transpose(2, 1), scale_factor=upp, + mode='linear', align_corners=True + ).transpose(2, 1) + rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1) + tmp_over_one %= 1 + tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + rad_values = rad_values.double() + cumsum_shift = cumsum_shift.double() + sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi) + if is_half: + sine_waves = sine_waves.half() + else: + sine_waves = sine_waves.float() + sine_waves = sine_waves * self.sine_amp + return sine_waves + + +class SourceModuleHnNSF(torch.nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen(sampling_rate, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x, upp): + sine_wavs = self.l_sin_gen(x, upp) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + return sine_merge + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.m_source = SourceModuleHnNSF( + sampling_rate=h.sampling_rate, + harmonic_num=8 + ) + self.noise_convs = nn.ModuleList() + self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if h.resblock == '1' else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + c_cur = h.upsample_initial_channel // (2 ** (i + 1)) + self.ups.append(weight_norm( + ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)), + k, u, padding=(k - u) // 2))) + if i + 1 < len(h.upsample_rates): # + stride_f0 = int(np.prod(h.upsample_rates[i + 1:])) + self.noise_convs.append(Conv1d( + 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2)) + else: + self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) + self.resblocks = nn.ModuleList() + ch = h.upsample_initial_channel + for i in range(len(self.ups)): + ch //= 2 + for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.upp = int(np.prod(h.upsample_rates)) + + def forward(self, x, f0): + har_source = self.m_source(f0, self.upp).transpose(1, 2) + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + x_source = self.noise_convs[i](har_source) + x = x + x_source + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, periods=None): + super(MultiPeriodDiscriminator, self).__init__() + self.periods = periods if periods is not None else [2, 3, 5, 7, 11] + self.discriminators = nn.ModuleList() + for period in self.periods: + self.discriminators.append(DiscriminatorP(period)) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + self.meanpools = nn.ModuleList([ + AvgPool1d(4, 2, padding=2), + AvgPool1d(4, 2, padding=2) + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg ** 2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses diff --git a/server/voice_changer/DDSP_SVC/models/nsf_hifigan/nvSTFT.py b/server/voice_changer/DDSP_SVC/models/nsf_hifigan/nvSTFT.py new file mode 100644 index 00000000..6babb65e --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/nsf_hifigan/nvSTFT.py @@ -0,0 +1,129 @@ +import math +import os +os.environ["LRU_CACHE_CAPACITY"] = "3" +import random +import torch +import torch.utils.data +import numpy as np +import librosa +from librosa.util import normalize +from librosa.filters import mel as librosa_mel_fn +from scipy.io.wavfile import read +import soundfile as sf +import torch.nn.functional as F + +def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False): + sampling_rate = None + try: + data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile. + except Exception as ex: + print(f"'{full_path}' failed to load.\nException:") + print(ex) + if return_empty_on_exception: + return [], sampling_rate or target_sr or 48000 + else: + raise Exception(ex) + + if len(data.shape) > 1: + data = data[:, 0] + assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension) + + if np.issubdtype(data.dtype, np.integer): # if audio data is type int + max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX + else: # if audio data is type fp32 + max_mag = max(np.amax(data), -np.amin(data)) + max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32 + + data = torch.FloatTensor(data.astype(np.float32))/max_mag + + if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except + return [], sampling_rate or target_sr or 48000 + if target_sr is not None and sampling_rate != target_sr: + data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr)) + sampling_rate = target_sr + + return data, sampling_rate + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + +class STFT(): + def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5): + self.target_sr = sr + + self.n_mels = n_mels + self.n_fft = n_fft + self.win_size = win_size + self.hop_length = hop_length + self.fmin = fmin + self.fmax = fmax + self.clip_val = clip_val + self.mel_basis = {} + self.hann_window = {} + + def get_mel(self, y, keyshift=0, speed=1, center=False): + sampling_rate = self.target_sr + n_mels = self.n_mels + n_fft = self.n_fft + win_size = self.win_size + hop_length = self.hop_length + fmin = self.fmin + fmax = self.fmax + clip_val = self.clip_val + + factor = 2 ** (keyshift / 12) + n_fft_new = int(np.round(n_fft * factor)) + win_size_new = int(np.round(win_size * factor)) + hop_length_new = int(np.round(hop_length * speed)) + + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + mel_basis_key = str(fmax)+'_'+str(y.device) + if mel_basis_key not in self.mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) + self.mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device) + + keyshift_key = str(keyshift)+'_'+str(y.device) + if keyshift_key not in self.hann_window: + self.hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device) + + pad_left = (win_size_new - hop_length_new) //2 + pad_right = max((win_size_new- hop_length_new + 1) //2, win_size_new - y.size(-1) - pad_left) + if pad_right < y.size(-1): + mode = 'reflect' + else: + mode = 'constant' + y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode = mode) + y = y.squeeze(1) + + spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=self.hann_window[keyshift_key], + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) + spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9)) + if keyshift != 0: + size = n_fft // 2 + 1 + resize = spec.size(1) + if resize < size: + spec = F.pad(spec, (0, 0, 0, size-resize)) + spec = spec[:, :size, :] * win_size / win_size_new + spec = torch.matmul(self.mel_basis[mel_basis_key], spec) + spec = dynamic_range_compression_torch(spec, clip_val=clip_val) + return spec + + def __call__(self, audiopath): + audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr) + spect = self.get_mel(audio.unsqueeze(0)).squeeze(0) + return spect + +stft = STFT() diff --git a/server/voice_changer/DDSP_SVC/models/nsf_hifigan/utils.py b/server/voice_changer/DDSP_SVC/models/nsf_hifigan/utils.py new file mode 100644 index 00000000..84bff024 --- /dev/null +++ b/server/voice_changer/DDSP_SVC/models/nsf_hifigan/utils.py @@ -0,0 +1,68 @@ +import glob +import os +import matplotlib +import torch +from torch.nn.utils import weight_norm +matplotlib.use("Agg") +import matplotlib.pylab as plt + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print("Saving checkpoint to {}".format(filepath)) + torch.save(obj, filepath) + print("Complete.") + + +def del_old_checkpoints(cp_dir, prefix, n_models=2): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) # get checkpoint paths + cp_list = sorted(cp_list)# sort by iter + if len(cp_list) > n_models: # if more than n_models models are found + for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models + open(cp, 'w').close()# empty file contents + os.unlink(cp)# delete file (move to trash when using Colab) + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] + diff --git a/server/voice_changer/SoVitsSvc40/models/readme.txt b/server/voice_changer/SoVitsSvc40/models/readme.txt new file mode 100644 index 00000000..1df97936 --- /dev/null +++ b/server/voice_changer/SoVitsSvc40/models/readme.txt @@ -0,0 +1,2 @@ +modules in this folder from https://github.com/svc-develop-team/so-vits-svc at 7846711d90b5210d4f39a4d6fab50d1d7bbd8d73 +forked repo: https://github.com/w-okada/so-vits-svc diff --git a/server/voice_changer/SoVitsSvc40v2/SoVitsSvc40v2.py b/server/voice_changer/SoVitsSvc40v2/SoVitsSvc40v2.py deleted file mode 100644 index c7e199d7..00000000 --- a/server/voice_changer/SoVitsSvc40v2/SoVitsSvc40v2.py +++ /dev/null @@ -1,466 +0,0 @@ -import sys -import os - -from voice_changer.utils.LoadModelParams import LoadModelParams -from voice_changer.utils.VoiceChangerModel import AudioInOut -from voice_changer.utils.VoiceChangerParams import VoiceChangerParams - -if sys.platform.startswith("darwin"): - baseDir = [x for x in sys.path if x.endswith("Contents/MacOS")] - if len(baseDir) != 1: - print("baseDir should be only one ", baseDir) - sys.exit() - modulePath = os.path.join(baseDir[0], "so-vits-svc-40v2") - sys.path.append(modulePath) -else: - sys.path.append("so-vits-svc-40v2") - -import io -from dataclasses import dataclass, asdict, field -import numpy as np -import torch -import onnxruntime -import pyworld as pw - -from models import SynthesizerTrn # type:ignore -import cluster # type:ignore -import utils -from fairseq import checkpoint_utils -import librosa - -from Exceptions import NoModeLoadedException - -providers = [ - "OpenVINOExecutionProvider", - "CUDAExecutionProvider", - "DmlExecutionProvider", - "CPUExecutionProvider", -] - - -@dataclass -class SoVitsSvc40v2Settings: - gpu: int = 0 - dstId: int = 0 - - f0Detector: str = "harvest" # dio or harvest - tran: int = 20 - noiseScale: float = 0.3 - predictF0: int = 0 # 0:False, 1:True - silentThreshold: float = 0.00001 - extraConvertSize: int = 1024 * 32 - clusterInferRatio: float = 0.1 - - framework: str = "PyTorch" # PyTorch or ONNX - pyTorchModelFile: str | None = "" - onnxModelFile: str | None = "" - configFile: str = "" - - speakers: dict[str, int] = field(default_factory=lambda: {}) - - # ↓mutableな物だけ列挙 - intData = ["gpu", "dstId", "tran", "predictF0", "extraConvertSize"] - floatData = ["noiseScale", "silentThreshold", "clusterInferRatio"] - strData = ["framework", "f0Detector"] - - -class SoVitsSvc40v2: - audio_buffer: AudioInOut | None = None - - def __init__(self, params: VoiceChangerParams): - self.settings = SoVitsSvc40v2Settings() - self.net_g = None - self.onnx_session = None - - self.raw_path = io.BytesIO() - self.gpu_num = torch.cuda.device_count() - self.prevVol = 0 - self.params = params - print("[Voice Changer] so-vits-svc 40v2 initialization:", params) - - def loadModel(self, props: LoadModelParams): - params = props.params - self.settings.configFile = params["files"]["soVitsSvc40v2Config"] - self.hps = utils.get_hparams_from_file(self.settings.configFile) - self.settings.speakers = self.hps.spk - - modelFile = params["files"]["soVitsSvc40v2Model"] - if modelFile.endswith(".onnx"): - self.settings.pyTorchModelFile = None - self.settings.onnxModelFile = modelFile - else: - self.settings.pyTorchModelFile = modelFile - self.settings.onnxModelFile = None - - clusterTorchModel = params["files"]["soVitsSvc40v2Cluster"] - - content_vec_path = self.params.content_vec_500 - hubert_base_path = self.params.hubert_base - - # hubert model - try: - if os.path.exists(content_vec_path) is False: - content_vec_path = hubert_base_path - - models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( - [content_vec_path], - suffix="", - ) - model = models[0] - model.eval() - self.hubert_model = model.cpu() - except Exception as e: - print("EXCEPTION during loading hubert/contentvec model", e) - - # cluster - try: - if clusterTorchModel is not None and os.path.exists(clusterTorchModel): - self.cluster_model = cluster.get_cluster_model(clusterTorchModel) - else: - self.cluster_model = None - except Exception as e: - print("EXCEPTION during loading cluster model ", e) - - # PyTorchモデル生成 - if self.settings.pyTorchModelFile is not None: - net_g = SynthesizerTrn(self.hps) - net_g.eval() - self.net_g = net_g - utils.load_checkpoint(self.settings.pyTorchModelFile, self.net_g, None) - - # ONNXモデル生成 - if self.settings.onnxModelFile is not None: - providers, options = self.getOnnxExecutionProvider() - self.onnx_session = onnxruntime.InferenceSession( - self.settings.onnxModelFile, - providers=providers, - provider_options=options, - ) - return self.get_info() - - def getOnnxExecutionProvider(self): - availableProviders = onnxruntime.get_available_providers() - if self.settings.gpu >= 0 and "CUDAExecutionProvider" in availableProviders: - return ["CUDAExecutionProvider"], [{"device_id": self.settings.gpu}] - elif self.settings.gpu >= 0 and "DmlExecutionProvider" in availableProviders: - return ["DmlExecutionProvider"], [{}] - else: - return ["CPUExecutionProvider"], [ - { - "intra_op_num_threads": 8, - "execution_mode": onnxruntime.ExecutionMode.ORT_PARALLEL, - "inter_op_num_threads": 8, - } - ] - - def isOnnx(self): - if self.settings.onnxModelFile is not None: - return True - else: - return False - - def update_settings(self, key: str, val: int | float | str): - if key in self.settings.intData: - val = int(val) - setattr(self.settings, key, val) - - if key == "gpu" and self.isOnnx(): - providers, options = self.getOnnxExecutionProvider() - if self.onnx_session is not None: - self.onnx_session.set_providers( - providers=providers, - provider_options=options, - ) - elif key in self.settings.floatData: - setattr(self.settings, key, float(val)) - elif key in self.settings.strData: - setattr(self.settings, key, str(val)) - else: - return False - - return True - - def get_info(self): - data = asdict(self.settings) - - data["onnxExecutionProviders"] = ( - self.onnx_session.get_providers() if self.onnx_session is not None else [] - ) - files = ["configFile", "pyTorchModelFile", "onnxModelFile"] - for f in files: - if data[f] is not None and os.path.exists(data[f]): - data[f] = os.path.basename(data[f]) - else: - data[f] = "" - - return data - - def get_processing_sampling_rate(self): - if hasattr(self, "hps") is False: - raise NoModeLoadedException("config") - return self.hps.data.sampling_rate - - def get_unit_f0(self, audio_buffer, tran): - wav_44k = audio_buffer - # f0 = utils.compute_f0_parselmouth(wav, sampling_rate=self.target_sample, hop_length=self.hop_size) - # f0 = utils.compute_f0_dio(wav_44k, sampling_rate=self.hps.data.sampling_rate, hop_length=self.hps.data.hop_length) - if self.settings.f0Detector == "dio": - f0 = compute_f0_dio( - wav_44k, - sampling_rate=self.hps.data.sampling_rate, - hop_length=self.hps.data.hop_length, - ) - else: - f0 = compute_f0_harvest( - wav_44k, - sampling_rate=self.hps.data.sampling_rate, - hop_length=self.hps.data.hop_length, - ) - - if wav_44k.shape[0] % self.hps.data.hop_length != 0: - print( - f" !!! !!! !!! wav size not multiple of hopsize: {wav_44k.shape[0] / self.hps.data.hop_length}" - ) - - f0, uv = utils.interpolate_f0(f0) - f0 = torch.FloatTensor(f0) - uv = torch.FloatTensor(uv) - f0 = f0 * 2 ** (tran / 12) - f0 = f0.unsqueeze(0) - uv = uv.unsqueeze(0) - - # wav16k = librosa.resample(audio_buffer, orig_sr=24000, target_sr=16000) - wav16k = librosa.resample( - audio_buffer, orig_sr=self.hps.data.sampling_rate, target_sr=16000 - ) - wav16k = torch.from_numpy(wav16k) - - if ( - self.settings.gpu < 0 or self.gpu_num == 0 - ) or self.settings.framework == "ONNX": - dev = torch.device("cpu") - else: - dev = torch.device("cuda", index=self.settings.gpu) - - self.hubert_model = self.hubert_model.to(dev) - wav16k = wav16k.to(dev) - uv = uv.to(dev) - f0 = f0.to(dev) - - c = utils.get_hubert_content(self.hubert_model, wav_16k_tensor=wav16k) - c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1]) - - if ( - self.settings.clusterInferRatio != 0 - and hasattr(self, "cluster_model") - and self.cluster_model is not None - ): - speaker = [ - key - for key, value in self.settings.speakers.items() - if value == self.settings.dstId - ] - if len(speaker) != 1: - pass - # print("not only one speaker found.", speaker) - else: - cluster_c = cluster.get_cluster_center_result( - self.cluster_model, c.cpu().numpy().T, speaker[0] - ).T - # cluster_c = cluster.get_cluster_center_result(self.cluster_model, c.cpu().numpy().T, self.settings.dstId).T - cluster_c = torch.FloatTensor(cluster_c).to(dev) - # print("cluster DEVICE", cluster_c.device, c.device) - c = ( - self.settings.clusterInferRatio * cluster_c - + (1 - self.settings.clusterInferRatio) * c - ) - - c = c.unsqueeze(0) - return c, f0, uv - - def generate_input( - self, - newData: AudioInOut, - inputSize: int, - crossfadeSize: int, - solaSearchFrame: int = 0, - ): - newData = newData.astype(np.float32) / self.hps.data.max_wav_value - - if self.audio_buffer is not None: - self.audio_buffer = np.concatenate( - [self.audio_buffer, newData], 0 - ) # 過去のデータに連結 - else: - self.audio_buffer = newData - - convertSize = ( - inputSize + crossfadeSize + solaSearchFrame + self.settings.extraConvertSize - ) - - if convertSize % self.hps.data.hop_length != 0: # モデルの出力のホップサイズで切り捨てが発生するので補う。 - convertSize = convertSize + ( - self.hps.data.hop_length - (convertSize % self.hps.data.hop_length) - ) - convertOffset = -1 * convertSize - self.audio_buffer = self.audio_buffer[convertOffset:] # 変換対象の部分だけ抽出 - - cropOffset = -1 * (inputSize + crossfadeSize) - cropEnd = -1 * (crossfadeSize) - crop = self.audio_buffer[cropOffset:cropEnd] - - rms = np.sqrt(np.square(crop).mean(axis=0)) - vol = max(rms, self.prevVol * 0.0) - self.prevVol = vol - - c, f0, uv = self.get_unit_f0(self.audio_buffer, self.settings.tran) - return (c, f0, uv, convertSize, vol) - - def _onnx_inference(self, data): - if hasattr(self, "onnx_session") is False or self.onnx_session is None: - print("[Voice Changer] No onnx session.") - raise NoModeLoadedException("ONNX") - - convertSize = data[3] - vol = data[4] - data = ( - data[0], - data[1], - data[2], - ) - - if vol < self.settings.silentThreshold: - return np.zeros(convertSize).astype(np.int16) - - c, f0, uv = [x.numpy() for x in data] - audio1 = ( - self.onnx_session.run( - ["audio"], - { - "c": c, - "f0": f0, - "g": np.array([self.settings.dstId]).astype(np.int64), - "uv": np.array([self.settings.dstId]).astype(np.int64), - "predict_f0": np.array([self.settings.dstId]).astype(np.int64), - "noice_scale": np.array([self.settings.dstId]).astype(np.int64), - }, - )[0][0, 0] - * self.hps.data.max_wav_value - ) - - audio1 = audio1 * vol - - result = audio1 - - return result - - def _pyTorch_inference(self, data): - if hasattr(self, "net_g") is False or self.net_g is None: - print("[Voice Changer] No pyTorch session.") - raise NoModeLoadedException("pytorch") - - if self.settings.gpu < 0 or self.gpu_num == 0: - dev = torch.device("cpu") - else: - dev = torch.device("cuda", index=self.settings.gpu) - - convertSize = data[3] - vol = data[4] - data = ( - data[0], - data[1], - data[2], - ) - - if vol < self.settings.silentThreshold: - return np.zeros(convertSize).astype(np.int16) - - with torch.no_grad(): - c, f0, uv = [x.to(dev) for x in data] - sid_target = torch.LongTensor([self.settings.dstId]).to(dev) - self.net_g.to(dev) - # audio1 = self.net_g.infer(c, f0=f0, g=sid_target, uv=uv, predict_f0=True, noice_scale=0.1)[0][0, 0].data.float() - predict_f0_flag = True if self.settings.predictF0 == 1 else False - audio1 = self.net_g.infer( - c, - f0=f0, - g=sid_target, - uv=uv, - predict_f0=predict_f0_flag, - noice_scale=self.settings.noiseScale, - )[0][0, 0].data.float() - audio1 = audio1 * self.hps.data.max_wav_value - - audio1 = audio1 * vol - - result = audio1.float().cpu().numpy() - - # result = infer_tool.pad_array(result, length) - return result - - def inference(self, data): - if self.isOnnx(): - audio = self._onnx_inference(data) - else: - audio = self._pyTorch_inference(data) - return audio - - def __del__(self): - del self.net_g - del self.onnx_session - - remove_path = os.path.join("so-vits-svc-40v2") - sys.path = [x for x in sys.path if x.endswith(remove_path) is False] - - for key in list(sys.modules): - val = sys.modules.get(key) - try: - file_path = val.__file__ - if file_path.find("so-vits-svc-40v2" + os.path.sep) >= 0: - # print("remove", key, file_path) - sys.modules.pop(key) - except: # type:ignore - pass - - -def resize_f0(x, target_len): - source = np.array(x) - source[source < 0.001] = np.nan - target = np.interp( - np.arange(0, len(source) * target_len, len(source)) / target_len, - np.arange(0, len(source)), - source, - ) - res = np.nan_to_num(target) - return res - - -def compute_f0_dio(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512): - if p_len is None: - p_len = wav_numpy.shape[0] // hop_length - f0, t = pw.dio( - wav_numpy.astype(np.double), - fs=sampling_rate, - f0_ceil=800, - frame_period=1000 * hop_length / sampling_rate, - ) - f0 = pw.stonemask(wav_numpy.astype(np.double), f0, t, sampling_rate) - for index, pitch in enumerate(f0): - f0[index] = round(pitch, 1) - return resize_f0(f0, p_len) - - -def compute_f0_harvest(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512): - if p_len is None: - p_len = wav_numpy.shape[0] // hop_length - f0, t = pw.harvest( - wav_numpy.astype(np.double), - fs=sampling_rate, - frame_period=5.5, - f0_floor=71.0, - f0_ceil=1000.0, - ) - - for index, pitch in enumerate(f0): - f0[index] = round(pitch, 1) - return resize_f0(f0, p_len) diff --git a/server/voice_changer/VoiceChangerManager.py b/server/voice_changer/VoiceChangerManager.py index 97321814..4d375496 100644 --- a/server/voice_changer/VoiceChangerManager.py +++ b/server/voice_changer/VoiceChangerManager.py @@ -131,9 +131,9 @@ class VoiceChangerManager(ServerDeviceCallbacks): slotInfo = SoVitsSvc40ModelSlotGenerator.loadModel(params) self.modelSlotManager.save_model_slot(params.slot, slotInfo) elif params.voiceChangerType == "DDSP-SVC": - from voice_changer.DDSP_SVC.DDSP_SVC import DDSP_SVC + from voice_changer.DDSP_SVC.DDSP_SVCModelSlotGenerator import DDSP_SVCModelSlotGenerator - slotInfo = DDSP_SVC.loadModel(params) + slotInfo = DDSP_SVCModelSlotGenerator.loadModel(params) self.modelSlotManager.save_model_slot(params.slot, slotInfo) print("params", params) @@ -195,6 +195,13 @@ class VoiceChangerManager(ServerDeviceCallbacks): self.voiceChangerModel = SoVitsSvc40(self.params, slotInfo) self.voiceChanger = VoiceChanger(self.params) self.voiceChanger.setModel(self.voiceChangerModel) + elif slotInfo.voiceChangerType == "DDSP-SVC": + print("................DDSP-SVC") + from voice_changer.DDSP_SVC.DDSP_SVC import DDSP_SVC + + self.voiceChangerModel = DDSP_SVC(self.params, slotInfo) + self.voiceChanger = VoiceChanger(self.params) + self.voiceChanger.setModel(self.voiceChangerModel) else: print(f"[Voice Changer] unknown voice changer model: {slotInfo.voiceChangerType}") del self.voiceChangerModel