From 86e71d05a366f59f44f6b81b2c05bce7f6c4b96f Mon Sep 17 00:00:00 2001 From: wataru Date: Sat, 31 Dec 2022 16:02:53 +0900 Subject: [PATCH] separate sio --- .gitignore | 5 + server/MMVCServerSIO.py | 463 ++++++++++++++++++++ server/sio/MMVC_Namespace.py | 43 ++ server/voice_changer/VoiceChangerManager.py | 21 + 4 files changed, 532 insertions(+) create mode 100755 server/MMVCServerSIO.py create mode 100644 server/sio/MMVC_Namespace.py create mode 100644 server/voice_changer/VoiceChangerManager.py diff --git a/.gitignore b/.gitignore index 13682990..b02c05d6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,7 @@ dummy node_modules +__pycache__ + +server/upload_dir/ +server/MMVC_Trainer/ +server/key \ No newline at end of file diff --git a/server/MMVCServerSIO.py b/server/MMVCServerSIO.py new file mode 100755 index 00000000..519d9a5d --- /dev/null +++ b/server/MMVCServerSIO.py @@ -0,0 +1,463 @@ +import sys, os, struct, argparse, logging, shutil, base64, traceback +logging.getLogger('numba').setLevel(logging.WARNING) + +class UvicornSuppressFilter(logging.Filter): + def filter(self, record): + return False + +logger = logging.getLogger("uvicorn.error") +logger.addFilter(UvicornSuppressFilter()) +# logger.propagate = False +logger = logging.getLogger("multipart.multipart") +logger.propagate = False + + +from dataclasses import dataclass +from datetime import datetime +from distutils.util import strtobool + +import numpy as np +from scipy.io.wavfile import write, read + +sys.path.append("MMVC_Trainer") +sys.path.append("MMVC_Trainer/text") + +from fastapi.routing import APIRoute +from fastapi import HTTPException, Request, Response, FastAPI, UploadFile, File, Form +from fastapi.staticfiles import StaticFiles +from fastapi.encoders import jsonable_encoder +from fastapi.responses import JSONResponse +from fastapi.middleware.cors import CORSMiddleware +import uvicorn +import socketio +from pydantic import BaseModel + +from typing import Callable + +from mods.Trainer_Speakers import mod_get_speakers +from mods.Trainer_Training import mod_post_pre_training, mod_post_start_training, mod_post_stop_training, mod_get_related_files, mod_get_tail_training_log +from mods.Trainer_Model import mod_get_model, mod_delete_model + +from mods.Trainer_Models import mod_get_models +from mods.Trainer_MultiSpeakerSetting import mod_get_multi_speaker_setting, mod_post_multi_speaker_setting +from mods.Trainer_Speaker_Voice import mod_get_speaker_voice +from mods.Trainer_Speaker_Voices import mod_get_speaker_voices + +from mods.Trainer_Speaker import mod_delete_speaker +from mods.FileUploader import upload_file, concat_file_chunks + +from mods.VoiceChanger import VoiceChanger + +from mods.ssl import create_self_signed_cert + +from sio.MMVC_Namespace import MMVC_Namespace +from voice_changer.VoiceChangerManager import VoiceChangerManager +@dataclass +class ExApplicationInfo(): + external_tensorboard_port: int + + +exApplitionInfo = ExApplicationInfo(external_tensorboard_port=0) + + +class VoiceModel(BaseModel): + gpu: int + srcId: int + dstId: int + timestamp: int + prefixChunkSize: int + buffer: str + +def setupArgParser(): + parser = argparse.ArgumentParser() + parser.add_argument("-t", type=str, default="MMVC", + help="Server type. MMVC|TRAIN") + parser.add_argument("-p", type=int, default=8080, help="port") + parser.add_argument("-c", type=str, help="path for the config.json") + parser.add_argument("-m", type=str, help="path for the model file") + parser.add_argument("--https", type=strtobool, + default=False, help="use https") + parser.add_argument("--httpsKey", type=str, + default="ssl.key", help="path for the key of https") + parser.add_argument("--httpsCert", type=str, + default="ssl.cert", help="path for the cert of https") + parser.add_argument("--httpsSelfSigned", type=strtobool, + default=True, help="generate self-signed certificate") + parser.add_argument("--colab", type=strtobool, + default=False, help="run on colab") + return parser + + +def printMessage(message, level=0): + if level == 0: + print(f"\033[17m{message}\033[0m") + elif level == 1: + print(f"\033[34m {message}\033[0m") + elif level == 2: + print(f"\033[32m {message}\033[0m") + else: + print(f"\033[47m {message}\033[0m") + + + +global app_socketio +global app_fastapi + +parser = setupArgParser() +args = parser.parse_args() + +printMessage(f"Phase name:{__name__}", level=2) +thisFilename = os.path.basename(__file__)[:-3] + + +class ValidationErrorLoggingRoute(APIRoute): + def get_route_handler(self) -> Callable: + original_route_handler = super().get_route_handler() + + async def custom_route_handler(request: Request) -> Response: + try: + return await original_route_handler(request) + except Exception as exc: + print("Exception", request.url, str(exc)) + body = await request.body() + detail = {"errors": exc.errors(), "body": body.decode()} + raise HTTPException(status_code=422, detail=detail) + + return custom_route_handler + + +if __name__ == thisFilename or args.colab == True: + printMessage(f"PHASE3:{__name__}", level=2) + TYPE = args.t + PORT = args.p + CONFIG = args.c + MODEL = args.m + + if os.getenv("EX_TB_PORT"): + EX_TB_PORT = os.environ["EX_TB_PORT"] + exApplitionInfo.external_tensorboard_port = int(EX_TB_PORT) + + + app_fastapi = FastAPI() + app_fastapi.router.route_class = ValidationErrorLoggingRoute + app_fastapi.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + app_fastapi.mount( + "/front", StaticFiles(directory="../frontend/dist", html=True), name="static") + + app_fastapi.mount( + "/trainer", StaticFiles(directory="../frontend/dist", html=True), name="static") + + app_fastapi.mount( + "/recorder", StaticFiles(directory="../frontend/dist", html=True), name="static") + + sio = socketio.AsyncServer( + async_mode='asgi', + cors_allowed_origins='*' + ) + voiceChangerManager = VoiceChangerManager.get_instance() + namespace = MMVC_Namespace.get_instance(voiceChangerManager) + sio.register_namespace(namespace) + if CONFIG and MODEL: + voiceChangerManager.loadModel(CONFIG, MODEL) + # namespace.loadWhisperModel("base") + + app_socketio = socketio.ASGIApp( + sio, + other_asgi_app=app_fastapi, + static_files={ + '/assets/icons/github.svg': { + 'filename': '../frontend/dist/assets/icons/github.svg', + 'content_type': 'image/svg+xml' + }, + '': '../frontend/dist', + '/': '../frontend/dist/index.html', + } + ) + + @app_fastapi.get("/api/hello") + async def index(): + return {"result": "Index"} + + ############ + # File Uploder + # ########## + UPLOAD_DIR = "upload_dir" + os.makedirs(UPLOAD_DIR, exist_ok=True) + MODEL_DIR = "MMVC_Trainer/logs" + os.makedirs(MODEL_DIR, exist_ok=True) + + @app_fastapi.post("/upload_file") + async def post_upload_file( + file: UploadFile = File(...), + filename: str = Form(...) + ): + return upload_file(UPLOAD_DIR, file, filename) + + @app_fastapi.post("/load_model") + async def post_load_model( + modelFilename: str = Form(...), + modelFilenameChunkNum: int = Form(...), + configFilename: str = Form(...) + ): + + modelFilePath = concat_file_chunks( + UPLOAD_DIR, modelFilename, modelFilenameChunkNum, UPLOAD_DIR) + print(f'File saved to: {modelFilePath}') + configFilePath = os.path.join(UPLOAD_DIR, configFilename) + + voiceChangerManager.loadModel(configFilePath, modelFilePath) + return {"load": f"{modelFilePath}, {configFilePath}"} + + @app_fastapi.post("/load_model_for_train") + async def post_load_model_for_train( + modelGFilename: str = Form(...), + modelGFilenameChunkNum: int = Form(...), + modelDFilename: str = Form(...), + modelDFilenameChunkNum: int = Form(...), + ): + + modelGFilePath = concat_file_chunks( + UPLOAD_DIR, modelGFilename, modelGFilenameChunkNum, MODEL_DIR) + modelDFilePath = concat_file_chunks( + UPLOAD_DIR, modelDFilename, modelDFilenameChunkNum, MODEL_DIR) + return {"File saved": f"{modelGFilePath}, {modelDFilePath}"} + + @app_fastapi.post("/extract_voices") + async def post_load_model( + zipFilename: str = Form(...), + zipFileChunkNum: int = Form(...), + ): + zipFilePath = concat_file_chunks( + UPLOAD_DIR, zipFilename, zipFileChunkNum, UPLOAD_DIR) + shutil.unpack_archive(zipFilePath, "MMVC_Trainer/dataset/textful/") + return {"Zip file unpacked": f"{zipFilePath}"} + + ############ + # Voice Changer + # ########## + + @app_fastapi.post("/test") + async def post_test(voice: VoiceModel): + try: + # print("POST REQUEST PROCESSING....") + gpu = voice.gpu + srcId = voice.srcId + dstId = voice.dstId + timestamp = voice.timestamp + prefixChunkSize = voice.prefixChunkSize + buffer = voice.buffer + wav = base64.b64decode(buffer) + + if wav == 0: + samplerate, data = read("dummy.wav") + unpackedData = data + else: + unpackedData = np.array(struct.unpack( + '<%sh' % (len(wav) // struct.calcsize(':/ with your browser.", level=0) + else: + printMessage( + f"open http://:/ with your browser.", level=0) + + if TYPE == "MMVC": + path = "" + else: + path = "trainer" + if "EX_PORT" in locals() and "EX_IP" in locals() and args.https == 1: + printMessage(f"In many cases it is one of the following", level=1) + printMessage(f"https://localhost:{EX_PORT}/{path}", level=1) + for ip in EX_IP.strip().split(" "): + printMessage(f"https://{ip}:{EX_PORT}/{path}", level=1) + elif "EX_PORT" in locals() and "EX_IP" in locals() and args.https == 0: + printMessage(f"In many cases it is one of the following", level=1) + printMessage(f"http://localhost:{EX_PORT}/{path}", level=1) + + # サーバ起動 + if args.https: + # HTTPS サーバ起動 + uvicorn.run( + f"{os.path.basename(__file__)[:-3]}:app_socketio", + host="0.0.0.0", + port=int(PORT), + reload=True, + ssl_keyfile=key_path, + ssl_certfile=cert_path, + log_level="critical" + ) + else: + # HTTP サーバ起動 + if args.colab == True: + uvicorn.run( + f"{os.path.basename(__file__)[:-3]}:app_fastapi", + host="0.0.0.0", + port=int(PORT), + log_level="critical" + ) + else: + uvicorn.run( + f"{os.path.basename(__file__)[:-3]}:app_socketio", + host="0.0.0.0", + port=int(PORT), + reload=True, + log_level="critical" + ) + diff --git a/server/sio/MMVC_Namespace.py b/server/sio/MMVC_Namespace.py new file mode 100644 index 00000000..7d999a04 --- /dev/null +++ b/server/sio/MMVC_Namespace.py @@ -0,0 +1,43 @@ +import struct +from datetime import datetime +import numpy as np +import socketio +from voice_changer.VoiceChangerManager import VoiceChangerManager + + +class MMVC_Namespace(socketio.AsyncNamespace): + def __init__(self, namespace:str, voiceChangerManager:VoiceChangerManager): + super().__init__(namespace) + self.voiceChangerManager = voiceChangerManager + + @classmethod + def get_instance(cls, voiceChangerManager:VoiceChangerManager): + if not hasattr(cls, "_instance"): + cls._instance = cls("/test", voiceChangerManager) + return cls._instance + + def on_connect(self, sid, environ): + # print('[{}] connet sid : {}'.format(datetime.now().strftime('%Y-%m-%d %H:%M:%S') , sid)) + pass + + async def on_request_message(self, sid, msg): + # print("on_request_message", torch.cuda.memory_allocated()) + gpu = int(msg[0]) + srcId = int(msg[1]) + dstId = int(msg[2]) + timestamp = int(msg[3]) + prefixChunkSize = int(msg[4]) + data = msg[5] + # print(srcId, dstId, timestamp) + unpackedData = np.array(struct.unpack( + '<%sh' % (len(data) // struct.calcsize('