mirror of
https://github.com/w-okada/voice-changer.git
synced 2025-02-03 00:33:57 +03:00
157 lines
6.3 KiB
Python
157 lines
6.3 KiB
Python
# based on https://github.com/YangangCao/Causal-U-Net/blob/main/cunet.py
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
"""
|
|
Based on https://github.com/f90/Seq-U-Net/blob/master/sequnet_res.py
|
|
"""
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, dilation, dropout, use_2d):
|
|
super().__init__()
|
|
self.use_2d = use_2d
|
|
if use_2d:
|
|
self.filter = nn.Conv2d(in_channels, out_channels, kernel_size, dilation=dilation)
|
|
self.gate = nn.Conv2d(in_channels, out_channels, kernel_size, dilation=dilation)
|
|
self.dropout = nn.Dropout2d(dropout)
|
|
else:
|
|
self.filter = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation)
|
|
self.gate = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation)
|
|
self.dropout = nn.Dropout1d(dropout)
|
|
self.output_crop = dilation * (kernel_size - 1)
|
|
|
|
def forward(self, x):
|
|
filtered = torch.tanh(self.filter(x))
|
|
gated = torch.sigmoid(self.gate(x))
|
|
residual = filtered * gated
|
|
# pad dim 1 of x to match residual
|
|
if self.use_2d:
|
|
x = F.pad(x, (0, 0, 0, 0, 0, residual.shape[1] - x.shape[1]))
|
|
output = x[..., self.output_crop :, self.output_crop :] + residual
|
|
else:
|
|
x = F.pad(x, (0, 0, 0, residual.shape[1] - x.shape[1]))
|
|
output = x[..., self.output_crop :] + residual
|
|
output = self.dropout(output)
|
|
return output
|
|
|
|
|
|
class CausalConvBlock(nn.Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size, dilation, dropout, use_2d):
|
|
super().__init__()
|
|
if use_2d:
|
|
conv_layer = nn.Conv2d
|
|
batchnorm_layer = nn.BatchNorm2d
|
|
dropout_layer = nn.Dropout2d
|
|
else:
|
|
conv_layer = nn.Conv1d
|
|
batchnorm_layer = nn.BatchNorm1d
|
|
dropout_layer = nn.Dropout1d
|
|
self.conv = nn.Sequential(
|
|
conv_layer(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, dilation=dilation),
|
|
batchnorm_layer(num_features=out_channels),
|
|
dropout_layer(dropout),
|
|
nn.LeakyReLU(inplace=True),
|
|
)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
1D Causal convolution.
|
|
"""
|
|
return self.conv(x)
|
|
|
|
|
|
class CachedConvNet(nn.Module):
|
|
def __init__(self, num_channels, kernel_sizes, dilations, dropout, combine_residuals, use_residual_blocks, out_channels, use_2d, use_pool=False, pool_kernel=2):
|
|
super().__init__()
|
|
assert len(kernel_sizes) == len(dilations), "kernel_sizes and dilations must be the same length"
|
|
assert len(kernel_sizes) == len(out_channels), "kernel_sizes and out_channels must be the same length"
|
|
self.num_layers = len(kernel_sizes)
|
|
self.ctx_height = max(out_channels)
|
|
self.down_convs = nn.ModuleList()
|
|
self.num_channels = num_channels
|
|
self.kernel_sizes = kernel_sizes
|
|
self.combine_residuals = combine_residuals
|
|
self.use_2d = use_2d
|
|
self.use_pool = use_pool
|
|
|
|
# compute buffer lengths for each layer
|
|
self.buf_lengths = [(k - 1) * d for k, d in zip(kernel_sizes, dilations)]
|
|
|
|
# Compute buffer start indices for each layer
|
|
self.buf_indices = [0]
|
|
for i in range(len(kernel_sizes) - 1):
|
|
self.buf_indices.append(self.buf_indices[-1] + self.buf_lengths[i])
|
|
|
|
if use_residual_blocks:
|
|
block = ResidualBlock
|
|
else:
|
|
block = CausalConvBlock
|
|
|
|
if self.use_pool:
|
|
self.pool = nn.AvgPool1d(kernel_size=pool_kernel)
|
|
|
|
for i in range(self.num_layers):
|
|
in_channel = num_channels if i == 0 else out_channels[i - 1]
|
|
self.down_convs.append(block(in_channels=in_channel, out_channels=out_channels[i], kernel_size=kernel_sizes[i], dilation=dilations[i], dropout=dropout, use_2d=use_2d))
|
|
|
|
def init_ctx_buf(self, batch_size, device, height=None):
|
|
"""
|
|
Initialize context buffer for each layer.
|
|
"""
|
|
if height is not None:
|
|
up_ctx = torch.zeros((batch_size, self.ctx_height, height, sum(self.buf_lengths))).to(device)
|
|
else:
|
|
up_ctx = torch.zeros((batch_size, self.ctx_height, sum(self.buf_lengths))).to(device)
|
|
return up_ctx
|
|
|
|
def forward(self, x, ctx):
|
|
"""
|
|
Args:
|
|
x: [B, in_channels, T]
|
|
Input
|
|
ctx: {[B, channels, self.buf_length[0]], ...}
|
|
A list of tensors holding context for each unet layer. (len(ctx) == self.num_layers)
|
|
Returns:
|
|
x: [B, out_channels, T]
|
|
ctx: {[B, channels, self.buf_length[0]], ...}
|
|
Updated context buffer with output as the
|
|
last element.
|
|
"""
|
|
if self.use_pool:
|
|
x = self.pool(x)
|
|
|
|
for i in range(self.num_layers):
|
|
buf_start_idx = self.buf_indices[i]
|
|
buf_end_idx = self.buf_indices[i] + self.buf_lengths[i]
|
|
|
|
# concatenate context buffer with input
|
|
if self.use_2d:
|
|
conv_in = torch.cat((ctx[..., : x.shape[1], : x.shape[-2], buf_start_idx:buf_end_idx], x), dim=-1)
|
|
else:
|
|
conv_in = torch.cat((ctx[..., : x.shape[-2], buf_start_idx:buf_end_idx], x), dim=-1)
|
|
|
|
# Push current output to the context buffer
|
|
if self.use_2d:
|
|
ctx[..., : x.shape[1], : x.shape[-2], buf_start_idx:buf_end_idx] = conv_in[..., -self.buf_lengths[i] :]
|
|
else:
|
|
ctx[..., : x.shape[1], buf_start_idx:buf_end_idx] = conv_in[..., -self.buf_lengths[i] :]
|
|
|
|
# pad second-to-last index of input with self.buf_lengths[i] // 2 zeros
|
|
# on each side to ensure that height of output is the same as input
|
|
if self.use_2d:
|
|
conv_in = F.pad(conv_in, (0, 0, self.buf_lengths[i] // 2, self.buf_lengths[i] // 2))
|
|
|
|
if self.combine_residuals == "add":
|
|
x = x + self.down_convs[i](conv_in)
|
|
elif self.combine_residuals == "multiply":
|
|
x = x * self.down_convs[i](conv_in)
|
|
else:
|
|
x = self.down_convs[i](conv_in)
|
|
|
|
if self.use_pool:
|
|
x = F.interpolate(x, scale_factor=self.pool.kernel_size[0])
|
|
|
|
return x, ctx
|