WIP:common sample

This commit is contained in:
wataru 2023-06-16 17:12:03 +09:00
parent 87de3dc10f
commit d7f5828710
5 changed files with 28 additions and 23 deletions

View File

@ -171,4 +171,4 @@ def getSampleJsonAndModelIds(mode: RVCSampleMode):
RVC_MODEL_DIRNAME = "rvc"
RVC_MAX_SLOT_NUM = 10
MAX_SLOT_NUM = 10

View File

@ -1,5 +1,5 @@
from typing import TypeAlias, Union
from const import EnumInferenceTypes, EnumEmbedderTypes, VoiceChangerType
from const import MAX_SLOT_NUM, EnumInferenceTypes, EnumEmbedderTypes, VoiceChangerType
from dataclasses import dataclass, asdict
@ -54,6 +54,14 @@ def loadSlotInfo(model_dir: str, slotIndex: int) -> ModelSlots:
return ModelSlot()
def loadAllSlotInfo(model_dir: str):
slotInfos: list[ModelSlots] = []
for slotIndex in range(MAX_SLOT_NUM):
slotInfo = loadSlotInfo(model_dir, slotIndex)
slotInfos.append(slotInfo)
return slotInfos
def saveSlotInfo(model_dir: str, slotIndex: int, slotInfo: ModelSlots):
slotDir = os.path.join(model_dir, str(slotIndex))
json.dump(asdict(slotInfo), open(os.path.join(slotDir, "params.json"), "w"))

View File

@ -5,6 +5,7 @@ from typing import cast
import numpy as np
import torch
import torchaudio
from data.ModelSlot import loadAllSlotInfo
from utils.downloader.SampleDownloader import getSampleInfos
from voice_changer.RVC.ModelSlot import ModelSlot
from voice_changer.RVC.SampleDownloader import downloadModelFiles
@ -67,18 +68,12 @@ class RVC:
self.pitchExtractor = PitchExtractorManager.getPitchExtractor(self.settings.f0Detector)
self.params = params
EmbedderManager.initialize(params)
self.loadSlots()
# self.loadSlots()
self.settings.modelSlots = loadAllSlotInfo(self.params.model_dir)
print("[Voice Changer] RVC initialization: ", params)
# サンプルカタログ作成
# sampleJsons: list[str] = []
samples = getSampleInfos(params.sample_mode)
# for url in sampleJsonUrls:
# filename = os.path.basename(url)
# sampleJsons.append(filename)
# sampleModels = getModelSamples(sampleJsons, "RVC")
# if sampleModels is not None:
# self.settings.sampleModels = sampleModels
self.settings.sampleModels = samples
# 起動時にスロットにモデルがある場合はロードしておく
if len(self.settings.modelSlots) > 0:
@ -160,7 +155,8 @@ class RVC:
if slotInfo.iconFile is not None and len(slotInfo.iconFile) > 0:
slotInfo.iconFile = self.moveToModelDir(slotInfo.iconFile, slotDir)
json.dump(asdict(slotInfo), open(os.path.join(slotDir, "params.json"), "w"))
self.loadSlots()
# self.loadSlots()
self.settings.modelSlots = loadAllSlotInfo(self.params.model_dir)
# 初回のみロード(起動時にスロットにモデルがあった場合はinitialLoadはFalseになっている)
if self.initialLoad:
@ -444,7 +440,8 @@ class RVC:
params["defaultProtect"] = self.settings.protect
json.dump(params, open(os.path.join(slotDir, "params.json"), "w"))
self.loadSlots()
# self.loadSlots()
self.settings.modelSlots = loadAllSlotInfo(self.params.model_dir)
def update_model_info(self, newData: str):
print("[Voice Changer] UPDATE MODEL INFO", newData)
@ -456,7 +453,8 @@ class RVC:
params = json.load(open(os.path.join(slotDir, "params.json"), "r", encoding="utf-8"))
params[newDataDict["key"]] = newDataDict["val"]
json.dump(params, open(os.path.join(slotDir, "params.json"), "w"))
self.loadSlots()
# self.loadSlots()
self.settings.modelSlots = loadAllSlotInfo(self.params.model_dir)
def upload_model_assets(self, params: str):
print("[Voice Changer] UPLOAD ASSETS", params)
@ -479,4 +477,5 @@ class RVC:
except Exception as e:
print("Exception::::", e)
self.loadSlots()
# self.loadSlots()
self.settings.modelSlots = loadAllSlotInfo(self.params.model_dir)

View File

@ -1,8 +1,7 @@
from dataclasses import dataclass, field
from ModelSample import RVCModelSample
from const import RVC_MAX_SLOT_NUM
from voice_changer.RVC.ModelSlot import ModelSlot
from const import MAX_SLOT_NUM
from data.ModelSlot import ModelSlot, ModelSlots
@dataclass
@ -17,9 +16,7 @@ class RVCSettings:
clusterInferRatio: float = 0.1
framework: str = "PyTorch" # PyTorch or ONNX
modelSlots: list[ModelSlot] = field(
default_factory=lambda: [ModelSlot() for _x in range(RVC_MAX_SLOT_NUM)]
)
modelSlots: list[ModelSlots] = field(default_factory=lambda: [ModelSlot() for _x in range(MAX_SLOT_NUM)])
sampleModels: list[RVCModelSample] = field(default_factory=lambda: [])

View File

@ -14,7 +14,7 @@ from voice_changer.IORecorder import IORecorder
from voice_changer.utils.LoadModelParams import LoadModelParams
from voice_changer.utils.Timer import Timer
from voice_changer.utils.VoiceChangerModel import VoiceChangerModel, AudioInOut
from voice_changer.utils.VoiceChangerModel import AudioInOut
from Exceptions import (
DeviceCannotSupportHalfPrecisionException,
DeviceChangingException,
@ -60,8 +60,6 @@ class VoiceChangerSettings:
class VoiceChanger:
settings: VoiceChangerSettings = VoiceChangerSettings()
voiceChanger: VoiceChangerModel | None = None
ioRecorder: IORecorder
sola_buffer: AudioInOut
namespace: socketio.AsyncNamespace | None = None
@ -148,7 +146,10 @@ class VoiceChanger:
def get_info(self):
data = asdict(self.settings)
if self.voiceChanger is not None:
print("------------------ self.voiceChanger is not None")
data.update(self.voiceChanger.get_info())
else:
print("------------------ self.voiceChanger is None")
return data
def get_performance(self):