This commit is contained in:
wataru 2022-11-12 10:27:34 +09:00
parent 22d9f723b3
commit 10fc4d9678
5 changed files with 244 additions and 191 deletions

View File

@ -1,61 +1,65 @@
import sys, os, struct, argparse, logging, shutil, base64, traceback import sys, os, struct, argparse, logging, shutil, base64, traceback
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime
from distutils.util import strtobool
import numpy as np
from scipy.io.wavfile import write, read
sys.path.append("/MMVC_Trainer") sys.path.append("/MMVC_Trainer")
sys.path.append("/MMVC_Trainer/text") sys.path.append("/MMVC_Trainer/text")
import uvicorn from fastapi.routing import APIRoute
from fastapi import FastAPI, UploadFile, File, Form from fastapi import HTTPException, Request, Response, FastAPI, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import socketio
from pydantic import BaseModel from pydantic import BaseModel
from scipy.io.wavfile import write, read from typing import Callable
import socketio from mods.Trainer_Speakers import mod_get_speakers
from distutils.util import strtobool from mods.Trainer_Training import mod_post_pre_training, mod_post_start_training, mod_post_stop_training, mod_get_related_files, mod_get_tail_training_log
from datetime import datetime from mods.Trainer_Model import mod_get_model, mod_delete_model
from mods.Trainer_Models import mod_get_models
import torch from mods.Trainer_MultiSpeakerSetting import mod_get_multi_speaker_setting, mod_post_multi_speaker_setting
import numpy as np from mods.Trainer_Speaker_Voice import mod_get_speaker_voice
from mods.Trainer_Speaker_Voices import mod_get_speaker_voices
from mods.Trainer_Speaker import mod_delete_speaker
from mods.ssl import create_self_signed_cert from mods.FileUploader import upload_file, concat_file_chunks
from mods.VoiceChanger import VoiceChanger from mods.VoiceChanger import VoiceChanger
from mods.ssl import create_self_signed_cert
# File Uploader # File Uploader
from mods.FileUploader import upload_file, concat_file_chunks
# Trainer Rest Internal # Trainer Rest Internal
from mods.Trainer_Speakers import mod_get_speakers
from mods.Trainer_Speaker import mod_delete_speaker
from mods.Trainer_Speaker_Voices import mod_get_speaker_voices
from mods.Trainer_Speaker_Voice import mod_get_speaker_voice
from mods.Trainer_MultiSpeakerSetting import mod_get_multi_speaker_setting, mod_post_multi_speaker_setting
from mods.Trainer_Models import mod_get_models
from mods.Trainer_Model import mod_get_model, mod_delete_model
from mods.Trainer_Training import mod_post_pre_training, mod_post_start_training, mod_post_stop_training, mod_get_related_files, mod_get_tail_training_log
class UvicornSuppressFilter(logging.Filter): class UvicornSuppressFilter(logging.Filter):
def filter(self, record): def filter(self, record):
return False return False
logger = logging.getLogger("uvicorn.error") logger = logging.getLogger("uvicorn.error")
logger.addFilter(UvicornSuppressFilter()) logger.addFilter(UvicornSuppressFilter())
# logger.propagate = False # logger.propagate = False
logger = logging.getLogger("multipart.multipart") logger = logging.getLogger("multipart.multipart")
logger.propagate = False logger.propagate = False
@dataclass @dataclass
class ExApplicationInfo(): class ExApplicationInfo():
external_tensorboard_port:int external_tensorboard_port: int
exApplitionInfo = ExApplicationInfo(external_tensorboard_port=0) exApplitionInfo = ExApplicationInfo(external_tensorboard_port=0)
class VoiceModel(BaseModel): class VoiceModel(BaseModel):
gpu: int gpu: int
srcId: int srcId: int
@ -88,14 +92,12 @@ class MyCustomNamespace(socketio.AsyncNamespace):
print("Voice Change is not loaded. Did you load a correct model?") print("Voice Change is not loaded. Did you load a correct model?")
return np.zeros(1).astype(np.int16) return np.zeros(1).astype(np.int16)
# def transcribe(self): # def transcribe(self):
# if hasattr(self, 'whisper') == True: # if hasattr(self, 'whisper') == True:
# self.whisper.transcribe(0) # self.whisper.transcribe(0)
# else: # else:
# print("whisper not found") # print("whisper not found")
def on_connect(self, sid, environ): def on_connect(self, sid, environ):
# print('[{}] connet sid : {}'.format(datetime.now().strftime('%Y-%m-%d %H:%M:%S') , sid)) # print('[{}] connet sid : {}'.format(datetime.now().strftime('%Y-%m-%d %H:%M:%S') , sid))
pass pass
@ -109,29 +111,39 @@ class MyCustomNamespace(socketio.AsyncNamespace):
prefixChunkSize = int(msg[4]) prefixChunkSize = int(msg[4])
data = msg[5] data = msg[5]
# print(srcId, dstId, timestamp) # print(srcId, dstId, timestamp)
unpackedData = np.array(struct.unpack('<%sh'%(len(data) // struct.calcsize('<h') ), data)) unpackedData = np.array(struct.unpack(
audio1 = self.changeVoice(gpu, srcId, dstId, timestamp, prefixChunkSize, unpackedData) '<%sh' % (len(data) // struct.calcsize('<h')), data))
audio1 = self.changeVoice(
gpu, srcId, dstId, timestamp, prefixChunkSize, unpackedData)
bin = struct.pack('<%sh'%len(audio1), *audio1) bin = struct.pack('<%sh' % len(audio1), *audio1)
await self.emit('response',[timestamp, bin]) await self.emit('response', [timestamp, bin])
def on_disconnect(self, sid): def on_disconnect(self, sid):
# print('[{}] disconnect'.format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) # print('[{}] disconnect'.format(datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
pass; pass
def setupArgParser(): def setupArgParser():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-t", type=str, default="MMVC",
help="Server type. MMVC|TRAIN")
parser.add_argument("-p", type=int, default=8080, help="port") parser.add_argument("-p", type=int, default=8080, help="port")
parser.add_argument("-c", type=str, help="path for the config.json") parser.add_argument("-c", type=str, help="path for the config.json")
parser.add_argument("-m", type=str, help="path for the model file") parser.add_argument("-m", type=str, help="path for the model file")
parser.add_argument("--https", type=strtobool, default=False, help="use https") parser.add_argument("--https", type=strtobool,
parser.add_argument("--httpsKey", type=str, default="ssl.key", help="path for the key of https") default=False, help="use https")
parser.add_argument("--httpsCert", type=str, default="ssl.cert", help="path for the cert of https") parser.add_argument("--httpsKey", type=str,
parser.add_argument("--httpsSelfSigned", type=strtobool, default=True, help="generate self-signed certificate") default="ssl.key", help="path for the key of https")
parser.add_argument("--colab", type=strtobool, default=False, help="run on colab") parser.add_argument("--httpsCert", type=str,
default="ssl.cert", help="path for the cert of https")
parser.add_argument("--httpsSelfSigned", type=strtobool,
default=True, help="generate self-signed certificate")
parser.add_argument("--colab", type=strtobool,
default=False, help="run on colab")
return parser return parser
def printMessage(message, level=0): def printMessage(message, level=0):
if level == 0: if level == 0:
print(f"\033[17m{message}\033[0m") print(f"\033[17m{message}\033[0m")
@ -142,6 +154,7 @@ def printMessage(message, level=0):
else: else:
print(f"\033[47m {message}\033[0m") print(f"\033[47m {message}\033[0m")
global app_socketio global app_socketio
global app_fastapi global app_fastapi
@ -151,10 +164,7 @@ args = parser.parse_args()
printMessage(f"Phase name:{__name__}", level=2) printMessage(f"Phase name:{__name__}", level=2)
thisFilename = os.path.basename(__file__)[:-3] thisFilename = os.path.basename(__file__)[:-3]
from typing import Callable, List
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.routing import APIRoute
class ValidationErrorLoggingRoute(APIRoute): class ValidationErrorLoggingRoute(APIRoute):
def get_route_handler(self) -> Callable: def get_route_handler(self) -> Callable:
original_route_handler = super().get_route_handler() original_route_handler = super().get_route_handler()
@ -170,8 +180,10 @@ class ValidationErrorLoggingRoute(APIRoute):
return custom_route_handler return custom_route_handler
if __name__ == thisFilename or args.colab == True: if __name__ == thisFilename or args.colab == True:
printMessage(f"PHASE3:{__name__}", level=2) printMessage(f"PHASE3:{__name__}", level=2)
TYPE = args.t
PORT = args.p PORT = args.p
CONFIG = args.c CONFIG = args.c
MODEL = args.m MODEL = args.m
@ -186,11 +198,14 @@ if __name__ == thisFilename or args.colab == True:
allow_headers=["*"], allow_headers=["*"],
) )
app_fastapi.mount("/front", StaticFiles(directory="../frontend/dist", html=True), name="static") app_fastapi.mount(
"/front", StaticFiles(directory="../frontend/dist", html=True), name="static")
app_fastapi.mount("/trainer", StaticFiles(directory="../frontend/dist", html=True), name="static") app_fastapi.mount(
"/trainer", StaticFiles(directory="../frontend/dist", html=True), name="static")
app_fastapi.mount("/recorder", StaticFiles(directory="../frontend/dist", html=True), name="static") app_fastapi.mount(
"/recorder", StaticFiles(directory="../frontend/dist", html=True), name="static")
sio = socketio.AsyncServer( sio = socketio.AsyncServer(
async_mode='asgi', async_mode='asgi',
@ -202,14 +217,13 @@ if __name__ == thisFilename or args.colab == True:
namespace.loadModel(CONFIG, MODEL) namespace.loadModel(CONFIG, MODEL)
# namespace.loadWhisperModel("base") # namespace.loadWhisperModel("base")
app_socketio = socketio.ASGIApp( app_socketio = socketio.ASGIApp(
sio, sio,
other_asgi_app=app_fastapi, other_asgi_app=app_fastapi,
static_files={ static_files={
'/assets/icons/github.svg': { '/assets/icons/github.svg': {
'filename':'../frontend/dist/assets/icons/github.svg', 'filename': '../frontend/dist/assets/icons/github.svg',
'content_type':'image/svg+xml' 'content_type': 'image/svg+xml'
}, },
'': '../frontend/dist', '': '../frontend/dist',
'/': '../frontend/dist/index.html', '/': '../frontend/dist/index.html',
@ -220,7 +234,6 @@ if __name__ == thisFilename or args.colab == True:
async def index(): async def index():
return {"result": "Index"} return {"result": "Index"}
############ ############
# File Uploder # File Uploder
# ########## # ##########
@ -231,7 +244,7 @@ if __name__ == thisFilename or args.colab == True:
@app_fastapi.post("/upload_file") @app_fastapi.post("/upload_file")
async def post_upload_file( async def post_upload_file(
file:UploadFile = File(...), file: UploadFile = File(...),
filename: str = Form(...) filename: str = Form(...)
): ):
return upload_file(UPLOAD_DIR, file, filename) return upload_file(UPLOAD_DIR, file, filename)
@ -243,7 +256,8 @@ if __name__ == thisFilename or args.colab == True:
configFilename: str = Form(...) configFilename: str = Form(...)
): ):
modelFilePath = concat_file_chunks(UPLOAD_DIR, modelFilename, modelFilenameChunkNum,UPLOAD_DIR) modelFilePath = concat_file_chunks(
UPLOAD_DIR, modelFilename, modelFilenameChunkNum, UPLOAD_DIR)
print(f'File saved to: {modelFilePath}') print(f'File saved to: {modelFilePath}')
configFilePath = os.path.join(UPLOAD_DIR, configFilename) configFilePath = os.path.join(UPLOAD_DIR, configFilename)
@ -258,28 +272,28 @@ if __name__ == thisFilename or args.colab == True:
modelDFilenameChunkNum: int = Form(...), modelDFilenameChunkNum: int = Form(...),
): ):
modelGFilePath = concat_file_chunks(
modelGFilePath = concat_file_chunks(UPLOAD_DIR, modelGFilename, modelGFilenameChunkNum, MODEL_DIR) UPLOAD_DIR, modelGFilename, modelGFilenameChunkNum, MODEL_DIR)
modelDFilePath = concat_file_chunks(UPLOAD_DIR, modelDFilename, modelDFilenameChunkNum,MODEL_DIR) modelDFilePath = concat_file_chunks(
UPLOAD_DIR, modelDFilename, modelDFilenameChunkNum, MODEL_DIR)
return {"File saved": f"{modelGFilePath}, {modelDFilePath}"} return {"File saved": f"{modelGFilePath}, {modelDFilePath}"}
@app_fastapi.post("/extract_voices") @app_fastapi.post("/extract_voices")
async def post_load_model( async def post_load_model(
zipFilename: str = Form(...), zipFilename: str = Form(...),
zipFileChunkNum: int = Form(...), zipFileChunkNum: int = Form(...),
): ):
zipFilePath = concat_file_chunks(UPLOAD_DIR, zipFilename, zipFileChunkNum, UPLOAD_DIR) zipFilePath = concat_file_chunks(
UPLOAD_DIR, zipFilename, zipFileChunkNum, UPLOAD_DIR)
shutil.unpack_archive(zipFilePath, "/MMVC_Trainer/dataset/textful/") shutil.unpack_archive(zipFilePath, "/MMVC_Trainer/dataset/textful/")
return {"Zip file unpacked": f"{zipFilePath}"} return {"Zip file unpacked": f"{zipFilePath}"}
############ ############
# Voice Changer # Voice Changer
# ########## # ##########
@app_fastapi.post("/test") @app_fastapi.post("/test")
async def post_test(voice:VoiceModel): async def post_test(voice: VoiceModel):
try: try:
# print("POST REQUEST PROCESSING....") # print("POST REQUEST PROCESSING....")
gpu = voice.gpu gpu = voice.gpu
@ -290,23 +304,26 @@ if __name__ == thisFilename or args.colab == True:
buffer = voice.buffer buffer = voice.buffer
wav = base64.b64decode(buffer) wav = base64.b64decode(buffer)
if wav==0: if wav == 0:
samplerate, data=read("dummy.wav") samplerate, data = read("dummy.wav")
unpackedData = data unpackedData = data
else: else:
unpackedData = np.array(struct.unpack('<%sh'%(len(wav) // struct.calcsize('<h') ), wav)) unpackedData = np.array(struct.unpack(
write("logs/received_data.wav", 24000, unpackedData.astype(np.int16)) '<%sh' % (len(wav) // struct.calcsize('<h')), wav))
write("logs/received_data.wav", 24000,
unpackedData.astype(np.int16))
changedVoice = namespace.changeVoice(gpu, srcId, dstId, timestamp, prefixChunkSize, unpackedData) changedVoice = namespace.changeVoice(
gpu, srcId, dstId, timestamp, prefixChunkSize, unpackedData)
changedVoiceBase64 = base64.b64encode(changedVoice).decode('utf-8') changedVoiceBase64 = base64.b64encode(changedVoice).decode('utf-8')
data = { data = {
"gpu":gpu, "gpu": gpu,
"srcId":srcId, "srcId": srcId,
"dstId":dstId, "dstId": dstId,
"timestamp":timestamp, "timestamp": timestamp,
"prefixChunkSize":prefixChunkSize, "prefixChunkSize": prefixChunkSize,
"changedVoiceBase64":changedVoiceBase64 "changedVoiceBase64": changedVoiceBase64
} }
json_compatible_item_data = jsonable_encoder(data) json_compatible_item_data = jsonable_encoder(data)
@ -318,25 +335,24 @@ if __name__ == thisFilename or args.colab == True:
print(traceback.format_exc()) print(traceback.format_exc())
return str(e) return str(e)
# Trainer REST API ※ ColabがTop直下のパスにしかPOSTを投げれないようなので"REST風" # Trainer REST API ※ ColabがTop直下のパスにしかPOSTを投げれないようなので"REST風"
@app_fastapi.get("/get_speakers") @app_fastapi.get("/get_speakers")
async def get_speakers(): async def get_speakers():
return mod_get_speakers() return mod_get_speakers()
@app_fastapi.delete("/delete_speaker") @app_fastapi.delete("/delete_speaker")
async def delete_speaker(speaker:str= Form(...)): async def delete_speaker(speaker: str = Form(...)):
return mod_delete_speaker(speaker) return mod_delete_speaker(speaker)
@app_fastapi.get("/get_speaker_voices") @app_fastapi.get("/get_speaker_voices")
async def get_speaker_voices(speaker:str): async def get_speaker_voices(speaker: str):
return mod_get_speaker_voices(speaker) return mod_get_speaker_voices(speaker)
@app_fastapi.get("/get_speaker_voice") @app_fastapi.get("/get_speaker_voice")
async def get_speaker_voices(speaker:str, voice:str): async def get_speaker_voices(speaker: str, voice: str):
return mod_get_speaker_voice(speaker, voice) return mod_get_speaker_voice(speaker, voice)
@app_fastapi.get("/get_multi_speaker_setting") @app_fastapi.get("/get_multi_speaker_setting")
async def get_multi_speaker_setting(): async def get_multi_speaker_setting():
return mod_get_multi_speaker_setting() return mod_get_multi_speaker_setting()
@ -350,16 +366,15 @@ if __name__ == thisFilename or args.colab == True:
return mod_get_models() return mod_get_models()
@app_fastapi.get("/get_model") @app_fastapi.get("/get_model")
async def get_model(model:str): async def get_model(model: str):
return mod_get_model(model) return mod_get_model(model)
@app_fastapi.delete("/delete_model") @app_fastapi.delete("/delete_model")
async def delete_model(model:str= Form(...)): async def delete_model(model: str = Form(...)):
return mod_delete_model(model) return mod_delete_model(model)
@app_fastapi.post("/post_pre_training") @app_fastapi.post("/post_pre_training")
async def post_pre_training(batch:int= Form(...)): async def post_pre_training(batch: int = Form(...)):
return mod_post_pre_training(batch) return mod_post_pre_training(batch)
@app_fastapi.post("/post_start_training") @app_fastapi.post("/post_start_training")
@ -377,7 +392,7 @@ if __name__ == thisFilename or args.colab == True:
return mod_get_related_files() return mod_get_related_files()
@app_fastapi.get("/get_tail_training_log") @app_fastapi.get("/get_tail_training_log")
async def get_tail_training_log(num:int): async def get_tail_training_log(num: int):
return mod_get_tail_training_log(num) return mod_get_tail_training_log(num)
@app_fastapi.get("/get_ex_application_info") @app_fastapi.get("/get_ex_application_info")
@ -392,18 +407,23 @@ if __name__ == '__mp_main__':
if __name__ == '__main__': if __name__ == '__main__':
printMessage(f"PHASE1:{__name__}", level=2) printMessage(f"PHASE1:{__name__}", level=2)
TYPE = args.t
PORT = args.p PORT = args.p
CONFIG = args.c CONFIG = args.c
MODEL = args.m MODEL = args.m
if TYPE != "MMVC" and TYPE != "TRAIN":
print("Type(-t) should be MMVC or TRAIN")
exit(1)
printMessage(f"Start MMVC SocketIO Server", level=0) printMessage(f"Start MMVC SocketIO Server", level=0)
printMessage(f"CONFIG:{CONFIG}, MODEL:{MODEL}", level=1) printMessage(f"CONFIG:{CONFIG}, MODEL:{MODEL}", level=1)
if args.colab == False: if args.colab == False:
if os.getenv("EX_PORT"): if os.getenv("EX_PORT"):
EX_PORT = os.environ["EX_PORT"] EX_PORT = os.environ["EX_PORT"]
printMessage(f"External_Port:{EX_PORT} Internal_Port:{PORT}", level=1) printMessage(
f"External_Port:{EX_PORT} Internal_Port:{PORT}", level=1)
else: else:
printMessage(f"Internal_Port:{PORT}", level=1) printMessage(f"Internal_Port:{PORT}", level=1)
@ -423,39 +443,45 @@ if __name__ == '__main__':
key_base_name = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}" key_base_name = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}"
keyname = f"{key_base_name}.key" keyname = f"{key_base_name}.key"
certname = f"{key_base_name}.cert" certname = f"{key_base_name}.cert"
create_self_signed_cert(certname, keyname, certargs= create_self_signed_cert(certname, keyname, certargs={"Country": "JP",
{"Country": "JP",
"State": "Tokyo", "State": "Tokyo",
"City": "Chuo-ku", "City": "Chuo-ku",
"Organization": "F", "Organization": "F",
"Org. Unit": "F"}, cert_dir="./key") "Org. Unit": "F"}, cert_dir="./key")
key_path = os.path.join("./key", keyname) key_path = os.path.join("./key", keyname)
cert_path = os.path.join("./key", certname) cert_path = os.path.join("./key", certname)
printMessage(f"protocol: HTTPS(self-signed), key:{key_path}, cert:{cert_path}", level=1) printMessage(
f"protocol: HTTPS(self-signed), key:{key_path}, cert:{cert_path}", level=1)
elif args.https and args.httpsSelfSigned == 0: elif args.https and args.httpsSelfSigned == 0:
# HTTPS # HTTPS
key_path = args.httpsKey key_path = args.httpsKey
cert_path = args.httpsCert cert_path = args.httpsCert
printMessage(f"protocol: HTTPS, key:{key_path}, cert:{cert_path}", level=1) printMessage(
f"protocol: HTTPS, key:{key_path}, cert:{cert_path}", level=1)
else: else:
# HTTP # HTTP
printMessage(f"protocol: HTTP", level=1) printMessage(f"protocol: HTTP", level=1)
# アドレス表示 # アドレス表示
if args.https == 1: if args.https == 1:
printMessage(f"open https://<IP>:<PORT>/ with your browser.", level=0) printMessage(
f"open https://<IP>:<PORT>/ with your browser.", level=0)
else: else:
printMessage(f"open http://<IP>:<PORT>/ with your browser.", level=0) printMessage(
f"open http://<IP>:<PORT>/ with your browser.", level=0)
if TYPE == "MMVC":
path = ""
else:
path = "trainer"
if EX_PORT and EX_IP and args.https == 1: if EX_PORT and EX_IP and args.https == 1:
printMessage(f"In many cases it is one of the following", level=1) printMessage(f"In many cases it is one of the following", level=1)
printMessage(f"https://localhost:{EX_PORT}/", level=1) printMessage(f"https://localhost:{EX_PORT}/{path}", level=1)
for ip in EX_IP.strip().split(" "): for ip in EX_IP.strip().split(" "):
printMessage(f"https://{ip}:{EX_PORT}/", level=1) printMessage(f"https://{ip}:{EX_PORT}/{path}", level=1)
elif EX_PORT and EX_IP and args.https == 0: elif EX_PORT and EX_IP and args.https == 0:
printMessage(f"In many cases it is one of the following", level=1) printMessage(f"In many cases it is one of the following", level=1)
printMessage(f"http://localhost:{EX_PORT}/", level=1) printMessage(f"http://localhost:{EX_PORT}/{path}", level=1)
# サーバ起動 # サーバ起動
if args.https: if args.https:
@ -465,8 +491,8 @@ if __name__ == '__main__':
host="0.0.0.0", host="0.0.0.0",
port=int(PORT), port=int(PORT),
reload=True, reload=True,
ssl_keyfile = key_path, ssl_keyfile=key_path,
ssl_certfile = cert_path, ssl_certfile=cert_path,
log_level="critical" log_level="critical"
) )
else: else:
@ -486,5 +512,3 @@ if __name__ == '__main__':
reload=True, reload=True,
log_level="critical" log_level="critical"
) )

View File

@ -1,7 +1,7 @@
import torch import torch
from scipy.io.wavfile import write, read from scipy.io.wavfile import write, read
import numpy as np import numpy as np
import struct, traceback import traceback
import utils import utils
import commons import commons

View File

@ -1,28 +1,56 @@
#!/bin/bash #!/bin/bash
cp -r /resources/* .
TYPE=$1 set -eu
PARAMS=${@:2:($#-1)}
if [ $# = 0 ]; then
echo "
usage:
$0 -t <TYPE> <params...>
TYPE: select one of ['TRAIN', 'MMVC']
" >&2
exit 1
fi
# TYPE=$1
# PARAMS=${@:2:($#-1)}
# echo $TYPE
# echo $PARAMS
if [ -e /resources ]; then
echo "/resources の中身をコピーします。"
cp -r /resources/* .
else
echo "/resourcesが存在しません。デフォルトの動作をします。"
fi
echo $TYPE
echo $PARAMS
## Config 設置 ## Config 設置
if [[ -e ./setting.json ]]; then if [[ -e ./setting.json ]]; then
echo "カスタムセッティングを使用" echo "カスタムセッティングを使用"
cp ./setting.json ../frontend/dist/assets/setting.json cp ./setting.json ../frontend/dist/assets/setting.json
else
cp ../frontend/dist/assets/setting_mmvc.json ../frontend/dist/assets/setting.json
fi fi
echo "起動します" "$@"
python3 MMVCServerSIO.py "$@"
# 起動 ###
if [ "${TYPE}" = "MMVC" ] ; then # 起動パラメータ
echo "MMVCを起動します" # (1) トレーニングの場合
python3 MMVCServerSIO.py $PARAMS 2>stderr.txt # python3 MMVCServerSIO.py <type>
elif [ "${TYPE}" = "MMVC_VERBOSE" ] ; then # 環境変数:
echo "MMVCを起動します(verbose)" # ※ Colabの場合python3 MMVCServerSIO.py -t Train -p {PORT} --colab True
python3 MMVCServerSIO.py $PARAMS # (2) VCの場合
fi
# # 起動
# if [ "${TYPE}" = "MMVC" ] ; then
# elif [ "${TYPE}" = "MMVC_VERBOSE" ] ; then
# echo "MMVCを起動します(verbose)"
# python3 MMVCServerSIO.py $PARAMS
# fi

View File

@ -1,4 +1,4 @@
FROM dannadori/voice-changer-internal:20221112_092232 as front FROM dannadori/voice-changer-internal:20221112_102341 as front
FROM debian:bullseye-slim as base FROM debian:bullseye-slim as base
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive

View File

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
set -eu set -eu
DOCKER_IMAGE=dannadori/voice-changer:20221112_092328 DOCKER_IMAGE=dannadori/voice-changer:20221112_102442
# DOCKER_IMAGE=voice-changer # DOCKER_IMAGE=voice-changer
if [ $# = 0 ]; then if [ $# = 0 ]; then
@ -44,7 +44,7 @@ if [ "${MODE}" = "TRAIN" ]; then
-e EX_PORT=${EX_PORT} -e EX_TB_PORT=${EX_TB_PORT} \ -e EX_PORT=${EX_PORT} -e EX_TB_PORT=${EX_TB_PORT} \
-e EX_IP="`hostname -I`" \ -e EX_IP="`hostname -I`" \
-p ${EX_PORT}:8080 -p ${EX_TB_PORT}:6006 \ -p ${EX_PORT}:8080 -p ${EX_TB_PORT}:6006 \
$DOCKER_IMAGE "$@" $DOCKER_IMAGE -t TRAIN "$@"
elif [ "${MODE}" = "MMVC" ]; then elif [ "${MODE}" = "MMVC" ]; then
@ -66,7 +66,8 @@ elif [ "${MODE}" = "MMVC" ]; then
-e LOCAL_GID=$(id -g $USER) \ -e LOCAL_GID=$(id -g $USER) \
-e EX_IP="`hostname -I`" \ -e EX_IP="`hostname -I`" \
-e EX_PORT=${EX_PORT} \ -e EX_PORT=${EX_PORT} \
-p ${EX_PORT}:8080 $DOCKER_IMAGE "$@" -p ${EX_PORT}:8080 \
$DOCKER_IMAGE -t MMVC "$@"
fi fi
else else
echo " echo "