From 5bf1202215b5b7303096b64e1c7a82ac196971f9 Mon Sep 17 00:00:00 2001 From: w-okada Date: Fri, 14 Jul 2023 03:33:04 +0900 Subject: [PATCH] WIP: diffusion svc rt badf0 --- server/const.py | 5 + server/data/ModelSlot.py | 6 +- .../DiffusionSVC/DiffusionSVC.py | 40 +- .../DiffusionSVCModelSlotGenerator.py | 1 + .../inferencer/DiffusionSVCInferencer.py | 129 +- .../DiffusionSVC/inferencer/Inferencer.py | 50 + .../inferencer/InferencerManager.py | 29 + .../diffusion_svc_model/DiffusionSVC.py | 196 +-- .../diffusion/diffusion.py | 2 +- .../diffusion/dpm_solver_pytorch.py | 1303 +++++++++++++++++ .../diffusion/naive/naive.py | 10 +- .../diffusion/naive/pcmer.py | 7 + .../diffusion_svc_model/diffusion/unit2mel.py | 8 +- .../diffusion_svc_model/diffusion/vocoder.py | 4 +- .../diffusion_svc_model/tools/tools.py | 11 +- .../DiffusionSVC/pipeline/Pipeline.py | 181 +-- .../pipeline/PipelineGenerator.py | 69 +- server/voice_changer/ModelSlotManager.py | 1 + .../voice_changer/common/VolumeExtractor.py | 41 + 19 files changed, 1709 insertions(+), 384 deletions(-) create mode 100644 server/voice_changer/DiffusionSVC/inferencer/Inferencer.py create mode 100644 server/voice_changer/DiffusionSVC/inferencer/InferencerManager.py create mode 100644 server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/dpm_solver_pytorch.py create mode 100644 server/voice_changer/common/VolumeExtractor.py diff --git a/server/const.py b/server/const.py index a04b62c1..9c9fb1e8 100644 --- a/server/const.py +++ b/server/const.py @@ -66,6 +66,11 @@ class EnumInferenceTypes(Enum): onnxRVCNono = "onnxRVCNono" +DiffusionSVCInferenceType: TypeAlias = Literal[ + "combo", +] + + PitchExtractorType: TypeAlias = Literal[ "harvest", "dio", diff --git a/server/data/ModelSlot.py b/server/data/ModelSlot.py index 87e5be02..00a2df13 100644 --- a/server/data/ModelSlot.py +++ b/server/data/ModelSlot.py @@ -1,5 +1,5 @@ from typing import TypeAlias, Union -from const import MAX_SLOT_NUM, EnumInferenceTypes, EmbedderType, VoiceChangerType +from const import MAX_SLOT_NUM, DiffusionSVCInferenceType, EnumInferenceTypes, EmbedderType, VoiceChangerType from dataclasses import dataclass, asdict, field @@ -107,7 +107,7 @@ class DiffusionSVCModelSlot(ModelSlot): voiceChangerType: VoiceChangerType = "Diffusion-SVC" modelFile: str = "" isONNX: bool = False - modelType: str = "combo" + modelType: DiffusionSVCInferenceType = "combo" dstId: int = 1 sampleId: str = "" @@ -115,6 +115,8 @@ class DiffusionSVCModelSlot(ModelSlot): kstep: int = 100 speakers: dict = field(default_factory=lambda: {1: "user"}) embedder: EmbedderType = "hubert_base" + samplingRate: int = 44100 + embChannels: int = 768 ModelSlots: TypeAlias = Union[ModelSlot, RVCModelSlot, MMVCv13ModelSlot, MMVCv15ModelSlot, SoVitsSvc40ModelSlot, DDSPSVCModelSlot, DiffusionSVCModelSlot] diff --git a/server/voice_changer/DiffusionSVC/DiffusionSVC.py b/server/voice_changer/DiffusionSVC/DiffusionSVC.py index 91bddfe4..6c67f47f 100644 --- a/server/voice_changer/DiffusionSVC/DiffusionSVC.py +++ b/server/voice_changer/DiffusionSVC/DiffusionSVC.py @@ -52,7 +52,7 @@ class DiffusionSVC(VoiceChangerModel): print("[Voice Changer] [DiffusionSVC] Initializing... done") def update_settings(self, key: str, val: int | float | str): - print("[Voice Changer][RVC]: update_settings", key, val) + print("[Voice Changer][DiffusionSVC]: update_settings", key, val) if key in self.settings.intData: setattr(self.settings, key, int(val)) if key == "gpu": @@ -86,19 +86,18 @@ class DiffusionSVC(VoiceChangerModel): crossfadeSize: int, solaSearchFrame: int = 0, ): - newData = newData.astype(np.float32) / 32768.0 # RVCのモデルのサンプリングレートで入ってきている。(extraDataLength, Crossfade等も同じSRで処理)(★1) + newData = newData.astype(np.float32) / 32768.0 # DiffusionSVCのモデルのサンプリングレートで入ってきている。(extraDataLength, Crossfade等も同じSRで処理)(★1) - new_feature_length = newData.shape[0] * 100 // self.slotInfo.samplingRate + new_feature_length = newData.shape[0] * 100 // self.slotInfo.samplingRate # 100 は hubertのhosizeから (16000 / 160) if self.audio_buffer is not None: # 過去のデータに連結 self.audio_buffer = np.concatenate([self.audio_buffer, newData], 0) - if self.slotInfo.f0: - self.pitchf_buffer = np.concatenate([self.pitchf_buffer, np.zeros(new_feature_length)], 0) + self.pitchf_buffer = np.concatenate([self.pitchf_buffer, np.zeros(new_feature_length)], 0) + print("^^^self.feature_buffer.shape, self.slotInfo.embChannels",self.feature_buffer.shape, self.slotInfo.embChannels) self.feature_buffer = np.concatenate([self.feature_buffer, np.zeros([new_feature_length, self.slotInfo.embChannels])], 0) else: self.audio_buffer = newData - if self.slotInfo.f0: - self.pitchf_buffer = np.zeros(new_feature_length) + self.pitchf_buffer = np.zeros(new_feature_length) self.feature_buffer = np.zeros([new_feature_length, self.slotInfo.embChannels]) convertSize = inputSize + crossfadeSize + solaSearchFrame + self.settings.extraConvertSize @@ -110,15 +109,13 @@ class DiffusionSVC(VoiceChangerModel): # バッファがたまっていない場合はzeroで補う if self.audio_buffer.shape[0] < convertSize: self.audio_buffer = np.concatenate([np.zeros([convertSize]), self.audio_buffer]) - if self.slotInfo.f0: - self.pitchf_buffer = np.concatenate([np.zeros([convertSize * 100 // self.slotInfo.samplingRate]), self.pitchf_buffer]) + self.pitchf_buffer = np.concatenate([np.zeros([convertSize * 100 // self.slotInfo.samplingRate]), self.pitchf_buffer]) self.feature_buffer = np.concatenate([np.zeros([convertSize * 100 // self.slotInfo.samplingRate, self.slotInfo.embChannels]), self.feature_buffer]) convertOffset = -1 * convertSize featureOffset = -convertSize * 100 // self.slotInfo.samplingRate self.audio_buffer = self.audio_buffer[convertOffset:] # 変換対象の部分だけ抽出 - if self.slotInfo.f0: - self.pitchf_buffer = self.pitchf_buffer[featureOffset:] + self.pitchf_buffer = self.pitchf_buffer[featureOffset:] self.feature_buffer = self.feature_buffer[featureOffset:] # 出力部分だけ切り出して音量を確認。(TODO:段階的消音にする) @@ -145,18 +142,18 @@ class DiffusionSVC(VoiceChangerModel): if self.pipeline is not None: device = self.pipeline.device else: - device = torch.device("cpu") + device = torch.device("cpu") # TODO:pipelineが存在しない場合はzeroを返してもいいかも(要確認)。 audio = torch.from_numpy(audio).to(device=device, dtype=torch.float32) audio = torchaudio.functional.resample(audio, self.slotInfo.samplingRate, 16000, rolloff=0.99) - repeat = 1 if self.settings.rvcQuality else 0 + repeat = 0 sid = self.settings.dstId f0_up_key = self.settings.tran - index_rate = self.settings.indexRatio - protect = self.settings.protect + index_rate = 0 + protect = 0 - if_f0 = 1 if self.slotInfo.f0 else 0 - embOutputLayer = self.slotInfo.embOutputLayer - useFinalProj = self.slotInfo.useFinalProj + if_f0 = 1 + embOutputLayer = 12 + useFinalProj = False try: audio_out, self.pitchf_buffer, self.feature_buffer = self.pipeline.exec( @@ -167,14 +164,17 @@ class DiffusionSVC(VoiceChangerModel): f0_up_key, index_rate, if_f0, - self.settings.extraConvertSize / self.slotInfo.samplingRate if self.settings.silenceFront else 0., # extaraDataSizeの秒数。RVCのモデルのサンプリングレートで処理(★1)。 + self.settings.extraConvertSize / self.slotInfo.samplingRate if self.settings.silenceFront else 0., # extaraConvertSize(既にモデルのサンプリングレートにリサンプリング済み)の秒数。モデルのサンプリングレートで処理(★1)。 embOutputLayer, useFinalProj, repeat, protect, outSize ) - result = audio_out.detach().cpu().numpy() * np.sqrt(vol) + # result = audio_out.detach().cpu().numpy() * np.sqrt(vol) + result = audio_out.detach().cpu().numpy() + + print("RESULT", result) return result except DeviceCannotSupportHalfPrecisionException as e: # NOQA diff --git a/server/voice_changer/DiffusionSVC/DiffusionSVCModelSlotGenerator.py b/server/voice_changer/DiffusionSVC/DiffusionSVCModelSlotGenerator.py index 27a2250c..3ff1aa39 100644 --- a/server/voice_changer/DiffusionSVC/DiffusionSVCModelSlotGenerator.py +++ b/server/voice_changer/DiffusionSVC/DiffusionSVCModelSlotGenerator.py @@ -21,6 +21,7 @@ class DiffusionSVCModelSlotGenerator(ModelSlotGenerator): slotInfo.isONNX = slotInfo.modelFile.endswith(".onnx") slotInfo.name = os.path.splitext(os.path.basename(slotInfo.modelFile))[0] slotInfo.iconFile = "/assets/icons/noimage.png" + slotInfo.embChannels = 768 # if slotInfo.isONNX: # slotInfo = cls._setInfoByONNX(slotInfo) diff --git a/server/voice_changer/DiffusionSVC/inferencer/DiffusionSVCInferencer.py b/server/voice_changer/DiffusionSVC/inferencer/DiffusionSVCInferencer.py index 4028ac24..f80df59e 100644 --- a/server/voice_changer/DiffusionSVC/inferencer/DiffusionSVCInferencer.py +++ b/server/voice_changer/DiffusionSVC/inferencer/DiffusionSVCInferencer.py @@ -1,35 +1,134 @@ +import numpy as np import torch +from voice_changer.DiffusionSVC.inferencer.Inferencer import Inferencer +from voice_changer.DiffusionSVC.inferencer.diffusion_svc_model.diffusion.naive.naive import Unit2MelNaive +from voice_changer.DiffusionSVC.inferencer.diffusion_svc_model.diffusion.unit2mel import Unit2Mel, load_model_vocoder_from_combo +from voice_changer.DiffusionSVC.inferencer.diffusion_svc_model.diffusion.vocoder import Vocoder from voice_changer.RVC.deviceManager.DeviceManager import DeviceManager -class RVCInferencer(Inferencer): +class DiffusionSVCInferencer(Inferencer): + def __init__(self): + self.diff_model: Unit2Mel | None = None + self.naive_model: Unit2MelNaive | None = None + self.vocoder: Vocoder | None = None + def loadModel(self, file: str, gpu: int): self.setProps("DiffusionSVCCombo", file, True, gpu) - dev = DeviceManager.get_instance().getDevice(gpu) - isHalf = DeviceManager.get_instance().halfPrecisionAvailable(gpu) + self.dev = DeviceManager.get_instance().getDevice(gpu) + # isHalf = DeviceManager.get_instance().halfPrecisionAvailable(gpu) - cpt = torch.load(file, map_location="cpu") - model = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=isHalf) + diff_model, diff_args, naive_model, naive_args, vocoder = load_model_vocoder_from_combo(file, device=self.dev) + self.diff_model = diff_model + self.naive_model = naive_model + self.vocoder = vocoder + self.diff_args = diff_args + print("-----------------> diff_args", diff_args) + print("-----------------> naive_args", naive_args) - model.eval() - model.load_state_dict(cpt["weight"], strict=False) + # cpt = torch.load(file, map_location="cpu") + # model = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=isHalf) - model = model.to(dev) - if isHalf: - model = model.half() + # model.eval() + # model.load_state_dict(cpt["weight"], strict=False) - self.model = model + # model = model.to(dev) + # if isHalf: + # model = model.half() + + # self.model = model return self + + def getConfig(self) -> tuple[int, int]: + model_sampling_rate = int(self.diff_args.data.sampling_rate) + model_block_size = int(self.diff_args.data.block_size) + return model_block_size, model_sampling_rate + @torch.no_grad() # 最基本推理代码,将输入标准化为tensor,只与mel打交道 + def __call__(self, units, f0, volume, spk_id=1, spk_mix_dict=None, aug_shift=0, + gt_spec=None, infer_speedup=10, method='dpm-solver', k_step=None, use_tqdm=True, + spk_emb=None): + + if self.diff_args.model.k_step_max is not None: + if k_step is None: + raise ValueError("k_step must not None when Shallow Diffusion Model inferring") + if k_step > int(self.diff_args.model.k_step_max): + raise ValueError("k_step must <= k_step_max of Shallow Diffusion Model") + if gt_spec is None: + raise ValueError("gt_spec must not None when Shallow Diffusion Model inferring, gt_spec can from " + "input mel or output of naive model") + print(f' [INFO] k_step_max is {self.diff_args.model.k_step_max}.') + + aug_shift = torch.from_numpy(np.array([[float(aug_shift)]])).float().to(self.dev) + + # spk_id + spk_emb_dict = None + if self.diff_args.model.use_speaker_encoder: # with speaker encoder + spk_mix_dict, spk_emb = self.pre_spk_emb(spk_id, spk_mix_dict, len(units), spk_emb) + # without speaker encoder + else: + spk_id = torch.LongTensor(np.array([[int(spk_id)]])).to(self.dev) + + if k_step is not None: + print(f' [INFO] get k_step, do shallow diffusion {k_step} steps') + else: + print(f' [INFO] Do full 1000 steps depth diffusion {k_step}') + print(f" [INFO] method:{method}; infer_speedup:{infer_speedup}") + return self.diff_model(units, f0, volume, spk_id=spk_id, spk_mix_dict=spk_mix_dict, aug_shift=aug_shift, gt_spec=gt_spec, infer=True, infer_speedup=infer_speedup, method=method, k_step=k_step, use_tqdm=use_tqdm, spk_emb=spk_emb, spk_emb_dict=spk_emb_dict) + + @torch.no_grad() + def naive_model_call(self, units, f0, volume, spk_id=1, spk_mix_dict=None, + aug_shift=0, spk_emb=None): + # spk_id + spk_emb_dict = None + if self.diff_args.model.use_speaker_encoder: # with speaker encoder + spk_mix_dict, spk_emb = self.pre_spk_emb(spk_id, spk_mix_dict, len(units), spk_emb) + # without speaker encoder + else: + spk_id = torch.LongTensor(np.array([[int(spk_id)]])).to(self.dev) + aug_shift = torch.from_numpy(np.array([[float(aug_shift)]])).float().to(self.dev) + print("====> unit, f0, vol", units.shape, f0.shape, volume.shape) + print("====> *unit, f0, vol", units) + print("====> unit, *f0, vol", f0) + print("====> unit, f0, *vol", volume) + out_spec = self.naive_model(units, f0, volume, spk_id=spk_id, spk_mix_dict=spk_mix_dict, + aug_shift=aug_shift, infer=True, + spk_emb=spk_emb, spk_emb_dict=spk_emb_dict) + return out_spec + + @torch.no_grad() + def mel2wav(self, mel, f0, start_frame=0): + if start_frame == 0: + return self.vocoder.infer(mel, f0) + else: # for realtime speedup + mel = mel[:, start_frame:, :] + f0 = f0[:, start_frame:, :] + out_wav = self.vocoder.infer(mel, f0) + return torch.nn.functional.pad(out_wav, (start_frame * self.vocoder.vocoder_hop_size, 0)) + + @torch.no_grad() def infer( self, feats: torch.Tensor, - pitch_length: torch.Tensor, pitch: torch.Tensor, - pitchf: torch.Tensor, + volume: torch.Tensor, + mask: torch.Tensor, sid: torch.Tensor, - convert_length: int | None, + infer_speedup: int, + k_step: int, + silence_front: float, ) -> torch.Tensor: - return self.model.infer(feats, pitch_length, pitch, pitchf, sid, convert_length=convert_length) + print("---------------------------------shape", feats.shape, pitch.shape, volume.shape) + gt_spec = self.naive_model_call(feats, pitch, volume, spk_id=sid, spk_mix_dict=None, aug_shift=0, spk_emb=None) + print("======================>>>>>gt_spec", gt_spec) + out_mel = self.__call__(feats, pitch, volume, spk_id=sid, spk_mix_dict=None, aug_shift=0, gt_spec=gt_spec, infer_speedup=infer_speedup, method='dpm-solver', k_step=k_step, use_tqdm=False, spk_emb=None) + print("======================>>>>>out_mel", out_mel) + start_frame = int(silence_front * self.vocoder.vocoder_sample_rate / self.vocoder.vocoder_hop_size) + out_wav = self.mel2wav(out_mel, pitch, start_frame=start_frame) + + print("======================>>>>>out_wav.shape, mask.shape", out_wav.shape, mask.shape) + out_wav *= mask + print("out_wav:::::::::::", out_wav) + return out_wav.squeeze() diff --git a/server/voice_changer/DiffusionSVC/inferencer/Inferencer.py b/server/voice_changer/DiffusionSVC/inferencer/Inferencer.py new file mode 100644 index 00000000..1daa2baa --- /dev/null +++ b/server/voice_changer/DiffusionSVC/inferencer/Inferencer.py @@ -0,0 +1,50 @@ +from typing import Any, Protocol +import torch +import onnxruntime + +from const import DiffusionSVCInferenceType + + +class Inferencer(Protocol): + inferencerType: DiffusionSVCInferenceType = "combo" + file: str + isHalf: bool = True + gpu: int = 0 + + model: onnxruntime.InferenceSession | Any | None = None + + def loadModel(self, file: str, gpu: int): + ... + + def getConfig(self) -> tuple[int, int]: + ... + + def infer( + self, + feats: torch.Tensor, + pitch_length: torch.Tensor, + pitch: torch.Tensor | None, + pitchf: torch.Tensor | None, + sid: torch.Tensor, + ) -> torch.Tensor: + ... + + def setProps( + self, + inferencerType: DiffusionSVCInferenceType, + file: str, + isHalf: bool, + gpu: int, + ): + self.inferencerType = inferencerType + self.file = file + self.isHalf = isHalf + self.gpu = gpu + + def getInferencerInfo(self): + return { + "inferencerType": self.inferencerType, + "file": self.file, + "isHalf": self.isHalf, + "gpu": self.gpu, + } diff --git a/server/voice_changer/DiffusionSVC/inferencer/InferencerManager.py b/server/voice_changer/DiffusionSVC/inferencer/InferencerManager.py new file mode 100644 index 00000000..579dc892 --- /dev/null +++ b/server/voice_changer/DiffusionSVC/inferencer/InferencerManager.py @@ -0,0 +1,29 @@ +from const import DiffusionSVCInferenceType +from voice_changer.DiffusionSVC.inferencer.DiffusionSVCInferencer import DiffusionSVCInferencer +from voice_changer.RVC.inferencer.Inferencer import Inferencer + + +class InferencerManager: + currentInferencer: Inferencer | None = None + + @classmethod + def getInferencer( + cls, + inferencerType: DiffusionSVCInferenceType, + file: str, + gpu: int, + ) -> Inferencer: + cls.currentInferencer = cls.loadInferencer(inferencerType, file, gpu) + return cls.currentInferencer + + @classmethod + def loadInferencer( + cls, + inferencerType: DiffusionSVCInferenceType, + file: str, + gpu: int, + ) -> Inferencer: + if inferencerType == "combo": + return DiffusionSVCInferencer().loadModel(file, gpu) + else: + raise RuntimeError("[Voice Changer] Inferencer not found", inferencerType) diff --git a/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/DiffusionSVC.py b/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/DiffusionSVC.py index 8d4becbf..659c0587 100644 --- a/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/DiffusionSVC.py +++ b/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/DiffusionSVC.py @@ -38,21 +38,16 @@ class DiffusionSVC: self.use_combo_model = False def load_model(self, model_path, f0_model=None, f0_min=None, f0_max=None): - - if ('1234' + model_path)[-4:] == '.ptc': - self.use_combo_model = True - self.model_path = model_path - self.naive_model_path = model_path - diff_model, diff_args, naive_model, naive_args, vocoder = load_model_vocoder_from_combo(model_path, - device=self.device) - self.model = diff_model - self.args = diff_args - self.naive_model = naive_model - self.naive_model_args = naive_args - self.vocoder = vocoder - else: - self.model_path = model_path - self.model, self.vocoder, self.args = load_model_vocoder(model_path, device=self.device) + self.use_combo_model = True + self.model_path = model_path + self.naive_model_path = model_path + diff_model, diff_args, naive_model, naive_args, vocoder = load_model_vocoder_from_combo(model_path, + device=self.device) + self.model = diff_model + self.args = diff_args + self.naive_model = naive_model + self.naive_model_args = naive_args + self.vocoder = vocoder self.units_encoder = Units_Encoder( self.args.data.encoder, @@ -85,33 +80,6 @@ class DiffusionSVC: self.units_indexer = UnitsIndexer(os.path.split(model_path)[0]) - def flush(self, model_path=None, f0_model=None, f0_min=None, f0_max=None, naive_model_path=None): - assert (model_path is not None) or (naive_model_path is not None) - # flush model if changed - if ((self.model_path != model_path) or (self.f0_model != f0_model) - or (self.f0_min != f0_min) or (self.f0_max != f0_max)): - self.load_model(model_path, f0_model=f0_model, f0_min=f0_min, f0_max=f0_max) - if (self.naive_model_path != naive_model_path) and (naive_model_path is not None): - self.load_naive_model(naive_model_path) - # check args if use naive - if self.naive_model is not None: - if self.naive_model_args.data.encoder != self.args.data.encoder: - raise ValueError("encoder of Naive Model and Diffusion Model are different") - if self.naive_model_args.model.n_spk != self.args.model.n_spk: - raise ValueError("n_spk of Naive Model and Diffusion Model are different") - if bool(self.naive_model_args.model.use_speaker_encoder) != bool(self.args.model.use_speaker_encoder): - raise ValueError("use_speaker_encoder of Naive Model and Diffusion Model are different") - if self.naive_model_args.vocoder.type != self.args.vocoder.type: - raise ValueError("vocoder of Naive Model and Diffusion Model are different") - if self.naive_model_args.data.block_size != self.args.data.block_size: - raise ValueError("block_size of Naive Model and Diffusion Model are different") - if self.naive_model_args.data.sampling_rate != self.args.data.sampling_rate: - raise ValueError("sampling_rate of Naive Model and Diffusion Model are different") - - def flush_f0_extractor(self, f0_model, f0_min=None, f0_max=None): - if (f0_model != self.f0_model) and (f0_model is not None): - self.load_f0_extractor(f0_model) - def load_f0_extractor(self, f0_model, f0_min=None, f0_max=None): self.f0_model = f0_model if (f0_model is not None) else self.args.data.f0_extractor self.f0_min = f0_min if (f0_min is not None) else self.args.data.f0_min @@ -127,12 +95,6 @@ class DiffusionSVC: model_sampling_rate=self.args.data.sampling_rate ) - def load_naive_model(self, naive_model_path): - self.naive_model_path = naive_model_path - model, _, args = load_model_vocoder(naive_model_path, device=self.device, loaded_vocoder=self.vocoder) - self.naive_model = model - self.naive_model_args = args - print(f" [INFO] Load naive model from {naive_model_path}") @torch.no_grad() def naive_model_call(self, units, f0, volume, spk_id=1, spk_mix_dict=None, @@ -265,144 +227,6 @@ class DiffusionSVC: gt_spec=gt_spec, infer=True, infer_speedup=infer_speedup, method=method, k_step=k_step, use_tqdm=use_tqdm, spk_emb=spk_emb, spk_emb_dict=spk_emb_dict) - @torch.no_grad() # 比__call__多了声码器代码,输出波形 - def infer(self, units, f0, volume, gt_spec=None, spk_id=1, spk_mix_dict=None, aug_shift=0, - infer_speedup=10, method='dpm-solver', k_step=None, use_tqdm=True, - spk_emb=None): - if k_step is not None: - if self.naive_model is not None: - gt_spec = self.naive_model_call(units, f0, volume, spk_id=spk_id, spk_mix_dict=spk_mix_dict, - aug_shift=aug_shift, spk_emb=spk_emb) - print(f" [INFO] get mel from naive model out.") - assert gt_spec is not None - if self.naive_model is None: - print(f" [INFO] get mel from input wav.") - if input(" [WARN] You are attempting shallow diffusion " - "on the mel of the input source," - " Please enter 'gt_mel' to continue") != 'gt_mel': - raise ValueError("Please understand what you're doing") - k_step = int(k_step) - gt_spec = gt_spec - else: - gt_spec = None - - out_mel = self.__call__(units, f0, volume, spk_id=spk_id, spk_mix_dict=spk_mix_dict, aug_shift=aug_shift, - gt_spec=gt_spec, infer_speedup=infer_speedup, method=method, k_step=k_step, - use_tqdm=use_tqdm, spk_emb=spk_emb) - return self.mel2wav(out_mel, f0) - - @torch.no_grad() # 为实时浅扩散优化的推理代码,可以切除pad省算力 - def infer_for_realtime(self, units, f0, volume, audio_t=None, spk_id=1, spk_mix_dict=None, aug_shift=0, - infer_speedup=10, method='dpm-solver', k_step=None, use_tqdm=True, - spk_emb=None, silence_front=0, diff_jump_silence_front=False): - - start_frame = int(silence_front * self.vocoder.vocoder_sample_rate / self.vocoder.vocoder_hop_size) - - if diff_jump_silence_front: - if audio_t is not None: - audio_t = audio_t[:, start_frame * self.vocoder.vocoder_hop_size:] - f0 = f0[:, start_frame:, :] - units = units[:, start_frame:, :] - volume = volume[:, start_frame:, :] - - if k_step is not None: - assert audio_t is not None - k_step = int(k_step) - gt_spec = self.vocoder.extract(audio_t, self.args.data.sampling_rate) - # 如果缺帧再开这行gt_spec = torch.cat((gt_spec, gt_spec[:, -1:, :]), 1) - else: - gt_spec = None - - out_mel = self.__call__(units, f0, volume, spk_id=spk_id, spk_mix_dict=spk_mix_dict, aug_shift=aug_shift, - gt_spec=gt_spec, infer_speedup=infer_speedup, method=method, k_step=k_step, - use_tqdm=use_tqdm, spk_emb=spk_emb) - - if diff_jump_silence_front: - out_wav = self.mel2wav(out_mel, f0) - else: - out_wav = self.mel2wav(out_mel, f0, start_frame=start_frame) - return out_wav - - @torch.no_grad() # 不切片从音频推理代码 - def infer_from_audio(self, audio, sr=44100, key=0, spk_id=1, spk_mix_dict=None, aug_shift=0, - infer_speedup=10, method='dpm-solver', k_step=None, use_tqdm=True, - spk_emb=None, threhold=-60, index_ratio=0): - units = self.encode_units(audio, sr) - if index_ratio > 0: - units = self.units_indexer(units_t=units, spk_id=spk_id, ratio=index_ratio) - f0 = self.extract_f0(audio, key=key, sr=sr) - volume, mask = self.extract_volume_and_mask(audio, sr, threhold=float(threhold)) - if k_step is not None: - assert 0 < int(k_step) <= 1000 - k_step = int(k_step) - audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(self.device) - gt_spec = self.vocoder.extract(audio_t, sr) - gt_spec = torch.cat((gt_spec, gt_spec[:, -1:, :]), 1) - else: - gt_spec = None - output = self.infer(units, f0, volume, gt_spec=gt_spec, spk_id=spk_id, spk_mix_dict=spk_mix_dict, - aug_shift=aug_shift, infer_speedup=infer_speedup, method=method, k_step=k_step, - use_tqdm=use_tqdm, spk_emb=spk_emb) - output *= mask - return output.squeeze().cpu().numpy(), self.args.data.sampling_rate - - @torch.no_grad() # 切片从音频推理代码 - def infer_from_long_audio(self, audio, sr=44100, key=0, spk_id=1, spk_mix_dict=None, aug_shift=0, - infer_speedup=10, method='dpm-solver', k_step=None, use_tqdm=True, - spk_emb=None, - threhold=-60, threhold_for_split=-40, min_len=5000, index_ratio=0): - - hop_size = self.args.data.block_size * sr / self.args.data.sampling_rate - segments = split(audio, sr, hop_size, db_thresh=threhold_for_split, min_len=min_len) - - print(f' [INFO] Extract f0 volume and mask: Use {self.f0_model}, start...') - _f0_start_time = time.time() - f0 = self.extract_f0(audio, key=key, sr=sr) - volume, mask = self.extract_volume_and_mask(audio, sr, threhold=float(threhold)) - _f0_end_time = time.time() - _f0_used_time = _f0_end_time - _f0_start_time - print(f' [INFO] Extract f0 volume and mask: Done. Use time:{_f0_used_time}') - - if k_step is not None: - assert 0 < int(k_step) <= 1000 - k_step = int(k_step) - audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(self.device) - gt_spec = self.vocoder.extract(audio_t, sr) - gt_spec = torch.cat((gt_spec, gt_spec[:, -1:, :]), 1) - else: - gt_spec = None - - result = np.zeros(0) - current_length = 0 - for segment in tqdm(segments): - start_frame = segment[0] - seg_input = torch.from_numpy(segment[1]).float().unsqueeze(0).to(self.device) - seg_units = self.units_encoder.encode(seg_input, sr, hop_size) - if index_ratio > 0: - seg_units = self.units_indexer(units_t=seg_units, spk_id=spk_id, ratio=index_ratio) - seg_f0 = f0[:, start_frame: start_frame + seg_units.size(1), :] - seg_volume = volume[:, start_frame: start_frame + seg_units.size(1), :] - if gt_spec is not None: - seg_gt_spec = gt_spec[:, start_frame: start_frame + seg_units.size(1), :] - else: - seg_gt_spec = None - seg_output = self.infer(seg_units, seg_f0, seg_volume, gt_spec=seg_gt_spec, spk_id=spk_id, - spk_mix_dict=spk_mix_dict, - aug_shift=aug_shift, infer_speedup=infer_speedup, method=method, k_step=k_step, - use_tqdm=use_tqdm, spk_emb=spk_emb) - _left = start_frame * self.args.data.block_size - _right = (start_frame + seg_units.size(1)) * self.args.data.block_size - seg_output *= mask[:, _left:_right] - seg_output = seg_output.squeeze().cpu().numpy() - silent_length = round(start_frame * self.args.data.block_size) - current_length - if silent_length >= 0: - result = np.append(result, np.zeros(silent_length)) - result = np.append(result, seg_output) - else: - result = cross_fade(result, seg_output, current_length + silent_length) - current_length = current_length + silent_length + len(seg_output) - - return result, self.args.data.sampling_rate @torch.no_grad() # 为实时优化的推理代码,可以切除pad省算力 def infer_from_audio_for_realtime(self, audio, sr, key, spk_id=1, spk_mix_dict=None, aug_shift=0, diff --git a/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/diffusion.py b/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/diffusion.py index 3ac99797..c1f16ca8 100644 --- a/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/diffusion.py +++ b/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/diffusion.py @@ -252,7 +252,7 @@ class GaussianDiffusion(nn.Module): if method is not None and infer_speedup > 1: if method == 'dpm-solver': - from .dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver + from voice_changer.DiffusionSVC.inferencer.diffusion_svc_model.diffusion.dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver # 1. Define the noise schedule. noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t]) diff --git a/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/dpm_solver_pytorch.py b/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/dpm_solver_pytorch.py new file mode 100644 index 00000000..c48ac017 --- /dev/null +++ b/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/dpm_solver_pytorch.py @@ -0,0 +1,1303 @@ +import torch + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + dtype=torch.float32 + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise + schedule are the default settings in Yang Song's ScoreSDE: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ['discrete', 'linear']: + raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear'".format(schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.T = 1. + self.log_alpha_array = self.numerical_clip_alpha(log_alphas).reshape((1, -1,)).to(dtype=dtype) + self.total_N = self.log_alpha_array.shape[1] + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) + else: + self.T = 1. + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + + def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1): + """ + For some beta schedules such as cosine schedule, the log-SNR has numerical isssues. + We clip the log-SNR near t=T within -5.1 to ensure the stability. + Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE. + """ + log_sigmas = 0.5 * torch.log(1. - torch.exp(2. * log_alphas)) + lambs = log_alphas - log_sigmas + idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda) + if idx > 0: + log_alphas = log_alphas[:-idx] + return log_alphas + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * noise_schedule.total_N + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim()) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + return -expand_dims(sigma_t, x.dim()) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v", "score"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__( + self, + model_fn, + noise_schedule, + algorithm_type="dpmsolver++", + correcting_x0_fn=None, + correcting_xt_fn=None, + thresholding_max_val=1., + dynamic_thresholding_ratio=0.995, + ): + """Construct a DPM-Solver. + + We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`). + + We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you + can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the + dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space + DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space + DPMs (such as stable-diffusion). + + To support advanced algorithms in image-to-image applications, we also support corrector functions for + both x0 and xt. + + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++". + correcting_x0_fn: A `str` or a function with the following format: + ``` + def correcting_x0_fn(x0, t): + x0_new = ... + return x0_new + ``` + This function is to correct the outputs of the data prediction model at each sampling step. e.g., + ``` + x0_pred = data_pred_model(xt, t) + if correcting_x0_fn is not None: + x0_pred = correcting_x0_fn(x0_pred, t) + xt_1 = update(x0_pred, xt, t) + ``` + If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1]. + correcting_xt_fn: A function with the following format: + ``` + def correcting_xt_fn(xt, t, step): + x_new = ... + return x_new + ``` + This function is to correct the intermediate samples xt at each sampling step. e.g., + ``` + xt = ... + xt = correcting_xt_fn(xt, t, step) + ``` + thresholding_max_val: A `float`. The max value for thresholding. + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details). + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, + Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models + with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) + self.noise_schedule = noise_schedule + assert algorithm_type in ["dpmsolver", "dpmsolver++"] + self.algorithm_type = algorithm_type + if correcting_x0_fn == "dynamic_thresholding": + self.correcting_x0_fn = self.dynamic_thresholding_fn + else: + self.correcting_x0_fn = correcting_x0_fn + self.correcting_xt_fn = correcting_xt_fn + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.thresholding_max_val = thresholding_max_val + + def dynamic_thresholding_fn(self, x0, t): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with corrector). + """ + noise = self.noise_prediction_fn(x, t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - sigma_t * noise) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0, t) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.algorithm_type == "dpmsolver++": + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3,] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3,] * (K - 1) + [1] + else: + orders = [3,] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2,] * K + else: + K = steps // 2 + 1 + orders = [2,] * (K - 1) + [1] + elif order == 1: + K = 1 + orders = [1,] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + sigma_t / sigma_s * x + - alpha_t * phi_1 * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpmsolver'): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + (sigma_s1 / sigma_s) * x + - (alpha_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpmsolver': + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1. / r1) * (alpha_t * (phi_1 / h + 1.)) * (model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + torch.exp(log_alpha_s1 - log_alpha_s) * x + - (sigma_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpmsolver': + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (1. / r1) * (sigma_t * (phi_1 / h - 1.)) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpmsolver'): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1. / 3. + if r2 is None: + r2 = 2. / 3. + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + (sigma_s1 / sigma_s) * x + - (alpha_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (sigma_s2 / sigma_s) * x + - (alpha_s2 * phi_12) * model_s + + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpmsolver': + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1. / r2) * (alpha_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + (torch.exp(log_alpha_s1 - log_alpha_s)) * x + - (sigma_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (torch.exp(log_alpha_s2 - log_alpha_s)) * x + - (sigma_s2 * phi_12) * model_s + - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpmsolver': + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (1. / r2) * (sigma_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] + t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = (1. / r0) * (model_prev_0 - model_prev_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if solver_type == 'dpmsolver': + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + - 0.5 * (alpha_t * phi_1) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * (phi_1 / h + 1.)) * D1_0 + ) + else: + phi_1 = torch.expm1(h) + if solver_type == 'dpmsolver': + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - 0.5 * (sigma_t * phi_1) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * (phi_1 / h - 1.)) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpmsolver'): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = (1. / r0) * (model_prev_0 - model_prev_1) + D1_1 = (1. / r1) * (model_prev_1 - model_prev_2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1. / (r0 + r1)) * (D1_0 - D1_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_1 = torch.expm1(h) + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + return x_t + + def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpmsolver', r1=None, r2=None): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1) + elif order == 3: + return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpmsolver'): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpmsolver'): + """ + The adaptive step size solver based on singlestep DPM-Solver. + + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((1,)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) + elif order == 3: + r1, r2 = 1. / 3., 2. / 3. + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + nfe += order + print('adaptive solver nfe', nfe) + return x + + def add_noise(self, x, t, noise=None): + """ + Compute the noised input xt = alpha_t * x + sigma_t * noise. + + Args: + x: A `torch.Tensor` with shape `(batch_size, *shape)`. + t: A `torch.Tensor` with shape `(t_size,)`. + Returns: + xt with shape `(t_size, batch_size, *shape)`. + """ + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + if noise is None: + noise = torch.randn((t.shape[0], *x.shape), device=x.device) + x = x.reshape((-1, *x.shape)) + xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise + if t.shape[0] == 1: + return xt.squeeze(0) + else: + return xt + + def inverse(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', + method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver', + atol=0.0078, rtol=0.05, return_intermediate=False, + ): + """ + Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver. + For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training. + """ + t_0 = 1. / self.noise_schedule.total_N if t_start is None else t_start + t_T = self.noise_schedule.T if t_end is None else t_end + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + return self.sample(x, steps=steps, t_start=t_0, t_end=t_T, order=order, skip_type=skip_type, + method=method, lower_order_final=lower_order_final, denoise_to_zero=denoise_to_zero, solver_type=solver_type, + atol=atol, rtol=rtol, return_intermediate=return_intermediate) + + def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', + method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver', + atol=0.0078, rtol=0.05, return_intermediate=False, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + + ===================================================== + + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + + ===================================================== + + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g., DPM-Solver: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + e.g., DPM-Solver++: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + return_intermediate: A `bool`. Whether to save the xt at each step. + When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + if return_intermediate: + assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values" + if self.correcting_xt_fn is not None: + assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None" + device = x.device + intermediates = [] + with torch.no_grad(): + if method == 'adaptive': + x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + t_prev_list = [t] + model_prev_list = [self.model_fn(x, t)] + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + # Init the first `order` values by lower order multistep DPM-Solver. + for step in range(1, order): + t = timesteps[step] + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step, solver_type=solver_type) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + t_prev_list.append(t) + model_prev_list.append(self.model_fn(x, t)) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in range(order, steps + 1): + t = timesteps[step] + # We only use lower order for steps < 10 + if lower_order_final and steps < 10: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [order,] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for step, order in enumerate(orders): + s, t = timesteps_outer[step], timesteps_outer[step + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + else: + raise ValueError("Got wrong method {}".format(method)) + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + if return_intermediate: + return x, intermediates + else: + return x + + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,)*(dims - 1)] \ No newline at end of file diff --git a/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/naive/naive.py b/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/naive/naive.py index 571c09a0..0430e417 100644 --- a/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/naive/naive.py +++ b/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/naive/naive.py @@ -56,6 +56,7 @@ class Unit2MelNaive(nn.Module): residual_dropout=0.1, attention_dropout=0.1) else: + print("[[[[[PCmer]]]]]") self.decoder = PCmer( num_layers=n_layers, num_heads=8, @@ -81,8 +82,9 @@ class Unit2MelNaive(nn.Module): return: dict of B x n_frames x feat ''' - x = self.stack(units.transpose(1,2)).transpose(1,2) - x = x + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume) + x = self.stack(units.transpose(1, 2)).transpose(1, 2) + x = x + self.f0_embed((1 + f0 / 700).log()) + self.volume_embed(volume) + print("-----------------x1>", x) if self.use_speaker_encoder: if spk_mix_dict is not None: assert spk_emb_dict is not None @@ -104,9 +106,13 @@ class Unit2MelNaive(nn.Module): if self.aug_shift_embed is not None and aug_shift is not None: x = x + self.aug_shift_embed(aug_shift / 5) + print("-----------------x2>", x) x = self.decoder(x) + print("-----------------x3>", x) x = self.norm(x) + print("-----------------x4>", x) x = self.dense_out(x) + print("-----------------x5>", x) if not infer: x = F.mse_loss(x, gt_spec) if self.l2reg_loss > 0: diff --git a/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/naive/pcmer.py b/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/naive/pcmer.py index f0eb32a5..4a4d06eb 100644 --- a/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/naive/pcmer.py +++ b/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/naive/pcmer.py @@ -94,9 +94,12 @@ class PCmer(nn.Module): def forward(self, phone, mask=None): # apply all layers to the input + print("[[[[[PCmer]]]]1]", phone, mask) for (i, layer) in enumerate(self._layers): phone = layer(phone, mask) + # print("[[[[[PCmer]]]] 2 ]", phone) # provide the final sequence + print("[[[[[PCmer]]]]3]", phone) return phone @@ -136,9 +139,13 @@ class _EncoderLayer(nn.Module): def forward(self, phone, mask=None): # compute attention sub-layer + print("Phone:::::1:", phone) + print("Phone:::::16:", self.norm(phone)) phone = phone + (self.attn(self.norm(phone), mask=mask)) + print("Phone:::::2:", phone) phone = phone + (self.conformer(phone)) + print("Phone:::::3:", phone) return phone diff --git a/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/unit2mel.py b/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/unit2mel.py index 033ead68..1abdd3e8 100644 --- a/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/unit2mel.py +++ b/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/unit2mel.py @@ -3,10 +3,10 @@ import yaml import torch import torch.nn as nn import numpy as np -from .diffusion import GaussianDiffusion -from .wavenet import WaveNet -from .vocoder import Vocoder -from .naive.naive import Unit2MelNaive +from voice_changer.DiffusionSVC.inferencer.diffusion_svc_model.diffusion.diffusion import GaussianDiffusion +from voice_changer.DiffusionSVC.inferencer.diffusion_svc_model.diffusion.wavenet import WaveNet +from voice_changer.DiffusionSVC.inferencer.diffusion_svc_model.diffusion.vocoder import Vocoder +from voice_changer.DiffusionSVC.inferencer.diffusion_svc_model.diffusion.naive.naive import Unit2MelNaive class DotDict(dict): diff --git a/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/vocoder.py b/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/vocoder.py index 6a4c9f88..c3bd0ad6 100644 --- a/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/vocoder.py +++ b/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/diffusion/vocoder.py @@ -1,6 +1,6 @@ import torch -from nsf_hifigan.nvSTFT import STFT -from nsf_hifigan.models import load_model, load_config +from voice_changer.DiffusionSVC.inferencer.diffusion_svc_model.nsf_hifigan.nvSTFT import STFT +from voice_changer.DiffusionSVC.inferencer.diffusion_svc_model.nsf_hifigan.models import load_model, load_config from torchaudio.transforms import Resample diff --git a/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/tools/tools.py b/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/tools/tools.py index 89e3e793..df0b0d3b 100644 --- a/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/tools/tools.py +++ b/server/voice_changer/DiffusionSVC/inferencer/diffusion_svc_model/tools/tools.py @@ -1,7 +1,7 @@ import numpy as np import torch import torch.nn.functional as F -import torch.nn as nn + import pyworld as pw import parselmouth import torchcrepe @@ -789,15 +789,6 @@ def median_pool_1d(x, kernel_size): x, _ = torch.sort(x, dim=-1) return x[:, :, (kernel_size - 1) // 2] - -def upsample(signal, factor): - signal = signal.permute(0, 2, 1) - signal = nn.functional.interpolate(torch.cat((signal, signal[:, :, -1:]), 2), size=signal.shape[-1] * factor + 1, - mode='linear', align_corners=True) - signal = signal[:, :, :-1] - return signal.permute(0, 2, 1) - - def cross_fade(a: np.ndarray, b: np.ndarray, idx: int): result = np.zeros(idx + b.shape[0]) fade_len = a.shape[0] - idx diff --git a/server/voice_changer/DiffusionSVC/pipeline/Pipeline.py b/server/voice_changer/DiffusionSVC/pipeline/Pipeline.py index 62eecb54..a4784b38 100644 --- a/server/voice_changer/DiffusionSVC/pipeline/Pipeline.py +++ b/server/voice_changer/DiffusionSVC/pipeline/Pipeline.py @@ -1,4 +1,3 @@ -import numpy as np from typing import Any import math import torch @@ -10,13 +9,14 @@ from Exceptions import ( HalfPrecisionChangingException, NotEnoughDataExtimateF0, ) +from voice_changer.DiffusionSVC.inferencer.Inferencer import Inferencer from voice_changer.RVC.embedder.Embedder import Embedder -from voice_changer.RVC.inferencer.Inferencer import Inferencer from voice_changer.RVC.inferencer.OnnxRVCInferencer import OnnxRVCInferencer from voice_changer.RVC.inferencer.OnnxRVCInferencerNono import OnnxRVCInferencerNono from voice_changer.RVC.pitchExtractor.PitchExtractor import PitchExtractor +from voice_changer.common.VolumeExtractor import VolumeExtractor class Pipeline(object): @@ -37,29 +37,30 @@ class Pipeline(object): embedder: Embedder, inferencer: Inferencer, pitchExtractor: PitchExtractor, - index: Any | None, - # feature: Any | None, + # index: Any | None, targetSR, device, isHalf, ): + model_block_size, model_sampling_rate = inferencer.getConfig() + self.hop_size = model_block_size * 16000 / model_sampling_rate # 16000はオーディオのサンプルレート。この時点で16Kになっている。 + + self.volumeExtractor = VolumeExtractor(self.hop_size, model_block_size, model_sampling_rate, audio_sampling_rate=16000) self.embedder = embedder + self.inferencer = inferencer self.pitchExtractor = pitchExtractor print("GENERATE INFERENCER", self.inferencer) print("GENERATE EMBEDDER", self.embedder) print("GENERATE PITCH EXTRACTOR", self.pitchExtractor) - self.index = index - self.big_npy = index.reconstruct_n(0, index.ntotal) if index is not None else None - # self.feature = feature - self.targetSR = targetSR self.device = device - self.isHalf = isHalf + # self.isHalf = isHalf + self.isHalf = False - self.sr = 16000 - self.window = 160 + # self.sr = 16000 + # self.window = 160 def getPipelineInfo(self): inferencerInfo = self.inferencer.getInferencerInfo() if self.inferencer else {} @@ -70,6 +71,13 @@ class Pipeline(object): def setPitchExtractor(self, pitchExtractor: PitchExtractor): self.pitchExtractor = pitchExtractor + @torch.no_grad() + def extract_volume_and_mask(self, audio, threhold): + volume = self.volumeExtractor.extract(audio) + mask = self.volumeExtractor.get_mask_from_volume(volume, threhold=threhold, device=self.device) + volume = torch.from_numpy(volume).float().to(self.device).unsqueeze(-1).unsqueeze(0) + return volume, mask + def exec( self, sid, @@ -87,56 +95,45 @@ class Pipeline(object): out_size=None, ): # 16000のサンプリングレートで入ってきている。以降この世界は16000で処理。 - - search_index = self.index is not None and self.big_npy is not None and index_rate != 0 - # self.t_pad = self.sr * repeat # 1秒 - # self.t_pad_tgt = self.targetSR * repeat # 1秒 出力時のトリミング(モデルのサンプリングで出力される) audio = audio.unsqueeze(0) - - quality_padding_sec = (repeat * (audio.shape[1] - 1)) / self.sr # padding(reflect)のサイズは元のサイズより小さい必要がある。 - - self.t_pad = round(self.sr * quality_padding_sec) # 前後に音声を追加 - self.t_pad_tgt = round(self.targetSR * quality_padding_sec) # 前後に音声を追加 出力時のトリミング(モデルのサンプリングで出力される) + self.t_pad = 0 audio_pad = F.pad(audio, (self.t_pad, self.t_pad), mode="reflect").squeeze(0) - p_len = audio_pad.shape[0] // self.window sid = torch.tensor(sid, device=self.device).unsqueeze(0).long() - # RVC QualityがOnのときにはsilence_frontをオフに。 - silence_front = silence_front if repeat == 0 else 0 - pitchf = pitchf if repeat == 0 else np.zeros(p_len) - out_size = out_size if repeat == 0 else None + n_frames = int(audio_pad.size(-1) // self.hop_size + 1) + print("--------------------> n_frames:", n_frames) + volume, mask = self.extract_volume_and_mask(audio, threhold=-60.0) + print("--------------------> volume:", volume.shape) # ピッチ検出 try: - if if_f0 == 1: - pitch, pitchf = self.pitchExtractor.extract( - audio_pad, - pitchf, - f0_up_key, - self.sr, - self.window, - silence_front=silence_front, - ) - # pitch = pitch[:p_len] - # pitchf = pitchf[:p_len] - pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long() - pitchf = torch.tensor(pitchf, device=self.device, dtype=torch.float).unsqueeze(0) - else: - pitch = None - pitchf = None - except IndexError: + pitch, pitchf = self.pitchExtractor.extract( + audio_pad, + pitchf, + f0_up_key, + 16000, # 音声のサンプリングレート(既に16000) + # int(self.hop_size), # 処理のwindowサイズ (44100における512) + int(self.hop_size), # 処理のwindowサイズ (44100における512) + silence_front=silence_front, + ) + print("--------------------> pitch11111111111111111111111111111111:", pitch[1:], pitch.shape) + + pitch = torch.tensor(pitch[-n_frames:], device=self.device).unsqueeze(0).long() # 160window sizeを前提にバッファを作っているので切る。 + pitchf = torch.tensor(pitchf[-n_frames:], device=self.device, dtype=torch.float).unsqueeze(0) # 160window sizeを前提にバッファを作っているので切る。 + except IndexError as e: + print(e) # print(e) raise NotEnoughDataExtimateF0() + print("--------------------> pitch:", pitch, pitch.shape) + # tensor型調整 feats = audio_pad if feats.dim() == 2: # double channels feats = feats.mean(-1) - assert feats.dim() == 1, feats.dim() feats = feats.view(1, -1) # embedding - padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False) with autocast(enabled=self.isHalf): try: feats = self.embedder.extractFeatures(feats, embOutputLayer, useFinalProj) @@ -149,74 +146,46 @@ class Pipeline(object): raise DeviceChangingException() else: raise e - if protect < 0.5 and search_index: - feats0 = feats.clone() - # Index - feature抽出 - # if self.index is not None and self.feature is not None and index_rate != 0: - if search_index: - npy = feats[0].cpu().numpy() - # apply silent front for indexsearch - npyOffset = math.floor(silence_front * 16000) // 360 - npy = npy[npyOffset:] + print("--------------------> feats1:", feats, feats.shape) - if self.isHalf is True: - npy = npy.astype("float32") + # feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) + feats = F.interpolate(feats.permute(0, 2, 1), size=int(n_frames), mode='nearest').permute(0, 2, 1) - # TODO: kは調整できるようにする - k = 1 - if k == 1: - _, ix = self.index.search(npy, 1) - npy = self.big_npy[ix.squeeze()] - else: - score, ix = self.index.search(npy, k=8) - weight = np.square(1 / score) - weight /= weight.sum(axis=1, keepdims=True) - npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1) + if protect < 0.5: + feats0 = feats.clone() + print("--------------------> feats2:", feats, feats.shape) - # recover silient font - npy = np.concatenate([np.zeros([npyOffset, npy.shape[1]], dtype=np.float32), feature[:npyOffset:2].astype("float32"), npy])[-feats.shape[1]:] - feats = torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate + (1 - index_rate) * feats - feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) - if protect < 0.5 and search_index: - feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) + # # ピッチサイズ調整 + # p_len = audio_pad.shape[0] // self.window + # feats_len = feats.shape[1] + # if feats.shape[1] < p_len: + # p_len = feats_len + # pitch = pitch[:, :feats_len] + # pitchf = pitchf[:, :feats_len] - # ピッチサイズ調整 - p_len = audio_pad.shape[0] // self.window - if feats.shape[1] < p_len: - p_len = feats.shape[1] - if pitch is not None and pitchf is not None: - pitch = pitch[:, :p_len] - pitchf = pitchf[:, :p_len] + # pitch = pitch[:, -feats_len:] + # pitchf = pitchf[:, -feats_len:] + # p_len = torch.tensor([feats_len], device=self.device).long() - feats_len = feats.shape[1] - if pitch is not None and pitchf is not None: - pitch = pitch[:, -feats_len:] - pitchf = pitchf[:, -feats_len:] - p_len = torch.tensor([feats_len], device=self.device).long() + # print("----------plen::1:", p_len) # pitchの推定が上手くいかない(pitchf=0)場合、検索前の特徴を混ぜる # pitchffの作り方の疑問はあるが、本家通りなので、このまま使うことにする。 # https://github.com/w-okada/voice-changer/pull/276#issuecomment-1571336929 - if protect < 0.5 and search_index: + if protect < 0.5: pitchff = pitchf.clone() pitchff[pitchf > 0] = 1 pitchff[pitchf < 1] = protect pitchff = pitchff.unsqueeze(-1) feats = feats * pitchff + feats0 * (1 - pitchff) feats = feats.to(feats0.dtype) - p_len = torch.tensor([p_len], device=self.device).long() + # p_len = torch.tensor([p_len], device=self.device).long() - # apply silent front for inference - if type(self.inferencer) in [OnnxRVCInferencer, OnnxRVCInferencerNono]: - npyOffset = math.floor(silence_front * 16000) // 360 - feats = feats[:, npyOffset * 2 :, :] # NOQA - - feats_len = feats.shape[1] - if pitch is not None and pitchf is not None: - pitch = pitch[:, -feats_len:] - pitchf = pitchf[:, -feats_len:] - p_len = torch.tensor([feats_len], device=self.device).long() + # # apply silent front for inference + # if type(self.inferencer) in [OnnxRVCInferencer, OnnxRVCInferencerNono]: + # npyOffset = math.floor(silence_front * 16000) // 360 # 160x2 = 360 + # feats = feats[:, npyOffset * 2 :, :] # NOQA # 推論実行 try: @@ -224,7 +193,16 @@ class Pipeline(object): with autocast(enabled=self.isHalf): audio1 = ( torch.clip( - self.inferencer.infer(feats, p_len, pitch, pitchf, sid, out_size)[0][0, 0].to(dtype=torch.float32), + self.inferencer.infer( + feats, + pitch.unsqueeze(-1), + volume, + mask, + sid, + infer_speedup=10, + k_step=20, + silence_front=silence_front + ).to(dtype=torch.float32), -1.0, 1.0, ) @@ -243,16 +221,7 @@ class Pipeline(object): else: pitchf_buffer = None - del p_len, padding_mask, pitch, pitchf, feats + del pitch, pitchf, feats, sid torch.cuda.empty_cache() - # inferで出力されるサンプリングレートはモデルのサンプリングレートになる。 - # pipelineに(入力されるときはhubertように16k) - if self.t_pad_tgt != 0: - offset = self.t_pad_tgt - end = -1 * self.t_pad_tgt - audio1 = audio1[offset:end] - - del sid - torch.cuda.empty_cache() return audio1, pitchf_buffer, feats_buffer diff --git a/server/voice_changer/DiffusionSVC/pipeline/PipelineGenerator.py b/server/voice_changer/DiffusionSVC/pipeline/PipelineGenerator.py index dff294ab..6545b6f1 100644 --- a/server/voice_changer/DiffusionSVC/pipeline/PipelineGenerator.py +++ b/server/voice_changer/DiffusionSVC/pipeline/PipelineGenerator.py @@ -1,51 +1,48 @@ -import os import traceback -import faiss -from data.ModelSlot import DiffusionSVCModelSlot, RVCModelSlot +from data.ModelSlot import DiffusionSVCModelSlot +from voice_changer.DiffusionSVC.inferencer.InferencerManager import InferencerManager +from voice_changer.DiffusionSVC.pipeline.Pipeline import Pipeline from voice_changer.RVC.deviceManager.DeviceManager import DeviceManager from voice_changer.RVC.embedder.EmbedderManager import EmbedderManager -from voice_changer.RVC.inferencer.InferencerManager import InferencerManager -from voice_changer.RVC.pipeline.Pipeline import Pipeline from voice_changer.RVC.pitchExtractor.PitchExtractorManager import PitchExtractorManager def createPipeline(modelSlot: DiffusionSVCModelSlot, gpu: int, f0Detector: str): dev = DeviceManager.get_instance().getDevice(gpu) - half = DeviceManager.get_instance().halfPrecisionAvailable(gpu) + # half = DeviceManager.get_instance().halfPrecisionAvailable(gpu) + half = False - # # Inferencer 生成 - # try: - # inferencer = InferencerManager.getInferencer(modelSlot.modelType, modelSlot.modelFile, gpu) - # except Exception as e: - # print("[Voice Changer] exception! loading inferencer", e) - # traceback.print_exc() + # Inferencer 生成 + try: + inferencer = InferencerManager.getInferencer(modelSlot.modelType, modelSlot.modelFile, gpu) + except Exception as e: + print("[Voice Changer] exception! loading inferencer", e) + traceback.print_exc() - # # Embedder 生成 - # try: - # embedder = EmbedderManager.getEmbedder( - # modelSlot.embedder, - # # emmbedderFilename, - # half, - # dev, - # ) - # except Exception as e: - # print("[Voice Changer] exception! loading embedder", e) - # traceback.print_exc() + # Embedder 生成 + try: + embedder = EmbedderManager.getEmbedder( + modelSlot.embedder, + # emmbedderFilename, + half, + dev, + ) + except Exception as e: + print("[Voice Changer] exception! loading embedder", e) + traceback.print_exc() - # # pitchExtractor - # pitchExtractor = PitchExtractorManager.getPitchExtractor(f0Detector, gpu) + # pitchExtractor + pitchExtractor = PitchExtractorManager.getPitchExtractor(f0Detector, gpu) + pipeline = Pipeline( + embedder, + inferencer, + pitchExtractor, + modelSlot.samplingRate, + dev, + half, + ) - # pipeline = Pipeline( - # embedder, - # inferencer, - # pitchExtractor, - # index, - # modelSlot.samplingRate, - # dev, - # half, - # ) - - # return pipeline + return pipeline diff --git a/server/voice_changer/ModelSlotManager.py b/server/voice_changer/ModelSlotManager.py index f589ce10..94926431 100644 --- a/server/voice_changer/ModelSlotManager.py +++ b/server/voice_changer/ModelSlotManager.py @@ -11,6 +11,7 @@ class ModelSlotManager: def __init__(self, model_dir: str): self.model_dir = model_dir self.modelSlots = loadAllSlotInfo(self.model_dir) + print("MODEL SLOT INFO-------------->>>>>", self.modelSlots) @classmethod def get_instance(cls, model_dir: str): diff --git a/server/voice_changer/common/VolumeExtractor.py b/server/voice_changer/common/VolumeExtractor.py new file mode 100644 index 00000000..8c76413a --- /dev/null +++ b/server/voice_changer/common/VolumeExtractor.py @@ -0,0 +1,41 @@ +import numpy as np +import torch +import torch.nn as nn + + +class VolumeExtractor: + def __init__(self, hop_size: float, block_size: int, model_sampling_rate: int, audio_sampling_rate: int): + self.hop_size = hop_size + self.block_size = block_size + self.model_sampling_rate = model_sampling_rate + self.audio_sampling_rate = audio_sampling_rate + # self.hop_size = self.block_size * self.audio_sampling_rate / self.model_sampling_rate # モデルの処理単位が512(Diffusion-SVC), 入力のサンプリングレートのサイズにhopsizeを合わせる。 + + def extract(self, audio): # audio: 1d numpy array + audio = audio.squeeze().cpu() + print("----VolExtractor2", audio.shape, self.block_size, self.model_sampling_rate, self.audio_sampling_rate, self.hop_size) + n_frames = int(len(audio) // self.hop_size) + 1 + print("=======> n_frames", n_frames) + audio2 = audio ** 2 + print("----VolExtractor3", audio2.shape) + audio2 = np.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)), mode='reflect') + print("----VolExtractor4", audio2.shape) + volume = np.array( + [np.mean(audio2[int(n * self.hop_size): int((n + 1) * self.hop_size)]) for n in range(n_frames)]) + volume = np.sqrt(volume) + return volume + + def get_mask_from_volume(self, volume, threhold=-60.0, device='cpu') -> torch.Tensor: + mask = (volume > 10 ** (float(threhold) / 20)).astype('float') + mask = np.pad(mask, (4, 4), constant_values=(mask[0], mask[-1])) + mask = np.array([np.max(mask[n: n + 9]) for n in range(len(mask) - 8)]) + mask = torch.from_numpy(mask).float().to(device).unsqueeze(-1).unsqueeze(0) + mask = upsample(mask, self.block_size).squeeze(-1) + return mask + + +def upsample(signal: torch.Tensor, factor: int) -> torch.Tensor: + signal = signal.permute(0, 2, 1) + signal = nn.functional.interpolate(torch.cat((signal, signal[:, :, -1:]), 2), size=signal.shape[-1] * factor + 1, mode='linear', align_corners=True) + signal = signal[:, :, :-1] + return signal.permute(0, 2, 1)