mirror of
https://github.com/w-okada/voice-changer.git
synced 2025-02-02 16:23:58 +03:00
WIP:common sample
This commit is contained in:
parent
4a8851aff1
commit
9dd2808509
@ -1,56 +0,0 @@
|
||||
import requests # type: ignore
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def download(params):
|
||||
url = params["url"]
|
||||
saveTo = params["saveTo"]
|
||||
position = params["position"]
|
||||
dirname = os.path.dirname(saveTo)
|
||||
if dirname != "":
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
|
||||
try:
|
||||
req = requests.get(url, stream=True, allow_redirects=True)
|
||||
content_length = req.headers.get("content-length")
|
||||
progress_bar = tqdm(
|
||||
total=int(content_length) if content_length is not None else None,
|
||||
leave=False,
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
position=position,
|
||||
)
|
||||
|
||||
# with tqdm
|
||||
with open(saveTo, "wb") as f:
|
||||
for chunk in req.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
progress_bar.update(len(chunk))
|
||||
f.write(chunk)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
def download_no_tqdm(params):
|
||||
url = params["url"]
|
||||
saveTo = params["saveTo"]
|
||||
dirname = os.path.dirname(saveTo)
|
||||
if dirname != "":
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
try:
|
||||
req = requests.get(url, stream=True, allow_redirects=True)
|
||||
with open(saveTo, "wb") as f:
|
||||
countToDot = 0
|
||||
for chunk in req.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
countToDot += 1
|
||||
if countToDot % 1024 == 0:
|
||||
print(".", end="", flush=True)
|
||||
|
||||
print("+", end="", flush=True)
|
||||
except Exception as e:
|
||||
print(e)
|
@ -1,44 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
|
||||
from const import ModelType
|
||||
|
||||
|
||||
@dataclass
|
||||
class RVCModelSample:
|
||||
id: str = ""
|
||||
lang: str = ""
|
||||
tag: list[str] = field(default_factory=lambda: [])
|
||||
name: str = ""
|
||||
modelUrl: str = ""
|
||||
indexUrl: str = ""
|
||||
termsOfUseUrl: str = ""
|
||||
icon: str = ""
|
||||
credit: str = ""
|
||||
description: str = ""
|
||||
|
||||
sampleRate: int = 48000
|
||||
modelType: str = ""
|
||||
f0: bool = True
|
||||
|
||||
|
||||
def getModelSamples(jsonFiles: list[str], modelType: ModelType):
|
||||
try:
|
||||
samples: list[RVCModelSample] = []
|
||||
for file in jsonFiles:
|
||||
with open(file, "r", encoding="utf-8") as f:
|
||||
jsonDict = json.load(f)
|
||||
|
||||
modelList = jsonDict[modelType]
|
||||
if modelType == "RVC":
|
||||
for s in modelList:
|
||||
modelSample = RVCModelSample(**s)
|
||||
samples.append(modelSample)
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Unknown model type {modelType}")
|
||||
return samples
|
||||
|
||||
except Exception as e:
|
||||
print("[Voice Changer] loading sample info error:", e)
|
||||
return None
|
@ -23,11 +23,8 @@ ModelType: TypeAlias = Literal[
|
||||
"RVC",
|
||||
]
|
||||
|
||||
ERROR_NO_ONNX_SESSION = "ERROR_NO_ONNX_SESSION"
|
||||
|
||||
|
||||
tmpdir = tempfile.TemporaryDirectory()
|
||||
# print("generate tmpdir:::",tmpdir)
|
||||
SSL_KEY_DIR = os.path.join(tmpdir.name, "keys") if hasattr(sys, "_MEIPASS") else "keys"
|
||||
MODEL_DIR = os.path.join(tmpdir.name, "logs") if hasattr(sys, "_MEIPASS") else "logs"
|
||||
UPLOAD_DIR = os.path.join(tmpdir.name, "upload_dir") if hasattr(sys, "_MEIPASS") else "upload_dir"
|
||||
|
@ -77,7 +77,7 @@ class MMVC_Rest_Fileuploader:
|
||||
return JSONResponse(content=json_compatible_item_data)
|
||||
except Exception as e:
|
||||
print("[Voice Changer] ex:", e)
|
||||
print(sys.exc_info()[2].tb_lineno)
|
||||
print(sys.exc_info())
|
||||
|
||||
def post_load_model(
|
||||
self,
|
||||
|
@ -5,7 +5,8 @@ from typing import Any, Tuple
|
||||
|
||||
from const import RVCSampleMode, getSampleJsonAndModelIds
|
||||
from data.ModelSample import ModelSamples, generateModelSample
|
||||
from data.ModelSlot import RVCModelSlot, loadSlotInfo, saveSlotInfo
|
||||
from data.ModelSlot import RVCModelSlot
|
||||
from voice_changer.ModelSlotManager import ModelSlotManager
|
||||
from voice_changer.RVC.ModelSlotGenerator import _setInfoByONNX, _setInfoByPytorch
|
||||
from utils.downloader.Downloader import download, download_no_tqdm
|
||||
|
||||
@ -27,7 +28,7 @@ def downloadSample(mode: RVCSampleMode, modelId: str, model_dir: str, slotIndex:
|
||||
sampleJsonUrls, _sampleModels = getSampleJsonAndModelIds(mode)
|
||||
sampleJsons = _generateSampleJsons(sampleJsonUrls)
|
||||
samples = _generateSampleList(sampleJsons)
|
||||
_downloadSamples(samples, [(modelId, params)], model_dir, [slotIndex])
|
||||
_downloadSamples(samples, [(modelId, params)], model_dir, [slotIndex], withoutTqdm=True)
|
||||
pass
|
||||
|
||||
|
||||
@ -67,9 +68,10 @@ def _generateSampleList(sampleJsons: list[str]):
|
||||
return samples
|
||||
|
||||
|
||||
def _downloadSamples(samples: list[ModelSamples], sampleModelIds: list[Tuple[str, Any]], model_dir: str, slotIndex: list[int]):
|
||||
def _downloadSamples(samples: list[ModelSamples], sampleModelIds: list[Tuple[str, Any]], model_dir: str, slotIndex: list[int], withoutTqdm=False):
|
||||
downloadParams = []
|
||||
line_num = 0
|
||||
modelSlotManager = ModelSlotManager.get_instance(model_dir)
|
||||
|
||||
for i, initSampleId in enumerate(sampleModelIds):
|
||||
targetSampleId = initSampleId[0]
|
||||
@ -145,22 +147,26 @@ def _downloadSamples(samples: list[ModelSamples], sampleModelIds: list[Tuple[str
|
||||
slotInfo.defaultIndexRatio = 1
|
||||
slotInfo.defaultProtect = 0.5
|
||||
slotInfo.isONNX = slotInfo.modelFile.endswith(".onnx")
|
||||
saveSlotInfo(model_dir, tagetSlotIndex, slotInfo)
|
||||
modelSlotManager.save_model_slot(tagetSlotIndex, slotInfo)
|
||||
else:
|
||||
print(f"[Voice Changer] {sample.voiceChangerType} is not supported.")
|
||||
|
||||
# ダウンロード
|
||||
print("[Voice Changer] Downloading model files...")
|
||||
with ThreadPoolExecutor() as pool:
|
||||
pool.map(download, downloadParams)
|
||||
if withoutTqdm:
|
||||
with ThreadPoolExecutor() as pool:
|
||||
pool.map(download_no_tqdm, downloadParams)
|
||||
else:
|
||||
with ThreadPoolExecutor() as pool:
|
||||
pool.map(download, downloadParams)
|
||||
|
||||
# メタデータ作成
|
||||
print("[Voice Changer] Generating metadata...")
|
||||
for targetSlotIndex in slotIndex:
|
||||
slotInfo = loadSlotInfo(model_dir, targetSlotIndex)
|
||||
slotInfo = modelSlotManager.get_slot_info(targetSlotIndex)
|
||||
if slotInfo.voiceChangerType == "RVC":
|
||||
if slotInfo.isONNX:
|
||||
_setInfoByONNX(slotInfo)
|
||||
else:
|
||||
_setInfoByPytorch(slotInfo)
|
||||
saveSlotInfo(model_dir, targetSlotIndex, slotInfo)
|
||||
modelSlotManager.save_model_slot(tagetSlotIndex, slotInfo)
|
||||
|
@ -1,6 +1,5 @@
|
||||
from const import UPLOAD_DIR
|
||||
from data.ModelSlot import ModelSlots, loadAllSlotInfo, saveSlotInfo
|
||||
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
@ -9,24 +8,26 @@ import shutil
|
||||
class ModelSlotManager:
|
||||
_instance = None
|
||||
|
||||
def __init__(self, params: VoiceChangerParams):
|
||||
self.params = params
|
||||
self.modelSlots = loadAllSlotInfo(self.params.model_dir)
|
||||
def __init__(self, model_dir: str):
|
||||
self.model_dir = model_dir
|
||||
self.modelSlots = loadAllSlotInfo(self.model_dir)
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, params: VoiceChangerParams):
|
||||
def get_instance(cls, model_dir: str):
|
||||
if cls._instance is None:
|
||||
cls._instance = cls(params)
|
||||
cls._instance = cls(model_dir)
|
||||
return cls._instance
|
||||
|
||||
def _save_model_slot(self, slotIndex: int, slotInfo: ModelSlots):
|
||||
saveSlotInfo(self.params.model_dir, slotIndex, slotInfo)
|
||||
self.modelSlots = loadAllSlotInfo(self.params.model_dir)
|
||||
saveSlotInfo(self.model_dir, slotIndex, slotInfo)
|
||||
self.modelSlots = loadAllSlotInfo(self.model_dir)
|
||||
|
||||
def _load_model_slot(self, slotIndex: int):
|
||||
return self.modelSlots[slotIndex]
|
||||
|
||||
def getAllSlotInfo(self):
|
||||
def getAllSlotInfo(self, reload: bool = False):
|
||||
if reload:
|
||||
self.modelSlots = loadAllSlotInfo(self.model_dir)
|
||||
return self.modelSlots
|
||||
|
||||
def get_slot_info(self, slotIndex: int):
|
||||
@ -46,7 +47,7 @@ class ModelSlotManager:
|
||||
print("[Voice Changer] UPLOAD ASSETS", params)
|
||||
paramsDict = json.loads(params)
|
||||
uploadPath = os.path.join(UPLOAD_DIR, paramsDict["file"])
|
||||
storeDir = os.path.join(self.params.model_dir, str(paramsDict["slot"]))
|
||||
storeDir = os.path.join(self.model_dir, str(paramsDict["slot"]))
|
||||
storePath = os.path.join(
|
||||
storeDir,
|
||||
paramsDict["file"],
|
||||
|
@ -6,7 +6,6 @@ import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from data.ModelSlot import RVCModelSlot
|
||||
from utils.downloader.SampleDownloader import getSampleInfos
|
||||
from voice_changer.ModelSlotManager import ModelSlotManager
|
||||
|
||||
|
||||
@ -67,13 +66,9 @@ class RVC:
|
||||
self.params = params
|
||||
EmbedderManager.initialize(params)
|
||||
print("[Voice Changer] RVC initialization: ", params)
|
||||
self.modelSlotManager = ModelSlotManager.get_instance(self.params)
|
||||
self.modelSlotManager = ModelSlotManager.get_instance(self.params.model_dir)
|
||||
|
||||
# サンプルカタログ作成
|
||||
samples = getSampleInfos(params.sample_mode)
|
||||
self.settings.sampleModels = samples
|
||||
# 起動時にスロットにモデルがある場合はロードしておく
|
||||
|
||||
allSlots = self.modelSlotManager.getAllSlotInfo()
|
||||
availableIndex = -1
|
||||
for i, slot in enumerate(allSlots):
|
||||
@ -88,13 +83,6 @@ class RVC:
|
||||
|
||||
self.prevVol = 0.0
|
||||
|
||||
def getSampleInfo(self, id: str):
|
||||
sampleInfos = list(filter(lambda x: x.id == id, self.settings.sampleModels))
|
||||
if len(sampleInfos) > 0:
|
||||
return sampleInfos[0]
|
||||
else:
|
||||
None
|
||||
|
||||
def moveToModelDir(self, file: str, dstDir: str):
|
||||
dst = os.path.join(dstDir, os.path.basename(file))
|
||||
if os.path.exists(dst):
|
||||
|
@ -1,8 +1,4 @@
|
||||
from dataclasses import dataclass, field
|
||||
from ModelSample import RVCModelSample
|
||||
|
||||
# from const import MAX_SLOT_NUM
|
||||
# from data.ModelSlot import ModelSlot, ModelSlots
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -17,9 +13,6 @@ class RVCSettings:
|
||||
clusterInferRatio: float = 0.1
|
||||
|
||||
framework: str = "PyTorch" # PyTorch or ONNX
|
||||
# modelSlots: list[ModelSlots] = field(default_factory=lambda: [ModelSlot() for _x in range(MAX_SLOT_NUM)])
|
||||
|
||||
sampleModels: list[RVCModelSample] = field(default_factory=lambda: [])
|
||||
|
||||
indexRatio: float = 0
|
||||
protect: float = 0.5
|
||||
|
@ -60,7 +60,7 @@ class VoiceChangerSettings:
|
||||
|
||||
class VoiceChanger:
|
||||
ioRecorder: IORecorder
|
||||
# sola_buffer: AudioInOut
|
||||
sola_buffer: AudioInOut
|
||||
|
||||
def __init__(self, params: VoiceChangerParams):
|
||||
# 初期化
|
||||
|
@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
from data.ModelSlot import loadAllSlotInfo
|
||||
from utils.downloader.SampleDownloader import downloadSample
|
||||
from utils.downloader.SampleDownloader import downloadSample, getSampleInfos
|
||||
from voice_changer.Local.ServerDevice import ServerDevice, ServerDeviceCallbacks
|
||||
from voice_changer.ModelSlotManager import ModelSlotManager
|
||||
from voice_changer.VoiceChanger import VoiceChanger
|
||||
from const import ModelType
|
||||
from voice_changer.utils.LoadModelParams import LoadModelParams
|
||||
@ -53,6 +53,8 @@ class VoiceChangerManager(ServerDeviceCallbacks):
|
||||
self.params = params
|
||||
self.voiceChanger: VoiceChanger = None
|
||||
self.settings: VoiceChangerManagerSettings = VoiceChangerManagerSettings(dummy=0)
|
||||
|
||||
self.modelSlotManager = ModelSlotManager.get_instance(self.params.model_dir)
|
||||
# スタティックな情報を収集
|
||||
self.gpus: list[GPUInfo] = self._get_gpuInfos()
|
||||
|
||||
@ -82,6 +84,7 @@ class VoiceChangerManager(ServerDeviceCallbacks):
|
||||
paramDict = props.params
|
||||
if "sampleId" in paramDict and len(paramDict["sampleId"]) > 0:
|
||||
downloadSample(self.params.sample_mode, paramDict["sampleId"], self.params.model_dir, props.slot, {"useIndex": paramDict["rvcIndexDownload"]})
|
||||
self.modelSlotManager.getAllSlotInfo(reload=True)
|
||||
info = {"status": "OK"}
|
||||
return info
|
||||
else:
|
||||
@ -96,7 +99,8 @@ class VoiceChangerManager(ServerDeviceCallbacks):
|
||||
def get_info(self):
|
||||
data = asdict(self.settings)
|
||||
data["gpus"] = self.gpus
|
||||
data["modelSlots"] = loadAllSlotInfo(self.params.model_dir)
|
||||
data["modelSlots"] = self.modelSlotManager.getAllSlotInfo(reload=True)
|
||||
data["sampleModels"] = getSampleInfos(self.params.sample_mode)
|
||||
|
||||
data["status"] = "OK"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user