WIP: switch base from trainer to client

This commit is contained in:
wataru 2023-01-14 06:44:30 +09:00
parent ae66ec3d3f
commit 9f3dab0295
5 changed files with 200 additions and 22 deletions

1
.gitignore vendored
View File

@ -4,6 +4,7 @@ __pycache__
server/upload_dir/
server/MMVC_Trainer/
server/MMVC_Client/
server/key
server/info

View File

@ -8,8 +8,9 @@ from distutils.util import strtobool
from scipy.io.wavfile import write, read
sys.path.append("MMVC_Trainer")
sys.path.append("MMVC_Trainer/text")
# sys.path.append("MMVC_Trainer")
# sys.path.append("MMVC_Trainer/text")
sys.path.append("MMVC_Client/python")
from fastapi.routing import APIRoute
from fastapi import HTTPException, FastAPI, UploadFile, File, Form

View File

@ -2,13 +2,15 @@ import logging
# logging.getLogger('numba').setLevel(logging.WARNING)
class UvicornSuppressFilter(logging.Filter):
def filter(self, record):
return False
# class UvicornSuppressFilter(logging.Filter):
# def filter(self, record):
# return False
# logger = logging.getLogger("uvicorn.error")
# logger.addFilter(UvicornSuppressFilter())
logger = logging.getLogger("uvicorn.error")
logger.addFilter(UvicornSuppressFilter())
# logger.propagate = False
logger = logging.getLogger("multipart.multipart")
logger.propagate = False

View File

@ -0,0 +1,151 @@
import torch
import os, sys, json
import logging
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging
hann_window = {}
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
if torch.min(y) < -1.:
print('min value is ', torch.min(y))
if torch.max(y) > 1.:
print('max value is ', torch.max(y))
global hann_window
dtype_device = str(y.dtype) + '_' + str(y.device)
wnsize_dtype_device = str(win_size) + '_' + dtype_device
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
spec = torch.view_as_real(spec)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
class TextAudioSpeakerCollate():
""" Zero-pads model inputs and targets
"""
def __init__(self, return_ids=False, no_text = False):
self.return_ids = return_ids
self.no_text = no_text
def __call__(self, batch):
"""Collate's training batch from normalized text, audio and speaker identities
PARAMS
------
batch: [text_normalized, spec_normalized, wav_normalized, sid]
"""
# Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort(
torch.LongTensor([x[1].size(1) for x in batch]),
dim=0, descending=True)
max_text_len = max([len(x[0]) for x in batch])
max_spec_len = max([x[1].size(1) for x in batch])
max_wav_len = max([x[2].size(1) for x in batch])
text_lengths = torch.LongTensor(len(batch))
spec_lengths = torch.LongTensor(len(batch))
wav_lengths = torch.LongTensor(len(batch))
sid = torch.LongTensor(len(batch))
text_padded = torch.LongTensor(len(batch), max_text_len)
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
text_padded.zero_()
spec_padded.zero_()
wav_padded.zero_()
for i in range(len(ids_sorted_decreasing)):
row = batch[ids_sorted_decreasing[i]]
text = row[0]
text_padded[i, :text.size(0)] = text
text_lengths[i] = text.size(0)
spec = row[1]
spec_padded[i, :, :spec.size(1)] = spec
spec_lengths[i] = spec.size(1)
wav = row[2]
wav_padded[i, :, :wav.size(1)] = wav
wav_lengths[i] = wav.size(1)
sid[i] = row[3]
if self.return_ids:
return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, sid, ids_sorted_decreasing
return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, sid
def load_checkpoint(checkpoint_path, model, optimizer=None):
assert os.path.isfile(checkpoint_path), f"No such file or directory: {checkpoint_path}"
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
iteration = checkpoint_dict['iteration']
learning_rate = checkpoint_dict['learning_rate']
if optimizer is not None:
optimizer.load_state_dict(checkpoint_dict['optimizer'])
saved_state_dict = checkpoint_dict['model']
if hasattr(model, 'module'):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict= {}
for k, v in state_dict.items():
try:
new_state_dict[k] = saved_state_dict[k]
except:
logger.info("%s is not in the checkpoint" % k)
new_state_dict[k] = v
if hasattr(model, 'module'):
model.module.load_state_dict(new_state_dict)
else:
model.load_state_dict(new_state_dict)
logger.info("Loaded checkpoint '{}' (iteration {})" .format(
checkpoint_path, iteration))
return model, optimizer, learning_rate, iteration
def get_hparams_from_file(config_path):
with open(config_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams =HParams(**config)
return hparams
class HParams():
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
v = HParams(**v)
self[k] = v
def keys(self):
return self.__dict__.keys()
def items(self):
return self.__dict__.items()
def values(self):
return self.__dict__.values()
def __len__(self):
return len(self.__dict__)
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
return setattr(self, key, value)
def __contains__(self, key):
return key in self.__dict__
def __repr__(self):
return self.__dict__.__repr__()

View File

@ -4,17 +4,34 @@ import math, os, traceback
from scipy.io.wavfile import write, read
import numpy as np
from dataclasses import dataclass, asdict
import utils
import commons
from models import SynthesizerTrn
from text.symbols import symbols
from data_utils import TextAudioSpeakerLoader, TextAudioSpeakerCollate
from mel_processing import spectrogram_torch
from text import text_to_sequence, cleaned_text_to_sequence
import onnxruntime
# import utils
# import commons
# from models import SynthesizerTrn
#from text.symbols import symbols
# from data_utils import TextAudioSpeakerLoader, TextAudioSpeakerCollate
# from mel_processing import spectrogram_torch
#from text import text_to_sequence, cleaned_text_to_sequence
################
from symbols import symbols
# from mmvc_client import get_hparams_from_file, load_checkpoint
from models import SynthesizerTrn
################
# from voice_changer.utils import get_hparams_from_file, load_checkpoint
# from voice_changer.models import SynthesizerTrn
# from voice_changer.symbols import symbols
from voice_changer.TrainerFunctions import TextAudioSpeakerCollate, spectrogram_torch, load_checkpoint, get_hparams_from_file
providers = ['OpenVINOExecutionProvider',"CUDAExecutionProvider","DmlExecutionProvider","CPUExecutionProvider"]
@dataclass
@ -49,12 +66,17 @@ class VoiceChanger():
self.currentCrossFadeOverlapRate=0
# 共通で使用する情報を収集
self.hps = utils.get_hparams_from_file(config)
# self.hps = utils.get_hparams_from_file(config)
self.hps = get_hparams_from_file(config)
self.gpu_num = torch.cuda.device_count()
text_norm = text_to_sequence("a", self.hps.data.text_cleaners)
text_norm = commons.intersperse(text_norm, 0)
self.text_norm = torch.LongTensor(text_norm)
# text_norm = text_to_sequence("a", self.hps.data.text_cleaners)
# print("text_norm1: ",text_norm)
# text_norm = commons.intersperse(text_norm, 0)
# print("text_norm2: ",text_norm)
# self.text_norm = torch.LongTensor(text_norm)
self.text_norm = torch.LongTensor([0, 6, 0])
self.audio_buffer = torch.zeros(1, 0)
self.prev_audio = np.zeros(1)
self.mps_enabled = getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available()
@ -77,7 +99,8 @@ class VoiceChanger():
n_speakers=self.hps.data.n_speakers,
**self.hps.model)
self.net_g.eval()
utils.load_checkpoint(pyTorch_model_file, self.net_g, None)
load_checkpoint(pyTorch_model_file, self.net_g, None)
# utils.load_checkpoint(pyTorch_model_file, self.net_g, None)
# ONNXモデル生成
if onnx_model_file != None:
@ -232,7 +255,7 @@ class VoiceChanger():
with torch.no_grad():
x, x_lengths, spec, spec_lengths, y, y_lengths, sid_src = [x.cpu() for x in data]
sid_tgt1 = torch.LongTensor([self.settings.dstId]).cpu()
audio1 = (self.net_g.cpu().voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt1)[0][0, 0].data * self.hps.data.max_wav_value)
audio1 = (self.net_g.cpu().voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt1)[0, 0].data * self.hps.data.max_wav_value)
if self.prev_strength.device != torch.device('cpu'):
print(f"prev_strength move from {self.prev_strength.device} to cpu")
@ -263,7 +286,7 @@ class VoiceChanger():
with torch.no_grad():
x, x_lengths, spec, spec_lengths, y, y_lengths, sid_src = [x.cuda(self.settings.gpu) for x in data]
sid_tgt1 = torch.LongTensor([self.settings.dstId]).cuda(self.settings.gpu)
audio1 = self.net_g.cuda(self.settings.gpu).voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt1)[0][0, 0].data * self.hps.data.max_wav_value
audio1 = self.net_g.cuda(self.settings.gpu).voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt1)[0, 0].data * self.hps.data.max_wav_value
if self.prev_strength.device != torch.device('cuda', self.settings.gpu):
print(f"prev_strength move from {self.prev_strength.device} to gpu{self.settings.gpu}")