mirror of
https://github.com/w-okada/voice-changer.git
synced 2025-01-23 21:45:00 +03:00
WIP: integrate vcs to new gui 4
This commit is contained in:
parent
4d31d2238d
commit
0201b2e44c
@ -4,7 +4,6 @@ from dataclasses import asdict
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from data.ModelSlot import DDSPSVCModelSlot
|
from data.ModelSlot import DDSPSVCModelSlot
|
||||||
from voice_changer.DDSP_SVC.ModelSlot import ModelSlot
|
|
||||||
|
|
||||||
from voice_changer.DDSP_SVC.deviceManager.DeviceManager import DeviceManager
|
from voice_changer.DDSP_SVC.deviceManager.DeviceManager import DeviceManager
|
||||||
|
|
||||||
@ -18,11 +17,10 @@ if sys.platform.startswith("darwin"):
|
|||||||
else:
|
else:
|
||||||
sys.path.append("DDSP-SVC")
|
sys.path.append("DDSP-SVC")
|
||||||
|
|
||||||
from diffusion.infer_gt_mel import DiffGtMel # type: ignore
|
from .models.diffusion.infer_gt_mel import DiffGtMel
|
||||||
|
|
||||||
from voice_changer.utils.VoiceChangerModel import AudioInOut
|
from voice_changer.utils.VoiceChangerModel import AudioInOut
|
||||||
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
|
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
|
||||||
from voice_changer.utils.LoadModelParams import LoadModelParams, LoadModelParams2
|
|
||||||
from voice_changer.DDSP_SVC.DDSP_SVCSetting import DDSP_SVCSettings
|
from voice_changer.DDSP_SVC.DDSP_SVCSetting import DDSP_SVCSettings
|
||||||
from voice_changer.RVC.embedder.EmbedderManager import EmbedderManager
|
from voice_changer.RVC.embedder.EmbedderManager import EmbedderManager
|
||||||
|
|
||||||
@ -51,68 +49,39 @@ def phase_vocoder(a, b, fade_out, fade_in):
|
|||||||
|
|
||||||
class DDSP_SVC:
|
class DDSP_SVC:
|
||||||
initialLoad: bool = True
|
initialLoad: bool = True
|
||||||
settings: DDSP_SVCSettings = DDSP_SVCSettings()
|
|
||||||
diff_model: DiffGtMel = DiffGtMel()
|
|
||||||
svc_model: SvcDDSP = SvcDDSP()
|
|
||||||
|
|
||||||
deviceManager = DeviceManager.get_instance()
|
def __init__(self, params: VoiceChangerParams, slotInfo: DDSPSVCModelSlot):
|
||||||
# diff_model: DiffGtMel = DiffGtMel()
|
print("[Voice Changer] [DDSP-SVC] Creating instance ")
|
||||||
|
self.deviceManager = DeviceManager.get_instance()
|
||||||
audio_buffer: AudioInOut | None = None
|
|
||||||
prevVol: float = 0
|
|
||||||
# resample_kernel = {}
|
|
||||||
|
|
||||||
def __init__(self, params: VoiceChangerParams):
|
|
||||||
self.gpu_num = torch.cuda.device_count()
|
self.gpu_num = torch.cuda.device_count()
|
||||||
self.params = params
|
self.params = params
|
||||||
|
self.settings = DDSP_SVCSettings()
|
||||||
|
self.svc_model: SvcDDSP = SvcDDSP()
|
||||||
|
self.diff_model: DiffGtMel = DiffGtMel()
|
||||||
|
|
||||||
self.svc_model.setVCParams(params)
|
self.svc_model.setVCParams(params)
|
||||||
EmbedderManager.initialize(params)
|
EmbedderManager.initialize(params)
|
||||||
print("[Voice Changer] DDSP-SVC initialization:", params)
|
|
||||||
|
|
||||||
def loadModel(self, props: LoadModelParams):
|
self.audio_buffer: AudioInOut | None = None
|
||||||
target_slot_idx = props.slot
|
self.prevVol = 0.0
|
||||||
params = props.params
|
self.slotInfo = slotInfo
|
||||||
|
self.initialize()
|
||||||
|
|
||||||
modelFile = params["files"]["ddspSvcModel"]
|
def initialize(self):
|
||||||
diffusionFile = params["files"]["ddspSvcDiffusion"]
|
|
||||||
modelSlot = ModelSlot(
|
|
||||||
modelFile=modelFile,
|
|
||||||
diffusionFile=diffusionFile,
|
|
||||||
defaultTrans=params["trans"] if "trans" in params else 0,
|
|
||||||
)
|
|
||||||
self.settings.modelSlots[target_slot_idx] = modelSlot
|
|
||||||
|
|
||||||
# 初回のみロード
|
|
||||||
# if self.initialLoad:
|
|
||||||
# self.prepareModel(target_slot_idx)
|
|
||||||
# self.settings.modelSlotIndex = target_slot_idx
|
|
||||||
# self.switchModel()
|
|
||||||
# self.initialLoad = False
|
|
||||||
# elif target_slot_idx == self.currentSlot:
|
|
||||||
# self.prepareModel(target_slot_idx)
|
|
||||||
self.settings.modelSlotIndex = target_slot_idx
|
|
||||||
self.reloadModel()
|
|
||||||
|
|
||||||
print("params:", params)
|
|
||||||
return self.get_info()
|
|
||||||
|
|
||||||
def reloadModel(self):
|
|
||||||
self.device = self.deviceManager.getDevice(self.settings.gpu)
|
self.device = self.deviceManager.getDevice(self.settings.gpu)
|
||||||
modelFile = self.settings.modelSlots[self.settings.modelSlotIndex].modelFile
|
|
||||||
diffusionFile = self.settings.modelSlots[self.settings.modelSlotIndex].diffusionFile
|
|
||||||
|
|
||||||
self.svc_model = SvcDDSP()
|
self.svc_model = SvcDDSP()
|
||||||
self.svc_model.setVCParams(self.params)
|
self.svc_model.setVCParams(self.params)
|
||||||
self.svc_model.update_model(modelFile, self.device)
|
self.svc_model.update_model(self.slotInfo.modelFile, self.device)
|
||||||
self.diff_model = DiffGtMel(device=self.device)
|
self.diff_model = DiffGtMel(device=self.device)
|
||||||
self.diff_model.flush_model(diffusionFile, ddsp_config=self.svc_model.args)
|
self.diff_model.flush_model(self.slotInfo.diffModelFile, ddsp_config=self.svc_model.args)
|
||||||
|
|
||||||
def update_settings(self, key: str, val: int | float | str):
|
def update_settings(self, key: str, val: int | float | str):
|
||||||
if key in self.settings.intData:
|
if key in self.settings.intData:
|
||||||
val = int(val)
|
val = int(val)
|
||||||
setattr(self.settings, key, val)
|
setattr(self.settings, key, val)
|
||||||
if key == "gpu":
|
if key == "gpu":
|
||||||
self.reloadModel()
|
self.initialize()
|
||||||
elif key in self.settings.floatData:
|
elif key in self.settings.floatData:
|
||||||
setattr(self.settings, key, float(val))
|
setattr(self.settings, key, float(val))
|
||||||
elif key in self.settings.strData:
|
elif key in self.settings.strData:
|
||||||
@ -160,10 +129,6 @@ class DDSP_SVC:
|
|||||||
# raise NoModeLoadedException("ONNX")
|
# raise NoModeLoadedException("ONNX")
|
||||||
|
|
||||||
def _pyTorch_inference(self, data):
|
def _pyTorch_inference(self, data):
|
||||||
# if hasattr(self, "model") is False or self.model is None:
|
|
||||||
# print("[Voice Changer] No pyTorch session.")
|
|
||||||
# raise NoModeLoadedException("pytorch")
|
|
||||||
|
|
||||||
input_wav = data[0]
|
input_wav = data[0]
|
||||||
_audio, _model_sr = self.svc_model.infer(
|
_audio, _model_sr = self.svc_model.infer(
|
||||||
input_wav,
|
input_wav,
|
||||||
@ -192,32 +157,13 @@ class DDSP_SVC:
|
|||||||
return _audio.cpu().numpy() * 32768.0
|
return _audio.cpu().numpy() * 32768.0
|
||||||
|
|
||||||
def inference(self, data):
|
def inference(self, data):
|
||||||
if self.settings.framework == "ONNX":
|
if self.slotInfo.isONNX:
|
||||||
audio = self._onnx_inference(data)
|
audio = self._onnx_inference(data)
|
||||||
else:
|
else:
|
||||||
audio = self._pyTorch_inference(data)
|
audio = self._pyTorch_inference(data)
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def loadModel2(cls, props: LoadModelParams2):
|
|
||||||
slotInfo: DDSPSVCModelSlot = DDSPSVCModelSlot()
|
|
||||||
for file in props.files:
|
|
||||||
if file.kind == "ddspSvcModelConfig":
|
|
||||||
slotInfo.configFile = file.name
|
|
||||||
elif file.kind == "ddspSvcModel":
|
|
||||||
slotInfo.modelFile = file.name
|
|
||||||
elif file.kind == "ddspSvcDiffusionConfig":
|
|
||||||
slotInfo.diffConfigFile = file.name
|
|
||||||
elif file.kind == "ddspSvcDiffusion":
|
|
||||||
slotInfo.diffModelFile = file.name
|
|
||||||
slotInfo.isONNX = slotInfo.modelFile.endswith(".onnx")
|
|
||||||
slotInfo.name = os.path.splitext(os.path.basename(slotInfo.modelFile))[0]
|
|
||||||
return slotInfo
|
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
del self.net_g
|
|
||||||
del self.onnx_session
|
|
||||||
|
|
||||||
remove_path = os.path.join("DDSP-SVC")
|
remove_path = os.path.join("DDSP-SVC")
|
||||||
sys.path = [x for x in sys.path if x.endswith(remove_path) is False]
|
sys.path = [x for x in sys.path if x.endswith(remove_path) is False]
|
||||||
|
|
||||||
|
22
server/voice_changer/DDSP_SVC/DDSP_SVCModelSlotGenerator.py
Normal file
22
server/voice_changer/DDSP_SVC/DDSP_SVCModelSlotGenerator.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import os
|
||||||
|
from data.ModelSlot import DDSPSVCModelSlot
|
||||||
|
from voice_changer.utils.LoadModelParams import LoadModelParams
|
||||||
|
from voice_changer.utils.ModelSlotGenerator import ModelSlotGenerator
|
||||||
|
|
||||||
|
|
||||||
|
class DDSP_SVCModelSlotGenerator(ModelSlotGenerator):
|
||||||
|
@classmethod
|
||||||
|
def loadModel(cls, props: LoadModelParams):
|
||||||
|
slotInfo: DDSPSVCModelSlot = DDSPSVCModelSlot()
|
||||||
|
for file in props.files:
|
||||||
|
if file.kind == "ddspSvcModelConfig":
|
||||||
|
slotInfo.configFile = file.name
|
||||||
|
elif file.kind == "ddspSvcModel":
|
||||||
|
slotInfo.modelFile = file.name
|
||||||
|
elif file.kind == "ddspSvcDiffusionConfig":
|
||||||
|
slotInfo.diffConfigFile = file.name
|
||||||
|
elif file.kind == "ddspSvcDiffusion":
|
||||||
|
slotInfo.diffModelFile = file.name
|
||||||
|
slotInfo.isONNX = slotInfo.modelFile.endswith(".onnx")
|
||||||
|
slotInfo.name = os.path.splitext(os.path.basename(slotInfo.modelFile))[0]
|
||||||
|
return slotInfo
|
@ -1,7 +1,5 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from voice_changer.DDSP_SVC.ModelSlot import ModelSlot
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DDSP_SVCSettings:
|
class DDSP_SVCSettings:
|
||||||
@ -23,14 +21,7 @@ class DDSP_SVCSettings:
|
|||||||
kStep: int = 120
|
kStep: int = 120
|
||||||
threshold: int = -45
|
threshold: int = -45
|
||||||
|
|
||||||
framework: str = "PyTorch" # PyTorch or ONNX
|
|
||||||
pyTorchModelFile: str = ""
|
|
||||||
onnxModelFile: str = ""
|
|
||||||
configFile: str = ""
|
|
||||||
|
|
||||||
speakers: dict[str, int] = field(default_factory=lambda: {})
|
speakers: dict[str, int] = field(default_factory=lambda: {})
|
||||||
modelSlotIndex: int = -1
|
|
||||||
modelSlots: list[ModelSlot] = field(default_factory=lambda: [ModelSlot()])
|
|
||||||
# ↓mutableな物だけ列挙
|
# ↓mutableな物だけ列挙
|
||||||
intData = [
|
intData = [
|
||||||
"gpu",
|
"gpu",
|
||||||
@ -46,4 +37,4 @@ class DDSP_SVCSettings:
|
|||||||
"kStep",
|
"kStep",
|
||||||
]
|
]
|
||||||
floatData = ["silentThreshold"]
|
floatData = ["silentThreshold"]
|
||||||
strData = ["framework", "f0Detector", "diffMethod"]
|
strData = ["f0Detector", "diffMethod"]
|
||||||
|
@ -1,8 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelSlot:
|
|
||||||
modelFile: str = ""
|
|
||||||
diffusionFile: str = ""
|
|
||||||
defaultTrans: int = 0
|
|
@ -3,13 +3,13 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder # type: ignore
|
from .models.ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
from ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder # type: ignore
|
from .models.ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder
|
||||||
|
|
||||||
from ddsp.core import upsample # type: ignore
|
from .models.ddsp.core import upsample
|
||||||
from enhancer import Enhancer # type: ignore
|
from .models.enhancer import Enhancer
|
||||||
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
|
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -38,11 +38,7 @@ class SvcDDSP:
|
|||||||
print("ARGS:", self.args)
|
print("ARGS:", self.args)
|
||||||
|
|
||||||
# load units encoder
|
# load units encoder
|
||||||
if (
|
if self.units_encoder is None or self.args.data.encoder != self.encoder_type or self.args.data.encoder_ckpt != self.encoder_ckpt:
|
||||||
self.units_encoder is None
|
|
||||||
or self.args.data.encoder != self.encoder_type
|
|
||||||
or self.args.data.encoder_ckpt != self.encoder_ckpt
|
|
||||||
):
|
|
||||||
if self.args.data.encoder == "cnhubertsoftfish":
|
if self.args.data.encoder == "cnhubertsoftfish":
|
||||||
cnhubertsoft_gate = self.args.data.cnhubertsoft_gate
|
cnhubertsoft_gate = self.args.data.cnhubertsoft_gate
|
||||||
else:
|
else:
|
||||||
@ -77,15 +73,9 @@ class SvcDDSP:
|
|||||||
self.encoder_ckpt = encoderPath
|
self.encoder_ckpt = encoderPath
|
||||||
|
|
||||||
# load enhancer
|
# load enhancer
|
||||||
if (
|
if self.enhancer is None or self.args.enhancer.type != self.enhancer_type or self.args.enhancer.ckpt != self.enhancer_ckpt:
|
||||||
self.enhancer is None
|
|
||||||
or self.args.enhancer.type != self.enhancer_type
|
|
||||||
or self.args.enhancer.ckpt != self.enhancer_ckpt
|
|
||||||
):
|
|
||||||
enhancerPath = self.params.nsf_hifigan
|
enhancerPath = self.params.nsf_hifigan
|
||||||
self.enhancer = Enhancer(
|
self.enhancer = Enhancer(self.args.enhancer.type, enhancerPath, device=self.device)
|
||||||
self.args.enhancer.type, enhancerPath, device=self.device
|
|
||||||
)
|
|
||||||
self.enhancer_type = self.args.enhancer.type
|
self.enhancer_type = self.args.enhancer.type
|
||||||
self.enhancer_ckpt = enhancerPath
|
self.enhancer_ckpt = enhancerPath
|
||||||
|
|
||||||
@ -118,9 +108,7 @@ class SvcDDSP:
|
|||||||
# print("audio", audio)
|
# print("audio", audio)
|
||||||
# load input
|
# load input
|
||||||
# audio, sample_rate = librosa.load(input_wav, sr=None, mono=True)
|
# audio, sample_rate = librosa.load(input_wav, sr=None, mono=True)
|
||||||
hop_size = (
|
hop_size = self.args.data.block_size * sample_rate / self.args.data.sampling_rate
|
||||||
self.args.data.block_size * sample_rate / self.args.data.sampling_rate
|
|
||||||
)
|
|
||||||
if audio_alignment:
|
if audio_alignment:
|
||||||
audio_length = len(audio)
|
audio_length = len(audio)
|
||||||
# safe front silence
|
# safe front silence
|
||||||
@ -131,12 +119,9 @@ class SvcDDSP:
|
|||||||
audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(self.device)
|
audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
# extract f0
|
# extract f0
|
||||||
pitch_extractor = F0_Extractor(
|
print("pitch_extractor_type", pitch_extractor_type)
|
||||||
pitch_extractor_type, sample_rate, hop_size, float(f0_min), float(f0_max)
|
pitch_extractor = F0_Extractor(pitch_extractor_type, sample_rate, hop_size, float(f0_min), float(f0_max))
|
||||||
)
|
f0 = pitch_extractor.extract(audio, uv_interp=True, device=self.device, silence_front=silence_front)
|
||||||
f0 = pitch_extractor.extract(
|
|
||||||
audio, uv_interp=True, device=self.device, silence_front=silence_front
|
|
||||||
)
|
|
||||||
f0 = torch.from_numpy(f0).float().to(self.device).unsqueeze(-1).unsqueeze(0)
|
f0 = torch.from_numpy(f0).float().to(self.device).unsqueeze(-1).unsqueeze(0)
|
||||||
f0 = f0 * 2 ** (float(pitch_adjust) / 12)
|
f0 = f0 * 2 ** (float(pitch_adjust) / 12)
|
||||||
|
|
||||||
@ -148,9 +133,7 @@ class SvcDDSP:
|
|||||||
mask = np.array([np.max(mask[n : n + 9]) for n in range(len(mask) - 8)]) # type: ignore
|
mask = np.array([np.max(mask[n : n + 9]) for n in range(len(mask) - 8)]) # type: ignore
|
||||||
mask = torch.from_numpy(mask).float().to(self.device).unsqueeze(-1).unsqueeze(0)
|
mask = torch.from_numpy(mask).float().to(self.device).unsqueeze(-1).unsqueeze(0)
|
||||||
mask = upsample(mask, self.args.data.block_size).squeeze(-1)
|
mask = upsample(mask, self.args.data.block_size).squeeze(-1)
|
||||||
volume = (
|
volume = torch.from_numpy(volume).float().to(self.device).unsqueeze(-1).unsqueeze(0)
|
||||||
torch.from_numpy(volume).float().to(self.device).unsqueeze(-1).unsqueeze(0)
|
|
||||||
)
|
|
||||||
|
|
||||||
# extract units
|
# extract units
|
||||||
units = self.units_encoder.encode(audio_t, sample_rate, hop_size)
|
units = self.units_encoder.encode(audio_t, sample_rate, hop_size)
|
||||||
@ -165,9 +148,7 @@ class SvcDDSP:
|
|||||||
|
|
||||||
# forward and return the output
|
# forward and return the output
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output, _, (s_h, s_n) = self.model(
|
output, _, (s_h, s_n) = self.model(units, f0, volume, spk_id=spk_id, spk_mix_dict=dictionary)
|
||||||
units, f0, volume, spk_id=spk_id, spk_mix_dict=dictionary
|
|
||||||
)
|
|
||||||
|
|
||||||
if diff_use and diff_model is not None:
|
if diff_use and diff_model is not None:
|
||||||
output = diff_model.infer(
|
output = diff_model.infer(
|
||||||
|
281
server/voice_changer/DDSP_SVC/models/ddsp/core.py
Normal file
281
server/voice_changer/DDSP_SVC/models/ddsp/core.py
Normal file
@ -0,0 +1,281 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def MaskedAvgPool1d(x, kernel_size):
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
x = F.pad(x, ((kernel_size - 1) // 2, kernel_size // 2), mode="reflect")
|
||||||
|
mask = ~torch.isnan(x)
|
||||||
|
masked_x = torch.where(mask, x, torch.zeros_like(x))
|
||||||
|
ones_kernel = torch.ones(x.size(1), 1, kernel_size, device=x.device)
|
||||||
|
|
||||||
|
# Perform sum pooling
|
||||||
|
sum_pooled = F.conv1d(
|
||||||
|
masked_x,
|
||||||
|
ones_kernel,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
groups=x.size(1),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count the non-masked (valid) elements in each pooling window
|
||||||
|
valid_count = F.conv1d(
|
||||||
|
mask.float(),
|
||||||
|
ones_kernel,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
groups=x.size(1),
|
||||||
|
)
|
||||||
|
valid_count = valid_count.clamp(min=1) # Avoid division by zero
|
||||||
|
|
||||||
|
# Perform masked average pooling
|
||||||
|
avg_pooled = sum_pooled / valid_count
|
||||||
|
|
||||||
|
return avg_pooled.squeeze(1)
|
||||||
|
|
||||||
|
def MedianPool1d(x, kernel_size):
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
x = F.pad(x, ((kernel_size - 1) // 2, kernel_size // 2), mode="reflect")
|
||||||
|
x = x.squeeze(1)
|
||||||
|
x = x.unfold(1, kernel_size, 1)
|
||||||
|
x, _ = torch.sort(x, dim=-1)
|
||||||
|
return x[:, :, (kernel_size - 1) // 2]
|
||||||
|
|
||||||
|
def get_fft_size(frame_size: int, ir_size: int, power_of_2: bool = True):
|
||||||
|
"""Calculate final size for efficient FFT.
|
||||||
|
Args:
|
||||||
|
frame_size: Size of the audio frame.
|
||||||
|
ir_size: Size of the convolving impulse response.
|
||||||
|
power_of_2: Constrain to be a power of 2. If False, allow other 5-smooth
|
||||||
|
numbers. TPU requires power of 2, while GPU is more flexible.
|
||||||
|
Returns:
|
||||||
|
fft_size: Size for efficient FFT.
|
||||||
|
"""
|
||||||
|
convolved_frame_size = ir_size + frame_size - 1
|
||||||
|
if power_of_2:
|
||||||
|
# Next power of 2.
|
||||||
|
fft_size = int(2**np.ceil(np.log2(convolved_frame_size)))
|
||||||
|
else:
|
||||||
|
fft_size = convolved_frame_size
|
||||||
|
return fft_size
|
||||||
|
|
||||||
|
|
||||||
|
def upsample(signal, factor):
|
||||||
|
signal = signal.permute(0, 2, 1)
|
||||||
|
signal = nn.functional.interpolate(torch.cat((signal,signal[:,:,-1:]),2), size=signal.shape[-1] * factor + 1, mode='linear', align_corners=True)
|
||||||
|
signal = signal[:,:,:-1]
|
||||||
|
return signal.permute(0, 2, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_above_fmax(amplitudes, pitch, fmax, level_start=1):
|
||||||
|
n_harm = amplitudes.shape[-1]
|
||||||
|
pitches = pitch * torch.arange(level_start, n_harm + level_start).to(pitch)
|
||||||
|
aa = (pitches < fmax).float() + 1e-7
|
||||||
|
return amplitudes * aa
|
||||||
|
|
||||||
|
|
||||||
|
def crop_and_compensate_delay(audio, audio_size, ir_size,
|
||||||
|
padding = 'same',
|
||||||
|
delay_compensation = -1):
|
||||||
|
"""Crop audio output from convolution to compensate for group delay.
|
||||||
|
Args:
|
||||||
|
audio: Audio after convolution. Tensor of shape [batch, time_steps].
|
||||||
|
audio_size: Initial size of the audio before convolution.
|
||||||
|
ir_size: Size of the convolving impulse response.
|
||||||
|
padding: Either 'valid' or 'same'. For 'same' the final output to be the
|
||||||
|
same size as the input audio (audio_timesteps). For 'valid' the audio is
|
||||||
|
extended to include the tail of the impulse response (audio_timesteps +
|
||||||
|
ir_timesteps - 1).
|
||||||
|
delay_compensation: Samples to crop from start of output audio to compensate
|
||||||
|
for group delay of the impulse response. If delay_compensation < 0 it
|
||||||
|
defaults to automatically calculating a constant group delay of the
|
||||||
|
windowed linear phase filter from frequency_impulse_response().
|
||||||
|
Returns:
|
||||||
|
Tensor of cropped and shifted audio.
|
||||||
|
Raises:
|
||||||
|
ValueError: If padding is not either 'valid' or 'same'.
|
||||||
|
"""
|
||||||
|
# Crop the output.
|
||||||
|
if padding == 'valid':
|
||||||
|
crop_size = ir_size + audio_size - 1
|
||||||
|
elif padding == 'same':
|
||||||
|
crop_size = audio_size
|
||||||
|
else:
|
||||||
|
raise ValueError('Padding must be \'valid\' or \'same\', instead '
|
||||||
|
'of {}.'.format(padding))
|
||||||
|
|
||||||
|
# Compensate for the group delay of the filter by trimming the front.
|
||||||
|
# For an impulse response produced by frequency_impulse_response(),
|
||||||
|
# the group delay is constant because the filter is linear phase.
|
||||||
|
total_size = int(audio.shape[-1])
|
||||||
|
crop = total_size - crop_size
|
||||||
|
start = (ir_size // 2 if delay_compensation < 0 else delay_compensation)
|
||||||
|
end = crop - start
|
||||||
|
return audio[:, start:-end]
|
||||||
|
|
||||||
|
|
||||||
|
def fft_convolve(audio,
|
||||||
|
impulse_response): # B, n_frames, 2*(n_mags-1)
|
||||||
|
"""Filter audio with frames of time-varying impulse responses.
|
||||||
|
Time-varying filter. Given audio [batch, n_samples], and a series of impulse
|
||||||
|
responses [batch, n_frames, n_impulse_response], splits the audio into frames,
|
||||||
|
applies filters, and then overlap-and-adds audio back together.
|
||||||
|
Applies non-windowed non-overlapping STFT/ISTFT to efficiently compute
|
||||||
|
convolution for large impulse response sizes.
|
||||||
|
Args:
|
||||||
|
audio: Input audio. Tensor of shape [batch, audio_timesteps].
|
||||||
|
impulse_response: Finite impulse response to convolve. Can either be a 2-D
|
||||||
|
Tensor of shape [batch, ir_size], or a 3-D Tensor of shape [batch,
|
||||||
|
ir_frames, ir_size]. A 2-D tensor will apply a single linear
|
||||||
|
time-invariant filter to the audio. A 3-D Tensor will apply a linear
|
||||||
|
time-varying filter. Automatically chops the audio into equally shaped
|
||||||
|
blocks to match ir_frames.
|
||||||
|
Returns:
|
||||||
|
audio_out: Convolved audio. Tensor of shape
|
||||||
|
[batch, audio_timesteps].
|
||||||
|
"""
|
||||||
|
# Add a frame dimension to impulse response if it doesn't have one.
|
||||||
|
ir_shape = impulse_response.size()
|
||||||
|
if len(ir_shape) == 2:
|
||||||
|
impulse_response = impulse_response.unsqueeze(1)
|
||||||
|
ir_shape = impulse_response.size()
|
||||||
|
|
||||||
|
# Get shapes of audio and impulse response.
|
||||||
|
batch_size_ir, n_ir_frames, ir_size = ir_shape
|
||||||
|
batch_size, audio_size = audio.size() # B, T
|
||||||
|
|
||||||
|
# Validate that batch sizes match.
|
||||||
|
if batch_size != batch_size_ir:
|
||||||
|
raise ValueError('Batch size of audio ({}) and impulse response ({}) must '
|
||||||
|
'be the same.'.format(batch_size, batch_size_ir))
|
||||||
|
|
||||||
|
# Cut audio into 50% overlapped frames (center padding).
|
||||||
|
hop_size = int(audio_size / n_ir_frames)
|
||||||
|
frame_size = 2 * hop_size
|
||||||
|
audio_frames = F.pad(audio, (hop_size, hop_size)).unfold(1, frame_size, hop_size)
|
||||||
|
|
||||||
|
# Apply Bartlett (triangular) window
|
||||||
|
window = torch.bartlett_window(frame_size).to(audio_frames)
|
||||||
|
audio_frames = audio_frames * window
|
||||||
|
|
||||||
|
# Pad and FFT the audio and impulse responses.
|
||||||
|
fft_size = get_fft_size(frame_size, ir_size, power_of_2=False)
|
||||||
|
audio_fft = torch.fft.rfft(audio_frames, fft_size)
|
||||||
|
ir_fft = torch.fft.rfft(torch.cat((impulse_response,impulse_response[:,-1:,:]),1), fft_size)
|
||||||
|
|
||||||
|
# Multiply the FFTs (same as convolution in time).
|
||||||
|
audio_ir_fft = torch.multiply(audio_fft, ir_fft)
|
||||||
|
|
||||||
|
# Take the IFFT to resynthesize audio.
|
||||||
|
audio_frames_out = torch.fft.irfft(audio_ir_fft, fft_size)
|
||||||
|
|
||||||
|
# Overlap Add
|
||||||
|
batch_size, n_audio_frames, frame_size = audio_frames_out.size() # # B, n_frames+1, 2*(hop_size+n_mags-1)-1
|
||||||
|
fold = torch.nn.Fold(output_size=(1, (n_audio_frames - 1) * hop_size + frame_size),kernel_size=(1, frame_size),stride=(1, hop_size))
|
||||||
|
output_signal = fold(audio_frames_out.transpose(1, 2)).squeeze(1).squeeze(1)
|
||||||
|
|
||||||
|
# Crop and shift the output audio.
|
||||||
|
output_signal = crop_and_compensate_delay(output_signal[:,hop_size:], audio_size, ir_size)
|
||||||
|
return output_signal
|
||||||
|
|
||||||
|
|
||||||
|
def apply_window_to_impulse_response(impulse_response, # B, n_frames, 2*(n_mag-1)
|
||||||
|
window_size: int = 0,
|
||||||
|
causal: bool = False):
|
||||||
|
"""Apply a window to an impulse response and put in causal form.
|
||||||
|
Args:
|
||||||
|
impulse_response: A series of impulse responses frames to window, of shape
|
||||||
|
[batch, n_frames, ir_size]. ---------> ir_size means size of filter_bank ??????
|
||||||
|
|
||||||
|
window_size: Size of the window to apply in the time domain. If window_size
|
||||||
|
is less than 1, it defaults to the impulse_response size.
|
||||||
|
causal: Impulse response input is in causal form (peak in the middle).
|
||||||
|
Returns:
|
||||||
|
impulse_response: Windowed impulse response in causal form, with last
|
||||||
|
dimension cropped to window_size if window_size is greater than 0 and less
|
||||||
|
than ir_size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# If IR is in causal form, put it in zero-phase form.
|
||||||
|
if causal:
|
||||||
|
impulse_response = torch.fftshift(impulse_response, axes=-1)
|
||||||
|
|
||||||
|
# Get a window for better time/frequency resolution than rectangular.
|
||||||
|
# Window defaults to IR size, cannot be bigger.
|
||||||
|
ir_size = int(impulse_response.size(-1))
|
||||||
|
if (window_size <= 0) or (window_size > ir_size):
|
||||||
|
window_size = ir_size
|
||||||
|
window = nn.Parameter(torch.hann_window(window_size), requires_grad = False).to(impulse_response)
|
||||||
|
|
||||||
|
# Zero pad the window and put in in zero-phase form.
|
||||||
|
padding = ir_size - window_size
|
||||||
|
if padding > 0:
|
||||||
|
half_idx = (window_size + 1) // 2
|
||||||
|
window = torch.cat([window[half_idx:],
|
||||||
|
torch.zeros([padding]),
|
||||||
|
window[:half_idx]], axis=0)
|
||||||
|
else:
|
||||||
|
window = window.roll(window.size(-1)//2, -1)
|
||||||
|
|
||||||
|
# Apply the window, to get new IR (both in zero-phase form).
|
||||||
|
window = window.unsqueeze(0)
|
||||||
|
impulse_response = impulse_response * window
|
||||||
|
|
||||||
|
# Put IR in causal form and trim zero padding.
|
||||||
|
if padding > 0:
|
||||||
|
first_half_start = (ir_size - (half_idx - 1)) + 1
|
||||||
|
second_half_end = half_idx + 1
|
||||||
|
impulse_response = torch.cat([impulse_response[..., first_half_start:],
|
||||||
|
impulse_response[..., :second_half_end]],
|
||||||
|
dim=-1)
|
||||||
|
else:
|
||||||
|
impulse_response = impulse_response.roll(impulse_response.size(-1)//2, -1)
|
||||||
|
|
||||||
|
return impulse_response
|
||||||
|
|
||||||
|
|
||||||
|
def apply_dynamic_window_to_impulse_response(impulse_response, # B, n_frames, 2*(n_mag-1) or 2*n_mag-1
|
||||||
|
half_width_frames): # B,n_frames, 1
|
||||||
|
ir_size = int(impulse_response.size(-1)) # 2*(n_mag -1) or 2*n_mag-1
|
||||||
|
|
||||||
|
window = torch.arange(-(ir_size // 2), (ir_size + 1) // 2).to(impulse_response) / half_width_frames
|
||||||
|
window[window > 1] = 0
|
||||||
|
window = (1 + torch.cos(np.pi * window)) / 2 # B, n_frames, 2*(n_mag -1) or 2*n_mag-1
|
||||||
|
|
||||||
|
impulse_response = impulse_response.roll(ir_size // 2, -1)
|
||||||
|
impulse_response = impulse_response * window
|
||||||
|
|
||||||
|
return impulse_response
|
||||||
|
|
||||||
|
|
||||||
|
def frequency_impulse_response(magnitudes,
|
||||||
|
hann_window = True,
|
||||||
|
half_width_frames = None):
|
||||||
|
|
||||||
|
# Get the IR
|
||||||
|
impulse_response = torch.fft.irfft(magnitudes) # B, n_frames, 2*(n_mags-1)
|
||||||
|
|
||||||
|
# Window and put in causal form.
|
||||||
|
if hann_window:
|
||||||
|
if half_width_frames is None:
|
||||||
|
impulse_response = apply_window_to_impulse_response(impulse_response)
|
||||||
|
else:
|
||||||
|
impulse_response = apply_dynamic_window_to_impulse_response(impulse_response, half_width_frames)
|
||||||
|
else:
|
||||||
|
impulse_response = impulse_response.roll(impulse_response.size(-1) // 2, -1)
|
||||||
|
|
||||||
|
return impulse_response
|
||||||
|
|
||||||
|
|
||||||
|
def frequency_filter(audio,
|
||||||
|
magnitudes,
|
||||||
|
hann_window=True,
|
||||||
|
half_width_frames=None):
|
||||||
|
|
||||||
|
impulse_response = frequency_impulse_response(magnitudes, hann_window, half_width_frames)
|
||||||
|
|
||||||
|
return fft_convolve(audio, impulse_response)
|
||||||
|
|
57
server/voice_changer/DDSP_SVC/models/ddsp/loss.py
Normal file
57
server/voice_changer/DDSP_SVC/models/ddsp/loss.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchaudio
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from .core import upsample
|
||||||
|
|
||||||
|
class SSSLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
Single-scale Spectral Loss.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_fft=111, alpha=1.0, overlap=0, eps=1e-7):
|
||||||
|
super().__init__()
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.alpha = alpha
|
||||||
|
self.eps = eps
|
||||||
|
self.hop_length = int(n_fft * (1 - overlap)) # 25% of the length
|
||||||
|
self.spec = torchaudio.transforms.Spectrogram(n_fft=self.n_fft, hop_length=self.hop_length, power=1, normalized=True, center=False)
|
||||||
|
|
||||||
|
def forward(self, x_true, x_pred):
|
||||||
|
S_true = self.spec(x_true) + self.eps
|
||||||
|
S_pred = self.spec(x_pred) + self.eps
|
||||||
|
|
||||||
|
converge_term = torch.mean(torch.linalg.norm(S_true - S_pred, dim = (1, 2)) / torch.linalg.norm(S_true + S_pred, dim = (1, 2)))
|
||||||
|
|
||||||
|
log_term = F.l1_loss(S_true.log(), S_pred.log())
|
||||||
|
|
||||||
|
loss = converge_term + self.alpha * log_term
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class RSSLoss(nn.Module):
|
||||||
|
'''
|
||||||
|
Random-scale Spectral Loss.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, fft_min, fft_max, n_scale, alpha=1.0, overlap=0, eps=1e-7, device='cuda'):
|
||||||
|
super().__init__()
|
||||||
|
self.fft_min = fft_min
|
||||||
|
self.fft_max = fft_max
|
||||||
|
self.n_scale = n_scale
|
||||||
|
self.lossdict = {}
|
||||||
|
for n_fft in range(fft_min, fft_max):
|
||||||
|
self.lossdict[n_fft] = SSSLoss(n_fft, alpha, overlap, eps).to(device)
|
||||||
|
|
||||||
|
def forward(self, x_pred, x_true):
|
||||||
|
value = 0.
|
||||||
|
n_ffts = torch.randint(self.fft_min, self.fft_max, (self.n_scale,))
|
||||||
|
for n_fft in n_ffts:
|
||||||
|
loss_func = self.lossdict[int(n_fft)]
|
||||||
|
value += loss_func(x_true, x_pred)
|
||||||
|
return value / self.n_scale
|
||||||
|
|
||||||
|
|
||||||
|
|
380
server/voice_changer/DDSP_SVC/models/ddsp/pcmer.py
Normal file
380
server/voice_changer/DDSP_SVC/models/ddsp/pcmer.py
Normal file
@ -0,0 +1,380 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
import math
|
||||||
|
from functools import partial
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from local_attention import LocalAttention
|
||||||
|
import torch.nn.functional as F
|
||||||
|
#import fast_transformers.causal_product.causal_product_cuda
|
||||||
|
|
||||||
|
def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None):
|
||||||
|
b, h, *_ = data.shape
|
||||||
|
# (batch size, head, length, model_dim)
|
||||||
|
|
||||||
|
# normalize model dim
|
||||||
|
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
|
||||||
|
|
||||||
|
# what is ration?, projection_matrix.shape[0] --> 266
|
||||||
|
|
||||||
|
ratio = (projection_matrix.shape[0] ** -0.5)
|
||||||
|
|
||||||
|
projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h)
|
||||||
|
projection = projection.type_as(data)
|
||||||
|
|
||||||
|
#data_dash = w^T x
|
||||||
|
data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)
|
||||||
|
|
||||||
|
|
||||||
|
# diag_data = D**2
|
||||||
|
diag_data = data ** 2
|
||||||
|
diag_data = torch.sum(diag_data, dim=-1)
|
||||||
|
diag_data = (diag_data / 2.0) * (data_normalizer ** 2)
|
||||||
|
diag_data = diag_data.unsqueeze(dim=-1)
|
||||||
|
|
||||||
|
#print ()
|
||||||
|
if is_query:
|
||||||
|
data_dash = ratio * (
|
||||||
|
torch.exp(data_dash - diag_data -
|
||||||
|
torch.max(data_dash, dim=-1, keepdim=True).values) + eps)
|
||||||
|
else:
|
||||||
|
data_dash = ratio * (
|
||||||
|
torch.exp(data_dash - diag_data + eps))#- torch.max(data_dash)) + eps)
|
||||||
|
|
||||||
|
return data_dash.type_as(data)
|
||||||
|
|
||||||
|
def orthogonal_matrix_chunk(cols, qr_uniform_q = False, device = None):
|
||||||
|
unstructured_block = torch.randn((cols, cols), device = device)
|
||||||
|
q, r = torch.linalg.qr(unstructured_block.cpu(), mode='reduced')
|
||||||
|
q, r = map(lambda t: t.to(device), (q, r))
|
||||||
|
|
||||||
|
# proposed by @Parskatt
|
||||||
|
# to make sure Q is uniform https://arxiv.org/pdf/math-ph/0609050.pdf
|
||||||
|
if qr_uniform_q:
|
||||||
|
d = torch.diag(r, 0)
|
||||||
|
q *= d.sign()
|
||||||
|
return q.t()
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
def empty(tensor):
|
||||||
|
return tensor.numel() == 0
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
return val if exists(val) else d
|
||||||
|
|
||||||
|
def cast_tuple(val):
|
||||||
|
return (val,) if not isinstance(val, tuple) else val
|
||||||
|
|
||||||
|
class PCmer(nn.Module):
|
||||||
|
"""The encoder that is used in the Transformer model."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_layers,
|
||||||
|
num_heads,
|
||||||
|
dim_model,
|
||||||
|
dim_keys,
|
||||||
|
dim_values,
|
||||||
|
residual_dropout,
|
||||||
|
attention_dropout):
|
||||||
|
super().__init__()
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dim_model = dim_model
|
||||||
|
self.dim_values = dim_values
|
||||||
|
self.dim_keys = dim_keys
|
||||||
|
self.residual_dropout = residual_dropout
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
|
self._layers = nn.ModuleList([_EncoderLayer(self) for _ in range(num_layers)])
|
||||||
|
|
||||||
|
# METHODS ########################################################################################################
|
||||||
|
|
||||||
|
def forward(self, phone, mask=None):
|
||||||
|
|
||||||
|
# apply all layers to the input
|
||||||
|
for (i, layer) in enumerate(self._layers):
|
||||||
|
phone = layer(phone, mask)
|
||||||
|
# provide the final sequence
|
||||||
|
return phone
|
||||||
|
|
||||||
|
|
||||||
|
# ==================================================================================================================== #
|
||||||
|
# CLASS _ E N C O D E R L A Y E R #
|
||||||
|
# ==================================================================================================================== #
|
||||||
|
|
||||||
|
|
||||||
|
class _EncoderLayer(nn.Module):
|
||||||
|
"""One layer of the encoder.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
attn: (:class:`mha.MultiHeadAttention`): The attention mechanism that is used to read the input sequence.
|
||||||
|
feed_forward (:class:`ffl.FeedForwardLayer`): The feed-forward layer on top of the attention mechanism.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, parent: PCmer):
|
||||||
|
"""Creates a new instance of ``_EncoderLayer``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parent (Encoder): The encoder that the layers is created for.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
self.conformer = ConformerConvModule(parent.dim_model)
|
||||||
|
self.norm = nn.LayerNorm(parent.dim_model)
|
||||||
|
self.dropout = nn.Dropout(parent.residual_dropout)
|
||||||
|
|
||||||
|
# selfatt -> fastatt: performer!
|
||||||
|
self.attn = SelfAttention(dim = parent.dim_model,
|
||||||
|
heads = parent.num_heads,
|
||||||
|
causal = False)
|
||||||
|
|
||||||
|
# METHODS ########################################################################################################
|
||||||
|
|
||||||
|
def forward(self, phone, mask=None):
|
||||||
|
|
||||||
|
# compute attention sub-layer
|
||||||
|
phone = phone + (self.attn(self.norm(phone), mask=mask))
|
||||||
|
|
||||||
|
phone = phone + (self.conformer(phone))
|
||||||
|
|
||||||
|
return phone
|
||||||
|
|
||||||
|
def calc_same_padding(kernel_size):
|
||||||
|
pad = kernel_size // 2
|
||||||
|
return (pad, pad - (kernel_size + 1) % 2)
|
||||||
|
|
||||||
|
# helper classes
|
||||||
|
|
||||||
|
class Swish(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x * x.sigmoid()
|
||||||
|
|
||||||
|
class Transpose(nn.Module):
|
||||||
|
def __init__(self, dims):
|
||||||
|
super().__init__()
|
||||||
|
assert len(dims) == 2, 'dims must be a tuple of two dimensions'
|
||||||
|
self.dims = dims
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x.transpose(*self.dims)
|
||||||
|
|
||||||
|
class GLU(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out, gate = x.chunk(2, dim=self.dim)
|
||||||
|
return out * gate.sigmoid()
|
||||||
|
|
||||||
|
class DepthWiseConv1d(nn.Module):
|
||||||
|
def __init__(self, chan_in, chan_out, kernel_size, padding):
|
||||||
|
super().__init__()
|
||||||
|
self.padding = padding
|
||||||
|
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.pad(x, self.padding)
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
class ConformerConvModule(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
causal = False,
|
||||||
|
expansion_factor = 2,
|
||||||
|
kernel_size = 31,
|
||||||
|
dropout = 0.):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
inner_dim = dim * expansion_factor
|
||||||
|
padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
|
||||||
|
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.LayerNorm(dim),
|
||||||
|
Transpose((1, 2)),
|
||||||
|
nn.Conv1d(dim, inner_dim * 2, 1),
|
||||||
|
GLU(dim=1),
|
||||||
|
DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding),
|
||||||
|
#nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
|
||||||
|
Swish(),
|
||||||
|
nn.Conv1d(inner_dim, dim, 1),
|
||||||
|
Transpose((1, 2)),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
def linear_attention(q, k, v):
|
||||||
|
if v is None:
|
||||||
|
#print (k.size(), q.size())
|
||||||
|
out = torch.einsum('...ed,...nd->...ne', k, q)
|
||||||
|
return out
|
||||||
|
|
||||||
|
else:
|
||||||
|
k_cumsum = k.sum(dim = -2)
|
||||||
|
#k_cumsum = k.sum(dim = -2)
|
||||||
|
D_inv = 1. / (torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q)) + 1e-8)
|
||||||
|
|
||||||
|
context = torch.einsum('...nd,...ne->...de', k, v)
|
||||||
|
#print ("TRUEEE: ", context.size(), q.size(), D_inv.size())
|
||||||
|
out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, qr_uniform_q = False, device = None):
|
||||||
|
nb_full_blocks = int(nb_rows / nb_columns)
|
||||||
|
#print (nb_full_blocks)
|
||||||
|
block_list = []
|
||||||
|
|
||||||
|
for _ in range(nb_full_blocks):
|
||||||
|
q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device)
|
||||||
|
block_list.append(q)
|
||||||
|
# block_list[n] is a orthogonal matrix ... (model_dim * model_dim)
|
||||||
|
#print (block_list[0].size(), torch.einsum('...nd,...nd->...n', block_list[0], torch.roll(block_list[0],1,1)))
|
||||||
|
#print (nb_rows, nb_full_blocks, nb_columns)
|
||||||
|
remaining_rows = nb_rows - nb_full_blocks * nb_columns
|
||||||
|
#print (remaining_rows)
|
||||||
|
if remaining_rows > 0:
|
||||||
|
q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device)
|
||||||
|
#print (q[:remaining_rows].size())
|
||||||
|
block_list.append(q[:remaining_rows])
|
||||||
|
|
||||||
|
final_matrix = torch.cat(block_list)
|
||||||
|
|
||||||
|
if scaling == 0:
|
||||||
|
multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1)
|
||||||
|
elif scaling == 1:
|
||||||
|
multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Invalid scaling {scaling}')
|
||||||
|
|
||||||
|
return torch.diag(multiplier) @ final_matrix
|
||||||
|
|
||||||
|
class FastAttention(nn.Module):
|
||||||
|
def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, no_projection = False):
|
||||||
|
super().__init__()
|
||||||
|
nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
|
||||||
|
|
||||||
|
self.dim_heads = dim_heads
|
||||||
|
self.nb_features = nb_features
|
||||||
|
self.ortho_scaling = ortho_scaling
|
||||||
|
|
||||||
|
self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling, qr_uniform_q = qr_uniform_q)
|
||||||
|
projection_matrix = self.create_projection()
|
||||||
|
self.register_buffer('projection_matrix', projection_matrix)
|
||||||
|
|
||||||
|
self.generalized_attention = generalized_attention
|
||||||
|
self.kernel_fn = kernel_fn
|
||||||
|
|
||||||
|
# if this is turned on, no projection will be used
|
||||||
|
# queries and keys will be softmax-ed as in the original efficient attention paper
|
||||||
|
self.no_projection = no_projection
|
||||||
|
|
||||||
|
self.causal = causal
|
||||||
|
if causal:
|
||||||
|
try:
|
||||||
|
import fast_transformers.causal_product.causal_product_cuda
|
||||||
|
self.causal_linear_fn = partial(causal_linear_attention)
|
||||||
|
except ImportError:
|
||||||
|
print('unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version')
|
||||||
|
self.causal_linear_fn = causal_linear_attention_noncuda
|
||||||
|
@torch.no_grad()
|
||||||
|
def redraw_projection_matrix(self):
|
||||||
|
projections = self.create_projection()
|
||||||
|
self.projection_matrix.copy_(projections)
|
||||||
|
del projections
|
||||||
|
|
||||||
|
def forward(self, q, k, v):
|
||||||
|
device = q.device
|
||||||
|
|
||||||
|
if self.no_projection:
|
||||||
|
q = q.softmax(dim = -1)
|
||||||
|
k = torch.exp(k) if self.causal else k.softmax(dim = -2)
|
||||||
|
|
||||||
|
elif self.generalized_attention:
|
||||||
|
create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device)
|
||||||
|
q, k = map(create_kernel, (q, k))
|
||||||
|
|
||||||
|
else:
|
||||||
|
create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device)
|
||||||
|
|
||||||
|
q = create_kernel(q, is_query = True)
|
||||||
|
k = create_kernel(k, is_query = False)
|
||||||
|
|
||||||
|
attn_fn = linear_attention if not self.causal else self.causal_linear_fn
|
||||||
|
if v is None:
|
||||||
|
out = attn_fn(q, k, None)
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
out = attn_fn(q, k, v)
|
||||||
|
return out
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
def __init__(self, dim, causal = False, heads = 8, dim_head = 64, local_heads = 0, local_window_size = 256, nb_features = None, feature_redraw_interval = 1000, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, dropout = 0., no_projection = False):
|
||||||
|
super().__init__()
|
||||||
|
assert dim % heads == 0, 'dimension must be divisible by number of heads'
|
||||||
|
dim_head = default(dim_head, dim // heads)
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
self.fast_attention = FastAttention(dim_head, nb_features, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, qr_uniform_q = qr_uniform_q, no_projection = no_projection)
|
||||||
|
|
||||||
|
self.heads = heads
|
||||||
|
self.global_heads = heads - local_heads
|
||||||
|
self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None
|
||||||
|
|
||||||
|
#print (heads, nb_features, dim_head)
|
||||||
|
#name_embedding = torch.zeros(110, heads, dim_head, dim_head)
|
||||||
|
#self.name_embedding = nn.Parameter(name_embedding, requires_grad=True)
|
||||||
|
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(dim, inner_dim)
|
||||||
|
self.to_k = nn.Linear(dim, inner_dim)
|
||||||
|
self.to_v = nn.Linear(dim, inner_dim)
|
||||||
|
self.to_out = nn.Linear(inner_dim, dim)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def redraw_projection_matrix(self):
|
||||||
|
self.fast_attention.redraw_projection_matrix()
|
||||||
|
#torch.nn.init.zeros_(self.name_embedding)
|
||||||
|
#print (torch.sum(self.name_embedding))
|
||||||
|
def forward(self, x, context = None, mask = None, context_mask = None, name=None, inference=False, **kwargs):
|
||||||
|
b, n, _, h, gh = *x.shape, self.heads, self.global_heads
|
||||||
|
|
||||||
|
cross_attend = exists(context)
|
||||||
|
|
||||||
|
context = default(context, x)
|
||||||
|
context_mask = default(context_mask, mask) if not cross_attend else context_mask
|
||||||
|
#print (torch.sum(self.name_embedding))
|
||||||
|
q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
||||||
|
(q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))
|
||||||
|
|
||||||
|
attn_outs = []
|
||||||
|
#print (name)
|
||||||
|
#print (self.name_embedding[name].size())
|
||||||
|
if not empty(q):
|
||||||
|
if exists(context_mask):
|
||||||
|
global_mask = context_mask[:, None, :, None]
|
||||||
|
v.masked_fill_(~global_mask, 0.)
|
||||||
|
if cross_attend:
|
||||||
|
pass
|
||||||
|
#print (torch.sum(self.name_embedding))
|
||||||
|
#out = self.fast_attention(q,self.name_embedding[name],None)
|
||||||
|
#print (torch.sum(self.name_embedding[...,-1:]))
|
||||||
|
else:
|
||||||
|
out = self.fast_attention(q, k, v)
|
||||||
|
attn_outs.append(out)
|
||||||
|
|
||||||
|
if not empty(lq):
|
||||||
|
assert not cross_attend, 'local attention is not compatible with cross attention'
|
||||||
|
out = self.local_attn(lq, lk, lv, input_mask = mask)
|
||||||
|
attn_outs.append(out)
|
||||||
|
|
||||||
|
out = torch.cat(attn_outs, dim = 1)
|
||||||
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
|
out = self.to_out(out)
|
||||||
|
return self.dropout(out)
|
86
server/voice_changer/DDSP_SVC/models/ddsp/unit2control.py
Normal file
86
server/voice_changer/DDSP_SVC/models/ddsp/unit2control.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import gin
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.utils import weight_norm
|
||||||
|
|
||||||
|
from .pcmer import PCmer
|
||||||
|
|
||||||
|
|
||||||
|
def split_to_dict(tensor, tensor_splits):
|
||||||
|
"""Split a tensor into a dictionary of multiple tensors."""
|
||||||
|
labels = []
|
||||||
|
sizes = []
|
||||||
|
|
||||||
|
for k, v in tensor_splits.items():
|
||||||
|
labels.append(k)
|
||||||
|
sizes.append(v)
|
||||||
|
|
||||||
|
tensors = torch.split(tensor, sizes, dim=-1)
|
||||||
|
return dict(zip(labels, tensors))
|
||||||
|
|
||||||
|
|
||||||
|
class Unit2Control(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_channel,
|
||||||
|
n_spk,
|
||||||
|
output_splits):
|
||||||
|
super().__init__()
|
||||||
|
self.output_splits = output_splits
|
||||||
|
self.f0_embed = nn.Linear(1, 256)
|
||||||
|
self.phase_embed = nn.Linear(1, 256)
|
||||||
|
self.volume_embed = nn.Linear(1, 256)
|
||||||
|
self.n_spk = n_spk
|
||||||
|
if n_spk is not None and n_spk > 1:
|
||||||
|
self.spk_embed = nn.Embedding(n_spk, 256)
|
||||||
|
|
||||||
|
# conv in stack
|
||||||
|
self.stack = nn.Sequential(
|
||||||
|
nn.Conv1d(input_channel, 256, 3, 1, 1),
|
||||||
|
nn.GroupNorm(4, 256),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Conv1d(256, 256, 3, 1, 1))
|
||||||
|
|
||||||
|
# transformer
|
||||||
|
self.decoder = PCmer(
|
||||||
|
num_layers=3,
|
||||||
|
num_heads=8,
|
||||||
|
dim_model=256,
|
||||||
|
dim_keys=256,
|
||||||
|
dim_values=256,
|
||||||
|
residual_dropout=0.1,
|
||||||
|
attention_dropout=0.1)
|
||||||
|
self.norm = nn.LayerNorm(256)
|
||||||
|
|
||||||
|
# out
|
||||||
|
self.n_out = sum([v for k, v in output_splits.items()])
|
||||||
|
self.dense_out = weight_norm(
|
||||||
|
nn.Linear(256, self.n_out))
|
||||||
|
|
||||||
|
def forward(self, units, f0, phase, volume, spk_id = None, spk_mix_dict = None):
|
||||||
|
|
||||||
|
'''
|
||||||
|
input:
|
||||||
|
B x n_frames x n_unit
|
||||||
|
return:
|
||||||
|
dict of B x n_frames x feat
|
||||||
|
'''
|
||||||
|
|
||||||
|
x = self.stack(units.transpose(1,2)).transpose(1,2)
|
||||||
|
x = x + self.f0_embed((1+ f0 / 700).log()) + self.phase_embed(phase / np.pi) + self.volume_embed(volume)
|
||||||
|
if self.n_spk is not None and self.n_spk > 1:
|
||||||
|
if spk_mix_dict is not None:
|
||||||
|
for k, v in spk_mix_dict.items():
|
||||||
|
spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device)
|
||||||
|
x = x + v * self.spk_embed(spk_id_torch - 1)
|
||||||
|
else:
|
||||||
|
x = x + self.spk_embed(spk_id - 1)
|
||||||
|
x = self.decoder(x)
|
||||||
|
x = self.norm(x)
|
||||||
|
e = self.dense_out(x)
|
||||||
|
controls = split_to_dict(e, self.output_splits)
|
||||||
|
|
||||||
|
return controls
|
||||||
|
|
639
server/voice_changer/DDSP_SVC/models/ddsp/vocoder.py
Normal file
639
server/voice_changer/DDSP_SVC/models/ddsp/vocoder.py
Normal file
@ -0,0 +1,639 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import yaml
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import pyworld as pw
|
||||||
|
import parselmouth
|
||||||
|
import torchcrepe
|
||||||
|
from transformers import HubertModel, Wav2Vec2FeatureExtractor
|
||||||
|
from fairseq import checkpoint_utils
|
||||||
|
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
||||||
|
from torchaudio.transforms import Resample
|
||||||
|
from .unit2control import Unit2Control
|
||||||
|
from .core import frequency_filter, upsample, remove_above_fmax, MaskedAvgPool1d, MedianPool1d
|
||||||
|
from ..encoder.hubert.model import HubertSoft
|
||||||
|
|
||||||
|
CREPE_RESAMPLE_KERNEL = {}
|
||||||
|
|
||||||
|
|
||||||
|
class F0_Extractor:
|
||||||
|
def __init__(self, f0_extractor, sample_rate=44100, hop_size=512, f0_min=65, f0_max=800):
|
||||||
|
self.f0_extractor = f0_extractor
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.hop_size = hop_size
|
||||||
|
self.f0_min = f0_min
|
||||||
|
self.f0_max = f0_max
|
||||||
|
if f0_extractor == "crepe":
|
||||||
|
key_str = str(sample_rate)
|
||||||
|
if key_str not in CREPE_RESAMPLE_KERNEL:
|
||||||
|
CREPE_RESAMPLE_KERNEL[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128)
|
||||||
|
self.resample_kernel = CREPE_RESAMPLE_KERNEL[key_str]
|
||||||
|
|
||||||
|
def extract(self, audio, uv_interp=False, device=None, silence_front=0): # audio: 1d numpy array
|
||||||
|
# extractor start time
|
||||||
|
n_frames = int(len(audio) // self.hop_size) + 1
|
||||||
|
|
||||||
|
start_frame = int(silence_front * self.sample_rate / self.hop_size)
|
||||||
|
real_silence_front = start_frame * self.hop_size / self.sample_rate
|
||||||
|
audio = audio[int(np.round(real_silence_front * self.sample_rate)) :]
|
||||||
|
|
||||||
|
# extract f0 using parselmouth
|
||||||
|
if self.f0_extractor == "parselmouth":
|
||||||
|
f0 = parselmouth.Sound(audio, self.sample_rate).to_pitch_ac(time_step=self.hop_size / self.sample_rate, voicing_threshold=0.6, pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array["frequency"]
|
||||||
|
pad_size = start_frame + (int(len(audio) // self.hop_size) - len(f0) + 1) // 2
|
||||||
|
f0 = np.pad(f0, (pad_size, n_frames - len(f0) - pad_size))
|
||||||
|
|
||||||
|
# extract f0 using dio
|
||||||
|
elif self.f0_extractor == "dio":
|
||||||
|
_f0, t = pw.dio(audio.astype("double"), self.sample_rate, f0_floor=self.f0_min, f0_ceil=self.f0_max, channels_in_octave=2, frame_period=(1000 * self.hop_size / self.sample_rate))
|
||||||
|
f0 = pw.stonemask(audio.astype("double"), _f0, t, self.sample_rate)
|
||||||
|
f0 = np.pad(f0.astype("float"), (start_frame, n_frames - len(f0) - start_frame))
|
||||||
|
|
||||||
|
# extract f0 using harvest
|
||||||
|
elif self.f0_extractor == "harvest":
|
||||||
|
f0, _ = pw.harvest(audio.astype("double"), self.sample_rate, f0_floor=self.f0_min, f0_ceil=self.f0_max, frame_period=(1000 * self.hop_size / self.sample_rate))
|
||||||
|
f0 = np.pad(f0.astype("float"), (start_frame, n_frames - len(f0) - start_frame))
|
||||||
|
|
||||||
|
# extract f0 using crepe
|
||||||
|
elif self.f0_extractor == "crepe":
|
||||||
|
if device is None:
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
resample_kernel = self.resample_kernel.to(device)
|
||||||
|
wav16k_torch = resample_kernel(torch.FloatTensor(audio).unsqueeze(0).to(device))
|
||||||
|
|
||||||
|
f0, pd = torchcrepe.predict(wav16k_torch, 16000, 80, self.f0_min, self.f0_max, pad=True, model="full", batch_size=512, device=device, return_periodicity=True)
|
||||||
|
pd = MedianPool1d(pd, 4)
|
||||||
|
f0 = torchcrepe.threshold.At(0.05)(f0, pd)
|
||||||
|
f0 = MaskedAvgPool1d(f0, 4)
|
||||||
|
|
||||||
|
f0 = f0.squeeze(0).cpu().numpy()
|
||||||
|
f0 = np.array([f0[int(min(int(np.round(n * self.hop_size / self.sample_rate / 0.005)), len(f0) - 1))] for n in range(n_frames - start_frame)])
|
||||||
|
f0 = np.pad(f0, (start_frame, 0))
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f" [x] Unknown f0 extractor: {f0_extractor}")
|
||||||
|
|
||||||
|
# interpolate the unvoiced f0
|
||||||
|
if uv_interp:
|
||||||
|
uv = f0 == 0
|
||||||
|
if len(f0[~uv]) > 0:
|
||||||
|
f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
|
||||||
|
f0[f0 < self.f0_min] = self.f0_min
|
||||||
|
return f0
|
||||||
|
|
||||||
|
|
||||||
|
class Volume_Extractor:
|
||||||
|
def __init__(self, hop_size=512):
|
||||||
|
self.hop_size = hop_size
|
||||||
|
|
||||||
|
def extract(self, audio): # audio: 1d numpy array
|
||||||
|
n_frames = int(len(audio) // self.hop_size) + 1
|
||||||
|
audio2 = audio**2
|
||||||
|
audio2 = np.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)), mode="reflect")
|
||||||
|
volume = np.array([np.mean(audio2[int(n * self.hop_size) : int((n + 1) * self.hop_size)]) for n in range(n_frames)])
|
||||||
|
volume = np.sqrt(volume)
|
||||||
|
return volume
|
||||||
|
|
||||||
|
|
||||||
|
class Units_Encoder:
|
||||||
|
def __init__(self, encoder, encoder_ckpt, encoder_sample_rate=16000, encoder_hop_size=320, device=None, cnhubertsoft_gate=10):
|
||||||
|
if device is None:
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
is_loaded_encoder = False
|
||||||
|
if encoder == "hubertsoft":
|
||||||
|
self.model = Audio2HubertSoft(encoder_ckpt).to(device)
|
||||||
|
is_loaded_encoder = True
|
||||||
|
if encoder == "hubertbase":
|
||||||
|
self.model = Audio2HubertBase(encoder_ckpt, device=device)
|
||||||
|
is_loaded_encoder = True
|
||||||
|
if encoder == "hubertbase768":
|
||||||
|
self.model = Audio2HubertBase768(encoder_ckpt, device=device)
|
||||||
|
is_loaded_encoder = True
|
||||||
|
if encoder == "hubertbase768l12":
|
||||||
|
self.model = Audio2HubertBase768L12(encoder_ckpt, device=device)
|
||||||
|
is_loaded_encoder = True
|
||||||
|
if encoder == "hubertlarge1024l24":
|
||||||
|
self.model = Audio2HubertLarge1024L24(encoder_ckpt, device=device)
|
||||||
|
is_loaded_encoder = True
|
||||||
|
if encoder == "contentvec":
|
||||||
|
self.model = Audio2ContentVec(encoder_ckpt, device=device)
|
||||||
|
is_loaded_encoder = True
|
||||||
|
if encoder == "contentvec768":
|
||||||
|
self.model = Audio2ContentVec768(encoder_ckpt, device=device)
|
||||||
|
is_loaded_encoder = True
|
||||||
|
if encoder == "contentvec768l12":
|
||||||
|
self.model = Audio2ContentVec768L12(encoder_ckpt, device=device)
|
||||||
|
is_loaded_encoder = True
|
||||||
|
if encoder == "cnhubertsoftfish":
|
||||||
|
self.model = CNHubertSoftFish(encoder_ckpt, device=device, gate_size=cnhubertsoft_gate)
|
||||||
|
is_loaded_encoder = True
|
||||||
|
if not is_loaded_encoder:
|
||||||
|
raise ValueError(f" [x] Unknown units encoder: {encoder}")
|
||||||
|
|
||||||
|
self.resample_kernel = {}
|
||||||
|
self.encoder_sample_rate = encoder_sample_rate
|
||||||
|
self.encoder_hop_size = encoder_hop_size
|
||||||
|
|
||||||
|
def encode(self, audio, sample_rate, hop_size): # B, T
|
||||||
|
# resample
|
||||||
|
if sample_rate == self.encoder_sample_rate:
|
||||||
|
audio_res = audio
|
||||||
|
else:
|
||||||
|
key_str = str(sample_rate)
|
||||||
|
if key_str not in self.resample_kernel:
|
||||||
|
self.resample_kernel[key_str] = Resample(sample_rate, self.encoder_sample_rate, lowpass_filter_width=128).to(self.device)
|
||||||
|
audio_res = self.resample_kernel[key_str](audio)
|
||||||
|
|
||||||
|
# encode
|
||||||
|
if audio_res.size(-1) < 400:
|
||||||
|
audio_res = torch.nn.functional.pad(audio, (0, 400 - audio_res.size(-1)))
|
||||||
|
units = self.model(audio_res)
|
||||||
|
|
||||||
|
# alignment
|
||||||
|
n_frames = audio.size(-1) // hop_size + 1
|
||||||
|
ratio = (hop_size / sample_rate) / (self.encoder_hop_size / self.encoder_sample_rate)
|
||||||
|
index = torch.clamp(torch.round(ratio * torch.arange(n_frames).to(self.device)).long(), max=units.size(1) - 1)
|
||||||
|
units_aligned = torch.gather(units, 1, index.unsqueeze(0).unsqueeze(-1).repeat([1, 1, units.size(-1)]))
|
||||||
|
return units_aligned
|
||||||
|
|
||||||
|
|
||||||
|
class Audio2HubertSoft(torch.nn.Module):
|
||||||
|
def __init__(self, path, h_sample_rate=16000, h_hop_size=320):
|
||||||
|
super().__init__()
|
||||||
|
print(" [Encoder Model] HuBERT Soft")
|
||||||
|
self.hubert = HubertSoft()
|
||||||
|
print(" [Loading] " + path)
|
||||||
|
checkpoint = torch.load(path)
|
||||||
|
consume_prefix_in_state_dict_if_present(checkpoint, "module.")
|
||||||
|
self.hubert.load_state_dict(checkpoint)
|
||||||
|
self.hubert.eval()
|
||||||
|
|
||||||
|
def forward(self, audio): # B, T
|
||||||
|
with torch.inference_mode():
|
||||||
|
units = self.hubert.units(audio.unsqueeze(1))
|
||||||
|
return units
|
||||||
|
|
||||||
|
|
||||||
|
class Audio2ContentVec:
|
||||||
|
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"):
|
||||||
|
self.device = device
|
||||||
|
print(" [Encoder Model] Content Vec")
|
||||||
|
print(" [Loading] " + path)
|
||||||
|
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
|
||||||
|
[path],
|
||||||
|
suffix="",
|
||||||
|
)
|
||||||
|
self.hubert = self.models[0]
|
||||||
|
self.hubert = self.hubert.to(self.device)
|
||||||
|
self.hubert.eval()
|
||||||
|
|
||||||
|
def __call__(self, audio): # B, T
|
||||||
|
# wav_tensor = torch.from_numpy(audio).to(self.device)
|
||||||
|
wav_tensor = audio
|
||||||
|
feats = wav_tensor.view(1, -1)
|
||||||
|
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
|
||||||
|
inputs = {
|
||||||
|
"source": feats.to(wav_tensor.device),
|
||||||
|
"padding_mask": padding_mask.to(wav_tensor.device),
|
||||||
|
"output_layer": 9, # layer 9
|
||||||
|
}
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = self.hubert.extract_features(**inputs)
|
||||||
|
feats = self.hubert.final_proj(logits[0])
|
||||||
|
units = feats # .transpose(2, 1)
|
||||||
|
return units
|
||||||
|
|
||||||
|
|
||||||
|
class Audio2ContentVec768:
|
||||||
|
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"):
|
||||||
|
self.device = device
|
||||||
|
print(" [Encoder Model] Content Vec")
|
||||||
|
print(" [Loading] " + path)
|
||||||
|
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
|
||||||
|
[path],
|
||||||
|
suffix="",
|
||||||
|
)
|
||||||
|
self.hubert = self.models[0]
|
||||||
|
self.hubert = self.hubert.to(self.device)
|
||||||
|
self.hubert.eval()
|
||||||
|
|
||||||
|
def __call__(self, audio): # B, T
|
||||||
|
# wav_tensor = torch.from_numpy(audio).to(self.device)
|
||||||
|
wav_tensor = audio
|
||||||
|
feats = wav_tensor.view(1, -1)
|
||||||
|
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
|
||||||
|
inputs = {
|
||||||
|
"source": feats.to(wav_tensor.device),
|
||||||
|
"padding_mask": padding_mask.to(wav_tensor.device),
|
||||||
|
"output_layer": 9, # layer 9
|
||||||
|
}
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = self.hubert.extract_features(**inputs)
|
||||||
|
feats = logits[0]
|
||||||
|
units = feats # .transpose(2, 1)
|
||||||
|
return units
|
||||||
|
|
||||||
|
|
||||||
|
class Audio2ContentVec768L12:
|
||||||
|
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"):
|
||||||
|
self.device = device
|
||||||
|
print(" [Encoder Model] Content Vec")
|
||||||
|
print(" [Loading] " + path)
|
||||||
|
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
|
||||||
|
[path],
|
||||||
|
suffix="",
|
||||||
|
)
|
||||||
|
self.hubert = self.models[0]
|
||||||
|
self.hubert = self.hubert.to(self.device)
|
||||||
|
self.hubert.eval()
|
||||||
|
|
||||||
|
def __call__(self, audio): # B, T
|
||||||
|
# wav_tensor = torch.from_numpy(audio).to(self.device)
|
||||||
|
wav_tensor = audio
|
||||||
|
feats = wav_tensor.view(1, -1)
|
||||||
|
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
|
||||||
|
inputs = {
|
||||||
|
"source": feats.to(wav_tensor.device),
|
||||||
|
"padding_mask": padding_mask.to(wav_tensor.device),
|
||||||
|
"output_layer": 12, # layer 12
|
||||||
|
}
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = self.hubert.extract_features(**inputs)
|
||||||
|
feats = logits[0]
|
||||||
|
units = feats # .transpose(2, 1)
|
||||||
|
return units
|
||||||
|
|
||||||
|
|
||||||
|
class CNHubertSoftFish(torch.nn.Module):
|
||||||
|
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu", gate_size=10):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
self.gate_size = gate_size
|
||||||
|
|
||||||
|
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("./pretrain/TencentGameMate/chinese-hubert-base")
|
||||||
|
self.model = HubertModel.from_pretrained("./pretrain/TencentGameMate/chinese-hubert-base")
|
||||||
|
self.proj = torch.nn.Sequential(torch.nn.Dropout(0.1), torch.nn.Linear(768, 256))
|
||||||
|
# self.label_embedding = nn.Embedding(128, 256)
|
||||||
|
|
||||||
|
state_dict = torch.load(path, map_location=device)
|
||||||
|
self.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, audio):
|
||||||
|
input_values = self.feature_extractor(audio, sampling_rate=16000, return_tensors="pt").input_values
|
||||||
|
input_values = input_values.to(self.model.device)
|
||||||
|
|
||||||
|
return self._forward(input_values[0])
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _forward(self, input_values):
|
||||||
|
features = self.model(input_values)
|
||||||
|
features = self.proj(features.last_hidden_state)
|
||||||
|
|
||||||
|
# Top-k gating
|
||||||
|
topk, indices = torch.topk(features, self.gate_size, dim=2)
|
||||||
|
features = torch.zeros_like(features).scatter(2, indices, topk)
|
||||||
|
features = features / features.sum(2, keepdim=True)
|
||||||
|
|
||||||
|
return features.to(self.device) # .transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
class Audio2HubertBase:
|
||||||
|
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"):
|
||||||
|
self.device = device
|
||||||
|
print(" [Encoder Model] HuBERT Base")
|
||||||
|
print(" [Loading] " + path)
|
||||||
|
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
|
||||||
|
[path],
|
||||||
|
suffix="",
|
||||||
|
)
|
||||||
|
self.hubert = self.models[0]
|
||||||
|
self.hubert = self.hubert.to(self.device)
|
||||||
|
self.hubert = self.hubert.float()
|
||||||
|
self.hubert.eval()
|
||||||
|
|
||||||
|
def __call__(self, audio): # B, T
|
||||||
|
with torch.no_grad():
|
||||||
|
padding_mask = torch.BoolTensor(audio.shape).fill_(False)
|
||||||
|
inputs = {
|
||||||
|
"source": audio.to(self.device),
|
||||||
|
"padding_mask": padding_mask.to(self.device),
|
||||||
|
"output_layer": 9, # layer 9
|
||||||
|
}
|
||||||
|
logits = self.hubert.extract_features(**inputs)
|
||||||
|
units = self.hubert.final_proj(logits[0])
|
||||||
|
return units
|
||||||
|
|
||||||
|
|
||||||
|
class Audio2HubertBase768:
|
||||||
|
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"):
|
||||||
|
self.device = device
|
||||||
|
print(" [Encoder Model] HuBERT Base")
|
||||||
|
print(" [Loading] " + path)
|
||||||
|
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
|
||||||
|
[path],
|
||||||
|
suffix="",
|
||||||
|
)
|
||||||
|
self.hubert = self.models[0]
|
||||||
|
self.hubert = self.hubert.to(self.device)
|
||||||
|
self.hubert = self.hubert.float()
|
||||||
|
self.hubert.eval()
|
||||||
|
|
||||||
|
def __call__(self, audio): # B, T
|
||||||
|
with torch.no_grad():
|
||||||
|
padding_mask = torch.BoolTensor(audio.shape).fill_(False)
|
||||||
|
inputs = {
|
||||||
|
"source": audio.to(self.device),
|
||||||
|
"padding_mask": padding_mask.to(self.device),
|
||||||
|
"output_layer": 9, # layer 9
|
||||||
|
}
|
||||||
|
logits = self.hubert.extract_features(**inputs)
|
||||||
|
units = logits[0]
|
||||||
|
return units
|
||||||
|
|
||||||
|
|
||||||
|
class Audio2HubertBase768L12:
|
||||||
|
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"):
|
||||||
|
self.device = device
|
||||||
|
print(" [Encoder Model] HuBERT Base")
|
||||||
|
print(" [Loading] " + path)
|
||||||
|
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
|
||||||
|
[path],
|
||||||
|
suffix="",
|
||||||
|
)
|
||||||
|
self.hubert = self.models[0]
|
||||||
|
self.hubert = self.hubert.to(self.device)
|
||||||
|
self.hubert = self.hubert.float()
|
||||||
|
self.hubert.eval()
|
||||||
|
|
||||||
|
def __call__(self, audio): # B, T
|
||||||
|
with torch.no_grad():
|
||||||
|
padding_mask = torch.BoolTensor(audio.shape).fill_(False)
|
||||||
|
inputs = {
|
||||||
|
"source": audio.to(self.device),
|
||||||
|
"padding_mask": padding_mask.to(self.device),
|
||||||
|
"output_layer": 12, # layer 12
|
||||||
|
}
|
||||||
|
logits = self.hubert.extract_features(**inputs)
|
||||||
|
units = logits[0]
|
||||||
|
return units
|
||||||
|
|
||||||
|
|
||||||
|
class Audio2HubertLarge1024L24:
|
||||||
|
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"):
|
||||||
|
self.device = device
|
||||||
|
print(" [Encoder Model] HuBERT Base")
|
||||||
|
print(" [Loading] " + path)
|
||||||
|
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
|
||||||
|
[path],
|
||||||
|
suffix="",
|
||||||
|
)
|
||||||
|
self.hubert = self.models[0]
|
||||||
|
self.hubert = self.hubert.to(self.device)
|
||||||
|
self.hubert = self.hubert.float()
|
||||||
|
self.hubert.eval()
|
||||||
|
|
||||||
|
def __call__(self, audio): # B, T
|
||||||
|
with torch.no_grad():
|
||||||
|
padding_mask = torch.BoolTensor(audio.shape).fill_(False)
|
||||||
|
inputs = {
|
||||||
|
"source": audio.to(self.device),
|
||||||
|
"padding_mask": padding_mask.to(self.device),
|
||||||
|
"output_layer": 24, # layer 24
|
||||||
|
}
|
||||||
|
logits = self.hubert.extract_features(**inputs)
|
||||||
|
units = logits[0]
|
||||||
|
return units
|
||||||
|
|
||||||
|
|
||||||
|
class DotDict(dict):
|
||||||
|
def __getattr__(*args):
|
||||||
|
val = dict.get(*args)
|
||||||
|
return DotDict(val) if type(val) is dict else val
|
||||||
|
|
||||||
|
__setattr__ = dict.__setitem__
|
||||||
|
__delattr__ = dict.__delitem__
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_path, device="cpu"):
|
||||||
|
config_file = os.path.join(os.path.split(model_path)[0], "config.yaml")
|
||||||
|
with open(config_file, "r") as config:
|
||||||
|
args = yaml.safe_load(config)
|
||||||
|
args = DotDict(args)
|
||||||
|
|
||||||
|
# load model
|
||||||
|
model = None
|
||||||
|
|
||||||
|
if args.model.type == "Sins":
|
||||||
|
model = Sins(sampling_rate=args.data.sampling_rate, block_size=args.data.block_size, n_harmonics=args.model.n_harmonics, n_mag_allpass=args.model.n_mag_allpass, n_mag_noise=args.model.n_mag_noise, n_unit=args.data.encoder_out_channels, n_spk=args.model.n_spk)
|
||||||
|
|
||||||
|
elif args.model.type == "CombSub":
|
||||||
|
model = CombSub(sampling_rate=args.data.sampling_rate, block_size=args.data.block_size, n_mag_allpass=args.model.n_mag_allpass, n_mag_harmonic=args.model.n_mag_harmonic, n_mag_noise=args.model.n_mag_noise, n_unit=args.data.encoder_out_channels, n_spk=args.model.n_spk)
|
||||||
|
|
||||||
|
elif args.model.type == "CombSubFast":
|
||||||
|
model = CombSubFast(sampling_rate=args.data.sampling_rate, block_size=args.data.block_size, n_unit=args.data.encoder_out_channels, n_spk=args.model.n_spk)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f" [x] Unknown Model: {args.model.type}")
|
||||||
|
|
||||||
|
print(" [Loading] " + model_path)
|
||||||
|
ckpt = torch.load(model_path, map_location=torch.device(device))
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(ckpt["model"])
|
||||||
|
model.eval()
|
||||||
|
return model, args
|
||||||
|
|
||||||
|
|
||||||
|
class Sins(torch.nn.Module):
|
||||||
|
def __init__(self, sampling_rate, block_size, n_harmonics, n_mag_allpass, n_mag_noise, n_unit=256, n_spk=1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
print(" [DDSP Model] Sinusoids Additive Synthesiser")
|
||||||
|
|
||||||
|
# params
|
||||||
|
self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
|
||||||
|
self.register_buffer("block_size", torch.tensor(block_size))
|
||||||
|
# Unit2Control
|
||||||
|
split_map = {
|
||||||
|
"amplitudes": n_harmonics,
|
||||||
|
"group_delay": n_mag_allpass,
|
||||||
|
"noise_magnitude": n_mag_noise,
|
||||||
|
}
|
||||||
|
self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map)
|
||||||
|
|
||||||
|
def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, initial_phase=None, infer=True, max_upsample_dim=32):
|
||||||
|
"""
|
||||||
|
units_frames: B x n_frames x n_unit
|
||||||
|
f0_frames: B x n_frames x 1
|
||||||
|
volume_frames: B x n_frames x 1
|
||||||
|
spk_id: B x 1
|
||||||
|
"""
|
||||||
|
# exciter phase
|
||||||
|
f0 = upsample(f0_frames, self.block_size)
|
||||||
|
if infer:
|
||||||
|
x = torch.cumsum(f0.double() / self.sampling_rate, axis=1)
|
||||||
|
else:
|
||||||
|
x = torch.cumsum(f0 / self.sampling_rate, axis=1)
|
||||||
|
if initial_phase is not None:
|
||||||
|
x += initial_phase.to(x) / 2 / np.pi
|
||||||
|
x = x - torch.round(x)
|
||||||
|
x = x.to(f0)
|
||||||
|
|
||||||
|
phase = 2 * np.pi * x
|
||||||
|
phase_frames = phase[:, :: self.block_size, :]
|
||||||
|
|
||||||
|
# parameter prediction
|
||||||
|
ctrls = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict)
|
||||||
|
|
||||||
|
amplitudes_frames = torch.exp(ctrls["amplitudes"]) / 128
|
||||||
|
group_delay = np.pi * torch.tanh(ctrls["group_delay"])
|
||||||
|
noise_param = torch.exp(ctrls["noise_magnitude"]) / 128
|
||||||
|
|
||||||
|
# sinusoids exciter signal
|
||||||
|
amplitudes_frames = remove_above_fmax(amplitudes_frames, f0_frames, self.sampling_rate / 2, level_start=1)
|
||||||
|
n_harmonic = amplitudes_frames.shape[-1]
|
||||||
|
level_harmonic = torch.arange(1, n_harmonic + 1).to(phase)
|
||||||
|
sinusoids = 0.0
|
||||||
|
for n in range((n_harmonic - 1) // max_upsample_dim + 1):
|
||||||
|
start = n * max_upsample_dim
|
||||||
|
end = (n + 1) * max_upsample_dim
|
||||||
|
phases = phase * level_harmonic[start:end]
|
||||||
|
amplitudes = upsample(amplitudes_frames[:, :, start:end], self.block_size)
|
||||||
|
sinusoids += (torch.sin(phases) * amplitudes).sum(-1)
|
||||||
|
|
||||||
|
# harmonic part filter (apply group-delay)
|
||||||
|
harmonic = frequency_filter(sinusoids, torch.exp(1.0j * torch.cumsum(group_delay, axis=-1)), hann_window=False)
|
||||||
|
|
||||||
|
# noise part filter
|
||||||
|
noise = torch.rand_like(harmonic) * 2 - 1
|
||||||
|
noise = frequency_filter(noise, torch.complex(noise_param, torch.zeros_like(noise_param)), hann_window=True)
|
||||||
|
|
||||||
|
signal = harmonic + noise
|
||||||
|
|
||||||
|
return signal, phase, (harmonic, noise) # , (noise_param, noise_param)
|
||||||
|
|
||||||
|
|
||||||
|
class CombSubFast(torch.nn.Module):
|
||||||
|
def __init__(self, sampling_rate, block_size, n_unit=256, n_spk=1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
print(" [DDSP Model] Combtooth Subtractive Synthesiser")
|
||||||
|
# params
|
||||||
|
self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
|
||||||
|
self.register_buffer("block_size", torch.tensor(block_size))
|
||||||
|
self.register_buffer("window", torch.sqrt(torch.hann_window(2 * block_size)))
|
||||||
|
# Unit2Control
|
||||||
|
split_map = {"harmonic_magnitude": block_size + 1, "harmonic_phase": block_size + 1, "noise_magnitude": block_size + 1}
|
||||||
|
self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map)
|
||||||
|
|
||||||
|
def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, initial_phase=None, infer=True, **kwargs):
|
||||||
|
"""
|
||||||
|
units_frames: B x n_frames x n_unit
|
||||||
|
f0_frames: B x n_frames x 1
|
||||||
|
volume_frames: B x n_frames x 1
|
||||||
|
spk_id: B x 1
|
||||||
|
"""
|
||||||
|
# exciter phase
|
||||||
|
f0 = upsample(f0_frames, self.block_size)
|
||||||
|
if infer:
|
||||||
|
x = torch.cumsum(f0.double() / self.sampling_rate, axis=1)
|
||||||
|
else:
|
||||||
|
x = torch.cumsum(f0 / self.sampling_rate, axis=1)
|
||||||
|
if initial_phase is not None:
|
||||||
|
x += initial_phase.to(x) / 2 / np.pi
|
||||||
|
x = x - torch.round(x)
|
||||||
|
x = x.to(f0)
|
||||||
|
|
||||||
|
phase_frames = 2 * np.pi * x[:, :: self.block_size, :]
|
||||||
|
|
||||||
|
# parameter prediction
|
||||||
|
ctrls = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict)
|
||||||
|
|
||||||
|
src_filter = torch.exp(ctrls["harmonic_magnitude"] + 1.0j * np.pi * ctrls["harmonic_phase"])
|
||||||
|
src_filter = torch.cat((src_filter, src_filter[:, -1:, :]), 1)
|
||||||
|
noise_filter = torch.exp(ctrls["noise_magnitude"]) / 128
|
||||||
|
noise_filter = torch.cat((noise_filter, noise_filter[:, -1:, :]), 1)
|
||||||
|
|
||||||
|
# combtooth exciter signal
|
||||||
|
combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3))
|
||||||
|
combtooth = combtooth.squeeze(-1)
|
||||||
|
combtooth_frames = F.pad(combtooth, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size)
|
||||||
|
combtooth_frames = combtooth_frames * self.window
|
||||||
|
combtooth_fft = torch.fft.rfft(combtooth_frames, 2 * self.block_size)
|
||||||
|
|
||||||
|
# noise exciter signal
|
||||||
|
noise = torch.rand_like(combtooth) * 2 - 1
|
||||||
|
noise_frames = F.pad(noise, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size)
|
||||||
|
noise_frames = noise_frames * self.window
|
||||||
|
noise_fft = torch.fft.rfft(noise_frames, 2 * self.block_size)
|
||||||
|
|
||||||
|
# apply the filters
|
||||||
|
signal_fft = combtooth_fft * src_filter + noise_fft * noise_filter
|
||||||
|
|
||||||
|
# take the ifft to resynthesize audio.
|
||||||
|
signal_frames_out = torch.fft.irfft(signal_fft, 2 * self.block_size) * self.window
|
||||||
|
|
||||||
|
# overlap add
|
||||||
|
fold = torch.nn.Fold(output_size=(1, (signal_frames_out.size(1) + 1) * self.block_size), kernel_size=(1, 2 * self.block_size), stride=(1, self.block_size))
|
||||||
|
signal = fold(signal_frames_out.transpose(1, 2))[:, 0, 0, self.block_size : -self.block_size]
|
||||||
|
|
||||||
|
return signal, phase_frames, (signal, signal)
|
||||||
|
|
||||||
|
|
||||||
|
class CombSub(torch.nn.Module):
|
||||||
|
def __init__(self, sampling_rate, block_size, n_mag_allpass, n_mag_harmonic, n_mag_noise, n_unit=256, n_spk=1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
print(" [DDSP Model] Combtooth Subtractive Synthesiser (Old Version)")
|
||||||
|
# params
|
||||||
|
self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
|
||||||
|
self.register_buffer("block_size", torch.tensor(block_size))
|
||||||
|
# Unit2Control
|
||||||
|
split_map = {"group_delay": n_mag_allpass, "harmonic_magnitude": n_mag_harmonic, "noise_magnitude": n_mag_noise}
|
||||||
|
self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map)
|
||||||
|
|
||||||
|
def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, initial_phase=None, infer=True, **kwargs):
|
||||||
|
"""
|
||||||
|
units_frames: B x n_frames x n_unit
|
||||||
|
f0_frames: B x n_frames x 1
|
||||||
|
volume_frames: B x n_frames x 1
|
||||||
|
spk_id: B x 1
|
||||||
|
"""
|
||||||
|
# exciter phase
|
||||||
|
f0 = upsample(f0_frames, self.block_size)
|
||||||
|
if infer:
|
||||||
|
x = torch.cumsum(f0.double() / self.sampling_rate, axis=1)
|
||||||
|
else:
|
||||||
|
x = torch.cumsum(f0 / self.sampling_rate, axis=1)
|
||||||
|
if initial_phase is not None:
|
||||||
|
x += initial_phase.to(x) / 2 / np.pi
|
||||||
|
x = x - torch.round(x)
|
||||||
|
x = x.to(f0)
|
||||||
|
|
||||||
|
phase_frames = 2 * np.pi * x[:, :: self.block_size, :]
|
||||||
|
|
||||||
|
# parameter prediction
|
||||||
|
ctrls = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict)
|
||||||
|
|
||||||
|
group_delay = np.pi * torch.tanh(ctrls["group_delay"])
|
||||||
|
src_param = torch.exp(ctrls["harmonic_magnitude"])
|
||||||
|
noise_param = torch.exp(ctrls["noise_magnitude"]) / 128
|
||||||
|
|
||||||
|
# combtooth exciter signal
|
||||||
|
combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3))
|
||||||
|
combtooth = combtooth.squeeze(-1)
|
||||||
|
|
||||||
|
# harmonic part filter (using dynamic-windowed LTV-FIR, with group-delay prediction)
|
||||||
|
harmonic = frequency_filter(combtooth, torch.exp(1.0j * torch.cumsum(group_delay, axis=-1)), hann_window=False)
|
||||||
|
harmonic = frequency_filter(harmonic, torch.complex(src_param, torch.zeros_like(src_param)), hann_window=True, half_width_frames=1.5 * self.sampling_rate / (f0_frames + 1e-3))
|
||||||
|
|
||||||
|
# noise part filter (using constant-windowed LTV-FIR, without group-delay)
|
||||||
|
noise = torch.rand_like(harmonic) * 2 - 1
|
||||||
|
noise = frequency_filter(noise, torch.complex(noise_param, torch.zeros_like(noise_param)), hann_window=True)
|
||||||
|
|
||||||
|
signal = harmonic + noise
|
||||||
|
|
||||||
|
return signal, phase_frames, (harmonic, noise)
|
215
server/voice_changer/DDSP_SVC/models/diffusion/data_loaders.py
Normal file
215
server/voice_changer/DDSP_SVC/models/diffusion/data_loaders.py
Normal file
@ -0,0 +1,215 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import numpy as np
|
||||||
|
import librosa
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
def traverse_dir(root_dir, extensions, amount=None, str_include=None, str_exclude=None, is_pure=False, is_sort=False, is_ext=True):
|
||||||
|
file_list = []
|
||||||
|
cnt = 0
|
||||||
|
for root, _, files in os.walk(root_dir):
|
||||||
|
for file in files:
|
||||||
|
if any([file.endswith(f".{ext}") for ext in extensions]):
|
||||||
|
# path
|
||||||
|
mix_path = os.path.join(root, file)
|
||||||
|
pure_path = mix_path[len(root_dir) + 1 :] if is_pure else mix_path
|
||||||
|
|
||||||
|
# amount
|
||||||
|
if (amount is not None) and (cnt == amount):
|
||||||
|
if is_sort:
|
||||||
|
file_list.sort()
|
||||||
|
return file_list
|
||||||
|
|
||||||
|
# check string
|
||||||
|
if (str_include is not None) and (str_include not in pure_path):
|
||||||
|
continue
|
||||||
|
if (str_exclude is not None) and (str_exclude in pure_path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not is_ext:
|
||||||
|
ext = pure_path.split(".")[-1]
|
||||||
|
pure_path = pure_path[: -(len(ext) + 1)]
|
||||||
|
file_list.append(pure_path)
|
||||||
|
cnt += 1
|
||||||
|
if is_sort:
|
||||||
|
file_list.sort()
|
||||||
|
return file_list
|
||||||
|
|
||||||
|
|
||||||
|
def get_data_loaders(args, whole_audio=False):
|
||||||
|
data_train = AudioDataset(args.data.train_path, waveform_sec=args.data.duration, hop_size=args.data.block_size, sample_rate=args.data.sampling_rate, load_all_data=args.train.cache_all_data, whole_audio=whole_audio, extensions=args.data.extensions, n_spk=args.model.n_spk, device=args.train.cache_device, fp16=args.train.cache_fp16, use_aug=True)
|
||||||
|
loader_train = torch.utils.data.DataLoader(data_train, batch_size=args.train.batch_size if not whole_audio else 1, shuffle=True, num_workers=args.train.num_workers if args.train.cache_device == "cpu" else 0, persistent_workers=(args.train.num_workers > 0) if args.train.cache_device == "cpu" else False, pin_memory=True if args.train.cache_device == "cpu" else False)
|
||||||
|
data_valid = AudioDataset(args.data.valid_path, waveform_sec=args.data.duration, hop_size=args.data.block_size, sample_rate=args.data.sampling_rate, load_all_data=args.train.cache_all_data, whole_audio=True, extensions=args.data.extensions, n_spk=args.model.n_spk)
|
||||||
|
loader_valid = torch.utils.data.DataLoader(data_valid, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
|
||||||
|
return loader_train, loader_valid
|
||||||
|
|
||||||
|
|
||||||
|
class AudioDataset(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path_root,
|
||||||
|
waveform_sec,
|
||||||
|
hop_size,
|
||||||
|
sample_rate,
|
||||||
|
load_all_data=True,
|
||||||
|
whole_audio=False,
|
||||||
|
extensions=["wav"],
|
||||||
|
n_spk=1,
|
||||||
|
device="cpu",
|
||||||
|
fp16=False,
|
||||||
|
use_aug=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.waveform_sec = waveform_sec
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.hop_size = hop_size
|
||||||
|
self.path_root = path_root
|
||||||
|
self.paths = traverse_dir(os.path.join(path_root, "audio"), extensions=extensions, is_pure=True, is_sort=True, is_ext=True)
|
||||||
|
self.whole_audio = whole_audio
|
||||||
|
self.use_aug = use_aug
|
||||||
|
self.data_buffer = {}
|
||||||
|
self.pitch_aug_dict = np.load(os.path.join(self.path_root, "pitch_aug_dict.npy"), allow_pickle=True).item()
|
||||||
|
if load_all_data:
|
||||||
|
print("Load all the data from :", path_root)
|
||||||
|
else:
|
||||||
|
print("Load the f0, volume data from :", path_root)
|
||||||
|
for name_ext in tqdm(self.paths, total=len(self.paths)):
|
||||||
|
name = os.path.splitext(name_ext)[0] # NOQA
|
||||||
|
path_audio = os.path.join(self.path_root, "audio", name_ext)
|
||||||
|
duration = librosa.get_duration(filename=path_audio, sr=self.sample_rate)
|
||||||
|
|
||||||
|
path_f0 = os.path.join(self.path_root, "f0", name_ext) + ".npy"
|
||||||
|
f0 = np.load(path_f0)
|
||||||
|
f0 = torch.from_numpy(f0).float().unsqueeze(-1).to(device)
|
||||||
|
|
||||||
|
path_volume = os.path.join(self.path_root, "volume", name_ext) + ".npy"
|
||||||
|
volume = np.load(path_volume)
|
||||||
|
volume = torch.from_numpy(volume).float().unsqueeze(-1).to(device)
|
||||||
|
|
||||||
|
path_augvol = os.path.join(self.path_root, "aug_vol", name_ext) + ".npy"
|
||||||
|
aug_vol = np.load(path_augvol)
|
||||||
|
aug_vol = torch.from_numpy(aug_vol).float().unsqueeze(-1).to(device)
|
||||||
|
|
||||||
|
if n_spk is not None and n_spk > 1:
|
||||||
|
dirname_split = re.split(r"_|\-", os.path.dirname(name_ext), 2)[0]
|
||||||
|
spk_id = int(dirname_split) if str.isdigit(dirname_split) else 0
|
||||||
|
if spk_id < 1 or spk_id > n_spk:
|
||||||
|
raise ValueError(" [x] Muiti-speaker traing error : spk_id must be a positive integer from 1 to n_spk ")
|
||||||
|
else:
|
||||||
|
spk_id = 1
|
||||||
|
spk_id = torch.LongTensor(np.array([spk_id])).to(device)
|
||||||
|
|
||||||
|
if load_all_data:
|
||||||
|
"""
|
||||||
|
audio, sr = librosa.load(path_audio, sr=self.sample_rate)
|
||||||
|
if len(audio.shape) > 1:
|
||||||
|
audio = librosa.to_mono(audio)
|
||||||
|
audio = torch.from_numpy(audio).to(device)
|
||||||
|
"""
|
||||||
|
path_mel = os.path.join(self.path_root, "mel", name_ext) + ".npy"
|
||||||
|
mel = np.load(path_mel)
|
||||||
|
mel = torch.from_numpy(mel).to(device)
|
||||||
|
|
||||||
|
path_augmel = os.path.join(self.path_root, "aug_mel", name_ext) + ".npy"
|
||||||
|
aug_mel = np.load(path_augmel)
|
||||||
|
aug_mel = torch.from_numpy(aug_mel).to(device)
|
||||||
|
|
||||||
|
path_units = os.path.join(self.path_root, "units", name_ext) + ".npy"
|
||||||
|
units = np.load(path_units)
|
||||||
|
units = torch.from_numpy(units).to(device)
|
||||||
|
|
||||||
|
if fp16:
|
||||||
|
mel = mel.half()
|
||||||
|
aug_mel = aug_mel.half()
|
||||||
|
units = units.half()
|
||||||
|
|
||||||
|
self.data_buffer[name_ext] = {"duration": duration, "mel": mel, "aug_mel": aug_mel, "units": units, "f0": f0, "volume": volume, "aug_vol": aug_vol, "spk_id": spk_id}
|
||||||
|
else:
|
||||||
|
self.data_buffer[name_ext] = {"duration": duration, "f0": f0, "volume": volume, "aug_vol": aug_vol, "spk_id": spk_id}
|
||||||
|
|
||||||
|
def __getitem__(self, file_idx):
|
||||||
|
name_ext = self.paths[file_idx]
|
||||||
|
data_buffer = self.data_buffer[name_ext]
|
||||||
|
# check duration. if too short, then skip
|
||||||
|
if data_buffer["duration"] < (self.waveform_sec + 0.1):
|
||||||
|
return self.__getitem__((file_idx + 1) % len(self.paths))
|
||||||
|
|
||||||
|
# get item
|
||||||
|
return self.get_data(name_ext, data_buffer)
|
||||||
|
|
||||||
|
def get_data(self, name_ext, data_buffer):
|
||||||
|
name = os.path.splitext(name_ext)[0]
|
||||||
|
frame_resolution = self.hop_size / self.sample_rate
|
||||||
|
duration = data_buffer["duration"]
|
||||||
|
waveform_sec = duration if self.whole_audio else self.waveform_sec
|
||||||
|
|
||||||
|
# load audio
|
||||||
|
idx_from = 0 if self.whole_audio else random.uniform(0, duration - waveform_sec - 0.1)
|
||||||
|
start_frame = int(idx_from / frame_resolution)
|
||||||
|
units_frame_len = int(waveform_sec / frame_resolution)
|
||||||
|
aug_flag = random.choice([True, False]) and self.use_aug
|
||||||
|
"""
|
||||||
|
audio = data_buffer.get('audio')
|
||||||
|
if audio is None:
|
||||||
|
path_audio = os.path.join(self.path_root, 'audio', name) + '.wav'
|
||||||
|
audio, sr = librosa.load(
|
||||||
|
path_audio,
|
||||||
|
sr = self.sample_rate,
|
||||||
|
offset = start_frame * frame_resolution,
|
||||||
|
duration = waveform_sec)
|
||||||
|
if len(audio.shape) > 1:
|
||||||
|
audio = librosa.to_mono(audio)
|
||||||
|
# clip audio into N seconds
|
||||||
|
audio = audio[ : audio.shape[-1] // self.hop_size * self.hop_size]
|
||||||
|
audio = torch.from_numpy(audio).float()
|
||||||
|
else:
|
||||||
|
audio = audio[start_frame * self.hop_size : (start_frame + units_frame_len) * self.hop_size]
|
||||||
|
"""
|
||||||
|
# load mel
|
||||||
|
mel_key = "aug_mel" if aug_flag else "mel"
|
||||||
|
mel = data_buffer.get(mel_key)
|
||||||
|
if mel is None:
|
||||||
|
mel = os.path.join(self.path_root, mel_key, name_ext) + ".npy"
|
||||||
|
mel = np.load(mel)
|
||||||
|
mel = mel[start_frame : start_frame + units_frame_len]
|
||||||
|
mel = torch.from_numpy(mel).float()
|
||||||
|
else:
|
||||||
|
mel = mel[start_frame : start_frame + units_frame_len]
|
||||||
|
|
||||||
|
# load units
|
||||||
|
units = data_buffer.get("units")
|
||||||
|
if units is None:
|
||||||
|
units = os.path.join(self.path_root, "units", name_ext) + ".npy"
|
||||||
|
units = np.load(units)
|
||||||
|
units = units[start_frame : start_frame + units_frame_len]
|
||||||
|
units = torch.from_numpy(units).float()
|
||||||
|
else:
|
||||||
|
units = units[start_frame : start_frame + units_frame_len]
|
||||||
|
|
||||||
|
# load f0
|
||||||
|
f0 = data_buffer.get("f0")
|
||||||
|
aug_shift = 0
|
||||||
|
if aug_flag:
|
||||||
|
aug_shift = self.pitch_aug_dict[name_ext]
|
||||||
|
f0_frames = 2 ** (aug_shift / 12) * f0[start_frame : start_frame + units_frame_len]
|
||||||
|
|
||||||
|
# load volume
|
||||||
|
vol_key = "aug_vol" if aug_flag else "volume"
|
||||||
|
volume = data_buffer.get(vol_key)
|
||||||
|
volume_frames = volume[start_frame : start_frame + units_frame_len]
|
||||||
|
|
||||||
|
# load spk_id
|
||||||
|
spk_id = data_buffer.get("spk_id")
|
||||||
|
|
||||||
|
# load shift
|
||||||
|
aug_shift = torch.from_numpy(np.array([[aug_shift]])).float()
|
||||||
|
|
||||||
|
return dict(mel=mel, f0=f0_frames, volume=volume_frames, units=units, spk_id=spk_id, aug_shift=aug_shift, name=name, name_ext=name_ext)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.paths)
|
342
server/voice_changer/DDSP_SVC/models/diffusion/diffusion.py
Normal file
342
server/voice_changer/DDSP_SVC/models/diffusion/diffusion.py
Normal file
@ -0,0 +1,342 @@
|
|||||||
|
from collections import deque
|
||||||
|
from functools import partial
|
||||||
|
from inspect import isfunction
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def exists(x):
|
||||||
|
return x is not None
|
||||||
|
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
if exists(val):
|
||||||
|
return val
|
||||||
|
return d() if isfunction(d) else d
|
||||||
|
|
||||||
|
|
||||||
|
def extract(a, t, x_shape):
|
||||||
|
b, *_ = t.shape
|
||||||
|
out = a.gather(-1, t)
|
||||||
|
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||||
|
|
||||||
|
|
||||||
|
def noise_like(shape, device, repeat=False):
|
||||||
|
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) # NOQA
|
||||||
|
noise = lambda: torch.randn(shape, device=device) # NOQA
|
||||||
|
return repeat_noise() if repeat else noise()
|
||||||
|
|
||||||
|
|
||||||
|
def linear_beta_schedule(timesteps, max_beta=0.02):
|
||||||
|
"""
|
||||||
|
linear schedule
|
||||||
|
"""
|
||||||
|
betas = np.linspace(1e-4, max_beta, timesteps)
|
||||||
|
return betas
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_beta_schedule(timesteps, s=0.008):
|
||||||
|
"""
|
||||||
|
cosine schedule
|
||||||
|
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
||||||
|
"""
|
||||||
|
steps = timesteps + 1
|
||||||
|
x = np.linspace(0, steps, steps)
|
||||||
|
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
|
||||||
|
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||||
|
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||||
|
return np.clip(betas, a_min=0, a_max=0.999)
|
||||||
|
|
||||||
|
|
||||||
|
beta_schedule = {
|
||||||
|
"cosine": cosine_beta_schedule,
|
||||||
|
"linear": linear_beta_schedule,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianDiffusion(nn.Module):
|
||||||
|
def __init__(self, denoise_fn, out_dims=128, timesteps=1000, k_step=1000, max_beta=0.02, spec_min=-12, spec_max=2):
|
||||||
|
super().__init__()
|
||||||
|
self.denoise_fn = denoise_fn
|
||||||
|
self.out_dims = out_dims
|
||||||
|
betas = beta_schedule["linear"](timesteps, max_beta=max_beta)
|
||||||
|
|
||||||
|
alphas = 1.0 - betas
|
||||||
|
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||||
|
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
|
||||||
|
|
||||||
|
(timesteps,) = betas.shape
|
||||||
|
self.num_timesteps = int(timesteps)
|
||||||
|
self.k_step = k_step
|
||||||
|
|
||||||
|
self.noise_list = deque(maxlen=4)
|
||||||
|
|
||||||
|
to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||||
|
|
||||||
|
self.register_buffer("betas", to_torch(betas))
|
||||||
|
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||||
|
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
|
||||||
|
|
||||||
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
|
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
|
||||||
|
self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)))
|
||||||
|
self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)))
|
||||||
|
self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)))
|
||||||
|
self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)))
|
||||||
|
|
||||||
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||||
|
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||||
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||||
|
self.register_buffer("posterior_variance", to_torch(posterior_variance))
|
||||||
|
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||||
|
self.register_buffer("posterior_log_variance_clipped", to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
||||||
|
self.register_buffer("posterior_mean_coef1", to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)))
|
||||||
|
self.register_buffer("posterior_mean_coef2", to_torch((1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)))
|
||||||
|
|
||||||
|
self.register_buffer("spec_min", torch.FloatTensor([spec_min])[None, None, :out_dims])
|
||||||
|
self.register_buffer("spec_max", torch.FloatTensor([spec_max])[None, None, :out_dims])
|
||||||
|
|
||||||
|
def q_mean_variance(self, x_start, t):
|
||||||
|
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||||
|
variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape)
|
||||||
|
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
||||||
|
return mean, variance, log_variance
|
||||||
|
|
||||||
|
def predict_start_from_noise(self, x_t, t, noise):
|
||||||
|
return extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
||||||
|
|
||||||
|
def q_posterior(self, x_start, x_t, t):
|
||||||
|
posterior_mean = extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||||
|
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
||||||
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||||
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||||
|
|
||||||
|
def p_mean_variance(self, x, t, cond):
|
||||||
|
noise_pred = self.denoise_fn(x, t, cond=cond)
|
||||||
|
x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
|
||||||
|
|
||||||
|
x_recon.clamp_(-1.0, 1.0)
|
||||||
|
|
||||||
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||||
|
return model_mean, posterior_variance, posterior_log_variance
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
|
||||||
|
b, *_, device = *x.shape, x.device # type:ignore
|
||||||
|
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond)
|
||||||
|
noise = noise_like(x.shape, device, repeat_noise)
|
||||||
|
# no noise when t == 0
|
||||||
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_ddim(self, x, t, interval, cond):
|
||||||
|
a_t = extract(self.alphas_cumprod, t, x.shape)
|
||||||
|
a_prev = extract(self.alphas_cumprod, torch.max(t - interval, torch.zeros_like(t)), x.shape)
|
||||||
|
|
||||||
|
noise_pred = self.denoise_fn(x, t, cond=cond)
|
||||||
|
x_prev = a_prev.sqrt() * (x / a_t.sqrt() + (((1 - a_prev) / a_prev).sqrt() - ((1 - a_t) / a_t).sqrt()) * noise_pred)
|
||||||
|
return x_prev
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False):
|
||||||
|
"""
|
||||||
|
Use the PLMS method from
|
||||||
|
[Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_x_pred(x, noise_t, t):
|
||||||
|
a_t = extract(self.alphas_cumprod, t, x.shape)
|
||||||
|
a_prev = extract(self.alphas_cumprod, torch.max(t - interval, torch.zeros_like(t)), x.shape)
|
||||||
|
a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
|
||||||
|
|
||||||
|
x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
|
||||||
|
x_pred = x + x_delta
|
||||||
|
|
||||||
|
return x_pred
|
||||||
|
|
||||||
|
noise_list = self.noise_list
|
||||||
|
noise_pred = self.denoise_fn(x, t, cond=cond)
|
||||||
|
|
||||||
|
if len(noise_list) == 0:
|
||||||
|
x_pred = get_x_pred(x, noise_pred, t)
|
||||||
|
noise_pred_prev = self.denoise_fn(x_pred, max(t - interval, 0), cond=cond)
|
||||||
|
noise_pred_prime = (noise_pred + noise_pred_prev) / 2
|
||||||
|
elif len(noise_list) == 1:
|
||||||
|
noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2
|
||||||
|
elif len(noise_list) == 2:
|
||||||
|
noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12
|
||||||
|
else:
|
||||||
|
noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24
|
||||||
|
|
||||||
|
x_prev = get_x_pred(x, noise_pred_prime, t)
|
||||||
|
noise_list.append(noise_pred)
|
||||||
|
|
||||||
|
return x_prev
|
||||||
|
|
||||||
|
def q_sample(self, x_start, t, noise=None):
|
||||||
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
return extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||||
|
|
||||||
|
def p_losses(self, x_start, t, cond, noise=None, loss_type="l2"):
|
||||||
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
|
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
||||||
|
x_recon = self.denoise_fn(x_noisy, t, cond)
|
||||||
|
|
||||||
|
if loss_type == "l1":
|
||||||
|
loss = (noise - x_recon).abs().mean()
|
||||||
|
elif loss_type == "l2":
|
||||||
|
loss = F.mse_loss(noise, x_recon)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def forward(self, condition, gt_spec=None, infer=True, infer_speedup=10, method="dpm-solver", k_step=None, use_tqdm=True):
|
||||||
|
"""
|
||||||
|
conditioning diffusion, use fastspeech2 encoder output as the condition
|
||||||
|
"""
|
||||||
|
cond = condition.transpose(1, 2)
|
||||||
|
b, device = condition.shape[0], condition.device
|
||||||
|
|
||||||
|
if not infer:
|
||||||
|
spec = self.norm_spec(gt_spec)
|
||||||
|
if k_step is None:
|
||||||
|
t_max = self.k_step
|
||||||
|
else:
|
||||||
|
t_max = k_step
|
||||||
|
t = torch.randint(0, t_max, (b,), device=device).long()
|
||||||
|
norm_spec = spec.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
|
||||||
|
return self.p_losses(norm_spec, t, cond=cond)
|
||||||
|
else:
|
||||||
|
shape = (cond.shape[0], 1, self.out_dims, cond.shape[2])
|
||||||
|
|
||||||
|
if gt_spec is None or k_step is None:
|
||||||
|
t = self.k_step
|
||||||
|
x = torch.randn(shape, device=device)
|
||||||
|
else:
|
||||||
|
t = k_step
|
||||||
|
norm_spec = self.norm_spec(gt_spec)
|
||||||
|
norm_spec = norm_spec.transpose(1, 2)[:, None, :, :]
|
||||||
|
x = self.q_sample(x_start=norm_spec, t=torch.tensor([t - 1], device=device).long())
|
||||||
|
|
||||||
|
if method is not None and infer_speedup > 1:
|
||||||
|
if method == "dpm-solver":
|
||||||
|
from .dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
|
||||||
|
|
||||||
|
# 1. Define the noise schedule.
|
||||||
|
noise_schedule = NoiseScheduleVP(schedule="discrete", betas=self.betas[:t])
|
||||||
|
|
||||||
|
# 2. Convert your discrete-time `model` to the continuous-time
|
||||||
|
# noise prediction model. Here is an example for a diffusion model
|
||||||
|
# `model` with the noise prediction type ("noise") .
|
||||||
|
def my_wrapper(fn):
|
||||||
|
def wrapped(x, t, **kwargs):
|
||||||
|
ret = fn(x, t, **kwargs)
|
||||||
|
if use_tqdm:
|
||||||
|
self.bar.update(1)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
model_fn = model_wrapper(my_wrapper(self.denoise_fn), noise_schedule, model_type="noise", model_kwargs={"cond": cond}) # or "x_start" or "v" or "score"
|
||||||
|
|
||||||
|
# 3. Define dpm-solver and sample by singlestep DPM-Solver.
|
||||||
|
# (We recommend singlestep DPM-Solver for unconditional sampling)
|
||||||
|
# You can adjust the `steps` to balance the computation
|
||||||
|
# costs and the sample quality.
|
||||||
|
dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
|
||||||
|
|
||||||
|
steps = t // infer_speedup
|
||||||
|
if use_tqdm:
|
||||||
|
self.bar = tqdm(desc="sample time step", total=steps)
|
||||||
|
x = dpm_solver.sample(
|
||||||
|
x,
|
||||||
|
steps=steps,
|
||||||
|
order=2,
|
||||||
|
skip_type="time_uniform",
|
||||||
|
method="multistep",
|
||||||
|
)
|
||||||
|
if use_tqdm:
|
||||||
|
self.bar.close()
|
||||||
|
elif method == "unipc":
|
||||||
|
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
|
||||||
|
|
||||||
|
# 1. Define the noise schedule.
|
||||||
|
noise_schedule = NoiseScheduleVP(schedule="discrete", betas=self.betas[:t])
|
||||||
|
|
||||||
|
# 2. Convert your discrete-time `model` to the continuous-time
|
||||||
|
# noise prediction model. Here is an example for a diffusion model
|
||||||
|
# `model` with the noise prediction type ("noise") .
|
||||||
|
def my_wrapper(fn):
|
||||||
|
def wrapped(x, t, **kwargs):
|
||||||
|
ret = fn(x, t, **kwargs)
|
||||||
|
if use_tqdm:
|
||||||
|
self.bar.update(1)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
model_fn = model_wrapper(my_wrapper(self.denoise_fn), noise_schedule, model_type="noise", model_kwargs={"cond": cond}) # or "x_start" or "v" or "score"
|
||||||
|
|
||||||
|
# 3. Define uni_pc and sample by multistep UniPC.
|
||||||
|
# You can adjust the `steps` to balance the computation
|
||||||
|
# costs and the sample quality.
|
||||||
|
uni_pc = UniPC(model_fn, noise_schedule, variant="bh2")
|
||||||
|
|
||||||
|
steps = t // infer_speedup
|
||||||
|
if use_tqdm:
|
||||||
|
self.bar = tqdm(desc="sample time step", total=steps)
|
||||||
|
x = uni_pc.sample(
|
||||||
|
x,
|
||||||
|
steps=steps,
|
||||||
|
order=2,
|
||||||
|
skip_type="time_uniform",
|
||||||
|
method="multistep",
|
||||||
|
)
|
||||||
|
if use_tqdm:
|
||||||
|
self.bar.close()
|
||||||
|
elif method == "pndm":
|
||||||
|
self.noise_list = deque(maxlen=4)
|
||||||
|
if use_tqdm:
|
||||||
|
for i in tqdm(
|
||||||
|
reversed(range(0, t, infer_speedup)),
|
||||||
|
desc="sample time step",
|
||||||
|
total=t // infer_speedup,
|
||||||
|
):
|
||||||
|
x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond)
|
||||||
|
else:
|
||||||
|
for i in reversed(range(0, t, infer_speedup)):
|
||||||
|
x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond)
|
||||||
|
elif method == "ddim":
|
||||||
|
if use_tqdm:
|
||||||
|
for i in tqdm(
|
||||||
|
reversed(range(0, t, infer_speedup)),
|
||||||
|
desc="sample time step",
|
||||||
|
total=t // infer_speedup,
|
||||||
|
):
|
||||||
|
x = self.p_sample_ddim(x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond)
|
||||||
|
else:
|
||||||
|
for i in reversed(range(0, t, infer_speedup)):
|
||||||
|
x = self.p_sample_ddim(x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(method)
|
||||||
|
else:
|
||||||
|
if use_tqdm:
|
||||||
|
for i in tqdm(reversed(range(0, t)), desc="sample time step", total=t):
|
||||||
|
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
||||||
|
else:
|
||||||
|
for i in reversed(range(0, t)):
|
||||||
|
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
||||||
|
x = x.squeeze(1).transpose(1, 2) # [B, T, M]
|
||||||
|
return self.denorm_spec(x)
|
||||||
|
|
||||||
|
def norm_spec(self, x):
|
||||||
|
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
||||||
|
|
||||||
|
def denorm_spec(self, x):
|
||||||
|
return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
|
526
server/voice_changer/DDSP_SVC/models/diffusion/diffusion_onnx.py
Normal file
526
server/voice_changer/DDSP_SVC/models/diffusion/diffusion_onnx.py
Normal file
@ -0,0 +1,526 @@
|
|||||||
|
from collections import deque
|
||||||
|
from functools import partial
|
||||||
|
from inspect import isfunction
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
from torch.nn import Conv1d
|
||||||
|
from torch.nn import Mish
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from tqdm import tqdm
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def exists(x):
|
||||||
|
return x is not None
|
||||||
|
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
if exists(val):
|
||||||
|
return val
|
||||||
|
return d() if isfunction(d) else d
|
||||||
|
|
||||||
|
|
||||||
|
def extract(a, t):
|
||||||
|
return a[t].reshape((1, 1, 1, 1))
|
||||||
|
|
||||||
|
|
||||||
|
def noise_like(shape, device, repeat=False):
|
||||||
|
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) # NOQA
|
||||||
|
noise = lambda: torch.randn(shape, device=device) # NOQA
|
||||||
|
return repeat_noise() if repeat else noise()
|
||||||
|
|
||||||
|
|
||||||
|
def linear_beta_schedule(timesteps, max_beta=0.02):
|
||||||
|
"""
|
||||||
|
linear schedule
|
||||||
|
"""
|
||||||
|
betas = np.linspace(1e-4, max_beta, timesteps)
|
||||||
|
return betas
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_beta_schedule(timesteps, s=0.008):
|
||||||
|
"""
|
||||||
|
cosine schedule
|
||||||
|
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
||||||
|
"""
|
||||||
|
steps = timesteps + 1
|
||||||
|
x = np.linspace(0, steps, steps)
|
||||||
|
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
|
||||||
|
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||||
|
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||||
|
return np.clip(betas, a_min=0, a_max=0.999)
|
||||||
|
|
||||||
|
|
||||||
|
beta_schedule = {
|
||||||
|
"cosine": cosine_beta_schedule,
|
||||||
|
"linear": linear_beta_schedule,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def extract_1(a, t):
|
||||||
|
return a[t].reshape((1, 1, 1, 1))
|
||||||
|
|
||||||
|
|
||||||
|
def predict_stage0(noise_pred, noise_pred_prev):
|
||||||
|
return (noise_pred + noise_pred_prev) / 2
|
||||||
|
|
||||||
|
|
||||||
|
def predict_stage1(noise_pred, noise_list):
|
||||||
|
return (noise_pred * 3 - noise_list[-1]) / 2
|
||||||
|
|
||||||
|
|
||||||
|
def predict_stage2(noise_pred, noise_list):
|
||||||
|
return (noise_pred * 23 - noise_list[-1] * 16 + noise_list[-2] * 5) / 12
|
||||||
|
|
||||||
|
|
||||||
|
def predict_stage3(noise_pred, noise_list):
|
||||||
|
return (noise_pred * 55 - noise_list[-1] * 59 + noise_list[-2] * 37 - noise_list[-3] * 9) / 24
|
||||||
|
|
||||||
|
|
||||||
|
class SinusoidalPosEmb(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.half_dim = dim // 2
|
||||||
|
self.emb = 9.21034037 / (self.half_dim - 1)
|
||||||
|
self.emb = torch.exp(torch.arange(self.half_dim) * torch.tensor(-self.emb)).unsqueeze(0)
|
||||||
|
self.emb = self.emb.cpu()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
emb = self.emb * x
|
||||||
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, encoder_hidden, residual_channels, dilation):
|
||||||
|
super().__init__()
|
||||||
|
self.residual_channels = residual_channels
|
||||||
|
self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation)
|
||||||
|
self.diffusion_projection = nn.Linear(residual_channels, residual_channels)
|
||||||
|
self.conditioner_projection = Conv1d(encoder_hidden, 2 * residual_channels, 1)
|
||||||
|
self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)
|
||||||
|
|
||||||
|
def forward(self, x, conditioner, diffusion_step):
|
||||||
|
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
|
||||||
|
conditioner = self.conditioner_projection(conditioner)
|
||||||
|
y = x + diffusion_step
|
||||||
|
y = self.dilated_conv(y) + conditioner
|
||||||
|
|
||||||
|
gate, filter_1 = torch.split(y, [self.residual_channels, self.residual_channels], dim=1)
|
||||||
|
|
||||||
|
y = torch.sigmoid(gate) * torch.tanh(filter_1)
|
||||||
|
y = self.output_projection(y)
|
||||||
|
|
||||||
|
residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1)
|
||||||
|
|
||||||
|
return (x + residual) / 1.41421356, skip
|
||||||
|
|
||||||
|
|
||||||
|
class DiffNet(nn.Module):
|
||||||
|
def __init__(self, in_dims, n_layers, n_chans, n_hidden):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder_hidden = n_hidden
|
||||||
|
self.residual_layers = n_layers
|
||||||
|
self.residual_channels = n_chans
|
||||||
|
self.input_projection = Conv1d(in_dims, self.residual_channels, 1)
|
||||||
|
self.diffusion_embedding = SinusoidalPosEmb(self.residual_channels)
|
||||||
|
dim = self.residual_channels
|
||||||
|
self.mlp = nn.Sequential(nn.Linear(dim, dim * 4), Mish(), nn.Linear(dim * 4, dim))
|
||||||
|
self.residual_layers = nn.ModuleList([ResidualBlock(self.encoder_hidden, self.residual_channels, 1) for i in range(self.residual_layers)])
|
||||||
|
self.skip_projection = Conv1d(self.residual_channels, self.residual_channels, 1)
|
||||||
|
self.output_projection = Conv1d(self.residual_channels, in_dims, 1)
|
||||||
|
nn.init.zeros_(self.output_projection.weight)
|
||||||
|
|
||||||
|
def forward(self, spec, diffusion_step, cond):
|
||||||
|
x = spec.squeeze(0)
|
||||||
|
x = self.input_projection(x) # x [B, residual_channel, T]
|
||||||
|
x = F.relu(x)
|
||||||
|
# skip = torch.randn_like(x)
|
||||||
|
diffusion_step = diffusion_step.float()
|
||||||
|
diffusion_step = self.diffusion_embedding(diffusion_step)
|
||||||
|
diffusion_step = self.mlp(diffusion_step)
|
||||||
|
|
||||||
|
x, skip = self.residual_layers[0](x, cond, diffusion_step)
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
for layer in self.residual_layers[1:]:
|
||||||
|
x, skip_connection = layer.forward(x, cond, diffusion_step)
|
||||||
|
skip = skip + skip_connection
|
||||||
|
x = skip / math.sqrt(len(self.residual_layers))
|
||||||
|
x = self.skip_projection(x)
|
||||||
|
x = F.relu(x)
|
||||||
|
x = self.output_projection(x) # [B, 80, T]
|
||||||
|
return x.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
class AfterDiffusion(nn.Module):
|
||||||
|
def __init__(self, spec_max, spec_min, v_type="a"):
|
||||||
|
super().__init__()
|
||||||
|
self.spec_max = spec_max
|
||||||
|
self.spec_min = spec_min
|
||||||
|
self.type = v_type
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.squeeze(1).permute(0, 2, 1)
|
||||||
|
mel_out = (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
|
||||||
|
if self.type == "nsf-hifigan-log10":
|
||||||
|
mel_out = mel_out * 0.434294
|
||||||
|
return mel_out.transpose(2, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class Pred(nn.Module):
|
||||||
|
def __init__(self, alphas_cumprod):
|
||||||
|
super().__init__()
|
||||||
|
self.alphas_cumprod = alphas_cumprod
|
||||||
|
|
||||||
|
def forward(self, x_1, noise_t, t_1, t_prev):
|
||||||
|
a_t = extract(self.alphas_cumprod, t_1).cpu()
|
||||||
|
a_prev = extract(self.alphas_cumprod, t_prev).cpu()
|
||||||
|
a_t_sq, a_prev_sq = a_t.sqrt().cpu(), a_prev.sqrt().cpu()
|
||||||
|
x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x_1 - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
|
||||||
|
x_pred = x_1 + x_delta.cpu()
|
||||||
|
|
||||||
|
return x_pred
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianDiffusion(nn.Module):
|
||||||
|
def __init__(self, out_dims=128, n_layers=20, n_chans=384, n_hidden=256, timesteps=1000, k_step=1000, max_beta=0.02, spec_min=-12, spec_max=2):
|
||||||
|
super().__init__()
|
||||||
|
self.denoise_fn = DiffNet(out_dims, n_layers, n_chans, n_hidden)
|
||||||
|
self.out_dims = out_dims
|
||||||
|
self.mel_bins = out_dims
|
||||||
|
self.n_hidden = n_hidden
|
||||||
|
betas = beta_schedule["linear"](timesteps, max_beta=max_beta)
|
||||||
|
|
||||||
|
alphas = 1.0 - betas
|
||||||
|
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||||
|
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
|
||||||
|
(timesteps,) = betas.shape
|
||||||
|
self.num_timesteps = int(timesteps)
|
||||||
|
self.k_step = k_step
|
||||||
|
|
||||||
|
self.noise_list = deque(maxlen=4)
|
||||||
|
|
||||||
|
to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||||
|
|
||||||
|
self.register_buffer("betas", to_torch(betas))
|
||||||
|
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||||
|
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
|
||||||
|
|
||||||
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
|
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
|
||||||
|
self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)))
|
||||||
|
self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)))
|
||||||
|
self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)))
|
||||||
|
self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)))
|
||||||
|
|
||||||
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||||
|
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||||
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||||
|
self.register_buffer("posterior_variance", to_torch(posterior_variance))
|
||||||
|
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||||
|
self.register_buffer("posterior_log_variance_clipped", to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
||||||
|
self.register_buffer("posterior_mean_coef1", to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)))
|
||||||
|
self.register_buffer("posterior_mean_coef2", to_torch((1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)))
|
||||||
|
|
||||||
|
self.register_buffer("spec_min", torch.FloatTensor([spec_min])[None, None, :out_dims])
|
||||||
|
self.register_buffer("spec_max", torch.FloatTensor([spec_max])[None, None, :out_dims])
|
||||||
|
self.ad = AfterDiffusion(self.spec_max, self.spec_min)
|
||||||
|
self.xp = Pred(self.alphas_cumprod)
|
||||||
|
|
||||||
|
def q_mean_variance(self, x_start, t):
|
||||||
|
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||||
|
variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape)
|
||||||
|
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
||||||
|
return mean, variance, log_variance
|
||||||
|
|
||||||
|
def predict_start_from_noise(self, x_t, t, noise):
|
||||||
|
return extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
||||||
|
|
||||||
|
def q_posterior(self, x_start, x_t, t):
|
||||||
|
posterior_mean = extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||||
|
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
||||||
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||||
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||||
|
|
||||||
|
def p_mean_variance(self, x, t, cond):
|
||||||
|
noise_pred = self.denoise_fn(x, t, cond=cond)
|
||||||
|
x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
|
||||||
|
|
||||||
|
x_recon.clamp_(-1.0, 1.0)
|
||||||
|
|
||||||
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||||
|
return model_mean, posterior_variance, posterior_log_variance
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
|
||||||
|
b, *_, device = *x.shape, x.device # type: ignore
|
||||||
|
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond)
|
||||||
|
noise = noise_like(x.shape, device, repeat_noise)
|
||||||
|
# no noise when t == 0
|
||||||
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False):
|
||||||
|
"""
|
||||||
|
Use the PLMS method from
|
||||||
|
[Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_x_pred(x, noise_t, t):
|
||||||
|
a_t = extract(self.alphas_cumprod, t)
|
||||||
|
a_prev = extract(self.alphas_cumprod, torch.max(t - interval, torch.zeros_like(t)))
|
||||||
|
a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
|
||||||
|
|
||||||
|
x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
|
||||||
|
x_pred = x + x_delta
|
||||||
|
|
||||||
|
return x_pred
|
||||||
|
|
||||||
|
noise_list = self.noise_list
|
||||||
|
noise_pred = self.denoise_fn(x, t, cond=cond)
|
||||||
|
|
||||||
|
if len(noise_list) == 0:
|
||||||
|
x_pred = get_x_pred(x, noise_pred, t)
|
||||||
|
noise_pred_prev = self.denoise_fn(x_pred, max(t - interval, 0), cond=cond)
|
||||||
|
noise_pred_prime = (noise_pred + noise_pred_prev) / 2
|
||||||
|
elif len(noise_list) == 1:
|
||||||
|
noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2
|
||||||
|
elif len(noise_list) == 2:
|
||||||
|
noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12
|
||||||
|
else:
|
||||||
|
noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24
|
||||||
|
|
||||||
|
x_prev = get_x_pred(x, noise_pred_prime, t)
|
||||||
|
noise_list.append(noise_pred)
|
||||||
|
|
||||||
|
return x_prev
|
||||||
|
|
||||||
|
def q_sample(self, x_start, t, noise=None):
|
||||||
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
return extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||||
|
|
||||||
|
def p_losses(self, x_start, t, cond, noise=None, loss_type="l2"):
|
||||||
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
|
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
||||||
|
x_recon = self.denoise_fn(x_noisy, t, cond)
|
||||||
|
|
||||||
|
if loss_type == "l1":
|
||||||
|
loss = (noise - x_recon).abs().mean()
|
||||||
|
elif loss_type == "l2":
|
||||||
|
loss = F.mse_loss(noise, x_recon)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def org_forward(self, condition, init_noise=None, gt_spec=None, infer=True, infer_speedup=100, method="pndm", k_step=1000, use_tqdm=True):
|
||||||
|
"""
|
||||||
|
conditioning diffusion, use fastspeech2 encoder output as the condition
|
||||||
|
"""
|
||||||
|
cond = condition
|
||||||
|
b, device = condition.shape[0], condition.device
|
||||||
|
if not infer:
|
||||||
|
spec = self.norm_spec(gt_spec)
|
||||||
|
t = torch.randint(0, self.k_step, (b,), device=device).long()
|
||||||
|
norm_spec = spec.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
|
||||||
|
return self.p_losses(norm_spec, t, cond=cond)
|
||||||
|
else:
|
||||||
|
shape = (cond.shape[0], 1, self.out_dims, cond.shape[2])
|
||||||
|
|
||||||
|
if gt_spec is None:
|
||||||
|
t = self.k_step
|
||||||
|
if init_noise is None:
|
||||||
|
x = torch.randn(shape, device=device)
|
||||||
|
else:
|
||||||
|
x = init_noise
|
||||||
|
else:
|
||||||
|
t = k_step
|
||||||
|
norm_spec = self.norm_spec(gt_spec)
|
||||||
|
norm_spec = norm_spec.transpose(1, 2)[:, None, :, :]
|
||||||
|
x = self.q_sample(x_start=norm_spec, t=torch.tensor([t - 1], device=device).long())
|
||||||
|
|
||||||
|
if method is not None and infer_speedup > 1:
|
||||||
|
if method == "dpm-solver":
|
||||||
|
from .dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
|
||||||
|
|
||||||
|
# 1. Define the noise schedule.
|
||||||
|
noise_schedule = NoiseScheduleVP(schedule="discrete", betas=self.betas[:t])
|
||||||
|
|
||||||
|
# 2. Convert your discrete-time `model` to the continuous-time
|
||||||
|
# noise prediction model. Here is an example for a diffusion model
|
||||||
|
# `model` with the noise prediction type ("noise") .
|
||||||
|
def my_wrapper(fn):
|
||||||
|
def wrapped(x, t, **kwargs):
|
||||||
|
ret = fn(x, t, **kwargs)
|
||||||
|
if use_tqdm:
|
||||||
|
self.bar.update(1)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
model_fn = model_wrapper(my_wrapper(self.denoise_fn), noise_schedule, model_type="noise", model_kwargs={"cond": cond}) # or "x_start" or "v" or "score"
|
||||||
|
|
||||||
|
# 3. Define dpm-solver and sample by singlestep DPM-Solver.
|
||||||
|
# (We recommend singlestep DPM-Solver for unconditional sampling)
|
||||||
|
# You can adjust the `steps` to balance the computation
|
||||||
|
# costs and the sample quality.
|
||||||
|
dpm_solver = DPM_Solver(model_fn, noise_schedule)
|
||||||
|
|
||||||
|
steps = t // infer_speedup
|
||||||
|
if use_tqdm:
|
||||||
|
self.bar = tqdm(desc="sample time step", total=steps)
|
||||||
|
x = dpm_solver.sample(
|
||||||
|
x,
|
||||||
|
steps=steps,
|
||||||
|
order=3,
|
||||||
|
skip_type="time_uniform",
|
||||||
|
method="singlestep",
|
||||||
|
)
|
||||||
|
if use_tqdm:
|
||||||
|
self.bar.close()
|
||||||
|
elif method == "pndm":
|
||||||
|
self.noise_list = deque(maxlen=4)
|
||||||
|
if use_tqdm:
|
||||||
|
for i in tqdm(
|
||||||
|
reversed(range(0, t, infer_speedup)),
|
||||||
|
desc="sample time step",
|
||||||
|
total=t // infer_speedup,
|
||||||
|
):
|
||||||
|
x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond)
|
||||||
|
else:
|
||||||
|
for i in reversed(range(0, t, infer_speedup)):
|
||||||
|
x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(method)
|
||||||
|
else:
|
||||||
|
if use_tqdm:
|
||||||
|
for i in tqdm(reversed(range(0, t)), desc="sample time step", total=t):
|
||||||
|
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
||||||
|
else:
|
||||||
|
for i in reversed(range(0, t)):
|
||||||
|
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
||||||
|
x = x.squeeze(1).transpose(1, 2) # [B, T, M]
|
||||||
|
return self.denorm_spec(x).transpose(2, 1)
|
||||||
|
|
||||||
|
def norm_spec(self, x):
|
||||||
|
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
||||||
|
|
||||||
|
def denorm_spec(self, x):
|
||||||
|
return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
|
||||||
|
|
||||||
|
def get_x_pred(self, x_1, noise_t, t_1, t_prev):
|
||||||
|
a_t = extract(self.alphas_cumprod, t_1)
|
||||||
|
a_prev = extract(self.alphas_cumprod, t_prev)
|
||||||
|
a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
|
||||||
|
x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x_1 - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
|
||||||
|
x_pred = x_1 + x_delta
|
||||||
|
return x_pred
|
||||||
|
|
||||||
|
def OnnxExport(self, project_name=None, init_noise=None, hidden_channels=256, export_denoise=True, export_pred=True, export_after=True):
|
||||||
|
cond = torch.randn([1, self.n_hidden, 10]).cpu()
|
||||||
|
if init_noise is None:
|
||||||
|
x = torch.randn((1, 1, self.mel_bins, cond.shape[2]), dtype=torch.float32).cpu()
|
||||||
|
else:
|
||||||
|
x = init_noise
|
||||||
|
pndms = 100
|
||||||
|
|
||||||
|
org_y_x = self.org_forward(cond, init_noise=x)
|
||||||
|
|
||||||
|
device = cond.device
|
||||||
|
n_frames = cond.shape[2]
|
||||||
|
step_range = torch.arange(0, self.k_step, pndms, dtype=torch.long, device=device).flip(0)
|
||||||
|
plms_noise_stage = torch.tensor(0, dtype=torch.long, device=device)
|
||||||
|
noise_list = torch.zeros((0, 1, 1, self.mel_bins, n_frames), device=device)
|
||||||
|
|
||||||
|
ot = step_range[0]
|
||||||
|
ot_1 = torch.full((1,), ot, device=device, dtype=torch.long)
|
||||||
|
if export_denoise:
|
||||||
|
torch.onnx.export(self.denoise_fn, (x.cpu(), ot_1.cpu(), cond.cpu()), f"{project_name}_denoise.onnx", input_names=["noise", "time", "condition"], output_names=["noise_pred"], dynamic_axes={"noise": [3], "condition": [2]}, opset_version=16)
|
||||||
|
|
||||||
|
for t in step_range:
|
||||||
|
t_1 = torch.full((1,), t, device=device, dtype=torch.long)
|
||||||
|
noise_pred = self.denoise_fn(x, t_1, cond)
|
||||||
|
t_prev = t_1 - pndms
|
||||||
|
t_prev = t_prev * (t_prev > 0)
|
||||||
|
if plms_noise_stage == 0:
|
||||||
|
if export_pred:
|
||||||
|
torch.onnx.export(self.xp, (x.cpu(), noise_pred.cpu(), t_1.cpu(), t_prev.cpu()), f"{project_name}_pred.onnx", input_names=["noise", "noise_pred", "time", "time_prev"], output_names=["noise_pred_o"], dynamic_axes={"noise": [3], "noise_pred": [3]}, opset_version=16)
|
||||||
|
|
||||||
|
x_pred = self.get_x_pred(x, noise_pred, t_1, t_prev)
|
||||||
|
noise_pred_prev = self.denoise_fn(x_pred, t_prev, cond=cond)
|
||||||
|
noise_pred_prime = predict_stage0(noise_pred, noise_pred_prev)
|
||||||
|
|
||||||
|
elif plms_noise_stage == 1:
|
||||||
|
noise_pred_prime = predict_stage1(noise_pred, noise_list)
|
||||||
|
|
||||||
|
elif plms_noise_stage == 2:
|
||||||
|
noise_pred_prime = predict_stage2(noise_pred, noise_list)
|
||||||
|
|
||||||
|
else:
|
||||||
|
noise_pred_prime = predict_stage3(noise_pred, noise_list)
|
||||||
|
|
||||||
|
noise_pred = noise_pred.unsqueeze(0)
|
||||||
|
|
||||||
|
if plms_noise_stage < 3:
|
||||||
|
noise_list = torch.cat((noise_list, noise_pred), dim=0)
|
||||||
|
plms_noise_stage = plms_noise_stage + 1
|
||||||
|
|
||||||
|
else:
|
||||||
|
noise_list = torch.cat((noise_list[-2:], noise_pred), dim=0)
|
||||||
|
|
||||||
|
x = self.get_x_pred(x, noise_pred_prime, t_1, t_prev)
|
||||||
|
if export_after:
|
||||||
|
torch.onnx.export(self.ad, x.cpu(), f"{project_name}_after.onnx", input_names=["x"], output_names=["mel_out"], dynamic_axes={"x": [3]}, opset_version=16)
|
||||||
|
x = self.ad(x)
|
||||||
|
|
||||||
|
print((x == org_y_x).all())
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, condition=None, init_noise=None, pndms=None, k_step=None):
|
||||||
|
cond = condition
|
||||||
|
x = init_noise
|
||||||
|
|
||||||
|
device = cond.device
|
||||||
|
n_frames = cond.shape[2]
|
||||||
|
step_range = torch.arange(0, k_step.item(), pndms.item(), dtype=torch.long, device=device).flip(0)
|
||||||
|
plms_noise_stage = torch.tensor(0, dtype=torch.long, device=device)
|
||||||
|
noise_list = torch.zeros((0, 1, 1, self.mel_bins, n_frames), device=device)
|
||||||
|
|
||||||
|
ot = step_range[0]
|
||||||
|
ot_1 = torch.full((1,), ot, device=device, dtype=torch.long) # NOQA
|
||||||
|
|
||||||
|
for t in step_range:
|
||||||
|
t_1 = torch.full((1,), t, device=device, dtype=torch.long)
|
||||||
|
noise_pred = self.denoise_fn(x, t_1, cond)
|
||||||
|
t_prev = t_1 - pndms
|
||||||
|
t_prev = t_prev * (t_prev > 0)
|
||||||
|
if plms_noise_stage == 0:
|
||||||
|
x_pred = self.get_x_pred(x, noise_pred, t_1, t_prev)
|
||||||
|
noise_pred_prev = self.denoise_fn(x_pred, t_prev, cond=cond)
|
||||||
|
noise_pred_prime = predict_stage0(noise_pred, noise_pred_prev)
|
||||||
|
|
||||||
|
elif plms_noise_stage == 1:
|
||||||
|
noise_pred_prime = predict_stage1(noise_pred, noise_list)
|
||||||
|
|
||||||
|
elif plms_noise_stage == 2:
|
||||||
|
noise_pred_prime = predict_stage2(noise_pred, noise_list)
|
||||||
|
|
||||||
|
else:
|
||||||
|
noise_pred_prime = predict_stage3(noise_pred, noise_list)
|
||||||
|
|
||||||
|
noise_pred = noise_pred.unsqueeze(0)
|
||||||
|
|
||||||
|
if plms_noise_stage < 3:
|
||||||
|
noise_list = torch.cat((noise_list, noise_pred), dim=0)
|
||||||
|
plms_noise_stage = plms_noise_stage + 1
|
||||||
|
|
||||||
|
else:
|
||||||
|
noise_list = torch.cat((noise_list[-2:], noise_pred), dim=0)
|
||||||
|
|
||||||
|
x = self.get_x_pred(x, noise_pred_prime, t_1, t_prev)
|
||||||
|
x = self.ad(x)
|
||||||
|
return x
|
1284
server/voice_changer/DDSP_SVC/models/diffusion/dpm_solver_pytorch.py
Normal file
1284
server/voice_changer/DDSP_SVC/models/diffusion/dpm_solver_pytorch.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,59 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .unit2mel import load_model_vocoder
|
||||||
|
|
||||||
|
|
||||||
|
class DiffGtMel:
|
||||||
|
def __init__(self, project_path=None, device=None):
|
||||||
|
self.project_path = project_path
|
||||||
|
if device is not None:
|
||||||
|
self.device = device
|
||||||
|
else:
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
self.model = None
|
||||||
|
self.vocoder = None
|
||||||
|
self.args = None
|
||||||
|
|
||||||
|
def flush_model(self, project_path, ddsp_config=None):
|
||||||
|
if (self.model is None) or (project_path != self.project_path):
|
||||||
|
model, vocoder, args = load_model_vocoder(project_path, device=self.device)
|
||||||
|
if self.check_args(ddsp_config, args):
|
||||||
|
self.model = model
|
||||||
|
self.vocoder = vocoder
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def check_args(self, args1, args2):
|
||||||
|
if args1.data.block_size != args2.data.block_size:
|
||||||
|
raise ValueError("DDSP与DIFF模型的block_size不一致")
|
||||||
|
if args1.data.sampling_rate != args2.data.sampling_rate:
|
||||||
|
raise ValueError("DDSP与DIFF模型的sampling_rate不一致")
|
||||||
|
if args1.data.encoder != args2.data.encoder:
|
||||||
|
raise ValueError("DDSP与DIFF模型的encoder不一致")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __call__(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, method="pndm", spk_mix_dict=None, start_frame=0):
|
||||||
|
input_mel = self.vocoder.extract(audio, self.args.data.sampling_rate)
|
||||||
|
out_mel = self.model(hubert, f0, volume, spk_id=spk_id, spk_mix_dict=spk_mix_dict, gt_spec=input_mel, infer=True, infer_speedup=acc, method=method, k_step=k_step, use_tqdm=False)
|
||||||
|
if start_frame > 0:
|
||||||
|
out_mel = out_mel[:, start_frame:, :]
|
||||||
|
f0 = f0[:, start_frame:, :]
|
||||||
|
output = self.vocoder.infer(out_mel, f0)
|
||||||
|
if start_frame > 0:
|
||||||
|
output = F.pad(output, (start_frame * self.vocoder.vocoder_hop_size, 0))
|
||||||
|
return output
|
||||||
|
|
||||||
|
def infer(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, method="pndm", silence_front=0, use_silence=False, spk_mix_dict=None):
|
||||||
|
start_frame = int(silence_front * self.vocoder.vocoder_sample_rate / self.vocoder.vocoder_hop_size)
|
||||||
|
if use_silence:
|
||||||
|
audio = audio[:, start_frame * self.vocoder.vocoder_hop_size :]
|
||||||
|
f0 = f0[:, start_frame:, :]
|
||||||
|
hubert = hubert[:, start_frame:, :]
|
||||||
|
volume = volume[:, start_frame:, :]
|
||||||
|
_start_frame = 0
|
||||||
|
else:
|
||||||
|
_start_frame = start_frame
|
||||||
|
audio = self.__call__(audio, f0, hubert, volume, acc=acc, spk_id=spk_id, k_step=k_step, method=method, spk_mix_dict=spk_mix_dict, start_frame=_start_frame)
|
||||||
|
if use_silence:
|
||||||
|
if start_frame > 0:
|
||||||
|
audio = F.pad(audio, (start_frame * self.vocoder.vocoder_hop_size, 0))
|
||||||
|
return audio
|
226
server/voice_changer/DDSP_SVC/models/diffusion/onnx_export.py
Normal file
226
server/voice_changer/DDSP_SVC/models/diffusion/onnx_export.py
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
from diffusion_onnx import GaussianDiffusion
|
||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from wavenet import WaveNet
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import diffusion
|
||||||
|
|
||||||
|
class DotDict(dict):
|
||||||
|
def __getattr__(*args):
|
||||||
|
val = dict.get(*args)
|
||||||
|
return DotDict(val) if type(val) is dict else val
|
||||||
|
|
||||||
|
__setattr__ = dict.__setitem__
|
||||||
|
__delattr__ = dict.__delitem__
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_vocoder(
|
||||||
|
model_path,
|
||||||
|
device='cpu'):
|
||||||
|
config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml')
|
||||||
|
with open(config_file, "r") as config:
|
||||||
|
args = yaml.safe_load(config)
|
||||||
|
args = DotDict(args)
|
||||||
|
|
||||||
|
# load model
|
||||||
|
model = Unit2Mel(
|
||||||
|
args.data.encoder_out_channels,
|
||||||
|
args.model.n_spk,
|
||||||
|
args.model.use_pitch_aug,
|
||||||
|
128,
|
||||||
|
args.model.n_layers,
|
||||||
|
args.model.n_chans,
|
||||||
|
args.model.n_hidden)
|
||||||
|
|
||||||
|
print(' [Loading] ' + model_path)
|
||||||
|
ckpt = torch.load(model_path, map_location=torch.device(device))
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(ckpt['model'])
|
||||||
|
model.eval()
|
||||||
|
return model, args
|
||||||
|
|
||||||
|
|
||||||
|
class Unit2Mel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_channel,
|
||||||
|
n_spk,
|
||||||
|
use_pitch_aug=False,
|
||||||
|
out_dims=128,
|
||||||
|
n_layers=20,
|
||||||
|
n_chans=384,
|
||||||
|
n_hidden=256):
|
||||||
|
super().__init__()
|
||||||
|
self.unit_embed = nn.Linear(input_channel, n_hidden)
|
||||||
|
self.f0_embed = nn.Linear(1, n_hidden)
|
||||||
|
self.volume_embed = nn.Linear(1, n_hidden)
|
||||||
|
if use_pitch_aug:
|
||||||
|
self.aug_shift_embed = nn.Linear(1, n_hidden, bias=False)
|
||||||
|
else:
|
||||||
|
self.aug_shift_embed = None
|
||||||
|
self.n_spk = n_spk
|
||||||
|
if n_spk is not None and n_spk > 1:
|
||||||
|
self.spk_embed = nn.Embedding(n_spk, n_hidden)
|
||||||
|
|
||||||
|
# diffusion
|
||||||
|
self.decoder = GaussianDiffusion(out_dims, n_layers, n_chans, n_hidden)
|
||||||
|
self.hidden_size = n_hidden
|
||||||
|
self.speaker_map = torch.zeros((self.n_spk,1,1,n_hidden))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, units, mel2ph, f0, volume, g = None):
|
||||||
|
|
||||||
|
'''
|
||||||
|
input:
|
||||||
|
B x n_frames x n_unit
|
||||||
|
return:
|
||||||
|
dict of B x n_frames x feat
|
||||||
|
'''
|
||||||
|
|
||||||
|
decoder_inp = F.pad(units, [0, 0, 1, 0])
|
||||||
|
mel2ph_ = mel2ph.unsqueeze(2).repeat([1, 1, units.shape[-1]])
|
||||||
|
units = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
|
||||||
|
|
||||||
|
x = self.unit_embed(units) + self.f0_embed((1 + f0.unsqueeze(-1) / 700).log()) + self.volume_embed(volume.unsqueeze(-1))
|
||||||
|
|
||||||
|
if self.n_spk is not None and self.n_spk > 1: # [N, S] * [S, B, 1, H]
|
||||||
|
g = g.reshape((g.shape[0], g.shape[1], 1, 1, 1)) # [N, S, B, 1, 1]
|
||||||
|
g = g * self.speaker_map # [N, S, B, 1, H]
|
||||||
|
g = torch.sum(g, dim=1) # [N, 1, B, 1, H]
|
||||||
|
g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N]
|
||||||
|
x = x.transpose(1, 2) + g
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return x.transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def init_spkembed(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None,
|
||||||
|
gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=300, use_tqdm=True):
|
||||||
|
|
||||||
|
'''
|
||||||
|
input:
|
||||||
|
B x n_frames x n_unit
|
||||||
|
return:
|
||||||
|
dict of B x n_frames x feat
|
||||||
|
'''
|
||||||
|
x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume)
|
||||||
|
if self.n_spk is not None and self.n_spk > 1:
|
||||||
|
if spk_mix_dict is not None:
|
||||||
|
spk_embed_mix = torch.zeros((1,1,self.hidden_size))
|
||||||
|
for k, v in spk_mix_dict.items():
|
||||||
|
spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device)
|
||||||
|
spk_embeddd = self.spk_embed(spk_id_torch)
|
||||||
|
self.speaker_map[k] = spk_embeddd
|
||||||
|
spk_embed_mix = spk_embed_mix + v * spk_embeddd
|
||||||
|
x = x + spk_embed_mix
|
||||||
|
else:
|
||||||
|
x = x + self.spk_embed(spk_id - 1)
|
||||||
|
self.speaker_map = self.speaker_map.unsqueeze(0)
|
||||||
|
self.speaker_map = self.speaker_map.detach()
|
||||||
|
return x.transpose(1, 2)
|
||||||
|
|
||||||
|
def OnnxExport(self, project_name=None, init_noise=None, export_encoder=True, export_denoise=True, export_pred=True, export_after=True):
|
||||||
|
hubert_hidden_size = 768
|
||||||
|
n_frames = 100
|
||||||
|
hubert = torch.randn((1, n_frames, hubert_hidden_size))
|
||||||
|
mel2ph = torch.arange(end=n_frames).unsqueeze(0).long()
|
||||||
|
f0 = torch.randn((1, n_frames))
|
||||||
|
volume = torch.randn((1, n_frames))
|
||||||
|
spk_mix = []
|
||||||
|
spks = {}
|
||||||
|
if self.n_spk is not None and self.n_spk > 1:
|
||||||
|
for i in range(self.n_spk):
|
||||||
|
spk_mix.append(1.0/float(self.n_spk))
|
||||||
|
spks.update({i:1.0/float(self.n_spk)})
|
||||||
|
spk_mix = torch.tensor(spk_mix)
|
||||||
|
spk_mix = spk_mix.repeat(n_frames, 1)
|
||||||
|
orgouttt = self.init_spkembed(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks)
|
||||||
|
outtt = self.forward(hubert, mel2ph, f0, volume, spk_mix)
|
||||||
|
if export_encoder:
|
||||||
|
torch.onnx.export(
|
||||||
|
self,
|
||||||
|
(hubert, mel2ph, f0, volume, spk_mix),
|
||||||
|
f"{project_name}_encoder.onnx",
|
||||||
|
input_names=["hubert", "mel2ph", "f0", "volume", "spk_mix"],
|
||||||
|
output_names=["mel_pred"],
|
||||||
|
dynamic_axes={
|
||||||
|
"hubert": [1],
|
||||||
|
"f0": [1],
|
||||||
|
"volume": [1],
|
||||||
|
"mel2ph": [1],
|
||||||
|
"spk_mix": [0],
|
||||||
|
},
|
||||||
|
opset_version=16
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decoder.OnnxExport(project_name, init_noise=init_noise, export_denoise=export_denoise, export_pred=export_pred, export_after=export_after)
|
||||||
|
|
||||||
|
def ExportOnnx(self, project_name=None):
|
||||||
|
hubert_hidden_size = 768
|
||||||
|
n_frames = 100
|
||||||
|
hubert = torch.randn((1, n_frames, hubert_hidden_size))
|
||||||
|
mel2ph = torch.arange(end=n_frames).unsqueeze(0).long()
|
||||||
|
f0 = torch.randn((1, n_frames))
|
||||||
|
volume = torch.randn((1, n_frames))
|
||||||
|
spk_mix = []
|
||||||
|
spks = {}
|
||||||
|
if self.n_spk is not None and self.n_spk > 1:
|
||||||
|
for i in range(self.n_spk):
|
||||||
|
spk_mix.append(1.0/float(self.n_spk))
|
||||||
|
spks.update({i:1.0/float(self.n_spk)})
|
||||||
|
spk_mix = torch.tensor(spk_mix)
|
||||||
|
orgouttt = self.orgforward(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks)
|
||||||
|
outtt = self.forward(hubert, mel2ph, f0, volume, spk_mix)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
self,
|
||||||
|
(hubert, mel2ph, f0, volume, spk_mix),
|
||||||
|
f"{project_name}_encoder.onnx",
|
||||||
|
input_names=["hubert", "mel2ph", "f0", "volume", "spk_mix"],
|
||||||
|
output_names=["mel_pred"],
|
||||||
|
dynamic_axes={
|
||||||
|
"hubert": [1],
|
||||||
|
"f0": [1],
|
||||||
|
"volume": [1],
|
||||||
|
"mel2ph": [1]
|
||||||
|
},
|
||||||
|
opset_version=16
|
||||||
|
)
|
||||||
|
|
||||||
|
condition = torch.randn(1,self.decoder.n_hidden,n_frames)
|
||||||
|
noise = torch.randn((1, 1, self.decoder.mel_bins, condition.shape[2]), dtype=torch.float32)
|
||||||
|
pndm_speedup = torch.LongTensor([100])
|
||||||
|
K_steps = torch.LongTensor([1000])
|
||||||
|
self.decoder = torch.jit.script(self.decoder)
|
||||||
|
self.decoder(condition, noise, pndm_speedup, K_steps)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
self.decoder,
|
||||||
|
(condition, noise, pndm_speedup, K_steps),
|
||||||
|
f"{project_name}_diffusion.onnx",
|
||||||
|
input_names=["condition", "noise", "pndm_speedup", "K_steps"],
|
||||||
|
output_names=["mel"],
|
||||||
|
dynamic_axes={
|
||||||
|
"condition": [2],
|
||||||
|
"noise": [3],
|
||||||
|
},
|
||||||
|
opset_version=16
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
project_name = "dddsp"
|
||||||
|
model_path = f'{project_name}/model_500000.pt'
|
||||||
|
|
||||||
|
model, _ = load_model_vocoder(model_path)
|
||||||
|
|
||||||
|
# 分开Diffusion导出(需要使用MoeSS/MoeVoiceStudio或者自己编写Pndm/Dpm采样)
|
||||||
|
model.OnnxExport(project_name, export_encoder=True, export_denoise=True, export_pred=True, export_after=True)
|
||||||
|
|
||||||
|
# 合并Diffusion导出(Encoder和Diffusion分开,直接将Encoder的结果和初始噪声输入Diffusion即可)
|
||||||
|
# model.ExportOnnx(project_name)
|
||||||
|
|
194
server/voice_changer/DDSP_SVC/models/diffusion/solver.py
Normal file
194
server/voice_changer/DDSP_SVC/models/diffusion/solver.py
Normal file
@ -0,0 +1,194 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import librosa
|
||||||
|
from logger.saver import Saver
|
||||||
|
from logger import utils
|
||||||
|
from torch import autocast
|
||||||
|
from torch.cuda.amp import GradScaler
|
||||||
|
|
||||||
|
def test(args, model, vocoder, loader_test, saver):
|
||||||
|
print(' [*] testing...')
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# losses
|
||||||
|
test_loss = 0.
|
||||||
|
|
||||||
|
# intialization
|
||||||
|
num_batches = len(loader_test)
|
||||||
|
rtf_all = []
|
||||||
|
|
||||||
|
# run
|
||||||
|
with torch.no_grad():
|
||||||
|
for bidx, data in enumerate(loader_test):
|
||||||
|
fn = data['name'][0]
|
||||||
|
print('--------')
|
||||||
|
print('{}/{} - {}'.format(bidx, num_batches, fn))
|
||||||
|
|
||||||
|
# unpack data
|
||||||
|
for k in data.keys():
|
||||||
|
if not k.startswith('name'):
|
||||||
|
data[k] = data[k].to(args.device)
|
||||||
|
print('>>', data['name'][0])
|
||||||
|
|
||||||
|
# forward
|
||||||
|
st_time = time.time()
|
||||||
|
mel = model(
|
||||||
|
data['units'],
|
||||||
|
data['f0'],
|
||||||
|
data['volume'],
|
||||||
|
data['spk_id'],
|
||||||
|
gt_spec=None,
|
||||||
|
infer=True,
|
||||||
|
infer_speedup=args.infer.speedup,
|
||||||
|
method=args.infer.method)
|
||||||
|
signal = vocoder.infer(mel, data['f0'])
|
||||||
|
ed_time = time.time()
|
||||||
|
|
||||||
|
# RTF
|
||||||
|
run_time = ed_time - st_time
|
||||||
|
song_time = signal.shape[-1] / args.data.sampling_rate
|
||||||
|
rtf = run_time / song_time
|
||||||
|
print('RTF: {} | {} / {}'.format(rtf, run_time, song_time))
|
||||||
|
rtf_all.append(rtf)
|
||||||
|
|
||||||
|
# loss
|
||||||
|
for i in range(args.train.batch_size):
|
||||||
|
loss = model(
|
||||||
|
data['units'],
|
||||||
|
data['f0'],
|
||||||
|
data['volume'],
|
||||||
|
data['spk_id'],
|
||||||
|
gt_spec=data['mel'],
|
||||||
|
infer=False)
|
||||||
|
test_loss += loss.item()
|
||||||
|
|
||||||
|
# log mel
|
||||||
|
saver.log_spec(data['name'][0], data['mel'], mel)
|
||||||
|
|
||||||
|
# log audio
|
||||||
|
path_audio = os.path.join(args.data.valid_path, 'audio', data['name_ext'][0])
|
||||||
|
audio, sr = librosa.load(path_audio, sr=args.data.sampling_rate)
|
||||||
|
if len(audio.shape) > 1:
|
||||||
|
audio = librosa.to_mono(audio)
|
||||||
|
audio = torch.from_numpy(audio).unsqueeze(0).to(signal)
|
||||||
|
saver.log_audio({fn+'/gt.wav': audio, fn+'/pred.wav': signal})
|
||||||
|
|
||||||
|
# report
|
||||||
|
test_loss /= args.train.batch_size
|
||||||
|
test_loss /= num_batches
|
||||||
|
|
||||||
|
# check
|
||||||
|
print(' [test_loss] test_loss:', test_loss)
|
||||||
|
print(' Real Time Factor', np.mean(rtf_all))
|
||||||
|
return test_loss
|
||||||
|
|
||||||
|
|
||||||
|
def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_test):
|
||||||
|
# saver
|
||||||
|
saver = Saver(args, initial_global_step=initial_global_step)
|
||||||
|
|
||||||
|
# model size
|
||||||
|
params_count = utils.get_network_paras_amount({'model': model})
|
||||||
|
saver.log_info('--- model size ---')
|
||||||
|
saver.log_info(params_count)
|
||||||
|
|
||||||
|
# run
|
||||||
|
num_batches = len(loader_train)
|
||||||
|
model.train()
|
||||||
|
saver.log_info('======= start training =======')
|
||||||
|
scaler = GradScaler()
|
||||||
|
if args.train.amp_dtype == 'fp32':
|
||||||
|
dtype = torch.float32
|
||||||
|
elif args.train.amp_dtype == 'fp16':
|
||||||
|
dtype = torch.float16
|
||||||
|
elif args.train.amp_dtype == 'bf16':
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
else:
|
||||||
|
raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype)
|
||||||
|
for epoch in range(args.train.epochs):
|
||||||
|
for batch_idx, data in enumerate(loader_train):
|
||||||
|
saver.global_step_increment()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# unpack data
|
||||||
|
for k in data.keys():
|
||||||
|
if not k.startswith('name'):
|
||||||
|
data[k] = data[k].to(args.device)
|
||||||
|
|
||||||
|
# forward
|
||||||
|
if dtype == torch.float32:
|
||||||
|
loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'],
|
||||||
|
aug_shift = data['aug_shift'], gt_spec=data['mel'].float(), infer=False)
|
||||||
|
else:
|
||||||
|
with autocast(device_type=args.device, dtype=dtype):
|
||||||
|
loss = model(data['units'], data['f0'], data['volume'], data['spk_id'],
|
||||||
|
aug_shift = data['aug_shift'], gt_spec=data['mel'], infer=False)
|
||||||
|
|
||||||
|
# handle nan loss
|
||||||
|
if torch.isnan(loss):
|
||||||
|
raise ValueError(' [x] nan loss ')
|
||||||
|
else:
|
||||||
|
# backpropagate
|
||||||
|
if dtype == torch.float32:
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
else:
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
# log loss
|
||||||
|
if saver.global_step % args.train.interval_log == 0:
|
||||||
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
|
saver.log_info(
|
||||||
|
'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | lr: {:.6} | loss: {:.3f} | time: {} | step: {}'.format(
|
||||||
|
epoch,
|
||||||
|
batch_idx,
|
||||||
|
num_batches,
|
||||||
|
args.env.expdir,
|
||||||
|
args.train.interval_log/saver.get_interval_time(),
|
||||||
|
current_lr,
|
||||||
|
loss.item(),
|
||||||
|
saver.get_total_time(),
|
||||||
|
saver.global_step
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
saver.log_value({
|
||||||
|
'train/loss': loss.item()
|
||||||
|
})
|
||||||
|
|
||||||
|
saver.log_value({
|
||||||
|
'train/lr': current_lr
|
||||||
|
})
|
||||||
|
|
||||||
|
# validation
|
||||||
|
if saver.global_step % args.train.interval_val == 0:
|
||||||
|
optimizer_save = optimizer if args.train.save_opt else None
|
||||||
|
|
||||||
|
# save latest
|
||||||
|
saver.save_model(model, optimizer_save, postfix=f'{saver.global_step}')
|
||||||
|
last_val_step = saver.global_step - args.train.interval_val
|
||||||
|
if last_val_step % args.train.interval_force_save != 0:
|
||||||
|
saver.delete_model(postfix=f'{last_val_step}')
|
||||||
|
|
||||||
|
# run testing set
|
||||||
|
test_loss = test(args, model, vocoder, loader_test, saver)
|
||||||
|
|
||||||
|
# log loss
|
||||||
|
saver.log_info(
|
||||||
|
' --- <validation> --- \nloss: {:.3f}. '.format(
|
||||||
|
test_loss,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
saver.log_value({
|
||||||
|
'validation/loss': test_loss
|
||||||
|
})
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
|
731
server/voice_changer/DDSP_SVC/models/diffusion/uni_pc.py
Normal file
731
server/voice_changer/DDSP_SVC/models/diffusion/uni_pc.py
Normal file
@ -0,0 +1,731 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseScheduleVP:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
schedule='discrete',
|
||||||
|
betas=None,
|
||||||
|
alphas_cumprod=None,
|
||||||
|
continuous_beta_0=0.1,
|
||||||
|
continuous_beta_1=20.,
|
||||||
|
dtype=torch.float32,
|
||||||
|
):
|
||||||
|
"""Create a wrapper class for the forward SDE (VP type).
|
||||||
|
***
|
||||||
|
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
||||||
|
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
|
||||||
|
***
|
||||||
|
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
|
||||||
|
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
|
||||||
|
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
|
||||||
|
log_alpha_t = self.marginal_log_mean_coeff(t)
|
||||||
|
sigma_t = self.marginal_std(t)
|
||||||
|
lambda_t = self.marginal_lambda(t)
|
||||||
|
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
|
||||||
|
t = self.inverse_lambda(lambda_t)
|
||||||
|
===============================================================
|
||||||
|
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
|
||||||
|
1. For discrete-time DPMs:
|
||||||
|
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
|
||||||
|
t_i = (i + 1) / N
|
||||||
|
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
|
||||||
|
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
|
||||||
|
Args:
|
||||||
|
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
|
||||||
|
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
|
||||||
|
Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
|
||||||
|
**Important**: Please pay special attention for the args for `alphas_cumprod`:
|
||||||
|
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
|
||||||
|
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
|
||||||
|
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
|
||||||
|
alpha_{t_n} = \sqrt{\hat{alpha_n}},
|
||||||
|
and
|
||||||
|
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
|
||||||
|
2. For continuous-time DPMs:
|
||||||
|
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
|
||||||
|
schedule are the default settings in DDPM and improved-DDPM:
|
||||||
|
Args:
|
||||||
|
beta_min: A `float` number. The smallest beta for the linear schedule.
|
||||||
|
beta_max: A `float` number. The largest beta for the linear schedule.
|
||||||
|
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
|
||||||
|
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
|
||||||
|
T: A `float` number. The ending time of the forward process.
|
||||||
|
===============================================================
|
||||||
|
Args:
|
||||||
|
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
|
||||||
|
'linear' or 'cosine' for continuous-time DPMs.
|
||||||
|
Returns:
|
||||||
|
A wrapper object of the forward SDE (VP type).
|
||||||
|
|
||||||
|
===============================================================
|
||||||
|
Example:
|
||||||
|
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
|
||||||
|
>>> ns = NoiseScheduleVP('discrete', betas=betas)
|
||||||
|
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
|
||||||
|
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
||||||
|
# For continuous-time DPMs (VPSDE), linear schedule:
|
||||||
|
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if schedule not in ['discrete', 'linear', 'cosine']:
|
||||||
|
raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
|
||||||
|
|
||||||
|
self.schedule = schedule
|
||||||
|
if schedule == 'discrete':
|
||||||
|
if betas is not None:
|
||||||
|
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
||||||
|
else:
|
||||||
|
assert alphas_cumprod is not None
|
||||||
|
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
||||||
|
self.total_N = len(log_alphas)
|
||||||
|
self.T = 1.
|
||||||
|
self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
|
||||||
|
self.log_alpha_array = log_alphas.reshape((1, -1,)).to(dtype=dtype)
|
||||||
|
else:
|
||||||
|
self.total_N = 1000
|
||||||
|
self.beta_0 = continuous_beta_0
|
||||||
|
self.beta_1 = continuous_beta_1
|
||||||
|
self.cosine_s = 0.008
|
||||||
|
self.cosine_beta_max = 999.
|
||||||
|
self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
||||||
|
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
|
||||||
|
self.schedule = schedule
|
||||||
|
if schedule == 'cosine':
|
||||||
|
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
|
||||||
|
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
|
||||||
|
self.T = 0.9946
|
||||||
|
else:
|
||||||
|
self.T = 1.
|
||||||
|
|
||||||
|
def marginal_log_mean_coeff(self, t):
|
||||||
|
"""
|
||||||
|
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
||||||
|
"""
|
||||||
|
if self.schedule == 'discrete':
|
||||||
|
return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
|
||||||
|
elif self.schedule == 'linear':
|
||||||
|
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
||||||
|
elif self.schedule == 'cosine':
|
||||||
|
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
|
||||||
|
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
|
||||||
|
return log_alpha_t
|
||||||
|
|
||||||
|
def marginal_alpha(self, t):
|
||||||
|
"""
|
||||||
|
Compute alpha_t of a given continuous-time label t in [0, T].
|
||||||
|
"""
|
||||||
|
return torch.exp(self.marginal_log_mean_coeff(t))
|
||||||
|
|
||||||
|
def marginal_std(self, t):
|
||||||
|
"""
|
||||||
|
Compute sigma_t of a given continuous-time label t in [0, T].
|
||||||
|
"""
|
||||||
|
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
|
||||||
|
|
||||||
|
def marginal_lambda(self, t):
|
||||||
|
"""
|
||||||
|
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
||||||
|
"""
|
||||||
|
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
||||||
|
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
||||||
|
return log_mean_coeff - log_std
|
||||||
|
|
||||||
|
def inverse_lambda(self, lamb):
|
||||||
|
"""
|
||||||
|
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
||||||
|
"""
|
||||||
|
if self.schedule == 'linear':
|
||||||
|
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
||||||
|
Delta = self.beta_0**2 + tmp
|
||||||
|
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
||||||
|
elif self.schedule == 'discrete':
|
||||||
|
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
|
||||||
|
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
|
||||||
|
return t.reshape((-1,))
|
||||||
|
else:
|
||||||
|
log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
||||||
|
t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
||||||
|
t = t_fn(log_alpha)
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def model_wrapper(
|
||||||
|
model,
|
||||||
|
noise_schedule,
|
||||||
|
model_type="noise",
|
||||||
|
model_kwargs={},
|
||||||
|
guidance_type="uncond",
|
||||||
|
condition=None,
|
||||||
|
unconditional_condition=None,
|
||||||
|
guidance_scale=1.,
|
||||||
|
classifier_fn=None,
|
||||||
|
classifier_kwargs={},
|
||||||
|
):
|
||||||
|
"""Create a wrapper function for the noise prediction model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_model_input_time(t_continuous):
|
||||||
|
"""
|
||||||
|
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
||||||
|
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
|
||||||
|
For continuous-time DPMs, we just use `t_continuous`.
|
||||||
|
"""
|
||||||
|
if noise_schedule.schedule == 'discrete':
|
||||||
|
return (t_continuous - 1. / noise_schedule.total_N) * noise_schedule.total_N
|
||||||
|
else:
|
||||||
|
return t_continuous
|
||||||
|
|
||||||
|
def noise_pred_fn(x, t_continuous, cond=None):
|
||||||
|
t_input = get_model_input_time(t_continuous)
|
||||||
|
if cond is None:
|
||||||
|
output = model(x, t_input, **model_kwargs)
|
||||||
|
else:
|
||||||
|
output = model(x, t_input, cond, **model_kwargs)
|
||||||
|
if model_type == "noise":
|
||||||
|
return output
|
||||||
|
elif model_type == "x_start":
|
||||||
|
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
||||||
|
return (x - alpha_t * output) / sigma_t
|
||||||
|
elif model_type == "v":
|
||||||
|
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
||||||
|
return alpha_t * output + sigma_t * x
|
||||||
|
elif model_type == "score":
|
||||||
|
sigma_t = noise_schedule.marginal_std(t_continuous)
|
||||||
|
return -sigma_t * output
|
||||||
|
|
||||||
|
def cond_grad_fn(x, t_input):
|
||||||
|
"""
|
||||||
|
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
|
||||||
|
"""
|
||||||
|
with torch.enable_grad():
|
||||||
|
x_in = x.detach().requires_grad_(True)
|
||||||
|
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
|
||||||
|
return torch.autograd.grad(log_prob.sum(), x_in)[0]
|
||||||
|
|
||||||
|
def model_fn(x, t_continuous):
|
||||||
|
"""
|
||||||
|
The noise predicition model function that is used for DPM-Solver.
|
||||||
|
"""
|
||||||
|
if guidance_type == "uncond":
|
||||||
|
return noise_pred_fn(x, t_continuous)
|
||||||
|
elif guidance_type == "classifier":
|
||||||
|
assert classifier_fn is not None
|
||||||
|
t_input = get_model_input_time(t_continuous)
|
||||||
|
cond_grad = cond_grad_fn(x, t_input)
|
||||||
|
sigma_t = noise_schedule.marginal_std(t_continuous)
|
||||||
|
noise = noise_pred_fn(x, t_continuous)
|
||||||
|
return noise - guidance_scale * sigma_t * cond_grad
|
||||||
|
elif guidance_type == "classifier-free":
|
||||||
|
if guidance_scale == 1. or unconditional_condition is None:
|
||||||
|
return noise_pred_fn(x, t_continuous, cond=condition)
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
t_in = torch.cat([t_continuous] * 2)
|
||||||
|
c_in = torch.cat([unconditional_condition, condition])
|
||||||
|
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
||||||
|
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
||||||
|
|
||||||
|
assert model_type in ["noise", "x_start", "v"]
|
||||||
|
assert guidance_type in ["uncond", "classifier", "classifier-free"]
|
||||||
|
return model_fn
|
||||||
|
|
||||||
|
|
||||||
|
class UniPC:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_fn,
|
||||||
|
noise_schedule,
|
||||||
|
algorithm_type="data_prediction",
|
||||||
|
correcting_x0_fn=None,
|
||||||
|
correcting_xt_fn=None,
|
||||||
|
thresholding_max_val=1.,
|
||||||
|
dynamic_thresholding_ratio=0.995,
|
||||||
|
variant='bh1'
|
||||||
|
):
|
||||||
|
"""Construct a UniPC.
|
||||||
|
|
||||||
|
We support both data_prediction and noise_prediction.
|
||||||
|
"""
|
||||||
|
self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
|
||||||
|
self.noise_schedule = noise_schedule
|
||||||
|
assert algorithm_type in ["data_prediction", "noise_prediction"]
|
||||||
|
|
||||||
|
if correcting_x0_fn == "dynamic_thresholding":
|
||||||
|
self.correcting_x0_fn = self.dynamic_thresholding_fn
|
||||||
|
else:
|
||||||
|
self.correcting_x0_fn = correcting_x0_fn
|
||||||
|
|
||||||
|
self.correcting_xt_fn = correcting_xt_fn
|
||||||
|
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
|
||||||
|
self.thresholding_max_val = thresholding_max_val
|
||||||
|
|
||||||
|
self.variant = variant
|
||||||
|
self.predict_x0 = algorithm_type == "data_prediction"
|
||||||
|
|
||||||
|
def dynamic_thresholding_fn(self, x0, t=None):
|
||||||
|
"""
|
||||||
|
The dynamic thresholding method.
|
||||||
|
"""
|
||||||
|
dims = x0.dim()
|
||||||
|
p = self.dynamic_thresholding_ratio
|
||||||
|
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||||
|
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
|
||||||
|
x0 = torch.clamp(x0, -s, s) / s
|
||||||
|
return x0
|
||||||
|
|
||||||
|
def noise_prediction_fn(self, x, t):
|
||||||
|
"""
|
||||||
|
Return the noise prediction model.
|
||||||
|
"""
|
||||||
|
return self.model(x, t)
|
||||||
|
|
||||||
|
def data_prediction_fn(self, x, t):
|
||||||
|
"""
|
||||||
|
Return the data prediction model (with corrector).
|
||||||
|
"""
|
||||||
|
noise = self.noise_prediction_fn(x, t)
|
||||||
|
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
||||||
|
x0 = (x - sigma_t * noise) / alpha_t
|
||||||
|
if self.correcting_x0_fn is not None:
|
||||||
|
x0 = self.correcting_x0_fn(x0)
|
||||||
|
return x0
|
||||||
|
|
||||||
|
def model_fn(self, x, t):
|
||||||
|
"""
|
||||||
|
Convert the model to the noise prediction model or the data prediction model.
|
||||||
|
"""
|
||||||
|
if self.predict_x0:
|
||||||
|
return self.data_prediction_fn(x, t)
|
||||||
|
else:
|
||||||
|
return self.noise_prediction_fn(x, t)
|
||||||
|
|
||||||
|
def get_time_steps(self, skip_type, t_T, t_0, N, device):
|
||||||
|
"""Compute the intermediate time steps for sampling.
|
||||||
|
"""
|
||||||
|
if skip_type == 'logSNR':
|
||||||
|
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
||||||
|
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
||||||
|
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
||||||
|
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
||||||
|
elif skip_type == 'time_uniform':
|
||||||
|
return torch.linspace(t_T, t_0, N + 1).to(device)
|
||||||
|
elif skip_type == 'time_quadratic':
|
||||||
|
t_order = 2
|
||||||
|
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
||||||
|
return t
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
|
||||||
|
|
||||||
|
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
||||||
|
"""
|
||||||
|
Get the order of each step for sampling by the singlestep DPM-Solver.
|
||||||
|
"""
|
||||||
|
if order == 3:
|
||||||
|
K = steps // 3 + 1
|
||||||
|
if steps % 3 == 0:
|
||||||
|
orders = [3,] * (K - 2) + [2, 1]
|
||||||
|
elif steps % 3 == 1:
|
||||||
|
orders = [3,] * (K - 1) + [1]
|
||||||
|
else:
|
||||||
|
orders = [3,] * (K - 1) + [2]
|
||||||
|
elif order == 2:
|
||||||
|
if steps % 2 == 0:
|
||||||
|
K = steps // 2
|
||||||
|
orders = [2,] * K
|
||||||
|
else:
|
||||||
|
K = steps // 2 + 1
|
||||||
|
orders = [2,] * (K - 1) + [1]
|
||||||
|
elif order == 1:
|
||||||
|
K = steps
|
||||||
|
orders = [1,] * steps
|
||||||
|
else:
|
||||||
|
raise ValueError("'order' must be '1' or '2' or '3'.")
|
||||||
|
if skip_type == 'logSNR':
|
||||||
|
# To reproduce the results in DPM-Solver paper
|
||||||
|
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
||||||
|
else:
|
||||||
|
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
|
||||||
|
return timesteps_outer, orders
|
||||||
|
|
||||||
|
def denoise_to_zero_fn(self, x, s):
|
||||||
|
"""
|
||||||
|
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
||||||
|
"""
|
||||||
|
return self.data_prediction_fn(x, s)
|
||||||
|
|
||||||
|
def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
|
||||||
|
if len(t.shape) == 0:
|
||||||
|
t = t.view(-1)
|
||||||
|
if 'bh' in self.variant:
|
||||||
|
return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
||||||
|
else:
|
||||||
|
assert self.variant == 'vary_coeff'
|
||||||
|
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
||||||
|
|
||||||
|
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
||||||
|
#print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
||||||
|
ns = self.noise_schedule
|
||||||
|
assert order <= len(model_prev_list)
|
||||||
|
|
||||||
|
# first compute rks
|
||||||
|
t_prev_0 = t_prev_list[-1]
|
||||||
|
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
||||||
|
lambda_t = ns.marginal_lambda(t)
|
||||||
|
model_prev_0 = model_prev_list[-1]
|
||||||
|
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
||||||
|
log_alpha_t = ns.marginal_log_mean_coeff(t)
|
||||||
|
alpha_t = torch.exp(log_alpha_t)
|
||||||
|
|
||||||
|
h = lambda_t - lambda_prev_0
|
||||||
|
|
||||||
|
rks = []
|
||||||
|
D1s = []
|
||||||
|
for i in range(1, order):
|
||||||
|
t_prev_i = t_prev_list[-(i + 1)]
|
||||||
|
model_prev_i = model_prev_list[-(i + 1)]
|
||||||
|
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
||||||
|
rk = (lambda_prev_i - lambda_prev_0) / h
|
||||||
|
rks.append(rk)
|
||||||
|
D1s.append((model_prev_i - model_prev_0) / rk)
|
||||||
|
|
||||||
|
rks.append(1.)
|
||||||
|
rks = torch.tensor(rks, device=x.device)
|
||||||
|
|
||||||
|
K = len(rks)
|
||||||
|
# build C matrix
|
||||||
|
C = []
|
||||||
|
|
||||||
|
col = torch.ones_like(rks)
|
||||||
|
for k in range(1, K + 1):
|
||||||
|
C.append(col)
|
||||||
|
col = col * rks / (k + 1)
|
||||||
|
C = torch.stack(C, dim=1)
|
||||||
|
|
||||||
|
if len(D1s) > 0:
|
||||||
|
D1s = torch.stack(D1s, dim=1) # (B, K)
|
||||||
|
C_inv_p = torch.linalg.inv(C[:-1, :-1])
|
||||||
|
A_p = C_inv_p
|
||||||
|
|
||||||
|
if use_corrector:
|
||||||
|
#print('using corrector')
|
||||||
|
C_inv = torch.linalg.inv(C)
|
||||||
|
A_c = C_inv
|
||||||
|
|
||||||
|
hh = -h if self.predict_x0 else h
|
||||||
|
h_phi_1 = torch.expm1(hh)
|
||||||
|
h_phi_ks = []
|
||||||
|
factorial_k = 1
|
||||||
|
h_phi_k = h_phi_1
|
||||||
|
for k in range(1, K + 2):
|
||||||
|
h_phi_ks.append(h_phi_k)
|
||||||
|
h_phi_k = h_phi_k / hh - 1 / factorial_k
|
||||||
|
factorial_k *= (k + 1)
|
||||||
|
|
||||||
|
model_t = None
|
||||||
|
if self.predict_x0:
|
||||||
|
x_t_ = (
|
||||||
|
sigma_t / sigma_prev_0 * x
|
||||||
|
- alpha_t * h_phi_1 * model_prev_0
|
||||||
|
)
|
||||||
|
# now predictor
|
||||||
|
x_t = x_t_
|
||||||
|
if len(D1s) > 0:
|
||||||
|
# compute the residuals for predictor
|
||||||
|
for k in range(K - 1):
|
||||||
|
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
|
||||||
|
# now corrector
|
||||||
|
if use_corrector:
|
||||||
|
model_t = self.model_fn(x_t, t)
|
||||||
|
D1_t = (model_t - model_prev_0)
|
||||||
|
x_t = x_t_
|
||||||
|
k = 0
|
||||||
|
for k in range(K - 1):
|
||||||
|
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
|
||||||
|
x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
|
||||||
|
else:
|
||||||
|
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
||||||
|
x_t_ = (
|
||||||
|
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
|
||||||
|
- (sigma_t * h_phi_1) * model_prev_0
|
||||||
|
)
|
||||||
|
# now predictor
|
||||||
|
x_t = x_t_
|
||||||
|
if len(D1s) > 0:
|
||||||
|
# compute the residuals for predictor
|
||||||
|
for k in range(K - 1):
|
||||||
|
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
|
||||||
|
# now corrector
|
||||||
|
if use_corrector:
|
||||||
|
model_t = self.model_fn(x_t, t)
|
||||||
|
D1_t = (model_t - model_prev_0)
|
||||||
|
x_t = x_t_
|
||||||
|
k = 0
|
||||||
|
for k in range(K - 1):
|
||||||
|
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
|
||||||
|
x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
|
||||||
|
return x_t, model_t
|
||||||
|
|
||||||
|
def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
|
||||||
|
#print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
|
||||||
|
ns = self.noise_schedule
|
||||||
|
assert order <= len(model_prev_list)
|
||||||
|
|
||||||
|
# first compute rks
|
||||||
|
t_prev_0 = t_prev_list[-1]
|
||||||
|
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
||||||
|
lambda_t = ns.marginal_lambda(t)
|
||||||
|
model_prev_0 = model_prev_list[-1]
|
||||||
|
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
||||||
|
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
||||||
|
alpha_t = torch.exp(log_alpha_t)
|
||||||
|
|
||||||
|
h = lambda_t - lambda_prev_0
|
||||||
|
|
||||||
|
rks = []
|
||||||
|
D1s = []
|
||||||
|
for i in range(1, order):
|
||||||
|
t_prev_i = t_prev_list[-(i + 1)]
|
||||||
|
model_prev_i = model_prev_list[-(i + 1)]
|
||||||
|
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
||||||
|
rk = (lambda_prev_i - lambda_prev_0) / h
|
||||||
|
rks.append(rk)
|
||||||
|
D1s.append((model_prev_i - model_prev_0) / rk)
|
||||||
|
|
||||||
|
rks.append(1.)
|
||||||
|
rks = torch.tensor(rks, device=x.device)
|
||||||
|
|
||||||
|
R = []
|
||||||
|
b = []
|
||||||
|
|
||||||
|
hh = -h if self.predict_x0 else h
|
||||||
|
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
||||||
|
h_phi_k = h_phi_1 / hh - 1
|
||||||
|
|
||||||
|
factorial_i = 1
|
||||||
|
|
||||||
|
if self.variant == 'bh1':
|
||||||
|
B_h = hh
|
||||||
|
elif self.variant == 'bh2':
|
||||||
|
B_h = torch.expm1(hh)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
for i in range(1, order + 1):
|
||||||
|
R.append(torch.pow(rks, i - 1))
|
||||||
|
b.append(h_phi_k * factorial_i / B_h)
|
||||||
|
factorial_i *= (i + 1)
|
||||||
|
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||||
|
|
||||||
|
R = torch.stack(R)
|
||||||
|
b = torch.cat(b)
|
||||||
|
|
||||||
|
# now predictor
|
||||||
|
use_predictor = len(D1s) > 0 and x_t is None
|
||||||
|
if len(D1s) > 0:
|
||||||
|
D1s = torch.stack(D1s, dim=1) # (B, K)
|
||||||
|
if x_t is None:
|
||||||
|
# for order 2, we use a simplified version
|
||||||
|
if order == 2:
|
||||||
|
rhos_p = torch.tensor([0.5], device=b.device)
|
||||||
|
else:
|
||||||
|
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
||||||
|
else:
|
||||||
|
D1s = None
|
||||||
|
|
||||||
|
if use_corrector:
|
||||||
|
#print('using corrector')
|
||||||
|
# for order 1, we use a simplified version
|
||||||
|
if order == 1:
|
||||||
|
rhos_c = torch.tensor([0.5], device=b.device)
|
||||||
|
else:
|
||||||
|
rhos_c = torch.linalg.solve(R, b)
|
||||||
|
|
||||||
|
model_t = None
|
||||||
|
if self.predict_x0:
|
||||||
|
x_t_ = (
|
||||||
|
sigma_t / sigma_prev_0 * x
|
||||||
|
- alpha_t * h_phi_1 * model_prev_0
|
||||||
|
)
|
||||||
|
|
||||||
|
if x_t is None:
|
||||||
|
if use_predictor:
|
||||||
|
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
||||||
|
else:
|
||||||
|
pred_res = 0
|
||||||
|
x_t = x_t_ - alpha_t * B_h * pred_res
|
||||||
|
|
||||||
|
if use_corrector:
|
||||||
|
model_t = self.model_fn(x_t, t)
|
||||||
|
if D1s is not None:
|
||||||
|
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
||||||
|
else:
|
||||||
|
corr_res = 0
|
||||||
|
D1_t = (model_t - model_prev_0)
|
||||||
|
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
||||||
|
else:
|
||||||
|
x_t_ = (
|
||||||
|
torch.exp(log_alpha_t - log_alpha_prev_0) * x
|
||||||
|
- sigma_t * h_phi_1 * model_prev_0
|
||||||
|
)
|
||||||
|
if x_t is None:
|
||||||
|
if use_predictor:
|
||||||
|
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
||||||
|
else:
|
||||||
|
pred_res = 0
|
||||||
|
x_t = x_t_ - sigma_t * B_h * pred_res
|
||||||
|
|
||||||
|
if use_corrector:
|
||||||
|
model_t = self.model_fn(x_t, t)
|
||||||
|
if D1s is not None:
|
||||||
|
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
||||||
|
else:
|
||||||
|
corr_res = 0
|
||||||
|
D1_t = (model_t - model_prev_0)
|
||||||
|
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
||||||
|
return x_t, model_t
|
||||||
|
|
||||||
|
def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
|
||||||
|
method='multistep', lower_order_final=True, denoise_to_zero=False, atol=0.0078, rtol=0.05, return_intermediate=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute the sample at time `t_end` by UniPC, given the initial `x` at time `t_start`.
|
||||||
|
"""
|
||||||
|
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
||||||
|
t_T = self.noise_schedule.T if t_start is None else t_start
|
||||||
|
assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
|
||||||
|
if return_intermediate:
|
||||||
|
assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values"
|
||||||
|
if self.correcting_xt_fn is not None:
|
||||||
|
assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None"
|
||||||
|
device = x.device
|
||||||
|
intermediates = []
|
||||||
|
with torch.no_grad():
|
||||||
|
if method == 'multistep':
|
||||||
|
assert steps >= order
|
||||||
|
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
||||||
|
assert timesteps.shape[0] - 1 == steps
|
||||||
|
# Init the initial values.
|
||||||
|
step = 0
|
||||||
|
t = timesteps[step]
|
||||||
|
t_prev_list = [t]
|
||||||
|
model_prev_list = [self.model_fn(x, t)]
|
||||||
|
if self.correcting_xt_fn is not None:
|
||||||
|
x = self.correcting_xt_fn(x, t, step)
|
||||||
|
if return_intermediate:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
# Init the first `order` values by lower order multistep UniPC.
|
||||||
|
for step in range(1, order):
|
||||||
|
t = timesteps[step]
|
||||||
|
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, t, step, use_corrector=True)
|
||||||
|
if model_x is None:
|
||||||
|
model_x = self.model_fn(x, t)
|
||||||
|
if self.correcting_xt_fn is not None:
|
||||||
|
x = self.correcting_xt_fn(x, t, step)
|
||||||
|
if return_intermediate:
|
||||||
|
intermediates.append(x)
|
||||||
|
t_prev_list.append(t)
|
||||||
|
model_prev_list.append(model_x)
|
||||||
|
|
||||||
|
# Compute the remaining values by `order`-th order multistep DPM-Solver.
|
||||||
|
for step in range(order, steps + 1):
|
||||||
|
t = timesteps[step]
|
||||||
|
if lower_order_final:
|
||||||
|
step_order = min(order, steps + 1 - step)
|
||||||
|
else:
|
||||||
|
step_order = order
|
||||||
|
if step == steps:
|
||||||
|
#print('do not run corrector at the last step')
|
||||||
|
use_corrector = False
|
||||||
|
else:
|
||||||
|
use_corrector = True
|
||||||
|
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, t, step_order, use_corrector=use_corrector)
|
||||||
|
if self.correcting_xt_fn is not None:
|
||||||
|
x = self.correcting_xt_fn(x, t, step)
|
||||||
|
if return_intermediate:
|
||||||
|
intermediates.append(x)
|
||||||
|
for i in range(order - 1):
|
||||||
|
t_prev_list[i] = t_prev_list[i + 1]
|
||||||
|
model_prev_list[i] = model_prev_list[i + 1]
|
||||||
|
t_prev_list[-1] = t
|
||||||
|
# We do not need to evaluate the final model value.
|
||||||
|
if step < steps:
|
||||||
|
if model_x is None:
|
||||||
|
model_x = self.model_fn(x, t)
|
||||||
|
model_prev_list[-1] = model_x
|
||||||
|
else:
|
||||||
|
raise ValueError("Got wrong method {}".format(method))
|
||||||
|
|
||||||
|
if denoise_to_zero:
|
||||||
|
t = torch.ones((1,)).to(device) * t_0
|
||||||
|
x = self.denoise_to_zero_fn(x, t)
|
||||||
|
if self.correcting_xt_fn is not None:
|
||||||
|
x = self.correcting_xt_fn(x, t, step + 1)
|
||||||
|
if return_intermediate:
|
||||||
|
intermediates.append(x)
|
||||||
|
if return_intermediate:
|
||||||
|
return x, intermediates
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
#############################################################
|
||||||
|
# other utility functions
|
||||||
|
#############################################################
|
||||||
|
|
||||||
|
def interpolate_fn(x, xp, yp):
|
||||||
|
"""
|
||||||
|
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
||||||
|
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
||||||
|
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
||||||
|
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
||||||
|
yp: PyTorch tensor with shape [C, K].
|
||||||
|
Returns:
|
||||||
|
The function values f(x), with shape [N, C].
|
||||||
|
"""
|
||||||
|
N, K = x.shape[0], xp.shape[1]
|
||||||
|
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
||||||
|
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
||||||
|
x_idx = torch.argmin(x_indices, dim=2)
|
||||||
|
cand_start_idx = x_idx - 1
|
||||||
|
start_idx = torch.where(
|
||||||
|
torch.eq(x_idx, 0),
|
||||||
|
torch.tensor(1, device=x.device),
|
||||||
|
torch.where(
|
||||||
|
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
||||||
|
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
||||||
|
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
||||||
|
start_idx2 = torch.where(
|
||||||
|
torch.eq(x_idx, 0),
|
||||||
|
torch.tensor(0, device=x.device),
|
||||||
|
torch.where(
|
||||||
|
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
||||||
|
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
||||||
|
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
||||||
|
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
||||||
|
return cand
|
||||||
|
|
||||||
|
|
||||||
|
def expand_dims(v, dims):
|
||||||
|
"""
|
||||||
|
Expand the tensor `v` to the dim `dims`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
`v`: a PyTorch tensor with shape [N].
|
||||||
|
`dim`: a `int`.
|
||||||
|
Returns:
|
||||||
|
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
||||||
|
"""
|
||||||
|
return v[(...,) + (None,)*(dims - 1)]
|
77
server/voice_changer/DDSP_SVC/models/diffusion/unit2mel.py
Normal file
77
server/voice_changer/DDSP_SVC/models/diffusion/unit2mel.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from .diffusion import GaussianDiffusion
|
||||||
|
from .wavenet import WaveNet
|
||||||
|
from .vocoder import Vocoder
|
||||||
|
|
||||||
|
|
||||||
|
class DotDict(dict):
|
||||||
|
def __getattr__(*args):
|
||||||
|
val = dict.get(*args)
|
||||||
|
return DotDict(val) if type(val) is dict else val
|
||||||
|
|
||||||
|
__setattr__ = dict.__setitem__
|
||||||
|
__delattr__ = dict.__delitem__
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_vocoder(model_path, device="cpu"):
|
||||||
|
config_file = os.path.join(os.path.split(model_path)[0], "config.yaml")
|
||||||
|
with open(config_file, "r") as config:
|
||||||
|
args = yaml.safe_load(config)
|
||||||
|
args = DotDict(args)
|
||||||
|
|
||||||
|
# load vocoder
|
||||||
|
vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=device)
|
||||||
|
|
||||||
|
# load model
|
||||||
|
model = Unit2Mel(args.data.encoder_out_channels, args.model.n_spk, args.model.use_pitch_aug, vocoder.dimension, args.model.n_layers, args.model.n_chans, args.model.n_hidden)
|
||||||
|
|
||||||
|
print(" [Loading] " + model_path)
|
||||||
|
ckpt = torch.load(model_path, map_location=torch.device(device))
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(ckpt["model"])
|
||||||
|
model.eval()
|
||||||
|
return model, vocoder, args
|
||||||
|
|
||||||
|
|
||||||
|
class Unit2Mel(nn.Module):
|
||||||
|
def __init__(self, input_channel, n_spk, use_pitch_aug=False, out_dims=128, n_layers=20, n_chans=384, n_hidden=256):
|
||||||
|
super().__init__()
|
||||||
|
self.unit_embed = nn.Linear(input_channel, n_hidden)
|
||||||
|
self.f0_embed = nn.Linear(1, n_hidden)
|
||||||
|
self.volume_embed = nn.Linear(1, n_hidden)
|
||||||
|
if use_pitch_aug:
|
||||||
|
self.aug_shift_embed = nn.Linear(1, n_hidden, bias=False)
|
||||||
|
else:
|
||||||
|
self.aug_shift_embed = None
|
||||||
|
self.n_spk = n_spk
|
||||||
|
if n_spk is not None and n_spk > 1:
|
||||||
|
self.spk_embed = nn.Embedding(n_spk, n_hidden)
|
||||||
|
|
||||||
|
# diffusion
|
||||||
|
self.decoder = GaussianDiffusion(WaveNet(out_dims, n_layers, n_chans, n_hidden), out_dims=out_dims)
|
||||||
|
|
||||||
|
def forward(self, units, f0, volume, spk_id=None, spk_mix_dict=None, aug_shift=None, gt_spec=None, infer=True, infer_speedup=10, method="dpm-solver", k_step=300, use_tqdm=True):
|
||||||
|
"""
|
||||||
|
input:
|
||||||
|
B x n_frames x n_unit
|
||||||
|
return:
|
||||||
|
dict of B x n_frames x feat
|
||||||
|
"""
|
||||||
|
|
||||||
|
x = self.unit_embed(units) + self.f0_embed((1 + f0 / 700).log()) + self.volume_embed(volume)
|
||||||
|
if self.n_spk is not None and self.n_spk > 1:
|
||||||
|
if spk_mix_dict is not None:
|
||||||
|
for k, v in spk_mix_dict.items():
|
||||||
|
spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device)
|
||||||
|
x = x + v * self.spk_embed(spk_id_torch - 1)
|
||||||
|
else:
|
||||||
|
x = x + self.spk_embed(spk_id - 1)
|
||||||
|
if self.aug_shift_embed is not None and aug_shift is not None:
|
||||||
|
x = x + self.aug_shift_embed(aug_shift / 5)
|
||||||
|
x = self.decoder(x, gt_spec=gt_spec, infer=infer, infer_speedup=infer_speedup, method=method, k_step=k_step, use_tqdm=use_tqdm)
|
||||||
|
|
||||||
|
return x
|
87
server/voice_changer/DDSP_SVC/models/diffusion/vocoder.py
Normal file
87
server/voice_changer/DDSP_SVC/models/diffusion/vocoder.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
import torch
|
||||||
|
from torchaudio.transforms import Resample
|
||||||
|
from ..nsf_hifigan.nvSTFT import STFT
|
||||||
|
from ..nsf_hifigan.models import load_model, load_config
|
||||||
|
|
||||||
|
|
||||||
|
class Vocoder:
|
||||||
|
def __init__(self, vocoder_type, vocoder_ckpt, device=None):
|
||||||
|
if device is None:
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
if vocoder_type == "nsf-hifigan":
|
||||||
|
self.vocoder = NsfHifiGAN(vocoder_ckpt, device=device)
|
||||||
|
elif vocoder_type == "nsf-hifigan-log10":
|
||||||
|
self.vocoder = NsfHifiGANLog10(vocoder_ckpt, device=device)
|
||||||
|
else:
|
||||||
|
raise ValueError(f" [x] Unknown vocoder: {vocoder_type}")
|
||||||
|
|
||||||
|
self.resample_kernel = {}
|
||||||
|
self.vocoder_sample_rate = self.vocoder.sample_rate()
|
||||||
|
self.vocoder_hop_size = self.vocoder.hop_size()
|
||||||
|
self.dimension = self.vocoder.dimension()
|
||||||
|
|
||||||
|
def extract(self, audio, sample_rate, keyshift=0):
|
||||||
|
# resample
|
||||||
|
if sample_rate == self.vocoder_sample_rate:
|
||||||
|
audio_res = audio
|
||||||
|
else:
|
||||||
|
key_str = str(sample_rate)
|
||||||
|
if key_str not in self.resample_kernel:
|
||||||
|
self.resample_kernel[key_str] = Resample(sample_rate, self.vocoder_sample_rate, lowpass_filter_width=128).to(self.device)
|
||||||
|
audio_res = self.resample_kernel[key_str](audio)
|
||||||
|
|
||||||
|
# extract
|
||||||
|
mel = self.vocoder.extract(audio_res, keyshift=keyshift) # B, n_frames, bins
|
||||||
|
return mel
|
||||||
|
|
||||||
|
def infer(self, mel, f0):
|
||||||
|
f0 = f0[:, : mel.size(1), 0] # B, n_frames
|
||||||
|
audio = self.vocoder(mel, f0)
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
class NsfHifiGAN(torch.nn.Module):
|
||||||
|
def __init__(self, model_path, device=None):
|
||||||
|
super().__init__()
|
||||||
|
if device is None:
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
self.device = device
|
||||||
|
self.model_path = model_path
|
||||||
|
self.model = None
|
||||||
|
self.h = load_config(model_path)
|
||||||
|
self.stft = STFT(self.h.sampling_rate, self.h.num_mels, self.h.n_fft, self.h.win_size, self.h.hop_size, self.h.fmin, self.h.fmax)
|
||||||
|
|
||||||
|
def sample_rate(self):
|
||||||
|
return self.h.sampling_rate
|
||||||
|
|
||||||
|
def hop_size(self):
|
||||||
|
return self.h.hop_size
|
||||||
|
|
||||||
|
def dimension(self):
|
||||||
|
return self.h.num_mels
|
||||||
|
|
||||||
|
def extract(self, audio, keyshift=0):
|
||||||
|
mel = self.stft.get_mel(audio, keyshift=keyshift).transpose(1, 2) # B, n_frames, bins
|
||||||
|
return mel
|
||||||
|
|
||||||
|
def forward(self, mel, f0):
|
||||||
|
if self.model is None:
|
||||||
|
print("| Load HifiGAN: ", self.model_path)
|
||||||
|
self.model, self.h = load_model(self.model_path, device=self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
c = mel.transpose(1, 2)
|
||||||
|
audio = self.model(c, f0)
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
class NsfHifiGANLog10(NsfHifiGAN):
|
||||||
|
def forward(self, mel, f0):
|
||||||
|
if self.model is None:
|
||||||
|
print("| Load HifiGAN: ", self.model_path)
|
||||||
|
self.model, self.h = load_model(self.model_path, device=self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
c = 0.434294 * mel.transpose(1, 2)
|
||||||
|
audio = self.model(c, f0)
|
||||||
|
return audio
|
91
server/voice_changer/DDSP_SVC/models/diffusion/wavenet.py
Normal file
91
server/voice_changer/DDSP_SVC/models/diffusion/wavenet.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import math
|
||||||
|
from math import sqrt
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import Mish
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1d(torch.nn.Conv1d):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
nn.init.kaiming_normal_(self.weight)
|
||||||
|
|
||||||
|
|
||||||
|
class SinusoidalPosEmb(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
device = x.device
|
||||||
|
half_dim = self.dim // 2
|
||||||
|
emb = math.log(10000) / (half_dim - 1)
|
||||||
|
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||||
|
emb = x[:, None] * emb[None, :]
|
||||||
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, encoder_hidden, residual_channels, dilation):
|
||||||
|
super().__init__()
|
||||||
|
self.residual_channels = residual_channels
|
||||||
|
self.dilated_conv = nn.Conv1d(residual_channels, 2 * residual_channels, kernel_size=3, padding=dilation, dilation=dilation)
|
||||||
|
self.diffusion_projection = nn.Linear(residual_channels, residual_channels)
|
||||||
|
self.conditioner_projection = nn.Conv1d(encoder_hidden, 2 * residual_channels, 1)
|
||||||
|
self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1)
|
||||||
|
|
||||||
|
def forward(self, x, conditioner, diffusion_step):
|
||||||
|
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
|
||||||
|
conditioner = self.conditioner_projection(conditioner)
|
||||||
|
y = x + diffusion_step
|
||||||
|
|
||||||
|
y = self.dilated_conv(y) + conditioner
|
||||||
|
|
||||||
|
# Using torch.split instead of torch.chunk to avoid using onnx::Slice
|
||||||
|
gate, filter = torch.split(y, [self.residual_channels, self.residual_channels], dim=1)
|
||||||
|
y = torch.sigmoid(gate) * torch.tanh(filter)
|
||||||
|
|
||||||
|
y = self.output_projection(y)
|
||||||
|
|
||||||
|
# Using torch.split instead of torch.chunk to avoid using onnx::Slice
|
||||||
|
residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1)
|
||||||
|
return (x + residual) / math.sqrt(2.0), skip
|
||||||
|
|
||||||
|
|
||||||
|
class WaveNet(nn.Module):
|
||||||
|
def __init__(self, in_dims=128, n_layers=20, n_chans=384, n_hidden=256):
|
||||||
|
super().__init__()
|
||||||
|
self.input_projection = Conv1d(in_dims, n_chans, 1)
|
||||||
|
self.diffusion_embedding = SinusoidalPosEmb(n_chans)
|
||||||
|
self.mlp = nn.Sequential(nn.Linear(n_chans, n_chans * 4), Mish(), nn.Linear(n_chans * 4, n_chans))
|
||||||
|
self.residual_layers = nn.ModuleList([ResidualBlock(encoder_hidden=n_hidden, residual_channels=n_chans, dilation=1) for i in range(n_layers)])
|
||||||
|
self.skip_projection = Conv1d(n_chans, n_chans, 1)
|
||||||
|
self.output_projection = Conv1d(n_chans, in_dims, 1)
|
||||||
|
nn.init.zeros_(self.output_projection.weight)
|
||||||
|
|
||||||
|
def forward(self, spec, diffusion_step, cond):
|
||||||
|
"""
|
||||||
|
:param spec: [B, 1, M, T]
|
||||||
|
:param diffusion_step: [B, 1]
|
||||||
|
:param cond: [B, M, T]
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
x = spec.squeeze(1)
|
||||||
|
x = self.input_projection(x) # [B, residual_channel, T]
|
||||||
|
|
||||||
|
x = F.relu(x)
|
||||||
|
diffusion_step = self.diffusion_embedding(diffusion_step)
|
||||||
|
diffusion_step = self.mlp(diffusion_step)
|
||||||
|
skip = []
|
||||||
|
for layer in self.residual_layers:
|
||||||
|
x, skip_connection = layer(x, cond, diffusion_step)
|
||||||
|
skip.append(skip_connection)
|
||||||
|
|
||||||
|
x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers))
|
||||||
|
x = self.skip_projection(x)
|
||||||
|
x = F.relu(x)
|
||||||
|
x = self.output_projection(x) # [B, mel_bins, T]
|
||||||
|
return x[:, None, :, :]
|
293
server/voice_changer/DDSP_SVC/models/encoder/hubert/model.py
Normal file
293
server/voice_changer/DDSP_SVC/models/encoder/hubert/model.py
Normal file
@ -0,0 +1,293 @@
|
|||||||
|
import copy
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
import random
|
||||||
|
|
||||||
|
from sklearn.cluster import KMeans
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
||||||
|
|
||||||
|
URLS = {
|
||||||
|
"hubert-discrete": "https://github.com/bshall/hubert/releases/download/v0.1/hubert-discrete-e9416457.pt",
|
||||||
|
"hubert-soft": "https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt",
|
||||||
|
"kmeans100": "https://github.com/bshall/hubert/releases/download/v0.1/kmeans100-50f36a95.pt",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Hubert(nn.Module):
|
||||||
|
def __init__(self, num_label_embeddings: int = 100, mask: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
self._mask = mask
|
||||||
|
self.feature_extractor = FeatureExtractor()
|
||||||
|
self.feature_projection = FeatureProjection()
|
||||||
|
self.positional_embedding = PositionalConvEmbedding()
|
||||||
|
self.norm = nn.LayerNorm(768)
|
||||||
|
self.dropout = nn.Dropout(0.1)
|
||||||
|
self.encoder = TransformerEncoder(
|
||||||
|
nn.TransformerEncoderLayer(
|
||||||
|
768, 12, 3072, activation="gelu", batch_first=True
|
||||||
|
),
|
||||||
|
12,
|
||||||
|
)
|
||||||
|
self.proj = nn.Linear(768, 256)
|
||||||
|
|
||||||
|
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_())
|
||||||
|
self.label_embedding = nn.Embedding(num_label_embeddings, 256)
|
||||||
|
|
||||||
|
def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
mask = None
|
||||||
|
if self.training and self._mask:
|
||||||
|
mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2)
|
||||||
|
x[mask] = self.masked_spec_embed.to(x.dtype)
|
||||||
|
return x, mask
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self, x: torch.Tensor, layer: Optional[int] = None
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
x = self.feature_extractor(x)
|
||||||
|
x = self.feature_projection(x.transpose(1, 2))
|
||||||
|
x, mask = self.mask(x)
|
||||||
|
x = x + self.positional_embedding(x)
|
||||||
|
x = self.dropout(self.norm(x))
|
||||||
|
x = self.encoder(x, output_layer=layer)
|
||||||
|
return x, mask
|
||||||
|
|
||||||
|
def logits(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
logits = torch.cosine_similarity(
|
||||||
|
x.unsqueeze(2),
|
||||||
|
self.label_embedding.weight.unsqueeze(0).unsqueeze(0),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
return logits / 0.1
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
x, mask = self.encode(x)
|
||||||
|
x = self.proj(x)
|
||||||
|
logits = self.logits(x)
|
||||||
|
return logits, mask
|
||||||
|
|
||||||
|
|
||||||
|
class HubertSoft(Hubert):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def units(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
|
wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))
|
||||||
|
x, _ = self.encode(wav)
|
||||||
|
return self.proj(x)
|
||||||
|
|
||||||
|
|
||||||
|
class HubertDiscrete(Hubert):
|
||||||
|
def __init__(self, kmeans):
|
||||||
|
super().__init__(504)
|
||||||
|
self.kmeans = kmeans
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def units(self, wav: torch.Tensor) -> torch.LongTensor:
|
||||||
|
wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))
|
||||||
|
x, _ = self.encode(wav, layer=7)
|
||||||
|
x = self.kmeans.predict(x.squeeze().cpu().numpy())
|
||||||
|
return torch.tensor(x, dtype=torch.long, device=wav.device)
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureExtractor(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False)
|
||||||
|
self.norm0 = nn.GroupNorm(512, 512)
|
||||||
|
self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False)
|
||||||
|
self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False)
|
||||||
|
self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False)
|
||||||
|
self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False)
|
||||||
|
self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False)
|
||||||
|
self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = F.gelu(self.norm0(self.conv0(x)))
|
||||||
|
x = F.gelu(self.conv1(x))
|
||||||
|
x = F.gelu(self.conv2(x))
|
||||||
|
x = F.gelu(self.conv3(x))
|
||||||
|
x = F.gelu(self.conv4(x))
|
||||||
|
x = F.gelu(self.conv5(x))
|
||||||
|
x = F.gelu(self.conv6(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureProjection(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = nn.LayerNorm(512)
|
||||||
|
self.projection = nn.Linear(512, 768)
|
||||||
|
self.dropout = nn.Dropout(0.1)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.projection(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalConvEmbedding(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
768,
|
||||||
|
768,
|
||||||
|
kernel_size=128,
|
||||||
|
padding=128 // 2,
|
||||||
|
groups=16,
|
||||||
|
)
|
||||||
|
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.conv(x.transpose(1, 2))
|
||||||
|
x = F.gelu(x[:, :, :-1])
|
||||||
|
return x.transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int
|
||||||
|
) -> None:
|
||||||
|
super(TransformerEncoder, self).__init__()
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[copy.deepcopy(encoder_layer) for _ in range(num_layers)]
|
||||||
|
)
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
src: torch.Tensor,
|
||||||
|
mask: torch.Tensor = None,
|
||||||
|
src_key_padding_mask: torch.Tensor = None,
|
||||||
|
output_layer: Optional[int] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
output = src
|
||||||
|
for layer in self.layers[:output_layer]:
|
||||||
|
output = layer(
|
||||||
|
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_mask(
|
||||||
|
shape: Tuple[int, int],
|
||||||
|
mask_prob: float,
|
||||||
|
mask_length: int,
|
||||||
|
device: torch.device,
|
||||||
|
min_masks: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, sequence_length = shape
|
||||||
|
|
||||||
|
if mask_length < 1:
|
||||||
|
raise ValueError("`mask_length` has to be bigger than 0.")
|
||||||
|
|
||||||
|
if mask_length > sequence_length:
|
||||||
|
raise ValueError(
|
||||||
|
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute number of masked spans in batch
|
||||||
|
num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random())
|
||||||
|
num_masked_spans = max(num_masked_spans, min_masks)
|
||||||
|
|
||||||
|
# make sure num masked indices <= sequence_length
|
||||||
|
if num_masked_spans * mask_length > sequence_length:
|
||||||
|
num_masked_spans = sequence_length // mask_length
|
||||||
|
|
||||||
|
# SpecAugment mask to fill
|
||||||
|
mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)
|
||||||
|
|
||||||
|
# uniform distribution to sample from, make sure that offset samples are < sequence_length
|
||||||
|
uniform_dist = torch.ones(
|
||||||
|
(batch_size, sequence_length - (mask_length - 1)), device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# get random indices to mask
|
||||||
|
mask_indices = torch.multinomial(uniform_dist, num_masked_spans)
|
||||||
|
|
||||||
|
# expand masked indices to masked spans
|
||||||
|
mask_indices = (
|
||||||
|
mask_indices.unsqueeze(dim=-1)
|
||||||
|
.expand((batch_size, num_masked_spans, mask_length))
|
||||||
|
.reshape(batch_size, num_masked_spans * mask_length)
|
||||||
|
)
|
||||||
|
offsets = (
|
||||||
|
torch.arange(mask_length, device=device)[None, None, :]
|
||||||
|
.expand((batch_size, num_masked_spans, mask_length))
|
||||||
|
.reshape(batch_size, num_masked_spans * mask_length)
|
||||||
|
)
|
||||||
|
mask_idxs = mask_indices + offsets
|
||||||
|
|
||||||
|
# scatter indices to mask
|
||||||
|
mask = mask.scatter(1, mask_idxs, True)
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def hubert_discrete(
|
||||||
|
pretrained: bool = True,
|
||||||
|
progress: bool = True,
|
||||||
|
) -> HubertDiscrete:
|
||||||
|
r"""HuBERT-Discrete from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
|
||||||
|
Args:
|
||||||
|
pretrained (bool): load pretrained weights into the model
|
||||||
|
progress (bool): show progress bar when downloading model
|
||||||
|
"""
|
||||||
|
kmeans = kmeans100(pretrained=pretrained, progress=progress)
|
||||||
|
hubert = HubertDiscrete(kmeans)
|
||||||
|
if pretrained:
|
||||||
|
checkpoint = torch.hub.load_state_dict_from_url(
|
||||||
|
URLS["hubert-discrete"], progress=progress
|
||||||
|
)
|
||||||
|
consume_prefix_in_state_dict_if_present(checkpoint, "module.")
|
||||||
|
hubert.load_state_dict(checkpoint)
|
||||||
|
hubert.eval()
|
||||||
|
return hubert
|
||||||
|
|
||||||
|
|
||||||
|
def hubert_soft(
|
||||||
|
pretrained: bool = True,
|
||||||
|
progress: bool = True,
|
||||||
|
) -> HubertSoft:
|
||||||
|
r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
|
||||||
|
Args:
|
||||||
|
pretrained (bool): load pretrained weights into the model
|
||||||
|
progress (bool): show progress bar when downloading model
|
||||||
|
"""
|
||||||
|
hubert = HubertSoft()
|
||||||
|
if pretrained:
|
||||||
|
checkpoint = torch.hub.load_state_dict_from_url(
|
||||||
|
URLS["hubert-soft"], progress=progress
|
||||||
|
)
|
||||||
|
consume_prefix_in_state_dict_if_present(checkpoint, "module.")
|
||||||
|
hubert.load_state_dict(checkpoint)
|
||||||
|
hubert.eval()
|
||||||
|
return hubert
|
||||||
|
|
||||||
|
|
||||||
|
def _kmeans(
|
||||||
|
num_clusters: int, pretrained: bool = True, progress: bool = True
|
||||||
|
) -> KMeans:
|
||||||
|
kmeans = KMeans(num_clusters)
|
||||||
|
if pretrained:
|
||||||
|
checkpoint = torch.hub.load_state_dict_from_url(
|
||||||
|
URLS[f"kmeans{num_clusters}"], progress=progress
|
||||||
|
)
|
||||||
|
kmeans.__dict__["n_features_in_"] = checkpoint["n_features_in_"]
|
||||||
|
kmeans.__dict__["_n_threads"] = checkpoint["_n_threads"]
|
||||||
|
kmeans.__dict__["cluster_centers_"] = checkpoint["cluster_centers_"].numpy()
|
||||||
|
return kmeans
|
||||||
|
|
||||||
|
|
||||||
|
def kmeans100(pretrained: bool = True, progress: bool = True) -> KMeans:
|
||||||
|
r"""
|
||||||
|
k-means checkpoint for HuBERT-Discrete with 100 clusters.
|
||||||
|
Args:
|
||||||
|
pretrained (bool): load pretrained weights into the model
|
||||||
|
progress (bool): show progress bar when downloading model
|
||||||
|
"""
|
||||||
|
return _kmeans(100, pretrained, progress)
|
102
server/voice_changer/DDSP_SVC/models/enhancer.py
Normal file
102
server/voice_changer/DDSP_SVC/models/enhancer.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torchaudio.transforms import Resample
|
||||||
|
from .nsf_hifigan.nvSTFT import STFT
|
||||||
|
from .nsf_hifigan.models import load_model
|
||||||
|
|
||||||
|
|
||||||
|
class Enhancer:
|
||||||
|
def __init__(self, enhancer_type, enhancer_ckpt, device=None):
|
||||||
|
if device is None:
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
if enhancer_type == "nsf-hifigan":
|
||||||
|
self.enhancer = NsfHifiGAN(enhancer_ckpt, device=self.device)
|
||||||
|
else:
|
||||||
|
raise ValueError(f" [x] Unknown enhancer: {enhancer_type}")
|
||||||
|
|
||||||
|
self.resample_kernel = {}
|
||||||
|
self.enhancer_sample_rate = self.enhancer.sample_rate()
|
||||||
|
self.enhancer_hop_size = self.enhancer.hop_size()
|
||||||
|
|
||||||
|
def enhance(self, audio, sample_rate, f0, hop_size, adaptive_key=0, silence_front=0): # 1, T # 1, n_frames, 1
|
||||||
|
# enhancer start time
|
||||||
|
start_frame = int(silence_front * sample_rate / hop_size)
|
||||||
|
real_silence_front = start_frame * hop_size / sample_rate
|
||||||
|
audio = audio[:, int(np.round(real_silence_front * sample_rate)) :]
|
||||||
|
f0 = f0[:, start_frame:, :]
|
||||||
|
|
||||||
|
# adaptive parameters
|
||||||
|
if adaptive_key == "auto":
|
||||||
|
adaptive_key = 12 * np.log2(float(torch.max(f0) / 760))
|
||||||
|
adaptive_key = max(0, np.ceil(adaptive_key))
|
||||||
|
print("auto_adaptive_key: " + str(int(adaptive_key)))
|
||||||
|
else:
|
||||||
|
adaptive_key = float(adaptive_key)
|
||||||
|
|
||||||
|
adaptive_factor = 2 ** (-adaptive_key / 12)
|
||||||
|
adaptive_sample_rate = 100 * int(np.round(self.enhancer_sample_rate / adaptive_factor / 100))
|
||||||
|
real_factor = self.enhancer_sample_rate / adaptive_sample_rate
|
||||||
|
|
||||||
|
# resample the ddsp output
|
||||||
|
if sample_rate == adaptive_sample_rate:
|
||||||
|
audio_res = audio
|
||||||
|
else:
|
||||||
|
key_str = str(sample_rate) + str(adaptive_sample_rate)
|
||||||
|
if key_str not in self.resample_kernel:
|
||||||
|
self.resample_kernel[key_str] = Resample(sample_rate, adaptive_sample_rate, lowpass_filter_width=128).to(self.device)
|
||||||
|
audio_res = self.resample_kernel[key_str](audio)
|
||||||
|
|
||||||
|
n_frames = int(audio_res.size(-1) // self.enhancer_hop_size + 1)
|
||||||
|
|
||||||
|
# resample f0
|
||||||
|
if hop_size == self.enhancer_hop_size and sample_rate == self.enhancer_sample_rate and sample_rate == adaptive_sample_rate:
|
||||||
|
f0_res = f0.squeeze(-1) # 1, n_frames
|
||||||
|
else:
|
||||||
|
f0_np = f0.squeeze(0).squeeze(-1).cpu().numpy()
|
||||||
|
f0_np *= real_factor
|
||||||
|
time_org = (hop_size / sample_rate) * np.arange(len(f0_np)) / real_factor
|
||||||
|
time_frame = (self.enhancer_hop_size / self.enhancer_sample_rate) * np.arange(n_frames)
|
||||||
|
f0_res = np.interp(time_frame, time_org, f0_np, left=f0_np[0], right=f0_np[-1])
|
||||||
|
f0_res = torch.from_numpy(f0_res).unsqueeze(0).float().to(self.device) # 1, n_frames
|
||||||
|
|
||||||
|
# enhance
|
||||||
|
enhanced_audio, enhancer_sample_rate = self.enhancer(audio_res, f0_res)
|
||||||
|
|
||||||
|
# resample the enhanced output
|
||||||
|
if adaptive_sample_rate != enhancer_sample_rate:
|
||||||
|
key_str = str(adaptive_sample_rate) + str(enhancer_sample_rate)
|
||||||
|
if key_str not in self.resample_kernel:
|
||||||
|
self.resample_kernel[key_str] = Resample(adaptive_sample_rate, enhancer_sample_rate, lowpass_filter_width=128).to(self.device)
|
||||||
|
enhanced_audio = self.resample_kernel[key_str](enhanced_audio)
|
||||||
|
|
||||||
|
# pad the silence frames
|
||||||
|
if start_frame > 0:
|
||||||
|
enhanced_audio = F.pad(enhanced_audio, (int(np.round(enhancer_sample_rate * real_silence_front)), 0))
|
||||||
|
|
||||||
|
return enhanced_audio, enhancer_sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
class NsfHifiGAN(torch.nn.Module):
|
||||||
|
def __init__(self, model_path, device=None):
|
||||||
|
super().__init__()
|
||||||
|
if device is None:
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
self.device = device
|
||||||
|
print("| Load HifiGAN: ", model_path)
|
||||||
|
self.model, self.h = load_model(model_path, device=self.device)
|
||||||
|
self.stft = STFT(self.h.sampling_rate, self.h.num_mels, self.h.n_fft, self.h.win_size, self.h.hop_size, self.h.fmin, self.h.fmax)
|
||||||
|
|
||||||
|
def sample_rate(self):
|
||||||
|
return self.h.sampling_rate
|
||||||
|
|
||||||
|
def hop_size(self):
|
||||||
|
return self.h.hop_size
|
||||||
|
|
||||||
|
def forward(self, audio, f0):
|
||||||
|
with torch.no_grad():
|
||||||
|
mel = self.stft.get_mel(audio)
|
||||||
|
enhanced_audio = self.model(mel, f0[:, : mel.size(-1)])
|
||||||
|
return enhanced_audio, self.h.sampling_rate
|
15
server/voice_changer/DDSP_SVC/models/nsf_hifigan/env.py
Normal file
15
server/voice_changer/DDSP_SVC/models/nsf_hifigan/env.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
|
class AttrDict(dict):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(AttrDict, self).__init__(*args, **kwargs)
|
||||||
|
self.__dict__ = self
|
||||||
|
|
||||||
|
|
||||||
|
def build_env(config, config_name, path):
|
||||||
|
t_path = os.path.join(path, config_name)
|
||||||
|
if config != t_path:
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
shutil.copyfile(config, os.path.join(path, config_name))
|
434
server/voice_changer/DDSP_SVC/models/nsf_hifigan/models.py
Normal file
434
server/voice_changer/DDSP_SVC/models/nsf_hifigan/models.py
Normal file
@ -0,0 +1,434 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
from .env import AttrDict
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
||||||
|
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||||
|
from .utils import init_weights, get_padding
|
||||||
|
|
||||||
|
LRELU_SLOPE = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_path, device='cuda'):
|
||||||
|
h = load_config(model_path)
|
||||||
|
|
||||||
|
generator = Generator(h).to(device)
|
||||||
|
|
||||||
|
cp_dict = torch.load(model_path, map_location=device)
|
||||||
|
generator.load_state_dict(cp_dict['generator'])
|
||||||
|
generator.eval()
|
||||||
|
generator.remove_weight_norm()
|
||||||
|
del cp_dict
|
||||||
|
return generator, h
|
||||||
|
|
||||||
|
def load_config(model_path):
|
||||||
|
config_file = os.path.join(os.path.split(model_path)[0], 'config.json')
|
||||||
|
with open(config_file) as f:
|
||||||
|
data = f.read()
|
||||||
|
|
||||||
|
json_config = json.loads(data)
|
||||||
|
h = AttrDict(json_config)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock1(torch.nn.Module):
|
||||||
|
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||||
|
super(ResBlock1, self).__init__()
|
||||||
|
self.h = h
|
||||||
|
self.convs1 = nn.ModuleList([
|
||||||
|
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0]))),
|
||||||
|
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1]))),
|
||||||
|
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
||||||
|
padding=get_padding(kernel_size, dilation[2])))
|
||||||
|
])
|
||||||
|
self.convs1.apply(init_weights)
|
||||||
|
|
||||||
|
self.convs2 = nn.ModuleList([
|
||||||
|
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1))),
|
||||||
|
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1))),
|
||||||
|
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1)))
|
||||||
|
])
|
||||||
|
self.convs2.apply(init_weights)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for c1, c2 in zip(self.convs1, self.convs2):
|
||||||
|
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
xt = c1(xt)
|
||||||
|
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||||
|
xt = c2(xt)
|
||||||
|
x = xt + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
for l in self.convs1:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
for l in self.convs2:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock2(torch.nn.Module):
|
||||||
|
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
||||||
|
super(ResBlock2, self).__init__()
|
||||||
|
self.h = h
|
||||||
|
self.convs = nn.ModuleList([
|
||||||
|
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0]))),
|
||||||
|
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1])))
|
||||||
|
])
|
||||||
|
self.convs.apply(init_weights)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for c in self.convs:
|
||||||
|
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
xt = c(xt)
|
||||||
|
x = xt + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
for l in self.convs:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
|
||||||
|
|
||||||
|
class SineGen(torch.nn.Module):
|
||||||
|
""" Definition of sine generator
|
||||||
|
SineGen(samp_rate, harmonic_num = 0,
|
||||||
|
sine_amp = 0.1, noise_std = 0.003,
|
||||||
|
voiced_threshold = 0,
|
||||||
|
flag_for_pulse=False)
|
||||||
|
samp_rate: sampling rate in Hz
|
||||||
|
harmonic_num: number of harmonic overtones (default 0)
|
||||||
|
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
||||||
|
noise_std: std of Gaussian noise (default 0.003)
|
||||||
|
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
||||||
|
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
||||||
|
Note: when flag_for_pulse is True, the first time step of a voiced
|
||||||
|
segment is always sin(np.pi) or cos(0)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, samp_rate, harmonic_num=0,
|
||||||
|
sine_amp=0.1, noise_std=0.003,
|
||||||
|
voiced_threshold=0):
|
||||||
|
super(SineGen, self).__init__()
|
||||||
|
self.sine_amp = sine_amp
|
||||||
|
self.noise_std = noise_std
|
||||||
|
self.harmonic_num = harmonic_num
|
||||||
|
self.dim = self.harmonic_num + 1
|
||||||
|
self.sampling_rate = samp_rate
|
||||||
|
self.voiced_threshold = voiced_threshold
|
||||||
|
|
||||||
|
def _f02uv(self, f0):
|
||||||
|
# generate uv signal
|
||||||
|
uv = torch.ones_like(f0)
|
||||||
|
uv = uv * (f0 > self.voiced_threshold)
|
||||||
|
return uv
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, f0, upp):
|
||||||
|
""" sine_tensor, uv = forward(f0)
|
||||||
|
input F0: tensor(batchsize=1, length, dim=1)
|
||||||
|
f0 for unvoiced steps should be 0
|
||||||
|
output sine_tensor: tensor(batchsize=1, length, dim)
|
||||||
|
output uv: tensor(batchsize=1, length, 1)
|
||||||
|
"""
|
||||||
|
f0 = f0.unsqueeze(-1)
|
||||||
|
fn = torch.multiply(f0, torch.arange(1, self.dim + 1, device=f0.device).reshape((1, 1, -1)))
|
||||||
|
rad_values = (fn / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
|
||||||
|
rand_ini = torch.rand(fn.shape[0], fn.shape[2], device=fn.device)
|
||||||
|
rand_ini[:, 0] = 0
|
||||||
|
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
||||||
|
is_half = rad_values.dtype is not torch.float32
|
||||||
|
tmp_over_one = torch.cumsum(rad_values.double(), 1) # % 1 #####%1意味着后面的cumsum无法再优化
|
||||||
|
if is_half:
|
||||||
|
tmp_over_one = tmp_over_one.half()
|
||||||
|
else:
|
||||||
|
tmp_over_one = tmp_over_one.float()
|
||||||
|
tmp_over_one *= upp
|
||||||
|
tmp_over_one = F.interpolate(
|
||||||
|
tmp_over_one.transpose(2, 1), scale_factor=upp,
|
||||||
|
mode='linear', align_corners=True
|
||||||
|
).transpose(2, 1)
|
||||||
|
rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1)
|
||||||
|
tmp_over_one %= 1
|
||||||
|
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
|
||||||
|
cumsum_shift = torch.zeros_like(rad_values)
|
||||||
|
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
||||||
|
rad_values = rad_values.double()
|
||||||
|
cumsum_shift = cumsum_shift.double()
|
||||||
|
sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
|
||||||
|
if is_half:
|
||||||
|
sine_waves = sine_waves.half()
|
||||||
|
else:
|
||||||
|
sine_waves = sine_waves.float()
|
||||||
|
sine_waves = sine_waves * self.sine_amp
|
||||||
|
return sine_waves
|
||||||
|
|
||||||
|
|
||||||
|
class SourceModuleHnNSF(torch.nn.Module):
|
||||||
|
""" SourceModule for hn-nsf
|
||||||
|
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
||||||
|
add_noise_std=0.003, voiced_threshod=0)
|
||||||
|
sampling_rate: sampling_rate in Hz
|
||||||
|
harmonic_num: number of harmonic above F0 (default: 0)
|
||||||
|
sine_amp: amplitude of sine source signal (default: 0.1)
|
||||||
|
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
||||||
|
note that amplitude of noise in unvoiced is decided
|
||||||
|
by sine_amp
|
||||||
|
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
||||||
|
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||||
|
F0_sampled (batchsize, length, 1)
|
||||||
|
Sine_source (batchsize, length, 1)
|
||||||
|
noise_source (batchsize, length 1)
|
||||||
|
uv (batchsize, length, 1)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
|
||||||
|
add_noise_std=0.003, voiced_threshod=0):
|
||||||
|
super(SourceModuleHnNSF, self).__init__()
|
||||||
|
|
||||||
|
self.sine_amp = sine_amp
|
||||||
|
self.noise_std = add_noise_std
|
||||||
|
|
||||||
|
# to produce sine waveforms
|
||||||
|
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
|
||||||
|
sine_amp, add_noise_std, voiced_threshod)
|
||||||
|
|
||||||
|
# to merge source harmonics into a single excitation
|
||||||
|
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
||||||
|
self.l_tanh = torch.nn.Tanh()
|
||||||
|
|
||||||
|
def forward(self, x, upp):
|
||||||
|
sine_wavs = self.l_sin_gen(x, upp)
|
||||||
|
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||||
|
return sine_merge
|
||||||
|
|
||||||
|
|
||||||
|
class Generator(torch.nn.Module):
|
||||||
|
def __init__(self, h):
|
||||||
|
super(Generator, self).__init__()
|
||||||
|
self.h = h
|
||||||
|
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||||
|
self.num_upsamples = len(h.upsample_rates)
|
||||||
|
self.m_source = SourceModuleHnNSF(
|
||||||
|
sampling_rate=h.sampling_rate,
|
||||||
|
harmonic_num=8
|
||||||
|
)
|
||||||
|
self.noise_convs = nn.ModuleList()
|
||||||
|
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
|
||||||
|
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
||||||
|
|
||||||
|
self.ups = nn.ModuleList()
|
||||||
|
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||||
|
c_cur = h.upsample_initial_channel // (2 ** (i + 1))
|
||||||
|
self.ups.append(weight_norm(
|
||||||
|
ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)),
|
||||||
|
k, u, padding=(k - u) // 2)))
|
||||||
|
if i + 1 < len(h.upsample_rates): #
|
||||||
|
stride_f0 = int(np.prod(h.upsample_rates[i + 1:]))
|
||||||
|
self.noise_convs.append(Conv1d(
|
||||||
|
1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
|
||||||
|
else:
|
||||||
|
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
|
||||||
|
self.resblocks = nn.ModuleList()
|
||||||
|
ch = h.upsample_initial_channel
|
||||||
|
for i in range(len(self.ups)):
|
||||||
|
ch //= 2
|
||||||
|
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||||
|
self.resblocks.append(resblock(h, ch, k, d))
|
||||||
|
|
||||||
|
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
||||||
|
self.ups.apply(init_weights)
|
||||||
|
self.conv_post.apply(init_weights)
|
||||||
|
self.upp = int(np.prod(h.upsample_rates))
|
||||||
|
|
||||||
|
def forward(self, x, f0):
|
||||||
|
har_source = self.m_source(f0, self.upp).transpose(1, 2)
|
||||||
|
x = self.conv_pre(x)
|
||||||
|
for i in range(self.num_upsamples):
|
||||||
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
x = self.ups[i](x)
|
||||||
|
x_source = self.noise_convs[i](har_source)
|
||||||
|
x = x + x_source
|
||||||
|
xs = None
|
||||||
|
for j in range(self.num_kernels):
|
||||||
|
if xs is None:
|
||||||
|
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
else:
|
||||||
|
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
x = xs / self.num_kernels
|
||||||
|
x = F.leaky_relu(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
x = torch.tanh(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
print('Removing weight norm...')
|
||||||
|
for l in self.ups:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
for l in self.resblocks:
|
||||||
|
l.remove_weight_norm()
|
||||||
|
remove_weight_norm(self.conv_pre)
|
||||||
|
remove_weight_norm(self.conv_post)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorP(torch.nn.Module):
|
||||||
|
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||||
|
super(DiscriminatorP, self).__init__()
|
||||||
|
self.period = period
|
||||||
|
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||||
|
self.convs = nn.ModuleList([
|
||||||
|
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||||
|
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||||
|
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||||
|
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||||
|
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||||
|
])
|
||||||
|
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
fmap = []
|
||||||
|
|
||||||
|
# 1d to 2d
|
||||||
|
b, c, t = x.shape
|
||||||
|
if t % self.period != 0: # pad first
|
||||||
|
n_pad = self.period - (t % self.period)
|
||||||
|
x = F.pad(x, (0, n_pad), "reflect")
|
||||||
|
t = t + n_pad
|
||||||
|
x = x.view(b, c, t // self.period, self.period)
|
||||||
|
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x)
|
||||||
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
fmap.append(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
fmap.append(x)
|
||||||
|
x = torch.flatten(x, 1, -1)
|
||||||
|
|
||||||
|
return x, fmap
|
||||||
|
|
||||||
|
|
||||||
|
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||||
|
def __init__(self, periods=None):
|
||||||
|
super(MultiPeriodDiscriminator, self).__init__()
|
||||||
|
self.periods = periods if periods is not None else [2, 3, 5, 7, 11]
|
||||||
|
self.discriminators = nn.ModuleList()
|
||||||
|
for period in self.periods:
|
||||||
|
self.discriminators.append(DiscriminatorP(period))
|
||||||
|
|
||||||
|
def forward(self, y, y_hat):
|
||||||
|
y_d_rs = []
|
||||||
|
y_d_gs = []
|
||||||
|
fmap_rs = []
|
||||||
|
fmap_gs = []
|
||||||
|
for i, d in enumerate(self.discriminators):
|
||||||
|
y_d_r, fmap_r = d(y)
|
||||||
|
y_d_g, fmap_g = d(y_hat)
|
||||||
|
y_d_rs.append(y_d_r)
|
||||||
|
fmap_rs.append(fmap_r)
|
||||||
|
y_d_gs.append(y_d_g)
|
||||||
|
fmap_gs.append(fmap_g)
|
||||||
|
|
||||||
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorS(torch.nn.Module):
|
||||||
|
def __init__(self, use_spectral_norm=False):
|
||||||
|
super(DiscriminatorS, self).__init__()
|
||||||
|
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||||
|
self.convs = nn.ModuleList([
|
||||||
|
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
||||||
|
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
||||||
|
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
||||||
|
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
||||||
|
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
||||||
|
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
||||||
|
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||||
|
])
|
||||||
|
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
fmap = []
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x)
|
||||||
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
fmap.append(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
fmap.append(x)
|
||||||
|
x = torch.flatten(x, 1, -1)
|
||||||
|
|
||||||
|
return x, fmap
|
||||||
|
|
||||||
|
|
||||||
|
class MultiScaleDiscriminator(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MultiScaleDiscriminator, self).__init__()
|
||||||
|
self.discriminators = nn.ModuleList([
|
||||||
|
DiscriminatorS(use_spectral_norm=True),
|
||||||
|
DiscriminatorS(),
|
||||||
|
DiscriminatorS(),
|
||||||
|
])
|
||||||
|
self.meanpools = nn.ModuleList([
|
||||||
|
AvgPool1d(4, 2, padding=2),
|
||||||
|
AvgPool1d(4, 2, padding=2)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, y, y_hat):
|
||||||
|
y_d_rs = []
|
||||||
|
y_d_gs = []
|
||||||
|
fmap_rs = []
|
||||||
|
fmap_gs = []
|
||||||
|
for i, d in enumerate(self.discriminators):
|
||||||
|
if i != 0:
|
||||||
|
y = self.meanpools[i - 1](y)
|
||||||
|
y_hat = self.meanpools[i - 1](y_hat)
|
||||||
|
y_d_r, fmap_r = d(y)
|
||||||
|
y_d_g, fmap_g = d(y_hat)
|
||||||
|
y_d_rs.append(y_d_r)
|
||||||
|
fmap_rs.append(fmap_r)
|
||||||
|
y_d_gs.append(y_d_g)
|
||||||
|
fmap_gs.append(fmap_g)
|
||||||
|
|
||||||
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||||
|
|
||||||
|
|
||||||
|
def feature_loss(fmap_r, fmap_g):
|
||||||
|
loss = 0
|
||||||
|
for dr, dg in zip(fmap_r, fmap_g):
|
||||||
|
for rl, gl in zip(dr, dg):
|
||||||
|
loss += torch.mean(torch.abs(rl - gl))
|
||||||
|
|
||||||
|
return loss * 2
|
||||||
|
|
||||||
|
|
||||||
|
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||||
|
loss = 0
|
||||||
|
r_losses = []
|
||||||
|
g_losses = []
|
||||||
|
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||||
|
r_loss = torch.mean((1 - dr) ** 2)
|
||||||
|
g_loss = torch.mean(dg ** 2)
|
||||||
|
loss += (r_loss + g_loss)
|
||||||
|
r_losses.append(r_loss.item())
|
||||||
|
g_losses.append(g_loss.item())
|
||||||
|
|
||||||
|
return loss, r_losses, g_losses
|
||||||
|
|
||||||
|
|
||||||
|
def generator_loss(disc_outputs):
|
||||||
|
loss = 0
|
||||||
|
gen_losses = []
|
||||||
|
for dg in disc_outputs:
|
||||||
|
l = torch.mean((1 - dg) ** 2)
|
||||||
|
gen_losses.append(l)
|
||||||
|
loss += l
|
||||||
|
|
||||||
|
return loss, gen_losses
|
129
server/voice_changer/DDSP_SVC/models/nsf_hifigan/nvSTFT.py
Normal file
129
server/voice_changer/DDSP_SVC/models/nsf_hifigan/nvSTFT.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
import math
|
||||||
|
import os
|
||||||
|
os.environ["LRU_CACHE_CAPACITY"] = "3"
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
import numpy as np
|
||||||
|
import librosa
|
||||||
|
from librosa.util import normalize
|
||||||
|
from librosa.filters import mel as librosa_mel_fn
|
||||||
|
from scipy.io.wavfile import read
|
||||||
|
import soundfile as sf
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
|
||||||
|
sampling_rate = None
|
||||||
|
try:
|
||||||
|
data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile.
|
||||||
|
except Exception as ex:
|
||||||
|
print(f"'{full_path}' failed to load.\nException:")
|
||||||
|
print(ex)
|
||||||
|
if return_empty_on_exception:
|
||||||
|
return [], sampling_rate or target_sr or 48000
|
||||||
|
else:
|
||||||
|
raise Exception(ex)
|
||||||
|
|
||||||
|
if len(data.shape) > 1:
|
||||||
|
data = data[:, 0]
|
||||||
|
assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension)
|
||||||
|
|
||||||
|
if np.issubdtype(data.dtype, np.integer): # if audio data is type int
|
||||||
|
max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX
|
||||||
|
else: # if audio data is type fp32
|
||||||
|
max_mag = max(np.amax(data), -np.amin(data))
|
||||||
|
max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32
|
||||||
|
|
||||||
|
data = torch.FloatTensor(data.astype(np.float32))/max_mag
|
||||||
|
|
||||||
|
if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except
|
||||||
|
return [], sampling_rate or target_sr or 48000
|
||||||
|
if target_sr is not None and sampling_rate != target_sr:
|
||||||
|
data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr))
|
||||||
|
sampling_rate = target_sr
|
||||||
|
|
||||||
|
return data, sampling_rate
|
||||||
|
|
||||||
|
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||||
|
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
||||||
|
|
||||||
|
def dynamic_range_decompression(x, C=1):
|
||||||
|
return np.exp(x) / C
|
||||||
|
|
||||||
|
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||||
|
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||||
|
|
||||||
|
def dynamic_range_decompression_torch(x, C=1):
|
||||||
|
return torch.exp(x) / C
|
||||||
|
|
||||||
|
class STFT():
|
||||||
|
def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
|
||||||
|
self.target_sr = sr
|
||||||
|
|
||||||
|
self.n_mels = n_mels
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.win_size = win_size
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.fmin = fmin
|
||||||
|
self.fmax = fmax
|
||||||
|
self.clip_val = clip_val
|
||||||
|
self.mel_basis = {}
|
||||||
|
self.hann_window = {}
|
||||||
|
|
||||||
|
def get_mel(self, y, keyshift=0, speed=1, center=False):
|
||||||
|
sampling_rate = self.target_sr
|
||||||
|
n_mels = self.n_mels
|
||||||
|
n_fft = self.n_fft
|
||||||
|
win_size = self.win_size
|
||||||
|
hop_length = self.hop_length
|
||||||
|
fmin = self.fmin
|
||||||
|
fmax = self.fmax
|
||||||
|
clip_val = self.clip_val
|
||||||
|
|
||||||
|
factor = 2 ** (keyshift / 12)
|
||||||
|
n_fft_new = int(np.round(n_fft * factor))
|
||||||
|
win_size_new = int(np.round(win_size * factor))
|
||||||
|
hop_length_new = int(np.round(hop_length * speed))
|
||||||
|
|
||||||
|
if torch.min(y) < -1.:
|
||||||
|
print('min value is ', torch.min(y))
|
||||||
|
if torch.max(y) > 1.:
|
||||||
|
print('max value is ', torch.max(y))
|
||||||
|
|
||||||
|
mel_basis_key = str(fmax)+'_'+str(y.device)
|
||||||
|
if mel_basis_key not in self.mel_basis:
|
||||||
|
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
|
||||||
|
self.mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device)
|
||||||
|
|
||||||
|
keyshift_key = str(keyshift)+'_'+str(y.device)
|
||||||
|
if keyshift_key not in self.hann_window:
|
||||||
|
self.hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
|
||||||
|
|
||||||
|
pad_left = (win_size_new - hop_length_new) //2
|
||||||
|
pad_right = max((win_size_new- hop_length_new + 1) //2, win_size_new - y.size(-1) - pad_left)
|
||||||
|
if pad_right < y.size(-1):
|
||||||
|
mode = 'reflect'
|
||||||
|
else:
|
||||||
|
mode = 'constant'
|
||||||
|
y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode = mode)
|
||||||
|
y = y.squeeze(1)
|
||||||
|
|
||||||
|
spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=self.hann_window[keyshift_key],
|
||||||
|
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
|
||||||
|
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9))
|
||||||
|
if keyshift != 0:
|
||||||
|
size = n_fft // 2 + 1
|
||||||
|
resize = spec.size(1)
|
||||||
|
if resize < size:
|
||||||
|
spec = F.pad(spec, (0, 0, 0, size-resize))
|
||||||
|
spec = spec[:, :size, :] * win_size / win_size_new
|
||||||
|
spec = torch.matmul(self.mel_basis[mel_basis_key], spec)
|
||||||
|
spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
|
||||||
|
return spec
|
||||||
|
|
||||||
|
def __call__(self, audiopath):
|
||||||
|
audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr)
|
||||||
|
spect = self.get_mel(audio.unsqueeze(0)).squeeze(0)
|
||||||
|
return spect
|
||||||
|
|
||||||
|
stft = STFT()
|
68
server/voice_changer/DDSP_SVC/models/nsf_hifigan/utils.py
Normal file
68
server/voice_changer/DDSP_SVC/models/nsf_hifigan/utils.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import matplotlib
|
||||||
|
import torch
|
||||||
|
from torch.nn.utils import weight_norm
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pylab as plt
|
||||||
|
|
||||||
|
|
||||||
|
def plot_spectrogram(spectrogram):
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 2))
|
||||||
|
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
||||||
|
interpolation='none')
|
||||||
|
plt.colorbar(im, ax=ax)
|
||||||
|
|
||||||
|
fig.canvas.draw()
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(m, mean=0.0, std=0.01):
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find("Conv") != -1:
|
||||||
|
m.weight.data.normal_(mean, std)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_weight_norm(m):
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find("Conv") != -1:
|
||||||
|
weight_norm(m)
|
||||||
|
|
||||||
|
|
||||||
|
def get_padding(kernel_size, dilation=1):
|
||||||
|
return int((kernel_size*dilation - dilation)/2)
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(filepath, device):
|
||||||
|
assert os.path.isfile(filepath)
|
||||||
|
print("Loading '{}'".format(filepath))
|
||||||
|
checkpoint_dict = torch.load(filepath, map_location=device)
|
||||||
|
print("Complete.")
|
||||||
|
return checkpoint_dict
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(filepath, obj):
|
||||||
|
print("Saving checkpoint to {}".format(filepath))
|
||||||
|
torch.save(obj, filepath)
|
||||||
|
print("Complete.")
|
||||||
|
|
||||||
|
|
||||||
|
def del_old_checkpoints(cp_dir, prefix, n_models=2):
|
||||||
|
pattern = os.path.join(cp_dir, prefix + '????????')
|
||||||
|
cp_list = glob.glob(pattern) # get checkpoint paths
|
||||||
|
cp_list = sorted(cp_list)# sort by iter
|
||||||
|
if len(cp_list) > n_models: # if more than n_models models are found
|
||||||
|
for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models
|
||||||
|
open(cp, 'w').close()# empty file contents
|
||||||
|
os.unlink(cp)# delete file (move to trash when using Colab)
|
||||||
|
|
||||||
|
|
||||||
|
def scan_checkpoint(cp_dir, prefix):
|
||||||
|
pattern = os.path.join(cp_dir, prefix + '????????')
|
||||||
|
cp_list = glob.glob(pattern)
|
||||||
|
if len(cp_list) == 0:
|
||||||
|
return None
|
||||||
|
return sorted(cp_list)[-1]
|
||||||
|
|
2
server/voice_changer/SoVitsSvc40/models/readme.txt
Normal file
2
server/voice_changer/SoVitsSvc40/models/readme.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
modules in this folder from https://github.com/svc-develop-team/so-vits-svc at 7846711d90b5210d4f39a4d6fab50d1d7bbd8d73
|
||||||
|
forked repo: https://github.com/w-okada/so-vits-svc
|
@ -1,466 +0,0 @@
|
|||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
from voice_changer.utils.LoadModelParams import LoadModelParams
|
|
||||||
from voice_changer.utils.VoiceChangerModel import AudioInOut
|
|
||||||
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
|
|
||||||
|
|
||||||
if sys.platform.startswith("darwin"):
|
|
||||||
baseDir = [x for x in sys.path if x.endswith("Contents/MacOS")]
|
|
||||||
if len(baseDir) != 1:
|
|
||||||
print("baseDir should be only one ", baseDir)
|
|
||||||
sys.exit()
|
|
||||||
modulePath = os.path.join(baseDir[0], "so-vits-svc-40v2")
|
|
||||||
sys.path.append(modulePath)
|
|
||||||
else:
|
|
||||||
sys.path.append("so-vits-svc-40v2")
|
|
||||||
|
|
||||||
import io
|
|
||||||
from dataclasses import dataclass, asdict, field
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import onnxruntime
|
|
||||||
import pyworld as pw
|
|
||||||
|
|
||||||
from models import SynthesizerTrn # type:ignore
|
|
||||||
import cluster # type:ignore
|
|
||||||
import utils
|
|
||||||
from fairseq import checkpoint_utils
|
|
||||||
import librosa
|
|
||||||
|
|
||||||
from Exceptions import NoModeLoadedException
|
|
||||||
|
|
||||||
providers = [
|
|
||||||
"OpenVINOExecutionProvider",
|
|
||||||
"CUDAExecutionProvider",
|
|
||||||
"DmlExecutionProvider",
|
|
||||||
"CPUExecutionProvider",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SoVitsSvc40v2Settings:
|
|
||||||
gpu: int = 0
|
|
||||||
dstId: int = 0
|
|
||||||
|
|
||||||
f0Detector: str = "harvest" # dio or harvest
|
|
||||||
tran: int = 20
|
|
||||||
noiseScale: float = 0.3
|
|
||||||
predictF0: int = 0 # 0:False, 1:True
|
|
||||||
silentThreshold: float = 0.00001
|
|
||||||
extraConvertSize: int = 1024 * 32
|
|
||||||
clusterInferRatio: float = 0.1
|
|
||||||
|
|
||||||
framework: str = "PyTorch" # PyTorch or ONNX
|
|
||||||
pyTorchModelFile: str | None = ""
|
|
||||||
onnxModelFile: str | None = ""
|
|
||||||
configFile: str = ""
|
|
||||||
|
|
||||||
speakers: dict[str, int] = field(default_factory=lambda: {})
|
|
||||||
|
|
||||||
# ↓mutableな物だけ列挙
|
|
||||||
intData = ["gpu", "dstId", "tran", "predictF0", "extraConvertSize"]
|
|
||||||
floatData = ["noiseScale", "silentThreshold", "clusterInferRatio"]
|
|
||||||
strData = ["framework", "f0Detector"]
|
|
||||||
|
|
||||||
|
|
||||||
class SoVitsSvc40v2:
|
|
||||||
audio_buffer: AudioInOut | None = None
|
|
||||||
|
|
||||||
def __init__(self, params: VoiceChangerParams):
|
|
||||||
self.settings = SoVitsSvc40v2Settings()
|
|
||||||
self.net_g = None
|
|
||||||
self.onnx_session = None
|
|
||||||
|
|
||||||
self.raw_path = io.BytesIO()
|
|
||||||
self.gpu_num = torch.cuda.device_count()
|
|
||||||
self.prevVol = 0
|
|
||||||
self.params = params
|
|
||||||
print("[Voice Changer] so-vits-svc 40v2 initialization:", params)
|
|
||||||
|
|
||||||
def loadModel(self, props: LoadModelParams):
|
|
||||||
params = props.params
|
|
||||||
self.settings.configFile = params["files"]["soVitsSvc40v2Config"]
|
|
||||||
self.hps = utils.get_hparams_from_file(self.settings.configFile)
|
|
||||||
self.settings.speakers = self.hps.spk
|
|
||||||
|
|
||||||
modelFile = params["files"]["soVitsSvc40v2Model"]
|
|
||||||
if modelFile.endswith(".onnx"):
|
|
||||||
self.settings.pyTorchModelFile = None
|
|
||||||
self.settings.onnxModelFile = modelFile
|
|
||||||
else:
|
|
||||||
self.settings.pyTorchModelFile = modelFile
|
|
||||||
self.settings.onnxModelFile = None
|
|
||||||
|
|
||||||
clusterTorchModel = params["files"]["soVitsSvc40v2Cluster"]
|
|
||||||
|
|
||||||
content_vec_path = self.params.content_vec_500
|
|
||||||
hubert_base_path = self.params.hubert_base
|
|
||||||
|
|
||||||
# hubert model
|
|
||||||
try:
|
|
||||||
if os.path.exists(content_vec_path) is False:
|
|
||||||
content_vec_path = hubert_base_path
|
|
||||||
|
|
||||||
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
|
||||||
[content_vec_path],
|
|
||||||
suffix="",
|
|
||||||
)
|
|
||||||
model = models[0]
|
|
||||||
model.eval()
|
|
||||||
self.hubert_model = model.cpu()
|
|
||||||
except Exception as e:
|
|
||||||
print("EXCEPTION during loading hubert/contentvec model", e)
|
|
||||||
|
|
||||||
# cluster
|
|
||||||
try:
|
|
||||||
if clusterTorchModel is not None and os.path.exists(clusterTorchModel):
|
|
||||||
self.cluster_model = cluster.get_cluster_model(clusterTorchModel)
|
|
||||||
else:
|
|
||||||
self.cluster_model = None
|
|
||||||
except Exception as e:
|
|
||||||
print("EXCEPTION during loading cluster model ", e)
|
|
||||||
|
|
||||||
# PyTorchモデル生成
|
|
||||||
if self.settings.pyTorchModelFile is not None:
|
|
||||||
net_g = SynthesizerTrn(self.hps)
|
|
||||||
net_g.eval()
|
|
||||||
self.net_g = net_g
|
|
||||||
utils.load_checkpoint(self.settings.pyTorchModelFile, self.net_g, None)
|
|
||||||
|
|
||||||
# ONNXモデル生成
|
|
||||||
if self.settings.onnxModelFile is not None:
|
|
||||||
providers, options = self.getOnnxExecutionProvider()
|
|
||||||
self.onnx_session = onnxruntime.InferenceSession(
|
|
||||||
self.settings.onnxModelFile,
|
|
||||||
providers=providers,
|
|
||||||
provider_options=options,
|
|
||||||
)
|
|
||||||
return self.get_info()
|
|
||||||
|
|
||||||
def getOnnxExecutionProvider(self):
|
|
||||||
availableProviders = onnxruntime.get_available_providers()
|
|
||||||
if self.settings.gpu >= 0 and "CUDAExecutionProvider" in availableProviders:
|
|
||||||
return ["CUDAExecutionProvider"], [{"device_id": self.settings.gpu}]
|
|
||||||
elif self.settings.gpu >= 0 and "DmlExecutionProvider" in availableProviders:
|
|
||||||
return ["DmlExecutionProvider"], [{}]
|
|
||||||
else:
|
|
||||||
return ["CPUExecutionProvider"], [
|
|
||||||
{
|
|
||||||
"intra_op_num_threads": 8,
|
|
||||||
"execution_mode": onnxruntime.ExecutionMode.ORT_PARALLEL,
|
|
||||||
"inter_op_num_threads": 8,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
def isOnnx(self):
|
|
||||||
if self.settings.onnxModelFile is not None:
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def update_settings(self, key: str, val: int | float | str):
|
|
||||||
if key in self.settings.intData:
|
|
||||||
val = int(val)
|
|
||||||
setattr(self.settings, key, val)
|
|
||||||
|
|
||||||
if key == "gpu" and self.isOnnx():
|
|
||||||
providers, options = self.getOnnxExecutionProvider()
|
|
||||||
if self.onnx_session is not None:
|
|
||||||
self.onnx_session.set_providers(
|
|
||||||
providers=providers,
|
|
||||||
provider_options=options,
|
|
||||||
)
|
|
||||||
elif key in self.settings.floatData:
|
|
||||||
setattr(self.settings, key, float(val))
|
|
||||||
elif key in self.settings.strData:
|
|
||||||
setattr(self.settings, key, str(val))
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_info(self):
|
|
||||||
data = asdict(self.settings)
|
|
||||||
|
|
||||||
data["onnxExecutionProviders"] = (
|
|
||||||
self.onnx_session.get_providers() if self.onnx_session is not None else []
|
|
||||||
)
|
|
||||||
files = ["configFile", "pyTorchModelFile", "onnxModelFile"]
|
|
||||||
for f in files:
|
|
||||||
if data[f] is not None and os.path.exists(data[f]):
|
|
||||||
data[f] = os.path.basename(data[f])
|
|
||||||
else:
|
|
||||||
data[f] = ""
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def get_processing_sampling_rate(self):
|
|
||||||
if hasattr(self, "hps") is False:
|
|
||||||
raise NoModeLoadedException("config")
|
|
||||||
return self.hps.data.sampling_rate
|
|
||||||
|
|
||||||
def get_unit_f0(self, audio_buffer, tran):
|
|
||||||
wav_44k = audio_buffer
|
|
||||||
# f0 = utils.compute_f0_parselmouth(wav, sampling_rate=self.target_sample, hop_length=self.hop_size)
|
|
||||||
# f0 = utils.compute_f0_dio(wav_44k, sampling_rate=self.hps.data.sampling_rate, hop_length=self.hps.data.hop_length)
|
|
||||||
if self.settings.f0Detector == "dio":
|
|
||||||
f0 = compute_f0_dio(
|
|
||||||
wav_44k,
|
|
||||||
sampling_rate=self.hps.data.sampling_rate,
|
|
||||||
hop_length=self.hps.data.hop_length,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
f0 = compute_f0_harvest(
|
|
||||||
wav_44k,
|
|
||||||
sampling_rate=self.hps.data.sampling_rate,
|
|
||||||
hop_length=self.hps.data.hop_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
if wav_44k.shape[0] % self.hps.data.hop_length != 0:
|
|
||||||
print(
|
|
||||||
f" !!! !!! !!! wav size not multiple of hopsize: {wav_44k.shape[0] / self.hps.data.hop_length}"
|
|
||||||
)
|
|
||||||
|
|
||||||
f0, uv = utils.interpolate_f0(f0)
|
|
||||||
f0 = torch.FloatTensor(f0)
|
|
||||||
uv = torch.FloatTensor(uv)
|
|
||||||
f0 = f0 * 2 ** (tran / 12)
|
|
||||||
f0 = f0.unsqueeze(0)
|
|
||||||
uv = uv.unsqueeze(0)
|
|
||||||
|
|
||||||
# wav16k = librosa.resample(audio_buffer, orig_sr=24000, target_sr=16000)
|
|
||||||
wav16k = librosa.resample(
|
|
||||||
audio_buffer, orig_sr=self.hps.data.sampling_rate, target_sr=16000
|
|
||||||
)
|
|
||||||
wav16k = torch.from_numpy(wav16k)
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.settings.gpu < 0 or self.gpu_num == 0
|
|
||||||
) or self.settings.framework == "ONNX":
|
|
||||||
dev = torch.device("cpu")
|
|
||||||
else:
|
|
||||||
dev = torch.device("cuda", index=self.settings.gpu)
|
|
||||||
|
|
||||||
self.hubert_model = self.hubert_model.to(dev)
|
|
||||||
wav16k = wav16k.to(dev)
|
|
||||||
uv = uv.to(dev)
|
|
||||||
f0 = f0.to(dev)
|
|
||||||
|
|
||||||
c = utils.get_hubert_content(self.hubert_model, wav_16k_tensor=wav16k)
|
|
||||||
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1])
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.settings.clusterInferRatio != 0
|
|
||||||
and hasattr(self, "cluster_model")
|
|
||||||
and self.cluster_model is not None
|
|
||||||
):
|
|
||||||
speaker = [
|
|
||||||
key
|
|
||||||
for key, value in self.settings.speakers.items()
|
|
||||||
if value == self.settings.dstId
|
|
||||||
]
|
|
||||||
if len(speaker) != 1:
|
|
||||||
pass
|
|
||||||
# print("not only one speaker found.", speaker)
|
|
||||||
else:
|
|
||||||
cluster_c = cluster.get_cluster_center_result(
|
|
||||||
self.cluster_model, c.cpu().numpy().T, speaker[0]
|
|
||||||
).T
|
|
||||||
# cluster_c = cluster.get_cluster_center_result(self.cluster_model, c.cpu().numpy().T, self.settings.dstId).T
|
|
||||||
cluster_c = torch.FloatTensor(cluster_c).to(dev)
|
|
||||||
# print("cluster DEVICE", cluster_c.device, c.device)
|
|
||||||
c = (
|
|
||||||
self.settings.clusterInferRatio * cluster_c
|
|
||||||
+ (1 - self.settings.clusterInferRatio) * c
|
|
||||||
)
|
|
||||||
|
|
||||||
c = c.unsqueeze(0)
|
|
||||||
return c, f0, uv
|
|
||||||
|
|
||||||
def generate_input(
|
|
||||||
self,
|
|
||||||
newData: AudioInOut,
|
|
||||||
inputSize: int,
|
|
||||||
crossfadeSize: int,
|
|
||||||
solaSearchFrame: int = 0,
|
|
||||||
):
|
|
||||||
newData = newData.astype(np.float32) / self.hps.data.max_wav_value
|
|
||||||
|
|
||||||
if self.audio_buffer is not None:
|
|
||||||
self.audio_buffer = np.concatenate(
|
|
||||||
[self.audio_buffer, newData], 0
|
|
||||||
) # 過去のデータに連結
|
|
||||||
else:
|
|
||||||
self.audio_buffer = newData
|
|
||||||
|
|
||||||
convertSize = (
|
|
||||||
inputSize + crossfadeSize + solaSearchFrame + self.settings.extraConvertSize
|
|
||||||
)
|
|
||||||
|
|
||||||
if convertSize % self.hps.data.hop_length != 0: # モデルの出力のホップサイズで切り捨てが発生するので補う。
|
|
||||||
convertSize = convertSize + (
|
|
||||||
self.hps.data.hop_length - (convertSize % self.hps.data.hop_length)
|
|
||||||
)
|
|
||||||
convertOffset = -1 * convertSize
|
|
||||||
self.audio_buffer = self.audio_buffer[convertOffset:] # 変換対象の部分だけ抽出
|
|
||||||
|
|
||||||
cropOffset = -1 * (inputSize + crossfadeSize)
|
|
||||||
cropEnd = -1 * (crossfadeSize)
|
|
||||||
crop = self.audio_buffer[cropOffset:cropEnd]
|
|
||||||
|
|
||||||
rms = np.sqrt(np.square(crop).mean(axis=0))
|
|
||||||
vol = max(rms, self.prevVol * 0.0)
|
|
||||||
self.prevVol = vol
|
|
||||||
|
|
||||||
c, f0, uv = self.get_unit_f0(self.audio_buffer, self.settings.tran)
|
|
||||||
return (c, f0, uv, convertSize, vol)
|
|
||||||
|
|
||||||
def _onnx_inference(self, data):
|
|
||||||
if hasattr(self, "onnx_session") is False or self.onnx_session is None:
|
|
||||||
print("[Voice Changer] No onnx session.")
|
|
||||||
raise NoModeLoadedException("ONNX")
|
|
||||||
|
|
||||||
convertSize = data[3]
|
|
||||||
vol = data[4]
|
|
||||||
data = (
|
|
||||||
data[0],
|
|
||||||
data[1],
|
|
||||||
data[2],
|
|
||||||
)
|
|
||||||
|
|
||||||
if vol < self.settings.silentThreshold:
|
|
||||||
return np.zeros(convertSize).astype(np.int16)
|
|
||||||
|
|
||||||
c, f0, uv = [x.numpy() for x in data]
|
|
||||||
audio1 = (
|
|
||||||
self.onnx_session.run(
|
|
||||||
["audio"],
|
|
||||||
{
|
|
||||||
"c": c,
|
|
||||||
"f0": f0,
|
|
||||||
"g": np.array([self.settings.dstId]).astype(np.int64),
|
|
||||||
"uv": np.array([self.settings.dstId]).astype(np.int64),
|
|
||||||
"predict_f0": np.array([self.settings.dstId]).astype(np.int64),
|
|
||||||
"noice_scale": np.array([self.settings.dstId]).astype(np.int64),
|
|
||||||
},
|
|
||||||
)[0][0, 0]
|
|
||||||
* self.hps.data.max_wav_value
|
|
||||||
)
|
|
||||||
|
|
||||||
audio1 = audio1 * vol
|
|
||||||
|
|
||||||
result = audio1
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _pyTorch_inference(self, data):
|
|
||||||
if hasattr(self, "net_g") is False or self.net_g is None:
|
|
||||||
print("[Voice Changer] No pyTorch session.")
|
|
||||||
raise NoModeLoadedException("pytorch")
|
|
||||||
|
|
||||||
if self.settings.gpu < 0 or self.gpu_num == 0:
|
|
||||||
dev = torch.device("cpu")
|
|
||||||
else:
|
|
||||||
dev = torch.device("cuda", index=self.settings.gpu)
|
|
||||||
|
|
||||||
convertSize = data[3]
|
|
||||||
vol = data[4]
|
|
||||||
data = (
|
|
||||||
data[0],
|
|
||||||
data[1],
|
|
||||||
data[2],
|
|
||||||
)
|
|
||||||
|
|
||||||
if vol < self.settings.silentThreshold:
|
|
||||||
return np.zeros(convertSize).astype(np.int16)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
c, f0, uv = [x.to(dev) for x in data]
|
|
||||||
sid_target = torch.LongTensor([self.settings.dstId]).to(dev)
|
|
||||||
self.net_g.to(dev)
|
|
||||||
# audio1 = self.net_g.infer(c, f0=f0, g=sid_target, uv=uv, predict_f0=True, noice_scale=0.1)[0][0, 0].data.float()
|
|
||||||
predict_f0_flag = True if self.settings.predictF0 == 1 else False
|
|
||||||
audio1 = self.net_g.infer(
|
|
||||||
c,
|
|
||||||
f0=f0,
|
|
||||||
g=sid_target,
|
|
||||||
uv=uv,
|
|
||||||
predict_f0=predict_f0_flag,
|
|
||||||
noice_scale=self.settings.noiseScale,
|
|
||||||
)[0][0, 0].data.float()
|
|
||||||
audio1 = audio1 * self.hps.data.max_wav_value
|
|
||||||
|
|
||||||
audio1 = audio1 * vol
|
|
||||||
|
|
||||||
result = audio1.float().cpu().numpy()
|
|
||||||
|
|
||||||
# result = infer_tool.pad_array(result, length)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def inference(self, data):
|
|
||||||
if self.isOnnx():
|
|
||||||
audio = self._onnx_inference(data)
|
|
||||||
else:
|
|
||||||
audio = self._pyTorch_inference(data)
|
|
||||||
return audio
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
del self.net_g
|
|
||||||
del self.onnx_session
|
|
||||||
|
|
||||||
remove_path = os.path.join("so-vits-svc-40v2")
|
|
||||||
sys.path = [x for x in sys.path if x.endswith(remove_path) is False]
|
|
||||||
|
|
||||||
for key in list(sys.modules):
|
|
||||||
val = sys.modules.get(key)
|
|
||||||
try:
|
|
||||||
file_path = val.__file__
|
|
||||||
if file_path.find("so-vits-svc-40v2" + os.path.sep) >= 0:
|
|
||||||
# print("remove", key, file_path)
|
|
||||||
sys.modules.pop(key)
|
|
||||||
except: # type:ignore
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def resize_f0(x, target_len):
|
|
||||||
source = np.array(x)
|
|
||||||
source[source < 0.001] = np.nan
|
|
||||||
target = np.interp(
|
|
||||||
np.arange(0, len(source) * target_len, len(source)) / target_len,
|
|
||||||
np.arange(0, len(source)),
|
|
||||||
source,
|
|
||||||
)
|
|
||||||
res = np.nan_to_num(target)
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def compute_f0_dio(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512):
|
|
||||||
if p_len is None:
|
|
||||||
p_len = wav_numpy.shape[0] // hop_length
|
|
||||||
f0, t = pw.dio(
|
|
||||||
wav_numpy.astype(np.double),
|
|
||||||
fs=sampling_rate,
|
|
||||||
f0_ceil=800,
|
|
||||||
frame_period=1000 * hop_length / sampling_rate,
|
|
||||||
)
|
|
||||||
f0 = pw.stonemask(wav_numpy.astype(np.double), f0, t, sampling_rate)
|
|
||||||
for index, pitch in enumerate(f0):
|
|
||||||
f0[index] = round(pitch, 1)
|
|
||||||
return resize_f0(f0, p_len)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_f0_harvest(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512):
|
|
||||||
if p_len is None:
|
|
||||||
p_len = wav_numpy.shape[0] // hop_length
|
|
||||||
f0, t = pw.harvest(
|
|
||||||
wav_numpy.astype(np.double),
|
|
||||||
fs=sampling_rate,
|
|
||||||
frame_period=5.5,
|
|
||||||
f0_floor=71.0,
|
|
||||||
f0_ceil=1000.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
for index, pitch in enumerate(f0):
|
|
||||||
f0[index] = round(pitch, 1)
|
|
||||||
return resize_f0(f0, p_len)
|
|
@ -131,9 +131,9 @@ class VoiceChangerManager(ServerDeviceCallbacks):
|
|||||||
slotInfo = SoVitsSvc40ModelSlotGenerator.loadModel(params)
|
slotInfo = SoVitsSvc40ModelSlotGenerator.loadModel(params)
|
||||||
self.modelSlotManager.save_model_slot(params.slot, slotInfo)
|
self.modelSlotManager.save_model_slot(params.slot, slotInfo)
|
||||||
elif params.voiceChangerType == "DDSP-SVC":
|
elif params.voiceChangerType == "DDSP-SVC":
|
||||||
from voice_changer.DDSP_SVC.DDSP_SVC import DDSP_SVC
|
from voice_changer.DDSP_SVC.DDSP_SVCModelSlotGenerator import DDSP_SVCModelSlotGenerator
|
||||||
|
|
||||||
slotInfo = DDSP_SVC.loadModel(params)
|
slotInfo = DDSP_SVCModelSlotGenerator.loadModel(params)
|
||||||
self.modelSlotManager.save_model_slot(params.slot, slotInfo)
|
self.modelSlotManager.save_model_slot(params.slot, slotInfo)
|
||||||
print("params", params)
|
print("params", params)
|
||||||
|
|
||||||
@ -195,6 +195,13 @@ class VoiceChangerManager(ServerDeviceCallbacks):
|
|||||||
self.voiceChangerModel = SoVitsSvc40(self.params, slotInfo)
|
self.voiceChangerModel = SoVitsSvc40(self.params, slotInfo)
|
||||||
self.voiceChanger = VoiceChanger(self.params)
|
self.voiceChanger = VoiceChanger(self.params)
|
||||||
self.voiceChanger.setModel(self.voiceChangerModel)
|
self.voiceChanger.setModel(self.voiceChangerModel)
|
||||||
|
elif slotInfo.voiceChangerType == "DDSP-SVC":
|
||||||
|
print("................DDSP-SVC")
|
||||||
|
from voice_changer.DDSP_SVC.DDSP_SVC import DDSP_SVC
|
||||||
|
|
||||||
|
self.voiceChangerModel = DDSP_SVC(self.params, slotInfo)
|
||||||
|
self.voiceChanger = VoiceChanger(self.params)
|
||||||
|
self.voiceChanger.setModel(self.voiceChangerModel)
|
||||||
else:
|
else:
|
||||||
print(f"[Voice Changer] unknown voice changer model: {slotInfo.voiceChangerType}")
|
print(f"[Voice Changer] unknown voice changer model: {slotInfo.voiceChangerType}")
|
||||||
del self.voiceChangerModel
|
del self.voiceChangerModel
|
||||||
|
Loading…
Reference in New Issue
Block a user