voice-changer/server/downloader/SampleDownloader.py

225 lines
8.3 KiB
Python
Raw Normal View History

2023-06-16 10:49:55 +03:00
import json
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Tuple
from const import RVCSampleMode, getSampleJsonAndModelIds
from data.ModelSample import ModelSamples, generateModelSample
2023-07-21 12:25:28 +03:00
from data.ModelSlot import DiffusionSVCModelSlot, ModelSlot, RVCModelSlot
from voice_changer.DiffusionSVC.DiffusionSVCModelSlotGenerator import DiffusionSVCModelSlotGenerator
2023-06-17 08:16:29 +03:00
from voice_changer.ModelSlotManager import ModelSlotManager
2023-06-21 03:18:51 +03:00
from voice_changer.RVC.RVCModelSlotGenerator import RVCModelSlotGenerator
2023-06-17 09:35:43 +03:00
from downloader.Downloader import download, download_no_tqdm
2023-06-16 10:49:55 +03:00
def downloadInitialSamples(mode: RVCSampleMode, model_dir: str):
sampleJsonUrls, sampleModels = getSampleJsonAndModelIds(mode)
sampleJsons = _downloadSampleJsons(sampleJsonUrls)
if os.path.exists(model_dir):
2023-06-21 01:23:13 +03:00
print("[Voice Changer] model_dir is already exists. skip download samples.")
2023-06-16 10:49:55 +03:00
return
samples = _generateSampleList(sampleJsons)
slotIndex = list(range(len(sampleModels)))
_downloadSamples(samples, sampleModels, model_dir, slotIndex)
pass
def downloadSample(mode: RVCSampleMode, modelId: str, model_dir: str, slotIndex: int, params: Any):
sampleJsonUrls, _sampleModels = getSampleJsonAndModelIds(mode)
sampleJsons = _generateSampleJsons(sampleJsonUrls)
samples = _generateSampleList(sampleJsons)
2023-06-17 08:16:29 +03:00
_downloadSamples(samples, [(modelId, params)], model_dir, [slotIndex], withoutTqdm=True)
2023-06-16 10:49:55 +03:00
pass
def getSampleInfos(mode: RVCSampleMode):
sampleJsonUrls, _sampleModels = getSampleJsonAndModelIds(mode)
sampleJsons = _generateSampleJsons(sampleJsonUrls)
samples = _generateSampleList(sampleJsons)
return samples
def _downloadSampleJsons(sampleJsonUrls: list[str]):
sampleJsons = []
for url in sampleJsonUrls:
filename = os.path.basename(url)
download_no_tqdm({"url": url, "saveTo": filename, "position": 0})
sampleJsons.append(filename)
return sampleJsons
def _generateSampleJsons(sampleJsonUrls: list[str]):
sampleJsons = []
for url in sampleJsonUrls:
filename = os.path.basename(url)
sampleJsons.append(filename)
return sampleJsons
def _generateSampleList(sampleJsons: list[str]):
samples: list[ModelSamples] = []
for file in sampleJsons:
with open(file, "r", encoding="utf-8") as f:
jsonDict = json.load(f)
for vcType in jsonDict:
for sampleParams in jsonDict[vcType]:
sample = generateModelSample(sampleParams)
samples.append(sample)
return samples
2023-06-17 08:16:29 +03:00
def _downloadSamples(samples: list[ModelSamples], sampleModelIds: list[Tuple[str, Any]], model_dir: str, slotIndex: list[int], withoutTqdm=False):
2023-06-16 10:49:55 +03:00
downloadParams = []
line_num = 0
2023-06-17 08:16:29 +03:00
modelSlotManager = ModelSlotManager.get_instance(model_dir)
2023-06-16 10:49:55 +03:00
for i, initSampleId in enumerate(sampleModelIds):
targetSampleId = initSampleId[0]
targetSampleParams = initSampleId[1]
2023-06-18 03:24:47 +03:00
targetSlotIndex = slotIndex[i]
2023-06-16 10:49:55 +03:00
# 初期サンプルをサーチ
match = False
for sample in samples:
if sample.id == targetSampleId:
match = True
break
if match is False:
print(f"[Voice Changer] initiail sample not found. {targetSampleId}")
continue
# 検出されたら、、、
2023-06-18 03:24:47 +03:00
slotDir = os.path.join(model_dir, str(targetSlotIndex))
2023-07-21 12:25:28 +03:00
slotInfo: ModelSlot = ModelSlot()
2023-06-16 10:49:55 +03:00
if sample.voiceChangerType == "RVC":
slotInfo: RVCModelSlot = RVCModelSlot()
os.makedirs(slotDir, exist_ok=True)
modelFilePath = os.path.join(
slotDir,
os.path.basename(sample.modelUrl),
)
downloadParams.append(
{
"url": sample.modelUrl,
"saveTo": modelFilePath,
"position": line_num,
}
)
slotInfo.modelFile = modelFilePath
line_num += 1
if targetSampleParams["useIndex"] is True and hasattr(sample, "indexUrl") and sample.indexUrl != "":
indexPath = os.path.join(
slotDir,
os.path.basename(sample.indexUrl),
)
downloadParams.append(
{
"url": sample.indexUrl,
"saveTo": indexPath,
"position": line_num,
}
)
slotInfo.indexFile = indexPath
line_num += 1
if hasattr(sample, "icon") and sample.icon != "":
iconPath = os.path.join(
slotDir,
os.path.basename(sample.icon),
)
downloadParams.append(
{
"url": sample.icon,
"saveTo": iconPath,
"position": line_num,
}
)
slotInfo.iconFile = iconPath
line_num += 1
slotInfo.sampleId = sample.id
slotInfo.credit = sample.credit
slotInfo.description = sample.description
slotInfo.name = sample.name
slotInfo.termsOfUseUrl = sample.termsOfUseUrl
slotInfo.defaultTune = 0
2023-06-23 16:34:09 +03:00
slotInfo.defaultIndexRatio = 0
2023-06-16 10:49:55 +03:00
slotInfo.defaultProtect = 0.5
slotInfo.isONNX = slotInfo.modelFile.endswith(".onnx")
2023-06-18 03:24:47 +03:00
modelSlotManager.save_model_slot(targetSlotIndex, slotInfo)
2023-07-21 12:25:28 +03:00
elif sample.voiceChangerType == "Diffusion-SVC":
slotInfo: DiffusionSVCModelSlot = DiffusionSVCModelSlot()
os.makedirs(slotDir, exist_ok=True)
modelFilePath = os.path.join(
slotDir,
os.path.basename(sample.modelUrl),
)
downloadParams.append(
{
"url": sample.modelUrl,
"saveTo": modelFilePath,
"position": line_num,
}
)
slotInfo.modelFile = modelFilePath
line_num += 1
if hasattr(sample, "icon") and sample.icon != "":
iconPath = os.path.join(
slotDir,
os.path.basename(sample.icon),
)
downloadParams.append(
{
"url": sample.icon,
"saveTo": iconPath,
"position": line_num,
}
)
slotInfo.iconFile = iconPath
line_num += 1
slotInfo.sampleId = sample.id
slotInfo.credit = sample.credit
slotInfo.description = sample.description
slotInfo.name = sample.name
slotInfo.termsOfUseUrl = sample.termsOfUseUrl
slotInfo.defaultTune = 0
slotInfo.defaultKstep = 0
slotInfo.defaultSpeedup = 0
slotInfo.kStepMax = 0
slotInfo.isONNX = slotInfo.modelFile.endswith(".onnx")
modelSlotManager.save_model_slot(targetSlotIndex, slotInfo)
2023-06-16 10:49:55 +03:00
else:
print(f"[Voice Changer] {sample.voiceChangerType} is not supported.")
# ダウンロード
print("[Voice Changer] Downloading model files...")
2023-06-17 08:16:29 +03:00
if withoutTqdm:
with ThreadPoolExecutor() as pool:
pool.map(download_no_tqdm, downloadParams)
else:
with ThreadPoolExecutor() as pool:
pool.map(download, downloadParams)
2023-06-16 10:49:55 +03:00
# メタデータ作成
print("[Voice Changer] Generating metadata...")
for targetSlotIndex in slotIndex:
2023-06-17 08:16:29 +03:00
slotInfo = modelSlotManager.get_slot_info(targetSlotIndex)
2023-06-16 10:49:55 +03:00
if slotInfo.voiceChangerType == "RVC":
if slotInfo.isONNX:
slotInfo = RVCModelSlotGenerator._setInfoByONNX(slotInfo)
2023-06-16 10:49:55 +03:00
else:
slotInfo = RVCModelSlotGenerator._setInfoByPytorch(slotInfo)
2023-07-07 21:18:23 +03:00
2023-06-18 03:24:47 +03:00
modelSlotManager.save_model_slot(targetSlotIndex, slotInfo)
2023-07-21 12:25:28 +03:00
elif slotInfo.voiceChangerType == "Diffusion-SVC":
if slotInfo.isONNX:
pass
else:
slotInfo = DiffusionSVCModelSlotGenerator._setInfoByPytorch(slotInfo)
modelSlotManager.save_model_slot(targetSlotIndex, slotInfo)