import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles import logging import os, sys, base64, traceback, struct import torch import numpy as np from scipy.io.wavfile import write, read sys.path.append("mod") sys.path.append("mod/text") import utils from data_utils import TextAudioSpeakerLoader, TextAudioSpeakerCollate from models import SynthesizerTrn from text.symbols import symbols class VoiceChanger(): def __init__(self, config, model): self.hps =utils.get_hparams_from_file(config) self.net_g = SynthesizerTrn( len(symbols), self.hps.data.filter_length // 2 + 1, self.hps.train.segment_size // self.hps.data.hop_length, n_speakers=self.hps.data.n_speakers, **self.hps.model) self.net_g.eval() self.gpu_num = torch.cuda.device_count() print("GPU_NUM:",self.gpu_num) utils.load_checkpoint( model, self.net_g, None) def on_request(self, gpu, srcId, dstId, timestamp, wav): if wav==0: samplerate, data=read("dummy.wav") unpackedData = data else: unpackedData = np.array(struct.unpack('<%sh'%(len(wav) // struct.calcsize('<h') ), wav)) write("logs/received_data.wav", 24000, unpackedData.astype(np.int16)) try: if gpu<0 or self.gpu_num==0 : with torch.no_grad(): dataset = TextAudioSpeakerLoader("dummy.txt", self.hps.data, no_use_textfile=True) data = dataset.get_audio_text_speaker_pair([ unpackedData, srcId, "a"]) data = TextAudioSpeakerCollate()([data]) x, x_lengths, spec, spec_lengths, y, y_lengths, sid_src = [x.cpu() for x in data] sid_tgt1 = torch.LongTensor([dstId]).cpu() audio1 = (self.net_g.cpu().voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt1)[0][0,0].data * self.hps.data.max_wav_value).cpu().float().numpy() else: with torch.no_grad(): dataset = TextAudioSpeakerLoader("dummy.txt", self.hps.data, no_use_textfile=True) data = dataset.get_audio_text_speaker_pair([ unpackedData, srcId, "a"]) data = TextAudioSpeakerCollate()([data]) x, x_lengths, spec, spec_lengths, y, y_lengths, sid_src = [x.cuda(gpu) for x in data] sid_tgt1 = torch.LongTensor([dstId]).cuda(gpu) audio1 = (self.net_g.cuda(gpu).voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt1)[0][0,0].data * self.hps.data.max_wav_value).cpu().float().numpy() except Exception as e: print("VC PROCESSING!!!! EXCEPTION!!!", e) print(traceback.format_exc()) audio1 = audio1.astype(np.int16) return audio1 logger = logging.getLogger('uvicorn') args = sys.argv PORT = args[1] CONFIG = args[2] MODEL = args[3] logger.info('INITIALIZE MODEL') voiceChanger = VoiceChanger(CONFIG, MODEL) voiceChanger.on_request(0,0,0,0,0) app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.mount("/front", StaticFiles(directory="../frontend/dist", html=True), name="static") @app.get("/test") def get_test(): try: return request.args.get('query', '') except Exception as e: print("REQUEST PROCESSING!!!! EXCEPTION!!!", e) print(traceback.format_exc()) return str(e) class VoiceModel(BaseModel): gpu: int srcId: int dstId: int timestamp: int buffer: str @app.post("/test") def post_test(voice:VoiceModel): global voiceChanger try: print("POST REQUEST PROCESSING....") gpu = voice.gpu srcId = voice.srcId dstId = voice.dstId timestamp = voice.timestamp buffer = voice.buffer wav = base64.b64decode(buffer) changedVoice = voiceChanger.on_request(gpu, srcId, dstId, timestamp, wav) changedVoiceBase64 = base64.b64encode(changedVoice).decode('utf-8') data = { "gpu":gpu, "srcId":srcId, "dstId":dstId, "timestamp":timestamp, "changedVoiceBase64":changedVoiceBase64 } json_compatible_item_data = jsonable_encoder(data) return JSONResponse(content=json_compatible_item_data) except Exception as e: print("REQUEST PROCESSING!!!! EXCEPTION!!!", e) print(traceback.format_exc()) return str(e) if __name__ == '__main__': logger.info('START APP') uvicorn.run(f"{os.path.basename(__file__)[:-3]}:app", host="0.0.0.0", port=int(PORT), reload=True, log_level="info")