voice-changer/server/voice_changer/RVC/SampleDownloader.py

175 lines
5.4 KiB
Python
Raw Normal View History

2023-05-17 20:51:40 +03:00
from concurrent.futures import ThreadPoolExecutor
2023-05-30 20:26:16 +03:00
from dataclasses import asdict
2023-05-17 20:51:40 +03:00
import os
2023-06-04 23:00:57 +03:00
from const import RVC_MODEL_DIRNAME, TMP_DIR
2023-05-17 20:51:40 +03:00
from Downloader import download, download_no_tqdm
from ModelSample import RVCModelSample, getModelSamples
import json
2023-05-30 20:26:16 +03:00
from voice_changer.RVC.ModelSlot import ModelSlot
from voice_changer.RVC.ModelSlotGenerator import _setInfoByONNX, _setInfoByPytorch
2023-05-17 20:51:40 +03:00
def checkRvcModelExist(model_dir: str):
rvcModelDir = os.path.join(model_dir, RVC_MODEL_DIRNAME)
if not os.path.exists(rvcModelDir):
return False
return True
2023-06-04 23:00:57 +03:00
def downloadInitialSampleModels(
sampleJsons: list[str], sampleModelIds: list[str], model_dir: str
):
2023-05-20 09:54:00 +03:00
sampleModels = getModelSamples(sampleJsons, "RVC")
2023-05-17 20:51:40 +03:00
if sampleModels is None:
return
downloadParams = []
slot_count = 0
line_num = 0
2023-05-25 05:40:37 +03:00
for initSampleId in sampleModelIds:
# 初期サンプルをサーチ
match = False
for sample in sampleModels:
if sample.id == initSampleId[0]:
match = True
break
if match is False:
print(f"[Voice Changer] initiail sample not found. {initSampleId[0]}")
continue
2023-05-17 20:51:40 +03:00
2023-05-25 05:40:37 +03:00
# 検出されたら、、、
2023-05-30 20:26:16 +03:00
slotInfo: ModelSlot = ModelSlot()
# sampleParams: Any = {"files": {}}
2023-05-25 05:40:37 +03:00
slotDir = os.path.join(model_dir, RVC_MODEL_DIRNAME, str(slot_count))
os.makedirs(slotDir, exist_ok=True)
modelFilePath = os.path.join(
slotDir,
os.path.basename(sample.modelUrl),
)
downloadParams.append(
{
"url": sample.modelUrl,
"saveTo": modelFilePath,
"position": line_num,
}
)
2023-05-30 20:26:16 +03:00
slotInfo.modelFile = modelFilePath
2023-05-25 05:40:37 +03:00
line_num += 1
if (
initSampleId[1] is True
and hasattr(sample, "indexUrl")
and sample.indexUrl != ""
):
indexPath = os.path.join(
2023-05-17 20:51:40 +03:00
slotDir,
2023-05-25 05:40:37 +03:00
os.path.basename(sample.indexUrl),
2023-05-17 20:51:40 +03:00
)
downloadParams.append(
{
2023-05-25 05:40:37 +03:00
"url": sample.indexUrl,
"saveTo": indexPath,
2023-05-17 20:51:40 +03:00
"position": line_num,
}
)
2023-05-30 20:26:16 +03:00
slotInfo.indexFile = indexPath
2023-05-17 20:51:40 +03:00
line_num += 1
2023-06-07 01:30:09 +03:00
if hasattr(sample, "icon") and sample.icon != "":
iconPath = os.path.join(
slotDir,
os.path.basename(sample.icon),
)
downloadParams.append(
{
"url": sample.icon,
"saveTo": iconPath,
"position": line_num,
}
)
slotInfo.iconFile = iconPath
line_num += 1
2023-05-17 20:51:40 +03:00
2023-05-30 20:26:16 +03:00
slotInfo.sampleId = sample.id
slotInfo.credit = sample.credit
slotInfo.description = sample.description
slotInfo.name = sample.name
slotInfo.termsOfUseUrl = sample.termsOfUseUrl
slotInfo.defaultTune = 0
slotInfo.defaultIndexRatio = 1
2023-06-01 07:28:45 +03:00
slotInfo.defaultProtect = 0.5
2023-05-30 20:26:16 +03:00
slotInfo.isONNX = slotInfo.modelFile.endswith(".onnx")
# この時点ではまだファイルはダウンロードされていない
# if slotInfo.isONNX:
# _setInfoByONNX(slotInfo)
# else:
# _setInfoByPytorch(slotInfo)
json.dump(asdict(slotInfo), open(os.path.join(slotDir, "params.json"), "w"))
2023-05-25 05:40:37 +03:00
slot_count += 1
2023-05-17 20:51:40 +03:00
2023-05-30 20:26:16 +03:00
# ダウンロード
2023-05-17 20:51:40 +03:00
print("[Voice Changer] Downloading model files...")
with ThreadPoolExecutor() as pool:
pool.map(download, downloadParams)
2023-05-30 20:26:16 +03:00
# メタデータ作成
print("[Voice Changer] Generating metadata...")
for slotId in range(slot_count):
slotDir = os.path.join(model_dir, RVC_MODEL_DIRNAME, str(slotId))
jsonDict = json.load(open(os.path.join(slotDir, "params.json")))
slotInfo = ModelSlot(**jsonDict)
if slotInfo.isONNX:
_setInfoByONNX(slotInfo)
else:
_setInfoByPytorch(slotInfo)
json.dump(asdict(slotInfo), open(os.path.join(slotDir, "params.json"), "w"))
2023-05-17 20:51:40 +03:00
2023-05-25 05:40:37 +03:00
def downloadModelFiles(sampleInfo: RVCModelSample, useIndex: bool = True):
2023-05-17 20:51:40 +03:00
downloadParams = []
modelPath = os.path.join(TMP_DIR, os.path.basename(sampleInfo.modelUrl))
downloadParams.append(
{
"url": sampleInfo.modelUrl,
"saveTo": modelPath,
"position": 0,
}
)
indexPath = None
2023-05-25 05:40:37 +03:00
if (
useIndex is True
and hasattr(sampleInfo, "indexUrl")
and sampleInfo.indexUrl != ""
):
print("[Voice Changer] Download sample with index.")
2023-05-17 20:51:40 +03:00
indexPath = os.path.join(TMP_DIR, os.path.basename(sampleInfo.indexUrl))
downloadParams.append(
{
"url": sampleInfo.indexUrl,
"saveTo": indexPath,
"position": 1,
}
)
2023-06-10 07:01:46 +03:00
iconPath = None
if hasattr(sampleInfo, "icon") and sampleInfo.icon != "":
iconPath = os.path.join(TMP_DIR, os.path.basename(sampleInfo.icon))
downloadParams.append(
{
"url": sampleInfo.icon,
"saveTo": iconPath,
"position": 2,
}
)
2023-05-17 20:51:40 +03:00
print("[Voice Changer] Downloading model files...", end="")
with ThreadPoolExecutor() as pool:
pool.map(download_no_tqdm, downloadParams)
print("")
2023-06-10 07:01:46 +03:00
return modelPath, indexPath, iconPath