voice-changer/server/voice_changer/MMVCv15/models/index.py

83 lines
2.2 KiB
Python
Raw Normal View History

2023-06-22 00:56:00 +03:00
# -*- coding: utf-8 -*-
# Copyright 2020 Yi-Chiao Wu (Nagoya University)
# MIT License (https://opensource.org/licenses/MIT)
"""Indexing-related functions."""
import torch
from torch.nn import ConstantPad1d as pad1d
def pd_indexing(x, d, dilation, batch_index, ch_index):
"""Pitch-dependent indexing of past and future samples.
Args:
x (Tensor): Input feature map (B, C, T).
d (Tensor): Input pitch-dependent dilated factors (B, 1, T).
dilation (Int): Dilation size.
batch_index (Tensor): Batch index
ch_index (Tensor): Channel index
Returns:
Tensor: Past output tensor (B, out_channels, T)
Tensor: Future output tensor (B, out_channels, T)
"""
(_, _, batch_length) = d.size()
dilations = d * dilation
# get past index
idxP = torch.arange(-batch_length, 0).float()
idxP = idxP.to(x.device)
idxP = torch.add(-dilations, idxP)
idxP = idxP.round().long()
maxP = -((torch.min(idxP) + batch_length))
assert maxP >= 0
idxP = (batch_index, ch_index, idxP)
# padding past tensor
xP = pad1d((maxP, 0), 0)(x)
# get future index
idxF = torch.arange(0, batch_length).float()
idxF = idxF.to(x.device)
idxF = torch.add(dilations, idxF)
idxF = idxF.round().long()
maxF = torch.max(idxF) - (batch_length - 1)
assert maxF >= 0
idxF = (batch_index, ch_index, idxF)
# padding future tensor
xF = pad1d((0, maxF), 0)(x)
return xP[idxP], xF[idxF]
def index_initial(n_batch, n_ch, tensor=True):
"""Tensor batch and channel index initialization.
Args:
n_batch (Int): Number of batch.
n_ch (Int): Number of channel.
tensor (bool): Return tensor or numpy array
Returns:
Tensor: Batch index
Tensor: Channel index
"""
batch_index = []
for i in range(n_batch):
batch_index.append([[i]] * n_ch)
ch_index = []
for i in range(n_ch):
ch_index += [[i]]
ch_index = [ch_index] * n_batch
if tensor:
batch_index = torch.tensor(batch_index)
ch_index = torch.tensor(ch_index)
if torch.cuda.is_available():
batch_index = batch_index.cuda()
ch_index = ch_index.cuda()
return batch_index, ch_index