mirror of
https://github.com/w-okada/voice-changer.git
synced 2025-01-23 21:45:00 +03:00
WIP: refactor, model switcher
This commit is contained in:
parent
cdb3234111
commit
7fa6de855d
@ -14,7 +14,7 @@ from mods.ssl import create_self_signed_cert
|
|||||||
from voice_changer.VoiceChangerManager import VoiceChangerManager
|
from voice_changer.VoiceChangerManager import VoiceChangerManager
|
||||||
from sio.MMVC_SocketIOApp import MMVC_SocketIOApp
|
from sio.MMVC_SocketIOApp import MMVC_SocketIOApp
|
||||||
from restapi.MMVC_Rest import MMVC_Rest
|
from restapi.MMVC_Rest import MMVC_Rest
|
||||||
from const import NATIVE_CLIENT_FILE_MAC, NATIVE_CLIENT_FILE_WIN, SSL_KEY_DIR
|
from const import NATIVE_CLIENT_FILE_MAC, NATIVE_CLIENT_FILE_WIN, SSL_KEY_DIR, setModelType
|
||||||
import subprocess
|
import subprocess
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
|
|
||||||
@ -37,6 +37,8 @@ def setupArgParser():
|
|||||||
default=True, help="generate self-signed certificate")
|
default=True, help="generate self-signed certificate")
|
||||||
parser.add_argument("--colab", type=strtobool,
|
parser.add_argument("--colab", type=strtobool,
|
||||||
default=False, help="run on colab")
|
default=False, help="run on colab")
|
||||||
|
parser.add_argument("--modelType", type=str,
|
||||||
|
default="MMVCv15", help="model type")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -62,12 +64,8 @@ def printMessage(message, level=0):
|
|||||||
else:
|
else:
|
||||||
print(f"\033[47m {message}\033[0m")
|
print(f"\033[47m {message}\033[0m")
|
||||||
|
|
||||||
# global app_socketio
|
|
||||||
# global app_fastapi
|
|
||||||
|
|
||||||
|
|
||||||
parser = setupArgParser()
|
parser = setupArgParser()
|
||||||
# args = parser.parse_args()
|
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
# printMessage(f"Phase name:{__name__}", level=2)
|
# printMessage(f"Phase name:{__name__}", level=2)
|
||||||
@ -75,11 +73,14 @@ args, unknown = parser.parse_known_args()
|
|||||||
|
|
||||||
# if __name__ == thisFilename or args.colab == True:
|
# if __name__ == thisFilename or args.colab == True:
|
||||||
# printMessage(f"PHASE3:{__name__}", level=2)
|
# printMessage(f"PHASE3:{__name__}", level=2)
|
||||||
|
|
||||||
TYPE = args.t
|
TYPE = args.t
|
||||||
PORT = args.p
|
PORT = args.p
|
||||||
CONFIG = args.c
|
CONFIG = args.c
|
||||||
MODEL = args.m if args.m != None else None
|
MODEL = args.m if args.m != None else None
|
||||||
ONNX_MODEL = args.o if args.o != None else None
|
ONNX_MODEL = args.o if args.o != None else None
|
||||||
|
MODEL_TYPE = args.modelType
|
||||||
|
setModelType(MODEL_TYPE)
|
||||||
|
|
||||||
|
|
||||||
def localServer():
|
def localServer():
|
||||||
|
@ -2,15 +2,6 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
# MODEL_TYPE = "MMVCv13"
|
|
||||||
MODEL_TYPE = "MMVCv15"
|
|
||||||
# MODEL_TYPE = "sovt-vits-svc"
|
|
||||||
|
|
||||||
if MODEL_TYPE == "MMVCv15":
|
|
||||||
frontend_path = os.path.join(sys._MEIPASS, "dist") if hasattr(sys, "_MEIPASS") else "../client/demo_v15/dist"
|
|
||||||
elif MODEL_TYPE == "MMVCv13":
|
|
||||||
frontend_path = os.path.join(sys._MEIPASS, "dist") if hasattr(sys, "_MEIPASS") else "../client/demo_v13/dist"
|
|
||||||
|
|
||||||
ERROR_NO_ONNX_SESSION = "ERROR_NO_ONNX_SESSION"
|
ERROR_NO_ONNX_SESSION = "ERROR_NO_ONNX_SESSION"
|
||||||
|
|
||||||
|
|
||||||
@ -30,3 +21,24 @@ os.makedirs(TMP_DIR, exist_ok=True)
|
|||||||
# SSL_KEY_DIR = os.path.join(sys._MEIPASS, "keys") if hasattr(sys, "_MEIPASS") else "keys"
|
# SSL_KEY_DIR = os.path.join(sys._MEIPASS, "keys") if hasattr(sys, "_MEIPASS") else "keys"
|
||||||
# MODEL_DIR = os.path.join(sys._MEIPASS, "logs") if hasattr(sys, "_MEIPASS") else "logs"
|
# MODEL_DIR = os.path.join(sys._MEIPASS, "logs") if hasattr(sys, "_MEIPASS") else "logs"
|
||||||
# UPLOAD_DIR = os.path.join(sys._MEIPASS, "upload_dir") if hasattr(sys, "_MEIPASS") else "upload_dir"
|
# UPLOAD_DIR = os.path.join(sys._MEIPASS, "upload_dir") if hasattr(sys, "_MEIPASS") else "upload_dir"
|
||||||
|
|
||||||
|
|
||||||
|
modelType = "MMVCv15"
|
||||||
|
|
||||||
|
|
||||||
|
def getModelType():
|
||||||
|
return modelType
|
||||||
|
|
||||||
|
|
||||||
|
def setModelType(_modelType: str):
|
||||||
|
global modelType
|
||||||
|
modelType = _modelType
|
||||||
|
|
||||||
|
|
||||||
|
def getFrontendPath():
|
||||||
|
if modelType == "MMVCv15":
|
||||||
|
frontend_path = os.path.join(sys._MEIPASS, "dist") if hasattr(sys, "_MEIPASS") else "../client/demo_v15/dist"
|
||||||
|
elif modelType == "MMVCv13":
|
||||||
|
frontend_path = os.path.join(sys._MEIPASS, "dist") if hasattr(sys, "_MEIPASS") else "../client/demo_v13/dist"
|
||||||
|
|
||||||
|
return frontend_path
|
||||||
|
@ -9,7 +9,7 @@ from restapi.MMVC_Rest_Hello import MMVC_Rest_Hello
|
|||||||
from restapi.MMVC_Rest_VoiceChanger import MMVC_Rest_VoiceChanger
|
from restapi.MMVC_Rest_VoiceChanger import MMVC_Rest_VoiceChanger
|
||||||
from restapi.MMVC_Rest_Fileuploader import MMVC_Rest_Fileuploader
|
from restapi.MMVC_Rest_Fileuploader import MMVC_Rest_Fileuploader
|
||||||
from restapi.MMVC_Rest_Trainer import MMVC_Rest_Trainer
|
from restapi.MMVC_Rest_Trainer import MMVC_Rest_Trainer
|
||||||
from const import frontend_path, TMP_DIR
|
from const import getFrontendPath, TMP_DIR
|
||||||
|
|
||||||
|
|
||||||
class ValidationErrorLoggingRoute(APIRoute):
|
class ValidationErrorLoggingRoute(APIRoute):
|
||||||
@ -44,13 +44,13 @@ class MMVC_Rest:
|
|||||||
)
|
)
|
||||||
|
|
||||||
app_fastapi.mount(
|
app_fastapi.mount(
|
||||||
"/front", StaticFiles(directory=f'{frontend_path}', html=True), name="static")
|
"/front", StaticFiles(directory=f'{getFrontendPath()}', html=True), name="static")
|
||||||
|
|
||||||
app_fastapi.mount(
|
app_fastapi.mount(
|
||||||
"/trainer", StaticFiles(directory=f'{frontend_path}', html=True), name="static")
|
"/trainer", StaticFiles(directory=f'{getFrontendPath()}', html=True), name="static")
|
||||||
|
|
||||||
app_fastapi.mount(
|
app_fastapi.mount(
|
||||||
"/recorder", StaticFiles(directory=f'{frontend_path}', html=True), name="static")
|
"/recorder", StaticFiles(directory=f'{getFrontendPath()}', html=True), name="static")
|
||||||
app_fastapi.mount(
|
app_fastapi.mount(
|
||||||
"/tmp", StaticFiles(directory=f'{TMP_DIR}'), name="static")
|
"/tmp", StaticFiles(directory=f'{TMP_DIR}'), name="static")
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import socketio
|
|||||||
|
|
||||||
from sio.MMVC_SocketIOServer import MMVC_SocketIOServer
|
from sio.MMVC_SocketIOServer import MMVC_SocketIOServer
|
||||||
from voice_changer.VoiceChangerManager import VoiceChangerManager
|
from voice_changer.VoiceChangerManager import VoiceChangerManager
|
||||||
from const import frontend_path
|
from const import getFrontendPath
|
||||||
|
|
||||||
|
|
||||||
class MMVC_SocketIOApp():
|
class MMVC_SocketIOApp():
|
||||||
@ -15,19 +15,19 @@ class MMVC_SocketIOApp():
|
|||||||
other_asgi_app=app_fastapi,
|
other_asgi_app=app_fastapi,
|
||||||
static_files={
|
static_files={
|
||||||
'/assets/icons/github.svg': {
|
'/assets/icons/github.svg': {
|
||||||
'filename': f'{frontend_path}/assets/icons/github.svg',
|
'filename': f'{getFrontendPath()}/assets/icons/github.svg',
|
||||||
'content_type': 'image/svg+xml'
|
'content_type': 'image/svg+xml'
|
||||||
},
|
},
|
||||||
'/assets/icons/help-circle.svg': {
|
'/assets/icons/help-circle.svg': {
|
||||||
'filename': f'{frontend_path}/assets/icons/help-circle.svg',
|
'filename': f'{getFrontendPath()}/assets/icons/help-circle.svg',
|
||||||
'content_type': 'image/svg+xml'
|
'content_type': 'image/svg+xml'
|
||||||
},
|
},
|
||||||
'/buymeacoffee.png': {
|
'/buymeacoffee.png': {
|
||||||
'filename': f'{frontend_path}/assets/buymeacoffee.png',
|
'filename': f'{getFrontendPath()}/assets/buymeacoffee.png',
|
||||||
'content_type': 'image/png'
|
'content_type': 'image/png'
|
||||||
},
|
},
|
||||||
'': f'{frontend_path}',
|
'': f'{getFrontendPath()}',
|
||||||
'/': f'{frontend_path}/index.html',
|
'/': f'{getFrontendPath()}/index.html',
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from const import TMP_DIR, MODEL_TYPE
|
from const import TMP_DIR, getModelType
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
@ -7,11 +7,6 @@ from dataclasses import dataclass, asdict
|
|||||||
import resampy
|
import resampy
|
||||||
|
|
||||||
|
|
||||||
if MODEL_TYPE == "MMVCv15":
|
|
||||||
from voice_changer.MMVCv15.MMVCv15 import MMVCv15
|
|
||||||
else:
|
|
||||||
from voice_changer.MMVCv13.MMVCv13 import MMVCv13
|
|
||||||
|
|
||||||
from voice_changer.IORecorder import IORecorder
|
from voice_changer.IORecorder import IORecorder
|
||||||
from voice_changer.IOAnalyzer import IOAnalyzer
|
from voice_changer.IOAnalyzer import IOAnalyzer
|
||||||
|
|
||||||
@ -53,9 +48,13 @@ class VoiceChanger():
|
|||||||
self.currentCrossFadeEndRate = 0
|
self.currentCrossFadeEndRate = 0
|
||||||
self.currentCrossFadeOverlapSize = 0
|
self.currentCrossFadeOverlapSize = 0
|
||||||
|
|
||||||
if MODEL_TYPE == "MMVCv15":
|
modelType = getModelType()
|
||||||
|
print("[VoiceChanger] activate model type:", modelType)
|
||||||
|
if modelType == "MMVCv15":
|
||||||
|
from voice_changer.MMVCv15.MMVCv15 import MMVCv15
|
||||||
self.voiceChanger = MMVCv15()
|
self.voiceChanger = MMVCv15()
|
||||||
else:
|
else:
|
||||||
|
from voice_changer.MMVCv13.MMVCv13 import MMVCv13
|
||||||
self.voiceChanger = MMVCv13()
|
self.voiceChanger = MMVCv13()
|
||||||
|
|
||||||
self.gpu_num = torch.cuda.device_count()
|
self.gpu_num = torch.cuda.device_count()
|
||||||
|
Loading…
Reference in New Issue
Block a user