voice-changer/demo/SoftVcServerFastAPI.py
2022-09-18 22:09:33 +09:00

159 lines
4.9 KiB
Python
Executable File

import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
import logging
from logging.config import dictConfig
import os, sys, math, base64, struct, traceback, time
import torch, torchaudio
import numpy as np
from scipy.io.wavfile import write, read
from datetime import datetime
args = sys.argv
PORT = args[1]
MODE = args[2]
logger = logging.getLogger('uvicorn')
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.mount("/front", StaticFiles(directory="voice-changer/frontend/dist", html=True), name="static")
if MODE == "colab":
print("ENV: colab")
hubert_model = torch.hub.load("bshall/hubert:main", "hubert_soft").cuda()
acoustic_model = torch.hub.load("bshall/acoustic-model:main", "hubert_soft").cuda()
hifigan_model = torch.hub.load("bshall/hifigan:main", "hifigan_hubert_soft").cuda()
else:
print("ENV: Docker")
app.mount("/front", StaticFiles(directory="../frontend/dist", html=True), name="static")
sys.path.append("/hubert")
from hubert import hubert_discrete, hubert_soft, kmeans100
sys.path.append("/acoustic-model")
from acoustic import hubert_discrete, hubert_soft
sys.path.append("/hifigan")
from hifigan import hifigan
hubert_model = torch.load("/models/bshall_hubert_main.pt").cuda()
acoustic_model = torch.load("/models/bshall_acoustic-model_main.pt").cuda()
hifigan_model = torch.load("/models/bshall_hifigan_main.pt").cuda()
def applyVol(i, chunk, vols):
curVol = vols[i] / 2
if curVol < 0.0001:
line = torch.zeros(chunk.size())
else:
line = torch.ones(chunk.size())
volApplied = torch.mul(line, chunk)
volApplied = volApplied.unsqueeze(0)
return volApplied
@app.get("/test")
def get_test():
try:
return request.args.get('query', '')
except Exception as e:
print("REQUEST PROCESSING!!!! EXCEPTION!!!", e)
print(traceback.format_exc())
return str(e)
class VoiceModel(BaseModel):
gpu: int
srcId: int
dstId: int
timestamp: int
buffer: str
@app.post("/test")
def post_test(voice:VoiceModel):
try:
print("POST REQUEST PROCESSING....")
gpu = voice.gpu
srcId = voice.srcId
dstId = voice.dstId
timestamp = voice.timestamp
buffer = voice.buffer
wav = base64.b64decode(buffer)
unpackedData = np.array(struct.unpack('<%sh'%(len(wav) // struct.calcsize('<h') ), wav))
# received_data_file = f"received_data_{timestamp}.wav"
received_data_file = "received_data.wav"
write(received_data_file, 24000, unpackedData.astype(np.int16))
source, sr = torchaudio.load(received_data_file) # デフォルトでnormalize=Trueがついており、float32に変換して読んでくれるらしいのでこれを使う。https://pytorch.org/audio/stable/backend.html
source_16k = torchaudio.functional.resample(source, 24000, 16000)
source_16k = source_16k.unsqueeze(0).cuda()
# SOFT-VC
with torch.inference_mode():
units = hubert_model.units(source_16k)
mel = acoustic_model.generate(units).transpose(1, 2)
target = hifigan_model(mel)
dest = torchaudio.functional.resample(target, 16000,24000)
dest = dest.squeeze().cpu()
# ソースの音量取得
source = source.cpu()
specgram = torchaudio.transforms.MelSpectrogram(sample_rate=24000)(source)
vol_apply_window_size = math.ceil(len(source[0]) / specgram.size()[2])
specgram = specgram.transpose(1,2)
vols = [ torch.max(i) for i in specgram[0]]
chunks = torch.split(dest, vol_apply_window_size,0)
chunks = [applyVol(i,c,vols) for i, c in enumerate(chunks)]
dest = torch.cat(chunks,1)
arr = np.array(dest.squeeze())
int_size = 2**(16 - 1) - 1
arr = (arr * int_size).astype(np.int16)
write("converted_data.wav", 24000, arr)
changedVoiceBase64 = base64.b64encode(arr).decode('utf-8')
data = {
"gpu":gpu,
"srcId":srcId,
"dstId":dstId,
"timestamp":timestamp,
"changedVoiceBase64":changedVoiceBase64
}
json_compatible_item_data = jsonable_encoder(data)
return JSONResponse(content=json_compatible_item_data)
except Exception as e:
print("REQUEST PROCESSING!!!! EXCEPTION!!!", e)
print(traceback.format_exc())
return str(e)
if __name__ == '__main__':
args = sys.argv
PORT = args[1]
MODE = args[2]
logger.info('INITIALIZE MODEL')
logger.info('START APP')
uvicorn.run(f"{os.path.basename(__file__)[:-3]}:app", host="0.0.0.0", port=int(PORT), reload=True, log_level="info")