WIP: refactor, model switcher

This commit is contained in:
wataru 2023-03-08 09:48:50 +09:00
parent cdb3234111
commit 7fa6de855d
5 changed files with 43 additions and 31 deletions

View File

@ -14,7 +14,7 @@ from mods.ssl import create_self_signed_cert
from voice_changer.VoiceChangerManager import VoiceChangerManager from voice_changer.VoiceChangerManager import VoiceChangerManager
from sio.MMVC_SocketIOApp import MMVC_SocketIOApp from sio.MMVC_SocketIOApp import MMVC_SocketIOApp
from restapi.MMVC_Rest import MMVC_Rest from restapi.MMVC_Rest import MMVC_Rest
from const import NATIVE_CLIENT_FILE_MAC, NATIVE_CLIENT_FILE_WIN, SSL_KEY_DIR from const import NATIVE_CLIENT_FILE_MAC, NATIVE_CLIENT_FILE_WIN, SSL_KEY_DIR, setModelType
import subprocess import subprocess
import multiprocessing as mp import multiprocessing as mp
@ -37,6 +37,8 @@ def setupArgParser():
default=True, help="generate self-signed certificate") default=True, help="generate self-signed certificate")
parser.add_argument("--colab", type=strtobool, parser.add_argument("--colab", type=strtobool,
default=False, help="run on colab") default=False, help="run on colab")
parser.add_argument("--modelType", type=str,
default="MMVCv15", help="model type")
return parser return parser
@ -62,12 +64,8 @@ def printMessage(message, level=0):
else: else:
print(f"\033[47m {message}\033[0m") print(f"\033[47m {message}\033[0m")
# global app_socketio
# global app_fastapi
parser = setupArgParser() parser = setupArgParser()
# args = parser.parse_args()
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
# printMessage(f"Phase name:{__name__}", level=2) # printMessage(f"Phase name:{__name__}", level=2)
@ -75,11 +73,14 @@ args, unknown = parser.parse_known_args()
# if __name__ == thisFilename or args.colab == True: # if __name__ == thisFilename or args.colab == True:
# printMessage(f"PHASE3:{__name__}", level=2) # printMessage(f"PHASE3:{__name__}", level=2)
TYPE = args.t TYPE = args.t
PORT = args.p PORT = args.p
CONFIG = args.c CONFIG = args.c
MODEL = args.m if args.m != None else None MODEL = args.m if args.m != None else None
ONNX_MODEL = args.o if args.o != None else None ONNX_MODEL = args.o if args.o != None else None
MODEL_TYPE = args.modelType
setModelType(MODEL_TYPE)
def localServer(): def localServer():

View File

@ -2,15 +2,6 @@ import os
import sys import sys
import tempfile import tempfile
# MODEL_TYPE = "MMVCv13"
MODEL_TYPE = "MMVCv15"
# MODEL_TYPE = "sovt-vits-svc"
if MODEL_TYPE == "MMVCv15":
frontend_path = os.path.join(sys._MEIPASS, "dist") if hasattr(sys, "_MEIPASS") else "../client/demo_v15/dist"
elif MODEL_TYPE == "MMVCv13":
frontend_path = os.path.join(sys._MEIPASS, "dist") if hasattr(sys, "_MEIPASS") else "../client/demo_v13/dist"
ERROR_NO_ONNX_SESSION = "ERROR_NO_ONNX_SESSION" ERROR_NO_ONNX_SESSION = "ERROR_NO_ONNX_SESSION"
@ -30,3 +21,24 @@ os.makedirs(TMP_DIR, exist_ok=True)
# SSL_KEY_DIR = os.path.join(sys._MEIPASS, "keys") if hasattr(sys, "_MEIPASS") else "keys" # SSL_KEY_DIR = os.path.join(sys._MEIPASS, "keys") if hasattr(sys, "_MEIPASS") else "keys"
# MODEL_DIR = os.path.join(sys._MEIPASS, "logs") if hasattr(sys, "_MEIPASS") else "logs" # MODEL_DIR = os.path.join(sys._MEIPASS, "logs") if hasattr(sys, "_MEIPASS") else "logs"
# UPLOAD_DIR = os.path.join(sys._MEIPASS, "upload_dir") if hasattr(sys, "_MEIPASS") else "upload_dir" # UPLOAD_DIR = os.path.join(sys._MEIPASS, "upload_dir") if hasattr(sys, "_MEIPASS") else "upload_dir"
modelType = "MMVCv15"
def getModelType():
return modelType
def setModelType(_modelType: str):
global modelType
modelType = _modelType
def getFrontendPath():
if modelType == "MMVCv15":
frontend_path = os.path.join(sys._MEIPASS, "dist") if hasattr(sys, "_MEIPASS") else "../client/demo_v15/dist"
elif modelType == "MMVCv13":
frontend_path = os.path.join(sys._MEIPASS, "dist") if hasattr(sys, "_MEIPASS") else "../client/demo_v13/dist"
return frontend_path

View File

@ -9,7 +9,7 @@ from restapi.MMVC_Rest_Hello import MMVC_Rest_Hello
from restapi.MMVC_Rest_VoiceChanger import MMVC_Rest_VoiceChanger from restapi.MMVC_Rest_VoiceChanger import MMVC_Rest_VoiceChanger
from restapi.MMVC_Rest_Fileuploader import MMVC_Rest_Fileuploader from restapi.MMVC_Rest_Fileuploader import MMVC_Rest_Fileuploader
from restapi.MMVC_Rest_Trainer import MMVC_Rest_Trainer from restapi.MMVC_Rest_Trainer import MMVC_Rest_Trainer
from const import frontend_path, TMP_DIR from const import getFrontendPath, TMP_DIR
class ValidationErrorLoggingRoute(APIRoute): class ValidationErrorLoggingRoute(APIRoute):
@ -44,13 +44,13 @@ class MMVC_Rest:
) )
app_fastapi.mount( app_fastapi.mount(
"/front", StaticFiles(directory=f'{frontend_path}', html=True), name="static") "/front", StaticFiles(directory=f'{getFrontendPath()}', html=True), name="static")
app_fastapi.mount( app_fastapi.mount(
"/trainer", StaticFiles(directory=f'{frontend_path}', html=True), name="static") "/trainer", StaticFiles(directory=f'{getFrontendPath()}', html=True), name="static")
app_fastapi.mount( app_fastapi.mount(
"/recorder", StaticFiles(directory=f'{frontend_path}', html=True), name="static") "/recorder", StaticFiles(directory=f'{getFrontendPath()}', html=True), name="static")
app_fastapi.mount( app_fastapi.mount(
"/tmp", StaticFiles(directory=f'{TMP_DIR}'), name="static") "/tmp", StaticFiles(directory=f'{TMP_DIR}'), name="static")

View File

@ -2,7 +2,7 @@ import socketio
from sio.MMVC_SocketIOServer import MMVC_SocketIOServer from sio.MMVC_SocketIOServer import MMVC_SocketIOServer
from voice_changer.VoiceChangerManager import VoiceChangerManager from voice_changer.VoiceChangerManager import VoiceChangerManager
from const import frontend_path from const import getFrontendPath
class MMVC_SocketIOApp(): class MMVC_SocketIOApp():
@ -15,19 +15,19 @@ class MMVC_SocketIOApp():
other_asgi_app=app_fastapi, other_asgi_app=app_fastapi,
static_files={ static_files={
'/assets/icons/github.svg': { '/assets/icons/github.svg': {
'filename': f'{frontend_path}/assets/icons/github.svg', 'filename': f'{getFrontendPath()}/assets/icons/github.svg',
'content_type': 'image/svg+xml' 'content_type': 'image/svg+xml'
}, },
'/assets/icons/help-circle.svg': { '/assets/icons/help-circle.svg': {
'filename': f'{frontend_path}/assets/icons/help-circle.svg', 'filename': f'{getFrontendPath()}/assets/icons/help-circle.svg',
'content_type': 'image/svg+xml' 'content_type': 'image/svg+xml'
}, },
'/buymeacoffee.png': { '/buymeacoffee.png': {
'filename': f'{frontend_path}/assets/buymeacoffee.png', 'filename': f'{getFrontendPath()}/assets/buymeacoffee.png',
'content_type': 'image/png' 'content_type': 'image/png'
}, },
'': f'{frontend_path}', '': f'{getFrontendPath()}',
'/': f'{frontend_path}/index.html', '/': f'{getFrontendPath()}/index.html',
} }
) )

View File

@ -1,4 +1,4 @@
from const import TMP_DIR, MODEL_TYPE from const import TMP_DIR, getModelType
import torch import torch
import os import os
import traceback import traceback
@ -7,11 +7,6 @@ from dataclasses import dataclass, asdict
import resampy import resampy
if MODEL_TYPE == "MMVCv15":
from voice_changer.MMVCv15.MMVCv15 import MMVCv15
else:
from voice_changer.MMVCv13.MMVCv13 import MMVCv13
from voice_changer.IORecorder import IORecorder from voice_changer.IORecorder import IORecorder
from voice_changer.IOAnalyzer import IOAnalyzer from voice_changer.IOAnalyzer import IOAnalyzer
@ -53,9 +48,13 @@ class VoiceChanger():
self.currentCrossFadeEndRate = 0 self.currentCrossFadeEndRate = 0
self.currentCrossFadeOverlapSize = 0 self.currentCrossFadeOverlapSize = 0
if MODEL_TYPE == "MMVCv15": modelType = getModelType()
print("[VoiceChanger] activate model type:", modelType)
if modelType == "MMVCv15":
from voice_changer.MMVCv15.MMVCv15 import MMVCv15
self.voiceChanger = MMVCv15() self.voiceChanger = MMVCv15()
else: else:
from voice_changer.MMVCv13.MMVCv13 import MMVCv13
self.voiceChanger = MMVCv13() self.voiceChanger = MMVCv13()
self.gpu_num = torch.cuda.device_count() self.gpu_num = torch.cuda.device_count()