mirror of
https://github.com/w-okada/voice-changer.git
synced 2025-01-23 13:35:12 +03:00
WIP: refactor, model switcher
This commit is contained in:
parent
cdb3234111
commit
7fa6de855d
@ -14,7 +14,7 @@ from mods.ssl import create_self_signed_cert
|
||||
from voice_changer.VoiceChangerManager import VoiceChangerManager
|
||||
from sio.MMVC_SocketIOApp import MMVC_SocketIOApp
|
||||
from restapi.MMVC_Rest import MMVC_Rest
|
||||
from const import NATIVE_CLIENT_FILE_MAC, NATIVE_CLIENT_FILE_WIN, SSL_KEY_DIR
|
||||
from const import NATIVE_CLIENT_FILE_MAC, NATIVE_CLIENT_FILE_WIN, SSL_KEY_DIR, setModelType
|
||||
import subprocess
|
||||
import multiprocessing as mp
|
||||
|
||||
@ -37,6 +37,8 @@ def setupArgParser():
|
||||
default=True, help="generate self-signed certificate")
|
||||
parser.add_argument("--colab", type=strtobool,
|
||||
default=False, help="run on colab")
|
||||
parser.add_argument("--modelType", type=str,
|
||||
default="MMVCv15", help="model type")
|
||||
|
||||
return parser
|
||||
|
||||
@ -62,12 +64,8 @@ def printMessage(message, level=0):
|
||||
else:
|
||||
print(f"\033[47m {message}\033[0m")
|
||||
|
||||
# global app_socketio
|
||||
# global app_fastapi
|
||||
|
||||
|
||||
parser = setupArgParser()
|
||||
# args = parser.parse_args()
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
# printMessage(f"Phase name:{__name__}", level=2)
|
||||
@ -75,11 +73,14 @@ args, unknown = parser.parse_known_args()
|
||||
|
||||
# if __name__ == thisFilename or args.colab == True:
|
||||
# printMessage(f"PHASE3:{__name__}", level=2)
|
||||
|
||||
TYPE = args.t
|
||||
PORT = args.p
|
||||
CONFIG = args.c
|
||||
MODEL = args.m if args.m != None else None
|
||||
ONNX_MODEL = args.o if args.o != None else None
|
||||
MODEL_TYPE = args.modelType
|
||||
setModelType(MODEL_TYPE)
|
||||
|
||||
|
||||
def localServer():
|
||||
|
@ -2,15 +2,6 @@ import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
# MODEL_TYPE = "MMVCv13"
|
||||
MODEL_TYPE = "MMVCv15"
|
||||
# MODEL_TYPE = "sovt-vits-svc"
|
||||
|
||||
if MODEL_TYPE == "MMVCv15":
|
||||
frontend_path = os.path.join(sys._MEIPASS, "dist") if hasattr(sys, "_MEIPASS") else "../client/demo_v15/dist"
|
||||
elif MODEL_TYPE == "MMVCv13":
|
||||
frontend_path = os.path.join(sys._MEIPASS, "dist") if hasattr(sys, "_MEIPASS") else "../client/demo_v13/dist"
|
||||
|
||||
ERROR_NO_ONNX_SESSION = "ERROR_NO_ONNX_SESSION"
|
||||
|
||||
|
||||
@ -30,3 +21,24 @@ os.makedirs(TMP_DIR, exist_ok=True)
|
||||
# SSL_KEY_DIR = os.path.join(sys._MEIPASS, "keys") if hasattr(sys, "_MEIPASS") else "keys"
|
||||
# MODEL_DIR = os.path.join(sys._MEIPASS, "logs") if hasattr(sys, "_MEIPASS") else "logs"
|
||||
# UPLOAD_DIR = os.path.join(sys._MEIPASS, "upload_dir") if hasattr(sys, "_MEIPASS") else "upload_dir"
|
||||
|
||||
|
||||
modelType = "MMVCv15"
|
||||
|
||||
|
||||
def getModelType():
|
||||
return modelType
|
||||
|
||||
|
||||
def setModelType(_modelType: str):
|
||||
global modelType
|
||||
modelType = _modelType
|
||||
|
||||
|
||||
def getFrontendPath():
|
||||
if modelType == "MMVCv15":
|
||||
frontend_path = os.path.join(sys._MEIPASS, "dist") if hasattr(sys, "_MEIPASS") else "../client/demo_v15/dist"
|
||||
elif modelType == "MMVCv13":
|
||||
frontend_path = os.path.join(sys._MEIPASS, "dist") if hasattr(sys, "_MEIPASS") else "../client/demo_v13/dist"
|
||||
|
||||
return frontend_path
|
||||
|
@ -9,7 +9,7 @@ from restapi.MMVC_Rest_Hello import MMVC_Rest_Hello
|
||||
from restapi.MMVC_Rest_VoiceChanger import MMVC_Rest_VoiceChanger
|
||||
from restapi.MMVC_Rest_Fileuploader import MMVC_Rest_Fileuploader
|
||||
from restapi.MMVC_Rest_Trainer import MMVC_Rest_Trainer
|
||||
from const import frontend_path, TMP_DIR
|
||||
from const import getFrontendPath, TMP_DIR
|
||||
|
||||
|
||||
class ValidationErrorLoggingRoute(APIRoute):
|
||||
@ -44,13 +44,13 @@ class MMVC_Rest:
|
||||
)
|
||||
|
||||
app_fastapi.mount(
|
||||
"/front", StaticFiles(directory=f'{frontend_path}', html=True), name="static")
|
||||
"/front", StaticFiles(directory=f'{getFrontendPath()}', html=True), name="static")
|
||||
|
||||
app_fastapi.mount(
|
||||
"/trainer", StaticFiles(directory=f'{frontend_path}', html=True), name="static")
|
||||
"/trainer", StaticFiles(directory=f'{getFrontendPath()}', html=True), name="static")
|
||||
|
||||
app_fastapi.mount(
|
||||
"/recorder", StaticFiles(directory=f'{frontend_path}', html=True), name="static")
|
||||
"/recorder", StaticFiles(directory=f'{getFrontendPath()}', html=True), name="static")
|
||||
app_fastapi.mount(
|
||||
"/tmp", StaticFiles(directory=f'{TMP_DIR}'), name="static")
|
||||
|
||||
|
@ -2,7 +2,7 @@ import socketio
|
||||
|
||||
from sio.MMVC_SocketIOServer import MMVC_SocketIOServer
|
||||
from voice_changer.VoiceChangerManager import VoiceChangerManager
|
||||
from const import frontend_path
|
||||
from const import getFrontendPath
|
||||
|
||||
|
||||
class MMVC_SocketIOApp():
|
||||
@ -15,19 +15,19 @@ class MMVC_SocketIOApp():
|
||||
other_asgi_app=app_fastapi,
|
||||
static_files={
|
||||
'/assets/icons/github.svg': {
|
||||
'filename': f'{frontend_path}/assets/icons/github.svg',
|
||||
'filename': f'{getFrontendPath()}/assets/icons/github.svg',
|
||||
'content_type': 'image/svg+xml'
|
||||
},
|
||||
'/assets/icons/help-circle.svg': {
|
||||
'filename': f'{frontend_path}/assets/icons/help-circle.svg',
|
||||
'filename': f'{getFrontendPath()}/assets/icons/help-circle.svg',
|
||||
'content_type': 'image/svg+xml'
|
||||
},
|
||||
'/buymeacoffee.png': {
|
||||
'filename': f'{frontend_path}/assets/buymeacoffee.png',
|
||||
'filename': f'{getFrontendPath()}/assets/buymeacoffee.png',
|
||||
'content_type': 'image/png'
|
||||
},
|
||||
'': f'{frontend_path}',
|
||||
'/': f'{frontend_path}/index.html',
|
||||
'': f'{getFrontendPath()}',
|
||||
'/': f'{getFrontendPath()}/index.html',
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from const import TMP_DIR, MODEL_TYPE
|
||||
from const import TMP_DIR, getModelType
|
||||
import torch
|
||||
import os
|
||||
import traceback
|
||||
@ -7,11 +7,6 @@ from dataclasses import dataclass, asdict
|
||||
import resampy
|
||||
|
||||
|
||||
if MODEL_TYPE == "MMVCv15":
|
||||
from voice_changer.MMVCv15.MMVCv15 import MMVCv15
|
||||
else:
|
||||
from voice_changer.MMVCv13.MMVCv13 import MMVCv13
|
||||
|
||||
from voice_changer.IORecorder import IORecorder
|
||||
from voice_changer.IOAnalyzer import IOAnalyzer
|
||||
|
||||
@ -53,9 +48,13 @@ class VoiceChanger():
|
||||
self.currentCrossFadeEndRate = 0
|
||||
self.currentCrossFadeOverlapSize = 0
|
||||
|
||||
if MODEL_TYPE == "MMVCv15":
|
||||
modelType = getModelType()
|
||||
print("[VoiceChanger] activate model type:", modelType)
|
||||
if modelType == "MMVCv15":
|
||||
from voice_changer.MMVCv15.MMVCv15 import MMVCv15
|
||||
self.voiceChanger = MMVCv15()
|
||||
else:
|
||||
from voice_changer.MMVCv13.MMVCv13 import MMVCv13
|
||||
self.voiceChanger = MMVCv13()
|
||||
|
||||
self.gpu_num = torch.cuda.device_count()
|
||||
|
Loading…
Reference in New Issue
Block a user