From 7fa6de855df1986b72e7c31e30319878441c7448 Mon Sep 17 00:00:00 2001 From: wataru Date: Wed, 8 Mar 2023 09:48:50 +0900 Subject: [PATCH] WIP: refactor, model switcher --- server/MMVCServerSIO.py | 11 +++++----- server/const.py | 30 +++++++++++++++++++--------- server/restapi/MMVC_Rest.py | 8 ++++---- server/sio/MMVC_SocketIOApp.py | 12 +++++------ server/voice_changer/VoiceChanger.py | 13 ++++++------ 5 files changed, 43 insertions(+), 31 deletions(-) diff --git a/server/MMVCServerSIO.py b/server/MMVCServerSIO.py index e389cf51..510d7e91 100755 --- a/server/MMVCServerSIO.py +++ b/server/MMVCServerSIO.py @@ -14,7 +14,7 @@ from mods.ssl import create_self_signed_cert from voice_changer.VoiceChangerManager import VoiceChangerManager from sio.MMVC_SocketIOApp import MMVC_SocketIOApp from restapi.MMVC_Rest import MMVC_Rest -from const import NATIVE_CLIENT_FILE_MAC, NATIVE_CLIENT_FILE_WIN, SSL_KEY_DIR +from const import NATIVE_CLIENT_FILE_MAC, NATIVE_CLIENT_FILE_WIN, SSL_KEY_DIR, setModelType import subprocess import multiprocessing as mp @@ -37,6 +37,8 @@ def setupArgParser(): default=True, help="generate self-signed certificate") parser.add_argument("--colab", type=strtobool, default=False, help="run on colab") + parser.add_argument("--modelType", type=str, + default="MMVCv15", help="model type") return parser @@ -62,12 +64,8 @@ def printMessage(message, level=0): else: print(f"\033[47m {message}\033[0m") -# global app_socketio -# global app_fastapi - parser = setupArgParser() -# args = parser.parse_args() args, unknown = parser.parse_known_args() # printMessage(f"Phase name:{__name__}", level=2) @@ -75,11 +73,14 @@ args, unknown = parser.parse_known_args() # if __name__ == thisFilename or args.colab == True: # printMessage(f"PHASE3:{__name__}", level=2) + TYPE = args.t PORT = args.p CONFIG = args.c MODEL = args.m if args.m != None else None ONNX_MODEL = args.o if args.o != None else None +MODEL_TYPE = args.modelType +setModelType(MODEL_TYPE) def localServer(): diff --git a/server/const.py b/server/const.py index 076439b2..d6ba7ed8 100644 --- a/server/const.py +++ b/server/const.py @@ -2,15 +2,6 @@ import os import sys import tempfile -# MODEL_TYPE = "MMVCv13" -MODEL_TYPE = "MMVCv15" -# MODEL_TYPE = "sovt-vits-svc" - -if MODEL_TYPE == "MMVCv15": - frontend_path = os.path.join(sys._MEIPASS, "dist") if hasattr(sys, "_MEIPASS") else "../client/demo_v15/dist" -elif MODEL_TYPE == "MMVCv13": - frontend_path = os.path.join(sys._MEIPASS, "dist") if hasattr(sys, "_MEIPASS") else "../client/demo_v13/dist" - ERROR_NO_ONNX_SESSION = "ERROR_NO_ONNX_SESSION" @@ -30,3 +21,24 @@ os.makedirs(TMP_DIR, exist_ok=True) # SSL_KEY_DIR = os.path.join(sys._MEIPASS, "keys") if hasattr(sys, "_MEIPASS") else "keys" # MODEL_DIR = os.path.join(sys._MEIPASS, "logs") if hasattr(sys, "_MEIPASS") else "logs" # UPLOAD_DIR = os.path.join(sys._MEIPASS, "upload_dir") if hasattr(sys, "_MEIPASS") else "upload_dir" + + +modelType = "MMVCv15" + + +def getModelType(): + return modelType + + +def setModelType(_modelType: str): + global modelType + modelType = _modelType + + +def getFrontendPath(): + if modelType == "MMVCv15": + frontend_path = os.path.join(sys._MEIPASS, "dist") if hasattr(sys, "_MEIPASS") else "../client/demo_v15/dist" + elif modelType == "MMVCv13": + frontend_path = os.path.join(sys._MEIPASS, "dist") if hasattr(sys, "_MEIPASS") else "../client/demo_v13/dist" + + return frontend_path diff --git a/server/restapi/MMVC_Rest.py b/server/restapi/MMVC_Rest.py index e5014c4d..a109247e 100644 --- a/server/restapi/MMVC_Rest.py +++ b/server/restapi/MMVC_Rest.py @@ -9,7 +9,7 @@ from restapi.MMVC_Rest_Hello import MMVC_Rest_Hello from restapi.MMVC_Rest_VoiceChanger import MMVC_Rest_VoiceChanger from restapi.MMVC_Rest_Fileuploader import MMVC_Rest_Fileuploader from restapi.MMVC_Rest_Trainer import MMVC_Rest_Trainer -from const import frontend_path, TMP_DIR +from const import getFrontendPath, TMP_DIR class ValidationErrorLoggingRoute(APIRoute): @@ -44,13 +44,13 @@ class MMVC_Rest: ) app_fastapi.mount( - "/front", StaticFiles(directory=f'{frontend_path}', html=True), name="static") + "/front", StaticFiles(directory=f'{getFrontendPath()}', html=True), name="static") app_fastapi.mount( - "/trainer", StaticFiles(directory=f'{frontend_path}', html=True), name="static") + "/trainer", StaticFiles(directory=f'{getFrontendPath()}', html=True), name="static") app_fastapi.mount( - "/recorder", StaticFiles(directory=f'{frontend_path}', html=True), name="static") + "/recorder", StaticFiles(directory=f'{getFrontendPath()}', html=True), name="static") app_fastapi.mount( "/tmp", StaticFiles(directory=f'{TMP_DIR}'), name="static") diff --git a/server/sio/MMVC_SocketIOApp.py b/server/sio/MMVC_SocketIOApp.py index fdc1e86a..cc8ec449 100644 --- a/server/sio/MMVC_SocketIOApp.py +++ b/server/sio/MMVC_SocketIOApp.py @@ -2,7 +2,7 @@ import socketio from sio.MMVC_SocketIOServer import MMVC_SocketIOServer from voice_changer.VoiceChangerManager import VoiceChangerManager -from const import frontend_path +from const import getFrontendPath class MMVC_SocketIOApp(): @@ -15,19 +15,19 @@ class MMVC_SocketIOApp(): other_asgi_app=app_fastapi, static_files={ '/assets/icons/github.svg': { - 'filename': f'{frontend_path}/assets/icons/github.svg', + 'filename': f'{getFrontendPath()}/assets/icons/github.svg', 'content_type': 'image/svg+xml' }, '/assets/icons/help-circle.svg': { - 'filename': f'{frontend_path}/assets/icons/help-circle.svg', + 'filename': f'{getFrontendPath()}/assets/icons/help-circle.svg', 'content_type': 'image/svg+xml' }, '/buymeacoffee.png': { - 'filename': f'{frontend_path}/assets/buymeacoffee.png', + 'filename': f'{getFrontendPath()}/assets/buymeacoffee.png', 'content_type': 'image/png' }, - '': f'{frontend_path}', - '/': f'{frontend_path}/index.html', + '': f'{getFrontendPath()}', + '/': f'{getFrontendPath()}/index.html', } ) diff --git a/server/voice_changer/VoiceChanger.py b/server/voice_changer/VoiceChanger.py index 5d46c70b..9f1a16ad 100755 --- a/server/voice_changer/VoiceChanger.py +++ b/server/voice_changer/VoiceChanger.py @@ -1,4 +1,4 @@ -from const import TMP_DIR, MODEL_TYPE +from const import TMP_DIR, getModelType import torch import os import traceback @@ -7,11 +7,6 @@ from dataclasses import dataclass, asdict import resampy -if MODEL_TYPE == "MMVCv15": - from voice_changer.MMVCv15.MMVCv15 import MMVCv15 -else: - from voice_changer.MMVCv13.MMVCv13 import MMVCv13 - from voice_changer.IORecorder import IORecorder from voice_changer.IOAnalyzer import IOAnalyzer @@ -53,9 +48,13 @@ class VoiceChanger(): self.currentCrossFadeEndRate = 0 self.currentCrossFadeOverlapSize = 0 - if MODEL_TYPE == "MMVCv15": + modelType = getModelType() + print("[VoiceChanger] activate model type:", modelType) + if modelType == "MMVCv15": + from voice_changer.MMVCv15.MMVCv15 import MMVCv15 self.voiceChanger = MMVCv15() else: + from voice_changer.MMVCv13.MMVCv13 import MMVCv13 self.voiceChanger = MMVCv13() self.gpu_num = torch.cuda.device_count()