import sys, os, struct, argparse, logging, shutil, base64, traceback from dataclasses import dataclass from datetime import datetime from distutils.util import strtobool import numpy as np from scipy.io.wavfile import write, read sys.path.append("MMVC_Trainer") sys.path.append("MMVC_Trainer/text") from fastapi.routing import APIRoute from fastapi import HTTPException, Request, Response, FastAPI, UploadFile, File, Form from fastapi.staticfiles import StaticFiles from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware import uvicorn import socketio from pydantic import BaseModel from typing import Callable from mods.Trainer_Speakers import mod_get_speakers from mods.Trainer_Training import mod_post_pre_training, mod_post_start_training, mod_post_stop_training, mod_get_related_files, mod_get_tail_training_log from mods.Trainer_Model import mod_get_model, mod_delete_model from mods.Trainer_Models import mod_get_models from mods.Trainer_MultiSpeakerSetting import mod_get_multi_speaker_setting, mod_post_multi_speaker_setting from mods.Trainer_Speaker_Voice import mod_get_speaker_voice from mods.Trainer_Speaker_Voices import mod_get_speaker_voices from mods.Trainer_Speaker import mod_delete_speaker from mods.FileUploader import upload_file, concat_file_chunks from mods.VoiceChanger import VoiceChanger from mods.ssl import create_self_signed_cert # File Uploader # Trainer Rest Internal class UvicornSuppressFilter(logging.Filter): def filter(self, record): return False logger = logging.getLogger("uvicorn.error") logger.addFilter(UvicornSuppressFilter()) # logger.propagate = False logger = logging.getLogger("multipart.multipart") logger.propagate = False @dataclass class ExApplicationInfo(): external_tensorboard_port: int exApplitionInfo = ExApplicationInfo(external_tensorboard_port=0) class VoiceModel(BaseModel): gpu: int srcId: int dstId: int timestamp: int prefixChunkSize: int buffer: str class MyCustomNamespace(socketio.AsyncNamespace): def __init__(self, namespace): super().__init__(namespace) def loadModel(self, config, model): if hasattr(self, 'voiceChanger') == True: self.voiceChanger.destroy() self.voiceChanger = VoiceChanger(config, model) # def loadWhisperModel(self, model): # self.whisper = Whisper() # self.whisper.loadModel("tiny") # print("load") def changeVoice(self, gpu, srcId, dstId, timestamp, prefixChunkSize, unpackedData): # if hasattr(self, 'whisper') == True: # self.whisper.addData(unpackedData) if hasattr(self, 'voiceChanger') == True: return self.voiceChanger.on_request(gpu, srcId, dstId, timestamp, prefixChunkSize, unpackedData) else: print("Voice Change is not loaded. Did you load a correct model?") return np.zeros(1).astype(np.int16) # def transcribe(self): # if hasattr(self, 'whisper') == True: # self.whisper.transcribe(0) # else: # print("whisper not found") def on_connect(self, sid, environ): # print('[{}] connet sid : {}'.format(datetime.now().strftime('%Y-%m-%d %H:%M:%S') , sid)) pass async def on_request_message(self, sid, msg): # print("on_request_message", torch.cuda.memory_allocated()) gpu = int(msg[0]) srcId = int(msg[1]) dstId = int(msg[2]) timestamp = int(msg[3]) prefixChunkSize = int(msg[4]) data = msg[5] # print(srcId, dstId, timestamp) unpackedData = np.array(struct.unpack( '<%sh' % (len(data) // struct.calcsize(' Callable: original_route_handler = super().get_route_handler() async def custom_route_handler(request: Request) -> Response: try: return await original_route_handler(request) except Exception as exc: print("Exception", request.url, str(exc)) body = await request.body() detail = {"errors": exc.errors(), "body": body.decode()} raise HTTPException(status_code=422, detail=detail) return custom_route_handler if __name__ == thisFilename or args.colab == True: printMessage(f"PHASE3:{__name__}", level=2) TYPE = args.t PORT = args.p CONFIG = args.c MODEL = args.m if os.getenv("EX_TB_PORT"): EX_TB_PORT = os.environ["EX_TB_PORT"] exApplitionInfo.external_tensorboard_port = int(EX_TB_PORT) app_fastapi = FastAPI() app_fastapi.router.route_class = ValidationErrorLoggingRoute app_fastapi.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app_fastapi.mount( "/front", StaticFiles(directory="../frontend/dist", html=True), name="static") app_fastapi.mount( "/trainer", StaticFiles(directory="../frontend/dist", html=True), name="static") app_fastapi.mount( "/recorder", StaticFiles(directory="../frontend/dist", html=True), name="static") sio = socketio.AsyncServer( async_mode='asgi', cors_allowed_origins='*' ) namespace = MyCustomNamespace('/test') sio.register_namespace(namespace) if CONFIG and MODEL: namespace.loadModel(CONFIG, MODEL) # namespace.loadWhisperModel("base") app_socketio = socketio.ASGIApp( sio, other_asgi_app=app_fastapi, static_files={ '/assets/icons/github.svg': { 'filename': '../frontend/dist/assets/icons/github.svg', 'content_type': 'image/svg+xml' }, '': '../frontend/dist', '/': '../frontend/dist/index.html', } ) @app_fastapi.get("/api/hello") async def index(): return {"result": "Index"} ############ # File Uploder # ########## UPLOAD_DIR = "upload_dir" os.makedirs(UPLOAD_DIR, exist_ok=True) MODEL_DIR = "MMVC_Trainer/logs" os.makedirs(MODEL_DIR, exist_ok=True) @app_fastapi.post("/upload_file") async def post_upload_file( file: UploadFile = File(...), filename: str = Form(...) ): return upload_file(UPLOAD_DIR, file, filename) @app_fastapi.post("/load_model") async def post_load_model( modelFilename: str = Form(...), modelFilenameChunkNum: int = Form(...), configFilename: str = Form(...) ): modelFilePath = concat_file_chunks( UPLOAD_DIR, modelFilename, modelFilenameChunkNum, UPLOAD_DIR) print(f'File saved to: {modelFilePath}') configFilePath = os.path.join(UPLOAD_DIR, configFilename) namespace.loadModel(configFilePath, modelFilePath) return {"load": f"{modelFilePath}, {configFilePath}"} @app_fastapi.post("/load_model_for_train") async def post_load_model_for_train( modelGFilename: str = Form(...), modelGFilenameChunkNum: int = Form(...), modelDFilename: str = Form(...), modelDFilenameChunkNum: int = Form(...), ): modelGFilePath = concat_file_chunks( UPLOAD_DIR, modelGFilename, modelGFilenameChunkNum, MODEL_DIR) modelDFilePath = concat_file_chunks( UPLOAD_DIR, modelDFilename, modelDFilenameChunkNum, MODEL_DIR) return {"File saved": f"{modelGFilePath}, {modelDFilePath}"} @app_fastapi.post("/extract_voices") async def post_load_model( zipFilename: str = Form(...), zipFileChunkNum: int = Form(...), ): zipFilePath = concat_file_chunks( UPLOAD_DIR, zipFilename, zipFileChunkNum, UPLOAD_DIR) shutil.unpack_archive(zipFilePath, "MMVC_Trainer/dataset/textful/") return {"Zip file unpacked": f"{zipFilePath}"} ############ # Voice Changer # ########## @app_fastapi.post("/test") async def post_test(voice: VoiceModel): try: # print("POST REQUEST PROCESSING....") gpu = voice.gpu srcId = voice.srcId dstId = voice.dstId timestamp = voice.timestamp prefixChunkSize = voice.prefixChunkSize buffer = voice.buffer wav = base64.b64decode(buffer) if wav == 0: samplerate, data = read("dummy.wav") unpackedData = data else: unpackedData = np.array(struct.unpack( '<%sh' % (len(wav) // struct.calcsize(':/ with your browser.", level=0) else: printMessage( f"open http://:/ with your browser.", level=0) if TYPE == "MMVC": path = "" else: path = "trainer" if EX_PORT and EX_IP and args.https == 1: printMessage(f"In many cases it is one of the following", level=1) printMessage(f"https://localhost:{EX_PORT}/{path}", level=1) for ip in EX_IP.strip().split(" "): printMessage(f"https://{ip}:{EX_PORT}/{path}", level=1) elif EX_PORT and EX_IP and args.https == 0: printMessage(f"In many cases it is one of the following", level=1) printMessage(f"http://localhost:{EX_PORT}/{path}", level=1) # サーバ起動 if args.https: # HTTPS サーバ起動 uvicorn.run( f"{os.path.basename(__file__)[:-3]}:app_socketio", host="0.0.0.0", port=int(PORT), reload=True, ssl_keyfile=key_path, ssl_certfile=cert_path, log_level="critical" ) else: # HTTP サーバ起動 if args.colab == True: uvicorn.run( f"{os.path.basename(__file__)[:-3]}:app_fastapi", host="0.0.0.0", port=int(PORT), log_level="critical" ) else: uvicorn.run( f"{os.path.basename(__file__)[:-3]}:app_socketio", host="0.0.0.0", port=int(PORT), reload=True, log_level="critical" )