This commit is contained in:
wataru 2023-05-03 22:36:59 +09:00
parent f2ba7e3938
commit 3a4e534091
2 changed files with 65 additions and 3 deletions

View File

@ -1,3 +1,4 @@
from concurrent.futures import ThreadPoolExecutor
import sys
from distutils.util import strtobool
@ -6,6 +7,9 @@ import socket
import platform
import os
import argparse
import requests
from tqdm import tqdm
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
import uvicorn
@ -106,6 +110,36 @@ def localServer():
)
def download(params):
url = params["url"]
saveTo = params["saveTo"]
position = params["position"]
dirname = os.path.dirname(saveTo)
os.makedirs(dirname, exist_ok=True)
try:
req = requests.get(url, stream=True, allow_redirects=True)
content_length = req.headers.get("content-length")
progress_bar = tqdm(
total=int(content_length) if content_length is not None else None,
leave=False,
unit="B",
unit_scale=True,
unit_divisor=1024,
position=position,
)
# with tqdm
with open(saveTo, "wb") as f:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
progress_bar.update(len(chunk))
f.write(chunk)
except Exception as e:
print(e)
if __name__ == "MMVCServerSIO":
voiceChangerParams = VoiceChangerParams(
content_vec_500=args.content_vec_500,
@ -116,9 +150,36 @@ if __name__ == "MMVCServerSIO":
hubert_soft=args.hubert_soft,
nsf_hifigan=args.nsf_hifigan,
)
voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams)
print("voiceChangerManager", voiceChangerManager)
# file exists check (currently only for rvc)
downloadParams = []
if os.path.exists(voiceChangerParams.hubert_base) is False:
downloadParams.append(
{
"url": "https://huggingface.co/ddPn08/rvc-webui-models/resolve/main/embeddings/hubert_base.pt",
"saveTo": voiceChangerParams.hubert_base,
"position": 0,
}
)
if os.path.exists(voiceChangerParams.hubert_base_jp) is False:
downloadParams.append(
{
"url": "https://huggingface.co/rinna/japanese-hubert-base/resolve/main/fairseq/model.pt",
"saveTo": voiceChangerParams.hubert_base_jp,
"position": 1,
}
)
with ThreadPoolExecutor() as pool:
pool.map(download, downloadParams)
if (
os.path.exists(voiceChangerParams.hubert_base) is False
or os.path.exists(voiceChangerParams.hubert_base_jp) is False
):
printMessage("RVC用のモデルファイルのダウンロードに失敗しました。", level=2)
printMessage("failed to download weight for rvc", level=2)
voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams)
app_fastapi = MMVC_Rest.get_instance(voiceChangerManager)
app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager)

View File

@ -114,7 +114,8 @@ class RVC:
if modelSlot.embedder == EnumEmbedderTypes.hubert:
emmbedderFilename = self.params.hubert_base
elif modelSlot.embedder == EnumEmbedderTypes.contentvec:
emmbedderFilename = self.params.content_vec_500
# emmbedderFilename = self.params.content_vec_500
emmbedderFilename = self.params.hubert_base
elif modelSlot.embedder == EnumEmbedderTypes.hubert_jp:
emmbedderFilename = self.params.hubert_base_jp
else: