mirror of
https://github.com/w-okada/voice-changer.git
synced 2025-01-23 21:45:00 +03:00
WIP: integrate vcs to new gui 4
This commit is contained in:
parent
4d31d2238d
commit
0201b2e44c
@ -4,7 +4,6 @@ from dataclasses import asdict
|
||||
import numpy as np
|
||||
import torch
|
||||
from data.ModelSlot import DDSPSVCModelSlot
|
||||
from voice_changer.DDSP_SVC.ModelSlot import ModelSlot
|
||||
|
||||
from voice_changer.DDSP_SVC.deviceManager.DeviceManager import DeviceManager
|
||||
|
||||
@ -18,11 +17,10 @@ if sys.platform.startswith("darwin"):
|
||||
else:
|
||||
sys.path.append("DDSP-SVC")
|
||||
|
||||
from diffusion.infer_gt_mel import DiffGtMel # type: ignore
|
||||
from .models.diffusion.infer_gt_mel import DiffGtMel
|
||||
|
||||
from voice_changer.utils.VoiceChangerModel import AudioInOut
|
||||
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
|
||||
from voice_changer.utils.LoadModelParams import LoadModelParams, LoadModelParams2
|
||||
from voice_changer.DDSP_SVC.DDSP_SVCSetting import DDSP_SVCSettings
|
||||
from voice_changer.RVC.embedder.EmbedderManager import EmbedderManager
|
||||
|
||||
@ -51,68 +49,39 @@ def phase_vocoder(a, b, fade_out, fade_in):
|
||||
|
||||
class DDSP_SVC:
|
||||
initialLoad: bool = True
|
||||
settings: DDSP_SVCSettings = DDSP_SVCSettings()
|
||||
diff_model: DiffGtMel = DiffGtMel()
|
||||
svc_model: SvcDDSP = SvcDDSP()
|
||||
|
||||
deviceManager = DeviceManager.get_instance()
|
||||
# diff_model: DiffGtMel = DiffGtMel()
|
||||
|
||||
audio_buffer: AudioInOut | None = None
|
||||
prevVol: float = 0
|
||||
# resample_kernel = {}
|
||||
|
||||
def __init__(self, params: VoiceChangerParams):
|
||||
def __init__(self, params: VoiceChangerParams, slotInfo: DDSPSVCModelSlot):
|
||||
print("[Voice Changer] [DDSP-SVC] Creating instance ")
|
||||
self.deviceManager = DeviceManager.get_instance()
|
||||
self.gpu_num = torch.cuda.device_count()
|
||||
self.params = params
|
||||
self.settings = DDSP_SVCSettings()
|
||||
self.svc_model: SvcDDSP = SvcDDSP()
|
||||
self.diff_model: DiffGtMel = DiffGtMel()
|
||||
|
||||
self.svc_model.setVCParams(params)
|
||||
EmbedderManager.initialize(params)
|
||||
print("[Voice Changer] DDSP-SVC initialization:", params)
|
||||
|
||||
def loadModel(self, props: LoadModelParams):
|
||||
target_slot_idx = props.slot
|
||||
params = props.params
|
||||
self.audio_buffer: AudioInOut | None = None
|
||||
self.prevVol = 0.0
|
||||
self.slotInfo = slotInfo
|
||||
self.initialize()
|
||||
|
||||
modelFile = params["files"]["ddspSvcModel"]
|
||||
diffusionFile = params["files"]["ddspSvcDiffusion"]
|
||||
modelSlot = ModelSlot(
|
||||
modelFile=modelFile,
|
||||
diffusionFile=diffusionFile,
|
||||
defaultTrans=params["trans"] if "trans" in params else 0,
|
||||
)
|
||||
self.settings.modelSlots[target_slot_idx] = modelSlot
|
||||
|
||||
# 初回のみロード
|
||||
# if self.initialLoad:
|
||||
# self.prepareModel(target_slot_idx)
|
||||
# self.settings.modelSlotIndex = target_slot_idx
|
||||
# self.switchModel()
|
||||
# self.initialLoad = False
|
||||
# elif target_slot_idx == self.currentSlot:
|
||||
# self.prepareModel(target_slot_idx)
|
||||
self.settings.modelSlotIndex = target_slot_idx
|
||||
self.reloadModel()
|
||||
|
||||
print("params:", params)
|
||||
return self.get_info()
|
||||
|
||||
def reloadModel(self):
|
||||
def initialize(self):
|
||||
self.device = self.deviceManager.getDevice(self.settings.gpu)
|
||||
modelFile = self.settings.modelSlots[self.settings.modelSlotIndex].modelFile
|
||||
diffusionFile = self.settings.modelSlots[self.settings.modelSlotIndex].diffusionFile
|
||||
|
||||
self.svc_model = SvcDDSP()
|
||||
self.svc_model.setVCParams(self.params)
|
||||
self.svc_model.update_model(modelFile, self.device)
|
||||
self.svc_model.update_model(self.slotInfo.modelFile, self.device)
|
||||
self.diff_model = DiffGtMel(device=self.device)
|
||||
self.diff_model.flush_model(diffusionFile, ddsp_config=self.svc_model.args)
|
||||
self.diff_model.flush_model(self.slotInfo.diffModelFile, ddsp_config=self.svc_model.args)
|
||||
|
||||
def update_settings(self, key: str, val: int | float | str):
|
||||
if key in self.settings.intData:
|
||||
val = int(val)
|
||||
setattr(self.settings, key, val)
|
||||
if key == "gpu":
|
||||
self.reloadModel()
|
||||
self.initialize()
|
||||
elif key in self.settings.floatData:
|
||||
setattr(self.settings, key, float(val))
|
||||
elif key in self.settings.strData:
|
||||
@ -160,10 +129,6 @@ class DDSP_SVC:
|
||||
# raise NoModeLoadedException("ONNX")
|
||||
|
||||
def _pyTorch_inference(self, data):
|
||||
# if hasattr(self, "model") is False or self.model is None:
|
||||
# print("[Voice Changer] No pyTorch session.")
|
||||
# raise NoModeLoadedException("pytorch")
|
||||
|
||||
input_wav = data[0]
|
||||
_audio, _model_sr = self.svc_model.infer(
|
||||
input_wav,
|
||||
@ -192,32 +157,13 @@ class DDSP_SVC:
|
||||
return _audio.cpu().numpy() * 32768.0
|
||||
|
||||
def inference(self, data):
|
||||
if self.settings.framework == "ONNX":
|
||||
if self.slotInfo.isONNX:
|
||||
audio = self._onnx_inference(data)
|
||||
else:
|
||||
audio = self._pyTorch_inference(data)
|
||||
return audio
|
||||
|
||||
@classmethod
|
||||
def loadModel2(cls, props: LoadModelParams2):
|
||||
slotInfo: DDSPSVCModelSlot = DDSPSVCModelSlot()
|
||||
for file in props.files:
|
||||
if file.kind == "ddspSvcModelConfig":
|
||||
slotInfo.configFile = file.name
|
||||
elif file.kind == "ddspSvcModel":
|
||||
slotInfo.modelFile = file.name
|
||||
elif file.kind == "ddspSvcDiffusionConfig":
|
||||
slotInfo.diffConfigFile = file.name
|
||||
elif file.kind == "ddspSvcDiffusion":
|
||||
slotInfo.diffModelFile = file.name
|
||||
slotInfo.isONNX = slotInfo.modelFile.endswith(".onnx")
|
||||
slotInfo.name = os.path.splitext(os.path.basename(slotInfo.modelFile))[0]
|
||||
return slotInfo
|
||||
|
||||
def __del__(self):
|
||||
del self.net_g
|
||||
del self.onnx_session
|
||||
|
||||
remove_path = os.path.join("DDSP-SVC")
|
||||
sys.path = [x for x in sys.path if x.endswith(remove_path) is False]
|
||||
|
||||
|
22
server/voice_changer/DDSP_SVC/DDSP_SVCModelSlotGenerator.py
Normal file
22
server/voice_changer/DDSP_SVC/DDSP_SVCModelSlotGenerator.py
Normal file
@ -0,0 +1,22 @@
|
||||
import os
|
||||
from data.ModelSlot import DDSPSVCModelSlot
|
||||
from voice_changer.utils.LoadModelParams import LoadModelParams
|
||||
from voice_changer.utils.ModelSlotGenerator import ModelSlotGenerator
|
||||
|
||||
|
||||
class DDSP_SVCModelSlotGenerator(ModelSlotGenerator):
|
||||
@classmethod
|
||||
def loadModel(cls, props: LoadModelParams):
|
||||
slotInfo: DDSPSVCModelSlot = DDSPSVCModelSlot()
|
||||
for file in props.files:
|
||||
if file.kind == "ddspSvcModelConfig":
|
||||
slotInfo.configFile = file.name
|
||||
elif file.kind == "ddspSvcModel":
|
||||
slotInfo.modelFile = file.name
|
||||
elif file.kind == "ddspSvcDiffusionConfig":
|
||||
slotInfo.diffConfigFile = file.name
|
||||
elif file.kind == "ddspSvcDiffusion":
|
||||
slotInfo.diffModelFile = file.name
|
||||
slotInfo.isONNX = slotInfo.modelFile.endswith(".onnx")
|
||||
slotInfo.name = os.path.splitext(os.path.basename(slotInfo.modelFile))[0]
|
||||
return slotInfo
|
@ -1,7 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from voice_changer.DDSP_SVC.ModelSlot import ModelSlot
|
||||
|
||||
|
||||
@dataclass
|
||||
class DDSP_SVCSettings:
|
||||
@ -23,14 +21,7 @@ class DDSP_SVCSettings:
|
||||
kStep: int = 120
|
||||
threshold: int = -45
|
||||
|
||||
framework: str = "PyTorch" # PyTorch or ONNX
|
||||
pyTorchModelFile: str = ""
|
||||
onnxModelFile: str = ""
|
||||
configFile: str = ""
|
||||
|
||||
speakers: dict[str, int] = field(default_factory=lambda: {})
|
||||
modelSlotIndex: int = -1
|
||||
modelSlots: list[ModelSlot] = field(default_factory=lambda: [ModelSlot()])
|
||||
# ↓mutableな物だけ列挙
|
||||
intData = [
|
||||
"gpu",
|
||||
@ -46,4 +37,4 @@ class DDSP_SVCSettings:
|
||||
"kStep",
|
||||
]
|
||||
floatData = ["silentThreshold"]
|
||||
strData = ["framework", "f0Detector", "diffMethod"]
|
||||
strData = ["f0Detector", "diffMethod"]
|
||||
|
@ -1,8 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSlot:
|
||||
modelFile: str = ""
|
||||
diffusionFile: str = ""
|
||||
defaultTrans: int = 0
|
@ -3,13 +3,13 @@
|
||||
import torch
|
||||
|
||||
try:
|
||||
from ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder # type: ignore
|
||||
from .models.ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder
|
||||
except Exception as e:
|
||||
print(e)
|
||||
from ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder # type: ignore
|
||||
from .models.ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder
|
||||
|
||||
from ddsp.core import upsample # type: ignore
|
||||
from enhancer import Enhancer # type: ignore
|
||||
from .models.ddsp.core import upsample
|
||||
from .models.enhancer import Enhancer
|
||||
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
|
||||
import numpy as np
|
||||
|
||||
@ -38,11 +38,7 @@ class SvcDDSP:
|
||||
print("ARGS:", self.args)
|
||||
|
||||
# load units encoder
|
||||
if (
|
||||
self.units_encoder is None
|
||||
or self.args.data.encoder != self.encoder_type
|
||||
or self.args.data.encoder_ckpt != self.encoder_ckpt
|
||||
):
|
||||
if self.units_encoder is None or self.args.data.encoder != self.encoder_type or self.args.data.encoder_ckpt != self.encoder_ckpt:
|
||||
if self.args.data.encoder == "cnhubertsoftfish":
|
||||
cnhubertsoft_gate = self.args.data.cnhubertsoft_gate
|
||||
else:
|
||||
@ -77,15 +73,9 @@ class SvcDDSP:
|
||||
self.encoder_ckpt = encoderPath
|
||||
|
||||
# load enhancer
|
||||
if (
|
||||
self.enhancer is None
|
||||
or self.args.enhancer.type != self.enhancer_type
|
||||
or self.args.enhancer.ckpt != self.enhancer_ckpt
|
||||
):
|
||||
if self.enhancer is None or self.args.enhancer.type != self.enhancer_type or self.args.enhancer.ckpt != self.enhancer_ckpt:
|
||||
enhancerPath = self.params.nsf_hifigan
|
||||
self.enhancer = Enhancer(
|
||||
self.args.enhancer.type, enhancerPath, device=self.device
|
||||
)
|
||||
self.enhancer = Enhancer(self.args.enhancer.type, enhancerPath, device=self.device)
|
||||
self.enhancer_type = self.args.enhancer.type
|
||||
self.enhancer_ckpt = enhancerPath
|
||||
|
||||
@ -118,9 +108,7 @@ class SvcDDSP:
|
||||
# print("audio", audio)
|
||||
# load input
|
||||
# audio, sample_rate = librosa.load(input_wav, sr=None, mono=True)
|
||||
hop_size = (
|
||||
self.args.data.block_size * sample_rate / self.args.data.sampling_rate
|
||||
)
|
||||
hop_size = self.args.data.block_size * sample_rate / self.args.data.sampling_rate
|
||||
if audio_alignment:
|
||||
audio_length = len(audio)
|
||||
# safe front silence
|
||||
@ -131,12 +119,9 @@ class SvcDDSP:
|
||||
audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(self.device)
|
||||
|
||||
# extract f0
|
||||
pitch_extractor = F0_Extractor(
|
||||
pitch_extractor_type, sample_rate, hop_size, float(f0_min), float(f0_max)
|
||||
)
|
||||
f0 = pitch_extractor.extract(
|
||||
audio, uv_interp=True, device=self.device, silence_front=silence_front
|
||||
)
|
||||
print("pitch_extractor_type", pitch_extractor_type)
|
||||
pitch_extractor = F0_Extractor(pitch_extractor_type, sample_rate, hop_size, float(f0_min), float(f0_max))
|
||||
f0 = pitch_extractor.extract(audio, uv_interp=True, device=self.device, silence_front=silence_front)
|
||||
f0 = torch.from_numpy(f0).float().to(self.device).unsqueeze(-1).unsqueeze(0)
|
||||
f0 = f0 * 2 ** (float(pitch_adjust) / 12)
|
||||
|
||||
@ -148,9 +133,7 @@ class SvcDDSP:
|
||||
mask = np.array([np.max(mask[n : n + 9]) for n in range(len(mask) - 8)]) # type: ignore
|
||||
mask = torch.from_numpy(mask).float().to(self.device).unsqueeze(-1).unsqueeze(0)
|
||||
mask = upsample(mask, self.args.data.block_size).squeeze(-1)
|
||||
volume = (
|
||||
torch.from_numpy(volume).float().to(self.device).unsqueeze(-1).unsqueeze(0)
|
||||
)
|
||||
volume = torch.from_numpy(volume).float().to(self.device).unsqueeze(-1).unsqueeze(0)
|
||||
|
||||
# extract units
|
||||
units = self.units_encoder.encode(audio_t, sample_rate, hop_size)
|
||||
@ -165,9 +148,7 @@ class SvcDDSP:
|
||||
|
||||
# forward and return the output
|
||||
with torch.no_grad():
|
||||
output, _, (s_h, s_n) = self.model(
|
||||
units, f0, volume, spk_id=spk_id, spk_mix_dict=dictionary
|
||||
)
|
||||
output, _, (s_h, s_n) = self.model(units, f0, volume, spk_id=spk_id, spk_mix_dict=dictionary)
|
||||
|
||||
if diff_use and diff_model is not None:
|
||||
output = diff_model.infer(
|
||||
|
281
server/voice_changer/DDSP_SVC/models/ddsp/core.py
Normal file
281
server/voice_changer/DDSP_SVC/models/ddsp/core.py
Normal file
@ -0,0 +1,281 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
def MaskedAvgPool1d(x, kernel_size):
|
||||
x = x.unsqueeze(1)
|
||||
x = F.pad(x, ((kernel_size - 1) // 2, kernel_size // 2), mode="reflect")
|
||||
mask = ~torch.isnan(x)
|
||||
masked_x = torch.where(mask, x, torch.zeros_like(x))
|
||||
ones_kernel = torch.ones(x.size(1), 1, kernel_size, device=x.device)
|
||||
|
||||
# Perform sum pooling
|
||||
sum_pooled = F.conv1d(
|
||||
masked_x,
|
||||
ones_kernel,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=x.size(1),
|
||||
)
|
||||
|
||||
# Count the non-masked (valid) elements in each pooling window
|
||||
valid_count = F.conv1d(
|
||||
mask.float(),
|
||||
ones_kernel,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=x.size(1),
|
||||
)
|
||||
valid_count = valid_count.clamp(min=1) # Avoid division by zero
|
||||
|
||||
# Perform masked average pooling
|
||||
avg_pooled = sum_pooled / valid_count
|
||||
|
||||
return avg_pooled.squeeze(1)
|
||||
|
||||
def MedianPool1d(x, kernel_size):
|
||||
x = x.unsqueeze(1)
|
||||
x = F.pad(x, ((kernel_size - 1) // 2, kernel_size // 2), mode="reflect")
|
||||
x = x.squeeze(1)
|
||||
x = x.unfold(1, kernel_size, 1)
|
||||
x, _ = torch.sort(x, dim=-1)
|
||||
return x[:, :, (kernel_size - 1) // 2]
|
||||
|
||||
def get_fft_size(frame_size: int, ir_size: int, power_of_2: bool = True):
|
||||
"""Calculate final size for efficient FFT.
|
||||
Args:
|
||||
frame_size: Size of the audio frame.
|
||||
ir_size: Size of the convolving impulse response.
|
||||
power_of_2: Constrain to be a power of 2. If False, allow other 5-smooth
|
||||
numbers. TPU requires power of 2, while GPU is more flexible.
|
||||
Returns:
|
||||
fft_size: Size for efficient FFT.
|
||||
"""
|
||||
convolved_frame_size = ir_size + frame_size - 1
|
||||
if power_of_2:
|
||||
# Next power of 2.
|
||||
fft_size = int(2**np.ceil(np.log2(convolved_frame_size)))
|
||||
else:
|
||||
fft_size = convolved_frame_size
|
||||
return fft_size
|
||||
|
||||
|
||||
def upsample(signal, factor):
|
||||
signal = signal.permute(0, 2, 1)
|
||||
signal = nn.functional.interpolate(torch.cat((signal,signal[:,:,-1:]),2), size=signal.shape[-1] * factor + 1, mode='linear', align_corners=True)
|
||||
signal = signal[:,:,:-1]
|
||||
return signal.permute(0, 2, 1)
|
||||
|
||||
|
||||
def remove_above_fmax(amplitudes, pitch, fmax, level_start=1):
|
||||
n_harm = amplitudes.shape[-1]
|
||||
pitches = pitch * torch.arange(level_start, n_harm + level_start).to(pitch)
|
||||
aa = (pitches < fmax).float() + 1e-7
|
||||
return amplitudes * aa
|
||||
|
||||
|
||||
def crop_and_compensate_delay(audio, audio_size, ir_size,
|
||||
padding = 'same',
|
||||
delay_compensation = -1):
|
||||
"""Crop audio output from convolution to compensate for group delay.
|
||||
Args:
|
||||
audio: Audio after convolution. Tensor of shape [batch, time_steps].
|
||||
audio_size: Initial size of the audio before convolution.
|
||||
ir_size: Size of the convolving impulse response.
|
||||
padding: Either 'valid' or 'same'. For 'same' the final output to be the
|
||||
same size as the input audio (audio_timesteps). For 'valid' the audio is
|
||||
extended to include the tail of the impulse response (audio_timesteps +
|
||||
ir_timesteps - 1).
|
||||
delay_compensation: Samples to crop from start of output audio to compensate
|
||||
for group delay of the impulse response. If delay_compensation < 0 it
|
||||
defaults to automatically calculating a constant group delay of the
|
||||
windowed linear phase filter from frequency_impulse_response().
|
||||
Returns:
|
||||
Tensor of cropped and shifted audio.
|
||||
Raises:
|
||||
ValueError: If padding is not either 'valid' or 'same'.
|
||||
"""
|
||||
# Crop the output.
|
||||
if padding == 'valid':
|
||||
crop_size = ir_size + audio_size - 1
|
||||
elif padding == 'same':
|
||||
crop_size = audio_size
|
||||
else:
|
||||
raise ValueError('Padding must be \'valid\' or \'same\', instead '
|
||||
'of {}.'.format(padding))
|
||||
|
||||
# Compensate for the group delay of the filter by trimming the front.
|
||||
# For an impulse response produced by frequency_impulse_response(),
|
||||
# the group delay is constant because the filter is linear phase.
|
||||
total_size = int(audio.shape[-1])
|
||||
crop = total_size - crop_size
|
||||
start = (ir_size // 2 if delay_compensation < 0 else delay_compensation)
|
||||
end = crop - start
|
||||
return audio[:, start:-end]
|
||||
|
||||
|
||||
def fft_convolve(audio,
|
||||
impulse_response): # B, n_frames, 2*(n_mags-1)
|
||||
"""Filter audio with frames of time-varying impulse responses.
|
||||
Time-varying filter. Given audio [batch, n_samples], and a series of impulse
|
||||
responses [batch, n_frames, n_impulse_response], splits the audio into frames,
|
||||
applies filters, and then overlap-and-adds audio back together.
|
||||
Applies non-windowed non-overlapping STFT/ISTFT to efficiently compute
|
||||
convolution for large impulse response sizes.
|
||||
Args:
|
||||
audio: Input audio. Tensor of shape [batch, audio_timesteps].
|
||||
impulse_response: Finite impulse response to convolve. Can either be a 2-D
|
||||
Tensor of shape [batch, ir_size], or a 3-D Tensor of shape [batch,
|
||||
ir_frames, ir_size]. A 2-D tensor will apply a single linear
|
||||
time-invariant filter to the audio. A 3-D Tensor will apply a linear
|
||||
time-varying filter. Automatically chops the audio into equally shaped
|
||||
blocks to match ir_frames.
|
||||
Returns:
|
||||
audio_out: Convolved audio. Tensor of shape
|
||||
[batch, audio_timesteps].
|
||||
"""
|
||||
# Add a frame dimension to impulse response if it doesn't have one.
|
||||
ir_shape = impulse_response.size()
|
||||
if len(ir_shape) == 2:
|
||||
impulse_response = impulse_response.unsqueeze(1)
|
||||
ir_shape = impulse_response.size()
|
||||
|
||||
# Get shapes of audio and impulse response.
|
||||
batch_size_ir, n_ir_frames, ir_size = ir_shape
|
||||
batch_size, audio_size = audio.size() # B, T
|
||||
|
||||
# Validate that batch sizes match.
|
||||
if batch_size != batch_size_ir:
|
||||
raise ValueError('Batch size of audio ({}) and impulse response ({}) must '
|
||||
'be the same.'.format(batch_size, batch_size_ir))
|
||||
|
||||
# Cut audio into 50% overlapped frames (center padding).
|
||||
hop_size = int(audio_size / n_ir_frames)
|
||||
frame_size = 2 * hop_size
|
||||
audio_frames = F.pad(audio, (hop_size, hop_size)).unfold(1, frame_size, hop_size)
|
||||
|
||||
# Apply Bartlett (triangular) window
|
||||
window = torch.bartlett_window(frame_size).to(audio_frames)
|
||||
audio_frames = audio_frames * window
|
||||
|
||||
# Pad and FFT the audio and impulse responses.
|
||||
fft_size = get_fft_size(frame_size, ir_size, power_of_2=False)
|
||||
audio_fft = torch.fft.rfft(audio_frames, fft_size)
|
||||
ir_fft = torch.fft.rfft(torch.cat((impulse_response,impulse_response[:,-1:,:]),1), fft_size)
|
||||
|
||||
# Multiply the FFTs (same as convolution in time).
|
||||
audio_ir_fft = torch.multiply(audio_fft, ir_fft)
|
||||
|
||||
# Take the IFFT to resynthesize audio.
|
||||
audio_frames_out = torch.fft.irfft(audio_ir_fft, fft_size)
|
||||
|
||||
# Overlap Add
|
||||
batch_size, n_audio_frames, frame_size = audio_frames_out.size() # # B, n_frames+1, 2*(hop_size+n_mags-1)-1
|
||||
fold = torch.nn.Fold(output_size=(1, (n_audio_frames - 1) * hop_size + frame_size),kernel_size=(1, frame_size),stride=(1, hop_size))
|
||||
output_signal = fold(audio_frames_out.transpose(1, 2)).squeeze(1).squeeze(1)
|
||||
|
||||
# Crop and shift the output audio.
|
||||
output_signal = crop_and_compensate_delay(output_signal[:,hop_size:], audio_size, ir_size)
|
||||
return output_signal
|
||||
|
||||
|
||||
def apply_window_to_impulse_response(impulse_response, # B, n_frames, 2*(n_mag-1)
|
||||
window_size: int = 0,
|
||||
causal: bool = False):
|
||||
"""Apply a window to an impulse response and put in causal form.
|
||||
Args:
|
||||
impulse_response: A series of impulse responses frames to window, of shape
|
||||
[batch, n_frames, ir_size]. ---------> ir_size means size of filter_bank ??????
|
||||
|
||||
window_size: Size of the window to apply in the time domain. If window_size
|
||||
is less than 1, it defaults to the impulse_response size.
|
||||
causal: Impulse response input is in causal form (peak in the middle).
|
||||
Returns:
|
||||
impulse_response: Windowed impulse response in causal form, with last
|
||||
dimension cropped to window_size if window_size is greater than 0 and less
|
||||
than ir_size.
|
||||
"""
|
||||
|
||||
# If IR is in causal form, put it in zero-phase form.
|
||||
if causal:
|
||||
impulse_response = torch.fftshift(impulse_response, axes=-1)
|
||||
|
||||
# Get a window for better time/frequency resolution than rectangular.
|
||||
# Window defaults to IR size, cannot be bigger.
|
||||
ir_size = int(impulse_response.size(-1))
|
||||
if (window_size <= 0) or (window_size > ir_size):
|
||||
window_size = ir_size
|
||||
window = nn.Parameter(torch.hann_window(window_size), requires_grad = False).to(impulse_response)
|
||||
|
||||
# Zero pad the window and put in in zero-phase form.
|
||||
padding = ir_size - window_size
|
||||
if padding > 0:
|
||||
half_idx = (window_size + 1) // 2
|
||||
window = torch.cat([window[half_idx:],
|
||||
torch.zeros([padding]),
|
||||
window[:half_idx]], axis=0)
|
||||
else:
|
||||
window = window.roll(window.size(-1)//2, -1)
|
||||
|
||||
# Apply the window, to get new IR (both in zero-phase form).
|
||||
window = window.unsqueeze(0)
|
||||
impulse_response = impulse_response * window
|
||||
|
||||
# Put IR in causal form and trim zero padding.
|
||||
if padding > 0:
|
||||
first_half_start = (ir_size - (half_idx - 1)) + 1
|
||||
second_half_end = half_idx + 1
|
||||
impulse_response = torch.cat([impulse_response[..., first_half_start:],
|
||||
impulse_response[..., :second_half_end]],
|
||||
dim=-1)
|
||||
else:
|
||||
impulse_response = impulse_response.roll(impulse_response.size(-1)//2, -1)
|
||||
|
||||
return impulse_response
|
||||
|
||||
|
||||
def apply_dynamic_window_to_impulse_response(impulse_response, # B, n_frames, 2*(n_mag-1) or 2*n_mag-1
|
||||
half_width_frames): # B,n_frames, 1
|
||||
ir_size = int(impulse_response.size(-1)) # 2*(n_mag -1) or 2*n_mag-1
|
||||
|
||||
window = torch.arange(-(ir_size // 2), (ir_size + 1) // 2).to(impulse_response) / half_width_frames
|
||||
window[window > 1] = 0
|
||||
window = (1 + torch.cos(np.pi * window)) / 2 # B, n_frames, 2*(n_mag -1) or 2*n_mag-1
|
||||
|
||||
impulse_response = impulse_response.roll(ir_size // 2, -1)
|
||||
impulse_response = impulse_response * window
|
||||
|
||||
return impulse_response
|
||||
|
||||
|
||||
def frequency_impulse_response(magnitudes,
|
||||
hann_window = True,
|
||||
half_width_frames = None):
|
||||
|
||||
# Get the IR
|
||||
impulse_response = torch.fft.irfft(magnitudes) # B, n_frames, 2*(n_mags-1)
|
||||
|
||||
# Window and put in causal form.
|
||||
if hann_window:
|
||||
if half_width_frames is None:
|
||||
impulse_response = apply_window_to_impulse_response(impulse_response)
|
||||
else:
|
||||
impulse_response = apply_dynamic_window_to_impulse_response(impulse_response, half_width_frames)
|
||||
else:
|
||||
impulse_response = impulse_response.roll(impulse_response.size(-1) // 2, -1)
|
||||
|
||||
return impulse_response
|
||||
|
||||
|
||||
def frequency_filter(audio,
|
||||
magnitudes,
|
||||
hann_window=True,
|
||||
half_width_frames=None):
|
||||
|
||||
impulse_response = frequency_impulse_response(magnitudes, hann_window, half_width_frames)
|
||||
|
||||
return fft_convolve(audio, impulse_response)
|
||||
|
57
server/voice_changer/DDSP_SVC/models/ddsp/loss.py
Normal file
57
server/voice_changer/DDSP_SVC/models/ddsp/loss.py
Normal file
@ -0,0 +1,57 @@
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
from torch.nn import functional as F
|
||||
from .core import upsample
|
||||
|
||||
class SSSLoss(nn.Module):
|
||||
"""
|
||||
Single-scale Spectral Loss.
|
||||
"""
|
||||
|
||||
def __init__(self, n_fft=111, alpha=1.0, overlap=0, eps=1e-7):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
self.alpha = alpha
|
||||
self.eps = eps
|
||||
self.hop_length = int(n_fft * (1 - overlap)) # 25% of the length
|
||||
self.spec = torchaudio.transforms.Spectrogram(n_fft=self.n_fft, hop_length=self.hop_length, power=1, normalized=True, center=False)
|
||||
|
||||
def forward(self, x_true, x_pred):
|
||||
S_true = self.spec(x_true) + self.eps
|
||||
S_pred = self.spec(x_pred) + self.eps
|
||||
|
||||
converge_term = torch.mean(torch.linalg.norm(S_true - S_pred, dim = (1, 2)) / torch.linalg.norm(S_true + S_pred, dim = (1, 2)))
|
||||
|
||||
log_term = F.l1_loss(S_true.log(), S_pred.log())
|
||||
|
||||
loss = converge_term + self.alpha * log_term
|
||||
return loss
|
||||
|
||||
|
||||
class RSSLoss(nn.Module):
|
||||
'''
|
||||
Random-scale Spectral Loss.
|
||||
'''
|
||||
|
||||
def __init__(self, fft_min, fft_max, n_scale, alpha=1.0, overlap=0, eps=1e-7, device='cuda'):
|
||||
super().__init__()
|
||||
self.fft_min = fft_min
|
||||
self.fft_max = fft_max
|
||||
self.n_scale = n_scale
|
||||
self.lossdict = {}
|
||||
for n_fft in range(fft_min, fft_max):
|
||||
self.lossdict[n_fft] = SSSLoss(n_fft, alpha, overlap, eps).to(device)
|
||||
|
||||
def forward(self, x_pred, x_true):
|
||||
value = 0.
|
||||
n_ffts = torch.randint(self.fft_min, self.fft_max, (self.n_scale,))
|
||||
for n_fft in n_ffts:
|
||||
loss_func = self.lossdict[int(n_fft)]
|
||||
value += loss_func(x_true, x_pred)
|
||||
return value / self.n_scale
|
||||
|
||||
|
||||
|
380
server/voice_changer/DDSP_SVC/models/ddsp/pcmer.py
Normal file
380
server/voice_changer/DDSP_SVC/models/ddsp/pcmer.py
Normal file
@ -0,0 +1,380 @@
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
import math
|
||||
from functools import partial
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from local_attention import LocalAttention
|
||||
import torch.nn.functional as F
|
||||
#import fast_transformers.causal_product.causal_product_cuda
|
||||
|
||||
def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None):
|
||||
b, h, *_ = data.shape
|
||||
# (batch size, head, length, model_dim)
|
||||
|
||||
# normalize model dim
|
||||
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
|
||||
|
||||
# what is ration?, projection_matrix.shape[0] --> 266
|
||||
|
||||
ratio = (projection_matrix.shape[0] ** -0.5)
|
||||
|
||||
projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h)
|
||||
projection = projection.type_as(data)
|
||||
|
||||
#data_dash = w^T x
|
||||
data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)
|
||||
|
||||
|
||||
# diag_data = D**2
|
||||
diag_data = data ** 2
|
||||
diag_data = torch.sum(diag_data, dim=-1)
|
||||
diag_data = (diag_data / 2.0) * (data_normalizer ** 2)
|
||||
diag_data = diag_data.unsqueeze(dim=-1)
|
||||
|
||||
#print ()
|
||||
if is_query:
|
||||
data_dash = ratio * (
|
||||
torch.exp(data_dash - diag_data -
|
||||
torch.max(data_dash, dim=-1, keepdim=True).values) + eps)
|
||||
else:
|
||||
data_dash = ratio * (
|
||||
torch.exp(data_dash - diag_data + eps))#- torch.max(data_dash)) + eps)
|
||||
|
||||
return data_dash.type_as(data)
|
||||
|
||||
def orthogonal_matrix_chunk(cols, qr_uniform_q = False, device = None):
|
||||
unstructured_block = torch.randn((cols, cols), device = device)
|
||||
q, r = torch.linalg.qr(unstructured_block.cpu(), mode='reduced')
|
||||
q, r = map(lambda t: t.to(device), (q, r))
|
||||
|
||||
# proposed by @Parskatt
|
||||
# to make sure Q is uniform https://arxiv.org/pdf/math-ph/0609050.pdf
|
||||
if qr_uniform_q:
|
||||
d = torch.diag(r, 0)
|
||||
q *= d.sign()
|
||||
return q.t()
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def empty(tensor):
|
||||
return tensor.numel() == 0
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def cast_tuple(val):
|
||||
return (val,) if not isinstance(val, tuple) else val
|
||||
|
||||
class PCmer(nn.Module):
|
||||
"""The encoder that is used in the Transformer model."""
|
||||
|
||||
def __init__(self,
|
||||
num_layers,
|
||||
num_heads,
|
||||
dim_model,
|
||||
dim_keys,
|
||||
dim_values,
|
||||
residual_dropout,
|
||||
attention_dropout):
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.dim_model = dim_model
|
||||
self.dim_values = dim_values
|
||||
self.dim_keys = dim_keys
|
||||
self.residual_dropout = residual_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
self._layers = nn.ModuleList([_EncoderLayer(self) for _ in range(num_layers)])
|
||||
|
||||
# METHODS ########################################################################################################
|
||||
|
||||
def forward(self, phone, mask=None):
|
||||
|
||||
# apply all layers to the input
|
||||
for (i, layer) in enumerate(self._layers):
|
||||
phone = layer(phone, mask)
|
||||
# provide the final sequence
|
||||
return phone
|
||||
|
||||
|
||||
# ==================================================================================================================== #
|
||||
# CLASS _ E N C O D E R L A Y E R #
|
||||
# ==================================================================================================================== #
|
||||
|
||||
|
||||
class _EncoderLayer(nn.Module):
|
||||
"""One layer of the encoder.
|
||||
|
||||
Attributes:
|
||||
attn: (:class:`mha.MultiHeadAttention`): The attention mechanism that is used to read the input sequence.
|
||||
feed_forward (:class:`ffl.FeedForwardLayer`): The feed-forward layer on top of the attention mechanism.
|
||||
"""
|
||||
|
||||
def __init__(self, parent: PCmer):
|
||||
"""Creates a new instance of ``_EncoderLayer``.
|
||||
|
||||
Args:
|
||||
parent (Encoder): The encoder that the layers is created for.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.conformer = ConformerConvModule(parent.dim_model)
|
||||
self.norm = nn.LayerNorm(parent.dim_model)
|
||||
self.dropout = nn.Dropout(parent.residual_dropout)
|
||||
|
||||
# selfatt -> fastatt: performer!
|
||||
self.attn = SelfAttention(dim = parent.dim_model,
|
||||
heads = parent.num_heads,
|
||||
causal = False)
|
||||
|
||||
# METHODS ########################################################################################################
|
||||
|
||||
def forward(self, phone, mask=None):
|
||||
|
||||
# compute attention sub-layer
|
||||
phone = phone + (self.attn(self.norm(phone), mask=mask))
|
||||
|
||||
phone = phone + (self.conformer(phone))
|
||||
|
||||
return phone
|
||||
|
||||
def calc_same_padding(kernel_size):
|
||||
pad = kernel_size // 2
|
||||
return (pad, pad - (kernel_size + 1) % 2)
|
||||
|
||||
# helper classes
|
||||
|
||||
class Swish(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * x.sigmoid()
|
||||
|
||||
class Transpose(nn.Module):
|
||||
def __init__(self, dims):
|
||||
super().__init__()
|
||||
assert len(dims) == 2, 'dims must be a tuple of two dimensions'
|
||||
self.dims = dims
|
||||
|
||||
def forward(self, x):
|
||||
return x.transpose(*self.dims)
|
||||
|
||||
class GLU(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
out, gate = x.chunk(2, dim=self.dim)
|
||||
return out * gate.sigmoid()
|
||||
|
||||
class DepthWiseConv1d(nn.Module):
|
||||
def __init__(self, chan_in, chan_out, kernel_size, padding):
|
||||
super().__init__()
|
||||
self.padding = padding
|
||||
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, self.padding)
|
||||
return self.conv(x)
|
||||
|
||||
class ConformerConvModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
causal = False,
|
||||
expansion_factor = 2,
|
||||
kernel_size = 31,
|
||||
dropout = 0.):
|
||||
super().__init__()
|
||||
|
||||
inner_dim = dim * expansion_factor
|
||||
padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
Transpose((1, 2)),
|
||||
nn.Conv1d(dim, inner_dim * 2, 1),
|
||||
GLU(dim=1),
|
||||
DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding),
|
||||
#nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
|
||||
Swish(),
|
||||
nn.Conv1d(inner_dim, dim, 1),
|
||||
Transpose((1, 2)),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
def linear_attention(q, k, v):
|
||||
if v is None:
|
||||
#print (k.size(), q.size())
|
||||
out = torch.einsum('...ed,...nd->...ne', k, q)
|
||||
return out
|
||||
|
||||
else:
|
||||
k_cumsum = k.sum(dim = -2)
|
||||
#k_cumsum = k.sum(dim = -2)
|
||||
D_inv = 1. / (torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q)) + 1e-8)
|
||||
|
||||
context = torch.einsum('...nd,...ne->...de', k, v)
|
||||
#print ("TRUEEE: ", context.size(), q.size(), D_inv.size())
|
||||
out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
|
||||
return out
|
||||
|
||||
def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, qr_uniform_q = False, device = None):
|
||||
nb_full_blocks = int(nb_rows / nb_columns)
|
||||
#print (nb_full_blocks)
|
||||
block_list = []
|
||||
|
||||
for _ in range(nb_full_blocks):
|
||||
q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device)
|
||||
block_list.append(q)
|
||||
# block_list[n] is a orthogonal matrix ... (model_dim * model_dim)
|
||||
#print (block_list[0].size(), torch.einsum('...nd,...nd->...n', block_list[0], torch.roll(block_list[0],1,1)))
|
||||
#print (nb_rows, nb_full_blocks, nb_columns)
|
||||
remaining_rows = nb_rows - nb_full_blocks * nb_columns
|
||||
#print (remaining_rows)
|
||||
if remaining_rows > 0:
|
||||
q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device)
|
||||
#print (q[:remaining_rows].size())
|
||||
block_list.append(q[:remaining_rows])
|
||||
|
||||
final_matrix = torch.cat(block_list)
|
||||
|
||||
if scaling == 0:
|
||||
multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1)
|
||||
elif scaling == 1:
|
||||
multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device)
|
||||
else:
|
||||
raise ValueError(f'Invalid scaling {scaling}')
|
||||
|
||||
return torch.diag(multiplier) @ final_matrix
|
||||
|
||||
class FastAttention(nn.Module):
|
||||
def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, no_projection = False):
|
||||
super().__init__()
|
||||
nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
|
||||
|
||||
self.dim_heads = dim_heads
|
||||
self.nb_features = nb_features
|
||||
self.ortho_scaling = ortho_scaling
|
||||
|
||||
self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling, qr_uniform_q = qr_uniform_q)
|
||||
projection_matrix = self.create_projection()
|
||||
self.register_buffer('projection_matrix', projection_matrix)
|
||||
|
||||
self.generalized_attention = generalized_attention
|
||||
self.kernel_fn = kernel_fn
|
||||
|
||||
# if this is turned on, no projection will be used
|
||||
# queries and keys will be softmax-ed as in the original efficient attention paper
|
||||
self.no_projection = no_projection
|
||||
|
||||
self.causal = causal
|
||||
if causal:
|
||||
try:
|
||||
import fast_transformers.causal_product.causal_product_cuda
|
||||
self.causal_linear_fn = partial(causal_linear_attention)
|
||||
except ImportError:
|
||||
print('unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version')
|
||||
self.causal_linear_fn = causal_linear_attention_noncuda
|
||||
@torch.no_grad()
|
||||
def redraw_projection_matrix(self):
|
||||
projections = self.create_projection()
|
||||
self.projection_matrix.copy_(projections)
|
||||
del projections
|
||||
|
||||
def forward(self, q, k, v):
|
||||
device = q.device
|
||||
|
||||
if self.no_projection:
|
||||
q = q.softmax(dim = -1)
|
||||
k = torch.exp(k) if self.causal else k.softmax(dim = -2)
|
||||
|
||||
elif self.generalized_attention:
|
||||
create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device)
|
||||
q, k = map(create_kernel, (q, k))
|
||||
|
||||
else:
|
||||
create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device)
|
||||
|
||||
q = create_kernel(q, is_query = True)
|
||||
k = create_kernel(k, is_query = False)
|
||||
|
||||
attn_fn = linear_attention if not self.causal else self.causal_linear_fn
|
||||
if v is None:
|
||||
out = attn_fn(q, k, None)
|
||||
return out
|
||||
else:
|
||||
out = attn_fn(q, k, v)
|
||||
return out
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim, causal = False, heads = 8, dim_head = 64, local_heads = 0, local_window_size = 256, nb_features = None, feature_redraw_interval = 1000, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, dropout = 0., no_projection = False):
|
||||
super().__init__()
|
||||
assert dim % heads == 0, 'dimension must be divisible by number of heads'
|
||||
dim_head = default(dim_head, dim // heads)
|
||||
inner_dim = dim_head * heads
|
||||
self.fast_attention = FastAttention(dim_head, nb_features, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, qr_uniform_q = qr_uniform_q, no_projection = no_projection)
|
||||
|
||||
self.heads = heads
|
||||
self.global_heads = heads - local_heads
|
||||
self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None
|
||||
|
||||
#print (heads, nb_features, dim_head)
|
||||
#name_embedding = torch.zeros(110, heads, dim_head, dim_head)
|
||||
#self.name_embedding = nn.Parameter(name_embedding, requires_grad=True)
|
||||
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim)
|
||||
self.to_k = nn.Linear(dim, inner_dim)
|
||||
self.to_v = nn.Linear(dim, inner_dim)
|
||||
self.to_out = nn.Linear(inner_dim, dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
@torch.no_grad()
|
||||
def redraw_projection_matrix(self):
|
||||
self.fast_attention.redraw_projection_matrix()
|
||||
#torch.nn.init.zeros_(self.name_embedding)
|
||||
#print (torch.sum(self.name_embedding))
|
||||
def forward(self, x, context = None, mask = None, context_mask = None, name=None, inference=False, **kwargs):
|
||||
b, n, _, h, gh = *x.shape, self.heads, self.global_heads
|
||||
|
||||
cross_attend = exists(context)
|
||||
|
||||
context = default(context, x)
|
||||
context_mask = default(context_mask, mask) if not cross_attend else context_mask
|
||||
#print (torch.sum(self.name_embedding))
|
||||
q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
||||
(q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))
|
||||
|
||||
attn_outs = []
|
||||
#print (name)
|
||||
#print (self.name_embedding[name].size())
|
||||
if not empty(q):
|
||||
if exists(context_mask):
|
||||
global_mask = context_mask[:, None, :, None]
|
||||
v.masked_fill_(~global_mask, 0.)
|
||||
if cross_attend:
|
||||
pass
|
||||
#print (torch.sum(self.name_embedding))
|
||||
#out = self.fast_attention(q,self.name_embedding[name],None)
|
||||
#print (torch.sum(self.name_embedding[...,-1:]))
|
||||
else:
|
||||
out = self.fast_attention(q, k, v)
|
||||
attn_outs.append(out)
|
||||
|
||||
if not empty(lq):
|
||||
assert not cross_attend, 'local attention is not compatible with cross attention'
|
||||
out = self.local_attn(lq, lk, lv, input_mask = mask)
|
||||
attn_outs.append(out)
|
||||
|
||||
out = torch.cat(attn_outs, dim = 1)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
out = self.to_out(out)
|
||||
return self.dropout(out)
|
86
server/voice_changer/DDSP_SVC/models/ddsp/unit2control.py
Normal file
86
server/voice_changer/DDSP_SVC/models/ddsp/unit2control.py
Normal file
@ -0,0 +1,86 @@
|
||||
import gin
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from .pcmer import PCmer
|
||||
|
||||
|
||||
def split_to_dict(tensor, tensor_splits):
|
||||
"""Split a tensor into a dictionary of multiple tensors."""
|
||||
labels = []
|
||||
sizes = []
|
||||
|
||||
for k, v in tensor_splits.items():
|
||||
labels.append(k)
|
||||
sizes.append(v)
|
||||
|
||||
tensors = torch.split(tensor, sizes, dim=-1)
|
||||
return dict(zip(labels, tensors))
|
||||
|
||||
|
||||
class Unit2Control(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channel,
|
||||
n_spk,
|
||||
output_splits):
|
||||
super().__init__()
|
||||
self.output_splits = output_splits
|
||||
self.f0_embed = nn.Linear(1, 256)
|
||||
self.phase_embed = nn.Linear(1, 256)
|
||||
self.volume_embed = nn.Linear(1, 256)
|
||||
self.n_spk = n_spk
|
||||
if n_spk is not None and n_spk > 1:
|
||||
self.spk_embed = nn.Embedding(n_spk, 256)
|
||||
|
||||
# conv in stack
|
||||
self.stack = nn.Sequential(
|
||||
nn.Conv1d(input_channel, 256, 3, 1, 1),
|
||||
nn.GroupNorm(4, 256),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv1d(256, 256, 3, 1, 1))
|
||||
|
||||
# transformer
|
||||
self.decoder = PCmer(
|
||||
num_layers=3,
|
||||
num_heads=8,
|
||||
dim_model=256,
|
||||
dim_keys=256,
|
||||
dim_values=256,
|
||||
residual_dropout=0.1,
|
||||
attention_dropout=0.1)
|
||||
self.norm = nn.LayerNorm(256)
|
||||
|
||||
# out
|
||||
self.n_out = sum([v for k, v in output_splits.items()])
|
||||
self.dense_out = weight_norm(
|
||||
nn.Linear(256, self.n_out))
|
||||
|
||||
def forward(self, units, f0, phase, volume, spk_id = None, spk_mix_dict = None):
|
||||
|
||||
'''
|
||||
input:
|
||||
B x n_frames x n_unit
|
||||
return:
|
||||
dict of B x n_frames x feat
|
||||
'''
|
||||
|
||||
x = self.stack(units.transpose(1,2)).transpose(1,2)
|
||||
x = x + self.f0_embed((1+ f0 / 700).log()) + self.phase_embed(phase / np.pi) + self.volume_embed(volume)
|
||||
if self.n_spk is not None and self.n_spk > 1:
|
||||
if spk_mix_dict is not None:
|
||||
for k, v in spk_mix_dict.items():
|
||||
spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device)
|
||||
x = x + v * self.spk_embed(spk_id_torch - 1)
|
||||
else:
|
||||
x = x + self.spk_embed(spk_id - 1)
|
||||
x = self.decoder(x)
|
||||
x = self.norm(x)
|
||||
e = self.dense_out(x)
|
||||
controls = split_to_dict(e, self.output_splits)
|
||||
|
||||
return controls
|
||||
|
639
server/voice_changer/DDSP_SVC/models/ddsp/vocoder.py
Normal file
639
server/voice_changer/DDSP_SVC/models/ddsp/vocoder.py
Normal file
@ -0,0 +1,639 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import yaml
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pyworld as pw
|
||||
import parselmouth
|
||||
import torchcrepe
|
||||
from transformers import HubertModel, Wav2Vec2FeatureExtractor
|
||||
from fairseq import checkpoint_utils
|
||||
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
||||
from torchaudio.transforms import Resample
|
||||
from .unit2control import Unit2Control
|
||||
from .core import frequency_filter, upsample, remove_above_fmax, MaskedAvgPool1d, MedianPool1d
|
||||
from ..encoder.hubert.model import HubertSoft
|
||||
|
||||
CREPE_RESAMPLE_KERNEL = {}
|
||||
|
||||
|
||||
class F0_Extractor:
|
||||
def __init__(self, f0_extractor, sample_rate=44100, hop_size=512, f0_min=65, f0_max=800):
|
||||
self.f0_extractor = f0_extractor
|
||||
self.sample_rate = sample_rate
|
||||
self.hop_size = hop_size
|
||||
self.f0_min = f0_min
|
||||
self.f0_max = f0_max
|
||||
if f0_extractor == "crepe":
|
||||
key_str = str(sample_rate)
|
||||
if key_str not in CREPE_RESAMPLE_KERNEL:
|
||||
CREPE_RESAMPLE_KERNEL[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128)
|
||||
self.resample_kernel = CREPE_RESAMPLE_KERNEL[key_str]
|
||||
|
||||
def extract(self, audio, uv_interp=False, device=None, silence_front=0): # audio: 1d numpy array
|
||||
# extractor start time
|
||||
n_frames = int(len(audio) // self.hop_size) + 1
|
||||
|
||||
start_frame = int(silence_front * self.sample_rate / self.hop_size)
|
||||
real_silence_front = start_frame * self.hop_size / self.sample_rate
|
||||
audio = audio[int(np.round(real_silence_front * self.sample_rate)) :]
|
||||
|
||||
# extract f0 using parselmouth
|
||||
if self.f0_extractor == "parselmouth":
|
||||
f0 = parselmouth.Sound(audio, self.sample_rate).to_pitch_ac(time_step=self.hop_size / self.sample_rate, voicing_threshold=0.6, pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array["frequency"]
|
||||
pad_size = start_frame + (int(len(audio) // self.hop_size) - len(f0) + 1) // 2
|
||||
f0 = np.pad(f0, (pad_size, n_frames - len(f0) - pad_size))
|
||||
|
||||
# extract f0 using dio
|
||||
elif self.f0_extractor == "dio":
|
||||
_f0, t = pw.dio(audio.astype("double"), self.sample_rate, f0_floor=self.f0_min, f0_ceil=self.f0_max, channels_in_octave=2, frame_period=(1000 * self.hop_size / self.sample_rate))
|
||||
f0 = pw.stonemask(audio.astype("double"), _f0, t, self.sample_rate)
|
||||
f0 = np.pad(f0.astype("float"), (start_frame, n_frames - len(f0) - start_frame))
|
||||
|
||||
# extract f0 using harvest
|
||||
elif self.f0_extractor == "harvest":
|
||||
f0, _ = pw.harvest(audio.astype("double"), self.sample_rate, f0_floor=self.f0_min, f0_ceil=self.f0_max, frame_period=(1000 * self.hop_size / self.sample_rate))
|
||||
f0 = np.pad(f0.astype("float"), (start_frame, n_frames - len(f0) - start_frame))
|
||||
|
||||
# extract f0 using crepe
|
||||
elif self.f0_extractor == "crepe":
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
resample_kernel = self.resample_kernel.to(device)
|
||||
wav16k_torch = resample_kernel(torch.FloatTensor(audio).unsqueeze(0).to(device))
|
||||
|
||||
f0, pd = torchcrepe.predict(wav16k_torch, 16000, 80, self.f0_min, self.f0_max, pad=True, model="full", batch_size=512, device=device, return_periodicity=True)
|
||||
pd = MedianPool1d(pd, 4)
|
||||
f0 = torchcrepe.threshold.At(0.05)(f0, pd)
|
||||
f0 = MaskedAvgPool1d(f0, 4)
|
||||
|
||||
f0 = f0.squeeze(0).cpu().numpy()
|
||||
f0 = np.array([f0[int(min(int(np.round(n * self.hop_size / self.sample_rate / 0.005)), len(f0) - 1))] for n in range(n_frames - start_frame)])
|
||||
f0 = np.pad(f0, (start_frame, 0))
|
||||
|
||||
else:
|
||||
raise ValueError(f" [x] Unknown f0 extractor: {f0_extractor}")
|
||||
|
||||
# interpolate the unvoiced f0
|
||||
if uv_interp:
|
||||
uv = f0 == 0
|
||||
if len(f0[~uv]) > 0:
|
||||
f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
|
||||
f0[f0 < self.f0_min] = self.f0_min
|
||||
return f0
|
||||
|
||||
|
||||
class Volume_Extractor:
|
||||
def __init__(self, hop_size=512):
|
||||
self.hop_size = hop_size
|
||||
|
||||
def extract(self, audio): # audio: 1d numpy array
|
||||
n_frames = int(len(audio) // self.hop_size) + 1
|
||||
audio2 = audio**2
|
||||
audio2 = np.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)), mode="reflect")
|
||||
volume = np.array([np.mean(audio2[int(n * self.hop_size) : int((n + 1) * self.hop_size)]) for n in range(n_frames)])
|
||||
volume = np.sqrt(volume)
|
||||
return volume
|
||||
|
||||
|
||||
class Units_Encoder:
|
||||
def __init__(self, encoder, encoder_ckpt, encoder_sample_rate=16000, encoder_hop_size=320, device=None, cnhubertsoft_gate=10):
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.device = device
|
||||
|
||||
is_loaded_encoder = False
|
||||
if encoder == "hubertsoft":
|
||||
self.model = Audio2HubertSoft(encoder_ckpt).to(device)
|
||||
is_loaded_encoder = True
|
||||
if encoder == "hubertbase":
|
||||
self.model = Audio2HubertBase(encoder_ckpt, device=device)
|
||||
is_loaded_encoder = True
|
||||
if encoder == "hubertbase768":
|
||||
self.model = Audio2HubertBase768(encoder_ckpt, device=device)
|
||||
is_loaded_encoder = True
|
||||
if encoder == "hubertbase768l12":
|
||||
self.model = Audio2HubertBase768L12(encoder_ckpt, device=device)
|
||||
is_loaded_encoder = True
|
||||
if encoder == "hubertlarge1024l24":
|
||||
self.model = Audio2HubertLarge1024L24(encoder_ckpt, device=device)
|
||||
is_loaded_encoder = True
|
||||
if encoder == "contentvec":
|
||||
self.model = Audio2ContentVec(encoder_ckpt, device=device)
|
||||
is_loaded_encoder = True
|
||||
if encoder == "contentvec768":
|
||||
self.model = Audio2ContentVec768(encoder_ckpt, device=device)
|
||||
is_loaded_encoder = True
|
||||
if encoder == "contentvec768l12":
|
||||
self.model = Audio2ContentVec768L12(encoder_ckpt, device=device)
|
||||
is_loaded_encoder = True
|
||||
if encoder == "cnhubertsoftfish":
|
||||
self.model = CNHubertSoftFish(encoder_ckpt, device=device, gate_size=cnhubertsoft_gate)
|
||||
is_loaded_encoder = True
|
||||
if not is_loaded_encoder:
|
||||
raise ValueError(f" [x] Unknown units encoder: {encoder}")
|
||||
|
||||
self.resample_kernel = {}
|
||||
self.encoder_sample_rate = encoder_sample_rate
|
||||
self.encoder_hop_size = encoder_hop_size
|
||||
|
||||
def encode(self, audio, sample_rate, hop_size): # B, T
|
||||
# resample
|
||||
if sample_rate == self.encoder_sample_rate:
|
||||
audio_res = audio
|
||||
else:
|
||||
key_str = str(sample_rate)
|
||||
if key_str not in self.resample_kernel:
|
||||
self.resample_kernel[key_str] = Resample(sample_rate, self.encoder_sample_rate, lowpass_filter_width=128).to(self.device)
|
||||
audio_res = self.resample_kernel[key_str](audio)
|
||||
|
||||
# encode
|
||||
if audio_res.size(-1) < 400:
|
||||
audio_res = torch.nn.functional.pad(audio, (0, 400 - audio_res.size(-1)))
|
||||
units = self.model(audio_res)
|
||||
|
||||
# alignment
|
||||
n_frames = audio.size(-1) // hop_size + 1
|
||||
ratio = (hop_size / sample_rate) / (self.encoder_hop_size / self.encoder_sample_rate)
|
||||
index = torch.clamp(torch.round(ratio * torch.arange(n_frames).to(self.device)).long(), max=units.size(1) - 1)
|
||||
units_aligned = torch.gather(units, 1, index.unsqueeze(0).unsqueeze(-1).repeat([1, 1, units.size(-1)]))
|
||||
return units_aligned
|
||||
|
||||
|
||||
class Audio2HubertSoft(torch.nn.Module):
|
||||
def __init__(self, path, h_sample_rate=16000, h_hop_size=320):
|
||||
super().__init__()
|
||||
print(" [Encoder Model] HuBERT Soft")
|
||||
self.hubert = HubertSoft()
|
||||
print(" [Loading] " + path)
|
||||
checkpoint = torch.load(path)
|
||||
consume_prefix_in_state_dict_if_present(checkpoint, "module.")
|
||||
self.hubert.load_state_dict(checkpoint)
|
||||
self.hubert.eval()
|
||||
|
||||
def forward(self, audio): # B, T
|
||||
with torch.inference_mode():
|
||||
units = self.hubert.units(audio.unsqueeze(1))
|
||||
return units
|
||||
|
||||
|
||||
class Audio2ContentVec:
|
||||
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"):
|
||||
self.device = device
|
||||
print(" [Encoder Model] Content Vec")
|
||||
print(" [Loading] " + path)
|
||||
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
|
||||
[path],
|
||||
suffix="",
|
||||
)
|
||||
self.hubert = self.models[0]
|
||||
self.hubert = self.hubert.to(self.device)
|
||||
self.hubert.eval()
|
||||
|
||||
def __call__(self, audio): # B, T
|
||||
# wav_tensor = torch.from_numpy(audio).to(self.device)
|
||||
wav_tensor = audio
|
||||
feats = wav_tensor.view(1, -1)
|
||||
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
|
||||
inputs = {
|
||||
"source": feats.to(wav_tensor.device),
|
||||
"padding_mask": padding_mask.to(wav_tensor.device),
|
||||
"output_layer": 9, # layer 9
|
||||
}
|
||||
with torch.no_grad():
|
||||
logits = self.hubert.extract_features(**inputs)
|
||||
feats = self.hubert.final_proj(logits[0])
|
||||
units = feats # .transpose(2, 1)
|
||||
return units
|
||||
|
||||
|
||||
class Audio2ContentVec768:
|
||||
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"):
|
||||
self.device = device
|
||||
print(" [Encoder Model] Content Vec")
|
||||
print(" [Loading] " + path)
|
||||
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
|
||||
[path],
|
||||
suffix="",
|
||||
)
|
||||
self.hubert = self.models[0]
|
||||
self.hubert = self.hubert.to(self.device)
|
||||
self.hubert.eval()
|
||||
|
||||
def __call__(self, audio): # B, T
|
||||
# wav_tensor = torch.from_numpy(audio).to(self.device)
|
||||
wav_tensor = audio
|
||||
feats = wav_tensor.view(1, -1)
|
||||
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
|
||||
inputs = {
|
||||
"source": feats.to(wav_tensor.device),
|
||||
"padding_mask": padding_mask.to(wav_tensor.device),
|
||||
"output_layer": 9, # layer 9
|
||||
}
|
||||
with torch.no_grad():
|
||||
logits = self.hubert.extract_features(**inputs)
|
||||
feats = logits[0]
|
||||
units = feats # .transpose(2, 1)
|
||||
return units
|
||||
|
||||
|
||||
class Audio2ContentVec768L12:
|
||||
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"):
|
||||
self.device = device
|
||||
print(" [Encoder Model] Content Vec")
|
||||
print(" [Loading] " + path)
|
||||
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
|
||||
[path],
|
||||
suffix="",
|
||||
)
|
||||
self.hubert = self.models[0]
|
||||
self.hubert = self.hubert.to(self.device)
|
||||
self.hubert.eval()
|
||||
|
||||
def __call__(self, audio): # B, T
|
||||
# wav_tensor = torch.from_numpy(audio).to(self.device)
|
||||
wav_tensor = audio
|
||||
feats = wav_tensor.view(1, -1)
|
||||
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
|
||||
inputs = {
|
||||
"source": feats.to(wav_tensor.device),
|
||||
"padding_mask": padding_mask.to(wav_tensor.device),
|
||||
"output_layer": 12, # layer 12
|
||||
}
|
||||
with torch.no_grad():
|
||||
logits = self.hubert.extract_features(**inputs)
|
||||
feats = logits[0]
|
||||
units = feats # .transpose(2, 1)
|
||||
return units
|
||||
|
||||
|
||||
class CNHubertSoftFish(torch.nn.Module):
|
||||
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu", gate_size=10):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.gate_size = gate_size
|
||||
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("./pretrain/TencentGameMate/chinese-hubert-base")
|
||||
self.model = HubertModel.from_pretrained("./pretrain/TencentGameMate/chinese-hubert-base")
|
||||
self.proj = torch.nn.Sequential(torch.nn.Dropout(0.1), torch.nn.Linear(768, 256))
|
||||
# self.label_embedding = nn.Embedding(128, 256)
|
||||
|
||||
state_dict = torch.load(path, map_location=device)
|
||||
self.load_state_dict(state_dict)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, audio):
|
||||
input_values = self.feature_extractor(audio, sampling_rate=16000, return_tensors="pt").input_values
|
||||
input_values = input_values.to(self.model.device)
|
||||
|
||||
return self._forward(input_values[0])
|
||||
|
||||
@torch.no_grad()
|
||||
def _forward(self, input_values):
|
||||
features = self.model(input_values)
|
||||
features = self.proj(features.last_hidden_state)
|
||||
|
||||
# Top-k gating
|
||||
topk, indices = torch.topk(features, self.gate_size, dim=2)
|
||||
features = torch.zeros_like(features).scatter(2, indices, topk)
|
||||
features = features / features.sum(2, keepdim=True)
|
||||
|
||||
return features.to(self.device) # .transpose(1, 2)
|
||||
|
||||
|
||||
class Audio2HubertBase:
|
||||
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"):
|
||||
self.device = device
|
||||
print(" [Encoder Model] HuBERT Base")
|
||||
print(" [Loading] " + path)
|
||||
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
|
||||
[path],
|
||||
suffix="",
|
||||
)
|
||||
self.hubert = self.models[0]
|
||||
self.hubert = self.hubert.to(self.device)
|
||||
self.hubert = self.hubert.float()
|
||||
self.hubert.eval()
|
||||
|
||||
def __call__(self, audio): # B, T
|
||||
with torch.no_grad():
|
||||
padding_mask = torch.BoolTensor(audio.shape).fill_(False)
|
||||
inputs = {
|
||||
"source": audio.to(self.device),
|
||||
"padding_mask": padding_mask.to(self.device),
|
||||
"output_layer": 9, # layer 9
|
||||
}
|
||||
logits = self.hubert.extract_features(**inputs)
|
||||
units = self.hubert.final_proj(logits[0])
|
||||
return units
|
||||
|
||||
|
||||
class Audio2HubertBase768:
|
||||
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"):
|
||||
self.device = device
|
||||
print(" [Encoder Model] HuBERT Base")
|
||||
print(" [Loading] " + path)
|
||||
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
|
||||
[path],
|
||||
suffix="",
|
||||
)
|
||||
self.hubert = self.models[0]
|
||||
self.hubert = self.hubert.to(self.device)
|
||||
self.hubert = self.hubert.float()
|
||||
self.hubert.eval()
|
||||
|
||||
def __call__(self, audio): # B, T
|
||||
with torch.no_grad():
|
||||
padding_mask = torch.BoolTensor(audio.shape).fill_(False)
|
||||
inputs = {
|
||||
"source": audio.to(self.device),
|
||||
"padding_mask": padding_mask.to(self.device),
|
||||
"output_layer": 9, # layer 9
|
||||
}
|
||||
logits = self.hubert.extract_features(**inputs)
|
||||
units = logits[0]
|
||||
return units
|
||||
|
||||
|
||||
class Audio2HubertBase768L12:
|
||||
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"):
|
||||
self.device = device
|
||||
print(" [Encoder Model] HuBERT Base")
|
||||
print(" [Loading] " + path)
|
||||
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
|
||||
[path],
|
||||
suffix="",
|
||||
)
|
||||
self.hubert = self.models[0]
|
||||
self.hubert = self.hubert.to(self.device)
|
||||
self.hubert = self.hubert.float()
|
||||
self.hubert.eval()
|
||||
|
||||
def __call__(self, audio): # B, T
|
||||
with torch.no_grad():
|
||||
padding_mask = torch.BoolTensor(audio.shape).fill_(False)
|
||||
inputs = {
|
||||
"source": audio.to(self.device),
|
||||
"padding_mask": padding_mask.to(self.device),
|
||||
"output_layer": 12, # layer 12
|
||||
}
|
||||
logits = self.hubert.extract_features(**inputs)
|
||||
units = logits[0]
|
||||
return units
|
||||
|
||||
|
||||
class Audio2HubertLarge1024L24:
|
||||
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"):
|
||||
self.device = device
|
||||
print(" [Encoder Model] HuBERT Base")
|
||||
print(" [Loading] " + path)
|
||||
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
|
||||
[path],
|
||||
suffix="",
|
||||
)
|
||||
self.hubert = self.models[0]
|
||||
self.hubert = self.hubert.to(self.device)
|
||||
self.hubert = self.hubert.float()
|
||||
self.hubert.eval()
|
||||
|
||||
def __call__(self, audio): # B, T
|
||||
with torch.no_grad():
|
||||
padding_mask = torch.BoolTensor(audio.shape).fill_(False)
|
||||
inputs = {
|
||||
"source": audio.to(self.device),
|
||||
"padding_mask": padding_mask.to(self.device),
|
||||
"output_layer": 24, # layer 24
|
||||
}
|
||||
logits = self.hubert.extract_features(**inputs)
|
||||
units = logits[0]
|
||||
return units
|
||||
|
||||
|
||||
class DotDict(dict):
|
||||
def __getattr__(*args):
|
||||
val = dict.get(*args)
|
||||
return DotDict(val) if type(val) is dict else val
|
||||
|
||||
__setattr__ = dict.__setitem__
|
||||
__delattr__ = dict.__delitem__
|
||||
|
||||
|
||||
def load_model(model_path, device="cpu"):
|
||||
config_file = os.path.join(os.path.split(model_path)[0], "config.yaml")
|
||||
with open(config_file, "r") as config:
|
||||
args = yaml.safe_load(config)
|
||||
args = DotDict(args)
|
||||
|
||||
# load model
|
||||
model = None
|
||||
|
||||
if args.model.type == "Sins":
|
||||
model = Sins(sampling_rate=args.data.sampling_rate, block_size=args.data.block_size, n_harmonics=args.model.n_harmonics, n_mag_allpass=args.model.n_mag_allpass, n_mag_noise=args.model.n_mag_noise, n_unit=args.data.encoder_out_channels, n_spk=args.model.n_spk)
|
||||
|
||||
elif args.model.type == "CombSub":
|
||||
model = CombSub(sampling_rate=args.data.sampling_rate, block_size=args.data.block_size, n_mag_allpass=args.model.n_mag_allpass, n_mag_harmonic=args.model.n_mag_harmonic, n_mag_noise=args.model.n_mag_noise, n_unit=args.data.encoder_out_channels, n_spk=args.model.n_spk)
|
||||
|
||||
elif args.model.type == "CombSubFast":
|
||||
model = CombSubFast(sampling_rate=args.data.sampling_rate, block_size=args.data.block_size, n_unit=args.data.encoder_out_channels, n_spk=args.model.n_spk)
|
||||
|
||||
else:
|
||||
raise ValueError(f" [x] Unknown Model: {args.model.type}")
|
||||
|
||||
print(" [Loading] " + model_path)
|
||||
ckpt = torch.load(model_path, map_location=torch.device(device))
|
||||
model.to(device)
|
||||
model.load_state_dict(ckpt["model"])
|
||||
model.eval()
|
||||
return model, args
|
||||
|
||||
|
||||
class Sins(torch.nn.Module):
|
||||
def __init__(self, sampling_rate, block_size, n_harmonics, n_mag_allpass, n_mag_noise, n_unit=256, n_spk=1):
|
||||
super().__init__()
|
||||
|
||||
print(" [DDSP Model] Sinusoids Additive Synthesiser")
|
||||
|
||||
# params
|
||||
self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
|
||||
self.register_buffer("block_size", torch.tensor(block_size))
|
||||
# Unit2Control
|
||||
split_map = {
|
||||
"amplitudes": n_harmonics,
|
||||
"group_delay": n_mag_allpass,
|
||||
"noise_magnitude": n_mag_noise,
|
||||
}
|
||||
self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map)
|
||||
|
||||
def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, initial_phase=None, infer=True, max_upsample_dim=32):
|
||||
"""
|
||||
units_frames: B x n_frames x n_unit
|
||||
f0_frames: B x n_frames x 1
|
||||
volume_frames: B x n_frames x 1
|
||||
spk_id: B x 1
|
||||
"""
|
||||
# exciter phase
|
||||
f0 = upsample(f0_frames, self.block_size)
|
||||
if infer:
|
||||
x = torch.cumsum(f0.double() / self.sampling_rate, axis=1)
|
||||
else:
|
||||
x = torch.cumsum(f0 / self.sampling_rate, axis=1)
|
||||
if initial_phase is not None:
|
||||
x += initial_phase.to(x) / 2 / np.pi
|
||||
x = x - torch.round(x)
|
||||
x = x.to(f0)
|
||||
|
||||
phase = 2 * np.pi * x
|
||||
phase_frames = phase[:, :: self.block_size, :]
|
||||
|
||||
# parameter prediction
|
||||
ctrls = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict)
|
||||
|
||||
amplitudes_frames = torch.exp(ctrls["amplitudes"]) / 128
|
||||
group_delay = np.pi * torch.tanh(ctrls["group_delay"])
|
||||
noise_param = torch.exp(ctrls["noise_magnitude"]) / 128
|
||||
|
||||
# sinusoids exciter signal
|
||||
amplitudes_frames = remove_above_fmax(amplitudes_frames, f0_frames, self.sampling_rate / 2, level_start=1)
|
||||
n_harmonic = amplitudes_frames.shape[-1]
|
||||
level_harmonic = torch.arange(1, n_harmonic + 1).to(phase)
|
||||
sinusoids = 0.0
|
||||
for n in range((n_harmonic - 1) // max_upsample_dim + 1):
|
||||
start = n * max_upsample_dim
|
||||
end = (n + 1) * max_upsample_dim
|
||||
phases = phase * level_harmonic[start:end]
|
||||
amplitudes = upsample(amplitudes_frames[:, :, start:end], self.block_size)
|
||||
sinusoids += (torch.sin(phases) * amplitudes).sum(-1)
|
||||
|
||||
# harmonic part filter (apply group-delay)
|
||||
harmonic = frequency_filter(sinusoids, torch.exp(1.0j * torch.cumsum(group_delay, axis=-1)), hann_window=False)
|
||||
|
||||
# noise part filter
|
||||
noise = torch.rand_like(harmonic) * 2 - 1
|
||||
noise = frequency_filter(noise, torch.complex(noise_param, torch.zeros_like(noise_param)), hann_window=True)
|
||||
|
||||
signal = harmonic + noise
|
||||
|
||||
return signal, phase, (harmonic, noise) # , (noise_param, noise_param)
|
||||
|
||||
|
||||
class CombSubFast(torch.nn.Module):
|
||||
def __init__(self, sampling_rate, block_size, n_unit=256, n_spk=1):
|
||||
super().__init__()
|
||||
|
||||
print(" [DDSP Model] Combtooth Subtractive Synthesiser")
|
||||
# params
|
||||
self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
|
||||
self.register_buffer("block_size", torch.tensor(block_size))
|
||||
self.register_buffer("window", torch.sqrt(torch.hann_window(2 * block_size)))
|
||||
# Unit2Control
|
||||
split_map = {"harmonic_magnitude": block_size + 1, "harmonic_phase": block_size + 1, "noise_magnitude": block_size + 1}
|
||||
self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map)
|
||||
|
||||
def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, initial_phase=None, infer=True, **kwargs):
|
||||
"""
|
||||
units_frames: B x n_frames x n_unit
|
||||
f0_frames: B x n_frames x 1
|
||||
volume_frames: B x n_frames x 1
|
||||
spk_id: B x 1
|
||||
"""
|
||||
# exciter phase
|
||||
f0 = upsample(f0_frames, self.block_size)
|
||||
if infer:
|
||||
x = torch.cumsum(f0.double() / self.sampling_rate, axis=1)
|
||||
else:
|
||||
x = torch.cumsum(f0 / self.sampling_rate, axis=1)
|
||||
if initial_phase is not None:
|
||||
x += initial_phase.to(x) / 2 / np.pi
|
||||
x = x - torch.round(x)
|
||||
x = x.to(f0)
|
||||
|
||||
phase_frames = 2 * np.pi * x[:, :: self.block_size, :]
|
||||
|
||||
# parameter prediction
|
||||
ctrls = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict)
|
||||
|
||||
src_filter = torch.exp(ctrls["harmonic_magnitude"] + 1.0j * np.pi * ctrls["harmonic_phase"])
|
||||
src_filter = torch.cat((src_filter, src_filter[:, -1:, :]), 1)
|
||||
noise_filter = torch.exp(ctrls["noise_magnitude"]) / 128
|
||||
noise_filter = torch.cat((noise_filter, noise_filter[:, -1:, :]), 1)
|
||||
|
||||
# combtooth exciter signal
|
||||
combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3))
|
||||
combtooth = combtooth.squeeze(-1)
|
||||
combtooth_frames = F.pad(combtooth, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size)
|
||||
combtooth_frames = combtooth_frames * self.window
|
||||
combtooth_fft = torch.fft.rfft(combtooth_frames, 2 * self.block_size)
|
||||
|
||||
# noise exciter signal
|
||||
noise = torch.rand_like(combtooth) * 2 - 1
|
||||
noise_frames = F.pad(noise, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size)
|
||||
noise_frames = noise_frames * self.window
|
||||
noise_fft = torch.fft.rfft(noise_frames, 2 * self.block_size)
|
||||
|
||||
# apply the filters
|
||||
signal_fft = combtooth_fft * src_filter + noise_fft * noise_filter
|
||||
|
||||
# take the ifft to resynthesize audio.
|
||||
signal_frames_out = torch.fft.irfft(signal_fft, 2 * self.block_size) * self.window
|
||||
|
||||
# overlap add
|
||||
fold = torch.nn.Fold(output_size=(1, (signal_frames_out.size(1) + 1) * self.block_size), kernel_size=(1, 2 * self.block_size), stride=(1, self.block_size))
|
||||
signal = fold(signal_frames_out.transpose(1, 2))[:, 0, 0, self.block_size : -self.block_size]
|
||||
|
||||
return signal, phase_frames, (signal, signal)
|
||||
|
||||
|
||||
class CombSub(torch.nn.Module):
|
||||
def __init__(self, sampling_rate, block_size, n_mag_allpass, n_mag_harmonic, n_mag_noise, n_unit=256, n_spk=1):
|
||||
super().__init__()
|
||||
|
||||
print(" [DDSP Model] Combtooth Subtractive Synthesiser (Old Version)")
|
||||
# params
|
||||
self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
|
||||
self.register_buffer("block_size", torch.tensor(block_size))
|
||||
# Unit2Control
|
||||
split_map = {"group_delay": n_mag_allpass, "harmonic_magnitude": n_mag_harmonic, "noise_magnitude": n_mag_noise}
|
||||
self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map)
|
||||
|
||||
def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, initial_phase=None, infer=True, **kwargs):
|
||||
"""
|
||||
units_frames: B x n_frames x n_unit
|
||||
f0_frames: B x n_frames x 1
|
||||
volume_frames: B x n_frames x 1
|
||||
spk_id: B x 1
|
||||
"""
|
||||
# exciter phase
|
||||
f0 = upsample(f0_frames, self.block_size)
|
||||
if infer:
|
||||
x = torch.cumsum(f0.double() / self.sampling_rate, axis=1)
|
||||
else:
|
||||
x = torch.cumsum(f0 / self.sampling_rate, axis=1)
|
||||
if initial_phase is not None:
|
||||
x += initial_phase.to(x) / 2 / np.pi
|
||||
x = x - torch.round(x)
|
||||
x = x.to(f0)
|
||||
|
||||
phase_frames = 2 * np.pi * x[:, :: self.block_size, :]
|
||||
|
||||
# parameter prediction
|
||||
ctrls = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict)
|
||||
|
||||
group_delay = np.pi * torch.tanh(ctrls["group_delay"])
|
||||
src_param = torch.exp(ctrls["harmonic_magnitude"])
|
||||
noise_param = torch.exp(ctrls["noise_magnitude"]) / 128
|
||||
|
||||
# combtooth exciter signal
|
||||
combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3))
|
||||
combtooth = combtooth.squeeze(-1)
|
||||
|
||||
# harmonic part filter (using dynamic-windowed LTV-FIR, with group-delay prediction)
|
||||
harmonic = frequency_filter(combtooth, torch.exp(1.0j * torch.cumsum(group_delay, axis=-1)), hann_window=False)
|
||||
harmonic = frequency_filter(harmonic, torch.complex(src_param, torch.zeros_like(src_param)), hann_window=True, half_width_frames=1.5 * self.sampling_rate / (f0_frames + 1e-3))
|
||||
|
||||
# noise part filter (using constant-windowed LTV-FIR, without group-delay)
|
||||
noise = torch.rand_like(harmonic) * 2 - 1
|
||||
noise = frequency_filter(noise, torch.complex(noise_param, torch.zeros_like(noise_param)), hann_window=True)
|
||||
|
||||
signal = harmonic + noise
|
||||
|
||||
return signal, phase_frames, (harmonic, noise)
|
215
server/voice_changer/DDSP_SVC/models/diffusion/data_loaders.py
Normal file
215
server/voice_changer/DDSP_SVC/models/diffusion/data_loaders.py
Normal file
@ -0,0 +1,215 @@
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import numpy as np
|
||||
import librosa
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def traverse_dir(root_dir, extensions, amount=None, str_include=None, str_exclude=None, is_pure=False, is_sort=False, is_ext=True):
|
||||
file_list = []
|
||||
cnt = 0
|
||||
for root, _, files in os.walk(root_dir):
|
||||
for file in files:
|
||||
if any([file.endswith(f".{ext}") for ext in extensions]):
|
||||
# path
|
||||
mix_path = os.path.join(root, file)
|
||||
pure_path = mix_path[len(root_dir) + 1 :] if is_pure else mix_path
|
||||
|
||||
# amount
|
||||
if (amount is not None) and (cnt == amount):
|
||||
if is_sort:
|
||||
file_list.sort()
|
||||
return file_list
|
||||
|
||||
# check string
|
||||
if (str_include is not None) and (str_include not in pure_path):
|
||||
continue
|
||||
if (str_exclude is not None) and (str_exclude in pure_path):
|
||||
continue
|
||||
|
||||
if not is_ext:
|
||||
ext = pure_path.split(".")[-1]
|
||||
pure_path = pure_path[: -(len(ext) + 1)]
|
||||
file_list.append(pure_path)
|
||||
cnt += 1
|
||||
if is_sort:
|
||||
file_list.sort()
|
||||
return file_list
|
||||
|
||||
|
||||
def get_data_loaders(args, whole_audio=False):
|
||||
data_train = AudioDataset(args.data.train_path, waveform_sec=args.data.duration, hop_size=args.data.block_size, sample_rate=args.data.sampling_rate, load_all_data=args.train.cache_all_data, whole_audio=whole_audio, extensions=args.data.extensions, n_spk=args.model.n_spk, device=args.train.cache_device, fp16=args.train.cache_fp16, use_aug=True)
|
||||
loader_train = torch.utils.data.DataLoader(data_train, batch_size=args.train.batch_size if not whole_audio else 1, shuffle=True, num_workers=args.train.num_workers if args.train.cache_device == "cpu" else 0, persistent_workers=(args.train.num_workers > 0) if args.train.cache_device == "cpu" else False, pin_memory=True if args.train.cache_device == "cpu" else False)
|
||||
data_valid = AudioDataset(args.data.valid_path, waveform_sec=args.data.duration, hop_size=args.data.block_size, sample_rate=args.data.sampling_rate, load_all_data=args.train.cache_all_data, whole_audio=True, extensions=args.data.extensions, n_spk=args.model.n_spk)
|
||||
loader_valid = torch.utils.data.DataLoader(data_valid, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
|
||||
return loader_train, loader_valid
|
||||
|
||||
|
||||
class AudioDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
path_root,
|
||||
waveform_sec,
|
||||
hop_size,
|
||||
sample_rate,
|
||||
load_all_data=True,
|
||||
whole_audio=False,
|
||||
extensions=["wav"],
|
||||
n_spk=1,
|
||||
device="cpu",
|
||||
fp16=False,
|
||||
use_aug=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.waveform_sec = waveform_sec
|
||||
self.sample_rate = sample_rate
|
||||
self.hop_size = hop_size
|
||||
self.path_root = path_root
|
||||
self.paths = traverse_dir(os.path.join(path_root, "audio"), extensions=extensions, is_pure=True, is_sort=True, is_ext=True)
|
||||
self.whole_audio = whole_audio
|
||||
self.use_aug = use_aug
|
||||
self.data_buffer = {}
|
||||
self.pitch_aug_dict = np.load(os.path.join(self.path_root, "pitch_aug_dict.npy"), allow_pickle=True).item()
|
||||
if load_all_data:
|
||||
print("Load all the data from :", path_root)
|
||||
else:
|
||||
print("Load the f0, volume data from :", path_root)
|
||||
for name_ext in tqdm(self.paths, total=len(self.paths)):
|
||||
name = os.path.splitext(name_ext)[0] # NOQA
|
||||
path_audio = os.path.join(self.path_root, "audio", name_ext)
|
||||
duration = librosa.get_duration(filename=path_audio, sr=self.sample_rate)
|
||||
|
||||
path_f0 = os.path.join(self.path_root, "f0", name_ext) + ".npy"
|
||||
f0 = np.load(path_f0)
|
||||
f0 = torch.from_numpy(f0).float().unsqueeze(-1).to(device)
|
||||
|
||||
path_volume = os.path.join(self.path_root, "volume", name_ext) + ".npy"
|
||||
volume = np.load(path_volume)
|
||||
volume = torch.from_numpy(volume).float().unsqueeze(-1).to(device)
|
||||
|
||||
path_augvol = os.path.join(self.path_root, "aug_vol", name_ext) + ".npy"
|
||||
aug_vol = np.load(path_augvol)
|
||||
aug_vol = torch.from_numpy(aug_vol).float().unsqueeze(-1).to(device)
|
||||
|
||||
if n_spk is not None and n_spk > 1:
|
||||
dirname_split = re.split(r"_|\-", os.path.dirname(name_ext), 2)[0]
|
||||
spk_id = int(dirname_split) if str.isdigit(dirname_split) else 0
|
||||
if spk_id < 1 or spk_id > n_spk:
|
||||
raise ValueError(" [x] Muiti-speaker traing error : spk_id must be a positive integer from 1 to n_spk ")
|
||||
else:
|
||||
spk_id = 1
|
||||
spk_id = torch.LongTensor(np.array([spk_id])).to(device)
|
||||
|
||||
if load_all_data:
|
||||
"""
|
||||
audio, sr = librosa.load(path_audio, sr=self.sample_rate)
|
||||
if len(audio.shape) > 1:
|
||||
audio = librosa.to_mono(audio)
|
||||
audio = torch.from_numpy(audio).to(device)
|
||||
"""
|
||||
path_mel = os.path.join(self.path_root, "mel", name_ext) + ".npy"
|
||||
mel = np.load(path_mel)
|
||||
mel = torch.from_numpy(mel).to(device)
|
||||
|
||||
path_augmel = os.path.join(self.path_root, "aug_mel", name_ext) + ".npy"
|
||||
aug_mel = np.load(path_augmel)
|
||||
aug_mel = torch.from_numpy(aug_mel).to(device)
|
||||
|
||||
path_units = os.path.join(self.path_root, "units", name_ext) + ".npy"
|
||||
units = np.load(path_units)
|
||||
units = torch.from_numpy(units).to(device)
|
||||
|
||||
if fp16:
|
||||
mel = mel.half()
|
||||
aug_mel = aug_mel.half()
|
||||
units = units.half()
|
||||
|
||||
self.data_buffer[name_ext] = {"duration": duration, "mel": mel, "aug_mel": aug_mel, "units": units, "f0": f0, "volume": volume, "aug_vol": aug_vol, "spk_id": spk_id}
|
||||
else:
|
||||
self.data_buffer[name_ext] = {"duration": duration, "f0": f0, "volume": volume, "aug_vol": aug_vol, "spk_id": spk_id}
|
||||
|
||||
def __getitem__(self, file_idx):
|
||||
name_ext = self.paths[file_idx]
|
||||
data_buffer = self.data_buffer[name_ext]
|
||||
# check duration. if too short, then skip
|
||||
if data_buffer["duration"] < (self.waveform_sec + 0.1):
|
||||
return self.__getitem__((file_idx + 1) % len(self.paths))
|
||||
|
||||
# get item
|
||||
return self.get_data(name_ext, data_buffer)
|
||||
|
||||
def get_data(self, name_ext, data_buffer):
|
||||
name = os.path.splitext(name_ext)[0]
|
||||
frame_resolution = self.hop_size / self.sample_rate
|
||||
duration = data_buffer["duration"]
|
||||
waveform_sec = duration if self.whole_audio else self.waveform_sec
|
||||
|
||||
# load audio
|
||||
idx_from = 0 if self.whole_audio else random.uniform(0, duration - waveform_sec - 0.1)
|
||||
start_frame = int(idx_from / frame_resolution)
|
||||
units_frame_len = int(waveform_sec / frame_resolution)
|
||||
aug_flag = random.choice([True, False]) and self.use_aug
|
||||
"""
|
||||
audio = data_buffer.get('audio')
|
||||
if audio is None:
|
||||
path_audio = os.path.join(self.path_root, 'audio', name) + '.wav'
|
||||
audio, sr = librosa.load(
|
||||
path_audio,
|
||||
sr = self.sample_rate,
|
||||
offset = start_frame * frame_resolution,
|
||||
duration = waveform_sec)
|
||||
if len(audio.shape) > 1:
|
||||
audio = librosa.to_mono(audio)
|
||||
# clip audio into N seconds
|
||||
audio = audio[ : audio.shape[-1] // self.hop_size * self.hop_size]
|
||||
audio = torch.from_numpy(audio).float()
|
||||
else:
|
||||
audio = audio[start_frame * self.hop_size : (start_frame + units_frame_len) * self.hop_size]
|
||||
"""
|
||||
# load mel
|
||||
mel_key = "aug_mel" if aug_flag else "mel"
|
||||
mel = data_buffer.get(mel_key)
|
||||
if mel is None:
|
||||
mel = os.path.join(self.path_root, mel_key, name_ext) + ".npy"
|
||||
mel = np.load(mel)
|
||||
mel = mel[start_frame : start_frame + units_frame_len]
|
||||
mel = torch.from_numpy(mel).float()
|
||||
else:
|
||||
mel = mel[start_frame : start_frame + units_frame_len]
|
||||
|
||||
# load units
|
||||
units = data_buffer.get("units")
|
||||
if units is None:
|
||||
units = os.path.join(self.path_root, "units", name_ext) + ".npy"
|
||||
units = np.load(units)
|
||||
units = units[start_frame : start_frame + units_frame_len]
|
||||
units = torch.from_numpy(units).float()
|
||||
else:
|
||||
units = units[start_frame : start_frame + units_frame_len]
|
||||
|
||||
# load f0
|
||||
f0 = data_buffer.get("f0")
|
||||
aug_shift = 0
|
||||
if aug_flag:
|
||||
aug_shift = self.pitch_aug_dict[name_ext]
|
||||
f0_frames = 2 ** (aug_shift / 12) * f0[start_frame : start_frame + units_frame_len]
|
||||
|
||||
# load volume
|
||||
vol_key = "aug_vol" if aug_flag else "volume"
|
||||
volume = data_buffer.get(vol_key)
|
||||
volume_frames = volume[start_frame : start_frame + units_frame_len]
|
||||
|
||||
# load spk_id
|
||||
spk_id = data_buffer.get("spk_id")
|
||||
|
||||
# load shift
|
||||
aug_shift = torch.from_numpy(np.array([[aug_shift]])).float()
|
||||
|
||||
return dict(mel=mel, f0=f0_frames, volume=volume_frames, units=units, spk_id=spk_id, aug_shift=aug_shift, name=name, name_ext=name_ext)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
342
server/voice_changer/DDSP_SVC/models/diffusion/diffusion.py
Normal file
342
server/voice_changer/DDSP_SVC/models/diffusion/diffusion.py
Normal file
@ -0,0 +1,342 @@
|
||||
from collections import deque
|
||||
from functools import partial
|
||||
from inspect import isfunction
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def extract(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) # NOQA
|
||||
noise = lambda: torch.randn(shape, device=device) # NOQA
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
||||
|
||||
def linear_beta_schedule(timesteps, max_beta=0.02):
|
||||
"""
|
||||
linear schedule
|
||||
"""
|
||||
betas = np.linspace(1e-4, max_beta, timesteps)
|
||||
return betas
|
||||
|
||||
|
||||
def cosine_beta_schedule(timesteps, s=0.008):
|
||||
"""
|
||||
cosine schedule
|
||||
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
||||
"""
|
||||
steps = timesteps + 1
|
||||
x = np.linspace(0, steps, steps)
|
||||
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
|
||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||
return np.clip(betas, a_min=0, a_max=0.999)
|
||||
|
||||
|
||||
beta_schedule = {
|
||||
"cosine": cosine_beta_schedule,
|
||||
"linear": linear_beta_schedule,
|
||||
}
|
||||
|
||||
|
||||
class GaussianDiffusion(nn.Module):
|
||||
def __init__(self, denoise_fn, out_dims=128, timesteps=1000, k_step=1000, max_beta=0.02, spec_min=-12, spec_max=2):
|
||||
super().__init__()
|
||||
self.denoise_fn = denoise_fn
|
||||
self.out_dims = out_dims
|
||||
betas = beta_schedule["linear"](timesteps, max_beta=max_beta)
|
||||
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
|
||||
|
||||
(timesteps,) = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.k_step = k_step
|
||||
|
||||
self.noise_list = deque(maxlen=4)
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
|
||||
self.register_buffer("betas", to_torch(betas))
|
||||
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
|
||||
self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)))
|
||||
self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)))
|
||||
self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)))
|
||||
self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)))
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||
self.register_buffer("posterior_variance", to_torch(posterior_variance))
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
self.register_buffer("posterior_log_variance_clipped", to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
||||
self.register_buffer("posterior_mean_coef1", to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)))
|
||||
self.register_buffer("posterior_mean_coef2", to_torch((1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)))
|
||||
|
||||
self.register_buffer("spec_min", torch.FloatTensor([spec_min])[None, None, :out_dims])
|
||||
self.register_buffer("spec_max", torch.FloatTensor([spec_max])[None, None, :out_dims])
|
||||
|
||||
def q_mean_variance(self, x_start, t):
|
||||
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape)
|
||||
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
||||
return mean, variance, log_variance
|
||||
|
||||
def predict_start_from_noise(self, x_t, t, noise):
|
||||
return extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
||||
|
||||
def q_posterior(self, x_start, x_t, t):
|
||||
posterior_mean = extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, x, t, cond):
|
||||
noise_pred = self.denoise_fn(x, t, cond=cond)
|
||||
x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
|
||||
|
||||
x_recon.clamp_(-1.0, 1.0)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
|
||||
b, *_, device = *x.shape, x.device # type:ignore
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self, x, t, interval, cond):
|
||||
a_t = extract(self.alphas_cumprod, t, x.shape)
|
||||
a_prev = extract(self.alphas_cumprod, torch.max(t - interval, torch.zeros_like(t)), x.shape)
|
||||
|
||||
noise_pred = self.denoise_fn(x, t, cond=cond)
|
||||
x_prev = a_prev.sqrt() * (x / a_t.sqrt() + (((1 - a_prev) / a_prev).sqrt() - ((1 - a_t) / a_t).sqrt()) * noise_pred)
|
||||
return x_prev
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False):
|
||||
"""
|
||||
Use the PLMS method from
|
||||
[Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778).
|
||||
"""
|
||||
|
||||
def get_x_pred(x, noise_t, t):
|
||||
a_t = extract(self.alphas_cumprod, t, x.shape)
|
||||
a_prev = extract(self.alphas_cumprod, torch.max(t - interval, torch.zeros_like(t)), x.shape)
|
||||
a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
|
||||
|
||||
x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
|
||||
x_pred = x + x_delta
|
||||
|
||||
return x_pred
|
||||
|
||||
noise_list = self.noise_list
|
||||
noise_pred = self.denoise_fn(x, t, cond=cond)
|
||||
|
||||
if len(noise_list) == 0:
|
||||
x_pred = get_x_pred(x, noise_pred, t)
|
||||
noise_pred_prev = self.denoise_fn(x_pred, max(t - interval, 0), cond=cond)
|
||||
noise_pred_prime = (noise_pred + noise_pred_prev) / 2
|
||||
elif len(noise_list) == 1:
|
||||
noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2
|
||||
elif len(noise_list) == 2:
|
||||
noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12
|
||||
else:
|
||||
noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24
|
||||
|
||||
x_prev = get_x_pred(x, noise_pred_prime, t)
|
||||
noise_list.append(noise_pred)
|
||||
|
||||
return x_prev
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
return extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
|
||||
def p_losses(self, x_start, t, cond, noise=None, loss_type="l2"):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
||||
x_recon = self.denoise_fn(x_noisy, t, cond)
|
||||
|
||||
if loss_type == "l1":
|
||||
loss = (noise - x_recon).abs().mean()
|
||||
elif loss_type == "l2":
|
||||
loss = F.mse_loss(noise, x_recon)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return loss
|
||||
|
||||
def forward(self, condition, gt_spec=None, infer=True, infer_speedup=10, method="dpm-solver", k_step=None, use_tqdm=True):
|
||||
"""
|
||||
conditioning diffusion, use fastspeech2 encoder output as the condition
|
||||
"""
|
||||
cond = condition.transpose(1, 2)
|
||||
b, device = condition.shape[0], condition.device
|
||||
|
||||
if not infer:
|
||||
spec = self.norm_spec(gt_spec)
|
||||
if k_step is None:
|
||||
t_max = self.k_step
|
||||
else:
|
||||
t_max = k_step
|
||||
t = torch.randint(0, t_max, (b,), device=device).long()
|
||||
norm_spec = spec.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
|
||||
return self.p_losses(norm_spec, t, cond=cond)
|
||||
else:
|
||||
shape = (cond.shape[0], 1, self.out_dims, cond.shape[2])
|
||||
|
||||
if gt_spec is None or k_step is None:
|
||||
t = self.k_step
|
||||
x = torch.randn(shape, device=device)
|
||||
else:
|
||||
t = k_step
|
||||
norm_spec = self.norm_spec(gt_spec)
|
||||
norm_spec = norm_spec.transpose(1, 2)[:, None, :, :]
|
||||
x = self.q_sample(x_start=norm_spec, t=torch.tensor([t - 1], device=device).long())
|
||||
|
||||
if method is not None and infer_speedup > 1:
|
||||
if method == "dpm-solver":
|
||||
from .dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
|
||||
|
||||
# 1. Define the noise schedule.
|
||||
noise_schedule = NoiseScheduleVP(schedule="discrete", betas=self.betas[:t])
|
||||
|
||||
# 2. Convert your discrete-time `model` to the continuous-time
|
||||
# noise prediction model. Here is an example for a diffusion model
|
||||
# `model` with the noise prediction type ("noise") .
|
||||
def my_wrapper(fn):
|
||||
def wrapped(x, t, **kwargs):
|
||||
ret = fn(x, t, **kwargs)
|
||||
if use_tqdm:
|
||||
self.bar.update(1)
|
||||
return ret
|
||||
|
||||
return wrapped
|
||||
|
||||
model_fn = model_wrapper(my_wrapper(self.denoise_fn), noise_schedule, model_type="noise", model_kwargs={"cond": cond}) # or "x_start" or "v" or "score"
|
||||
|
||||
# 3. Define dpm-solver and sample by singlestep DPM-Solver.
|
||||
# (We recommend singlestep DPM-Solver for unconditional sampling)
|
||||
# You can adjust the `steps` to balance the computation
|
||||
# costs and the sample quality.
|
||||
dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
|
||||
|
||||
steps = t // infer_speedup
|
||||
if use_tqdm:
|
||||
self.bar = tqdm(desc="sample time step", total=steps)
|
||||
x = dpm_solver.sample(
|
||||
x,
|
||||
steps=steps,
|
||||
order=2,
|
||||
skip_type="time_uniform",
|
||||
method="multistep",
|
||||
)
|
||||
if use_tqdm:
|
||||
self.bar.close()
|
||||
elif method == "unipc":
|
||||
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
|
||||
|
||||
# 1. Define the noise schedule.
|
||||
noise_schedule = NoiseScheduleVP(schedule="discrete", betas=self.betas[:t])
|
||||
|
||||
# 2. Convert your discrete-time `model` to the continuous-time
|
||||
# noise prediction model. Here is an example for a diffusion model
|
||||
# `model` with the noise prediction type ("noise") .
|
||||
def my_wrapper(fn):
|
||||
def wrapped(x, t, **kwargs):
|
||||
ret = fn(x, t, **kwargs)
|
||||
if use_tqdm:
|
||||
self.bar.update(1)
|
||||
return ret
|
||||
|
||||
return wrapped
|
||||
|
||||
model_fn = model_wrapper(my_wrapper(self.denoise_fn), noise_schedule, model_type="noise", model_kwargs={"cond": cond}) # or "x_start" or "v" or "score"
|
||||
|
||||
# 3. Define uni_pc and sample by multistep UniPC.
|
||||
# You can adjust the `steps` to balance the computation
|
||||
# costs and the sample quality.
|
||||
uni_pc = UniPC(model_fn, noise_schedule, variant="bh2")
|
||||
|
||||
steps = t // infer_speedup
|
||||
if use_tqdm:
|
||||
self.bar = tqdm(desc="sample time step", total=steps)
|
||||
x = uni_pc.sample(
|
||||
x,
|
||||
steps=steps,
|
||||
order=2,
|
||||
skip_type="time_uniform",
|
||||
method="multistep",
|
||||
)
|
||||
if use_tqdm:
|
||||
self.bar.close()
|
||||
elif method == "pndm":
|
||||
self.noise_list = deque(maxlen=4)
|
||||
if use_tqdm:
|
||||
for i in tqdm(
|
||||
reversed(range(0, t, infer_speedup)),
|
||||
desc="sample time step",
|
||||
total=t // infer_speedup,
|
||||
):
|
||||
x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond)
|
||||
else:
|
||||
for i in reversed(range(0, t, infer_speedup)):
|
||||
x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond)
|
||||
elif method == "ddim":
|
||||
if use_tqdm:
|
||||
for i in tqdm(
|
||||
reversed(range(0, t, infer_speedup)),
|
||||
desc="sample time step",
|
||||
total=t // infer_speedup,
|
||||
):
|
||||
x = self.p_sample_ddim(x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond)
|
||||
else:
|
||||
for i in reversed(range(0, t, infer_speedup)):
|
||||
x = self.p_sample_ddim(x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond)
|
||||
else:
|
||||
raise NotImplementedError(method)
|
||||
else:
|
||||
if use_tqdm:
|
||||
for i in tqdm(reversed(range(0, t)), desc="sample time step", total=t):
|
||||
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
||||
else:
|
||||
for i in reversed(range(0, t)):
|
||||
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
||||
x = x.squeeze(1).transpose(1, 2) # [B, T, M]
|
||||
return self.denorm_spec(x)
|
||||
|
||||
def norm_spec(self, x):
|
||||
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
||||
|
||||
def denorm_spec(self, x):
|
||||
return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
|
526
server/voice_changer/DDSP_SVC/models/diffusion/diffusion_onnx.py
Normal file
526
server/voice_changer/DDSP_SVC/models/diffusion/diffusion_onnx.py
Normal file
@ -0,0 +1,526 @@
|
||||
from collections import deque
|
||||
from functools import partial
|
||||
from inspect import isfunction
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from torch.nn import Conv1d
|
||||
from torch.nn import Mish
|
||||
import torch
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
import math
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def extract(a, t):
|
||||
return a[t].reshape((1, 1, 1, 1))
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) # NOQA
|
||||
noise = lambda: torch.randn(shape, device=device) # NOQA
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
||||
|
||||
def linear_beta_schedule(timesteps, max_beta=0.02):
|
||||
"""
|
||||
linear schedule
|
||||
"""
|
||||
betas = np.linspace(1e-4, max_beta, timesteps)
|
||||
return betas
|
||||
|
||||
|
||||
def cosine_beta_schedule(timesteps, s=0.008):
|
||||
"""
|
||||
cosine schedule
|
||||
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
||||
"""
|
||||
steps = timesteps + 1
|
||||
x = np.linspace(0, steps, steps)
|
||||
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
|
||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||
return np.clip(betas, a_min=0, a_max=0.999)
|
||||
|
||||
|
||||
beta_schedule = {
|
||||
"cosine": cosine_beta_schedule,
|
||||
"linear": linear_beta_schedule,
|
||||
}
|
||||
|
||||
|
||||
def extract_1(a, t):
|
||||
return a[t].reshape((1, 1, 1, 1))
|
||||
|
||||
|
||||
def predict_stage0(noise_pred, noise_pred_prev):
|
||||
return (noise_pred + noise_pred_prev) / 2
|
||||
|
||||
|
||||
def predict_stage1(noise_pred, noise_list):
|
||||
return (noise_pred * 3 - noise_list[-1]) / 2
|
||||
|
||||
|
||||
def predict_stage2(noise_pred, noise_list):
|
||||
return (noise_pred * 23 - noise_list[-1] * 16 + noise_list[-2] * 5) / 12
|
||||
|
||||
|
||||
def predict_stage3(noise_pred, noise_list):
|
||||
return (noise_pred * 55 - noise_list[-1] * 59 + noise_list[-2] * 37 - noise_list[-3] * 9) / 24
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.half_dim = dim // 2
|
||||
self.emb = 9.21034037 / (self.half_dim - 1)
|
||||
self.emb = torch.exp(torch.arange(self.half_dim) * torch.tensor(-self.emb)).unsqueeze(0)
|
||||
self.emb = self.emb.cpu()
|
||||
|
||||
def forward(self, x):
|
||||
emb = self.emb * x
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, encoder_hidden, residual_channels, dilation):
|
||||
super().__init__()
|
||||
self.residual_channels = residual_channels
|
||||
self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation)
|
||||
self.diffusion_projection = nn.Linear(residual_channels, residual_channels)
|
||||
self.conditioner_projection = Conv1d(encoder_hidden, 2 * residual_channels, 1)
|
||||
self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)
|
||||
|
||||
def forward(self, x, conditioner, diffusion_step):
|
||||
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
|
||||
conditioner = self.conditioner_projection(conditioner)
|
||||
y = x + diffusion_step
|
||||
y = self.dilated_conv(y) + conditioner
|
||||
|
||||
gate, filter_1 = torch.split(y, [self.residual_channels, self.residual_channels], dim=1)
|
||||
|
||||
y = torch.sigmoid(gate) * torch.tanh(filter_1)
|
||||
y = self.output_projection(y)
|
||||
|
||||
residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1)
|
||||
|
||||
return (x + residual) / 1.41421356, skip
|
||||
|
||||
|
||||
class DiffNet(nn.Module):
|
||||
def __init__(self, in_dims, n_layers, n_chans, n_hidden):
|
||||
super().__init__()
|
||||
self.encoder_hidden = n_hidden
|
||||
self.residual_layers = n_layers
|
||||
self.residual_channels = n_chans
|
||||
self.input_projection = Conv1d(in_dims, self.residual_channels, 1)
|
||||
self.diffusion_embedding = SinusoidalPosEmb(self.residual_channels)
|
||||
dim = self.residual_channels
|
||||
self.mlp = nn.Sequential(nn.Linear(dim, dim * 4), Mish(), nn.Linear(dim * 4, dim))
|
||||
self.residual_layers = nn.ModuleList([ResidualBlock(self.encoder_hidden, self.residual_channels, 1) for i in range(self.residual_layers)])
|
||||
self.skip_projection = Conv1d(self.residual_channels, self.residual_channels, 1)
|
||||
self.output_projection = Conv1d(self.residual_channels, in_dims, 1)
|
||||
nn.init.zeros_(self.output_projection.weight)
|
||||
|
||||
def forward(self, spec, diffusion_step, cond):
|
||||
x = spec.squeeze(0)
|
||||
x = self.input_projection(x) # x [B, residual_channel, T]
|
||||
x = F.relu(x)
|
||||
# skip = torch.randn_like(x)
|
||||
diffusion_step = diffusion_step.float()
|
||||
diffusion_step = self.diffusion_embedding(diffusion_step)
|
||||
diffusion_step = self.mlp(diffusion_step)
|
||||
|
||||
x, skip = self.residual_layers[0](x, cond, diffusion_step)
|
||||
# noinspection PyTypeChecker
|
||||
for layer in self.residual_layers[1:]:
|
||||
x, skip_connection = layer.forward(x, cond, diffusion_step)
|
||||
skip = skip + skip_connection
|
||||
x = skip / math.sqrt(len(self.residual_layers))
|
||||
x = self.skip_projection(x)
|
||||
x = F.relu(x)
|
||||
x = self.output_projection(x) # [B, 80, T]
|
||||
return x.unsqueeze(1)
|
||||
|
||||
|
||||
class AfterDiffusion(nn.Module):
|
||||
def __init__(self, spec_max, spec_min, v_type="a"):
|
||||
super().__init__()
|
||||
self.spec_max = spec_max
|
||||
self.spec_min = spec_min
|
||||
self.type = v_type
|
||||
|
||||
def forward(self, x):
|
||||
x = x.squeeze(1).permute(0, 2, 1)
|
||||
mel_out = (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
|
||||
if self.type == "nsf-hifigan-log10":
|
||||
mel_out = mel_out * 0.434294
|
||||
return mel_out.transpose(2, 1)
|
||||
|
||||
|
||||
class Pred(nn.Module):
|
||||
def __init__(self, alphas_cumprod):
|
||||
super().__init__()
|
||||
self.alphas_cumprod = alphas_cumprod
|
||||
|
||||
def forward(self, x_1, noise_t, t_1, t_prev):
|
||||
a_t = extract(self.alphas_cumprod, t_1).cpu()
|
||||
a_prev = extract(self.alphas_cumprod, t_prev).cpu()
|
||||
a_t_sq, a_prev_sq = a_t.sqrt().cpu(), a_prev.sqrt().cpu()
|
||||
x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x_1 - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
|
||||
x_pred = x_1 + x_delta.cpu()
|
||||
|
||||
return x_pred
|
||||
|
||||
|
||||
class GaussianDiffusion(nn.Module):
|
||||
def __init__(self, out_dims=128, n_layers=20, n_chans=384, n_hidden=256, timesteps=1000, k_step=1000, max_beta=0.02, spec_min=-12, spec_max=2):
|
||||
super().__init__()
|
||||
self.denoise_fn = DiffNet(out_dims, n_layers, n_chans, n_hidden)
|
||||
self.out_dims = out_dims
|
||||
self.mel_bins = out_dims
|
||||
self.n_hidden = n_hidden
|
||||
betas = beta_schedule["linear"](timesteps, max_beta=max_beta)
|
||||
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
|
||||
(timesteps,) = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.k_step = k_step
|
||||
|
||||
self.noise_list = deque(maxlen=4)
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
|
||||
self.register_buffer("betas", to_torch(betas))
|
||||
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
|
||||
self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)))
|
||||
self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)))
|
||||
self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)))
|
||||
self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)))
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||
self.register_buffer("posterior_variance", to_torch(posterior_variance))
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
self.register_buffer("posterior_log_variance_clipped", to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
||||
self.register_buffer("posterior_mean_coef1", to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)))
|
||||
self.register_buffer("posterior_mean_coef2", to_torch((1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)))
|
||||
|
||||
self.register_buffer("spec_min", torch.FloatTensor([spec_min])[None, None, :out_dims])
|
||||
self.register_buffer("spec_max", torch.FloatTensor([spec_max])[None, None, :out_dims])
|
||||
self.ad = AfterDiffusion(self.spec_max, self.spec_min)
|
||||
self.xp = Pred(self.alphas_cumprod)
|
||||
|
||||
def q_mean_variance(self, x_start, t):
|
||||
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape)
|
||||
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
||||
return mean, variance, log_variance
|
||||
|
||||
def predict_start_from_noise(self, x_t, t, noise):
|
||||
return extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
||||
|
||||
def q_posterior(self, x_start, x_t, t):
|
||||
posterior_mean = extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, x, t, cond):
|
||||
noise_pred = self.denoise_fn(x, t, cond=cond)
|
||||
x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
|
||||
|
||||
x_recon.clamp_(-1.0, 1.0)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
|
||||
b, *_, device = *x.shape, x.device # type: ignore
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False):
|
||||
"""
|
||||
Use the PLMS method from
|
||||
[Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778).
|
||||
"""
|
||||
|
||||
def get_x_pred(x, noise_t, t):
|
||||
a_t = extract(self.alphas_cumprod, t)
|
||||
a_prev = extract(self.alphas_cumprod, torch.max(t - interval, torch.zeros_like(t)))
|
||||
a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
|
||||
|
||||
x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
|
||||
x_pred = x + x_delta
|
||||
|
||||
return x_pred
|
||||
|
||||
noise_list = self.noise_list
|
||||
noise_pred = self.denoise_fn(x, t, cond=cond)
|
||||
|
||||
if len(noise_list) == 0:
|
||||
x_pred = get_x_pred(x, noise_pred, t)
|
||||
noise_pred_prev = self.denoise_fn(x_pred, max(t - interval, 0), cond=cond)
|
||||
noise_pred_prime = (noise_pred + noise_pred_prev) / 2
|
||||
elif len(noise_list) == 1:
|
||||
noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2
|
||||
elif len(noise_list) == 2:
|
||||
noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12
|
||||
else:
|
||||
noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24
|
||||
|
||||
x_prev = get_x_pred(x, noise_pred_prime, t)
|
||||
noise_list.append(noise_pred)
|
||||
|
||||
return x_prev
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
return extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
|
||||
def p_losses(self, x_start, t, cond, noise=None, loss_type="l2"):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
||||
x_recon = self.denoise_fn(x_noisy, t, cond)
|
||||
|
||||
if loss_type == "l1":
|
||||
loss = (noise - x_recon).abs().mean()
|
||||
elif loss_type == "l2":
|
||||
loss = F.mse_loss(noise, x_recon)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return loss
|
||||
|
||||
def org_forward(self, condition, init_noise=None, gt_spec=None, infer=True, infer_speedup=100, method="pndm", k_step=1000, use_tqdm=True):
|
||||
"""
|
||||
conditioning diffusion, use fastspeech2 encoder output as the condition
|
||||
"""
|
||||
cond = condition
|
||||
b, device = condition.shape[0], condition.device
|
||||
if not infer:
|
||||
spec = self.norm_spec(gt_spec)
|
||||
t = torch.randint(0, self.k_step, (b,), device=device).long()
|
||||
norm_spec = spec.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
|
||||
return self.p_losses(norm_spec, t, cond=cond)
|
||||
else:
|
||||
shape = (cond.shape[0], 1, self.out_dims, cond.shape[2])
|
||||
|
||||
if gt_spec is None:
|
||||
t = self.k_step
|
||||
if init_noise is None:
|
||||
x = torch.randn(shape, device=device)
|
||||
else:
|
||||
x = init_noise
|
||||
else:
|
||||
t = k_step
|
||||
norm_spec = self.norm_spec(gt_spec)
|
||||
norm_spec = norm_spec.transpose(1, 2)[:, None, :, :]
|
||||
x = self.q_sample(x_start=norm_spec, t=torch.tensor([t - 1], device=device).long())
|
||||
|
||||
if method is not None and infer_speedup > 1:
|
||||
if method == "dpm-solver":
|
||||
from .dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
|
||||
|
||||
# 1. Define the noise schedule.
|
||||
noise_schedule = NoiseScheduleVP(schedule="discrete", betas=self.betas[:t])
|
||||
|
||||
# 2. Convert your discrete-time `model` to the continuous-time
|
||||
# noise prediction model. Here is an example for a diffusion model
|
||||
# `model` with the noise prediction type ("noise") .
|
||||
def my_wrapper(fn):
|
||||
def wrapped(x, t, **kwargs):
|
||||
ret = fn(x, t, **kwargs)
|
||||
if use_tqdm:
|
||||
self.bar.update(1)
|
||||
return ret
|
||||
|
||||
return wrapped
|
||||
|
||||
model_fn = model_wrapper(my_wrapper(self.denoise_fn), noise_schedule, model_type="noise", model_kwargs={"cond": cond}) # or "x_start" or "v" or "score"
|
||||
|
||||
# 3. Define dpm-solver and sample by singlestep DPM-Solver.
|
||||
# (We recommend singlestep DPM-Solver for unconditional sampling)
|
||||
# You can adjust the `steps` to balance the computation
|
||||
# costs and the sample quality.
|
||||
dpm_solver = DPM_Solver(model_fn, noise_schedule)
|
||||
|
||||
steps = t // infer_speedup
|
||||
if use_tqdm:
|
||||
self.bar = tqdm(desc="sample time step", total=steps)
|
||||
x = dpm_solver.sample(
|
||||
x,
|
||||
steps=steps,
|
||||
order=3,
|
||||
skip_type="time_uniform",
|
||||
method="singlestep",
|
||||
)
|
||||
if use_tqdm:
|
||||
self.bar.close()
|
||||
elif method == "pndm":
|
||||
self.noise_list = deque(maxlen=4)
|
||||
if use_tqdm:
|
||||
for i in tqdm(
|
||||
reversed(range(0, t, infer_speedup)),
|
||||
desc="sample time step",
|
||||
total=t // infer_speedup,
|
||||
):
|
||||
x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond)
|
||||
else:
|
||||
for i in reversed(range(0, t, infer_speedup)):
|
||||
x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond)
|
||||
else:
|
||||
raise NotImplementedError(method)
|
||||
else:
|
||||
if use_tqdm:
|
||||
for i in tqdm(reversed(range(0, t)), desc="sample time step", total=t):
|
||||
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
||||
else:
|
||||
for i in reversed(range(0, t)):
|
||||
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
||||
x = x.squeeze(1).transpose(1, 2) # [B, T, M]
|
||||
return self.denorm_spec(x).transpose(2, 1)
|
||||
|
||||
def norm_spec(self, x):
|
||||
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
||||
|
||||
def denorm_spec(self, x):
|
||||
return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
|
||||
|
||||
def get_x_pred(self, x_1, noise_t, t_1, t_prev):
|
||||
a_t = extract(self.alphas_cumprod, t_1)
|
||||
a_prev = extract(self.alphas_cumprod, t_prev)
|
||||
a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
|
||||
x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x_1 - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
|
||||
x_pred = x_1 + x_delta
|
||||
return x_pred
|
||||
|
||||
def OnnxExport(self, project_name=None, init_noise=None, hidden_channels=256, export_denoise=True, export_pred=True, export_after=True):
|
||||
cond = torch.randn([1, self.n_hidden, 10]).cpu()
|
||||
if init_noise is None:
|
||||
x = torch.randn((1, 1, self.mel_bins, cond.shape[2]), dtype=torch.float32).cpu()
|
||||
else:
|
||||
x = init_noise
|
||||
pndms = 100
|
||||
|
||||
org_y_x = self.org_forward(cond, init_noise=x)
|
||||
|
||||
device = cond.device
|
||||
n_frames = cond.shape[2]
|
||||
step_range = torch.arange(0, self.k_step, pndms, dtype=torch.long, device=device).flip(0)
|
||||
plms_noise_stage = torch.tensor(0, dtype=torch.long, device=device)
|
||||
noise_list = torch.zeros((0, 1, 1, self.mel_bins, n_frames), device=device)
|
||||
|
||||
ot = step_range[0]
|
||||
ot_1 = torch.full((1,), ot, device=device, dtype=torch.long)
|
||||
if export_denoise:
|
||||
torch.onnx.export(self.denoise_fn, (x.cpu(), ot_1.cpu(), cond.cpu()), f"{project_name}_denoise.onnx", input_names=["noise", "time", "condition"], output_names=["noise_pred"], dynamic_axes={"noise": [3], "condition": [2]}, opset_version=16)
|
||||
|
||||
for t in step_range:
|
||||
t_1 = torch.full((1,), t, device=device, dtype=torch.long)
|
||||
noise_pred = self.denoise_fn(x, t_1, cond)
|
||||
t_prev = t_1 - pndms
|
||||
t_prev = t_prev * (t_prev > 0)
|
||||
if plms_noise_stage == 0:
|
||||
if export_pred:
|
||||
torch.onnx.export(self.xp, (x.cpu(), noise_pred.cpu(), t_1.cpu(), t_prev.cpu()), f"{project_name}_pred.onnx", input_names=["noise", "noise_pred", "time", "time_prev"], output_names=["noise_pred_o"], dynamic_axes={"noise": [3], "noise_pred": [3]}, opset_version=16)
|
||||
|
||||
x_pred = self.get_x_pred(x, noise_pred, t_1, t_prev)
|
||||
noise_pred_prev = self.denoise_fn(x_pred, t_prev, cond=cond)
|
||||
noise_pred_prime = predict_stage0(noise_pred, noise_pred_prev)
|
||||
|
||||
elif plms_noise_stage == 1:
|
||||
noise_pred_prime = predict_stage1(noise_pred, noise_list)
|
||||
|
||||
elif plms_noise_stage == 2:
|
||||
noise_pred_prime = predict_stage2(noise_pred, noise_list)
|
||||
|
||||
else:
|
||||
noise_pred_prime = predict_stage3(noise_pred, noise_list)
|
||||
|
||||
noise_pred = noise_pred.unsqueeze(0)
|
||||
|
||||
if plms_noise_stage < 3:
|
||||
noise_list = torch.cat((noise_list, noise_pred), dim=0)
|
||||
plms_noise_stage = plms_noise_stage + 1
|
||||
|
||||
else:
|
||||
noise_list = torch.cat((noise_list[-2:], noise_pred), dim=0)
|
||||
|
||||
x = self.get_x_pred(x, noise_pred_prime, t_1, t_prev)
|
||||
if export_after:
|
||||
torch.onnx.export(self.ad, x.cpu(), f"{project_name}_after.onnx", input_names=["x"], output_names=["mel_out"], dynamic_axes={"x": [3]}, opset_version=16)
|
||||
x = self.ad(x)
|
||||
|
||||
print((x == org_y_x).all())
|
||||
return x
|
||||
|
||||
def forward(self, condition=None, init_noise=None, pndms=None, k_step=None):
|
||||
cond = condition
|
||||
x = init_noise
|
||||
|
||||
device = cond.device
|
||||
n_frames = cond.shape[2]
|
||||
step_range = torch.arange(0, k_step.item(), pndms.item(), dtype=torch.long, device=device).flip(0)
|
||||
plms_noise_stage = torch.tensor(0, dtype=torch.long, device=device)
|
||||
noise_list = torch.zeros((0, 1, 1, self.mel_bins, n_frames), device=device)
|
||||
|
||||
ot = step_range[0]
|
||||
ot_1 = torch.full((1,), ot, device=device, dtype=torch.long) # NOQA
|
||||
|
||||
for t in step_range:
|
||||
t_1 = torch.full((1,), t, device=device, dtype=torch.long)
|
||||
noise_pred = self.denoise_fn(x, t_1, cond)
|
||||
t_prev = t_1 - pndms
|
||||
t_prev = t_prev * (t_prev > 0)
|
||||
if plms_noise_stage == 0:
|
||||
x_pred = self.get_x_pred(x, noise_pred, t_1, t_prev)
|
||||
noise_pred_prev = self.denoise_fn(x_pred, t_prev, cond=cond)
|
||||
noise_pred_prime = predict_stage0(noise_pred, noise_pred_prev)
|
||||
|
||||
elif plms_noise_stage == 1:
|
||||
noise_pred_prime = predict_stage1(noise_pred, noise_list)
|
||||
|
||||
elif plms_noise_stage == 2:
|
||||
noise_pred_prime = predict_stage2(noise_pred, noise_list)
|
||||
|
||||
else:
|
||||
noise_pred_prime = predict_stage3(noise_pred, noise_list)
|
||||
|
||||
noise_pred = noise_pred.unsqueeze(0)
|
||||
|
||||
if plms_noise_stage < 3:
|
||||
noise_list = torch.cat((noise_list, noise_pred), dim=0)
|
||||
plms_noise_stage = plms_noise_stage + 1
|
||||
|
||||
else:
|
||||
noise_list = torch.cat((noise_list[-2:], noise_pred), dim=0)
|
||||
|
||||
x = self.get_x_pred(x, noise_pred_prime, t_1, t_prev)
|
||||
x = self.ad(x)
|
||||
return x
|
1284
server/voice_changer/DDSP_SVC/models/diffusion/dpm_solver_pytorch.py
Normal file
1284
server/voice_changer/DDSP_SVC/models/diffusion/dpm_solver_pytorch.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,59 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from .unit2mel import load_model_vocoder
|
||||
|
||||
|
||||
class DiffGtMel:
|
||||
def __init__(self, project_path=None, device=None):
|
||||
self.project_path = project_path
|
||||
if device is not None:
|
||||
self.device = device
|
||||
else:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.model = None
|
||||
self.vocoder = None
|
||||
self.args = None
|
||||
|
||||
def flush_model(self, project_path, ddsp_config=None):
|
||||
if (self.model is None) or (project_path != self.project_path):
|
||||
model, vocoder, args = load_model_vocoder(project_path, device=self.device)
|
||||
if self.check_args(ddsp_config, args):
|
||||
self.model = model
|
||||
self.vocoder = vocoder
|
||||
self.args = args
|
||||
|
||||
def check_args(self, args1, args2):
|
||||
if args1.data.block_size != args2.data.block_size:
|
||||
raise ValueError("DDSP与DIFF模型的block_size不一致")
|
||||
if args1.data.sampling_rate != args2.data.sampling_rate:
|
||||
raise ValueError("DDSP与DIFF模型的sampling_rate不一致")
|
||||
if args1.data.encoder != args2.data.encoder:
|
||||
raise ValueError("DDSP与DIFF模型的encoder不一致")
|
||||
return True
|
||||
|
||||
def __call__(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, method="pndm", spk_mix_dict=None, start_frame=0):
|
||||
input_mel = self.vocoder.extract(audio, self.args.data.sampling_rate)
|
||||
out_mel = self.model(hubert, f0, volume, spk_id=spk_id, spk_mix_dict=spk_mix_dict, gt_spec=input_mel, infer=True, infer_speedup=acc, method=method, k_step=k_step, use_tqdm=False)
|
||||
if start_frame > 0:
|
||||
out_mel = out_mel[:, start_frame:, :]
|
||||
f0 = f0[:, start_frame:, :]
|
||||
output = self.vocoder.infer(out_mel, f0)
|
||||
if start_frame > 0:
|
||||
output = F.pad(output, (start_frame * self.vocoder.vocoder_hop_size, 0))
|
||||
return output
|
||||
|
||||
def infer(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, method="pndm", silence_front=0, use_silence=False, spk_mix_dict=None):
|
||||
start_frame = int(silence_front * self.vocoder.vocoder_sample_rate / self.vocoder.vocoder_hop_size)
|
||||
if use_silence:
|
||||
audio = audio[:, start_frame * self.vocoder.vocoder_hop_size :]
|
||||
f0 = f0[:, start_frame:, :]
|
||||
hubert = hubert[:, start_frame:, :]
|
||||
volume = volume[:, start_frame:, :]
|
||||
_start_frame = 0
|
||||
else:
|
||||
_start_frame = start_frame
|
||||
audio = self.__call__(audio, f0, hubert, volume, acc=acc, spk_id=spk_id, k_step=k_step, method=method, spk_mix_dict=spk_mix_dict, start_frame=_start_frame)
|
||||
if use_silence:
|
||||
if start_frame > 0:
|
||||
audio = F.pad(audio, (start_frame * self.vocoder.vocoder_hop_size, 0))
|
||||
return audio
|
226
server/voice_changer/DDSP_SVC/models/diffusion/onnx_export.py
Normal file
226
server/voice_changer/DDSP_SVC/models/diffusion/onnx_export.py
Normal file
@ -0,0 +1,226 @@
|
||||
from diffusion_onnx import GaussianDiffusion
|
||||
import os
|
||||
import yaml
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from wavenet import WaveNet
|
||||
import torch.nn.functional as F
|
||||
import diffusion
|
||||
|
||||
class DotDict(dict):
|
||||
def __getattr__(*args):
|
||||
val = dict.get(*args)
|
||||
return DotDict(val) if type(val) is dict else val
|
||||
|
||||
__setattr__ = dict.__setitem__
|
||||
__delattr__ = dict.__delitem__
|
||||
|
||||
|
||||
def load_model_vocoder(
|
||||
model_path,
|
||||
device='cpu'):
|
||||
config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml')
|
||||
with open(config_file, "r") as config:
|
||||
args = yaml.safe_load(config)
|
||||
args = DotDict(args)
|
||||
|
||||
# load model
|
||||
model = Unit2Mel(
|
||||
args.data.encoder_out_channels,
|
||||
args.model.n_spk,
|
||||
args.model.use_pitch_aug,
|
||||
128,
|
||||
args.model.n_layers,
|
||||
args.model.n_chans,
|
||||
args.model.n_hidden)
|
||||
|
||||
print(' [Loading] ' + model_path)
|
||||
ckpt = torch.load(model_path, map_location=torch.device(device))
|
||||
model.to(device)
|
||||
model.load_state_dict(ckpt['model'])
|
||||
model.eval()
|
||||
return model, args
|
||||
|
||||
|
||||
class Unit2Mel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channel,
|
||||
n_spk,
|
||||
use_pitch_aug=False,
|
||||
out_dims=128,
|
||||
n_layers=20,
|
||||
n_chans=384,
|
||||
n_hidden=256):
|
||||
super().__init__()
|
||||
self.unit_embed = nn.Linear(input_channel, n_hidden)
|
||||
self.f0_embed = nn.Linear(1, n_hidden)
|
||||
self.volume_embed = nn.Linear(1, n_hidden)
|
||||
if use_pitch_aug:
|
||||
self.aug_shift_embed = nn.Linear(1, n_hidden, bias=False)
|
||||
else:
|
||||
self.aug_shift_embed = None
|
||||
self.n_spk = n_spk
|
||||
if n_spk is not None and n_spk > 1:
|
||||
self.spk_embed = nn.Embedding(n_spk, n_hidden)
|
||||
|
||||
# diffusion
|
||||
self.decoder = GaussianDiffusion(out_dims, n_layers, n_chans, n_hidden)
|
||||
self.hidden_size = n_hidden
|
||||
self.speaker_map = torch.zeros((self.n_spk,1,1,n_hidden))
|
||||
|
||||
|
||||
|
||||
def forward(self, units, mel2ph, f0, volume, g = None):
|
||||
|
||||
'''
|
||||
input:
|
||||
B x n_frames x n_unit
|
||||
return:
|
||||
dict of B x n_frames x feat
|
||||
'''
|
||||
|
||||
decoder_inp = F.pad(units, [0, 0, 1, 0])
|
||||
mel2ph_ = mel2ph.unsqueeze(2).repeat([1, 1, units.shape[-1]])
|
||||
units = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
|
||||
|
||||
x = self.unit_embed(units) + self.f0_embed((1 + f0.unsqueeze(-1) / 700).log()) + self.volume_embed(volume.unsqueeze(-1))
|
||||
|
||||
if self.n_spk is not None and self.n_spk > 1: # [N, S] * [S, B, 1, H]
|
||||
g = g.reshape((g.shape[0], g.shape[1], 1, 1, 1)) # [N, S, B, 1, 1]
|
||||
g = g * self.speaker_map # [N, S, B, 1, H]
|
||||
g = torch.sum(g, dim=1) # [N, 1, B, 1, H]
|
||||
g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N]
|
||||
x = x.transpose(1, 2) + g
|
||||
return x
|
||||
else:
|
||||
return x.transpose(1, 2)
|
||||
|
||||
|
||||
def init_spkembed(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None,
|
||||
gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=300, use_tqdm=True):
|
||||
|
||||
'''
|
||||
input:
|
||||
B x n_frames x n_unit
|
||||
return:
|
||||
dict of B x n_frames x feat
|
||||
'''
|
||||
x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume)
|
||||
if self.n_spk is not None and self.n_spk > 1:
|
||||
if spk_mix_dict is not None:
|
||||
spk_embed_mix = torch.zeros((1,1,self.hidden_size))
|
||||
for k, v in spk_mix_dict.items():
|
||||
spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device)
|
||||
spk_embeddd = self.spk_embed(spk_id_torch)
|
||||
self.speaker_map[k] = spk_embeddd
|
||||
spk_embed_mix = spk_embed_mix + v * spk_embeddd
|
||||
x = x + spk_embed_mix
|
||||
else:
|
||||
x = x + self.spk_embed(spk_id - 1)
|
||||
self.speaker_map = self.speaker_map.unsqueeze(0)
|
||||
self.speaker_map = self.speaker_map.detach()
|
||||
return x.transpose(1, 2)
|
||||
|
||||
def OnnxExport(self, project_name=None, init_noise=None, export_encoder=True, export_denoise=True, export_pred=True, export_after=True):
|
||||
hubert_hidden_size = 768
|
||||
n_frames = 100
|
||||
hubert = torch.randn((1, n_frames, hubert_hidden_size))
|
||||
mel2ph = torch.arange(end=n_frames).unsqueeze(0).long()
|
||||
f0 = torch.randn((1, n_frames))
|
||||
volume = torch.randn((1, n_frames))
|
||||
spk_mix = []
|
||||
spks = {}
|
||||
if self.n_spk is not None and self.n_spk > 1:
|
||||
for i in range(self.n_spk):
|
||||
spk_mix.append(1.0/float(self.n_spk))
|
||||
spks.update({i:1.0/float(self.n_spk)})
|
||||
spk_mix = torch.tensor(spk_mix)
|
||||
spk_mix = spk_mix.repeat(n_frames, 1)
|
||||
orgouttt = self.init_spkembed(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks)
|
||||
outtt = self.forward(hubert, mel2ph, f0, volume, spk_mix)
|
||||
if export_encoder:
|
||||
torch.onnx.export(
|
||||
self,
|
||||
(hubert, mel2ph, f0, volume, spk_mix),
|
||||
f"{project_name}_encoder.onnx",
|
||||
input_names=["hubert", "mel2ph", "f0", "volume", "spk_mix"],
|
||||
output_names=["mel_pred"],
|
||||
dynamic_axes={
|
||||
"hubert": [1],
|
||||
"f0": [1],
|
||||
"volume": [1],
|
||||
"mel2ph": [1],
|
||||
"spk_mix": [0],
|
||||
},
|
||||
opset_version=16
|
||||
)
|
||||
|
||||
self.decoder.OnnxExport(project_name, init_noise=init_noise, export_denoise=export_denoise, export_pred=export_pred, export_after=export_after)
|
||||
|
||||
def ExportOnnx(self, project_name=None):
|
||||
hubert_hidden_size = 768
|
||||
n_frames = 100
|
||||
hubert = torch.randn((1, n_frames, hubert_hidden_size))
|
||||
mel2ph = torch.arange(end=n_frames).unsqueeze(0).long()
|
||||
f0 = torch.randn((1, n_frames))
|
||||
volume = torch.randn((1, n_frames))
|
||||
spk_mix = []
|
||||
spks = {}
|
||||
if self.n_spk is not None and self.n_spk > 1:
|
||||
for i in range(self.n_spk):
|
||||
spk_mix.append(1.0/float(self.n_spk))
|
||||
spks.update({i:1.0/float(self.n_spk)})
|
||||
spk_mix = torch.tensor(spk_mix)
|
||||
orgouttt = self.orgforward(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks)
|
||||
outtt = self.forward(hubert, mel2ph, f0, volume, spk_mix)
|
||||
|
||||
torch.onnx.export(
|
||||
self,
|
||||
(hubert, mel2ph, f0, volume, spk_mix),
|
||||
f"{project_name}_encoder.onnx",
|
||||
input_names=["hubert", "mel2ph", "f0", "volume", "spk_mix"],
|
||||
output_names=["mel_pred"],
|
||||
dynamic_axes={
|
||||
"hubert": [1],
|
||||
"f0": [1],
|
||||
"volume": [1],
|
||||
"mel2ph": [1]
|
||||
},
|
||||
opset_version=16
|
||||
)
|
||||
|
||||
condition = torch.randn(1,self.decoder.n_hidden,n_frames)
|
||||
noise = torch.randn((1, 1, self.decoder.mel_bins, condition.shape[2]), dtype=torch.float32)
|
||||
pndm_speedup = torch.LongTensor([100])
|
||||
K_steps = torch.LongTensor([1000])
|
||||
self.decoder = torch.jit.script(self.decoder)
|
||||
self.decoder(condition, noise, pndm_speedup, K_steps)
|
||||
|
||||
torch.onnx.export(
|
||||
self.decoder,
|
||||
(condition, noise, pndm_speedup, K_steps),
|
||||
f"{project_name}_diffusion.onnx",
|
||||
input_names=["condition", "noise", "pndm_speedup", "K_steps"],
|
||||
output_names=["mel"],
|
||||
dynamic_axes={
|
||||
"condition": [2],
|
||||
"noise": [3],
|
||||
},
|
||||
opset_version=16
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
project_name = "dddsp"
|
||||
model_path = f'{project_name}/model_500000.pt'
|
||||
|
||||
model, _ = load_model_vocoder(model_path)
|
||||
|
||||
# 分开Diffusion导出(需要使用MoeSS/MoeVoiceStudio或者自己编写Pndm/Dpm采样)
|
||||
model.OnnxExport(project_name, export_encoder=True, export_denoise=True, export_pred=True, export_after=True)
|
||||
|
||||
# 合并Diffusion导出(Encoder和Diffusion分开,直接将Encoder的结果和初始噪声输入Diffusion即可)
|
||||
# model.ExportOnnx(project_name)
|
||||
|
194
server/voice_changer/DDSP_SVC/models/diffusion/solver.py
Normal file
194
server/voice_changer/DDSP_SVC/models/diffusion/solver.py
Normal file
@ -0,0 +1,194 @@
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import librosa
|
||||
from logger.saver import Saver
|
||||
from logger import utils
|
||||
from torch import autocast
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
def test(args, model, vocoder, loader_test, saver):
|
||||
print(' [*] testing...')
|
||||
model.eval()
|
||||
|
||||
# losses
|
||||
test_loss = 0.
|
||||
|
||||
# intialization
|
||||
num_batches = len(loader_test)
|
||||
rtf_all = []
|
||||
|
||||
# run
|
||||
with torch.no_grad():
|
||||
for bidx, data in enumerate(loader_test):
|
||||
fn = data['name'][0]
|
||||
print('--------')
|
||||
print('{}/{} - {}'.format(bidx, num_batches, fn))
|
||||
|
||||
# unpack data
|
||||
for k in data.keys():
|
||||
if not k.startswith('name'):
|
||||
data[k] = data[k].to(args.device)
|
||||
print('>>', data['name'][0])
|
||||
|
||||
# forward
|
||||
st_time = time.time()
|
||||
mel = model(
|
||||
data['units'],
|
||||
data['f0'],
|
||||
data['volume'],
|
||||
data['spk_id'],
|
||||
gt_spec=None,
|
||||
infer=True,
|
||||
infer_speedup=args.infer.speedup,
|
||||
method=args.infer.method)
|
||||
signal = vocoder.infer(mel, data['f0'])
|
||||
ed_time = time.time()
|
||||
|
||||
# RTF
|
||||
run_time = ed_time - st_time
|
||||
song_time = signal.shape[-1] / args.data.sampling_rate
|
||||
rtf = run_time / song_time
|
||||
print('RTF: {} | {} / {}'.format(rtf, run_time, song_time))
|
||||
rtf_all.append(rtf)
|
||||
|
||||
# loss
|
||||
for i in range(args.train.batch_size):
|
||||
loss = model(
|
||||
data['units'],
|
||||
data['f0'],
|
||||
data['volume'],
|
||||
data['spk_id'],
|
||||
gt_spec=data['mel'],
|
||||
infer=False)
|
||||
test_loss += loss.item()
|
||||
|
||||
# log mel
|
||||
saver.log_spec(data['name'][0], data['mel'], mel)
|
||||
|
||||
# log audio
|
||||
path_audio = os.path.join(args.data.valid_path, 'audio', data['name_ext'][0])
|
||||
audio, sr = librosa.load(path_audio, sr=args.data.sampling_rate)
|
||||
if len(audio.shape) > 1:
|
||||
audio = librosa.to_mono(audio)
|
||||
audio = torch.from_numpy(audio).unsqueeze(0).to(signal)
|
||||
saver.log_audio({fn+'/gt.wav': audio, fn+'/pred.wav': signal})
|
||||
|
||||
# report
|
||||
test_loss /= args.train.batch_size
|
||||
test_loss /= num_batches
|
||||
|
||||
# check
|
||||
print(' [test_loss] test_loss:', test_loss)
|
||||
print(' Real Time Factor', np.mean(rtf_all))
|
||||
return test_loss
|
||||
|
||||
|
||||
def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_test):
|
||||
# saver
|
||||
saver = Saver(args, initial_global_step=initial_global_step)
|
||||
|
||||
# model size
|
||||
params_count = utils.get_network_paras_amount({'model': model})
|
||||
saver.log_info('--- model size ---')
|
||||
saver.log_info(params_count)
|
||||
|
||||
# run
|
||||
num_batches = len(loader_train)
|
||||
model.train()
|
||||
saver.log_info('======= start training =======')
|
||||
scaler = GradScaler()
|
||||
if args.train.amp_dtype == 'fp32':
|
||||
dtype = torch.float32
|
||||
elif args.train.amp_dtype == 'fp16':
|
||||
dtype = torch.float16
|
||||
elif args.train.amp_dtype == 'bf16':
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype)
|
||||
for epoch in range(args.train.epochs):
|
||||
for batch_idx, data in enumerate(loader_train):
|
||||
saver.global_step_increment()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# unpack data
|
||||
for k in data.keys():
|
||||
if not k.startswith('name'):
|
||||
data[k] = data[k].to(args.device)
|
||||
|
||||
# forward
|
||||
if dtype == torch.float32:
|
||||
loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'],
|
||||
aug_shift = data['aug_shift'], gt_spec=data['mel'].float(), infer=False)
|
||||
else:
|
||||
with autocast(device_type=args.device, dtype=dtype):
|
||||
loss = model(data['units'], data['f0'], data['volume'], data['spk_id'],
|
||||
aug_shift = data['aug_shift'], gt_spec=data['mel'], infer=False)
|
||||
|
||||
# handle nan loss
|
||||
if torch.isnan(loss):
|
||||
raise ValueError(' [x] nan loss ')
|
||||
else:
|
||||
# backpropagate
|
||||
if dtype == torch.float32:
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
else:
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
scheduler.step()
|
||||
|
||||
# log loss
|
||||
if saver.global_step % args.train.interval_log == 0:
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
saver.log_info(
|
||||
'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | lr: {:.6} | loss: {:.3f} | time: {} | step: {}'.format(
|
||||
epoch,
|
||||
batch_idx,
|
||||
num_batches,
|
||||
args.env.expdir,
|
||||
args.train.interval_log/saver.get_interval_time(),
|
||||
current_lr,
|
||||
loss.item(),
|
||||
saver.get_total_time(),
|
||||
saver.global_step
|
||||
)
|
||||
)
|
||||
|
||||
saver.log_value({
|
||||
'train/loss': loss.item()
|
||||
})
|
||||
|
||||
saver.log_value({
|
||||
'train/lr': current_lr
|
||||
})
|
||||
|
||||
# validation
|
||||
if saver.global_step % args.train.interval_val == 0:
|
||||
optimizer_save = optimizer if args.train.save_opt else None
|
||||
|
||||
# save latest
|
||||
saver.save_model(model, optimizer_save, postfix=f'{saver.global_step}')
|
||||
last_val_step = saver.global_step - args.train.interval_val
|
||||
if last_val_step % args.train.interval_force_save != 0:
|
||||
saver.delete_model(postfix=f'{last_val_step}')
|
||||
|
||||
# run testing set
|
||||
test_loss = test(args, model, vocoder, loader_test, saver)
|
||||
|
||||
# log loss
|
||||
saver.log_info(
|
||||
' --- <validation> --- \nloss: {:.3f}. '.format(
|
||||
test_loss,
|
||||
)
|
||||
)
|
||||
|
||||
saver.log_value({
|
||||
'validation/loss': test_loss
|
||||
})
|
||||
|
||||
model.train()
|
||||
|
||||
|
731
server/voice_changer/DDSP_SVC/models/diffusion/uni_pc.py
Normal file
731
server/voice_changer/DDSP_SVC/models/diffusion/uni_pc.py
Normal file
@ -0,0 +1,731 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
class NoiseScheduleVP:
|
||||
def __init__(
|
||||
self,
|
||||
schedule='discrete',
|
||||
betas=None,
|
||||
alphas_cumprod=None,
|
||||
continuous_beta_0=0.1,
|
||||
continuous_beta_1=20.,
|
||||
dtype=torch.float32,
|
||||
):
|
||||
"""Create a wrapper class for the forward SDE (VP type).
|
||||
***
|
||||
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
||||
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
|
||||
***
|
||||
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
|
||||
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
|
||||
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
|
||||
log_alpha_t = self.marginal_log_mean_coeff(t)
|
||||
sigma_t = self.marginal_std(t)
|
||||
lambda_t = self.marginal_lambda(t)
|
||||
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
|
||||
t = self.inverse_lambda(lambda_t)
|
||||
===============================================================
|
||||
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
|
||||
1. For discrete-time DPMs:
|
||||
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
|
||||
t_i = (i + 1) / N
|
||||
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
|
||||
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
|
||||
Args:
|
||||
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
|
||||
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
|
||||
Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
|
||||
**Important**: Please pay special attention for the args for `alphas_cumprod`:
|
||||
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
|
||||
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
|
||||
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
|
||||
alpha_{t_n} = \sqrt{\hat{alpha_n}},
|
||||
and
|
||||
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
|
||||
2. For continuous-time DPMs:
|
||||
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
|
||||
schedule are the default settings in DDPM and improved-DDPM:
|
||||
Args:
|
||||
beta_min: A `float` number. The smallest beta for the linear schedule.
|
||||
beta_max: A `float` number. The largest beta for the linear schedule.
|
||||
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
|
||||
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
|
||||
T: A `float` number. The ending time of the forward process.
|
||||
===============================================================
|
||||
Args:
|
||||
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
|
||||
'linear' or 'cosine' for continuous-time DPMs.
|
||||
Returns:
|
||||
A wrapper object of the forward SDE (VP type).
|
||||
|
||||
===============================================================
|
||||
Example:
|
||||
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
|
||||
>>> ns = NoiseScheduleVP('discrete', betas=betas)
|
||||
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
|
||||
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
||||
# For continuous-time DPMs (VPSDE), linear schedule:
|
||||
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
|
||||
"""
|
||||
|
||||
if schedule not in ['discrete', 'linear', 'cosine']:
|
||||
raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
|
||||
|
||||
self.schedule = schedule
|
||||
if schedule == 'discrete':
|
||||
if betas is not None:
|
||||
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
||||
else:
|
||||
assert alphas_cumprod is not None
|
||||
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
||||
self.total_N = len(log_alphas)
|
||||
self.T = 1.
|
||||
self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
|
||||
self.log_alpha_array = log_alphas.reshape((1, -1,)).to(dtype=dtype)
|
||||
else:
|
||||
self.total_N = 1000
|
||||
self.beta_0 = continuous_beta_0
|
||||
self.beta_1 = continuous_beta_1
|
||||
self.cosine_s = 0.008
|
||||
self.cosine_beta_max = 999.
|
||||
self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
||||
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
|
||||
self.schedule = schedule
|
||||
if schedule == 'cosine':
|
||||
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
|
||||
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
|
||||
self.T = 0.9946
|
||||
else:
|
||||
self.T = 1.
|
||||
|
||||
def marginal_log_mean_coeff(self, t):
|
||||
"""
|
||||
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
||||
"""
|
||||
if self.schedule == 'discrete':
|
||||
return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
|
||||
elif self.schedule == 'linear':
|
||||
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
||||
elif self.schedule == 'cosine':
|
||||
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
|
||||
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
|
||||
return log_alpha_t
|
||||
|
||||
def marginal_alpha(self, t):
|
||||
"""
|
||||
Compute alpha_t of a given continuous-time label t in [0, T].
|
||||
"""
|
||||
return torch.exp(self.marginal_log_mean_coeff(t))
|
||||
|
||||
def marginal_std(self, t):
|
||||
"""
|
||||
Compute sigma_t of a given continuous-time label t in [0, T].
|
||||
"""
|
||||
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
|
||||
|
||||
def marginal_lambda(self, t):
|
||||
"""
|
||||
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
||||
"""
|
||||
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
||||
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
||||
return log_mean_coeff - log_std
|
||||
|
||||
def inverse_lambda(self, lamb):
|
||||
"""
|
||||
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
||||
"""
|
||||
if self.schedule == 'linear':
|
||||
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
||||
Delta = self.beta_0**2 + tmp
|
||||
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
||||
elif self.schedule == 'discrete':
|
||||
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
|
||||
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
|
||||
return t.reshape((-1,))
|
||||
else:
|
||||
log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
||||
t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
||||
t = t_fn(log_alpha)
|
||||
return t
|
||||
|
||||
|
||||
def model_wrapper(
|
||||
model,
|
||||
noise_schedule,
|
||||
model_type="noise",
|
||||
model_kwargs={},
|
||||
guidance_type="uncond",
|
||||
condition=None,
|
||||
unconditional_condition=None,
|
||||
guidance_scale=1.,
|
||||
classifier_fn=None,
|
||||
classifier_kwargs={},
|
||||
):
|
||||
"""Create a wrapper function for the noise prediction model.
|
||||
"""
|
||||
|
||||
def get_model_input_time(t_continuous):
|
||||
"""
|
||||
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
||||
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
|
||||
For continuous-time DPMs, we just use `t_continuous`.
|
||||
"""
|
||||
if noise_schedule.schedule == 'discrete':
|
||||
return (t_continuous - 1. / noise_schedule.total_N) * noise_schedule.total_N
|
||||
else:
|
||||
return t_continuous
|
||||
|
||||
def noise_pred_fn(x, t_continuous, cond=None):
|
||||
t_input = get_model_input_time(t_continuous)
|
||||
if cond is None:
|
||||
output = model(x, t_input, **model_kwargs)
|
||||
else:
|
||||
output = model(x, t_input, cond, **model_kwargs)
|
||||
if model_type == "noise":
|
||||
return output
|
||||
elif model_type == "x_start":
|
||||
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
||||
return (x - alpha_t * output) / sigma_t
|
||||
elif model_type == "v":
|
||||
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
||||
return alpha_t * output + sigma_t * x
|
||||
elif model_type == "score":
|
||||
sigma_t = noise_schedule.marginal_std(t_continuous)
|
||||
return -sigma_t * output
|
||||
|
||||
def cond_grad_fn(x, t_input):
|
||||
"""
|
||||
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
|
||||
"""
|
||||
with torch.enable_grad():
|
||||
x_in = x.detach().requires_grad_(True)
|
||||
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
|
||||
return torch.autograd.grad(log_prob.sum(), x_in)[0]
|
||||
|
||||
def model_fn(x, t_continuous):
|
||||
"""
|
||||
The noise predicition model function that is used for DPM-Solver.
|
||||
"""
|
||||
if guidance_type == "uncond":
|
||||
return noise_pred_fn(x, t_continuous)
|
||||
elif guidance_type == "classifier":
|
||||
assert classifier_fn is not None
|
||||
t_input = get_model_input_time(t_continuous)
|
||||
cond_grad = cond_grad_fn(x, t_input)
|
||||
sigma_t = noise_schedule.marginal_std(t_continuous)
|
||||
noise = noise_pred_fn(x, t_continuous)
|
||||
return noise - guidance_scale * sigma_t * cond_grad
|
||||
elif guidance_type == "classifier-free":
|
||||
if guidance_scale == 1. or unconditional_condition is None:
|
||||
return noise_pred_fn(x, t_continuous, cond=condition)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t_continuous] * 2)
|
||||
c_in = torch.cat([unconditional_condition, condition])
|
||||
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
||||
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
||||
|
||||
assert model_type in ["noise", "x_start", "v"]
|
||||
assert guidance_type in ["uncond", "classifier", "classifier-free"]
|
||||
return model_fn
|
||||
|
||||
|
||||
class UniPC:
|
||||
def __init__(
|
||||
self,
|
||||
model_fn,
|
||||
noise_schedule,
|
||||
algorithm_type="data_prediction",
|
||||
correcting_x0_fn=None,
|
||||
correcting_xt_fn=None,
|
||||
thresholding_max_val=1.,
|
||||
dynamic_thresholding_ratio=0.995,
|
||||
variant='bh1'
|
||||
):
|
||||
"""Construct a UniPC.
|
||||
|
||||
We support both data_prediction and noise_prediction.
|
||||
"""
|
||||
self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
|
||||
self.noise_schedule = noise_schedule
|
||||
assert algorithm_type in ["data_prediction", "noise_prediction"]
|
||||
|
||||
if correcting_x0_fn == "dynamic_thresholding":
|
||||
self.correcting_x0_fn = self.dynamic_thresholding_fn
|
||||
else:
|
||||
self.correcting_x0_fn = correcting_x0_fn
|
||||
|
||||
self.correcting_xt_fn = correcting_xt_fn
|
||||
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
|
||||
self.thresholding_max_val = thresholding_max_val
|
||||
|
||||
self.variant = variant
|
||||
self.predict_x0 = algorithm_type == "data_prediction"
|
||||
|
||||
def dynamic_thresholding_fn(self, x0, t=None):
|
||||
"""
|
||||
The dynamic thresholding method.
|
||||
"""
|
||||
dims = x0.dim()
|
||||
p = self.dynamic_thresholding_ratio
|
||||
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
|
||||
x0 = torch.clamp(x0, -s, s) / s
|
||||
return x0
|
||||
|
||||
def noise_prediction_fn(self, x, t):
|
||||
"""
|
||||
Return the noise prediction model.
|
||||
"""
|
||||
return self.model(x, t)
|
||||
|
||||
def data_prediction_fn(self, x, t):
|
||||
"""
|
||||
Return the data prediction model (with corrector).
|
||||
"""
|
||||
noise = self.noise_prediction_fn(x, t)
|
||||
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
||||
x0 = (x - sigma_t * noise) / alpha_t
|
||||
if self.correcting_x0_fn is not None:
|
||||
x0 = self.correcting_x0_fn(x0)
|
||||
return x0
|
||||
|
||||
def model_fn(self, x, t):
|
||||
"""
|
||||
Convert the model to the noise prediction model or the data prediction model.
|
||||
"""
|
||||
if self.predict_x0:
|
||||
return self.data_prediction_fn(x, t)
|
||||
else:
|
||||
return self.noise_prediction_fn(x, t)
|
||||
|
||||
def get_time_steps(self, skip_type, t_T, t_0, N, device):
|
||||
"""Compute the intermediate time steps for sampling.
|
||||
"""
|
||||
if skip_type == 'logSNR':
|
||||
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
||||
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
||||
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
||||
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
||||
elif skip_type == 'time_uniform':
|
||||
return torch.linspace(t_T, t_0, N + 1).to(device)
|
||||
elif skip_type == 'time_quadratic':
|
||||
t_order = 2
|
||||
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
||||
return t
|
||||
else:
|
||||
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
|
||||
|
||||
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
||||
"""
|
||||
Get the order of each step for sampling by the singlestep DPM-Solver.
|
||||
"""
|
||||
if order == 3:
|
||||
K = steps // 3 + 1
|
||||
if steps % 3 == 0:
|
||||
orders = [3,] * (K - 2) + [2, 1]
|
||||
elif steps % 3 == 1:
|
||||
orders = [3,] * (K - 1) + [1]
|
||||
else:
|
||||
orders = [3,] * (K - 1) + [2]
|
||||
elif order == 2:
|
||||
if steps % 2 == 0:
|
||||
K = steps // 2
|
||||
orders = [2,] * K
|
||||
else:
|
||||
K = steps // 2 + 1
|
||||
orders = [2,] * (K - 1) + [1]
|
||||
elif order == 1:
|
||||
K = steps
|
||||
orders = [1,] * steps
|
||||
else:
|
||||
raise ValueError("'order' must be '1' or '2' or '3'.")
|
||||
if skip_type == 'logSNR':
|
||||
# To reproduce the results in DPM-Solver paper
|
||||
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
||||
else:
|
||||
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
|
||||
return timesteps_outer, orders
|
||||
|
||||
def denoise_to_zero_fn(self, x, s):
|
||||
"""
|
||||
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
||||
"""
|
||||
return self.data_prediction_fn(x, s)
|
||||
|
||||
def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
|
||||
if len(t.shape) == 0:
|
||||
t = t.view(-1)
|
||||
if 'bh' in self.variant:
|
||||
return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
||||
else:
|
||||
assert self.variant == 'vary_coeff'
|
||||
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
||||
|
||||
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
||||
#print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
||||
ns = self.noise_schedule
|
||||
assert order <= len(model_prev_list)
|
||||
|
||||
# first compute rks
|
||||
t_prev_0 = t_prev_list[-1]
|
||||
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
||||
lambda_t = ns.marginal_lambda(t)
|
||||
model_prev_0 = model_prev_list[-1]
|
||||
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
||||
log_alpha_t = ns.marginal_log_mean_coeff(t)
|
||||
alpha_t = torch.exp(log_alpha_t)
|
||||
|
||||
h = lambda_t - lambda_prev_0
|
||||
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
t_prev_i = t_prev_list[-(i + 1)]
|
||||
model_prev_i = model_prev_list[-(i + 1)]
|
||||
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
||||
rk = (lambda_prev_i - lambda_prev_0) / h
|
||||
rks.append(rk)
|
||||
D1s.append((model_prev_i - model_prev_0) / rk)
|
||||
|
||||
rks.append(1.)
|
||||
rks = torch.tensor(rks, device=x.device)
|
||||
|
||||
K = len(rks)
|
||||
# build C matrix
|
||||
C = []
|
||||
|
||||
col = torch.ones_like(rks)
|
||||
for k in range(1, K + 1):
|
||||
C.append(col)
|
||||
col = col * rks / (k + 1)
|
||||
C = torch.stack(C, dim=1)
|
||||
|
||||
if len(D1s) > 0:
|
||||
D1s = torch.stack(D1s, dim=1) # (B, K)
|
||||
C_inv_p = torch.linalg.inv(C[:-1, :-1])
|
||||
A_p = C_inv_p
|
||||
|
||||
if use_corrector:
|
||||
#print('using corrector')
|
||||
C_inv = torch.linalg.inv(C)
|
||||
A_c = C_inv
|
||||
|
||||
hh = -h if self.predict_x0 else h
|
||||
h_phi_1 = torch.expm1(hh)
|
||||
h_phi_ks = []
|
||||
factorial_k = 1
|
||||
h_phi_k = h_phi_1
|
||||
for k in range(1, K + 2):
|
||||
h_phi_ks.append(h_phi_k)
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_k
|
||||
factorial_k *= (k + 1)
|
||||
|
||||
model_t = None
|
||||
if self.predict_x0:
|
||||
x_t_ = (
|
||||
sigma_t / sigma_prev_0 * x
|
||||
- alpha_t * h_phi_1 * model_prev_0
|
||||
)
|
||||
# now predictor
|
||||
x_t = x_t_
|
||||
if len(D1s) > 0:
|
||||
# compute the residuals for predictor
|
||||
for k in range(K - 1):
|
||||
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
|
||||
# now corrector
|
||||
if use_corrector:
|
||||
model_t = self.model_fn(x_t, t)
|
||||
D1_t = (model_t - model_prev_0)
|
||||
x_t = x_t_
|
||||
k = 0
|
||||
for k in range(K - 1):
|
||||
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
|
||||
x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
|
||||
else:
|
||||
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
||||
x_t_ = (
|
||||
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
|
||||
- (sigma_t * h_phi_1) * model_prev_0
|
||||
)
|
||||
# now predictor
|
||||
x_t = x_t_
|
||||
if len(D1s) > 0:
|
||||
# compute the residuals for predictor
|
||||
for k in range(K - 1):
|
||||
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
|
||||
# now corrector
|
||||
if use_corrector:
|
||||
model_t = self.model_fn(x_t, t)
|
||||
D1_t = (model_t - model_prev_0)
|
||||
x_t = x_t_
|
||||
k = 0
|
||||
for k in range(K - 1):
|
||||
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
|
||||
x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
|
||||
return x_t, model_t
|
||||
|
||||
def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
|
||||
#print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
|
||||
ns = self.noise_schedule
|
||||
assert order <= len(model_prev_list)
|
||||
|
||||
# first compute rks
|
||||
t_prev_0 = t_prev_list[-1]
|
||||
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
||||
lambda_t = ns.marginal_lambda(t)
|
||||
model_prev_0 = model_prev_list[-1]
|
||||
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
||||
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
||||
alpha_t = torch.exp(log_alpha_t)
|
||||
|
||||
h = lambda_t - lambda_prev_0
|
||||
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
t_prev_i = t_prev_list[-(i + 1)]
|
||||
model_prev_i = model_prev_list[-(i + 1)]
|
||||
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
||||
rk = (lambda_prev_i - lambda_prev_0) / h
|
||||
rks.append(rk)
|
||||
D1s.append((model_prev_i - model_prev_0) / rk)
|
||||
|
||||
rks.append(1.)
|
||||
rks = torch.tensor(rks, device=x.device)
|
||||
|
||||
R = []
|
||||
b = []
|
||||
|
||||
hh = -h if self.predict_x0 else h
|
||||
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
||||
h_phi_k = h_phi_1 / hh - 1
|
||||
|
||||
factorial_i = 1
|
||||
|
||||
if self.variant == 'bh1':
|
||||
B_h = hh
|
||||
elif self.variant == 'bh2':
|
||||
B_h = torch.expm1(hh)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
for i in range(1, order + 1):
|
||||
R.append(torch.pow(rks, i - 1))
|
||||
b.append(h_phi_k * factorial_i / B_h)
|
||||
factorial_i *= (i + 1)
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||
|
||||
R = torch.stack(R)
|
||||
b = torch.cat(b)
|
||||
|
||||
# now predictor
|
||||
use_predictor = len(D1s) > 0 and x_t is None
|
||||
if len(D1s) > 0:
|
||||
D1s = torch.stack(D1s, dim=1) # (B, K)
|
||||
if x_t is None:
|
||||
# for order 2, we use a simplified version
|
||||
if order == 2:
|
||||
rhos_p = torch.tensor([0.5], device=b.device)
|
||||
else:
|
||||
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
||||
else:
|
||||
D1s = None
|
||||
|
||||
if use_corrector:
|
||||
#print('using corrector')
|
||||
# for order 1, we use a simplified version
|
||||
if order == 1:
|
||||
rhos_c = torch.tensor([0.5], device=b.device)
|
||||
else:
|
||||
rhos_c = torch.linalg.solve(R, b)
|
||||
|
||||
model_t = None
|
||||
if self.predict_x0:
|
||||
x_t_ = (
|
||||
sigma_t / sigma_prev_0 * x
|
||||
- alpha_t * h_phi_1 * model_prev_0
|
||||
)
|
||||
|
||||
if x_t is None:
|
||||
if use_predictor:
|
||||
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
||||
else:
|
||||
pred_res = 0
|
||||
x_t = x_t_ - alpha_t * B_h * pred_res
|
||||
|
||||
if use_corrector:
|
||||
model_t = self.model_fn(x_t, t)
|
||||
if D1s is not None:
|
||||
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
||||
else:
|
||||
corr_res = 0
|
||||
D1_t = (model_t - model_prev_0)
|
||||
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
||||
else:
|
||||
x_t_ = (
|
||||
torch.exp(log_alpha_t - log_alpha_prev_0) * x
|
||||
- sigma_t * h_phi_1 * model_prev_0
|
||||
)
|
||||
if x_t is None:
|
||||
if use_predictor:
|
||||
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
||||
else:
|
||||
pred_res = 0
|
||||
x_t = x_t_ - sigma_t * B_h * pred_res
|
||||
|
||||
if use_corrector:
|
||||
model_t = self.model_fn(x_t, t)
|
||||
if D1s is not None:
|
||||
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
||||
else:
|
||||
corr_res = 0
|
||||
D1_t = (model_t - model_prev_0)
|
||||
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
||||
return x_t, model_t
|
||||
|
||||
def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
|
||||
method='multistep', lower_order_final=True, denoise_to_zero=False, atol=0.0078, rtol=0.05, return_intermediate=False,
|
||||
):
|
||||
"""
|
||||
Compute the sample at time `t_end` by UniPC, given the initial `x` at time `t_start`.
|
||||
"""
|
||||
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
||||
t_T = self.noise_schedule.T if t_start is None else t_start
|
||||
assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
|
||||
if return_intermediate:
|
||||
assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values"
|
||||
if self.correcting_xt_fn is not None:
|
||||
assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None"
|
||||
device = x.device
|
||||
intermediates = []
|
||||
with torch.no_grad():
|
||||
if method == 'multistep':
|
||||
assert steps >= order
|
||||
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
||||
assert timesteps.shape[0] - 1 == steps
|
||||
# Init the initial values.
|
||||
step = 0
|
||||
t = timesteps[step]
|
||||
t_prev_list = [t]
|
||||
model_prev_list = [self.model_fn(x, t)]
|
||||
if self.correcting_xt_fn is not None:
|
||||
x = self.correcting_xt_fn(x, t, step)
|
||||
if return_intermediate:
|
||||
intermediates.append(x)
|
||||
|
||||
# Init the first `order` values by lower order multistep UniPC.
|
||||
for step in range(1, order):
|
||||
t = timesteps[step]
|
||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, t, step, use_corrector=True)
|
||||
if model_x is None:
|
||||
model_x = self.model_fn(x, t)
|
||||
if self.correcting_xt_fn is not None:
|
||||
x = self.correcting_xt_fn(x, t, step)
|
||||
if return_intermediate:
|
||||
intermediates.append(x)
|
||||
t_prev_list.append(t)
|
||||
model_prev_list.append(model_x)
|
||||
|
||||
# Compute the remaining values by `order`-th order multistep DPM-Solver.
|
||||
for step in range(order, steps + 1):
|
||||
t = timesteps[step]
|
||||
if lower_order_final:
|
||||
step_order = min(order, steps + 1 - step)
|
||||
else:
|
||||
step_order = order
|
||||
if step == steps:
|
||||
#print('do not run corrector at the last step')
|
||||
use_corrector = False
|
||||
else:
|
||||
use_corrector = True
|
||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, t, step_order, use_corrector=use_corrector)
|
||||
if self.correcting_xt_fn is not None:
|
||||
x = self.correcting_xt_fn(x, t, step)
|
||||
if return_intermediate:
|
||||
intermediates.append(x)
|
||||
for i in range(order - 1):
|
||||
t_prev_list[i] = t_prev_list[i + 1]
|
||||
model_prev_list[i] = model_prev_list[i + 1]
|
||||
t_prev_list[-1] = t
|
||||
# We do not need to evaluate the final model value.
|
||||
if step < steps:
|
||||
if model_x is None:
|
||||
model_x = self.model_fn(x, t)
|
||||
model_prev_list[-1] = model_x
|
||||
else:
|
||||
raise ValueError("Got wrong method {}".format(method))
|
||||
|
||||
if denoise_to_zero:
|
||||
t = torch.ones((1,)).to(device) * t_0
|
||||
x = self.denoise_to_zero_fn(x, t)
|
||||
if self.correcting_xt_fn is not None:
|
||||
x = self.correcting_xt_fn(x, t, step + 1)
|
||||
if return_intermediate:
|
||||
intermediates.append(x)
|
||||
if return_intermediate:
|
||||
return x, intermediates
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
#############################################################
|
||||
# other utility functions
|
||||
#############################################################
|
||||
|
||||
def interpolate_fn(x, xp, yp):
|
||||
"""
|
||||
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
||||
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
||||
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
||||
|
||||
Args:
|
||||
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
||||
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
||||
yp: PyTorch tensor with shape [C, K].
|
||||
Returns:
|
||||
The function values f(x), with shape [N, C].
|
||||
"""
|
||||
N, K = x.shape[0], xp.shape[1]
|
||||
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
||||
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
||||
x_idx = torch.argmin(x_indices, dim=2)
|
||||
cand_start_idx = x_idx - 1
|
||||
start_idx = torch.where(
|
||||
torch.eq(x_idx, 0),
|
||||
torch.tensor(1, device=x.device),
|
||||
torch.where(
|
||||
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
||||
),
|
||||
)
|
||||
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
||||
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
||||
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
||||
start_idx2 = torch.where(
|
||||
torch.eq(x_idx, 0),
|
||||
torch.tensor(0, device=x.device),
|
||||
torch.where(
|
||||
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
||||
),
|
||||
)
|
||||
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
||||
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
||||
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
||||
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
||||
return cand
|
||||
|
||||
|
||||
def expand_dims(v, dims):
|
||||
"""
|
||||
Expand the tensor `v` to the dim `dims`.
|
||||
|
||||
Args:
|
||||
`v`: a PyTorch tensor with shape [N].
|
||||
`dim`: a `int`.
|
||||
Returns:
|
||||
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
||||
"""
|
||||
return v[(...,) + (None,)*(dims - 1)]
|
77
server/voice_changer/DDSP_SVC/models/diffusion/unit2mel.py
Normal file
77
server/voice_changer/DDSP_SVC/models/diffusion/unit2mel.py
Normal file
@ -0,0 +1,77 @@
|
||||
import os
|
||||
import yaml
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from .diffusion import GaussianDiffusion
|
||||
from .wavenet import WaveNet
|
||||
from .vocoder import Vocoder
|
||||
|
||||
|
||||
class DotDict(dict):
|
||||
def __getattr__(*args):
|
||||
val = dict.get(*args)
|
||||
return DotDict(val) if type(val) is dict else val
|
||||
|
||||
__setattr__ = dict.__setitem__
|
||||
__delattr__ = dict.__delitem__
|
||||
|
||||
|
||||
def load_model_vocoder(model_path, device="cpu"):
|
||||
config_file = os.path.join(os.path.split(model_path)[0], "config.yaml")
|
||||
with open(config_file, "r") as config:
|
||||
args = yaml.safe_load(config)
|
||||
args = DotDict(args)
|
||||
|
||||
# load vocoder
|
||||
vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=device)
|
||||
|
||||
# load model
|
||||
model = Unit2Mel(args.data.encoder_out_channels, args.model.n_spk, args.model.use_pitch_aug, vocoder.dimension, args.model.n_layers, args.model.n_chans, args.model.n_hidden)
|
||||
|
||||
print(" [Loading] " + model_path)
|
||||
ckpt = torch.load(model_path, map_location=torch.device(device))
|
||||
model.to(device)
|
||||
model.load_state_dict(ckpt["model"])
|
||||
model.eval()
|
||||
return model, vocoder, args
|
||||
|
||||
|
||||
class Unit2Mel(nn.Module):
|
||||
def __init__(self, input_channel, n_spk, use_pitch_aug=False, out_dims=128, n_layers=20, n_chans=384, n_hidden=256):
|
||||
super().__init__()
|
||||
self.unit_embed = nn.Linear(input_channel, n_hidden)
|
||||
self.f0_embed = nn.Linear(1, n_hidden)
|
||||
self.volume_embed = nn.Linear(1, n_hidden)
|
||||
if use_pitch_aug:
|
||||
self.aug_shift_embed = nn.Linear(1, n_hidden, bias=False)
|
||||
else:
|
||||
self.aug_shift_embed = None
|
||||
self.n_spk = n_spk
|
||||
if n_spk is not None and n_spk > 1:
|
||||
self.spk_embed = nn.Embedding(n_spk, n_hidden)
|
||||
|
||||
# diffusion
|
||||
self.decoder = GaussianDiffusion(WaveNet(out_dims, n_layers, n_chans, n_hidden), out_dims=out_dims)
|
||||
|
||||
def forward(self, units, f0, volume, spk_id=None, spk_mix_dict=None, aug_shift=None, gt_spec=None, infer=True, infer_speedup=10, method="dpm-solver", k_step=300, use_tqdm=True):
|
||||
"""
|
||||
input:
|
||||
B x n_frames x n_unit
|
||||
return:
|
||||
dict of B x n_frames x feat
|
||||
"""
|
||||
|
||||
x = self.unit_embed(units) + self.f0_embed((1 + f0 / 700).log()) + self.volume_embed(volume)
|
||||
if self.n_spk is not None and self.n_spk > 1:
|
||||
if spk_mix_dict is not None:
|
||||
for k, v in spk_mix_dict.items():
|
||||
spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device)
|
||||
x = x + v * self.spk_embed(spk_id_torch - 1)
|
||||
else:
|
||||
x = x + self.spk_embed(spk_id - 1)
|
||||
if self.aug_shift_embed is not None and aug_shift is not None:
|
||||
x = x + self.aug_shift_embed(aug_shift / 5)
|
||||
x = self.decoder(x, gt_spec=gt_spec, infer=infer, infer_speedup=infer_speedup, method=method, k_step=k_step, use_tqdm=use_tqdm)
|
||||
|
||||
return x
|
87
server/voice_changer/DDSP_SVC/models/diffusion/vocoder.py
Normal file
87
server/voice_changer/DDSP_SVC/models/diffusion/vocoder.py
Normal file
@ -0,0 +1,87 @@
|
||||
import torch
|
||||
from torchaudio.transforms import Resample
|
||||
from ..nsf_hifigan.nvSTFT import STFT
|
||||
from ..nsf_hifigan.models import load_model, load_config
|
||||
|
||||
|
||||
class Vocoder:
|
||||
def __init__(self, vocoder_type, vocoder_ckpt, device=None):
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.device = device
|
||||
|
||||
if vocoder_type == "nsf-hifigan":
|
||||
self.vocoder = NsfHifiGAN(vocoder_ckpt, device=device)
|
||||
elif vocoder_type == "nsf-hifigan-log10":
|
||||
self.vocoder = NsfHifiGANLog10(vocoder_ckpt, device=device)
|
||||
else:
|
||||
raise ValueError(f" [x] Unknown vocoder: {vocoder_type}")
|
||||
|
||||
self.resample_kernel = {}
|
||||
self.vocoder_sample_rate = self.vocoder.sample_rate()
|
||||
self.vocoder_hop_size = self.vocoder.hop_size()
|
||||
self.dimension = self.vocoder.dimension()
|
||||
|
||||
def extract(self, audio, sample_rate, keyshift=0):
|
||||
# resample
|
||||
if sample_rate == self.vocoder_sample_rate:
|
||||
audio_res = audio
|
||||
else:
|
||||
key_str = str(sample_rate)
|
||||
if key_str not in self.resample_kernel:
|
||||
self.resample_kernel[key_str] = Resample(sample_rate, self.vocoder_sample_rate, lowpass_filter_width=128).to(self.device)
|
||||
audio_res = self.resample_kernel[key_str](audio)
|
||||
|
||||
# extract
|
||||
mel = self.vocoder.extract(audio_res, keyshift=keyshift) # B, n_frames, bins
|
||||
return mel
|
||||
|
||||
def infer(self, mel, f0):
|
||||
f0 = f0[:, : mel.size(1), 0] # B, n_frames
|
||||
audio = self.vocoder(mel, f0)
|
||||
return audio
|
||||
|
||||
|
||||
class NsfHifiGAN(torch.nn.Module):
|
||||
def __init__(self, model_path, device=None):
|
||||
super().__init__()
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.device = device
|
||||
self.model_path = model_path
|
||||
self.model = None
|
||||
self.h = load_config(model_path)
|
||||
self.stft = STFT(self.h.sampling_rate, self.h.num_mels, self.h.n_fft, self.h.win_size, self.h.hop_size, self.h.fmin, self.h.fmax)
|
||||
|
||||
def sample_rate(self):
|
||||
return self.h.sampling_rate
|
||||
|
||||
def hop_size(self):
|
||||
return self.h.hop_size
|
||||
|
||||
def dimension(self):
|
||||
return self.h.num_mels
|
||||
|
||||
def extract(self, audio, keyshift=0):
|
||||
mel = self.stft.get_mel(audio, keyshift=keyshift).transpose(1, 2) # B, n_frames, bins
|
||||
return mel
|
||||
|
||||
def forward(self, mel, f0):
|
||||
if self.model is None:
|
||||
print("| Load HifiGAN: ", self.model_path)
|
||||
self.model, self.h = load_model(self.model_path, device=self.device)
|
||||
with torch.no_grad():
|
||||
c = mel.transpose(1, 2)
|
||||
audio = self.model(c, f0)
|
||||
return audio
|
||||
|
||||
|
||||
class NsfHifiGANLog10(NsfHifiGAN):
|
||||
def forward(self, mel, f0):
|
||||
if self.model is None:
|
||||
print("| Load HifiGAN: ", self.model_path)
|
||||
self.model, self.h = load_model(self.model_path, device=self.device)
|
||||
with torch.no_grad():
|
||||
c = 0.434294 * mel.transpose(1, 2)
|
||||
audio = self.model(c, f0)
|
||||
return audio
|
91
server/voice_changer/DDSP_SVC/models/diffusion/wavenet.py
Normal file
91
server/voice_changer/DDSP_SVC/models/diffusion/wavenet.py
Normal file
@ -0,0 +1,91 @@
|
||||
import math
|
||||
from math import sqrt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Mish
|
||||
|
||||
|
||||
class Conv1d(torch.nn.Conv1d):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
nn.init.kaiming_normal_(self.weight)
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, encoder_hidden, residual_channels, dilation):
|
||||
super().__init__()
|
||||
self.residual_channels = residual_channels
|
||||
self.dilated_conv = nn.Conv1d(residual_channels, 2 * residual_channels, kernel_size=3, padding=dilation, dilation=dilation)
|
||||
self.diffusion_projection = nn.Linear(residual_channels, residual_channels)
|
||||
self.conditioner_projection = nn.Conv1d(encoder_hidden, 2 * residual_channels, 1)
|
||||
self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1)
|
||||
|
||||
def forward(self, x, conditioner, diffusion_step):
|
||||
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
|
||||
conditioner = self.conditioner_projection(conditioner)
|
||||
y = x + diffusion_step
|
||||
|
||||
y = self.dilated_conv(y) + conditioner
|
||||
|
||||
# Using torch.split instead of torch.chunk to avoid using onnx::Slice
|
||||
gate, filter = torch.split(y, [self.residual_channels, self.residual_channels], dim=1)
|
||||
y = torch.sigmoid(gate) * torch.tanh(filter)
|
||||
|
||||
y = self.output_projection(y)
|
||||
|
||||
# Using torch.split instead of torch.chunk to avoid using onnx::Slice
|
||||
residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1)
|
||||
return (x + residual) / math.sqrt(2.0), skip
|
||||
|
||||
|
||||
class WaveNet(nn.Module):
|
||||
def __init__(self, in_dims=128, n_layers=20, n_chans=384, n_hidden=256):
|
||||
super().__init__()
|
||||
self.input_projection = Conv1d(in_dims, n_chans, 1)
|
||||
self.diffusion_embedding = SinusoidalPosEmb(n_chans)
|
||||
self.mlp = nn.Sequential(nn.Linear(n_chans, n_chans * 4), Mish(), nn.Linear(n_chans * 4, n_chans))
|
||||
self.residual_layers = nn.ModuleList([ResidualBlock(encoder_hidden=n_hidden, residual_channels=n_chans, dilation=1) for i in range(n_layers)])
|
||||
self.skip_projection = Conv1d(n_chans, n_chans, 1)
|
||||
self.output_projection = Conv1d(n_chans, in_dims, 1)
|
||||
nn.init.zeros_(self.output_projection.weight)
|
||||
|
||||
def forward(self, spec, diffusion_step, cond):
|
||||
"""
|
||||
:param spec: [B, 1, M, T]
|
||||
:param diffusion_step: [B, 1]
|
||||
:param cond: [B, M, T]
|
||||
:return:
|
||||
"""
|
||||
x = spec.squeeze(1)
|
||||
x = self.input_projection(x) # [B, residual_channel, T]
|
||||
|
||||
x = F.relu(x)
|
||||
diffusion_step = self.diffusion_embedding(diffusion_step)
|
||||
diffusion_step = self.mlp(diffusion_step)
|
||||
skip = []
|
||||
for layer in self.residual_layers:
|
||||
x, skip_connection = layer(x, cond, diffusion_step)
|
||||
skip.append(skip_connection)
|
||||
|
||||
x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers))
|
||||
x = self.skip_projection(x)
|
||||
x = F.relu(x)
|
||||
x = self.output_projection(x) # [B, mel_bins, T]
|
||||
return x[:, None, :, :]
|
293
server/voice_changer/DDSP_SVC/models/encoder/hubert/model.py
Normal file
293
server/voice_changer/DDSP_SVC/models/encoder/hubert/model.py
Normal file
@ -0,0 +1,293 @@
|
||||
import copy
|
||||
from typing import Optional, Tuple
|
||||
import random
|
||||
|
||||
from sklearn.cluster import KMeans
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
||||
|
||||
URLS = {
|
||||
"hubert-discrete": "https://github.com/bshall/hubert/releases/download/v0.1/hubert-discrete-e9416457.pt",
|
||||
"hubert-soft": "https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt",
|
||||
"kmeans100": "https://github.com/bshall/hubert/releases/download/v0.1/kmeans100-50f36a95.pt",
|
||||
}
|
||||
|
||||
|
||||
class Hubert(nn.Module):
|
||||
def __init__(self, num_label_embeddings: int = 100, mask: bool = True):
|
||||
super().__init__()
|
||||
self._mask = mask
|
||||
self.feature_extractor = FeatureExtractor()
|
||||
self.feature_projection = FeatureProjection()
|
||||
self.positional_embedding = PositionalConvEmbedding()
|
||||
self.norm = nn.LayerNorm(768)
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.encoder = TransformerEncoder(
|
||||
nn.TransformerEncoderLayer(
|
||||
768, 12, 3072, activation="gelu", batch_first=True
|
||||
),
|
||||
12,
|
||||
)
|
||||
self.proj = nn.Linear(768, 256)
|
||||
|
||||
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_())
|
||||
self.label_embedding = nn.Embedding(num_label_embeddings, 256)
|
||||
|
||||
def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
mask = None
|
||||
if self.training and self._mask:
|
||||
mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2)
|
||||
x[mask] = self.masked_spec_embed.to(x.dtype)
|
||||
return x, mask
|
||||
|
||||
def encode(
|
||||
self, x: torch.Tensor, layer: Optional[int] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x = self.feature_extractor(x)
|
||||
x = self.feature_projection(x.transpose(1, 2))
|
||||
x, mask = self.mask(x)
|
||||
x = x + self.positional_embedding(x)
|
||||
x = self.dropout(self.norm(x))
|
||||
x = self.encoder(x, output_layer=layer)
|
||||
return x, mask
|
||||
|
||||
def logits(self, x: torch.Tensor) -> torch.Tensor:
|
||||
logits = torch.cosine_similarity(
|
||||
x.unsqueeze(2),
|
||||
self.label_embedding.weight.unsqueeze(0).unsqueeze(0),
|
||||
dim=-1,
|
||||
)
|
||||
return logits / 0.1
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x, mask = self.encode(x)
|
||||
x = self.proj(x)
|
||||
logits = self.logits(x)
|
||||
return logits, mask
|
||||
|
||||
|
||||
class HubertSoft(Hubert):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@torch.inference_mode()
|
||||
def units(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))
|
||||
x, _ = self.encode(wav)
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
class HubertDiscrete(Hubert):
|
||||
def __init__(self, kmeans):
|
||||
super().__init__(504)
|
||||
self.kmeans = kmeans
|
||||
|
||||
@torch.inference_mode()
|
||||
def units(self, wav: torch.Tensor) -> torch.LongTensor:
|
||||
wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))
|
||||
x, _ = self.encode(wav, layer=7)
|
||||
x = self.kmeans.predict(x.squeeze().cpu().numpy())
|
||||
return torch.tensor(x, dtype=torch.long, device=wav.device)
|
||||
|
||||
|
||||
class FeatureExtractor(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False)
|
||||
self.norm0 = nn.GroupNorm(512, 512)
|
||||
self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False)
|
||||
self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False)
|
||||
self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False)
|
||||
self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False)
|
||||
self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False)
|
||||
self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = F.gelu(self.norm0(self.conv0(x)))
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
x = F.gelu(self.conv3(x))
|
||||
x = F.gelu(self.conv4(x))
|
||||
x = F.gelu(self.conv5(x))
|
||||
x = F.gelu(self.conv6(x))
|
||||
return x
|
||||
|
||||
|
||||
class FeatureProjection(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(512)
|
||||
self.projection = nn.Linear(512, 768)
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.norm(x)
|
||||
x = self.projection(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class PositionalConvEmbedding(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(
|
||||
768,
|
||||
768,
|
||||
kernel_size=128,
|
||||
padding=128 // 2,
|
||||
groups=16,
|
||||
)
|
||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv(x.transpose(1, 2))
|
||||
x = F.gelu(x[:, :, :-1])
|
||||
return x.transpose(1, 2)
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int
|
||||
) -> None:
|
||||
super(TransformerEncoder, self).__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[copy.deepcopy(encoder_layer) for _ in range(num_layers)]
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
src_key_padding_mask: torch.Tensor = None,
|
||||
output_layer: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
output = src
|
||||
for layer in self.layers[:output_layer]:
|
||||
output = layer(
|
||||
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def _compute_mask(
|
||||
shape: Tuple[int, int],
|
||||
mask_prob: float,
|
||||
mask_length: int,
|
||||
device: torch.device,
|
||||
min_masks: int = 0,
|
||||
) -> torch.Tensor:
|
||||
batch_size, sequence_length = shape
|
||||
|
||||
if mask_length < 1:
|
||||
raise ValueError("`mask_length` has to be bigger than 0.")
|
||||
|
||||
if mask_length > sequence_length:
|
||||
raise ValueError(
|
||||
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
|
||||
)
|
||||
|
||||
# compute number of masked spans in batch
|
||||
num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random())
|
||||
num_masked_spans = max(num_masked_spans, min_masks)
|
||||
|
||||
# make sure num masked indices <= sequence_length
|
||||
if num_masked_spans * mask_length > sequence_length:
|
||||
num_masked_spans = sequence_length // mask_length
|
||||
|
||||
# SpecAugment mask to fill
|
||||
mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)
|
||||
|
||||
# uniform distribution to sample from, make sure that offset samples are < sequence_length
|
||||
uniform_dist = torch.ones(
|
||||
(batch_size, sequence_length - (mask_length - 1)), device=device
|
||||
)
|
||||
|
||||
# get random indices to mask
|
||||
mask_indices = torch.multinomial(uniform_dist, num_masked_spans)
|
||||
|
||||
# expand masked indices to masked spans
|
||||
mask_indices = (
|
||||
mask_indices.unsqueeze(dim=-1)
|
||||
.expand((batch_size, num_masked_spans, mask_length))
|
||||
.reshape(batch_size, num_masked_spans * mask_length)
|
||||
)
|
||||
offsets = (
|
||||
torch.arange(mask_length, device=device)[None, None, :]
|
||||
.expand((batch_size, num_masked_spans, mask_length))
|
||||
.reshape(batch_size, num_masked_spans * mask_length)
|
||||
)
|
||||
mask_idxs = mask_indices + offsets
|
||||
|
||||
# scatter indices to mask
|
||||
mask = mask.scatter(1, mask_idxs, True)
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def hubert_discrete(
|
||||
pretrained: bool = True,
|
||||
progress: bool = True,
|
||||
) -> HubertDiscrete:
|
||||
r"""HuBERT-Discrete from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
|
||||
Args:
|
||||
pretrained (bool): load pretrained weights into the model
|
||||
progress (bool): show progress bar when downloading model
|
||||
"""
|
||||
kmeans = kmeans100(pretrained=pretrained, progress=progress)
|
||||
hubert = HubertDiscrete(kmeans)
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
URLS["hubert-discrete"], progress=progress
|
||||
)
|
||||
consume_prefix_in_state_dict_if_present(checkpoint, "module.")
|
||||
hubert.load_state_dict(checkpoint)
|
||||
hubert.eval()
|
||||
return hubert
|
||||
|
||||
|
||||
def hubert_soft(
|
||||
pretrained: bool = True,
|
||||
progress: bool = True,
|
||||
) -> HubertSoft:
|
||||
r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
|
||||
Args:
|
||||
pretrained (bool): load pretrained weights into the model
|
||||
progress (bool): show progress bar when downloading model
|
||||
"""
|
||||
hubert = HubertSoft()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
URLS["hubert-soft"], progress=progress
|
||||
)
|
||||
consume_prefix_in_state_dict_if_present(checkpoint, "module.")
|
||||
hubert.load_state_dict(checkpoint)
|
||||
hubert.eval()
|
||||
return hubert
|
||||
|
||||
|
||||
def _kmeans(
|
||||
num_clusters: int, pretrained: bool = True, progress: bool = True
|
||||
) -> KMeans:
|
||||
kmeans = KMeans(num_clusters)
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
URLS[f"kmeans{num_clusters}"], progress=progress
|
||||
)
|
||||
kmeans.__dict__["n_features_in_"] = checkpoint["n_features_in_"]
|
||||
kmeans.__dict__["_n_threads"] = checkpoint["_n_threads"]
|
||||
kmeans.__dict__["cluster_centers_"] = checkpoint["cluster_centers_"].numpy()
|
||||
return kmeans
|
||||
|
||||
|
||||
def kmeans100(pretrained: bool = True, progress: bool = True) -> KMeans:
|
||||
r"""
|
||||
k-means checkpoint for HuBERT-Discrete with 100 clusters.
|
||||
Args:
|
||||
pretrained (bool): load pretrained weights into the model
|
||||
progress (bool): show progress bar when downloading model
|
||||
"""
|
||||
return _kmeans(100, pretrained, progress)
|
102
server/voice_changer/DDSP_SVC/models/enhancer.py
Normal file
102
server/voice_changer/DDSP_SVC/models/enhancer.py
Normal file
@ -0,0 +1,102 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchaudio.transforms import Resample
|
||||
from .nsf_hifigan.nvSTFT import STFT
|
||||
from .nsf_hifigan.models import load_model
|
||||
|
||||
|
||||
class Enhancer:
|
||||
def __init__(self, enhancer_type, enhancer_ckpt, device=None):
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.device = device
|
||||
|
||||
if enhancer_type == "nsf-hifigan":
|
||||
self.enhancer = NsfHifiGAN(enhancer_ckpt, device=self.device)
|
||||
else:
|
||||
raise ValueError(f" [x] Unknown enhancer: {enhancer_type}")
|
||||
|
||||
self.resample_kernel = {}
|
||||
self.enhancer_sample_rate = self.enhancer.sample_rate()
|
||||
self.enhancer_hop_size = self.enhancer.hop_size()
|
||||
|
||||
def enhance(self, audio, sample_rate, f0, hop_size, adaptive_key=0, silence_front=0): # 1, T # 1, n_frames, 1
|
||||
# enhancer start time
|
||||
start_frame = int(silence_front * sample_rate / hop_size)
|
||||
real_silence_front = start_frame * hop_size / sample_rate
|
||||
audio = audio[:, int(np.round(real_silence_front * sample_rate)) :]
|
||||
f0 = f0[:, start_frame:, :]
|
||||
|
||||
# adaptive parameters
|
||||
if adaptive_key == "auto":
|
||||
adaptive_key = 12 * np.log2(float(torch.max(f0) / 760))
|
||||
adaptive_key = max(0, np.ceil(adaptive_key))
|
||||
print("auto_adaptive_key: " + str(int(adaptive_key)))
|
||||
else:
|
||||
adaptive_key = float(adaptive_key)
|
||||
|
||||
adaptive_factor = 2 ** (-adaptive_key / 12)
|
||||
adaptive_sample_rate = 100 * int(np.round(self.enhancer_sample_rate / adaptive_factor / 100))
|
||||
real_factor = self.enhancer_sample_rate / adaptive_sample_rate
|
||||
|
||||
# resample the ddsp output
|
||||
if sample_rate == adaptive_sample_rate:
|
||||
audio_res = audio
|
||||
else:
|
||||
key_str = str(sample_rate) + str(adaptive_sample_rate)
|
||||
if key_str not in self.resample_kernel:
|
||||
self.resample_kernel[key_str] = Resample(sample_rate, adaptive_sample_rate, lowpass_filter_width=128).to(self.device)
|
||||
audio_res = self.resample_kernel[key_str](audio)
|
||||
|
||||
n_frames = int(audio_res.size(-1) // self.enhancer_hop_size + 1)
|
||||
|
||||
# resample f0
|
||||
if hop_size == self.enhancer_hop_size and sample_rate == self.enhancer_sample_rate and sample_rate == adaptive_sample_rate:
|
||||
f0_res = f0.squeeze(-1) # 1, n_frames
|
||||
else:
|
||||
f0_np = f0.squeeze(0).squeeze(-1).cpu().numpy()
|
||||
f0_np *= real_factor
|
||||
time_org = (hop_size / sample_rate) * np.arange(len(f0_np)) / real_factor
|
||||
time_frame = (self.enhancer_hop_size / self.enhancer_sample_rate) * np.arange(n_frames)
|
||||
f0_res = np.interp(time_frame, time_org, f0_np, left=f0_np[0], right=f0_np[-1])
|
||||
f0_res = torch.from_numpy(f0_res).unsqueeze(0).float().to(self.device) # 1, n_frames
|
||||
|
||||
# enhance
|
||||
enhanced_audio, enhancer_sample_rate = self.enhancer(audio_res, f0_res)
|
||||
|
||||
# resample the enhanced output
|
||||
if adaptive_sample_rate != enhancer_sample_rate:
|
||||
key_str = str(adaptive_sample_rate) + str(enhancer_sample_rate)
|
||||
if key_str not in self.resample_kernel:
|
||||
self.resample_kernel[key_str] = Resample(adaptive_sample_rate, enhancer_sample_rate, lowpass_filter_width=128).to(self.device)
|
||||
enhanced_audio = self.resample_kernel[key_str](enhanced_audio)
|
||||
|
||||
# pad the silence frames
|
||||
if start_frame > 0:
|
||||
enhanced_audio = F.pad(enhanced_audio, (int(np.round(enhancer_sample_rate * real_silence_front)), 0))
|
||||
|
||||
return enhanced_audio, enhancer_sample_rate
|
||||
|
||||
|
||||
class NsfHifiGAN(torch.nn.Module):
|
||||
def __init__(self, model_path, device=None):
|
||||
super().__init__()
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.device = device
|
||||
print("| Load HifiGAN: ", model_path)
|
||||
self.model, self.h = load_model(model_path, device=self.device)
|
||||
self.stft = STFT(self.h.sampling_rate, self.h.num_mels, self.h.n_fft, self.h.win_size, self.h.hop_size, self.h.fmin, self.h.fmax)
|
||||
|
||||
def sample_rate(self):
|
||||
return self.h.sampling_rate
|
||||
|
||||
def hop_size(self):
|
||||
return self.h.hop_size
|
||||
|
||||
def forward(self, audio, f0):
|
||||
with torch.no_grad():
|
||||
mel = self.stft.get_mel(audio)
|
||||
enhanced_audio = self.model(mel, f0[:, : mel.size(-1)])
|
||||
return enhanced_audio, self.h.sampling_rate
|
15
server/voice_changer/DDSP_SVC/models/nsf_hifigan/env.py
Normal file
15
server/voice_changer/DDSP_SVC/models/nsf_hifigan/env.py
Normal file
@ -0,0 +1,15 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
def build_env(config, config_name, path):
|
||||
t_path = os.path.join(path, config_name)
|
||||
if config != t_path:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
shutil.copyfile(config, os.path.join(path, config_name))
|
434
server/voice_changer/DDSP_SVC/models/nsf_hifigan/models.py
Normal file
434
server/voice_changer/DDSP_SVC/models/nsf_hifigan/models.py
Normal file
@ -0,0 +1,434 @@
|
||||
import os
|
||||
import json
|
||||
from .env import AttrDict
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
from .utils import init_weights, get_padding
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
def load_model(model_path, device='cuda'):
|
||||
h = load_config(model_path)
|
||||
|
||||
generator = Generator(h).to(device)
|
||||
|
||||
cp_dict = torch.load(model_path, map_location=device)
|
||||
generator.load_state_dict(cp_dict['generator'])
|
||||
generator.eval()
|
||||
generator.remove_weight_norm()
|
||||
del cp_dict
|
||||
return generator, h
|
||||
|
||||
def load_config(model_path):
|
||||
config_file = os.path.join(os.path.split(model_path)[0], 'config.json')
|
||||
with open(config_file) as f:
|
||||
data = f.read()
|
||||
|
||||
json_config = json.loads(data)
|
||||
h = AttrDict(json_config)
|
||||
return h
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super(ResBlock1, self).__init__()
|
||||
self.h = h
|
||||
self.convs1 = nn.ModuleList([
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2])))
|
||||
])
|
||||
self.convs1.apply(init_weights)
|
||||
|
||||
self.convs2 = nn.ModuleList([
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1)))
|
||||
])
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
||||
super(ResBlock2, self).__init__()
|
||||
self.h = h
|
||||
self.convs = nn.ModuleList([
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1])))
|
||||
])
|
||||
self.convs.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class SineGen(torch.nn.Module):
|
||||
""" Definition of sine generator
|
||||
SineGen(samp_rate, harmonic_num = 0,
|
||||
sine_amp = 0.1, noise_std = 0.003,
|
||||
voiced_threshold = 0,
|
||||
flag_for_pulse=False)
|
||||
samp_rate: sampling rate in Hz
|
||||
harmonic_num: number of harmonic overtones (default 0)
|
||||
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
||||
noise_std: std of Gaussian noise (default 0.003)
|
||||
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
||||
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
||||
Note: when flag_for_pulse is True, the first time step of a voiced
|
||||
segment is always sin(np.pi) or cos(0)
|
||||
"""
|
||||
|
||||
def __init__(self, samp_rate, harmonic_num=0,
|
||||
sine_amp=0.1, noise_std=0.003,
|
||||
voiced_threshold=0):
|
||||
super(SineGen, self).__init__()
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = noise_std
|
||||
self.harmonic_num = harmonic_num
|
||||
self.dim = self.harmonic_num + 1
|
||||
self.sampling_rate = samp_rate
|
||||
self.voiced_threshold = voiced_threshold
|
||||
|
||||
def _f02uv(self, f0):
|
||||
# generate uv signal
|
||||
uv = torch.ones_like(f0)
|
||||
uv = uv * (f0 > self.voiced_threshold)
|
||||
return uv
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, f0, upp):
|
||||
""" sine_tensor, uv = forward(f0)
|
||||
input F0: tensor(batchsize=1, length, dim=1)
|
||||
f0 for unvoiced steps should be 0
|
||||
output sine_tensor: tensor(batchsize=1, length, dim)
|
||||
output uv: tensor(batchsize=1, length, 1)
|
||||
"""
|
||||
f0 = f0.unsqueeze(-1)
|
||||
fn = torch.multiply(f0, torch.arange(1, self.dim + 1, device=f0.device).reshape((1, 1, -1)))
|
||||
rad_values = (fn / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
|
||||
rand_ini = torch.rand(fn.shape[0], fn.shape[2], device=fn.device)
|
||||
rand_ini[:, 0] = 0
|
||||
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
||||
is_half = rad_values.dtype is not torch.float32
|
||||
tmp_over_one = torch.cumsum(rad_values.double(), 1) # % 1 #####%1意味着后面的cumsum无法再优化
|
||||
if is_half:
|
||||
tmp_over_one = tmp_over_one.half()
|
||||
else:
|
||||
tmp_over_one = tmp_over_one.float()
|
||||
tmp_over_one *= upp
|
||||
tmp_over_one = F.interpolate(
|
||||
tmp_over_one.transpose(2, 1), scale_factor=upp,
|
||||
mode='linear', align_corners=True
|
||||
).transpose(2, 1)
|
||||
rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1)
|
||||
tmp_over_one %= 1
|
||||
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
|
||||
cumsum_shift = torch.zeros_like(rad_values)
|
||||
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
||||
rad_values = rad_values.double()
|
||||
cumsum_shift = cumsum_shift.double()
|
||||
sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
|
||||
if is_half:
|
||||
sine_waves = sine_waves.half()
|
||||
else:
|
||||
sine_waves = sine_waves.float()
|
||||
sine_waves = sine_waves * self.sine_amp
|
||||
return sine_waves
|
||||
|
||||
|
||||
class SourceModuleHnNSF(torch.nn.Module):
|
||||
""" SourceModule for hn-nsf
|
||||
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
||||
add_noise_std=0.003, voiced_threshod=0)
|
||||
sampling_rate: sampling_rate in Hz
|
||||
harmonic_num: number of harmonic above F0 (default: 0)
|
||||
sine_amp: amplitude of sine source signal (default: 0.1)
|
||||
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
||||
note that amplitude of noise in unvoiced is decided
|
||||
by sine_amp
|
||||
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
||||
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||
F0_sampled (batchsize, length, 1)
|
||||
Sine_source (batchsize, length, 1)
|
||||
noise_source (batchsize, length 1)
|
||||
uv (batchsize, length, 1)
|
||||
"""
|
||||
|
||||
def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
|
||||
add_noise_std=0.003, voiced_threshod=0):
|
||||
super(SourceModuleHnNSF, self).__init__()
|
||||
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = add_noise_std
|
||||
|
||||
# to produce sine waveforms
|
||||
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
|
||||
sine_amp, add_noise_std, voiced_threshod)
|
||||
|
||||
# to merge source harmonics into a single excitation
|
||||
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
||||
self.l_tanh = torch.nn.Tanh()
|
||||
|
||||
def forward(self, x, upp):
|
||||
sine_wavs = self.l_sin_gen(x, upp)
|
||||
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||
return sine_merge
|
||||
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(self, h):
|
||||
super(Generator, self).__init__()
|
||||
self.h = h
|
||||
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||
self.num_upsamples = len(h.upsample_rates)
|
||||
self.m_source = SourceModuleHnNSF(
|
||||
sampling_rate=h.sampling_rate,
|
||||
harmonic_num=8
|
||||
)
|
||||
self.noise_convs = nn.ModuleList()
|
||||
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
|
||||
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||
c_cur = h.upsample_initial_channel // (2 ** (i + 1))
|
||||
self.ups.append(weight_norm(
|
||||
ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)),
|
||||
k, u, padding=(k - u) // 2)))
|
||||
if i + 1 < len(h.upsample_rates): #
|
||||
stride_f0 = int(np.prod(h.upsample_rates[i + 1:]))
|
||||
self.noise_convs.append(Conv1d(
|
||||
1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
|
||||
else:
|
||||
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
|
||||
self.resblocks = nn.ModuleList()
|
||||
ch = h.upsample_initial_channel
|
||||
for i in range(len(self.ups)):
|
||||
ch //= 2
|
||||
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(h, ch, k, d))
|
||||
|
||||
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
||||
self.ups.apply(init_weights)
|
||||
self.conv_post.apply(init_weights)
|
||||
self.upp = int(np.prod(h.upsample_rates))
|
||||
|
||||
def forward(self, x, f0):
|
||||
har_source = self.m_source(f0, self.upp).transpose(1, 2)
|
||||
x = self.conv_pre(x)
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
x_source = self.noise_convs[i](har_source)
|
||||
x = x + x_source
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||
super(DiscriminatorP, self).__init__()
|
||||
self.period = period
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||
])
|
||||
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
def __init__(self, periods=None):
|
||||
super(MultiPeriodDiscriminator, self).__init__()
|
||||
self.periods = periods if periods is not None else [2, 3, 5, 7, 11]
|
||||
self.discriminators = nn.ModuleList()
|
||||
for period in self.periods:
|
||||
self.discriminators.append(DiscriminatorP(period))
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(DiscriminatorS, self).__init__()
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
||||
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
||||
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
||||
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
])
|
||||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiScaleDiscriminator(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MultiScaleDiscriminator, self).__init__()
|
||||
self.discriminators = nn.ModuleList([
|
||||
DiscriminatorS(use_spectral_norm=True),
|
||||
DiscriminatorS(),
|
||||
DiscriminatorS(),
|
||||
])
|
||||
self.meanpools = nn.ModuleList([
|
||||
AvgPool1d(4, 2, padding=2),
|
||||
AvgPool1d(4, 2, padding=2)
|
||||
])
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
if i != 0:
|
||||
y = self.meanpools[i - 1](y)
|
||||
y_hat = self.meanpools[i - 1](y_hat)
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
def feature_loss(fmap_r, fmap_g):
|
||||
loss = 0
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
|
||||
return loss * 2
|
||||
|
||||
|
||||
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||
loss = 0
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
r_loss = torch.mean((1 - dr) ** 2)
|
||||
g_loss = torch.mean(dg ** 2)
|
||||
loss += (r_loss + g_loss)
|
||||
r_losses.append(r_loss.item())
|
||||
g_losses.append(g_loss.item())
|
||||
|
||||
return loss, r_losses, g_losses
|
||||
|
||||
|
||||
def generator_loss(disc_outputs):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
l = torch.mean((1 - dg) ** 2)
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
129
server/voice_changer/DDSP_SVC/models/nsf_hifigan/nvSTFT.py
Normal file
129
server/voice_changer/DDSP_SVC/models/nsf_hifigan/nvSTFT.py
Normal file
@ -0,0 +1,129 @@
|
||||
import math
|
||||
import os
|
||||
os.environ["LRU_CACHE_CAPACITY"] = "3"
|
||||
import random
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
import librosa
|
||||
from librosa.util import normalize
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from scipy.io.wavfile import read
|
||||
import soundfile as sf
|
||||
import torch.nn.functional as F
|
||||
|
||||
def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
|
||||
sampling_rate = None
|
||||
try:
|
||||
data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile.
|
||||
except Exception as ex:
|
||||
print(f"'{full_path}' failed to load.\nException:")
|
||||
print(ex)
|
||||
if return_empty_on_exception:
|
||||
return [], sampling_rate or target_sr or 48000
|
||||
else:
|
||||
raise Exception(ex)
|
||||
|
||||
if len(data.shape) > 1:
|
||||
data = data[:, 0]
|
||||
assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension)
|
||||
|
||||
if np.issubdtype(data.dtype, np.integer): # if audio data is type int
|
||||
max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX
|
||||
else: # if audio data is type fp32
|
||||
max_mag = max(np.amax(data), -np.amin(data))
|
||||
max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32
|
||||
|
||||
data = torch.FloatTensor(data.astype(np.float32))/max_mag
|
||||
|
||||
if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except
|
||||
return [], sampling_rate or target_sr or 48000
|
||||
if target_sr is not None and sampling_rate != target_sr:
|
||||
data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr))
|
||||
sampling_rate = target_sr
|
||||
|
||||
return data, sampling_rate
|
||||
|
||||
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
||||
|
||||
def dynamic_range_decompression(x, C=1):
|
||||
return np.exp(x) / C
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
def dynamic_range_decompression_torch(x, C=1):
|
||||
return torch.exp(x) / C
|
||||
|
||||
class STFT():
|
||||
def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
|
||||
self.target_sr = sr
|
||||
|
||||
self.n_mels = n_mels
|
||||
self.n_fft = n_fft
|
||||
self.win_size = win_size
|
||||
self.hop_length = hop_length
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.clip_val = clip_val
|
||||
self.mel_basis = {}
|
||||
self.hann_window = {}
|
||||
|
||||
def get_mel(self, y, keyshift=0, speed=1, center=False):
|
||||
sampling_rate = self.target_sr
|
||||
n_mels = self.n_mels
|
||||
n_fft = self.n_fft
|
||||
win_size = self.win_size
|
||||
hop_length = self.hop_length
|
||||
fmin = self.fmin
|
||||
fmax = self.fmax
|
||||
clip_val = self.clip_val
|
||||
|
||||
factor = 2 ** (keyshift / 12)
|
||||
n_fft_new = int(np.round(n_fft * factor))
|
||||
win_size_new = int(np.round(win_size * factor))
|
||||
hop_length_new = int(np.round(hop_length * speed))
|
||||
|
||||
if torch.min(y) < -1.:
|
||||
print('min value is ', torch.min(y))
|
||||
if torch.max(y) > 1.:
|
||||
print('max value is ', torch.max(y))
|
||||
|
||||
mel_basis_key = str(fmax)+'_'+str(y.device)
|
||||
if mel_basis_key not in self.mel_basis:
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
|
||||
self.mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device)
|
||||
|
||||
keyshift_key = str(keyshift)+'_'+str(y.device)
|
||||
if keyshift_key not in self.hann_window:
|
||||
self.hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
|
||||
|
||||
pad_left = (win_size_new - hop_length_new) //2
|
||||
pad_right = max((win_size_new- hop_length_new + 1) //2, win_size_new - y.size(-1) - pad_left)
|
||||
if pad_right < y.size(-1):
|
||||
mode = 'reflect'
|
||||
else:
|
||||
mode = 'constant'
|
||||
y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode = mode)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=self.hann_window[keyshift_key],
|
||||
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
|
||||
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9))
|
||||
if keyshift != 0:
|
||||
size = n_fft // 2 + 1
|
||||
resize = spec.size(1)
|
||||
if resize < size:
|
||||
spec = F.pad(spec, (0, 0, 0, size-resize))
|
||||
spec = spec[:, :size, :] * win_size / win_size_new
|
||||
spec = torch.matmul(self.mel_basis[mel_basis_key], spec)
|
||||
spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
|
||||
return spec
|
||||
|
||||
def __call__(self, audiopath):
|
||||
audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr)
|
||||
spect = self.get_mel(audio.unsqueeze(0)).squeeze(0)
|
||||
return spect
|
||||
|
||||
stft = STFT()
|
68
server/voice_changer/DDSP_SVC/models/nsf_hifigan/utils.py
Normal file
68
server/voice_changer/DDSP_SVC/models/nsf_hifigan/utils.py
Normal file
@ -0,0 +1,68 @@
|
||||
import glob
|
||||
import os
|
||||
import matplotlib
|
||||
import torch
|
||||
from torch.nn.utils import weight_norm
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pylab as plt
|
||||
|
||||
|
||||
def plot_spectrogram(spectrogram):
|
||||
fig, ax = plt.subplots(figsize=(10, 2))
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
||||
interpolation='none')
|
||||
plt.colorbar(im, ax=ax)
|
||||
|
||||
fig.canvas.draw()
|
||||
plt.close()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def apply_weight_norm(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
weight_norm(m)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size*dilation - dilation)/2)
|
||||
|
||||
|
||||
def load_checkpoint(filepath, device):
|
||||
assert os.path.isfile(filepath)
|
||||
print("Loading '{}'".format(filepath))
|
||||
checkpoint_dict = torch.load(filepath, map_location=device)
|
||||
print("Complete.")
|
||||
return checkpoint_dict
|
||||
|
||||
|
||||
def save_checkpoint(filepath, obj):
|
||||
print("Saving checkpoint to {}".format(filepath))
|
||||
torch.save(obj, filepath)
|
||||
print("Complete.")
|
||||
|
||||
|
||||
def del_old_checkpoints(cp_dir, prefix, n_models=2):
|
||||
pattern = os.path.join(cp_dir, prefix + '????????')
|
||||
cp_list = glob.glob(pattern) # get checkpoint paths
|
||||
cp_list = sorted(cp_list)# sort by iter
|
||||
if len(cp_list) > n_models: # if more than n_models models are found
|
||||
for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models
|
||||
open(cp, 'w').close()# empty file contents
|
||||
os.unlink(cp)# delete file (move to trash when using Colab)
|
||||
|
||||
|
||||
def scan_checkpoint(cp_dir, prefix):
|
||||
pattern = os.path.join(cp_dir, prefix + '????????')
|
||||
cp_list = glob.glob(pattern)
|
||||
if len(cp_list) == 0:
|
||||
return None
|
||||
return sorted(cp_list)[-1]
|
||||
|
2
server/voice_changer/SoVitsSvc40/models/readme.txt
Normal file
2
server/voice_changer/SoVitsSvc40/models/readme.txt
Normal file
@ -0,0 +1,2 @@
|
||||
modules in this folder from https://github.com/svc-develop-team/so-vits-svc at 7846711d90b5210d4f39a4d6fab50d1d7bbd8d73
|
||||
forked repo: https://github.com/w-okada/so-vits-svc
|
@ -1,466 +0,0 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
from voice_changer.utils.LoadModelParams import LoadModelParams
|
||||
from voice_changer.utils.VoiceChangerModel import AudioInOut
|
||||
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
|
||||
|
||||
if sys.platform.startswith("darwin"):
|
||||
baseDir = [x for x in sys.path if x.endswith("Contents/MacOS")]
|
||||
if len(baseDir) != 1:
|
||||
print("baseDir should be only one ", baseDir)
|
||||
sys.exit()
|
||||
modulePath = os.path.join(baseDir[0], "so-vits-svc-40v2")
|
||||
sys.path.append(modulePath)
|
||||
else:
|
||||
sys.path.append("so-vits-svc-40v2")
|
||||
|
||||
import io
|
||||
from dataclasses import dataclass, asdict, field
|
||||
import numpy as np
|
||||
import torch
|
||||
import onnxruntime
|
||||
import pyworld as pw
|
||||
|
||||
from models import SynthesizerTrn # type:ignore
|
||||
import cluster # type:ignore
|
||||
import utils
|
||||
from fairseq import checkpoint_utils
|
||||
import librosa
|
||||
|
||||
from Exceptions import NoModeLoadedException
|
||||
|
||||
providers = [
|
||||
"OpenVINOExecutionProvider",
|
||||
"CUDAExecutionProvider",
|
||||
"DmlExecutionProvider",
|
||||
"CPUExecutionProvider",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SoVitsSvc40v2Settings:
|
||||
gpu: int = 0
|
||||
dstId: int = 0
|
||||
|
||||
f0Detector: str = "harvest" # dio or harvest
|
||||
tran: int = 20
|
||||
noiseScale: float = 0.3
|
||||
predictF0: int = 0 # 0:False, 1:True
|
||||
silentThreshold: float = 0.00001
|
||||
extraConvertSize: int = 1024 * 32
|
||||
clusterInferRatio: float = 0.1
|
||||
|
||||
framework: str = "PyTorch" # PyTorch or ONNX
|
||||
pyTorchModelFile: str | None = ""
|
||||
onnxModelFile: str | None = ""
|
||||
configFile: str = ""
|
||||
|
||||
speakers: dict[str, int] = field(default_factory=lambda: {})
|
||||
|
||||
# ↓mutableな物だけ列挙
|
||||
intData = ["gpu", "dstId", "tran", "predictF0", "extraConvertSize"]
|
||||
floatData = ["noiseScale", "silentThreshold", "clusterInferRatio"]
|
||||
strData = ["framework", "f0Detector"]
|
||||
|
||||
|
||||
class SoVitsSvc40v2:
|
||||
audio_buffer: AudioInOut | None = None
|
||||
|
||||
def __init__(self, params: VoiceChangerParams):
|
||||
self.settings = SoVitsSvc40v2Settings()
|
||||
self.net_g = None
|
||||
self.onnx_session = None
|
||||
|
||||
self.raw_path = io.BytesIO()
|
||||
self.gpu_num = torch.cuda.device_count()
|
||||
self.prevVol = 0
|
||||
self.params = params
|
||||
print("[Voice Changer] so-vits-svc 40v2 initialization:", params)
|
||||
|
||||
def loadModel(self, props: LoadModelParams):
|
||||
params = props.params
|
||||
self.settings.configFile = params["files"]["soVitsSvc40v2Config"]
|
||||
self.hps = utils.get_hparams_from_file(self.settings.configFile)
|
||||
self.settings.speakers = self.hps.spk
|
||||
|
||||
modelFile = params["files"]["soVitsSvc40v2Model"]
|
||||
if modelFile.endswith(".onnx"):
|
||||
self.settings.pyTorchModelFile = None
|
||||
self.settings.onnxModelFile = modelFile
|
||||
else:
|
||||
self.settings.pyTorchModelFile = modelFile
|
||||
self.settings.onnxModelFile = None
|
||||
|
||||
clusterTorchModel = params["files"]["soVitsSvc40v2Cluster"]
|
||||
|
||||
content_vec_path = self.params.content_vec_500
|
||||
hubert_base_path = self.params.hubert_base
|
||||
|
||||
# hubert model
|
||||
try:
|
||||
if os.path.exists(content_vec_path) is False:
|
||||
content_vec_path = hubert_base_path
|
||||
|
||||
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
||||
[content_vec_path],
|
||||
suffix="",
|
||||
)
|
||||
model = models[0]
|
||||
model.eval()
|
||||
self.hubert_model = model.cpu()
|
||||
except Exception as e:
|
||||
print("EXCEPTION during loading hubert/contentvec model", e)
|
||||
|
||||
# cluster
|
||||
try:
|
||||
if clusterTorchModel is not None and os.path.exists(clusterTorchModel):
|
||||
self.cluster_model = cluster.get_cluster_model(clusterTorchModel)
|
||||
else:
|
||||
self.cluster_model = None
|
||||
except Exception as e:
|
||||
print("EXCEPTION during loading cluster model ", e)
|
||||
|
||||
# PyTorchモデル生成
|
||||
if self.settings.pyTorchModelFile is not None:
|
||||
net_g = SynthesizerTrn(self.hps)
|
||||
net_g.eval()
|
||||
self.net_g = net_g
|
||||
utils.load_checkpoint(self.settings.pyTorchModelFile, self.net_g, None)
|
||||
|
||||
# ONNXモデル生成
|
||||
if self.settings.onnxModelFile is not None:
|
||||
providers, options = self.getOnnxExecutionProvider()
|
||||
self.onnx_session = onnxruntime.InferenceSession(
|
||||
self.settings.onnxModelFile,
|
||||
providers=providers,
|
||||
provider_options=options,
|
||||
)
|
||||
return self.get_info()
|
||||
|
||||
def getOnnxExecutionProvider(self):
|
||||
availableProviders = onnxruntime.get_available_providers()
|
||||
if self.settings.gpu >= 0 and "CUDAExecutionProvider" in availableProviders:
|
||||
return ["CUDAExecutionProvider"], [{"device_id": self.settings.gpu}]
|
||||
elif self.settings.gpu >= 0 and "DmlExecutionProvider" in availableProviders:
|
||||
return ["DmlExecutionProvider"], [{}]
|
||||
else:
|
||||
return ["CPUExecutionProvider"], [
|
||||
{
|
||||
"intra_op_num_threads": 8,
|
||||
"execution_mode": onnxruntime.ExecutionMode.ORT_PARALLEL,
|
||||
"inter_op_num_threads": 8,
|
||||
}
|
||||
]
|
||||
|
||||
def isOnnx(self):
|
||||
if self.settings.onnxModelFile is not None:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def update_settings(self, key: str, val: int | float | str):
|
||||
if key in self.settings.intData:
|
||||
val = int(val)
|
||||
setattr(self.settings, key, val)
|
||||
|
||||
if key == "gpu" and self.isOnnx():
|
||||
providers, options = self.getOnnxExecutionProvider()
|
||||
if self.onnx_session is not None:
|
||||
self.onnx_session.set_providers(
|
||||
providers=providers,
|
||||
provider_options=options,
|
||||
)
|
||||
elif key in self.settings.floatData:
|
||||
setattr(self.settings, key, float(val))
|
||||
elif key in self.settings.strData:
|
||||
setattr(self.settings, key, str(val))
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_info(self):
|
||||
data = asdict(self.settings)
|
||||
|
||||
data["onnxExecutionProviders"] = (
|
||||
self.onnx_session.get_providers() if self.onnx_session is not None else []
|
||||
)
|
||||
files = ["configFile", "pyTorchModelFile", "onnxModelFile"]
|
||||
for f in files:
|
||||
if data[f] is not None and os.path.exists(data[f]):
|
||||
data[f] = os.path.basename(data[f])
|
||||
else:
|
||||
data[f] = ""
|
||||
|
||||
return data
|
||||
|
||||
def get_processing_sampling_rate(self):
|
||||
if hasattr(self, "hps") is False:
|
||||
raise NoModeLoadedException("config")
|
||||
return self.hps.data.sampling_rate
|
||||
|
||||
def get_unit_f0(self, audio_buffer, tran):
|
||||
wav_44k = audio_buffer
|
||||
# f0 = utils.compute_f0_parselmouth(wav, sampling_rate=self.target_sample, hop_length=self.hop_size)
|
||||
# f0 = utils.compute_f0_dio(wav_44k, sampling_rate=self.hps.data.sampling_rate, hop_length=self.hps.data.hop_length)
|
||||
if self.settings.f0Detector == "dio":
|
||||
f0 = compute_f0_dio(
|
||||
wav_44k,
|
||||
sampling_rate=self.hps.data.sampling_rate,
|
||||
hop_length=self.hps.data.hop_length,
|
||||
)
|
||||
else:
|
||||
f0 = compute_f0_harvest(
|
||||
wav_44k,
|
||||
sampling_rate=self.hps.data.sampling_rate,
|
||||
hop_length=self.hps.data.hop_length,
|
||||
)
|
||||
|
||||
if wav_44k.shape[0] % self.hps.data.hop_length != 0:
|
||||
print(
|
||||
f" !!! !!! !!! wav size not multiple of hopsize: {wav_44k.shape[0] / self.hps.data.hop_length}"
|
||||
)
|
||||
|
||||
f0, uv = utils.interpolate_f0(f0)
|
||||
f0 = torch.FloatTensor(f0)
|
||||
uv = torch.FloatTensor(uv)
|
||||
f0 = f0 * 2 ** (tran / 12)
|
||||
f0 = f0.unsqueeze(0)
|
||||
uv = uv.unsqueeze(0)
|
||||
|
||||
# wav16k = librosa.resample(audio_buffer, orig_sr=24000, target_sr=16000)
|
||||
wav16k = librosa.resample(
|
||||
audio_buffer, orig_sr=self.hps.data.sampling_rate, target_sr=16000
|
||||
)
|
||||
wav16k = torch.from_numpy(wav16k)
|
||||
|
||||
if (
|
||||
self.settings.gpu < 0 or self.gpu_num == 0
|
||||
) or self.settings.framework == "ONNX":
|
||||
dev = torch.device("cpu")
|
||||
else:
|
||||
dev = torch.device("cuda", index=self.settings.gpu)
|
||||
|
||||
self.hubert_model = self.hubert_model.to(dev)
|
||||
wav16k = wav16k.to(dev)
|
||||
uv = uv.to(dev)
|
||||
f0 = f0.to(dev)
|
||||
|
||||
c = utils.get_hubert_content(self.hubert_model, wav_16k_tensor=wav16k)
|
||||
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1])
|
||||
|
||||
if (
|
||||
self.settings.clusterInferRatio != 0
|
||||
and hasattr(self, "cluster_model")
|
||||
and self.cluster_model is not None
|
||||
):
|
||||
speaker = [
|
||||
key
|
||||
for key, value in self.settings.speakers.items()
|
||||
if value == self.settings.dstId
|
||||
]
|
||||
if len(speaker) != 1:
|
||||
pass
|
||||
# print("not only one speaker found.", speaker)
|
||||
else:
|
||||
cluster_c = cluster.get_cluster_center_result(
|
||||
self.cluster_model, c.cpu().numpy().T, speaker[0]
|
||||
).T
|
||||
# cluster_c = cluster.get_cluster_center_result(self.cluster_model, c.cpu().numpy().T, self.settings.dstId).T
|
||||
cluster_c = torch.FloatTensor(cluster_c).to(dev)
|
||||
# print("cluster DEVICE", cluster_c.device, c.device)
|
||||
c = (
|
||||
self.settings.clusterInferRatio * cluster_c
|
||||
+ (1 - self.settings.clusterInferRatio) * c
|
||||
)
|
||||
|
||||
c = c.unsqueeze(0)
|
||||
return c, f0, uv
|
||||
|
||||
def generate_input(
|
||||
self,
|
||||
newData: AudioInOut,
|
||||
inputSize: int,
|
||||
crossfadeSize: int,
|
||||
solaSearchFrame: int = 0,
|
||||
):
|
||||
newData = newData.astype(np.float32) / self.hps.data.max_wav_value
|
||||
|
||||
if self.audio_buffer is not None:
|
||||
self.audio_buffer = np.concatenate(
|
||||
[self.audio_buffer, newData], 0
|
||||
) # 過去のデータに連結
|
||||
else:
|
||||
self.audio_buffer = newData
|
||||
|
||||
convertSize = (
|
||||
inputSize + crossfadeSize + solaSearchFrame + self.settings.extraConvertSize
|
||||
)
|
||||
|
||||
if convertSize % self.hps.data.hop_length != 0: # モデルの出力のホップサイズで切り捨てが発生するので補う。
|
||||
convertSize = convertSize + (
|
||||
self.hps.data.hop_length - (convertSize % self.hps.data.hop_length)
|
||||
)
|
||||
convertOffset = -1 * convertSize
|
||||
self.audio_buffer = self.audio_buffer[convertOffset:] # 変換対象の部分だけ抽出
|
||||
|
||||
cropOffset = -1 * (inputSize + crossfadeSize)
|
||||
cropEnd = -1 * (crossfadeSize)
|
||||
crop = self.audio_buffer[cropOffset:cropEnd]
|
||||
|
||||
rms = np.sqrt(np.square(crop).mean(axis=0))
|
||||
vol = max(rms, self.prevVol * 0.0)
|
||||
self.prevVol = vol
|
||||
|
||||
c, f0, uv = self.get_unit_f0(self.audio_buffer, self.settings.tran)
|
||||
return (c, f0, uv, convertSize, vol)
|
||||
|
||||
def _onnx_inference(self, data):
|
||||
if hasattr(self, "onnx_session") is False or self.onnx_session is None:
|
||||
print("[Voice Changer] No onnx session.")
|
||||
raise NoModeLoadedException("ONNX")
|
||||
|
||||
convertSize = data[3]
|
||||
vol = data[4]
|
||||
data = (
|
||||
data[0],
|
||||
data[1],
|
||||
data[2],
|
||||
)
|
||||
|
||||
if vol < self.settings.silentThreshold:
|
||||
return np.zeros(convertSize).astype(np.int16)
|
||||
|
||||
c, f0, uv = [x.numpy() for x in data]
|
||||
audio1 = (
|
||||
self.onnx_session.run(
|
||||
["audio"],
|
||||
{
|
||||
"c": c,
|
||||
"f0": f0,
|
||||
"g": np.array([self.settings.dstId]).astype(np.int64),
|
||||
"uv": np.array([self.settings.dstId]).astype(np.int64),
|
||||
"predict_f0": np.array([self.settings.dstId]).astype(np.int64),
|
||||
"noice_scale": np.array([self.settings.dstId]).astype(np.int64),
|
||||
},
|
||||
)[0][0, 0]
|
||||
* self.hps.data.max_wav_value
|
||||
)
|
||||
|
||||
audio1 = audio1 * vol
|
||||
|
||||
result = audio1
|
||||
|
||||
return result
|
||||
|
||||
def _pyTorch_inference(self, data):
|
||||
if hasattr(self, "net_g") is False or self.net_g is None:
|
||||
print("[Voice Changer] No pyTorch session.")
|
||||
raise NoModeLoadedException("pytorch")
|
||||
|
||||
if self.settings.gpu < 0 or self.gpu_num == 0:
|
||||
dev = torch.device("cpu")
|
||||
else:
|
||||
dev = torch.device("cuda", index=self.settings.gpu)
|
||||
|
||||
convertSize = data[3]
|
||||
vol = data[4]
|
||||
data = (
|
||||
data[0],
|
||||
data[1],
|
||||
data[2],
|
||||
)
|
||||
|
||||
if vol < self.settings.silentThreshold:
|
||||
return np.zeros(convertSize).astype(np.int16)
|
||||
|
||||
with torch.no_grad():
|
||||
c, f0, uv = [x.to(dev) for x in data]
|
||||
sid_target = torch.LongTensor([self.settings.dstId]).to(dev)
|
||||
self.net_g.to(dev)
|
||||
# audio1 = self.net_g.infer(c, f0=f0, g=sid_target, uv=uv, predict_f0=True, noice_scale=0.1)[0][0, 0].data.float()
|
||||
predict_f0_flag = True if self.settings.predictF0 == 1 else False
|
||||
audio1 = self.net_g.infer(
|
||||
c,
|
||||
f0=f0,
|
||||
g=sid_target,
|
||||
uv=uv,
|
||||
predict_f0=predict_f0_flag,
|
||||
noice_scale=self.settings.noiseScale,
|
||||
)[0][0, 0].data.float()
|
||||
audio1 = audio1 * self.hps.data.max_wav_value
|
||||
|
||||
audio1 = audio1 * vol
|
||||
|
||||
result = audio1.float().cpu().numpy()
|
||||
|
||||
# result = infer_tool.pad_array(result, length)
|
||||
return result
|
||||
|
||||
def inference(self, data):
|
||||
if self.isOnnx():
|
||||
audio = self._onnx_inference(data)
|
||||
else:
|
||||
audio = self._pyTorch_inference(data)
|
||||
return audio
|
||||
|
||||
def __del__(self):
|
||||
del self.net_g
|
||||
del self.onnx_session
|
||||
|
||||
remove_path = os.path.join("so-vits-svc-40v2")
|
||||
sys.path = [x for x in sys.path if x.endswith(remove_path) is False]
|
||||
|
||||
for key in list(sys.modules):
|
||||
val = sys.modules.get(key)
|
||||
try:
|
||||
file_path = val.__file__
|
||||
if file_path.find("so-vits-svc-40v2" + os.path.sep) >= 0:
|
||||
# print("remove", key, file_path)
|
||||
sys.modules.pop(key)
|
||||
except: # type:ignore
|
||||
pass
|
||||
|
||||
|
||||
def resize_f0(x, target_len):
|
||||
source = np.array(x)
|
||||
source[source < 0.001] = np.nan
|
||||
target = np.interp(
|
||||
np.arange(0, len(source) * target_len, len(source)) / target_len,
|
||||
np.arange(0, len(source)),
|
||||
source,
|
||||
)
|
||||
res = np.nan_to_num(target)
|
||||
return res
|
||||
|
||||
|
||||
def compute_f0_dio(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512):
|
||||
if p_len is None:
|
||||
p_len = wav_numpy.shape[0] // hop_length
|
||||
f0, t = pw.dio(
|
||||
wav_numpy.astype(np.double),
|
||||
fs=sampling_rate,
|
||||
f0_ceil=800,
|
||||
frame_period=1000 * hop_length / sampling_rate,
|
||||
)
|
||||
f0 = pw.stonemask(wav_numpy.astype(np.double), f0, t, sampling_rate)
|
||||
for index, pitch in enumerate(f0):
|
||||
f0[index] = round(pitch, 1)
|
||||
return resize_f0(f0, p_len)
|
||||
|
||||
|
||||
def compute_f0_harvest(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512):
|
||||
if p_len is None:
|
||||
p_len = wav_numpy.shape[0] // hop_length
|
||||
f0, t = pw.harvest(
|
||||
wav_numpy.astype(np.double),
|
||||
fs=sampling_rate,
|
||||
frame_period=5.5,
|
||||
f0_floor=71.0,
|
||||
f0_ceil=1000.0,
|
||||
)
|
||||
|
||||
for index, pitch in enumerate(f0):
|
||||
f0[index] = round(pitch, 1)
|
||||
return resize_f0(f0, p_len)
|
@ -131,9 +131,9 @@ class VoiceChangerManager(ServerDeviceCallbacks):
|
||||
slotInfo = SoVitsSvc40ModelSlotGenerator.loadModel(params)
|
||||
self.modelSlotManager.save_model_slot(params.slot, slotInfo)
|
||||
elif params.voiceChangerType == "DDSP-SVC":
|
||||
from voice_changer.DDSP_SVC.DDSP_SVC import DDSP_SVC
|
||||
from voice_changer.DDSP_SVC.DDSP_SVCModelSlotGenerator import DDSP_SVCModelSlotGenerator
|
||||
|
||||
slotInfo = DDSP_SVC.loadModel(params)
|
||||
slotInfo = DDSP_SVCModelSlotGenerator.loadModel(params)
|
||||
self.modelSlotManager.save_model_slot(params.slot, slotInfo)
|
||||
print("params", params)
|
||||
|
||||
@ -195,6 +195,13 @@ class VoiceChangerManager(ServerDeviceCallbacks):
|
||||
self.voiceChangerModel = SoVitsSvc40(self.params, slotInfo)
|
||||
self.voiceChanger = VoiceChanger(self.params)
|
||||
self.voiceChanger.setModel(self.voiceChangerModel)
|
||||
elif slotInfo.voiceChangerType == "DDSP-SVC":
|
||||
print("................DDSP-SVC")
|
||||
from voice_changer.DDSP_SVC.DDSP_SVC import DDSP_SVC
|
||||
|
||||
self.voiceChangerModel = DDSP_SVC(self.params, slotInfo)
|
||||
self.voiceChanger = VoiceChanger(self.params)
|
||||
self.voiceChanger.setModel(self.voiceChangerModel)
|
||||
else:
|
||||
print(f"[Voice Changer] unknown voice changer model: {slotInfo.voiceChangerType}")
|
||||
del self.voiceChangerModel
|
||||
|
Loading…
Reference in New Issue
Block a user