WIP:common sample

This commit is contained in:
wataru 2023-06-17 14:16:29 +09:00
parent 4a8851aff1
commit 9dd2808509
10 changed files with 35 additions and 146 deletions

View File

@ -1,56 +0,0 @@
import requests # type: ignore
import os
from tqdm import tqdm
def download(params):
url = params["url"]
saveTo = params["saveTo"]
position = params["position"]
dirname = os.path.dirname(saveTo)
if dirname != "":
os.makedirs(dirname, exist_ok=True)
try:
req = requests.get(url, stream=True, allow_redirects=True)
content_length = req.headers.get("content-length")
progress_bar = tqdm(
total=int(content_length) if content_length is not None else None,
leave=False,
unit="B",
unit_scale=True,
unit_divisor=1024,
position=position,
)
# with tqdm
with open(saveTo, "wb") as f:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
progress_bar.update(len(chunk))
f.write(chunk)
except Exception as e:
print(e)
def download_no_tqdm(params):
url = params["url"]
saveTo = params["saveTo"]
dirname = os.path.dirname(saveTo)
if dirname != "":
os.makedirs(dirname, exist_ok=True)
try:
req = requests.get(url, stream=True, allow_redirects=True)
with open(saveTo, "wb") as f:
countToDot = 0
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
countToDot += 1
if countToDot % 1024 == 0:
print(".", end="", flush=True)
print("+", end="", flush=True)
except Exception as e:
print(e)

View File

@ -1,44 +0,0 @@
from dataclasses import dataclass, field
import json
from const import ModelType
@dataclass
class RVCModelSample:
id: str = ""
lang: str = ""
tag: list[str] = field(default_factory=lambda: [])
name: str = ""
modelUrl: str = ""
indexUrl: str = ""
termsOfUseUrl: str = ""
icon: str = ""
credit: str = ""
description: str = ""
sampleRate: int = 48000
modelType: str = ""
f0: bool = True
def getModelSamples(jsonFiles: list[str], modelType: ModelType):
try:
samples: list[RVCModelSample] = []
for file in jsonFiles:
with open(file, "r", encoding="utf-8") as f:
jsonDict = json.load(f)
modelList = jsonDict[modelType]
if modelType == "RVC":
for s in modelList:
modelSample = RVCModelSample(**s)
samples.append(modelSample)
else:
raise RuntimeError(f"Unknown model type {modelType}")
return samples
except Exception as e:
print("[Voice Changer] loading sample info error:", e)
return None

View File

@ -23,11 +23,8 @@ ModelType: TypeAlias = Literal[
"RVC",
]
ERROR_NO_ONNX_SESSION = "ERROR_NO_ONNX_SESSION"
tmpdir = tempfile.TemporaryDirectory()
# print("generate tmpdir:::",tmpdir)
SSL_KEY_DIR = os.path.join(tmpdir.name, "keys") if hasattr(sys, "_MEIPASS") else "keys"
MODEL_DIR = os.path.join(tmpdir.name, "logs") if hasattr(sys, "_MEIPASS") else "logs"
UPLOAD_DIR = os.path.join(tmpdir.name, "upload_dir") if hasattr(sys, "_MEIPASS") else "upload_dir"

View File

@ -77,7 +77,7 @@ class MMVC_Rest_Fileuploader:
return JSONResponse(content=json_compatible_item_data)
except Exception as e:
print("[Voice Changer] ex:", e)
print(sys.exc_info()[2].tb_lineno)
print(sys.exc_info())
def post_load_model(
self,

View File

@ -5,7 +5,8 @@ from typing import Any, Tuple
from const import RVCSampleMode, getSampleJsonAndModelIds
from data.ModelSample import ModelSamples, generateModelSample
from data.ModelSlot import RVCModelSlot, loadSlotInfo, saveSlotInfo
from data.ModelSlot import RVCModelSlot
from voice_changer.ModelSlotManager import ModelSlotManager
from voice_changer.RVC.ModelSlotGenerator import _setInfoByONNX, _setInfoByPytorch
from utils.downloader.Downloader import download, download_no_tqdm
@ -27,7 +28,7 @@ def downloadSample(mode: RVCSampleMode, modelId: str, model_dir: str, slotIndex:
sampleJsonUrls, _sampleModels = getSampleJsonAndModelIds(mode)
sampleJsons = _generateSampleJsons(sampleJsonUrls)
samples = _generateSampleList(sampleJsons)
_downloadSamples(samples, [(modelId, params)], model_dir, [slotIndex])
_downloadSamples(samples, [(modelId, params)], model_dir, [slotIndex], withoutTqdm=True)
pass
@ -67,9 +68,10 @@ def _generateSampleList(sampleJsons: list[str]):
return samples
def _downloadSamples(samples: list[ModelSamples], sampleModelIds: list[Tuple[str, Any]], model_dir: str, slotIndex: list[int]):
def _downloadSamples(samples: list[ModelSamples], sampleModelIds: list[Tuple[str, Any]], model_dir: str, slotIndex: list[int], withoutTqdm=False):
downloadParams = []
line_num = 0
modelSlotManager = ModelSlotManager.get_instance(model_dir)
for i, initSampleId in enumerate(sampleModelIds):
targetSampleId = initSampleId[0]
@ -145,22 +147,26 @@ def _downloadSamples(samples: list[ModelSamples], sampleModelIds: list[Tuple[str
slotInfo.defaultIndexRatio = 1
slotInfo.defaultProtect = 0.5
slotInfo.isONNX = slotInfo.modelFile.endswith(".onnx")
saveSlotInfo(model_dir, tagetSlotIndex, slotInfo)
modelSlotManager.save_model_slot(tagetSlotIndex, slotInfo)
else:
print(f"[Voice Changer] {sample.voiceChangerType} is not supported.")
# ダウンロード
print("[Voice Changer] Downloading model files...")
with ThreadPoolExecutor() as pool:
pool.map(download, downloadParams)
if withoutTqdm:
with ThreadPoolExecutor() as pool:
pool.map(download_no_tqdm, downloadParams)
else:
with ThreadPoolExecutor() as pool:
pool.map(download, downloadParams)
# メタデータ作成
print("[Voice Changer] Generating metadata...")
for targetSlotIndex in slotIndex:
slotInfo = loadSlotInfo(model_dir, targetSlotIndex)
slotInfo = modelSlotManager.get_slot_info(targetSlotIndex)
if slotInfo.voiceChangerType == "RVC":
if slotInfo.isONNX:
_setInfoByONNX(slotInfo)
else:
_setInfoByPytorch(slotInfo)
saveSlotInfo(model_dir, targetSlotIndex, slotInfo)
modelSlotManager.save_model_slot(tagetSlotIndex, slotInfo)

View File

@ -1,6 +1,5 @@
from const import UPLOAD_DIR
from data.ModelSlot import ModelSlots, loadAllSlotInfo, saveSlotInfo
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
import json
import os
import shutil
@ -9,24 +8,26 @@ import shutil
class ModelSlotManager:
_instance = None
def __init__(self, params: VoiceChangerParams):
self.params = params
self.modelSlots = loadAllSlotInfo(self.params.model_dir)
def __init__(self, model_dir: str):
self.model_dir = model_dir
self.modelSlots = loadAllSlotInfo(self.model_dir)
@classmethod
def get_instance(cls, params: VoiceChangerParams):
def get_instance(cls, model_dir: str):
if cls._instance is None:
cls._instance = cls(params)
cls._instance = cls(model_dir)
return cls._instance
def _save_model_slot(self, slotIndex: int, slotInfo: ModelSlots):
saveSlotInfo(self.params.model_dir, slotIndex, slotInfo)
self.modelSlots = loadAllSlotInfo(self.params.model_dir)
saveSlotInfo(self.model_dir, slotIndex, slotInfo)
self.modelSlots = loadAllSlotInfo(self.model_dir)
def _load_model_slot(self, slotIndex: int):
return self.modelSlots[slotIndex]
def getAllSlotInfo(self):
def getAllSlotInfo(self, reload: bool = False):
if reload:
self.modelSlots = loadAllSlotInfo(self.model_dir)
return self.modelSlots
def get_slot_info(self, slotIndex: int):
@ -46,7 +47,7 @@ class ModelSlotManager:
print("[Voice Changer] UPLOAD ASSETS", params)
paramsDict = json.loads(params)
uploadPath = os.path.join(UPLOAD_DIR, paramsDict["file"])
storeDir = os.path.join(self.params.model_dir, str(paramsDict["slot"]))
storeDir = os.path.join(self.model_dir, str(paramsDict["slot"]))
storePath = os.path.join(
storeDir,
paramsDict["file"],

View File

@ -6,7 +6,6 @@ import numpy as np
import torch
import torchaudio
from data.ModelSlot import RVCModelSlot
from utils.downloader.SampleDownloader import getSampleInfos
from voice_changer.ModelSlotManager import ModelSlotManager
@ -67,13 +66,9 @@ class RVC:
self.params = params
EmbedderManager.initialize(params)
print("[Voice Changer] RVC initialization: ", params)
self.modelSlotManager = ModelSlotManager.get_instance(self.params)
self.modelSlotManager = ModelSlotManager.get_instance(self.params.model_dir)
# サンプルカタログ作成
samples = getSampleInfos(params.sample_mode)
self.settings.sampleModels = samples
# 起動時にスロットにモデルがある場合はロードしておく
allSlots = self.modelSlotManager.getAllSlotInfo()
availableIndex = -1
for i, slot in enumerate(allSlots):
@ -88,13 +83,6 @@ class RVC:
self.prevVol = 0.0
def getSampleInfo(self, id: str):
sampleInfos = list(filter(lambda x: x.id == id, self.settings.sampleModels))
if len(sampleInfos) > 0:
return sampleInfos[0]
else:
None
def moveToModelDir(self, file: str, dstDir: str):
dst = os.path.join(dstDir, os.path.basename(file))
if os.path.exists(dst):

View File

@ -1,8 +1,4 @@
from dataclasses import dataclass, field
from ModelSample import RVCModelSample
# from const import MAX_SLOT_NUM
# from data.ModelSlot import ModelSlot, ModelSlots
@dataclass
@ -17,9 +13,6 @@ class RVCSettings:
clusterInferRatio: float = 0.1
framework: str = "PyTorch" # PyTorch or ONNX
# modelSlots: list[ModelSlots] = field(default_factory=lambda: [ModelSlot() for _x in range(MAX_SLOT_NUM)])
sampleModels: list[RVCModelSample] = field(default_factory=lambda: [])
indexRatio: float = 0
protect: float = 0.5

View File

@ -60,7 +60,7 @@ class VoiceChangerSettings:
class VoiceChanger:
ioRecorder: IORecorder
# sola_buffer: AudioInOut
sola_buffer: AudioInOut
def __init__(self, params: VoiceChangerParams):
# 初期化

View File

@ -1,7 +1,7 @@
import numpy as np
from data.ModelSlot import loadAllSlotInfo
from utils.downloader.SampleDownloader import downloadSample
from utils.downloader.SampleDownloader import downloadSample, getSampleInfos
from voice_changer.Local.ServerDevice import ServerDevice, ServerDeviceCallbacks
from voice_changer.ModelSlotManager import ModelSlotManager
from voice_changer.VoiceChanger import VoiceChanger
from const import ModelType
from voice_changer.utils.LoadModelParams import LoadModelParams
@ -53,6 +53,8 @@ class VoiceChangerManager(ServerDeviceCallbacks):
self.params = params
self.voiceChanger: VoiceChanger = None
self.settings: VoiceChangerManagerSettings = VoiceChangerManagerSettings(dummy=0)
self.modelSlotManager = ModelSlotManager.get_instance(self.params.model_dir)
# スタティックな情報を収集
self.gpus: list[GPUInfo] = self._get_gpuInfos()
@ -82,6 +84,7 @@ class VoiceChangerManager(ServerDeviceCallbacks):
paramDict = props.params
if "sampleId" in paramDict and len(paramDict["sampleId"]) > 0:
downloadSample(self.params.sample_mode, paramDict["sampleId"], self.params.model_dir, props.slot, {"useIndex": paramDict["rvcIndexDownload"]})
self.modelSlotManager.getAllSlotInfo(reload=True)
info = {"status": "OK"}
return info
else:
@ -96,7 +99,8 @@ class VoiceChangerManager(ServerDeviceCallbacks):
def get_info(self):
data = asdict(self.settings)
data["gpus"] = self.gpus
data["modelSlots"] = loadAllSlotInfo(self.params.model_dir)
data["modelSlots"] = self.modelSlotManager.getAllSlotInfo(reload=True)
data["sampleModels"] = getSampleInfos(self.params.sample_mode)
data["status"] = "OK"