bugfix: merge model

This commit is contained in:
w-okada 2023-08-05 03:02:43 +09:00
parent 8b5eb32047
commit aef9d2c34b
12 changed files with 157 additions and 145 deletions

File diff suppressed because one or more lines are too long

View File

@ -1,31 +0,0 @@
/*! regenerator-runtime -- Copyright (c) 2014-present, Facebook, Inc. -- license (MIT): https://github.com/facebook/regenerator/blob/main/LICENSE */
/**
* @license React
* react-dom.production.min.js
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
/**
* @license React
* react.production.min.js
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
/**
* @license React
* scheduler.production.min.js
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

View File

@ -3,158 +3,182 @@ import { useGuiState } from "./001_GuiStateProvider";
import { useAppState } from "../../001_provider/001_AppStateProvider"; import { useAppState } from "../../001_provider/001_AppStateProvider";
import { MergeElement, RVCModelSlot, RVCModelType, VoiceChangerType } from "@dannadori/voice-changer-client-js"; import { MergeElement, RVCModelSlot, RVCModelType, VoiceChangerType } from "@dannadori/voice-changer-client-js";
export const MergeLabDialog = () => { export const MergeLabDialog = () => {
const guiState = useGuiState() const guiState = useGuiState();
const { serverSetting } = useAppState() const { serverSetting } = useAppState();
const [currentFilter, setCurrentFilter] = useState<string>("") const [currentFilter, setCurrentFilter] = useState<string>("");
const [mergeElements, setMergeElements] = useState<MergeElement[]>([]) const [mergeElements, setMergeElements] = useState<MergeElement[]>([]);
// スロットが変更されたときの初期化処理 // スロットが変更されたときの初期化処理
const newSlotChangeKey = useMemo(() => { const newSlotChangeKey = useMemo(() => {
if (!serverSetting.serverSetting.modelSlots) { if (!serverSetting.serverSetting.modelSlots) {
return "" return "";
} }
return serverSetting.serverSetting.modelSlots.reduce((prev, cur) => { return serverSetting.serverSetting.modelSlots.reduce((prev, cur) => {
return prev + "_" + cur.modelFile return prev + "_" + cur.modelFile;
}, "") }, "");
}, [serverSetting.serverSetting.modelSlots]) }, [serverSetting.serverSetting.modelSlots]);
const filterItems = useMemo(() => { const filterItems = useMemo(() => {
return serverSetting.serverSetting.modelSlots.reduce((prev, cur) => { return serverSetting.serverSetting.modelSlots.reduce(
if (cur.voiceChangerType != "RVC") { (prev, cur) => {
return prev if (cur.voiceChangerType != "RVC") {
} return prev;
const curRVC = cur as RVCModelSlot }
const key = `${curRVC.modelType},${cur.samplingRate},${curRVC.embChannels}` const curRVC = cur as RVCModelSlot;
const val = { type: curRVC.modelType, samplingRate: cur.samplingRate, embChannels: curRVC.embChannels } const key = `${curRVC.modelType},${cur.samplingRate},${curRVC.embChannels}`;
const existKeys = Object.keys(prev) const val = { type: curRVC.modelType, samplingRate: cur.samplingRate, embChannels: curRVC.embChannels };
if (!cur.modelFile || cur.modelFile.length == 0) { const existKeys = Object.keys(prev);
return prev if (!cur.modelFile || cur.modelFile.length == 0) {
} return prev;
if (curRVC.modelType == "onnxRVC" || curRVC.modelType == "onnxRVCNono") { }
return prev if (curRVC.modelType == "onnxRVC" || curRVC.modelType == "onnxRVCNono") {
} return prev;
if (!existKeys.includes(key)) { }
prev[key] = val if (!existKeys.includes(key)) {
} prev[key] = val;
return prev }
}, {} as { [key: string]: { type: RVCModelType, samplingRate: number, embChannels: number } }) return prev;
},
}, [newSlotChangeKey]) {} as { [key: string]: { type: RVCModelType; samplingRate: number; embChannels: number } },
);
}, [newSlotChangeKey]);
const models = useMemo(() => { const models = useMemo(() => {
return serverSetting.serverSetting.modelSlots.filter(x => { return serverSetting.serverSetting.modelSlots.filter((x) => {
if (x.voiceChangerType != "RVC") { if (x.voiceChangerType != "RVC") {
return return;
} }
const xRVC = x as RVCModelSlot const xRVC = x as RVCModelSlot;
const filterVals = filterItems[currentFilter] const filterVals = filterItems[currentFilter];
if (!filterVals) { if (!filterVals) {
return false return false;
} }
if (xRVC.modelType == filterVals.type && xRVC.samplingRate == filterVals.samplingRate && xRVC.embChannels == filterVals.embChannels) { if (xRVC.modelType == filterVals.type && xRVC.samplingRate == filterVals.samplingRate && xRVC.embChannels == filterVals.embChannels) {
return true return true;
} else { } else {
return false return false;
} }
}) });
}, [filterItems, currentFilter]) }, [filterItems, currentFilter]);
useEffect(() => { useEffect(() => {
if (Object.keys(filterItems).length > 0) { if (Object.keys(filterItems).length > 0) {
setCurrentFilter(Object.keys(filterItems)[0]) setCurrentFilter(Object.keys(filterItems)[0]);
} }
}, [filterItems]) }, [filterItems]);
useEffect(() => { useEffect(() => {
// models はフィルタ後の配列
const newMergeElements = models.map((x) => { const newMergeElements = models.map((x) => {
return { filename: x.modelFile, strength: 0 } return { slotIndex: x.slotIndex, filename: x.modelFile, strength: 0 };
}) });
setMergeElements(newMergeElements) setMergeElements(newMergeElements);
}, [models]) }, [models]);
const dialog = useMemo(() => { const dialog = useMemo(() => {
const closeButtonRow = ( const closeButtonRow = (
<div className="body-row split-3-4-3 left-padding-1"> <div className="body-row split-3-4-3 left-padding-1">
<div className="body-item-text"> <div className="body-item-text"></div>
</div>
<div className="body-button-container body-button-container-space-around"> <div className="body-button-container body-button-container-space-around">
<div className="body-button" onClick={() => { guiState.stateControls.showMergeLabCheckbox.updateState(false) }} >close</div> <div
className="body-button"
onClick={() => {
guiState.stateControls.showMergeLabCheckbox.updateState(false);
}}
>
close
</div>
</div> </div>
<div className="body-item-text"></div> <div className="body-item-text"></div>
</div> </div>
) );
const filterOptions = Object.keys(filterItems)
const filterOptions = Object.keys(filterItems).map(x => { .map((x) => {
return <option key={x} value={x}>{x}</option> return (
}).filter(x => x != null) <option key={x} value={x}>
{x}
const onMergeElementsChanged = (filename: string, strength: number) => { </option>
const newMergeElements = mergeElements.map((x) => { );
if (x.filename == filename) {
return { filename: x.filename, strength: strength }
} else {
return x
}
}) })
setMergeElements(newMergeElements) .filter((x) => x != null);
}
const onMergeElementsChanged = (slotIndex: number, strength: number) => {
const newMergeElements = mergeElements.map((x) => {
if (x.slotIndex == slotIndex) {
return { slotIndex: x.slotIndex, filename: x.filename, strength: strength };
} else {
return x;
}
});
setMergeElements(newMergeElements);
};
const onMergeClicked = () => { const onMergeClicked = () => {
const validMergeElements = mergeElements.filter((x) => {
return x.strength > 0;
});
serverSetting.mergeModel({ serverSetting.mergeModel({
voiceChangerType: VoiceChangerType.RVC, voiceChangerType: VoiceChangerType.RVC,
command: "mix", command: "mix",
files: mergeElements files: validMergeElements,
}) });
} };
const modelList = mergeElements.map((x, index) => { const modelList = mergeElements.map((x, index) => {
const name = models.find(model => { return model.modelFile == x.filename })?.name || "" const name =
models.find((model) => {
return model.slotIndex == x.slotIndex;
})?.name || "";
return ( return (
<div key={index} className="merge-lab-model-item"> <div key={index} className="merge-lab-model-item">
<div>{name}</div>
<div> <div>
{name} <input
</div> type="range"
<div> className="body-item-input-slider"
<input type="range" className="body-item-input-slider" min="0" max="100" step="1" value={x.strength} onChange={(e) => { min="0"
onMergeElementsChanged(x.filename, Number(e.target.value)) max="100"
}}></input> step="1"
value={x.strength}
onChange={(e) => {
onMergeElementsChanged(x.slotIndex, Number(e.target.value));
}}
></input>
<span className="body-item-input-slider-val">{x.strength}</span> <span className="body-item-input-slider-val">{x.strength}</span>
</div> </div>
</div> </div>
) );
}) });
const content = ( const content = (
<div className="merge-lab-container"> <div className="merge-lab-container">
<div className="merge-lab-type-filter"> <div className="merge-lab-type-filter">
<div>Type:</div>
<div> <div>
Type: <select
</div> value={currentFilter}
<div> onChange={(e) => {
<select value={currentFilter} onChange={(e) => { setCurrentFilter(e.target.value) }}> setCurrentFilter(e.target.value);
}}
>
{filterOptions} {filterOptions}
</select> </select>
</div> </div>
</div> </div>
<div className="merge-lab-manipulator"> <div className="merge-lab-manipulator">
<div className="merge-lab-model-list"> <div className="merge-lab-model-list">{modelList}</div>
{modelList}
</div>
<div className="merge-lab-merge-buttons"> <div className="merge-lab-merge-buttons">
<div className="merge-lab-merge-buttons-notice"> <div className="merge-lab-merge-buttons-notice">The merged model is stored in the final slot. If you assign this slot, it will be overwritten.</div>
The merged model is stored in the final slot. If you assign this slot, it will be overwritten.
</div>
<div className="merge-lab-merge-button" onClick={onMergeClicked}> <div className="merge-lab-merge-button" onClick={onMergeClicked}>
merge merge
</div> </div>
</div> </div>
</div> </div>
</div> </div>
) );
return ( return (
<div className="dialog-frame"> <div className="dialog-frame">
<div className="dialog-title">MergeLab</div> <div className="dialog-title">MergeLab</div>
@ -166,5 +190,4 @@ export const MergeLabDialog = () => {
); );
}, [newSlotChangeKey, currentFilter, mergeElements, models]); }, [newSlotChangeKey, currentFilter, mergeElements, models]);
return dialog; return dialog;
}; };

View File

@ -38,7 +38,7 @@ export const ModelSlotArea = (_props: ModelSlotAreaProps) => {
if (!x.modelFile || x.modelFile.length == 0) { if (!x.modelFile || x.modelFile.length == 0) {
return null; return null;
} }
const tileContainerClass = x.id == serverSetting.serverSetting.modelSlotIndex ? "model-slot-tile-container-selected" : "model-slot-tile-container"; const tileContainerClass = x.slotIndex == serverSetting.serverSetting.modelSlotIndex ? "model-slot-tile-container-selected" : "model-slot-tile-container";
const name = x.name.length > 8 ? x.name.substring(0, 7) + "..." : x.name; const name = x.name.length > 8 ? x.name.substring(0, 7) + "..." : x.name;
const iconElem = const iconElem =
x.iconFile.length > 0 ? ( x.iconFile.length > 0 ? (
@ -54,7 +54,7 @@ export const ModelSlotArea = (_props: ModelSlotAreaProps) => {
); );
const clickAction = async () => { const clickAction = async () => {
const dummyModelSlotIndex = Math.floor(Date.now() / 1000) * 1000 + x.id; const dummyModelSlotIndex = Math.floor(Date.now() / 1000) * 1000 + x.slotIndex;
await serverSetting.updateServerSettings({ ...serverSetting.serverSetting, modelSlotIndex: dummyModelSlotIndex }); await serverSetting.updateServerSettings({ ...serverSetting.serverSetting, modelSlotIndex: dummyModelSlotIndex });
setTimeout(() => { setTimeout(() => {
// quick hack // quick hack

View File

@ -193,7 +193,7 @@ export type VoiceChangerServerSetting = {
} }
type ModelSlot = { type ModelSlot = {
id: number slotIndex: number
voiceChangerType: VoiceChangerType voiceChangerType: VoiceChangerType
name: string, name: string,
description: string, description: string,
@ -539,7 +539,8 @@ export type OnnxExporterInfo = {
// Merge // Merge
export type MergeElement = { export type MergeElement = {
filename: string slotIndex: number
filename: string // 一意性は保障されない場合がある(フォルダコピーされたときとか)
strength: number strength: number
} }
export type MergeModelRequest = { export type MergeModelRequest = {

View File

@ -124,7 +124,7 @@ if __name__ == "MMVCServerSIO":
if __name__ == "__mp_main__": if __name__ == "__mp_main__":
# printMessage("サーバプロセスを起動しています。", level=2) # printMessage("サーバプロセスを起動しています。", level=2)
printMessage("The server process is starting up.", level=2) printMessage("The server process is starting up.", level=2)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -9,7 +9,7 @@ import json
@dataclass @dataclass
class ModelSlot: class ModelSlot:
id: int = -1 slotIndex: int = -1
voiceChangerType: VoiceChangerType | None = None voiceChangerType: VoiceChangerType | None = None
name: str = "" name: str = ""
description: str = "" description: str = ""
@ -133,19 +133,26 @@ def loadSlotInfo(model_dir: str, slotIndex: int) -> ModelSlots:
if not os.path.exists(jsonFile): if not os.path.exists(jsonFile):
return ModelSlot() return ModelSlot()
jsonDict = json.load(open(os.path.join(slotDir, "params.json"))) jsonDict = json.load(open(os.path.join(slotDir, "params.json")))
slotInfo = ModelSlot(**{k: v for k, v in jsonDict.items() if k in ModelSlot.__annotations__}) slotInfoKey = list(ModelSlot.__annotations__.keys())
slotInfo = ModelSlot(**{k: v for k, v in jsonDict.items() if k in slotInfoKey})
if slotInfo.voiceChangerType == "RVC": if slotInfo.voiceChangerType == "RVC":
return RVCModelSlot(**jsonDict) slotInfoKey.extend(list(RVCModelSlot.__annotations__.keys()))
return RVCModelSlot(**{k: v for k, v in jsonDict.items() if k in slotInfoKey})
elif slotInfo.voiceChangerType == "MMVCv13": elif slotInfo.voiceChangerType == "MMVCv13":
return MMVCv13ModelSlot(**jsonDict) slotInfoKey.extend(list(MMVCv13ModelSlot.__annotations__.keys()))
return MMVCv13ModelSlot(**{k: v for k, v in jsonDict.items() if k in slotInfoKey})
elif slotInfo.voiceChangerType == "MMVCv15": elif slotInfo.voiceChangerType == "MMVCv15":
return MMVCv15ModelSlot(**jsonDict) slotInfoKey.extend(list(MMVCv15ModelSlot.__annotations__.keys()))
return MMVCv15ModelSlot(**{k: v for k, v in jsonDict.items() if k in slotInfoKey})
elif slotInfo.voiceChangerType == "so-vits-svc-40": elif slotInfo.voiceChangerType == "so-vits-svc-40":
return SoVitsSvc40ModelSlot(**jsonDict) slotInfoKey.extend(list(SoVitsSvc40ModelSlot.__annotations__.keys()))
return SoVitsSvc40ModelSlot(**{k: v for k, v in jsonDict.items() if k in slotInfoKey})
elif slotInfo.voiceChangerType == "DDSP-SVC": elif slotInfo.voiceChangerType == "DDSP-SVC":
return DDSPSVCModelSlot(**jsonDict) slotInfoKey.extend(list(DDSPSVCModelSlot.__annotations__.keys()))
return DDSPSVCModelSlot(**{k: v for k, v in jsonDict.items() if k in slotInfoKey})
elif slotInfo.voiceChangerType == "Diffusion-SVC": elif slotInfo.voiceChangerType == "Diffusion-SVC":
return DiffusionSVCModelSlot(**jsonDict) slotInfoKey.extend(list(DiffusionSVCModelSlot.__annotations__.keys()))
return DiffusionSVCModelSlot(**{k: v for k, v in jsonDict.items() if k in slotInfoKey})
else: else:
return ModelSlot() return ModelSlot()
@ -154,11 +161,13 @@ def loadAllSlotInfo(model_dir: str):
slotInfos: list[ModelSlots] = [] slotInfos: list[ModelSlots] = []
for slotIndex in range(MAX_SLOT_NUM): for slotIndex in range(MAX_SLOT_NUM):
slotInfo = loadSlotInfo(model_dir, slotIndex) slotInfo = loadSlotInfo(model_dir, slotIndex)
slotInfo.id = slotIndex slotInfo.slotIndex = slotIndex # スロットインデックスは動的に注入
slotInfos.append(slotInfo) slotInfos.append(slotInfo)
return slotInfos return slotInfos
def saveSlotInfo(model_dir: str, slotIndex: int, slotInfo: ModelSlots): def saveSlotInfo(model_dir: str, slotIndex: int, slotInfo: ModelSlots):
slotDir = os.path.join(model_dir, str(slotIndex)) slotDir = os.path.join(model_dir, str(slotIndex))
json.dump(asdict(slotInfo), open(os.path.join(slotDir, "params.json"), "w")) slotInfoDict = asdict(slotInfo)
slotInfo.slotIndex = -1 # スロットインデックスは動的に注入
json.dump(slotInfoDict, open(os.path.join(slotDir, "params.json"), "w"))

View File

@ -113,6 +113,8 @@ class MMVC_Rest_Fileuploader:
return JSONResponse(content=json_compatible_item_data) return JSONResponse(content=json_compatible_item_data)
except Exception as e: except Exception as e:
print("[Voice Changer] post_merge_models ex:", e) print("[Voice Changer] post_merge_models ex:", e)
import traceback
traceback.print_exc()
def post_update_model_default(self): def post_update_model_default(self):
try: try:

View File

@ -4,17 +4,17 @@ import torch
from const import UPLOAD_DIR from const import UPLOAD_DIR
from voice_changer.RVC.modelMerger.MergeModel import merge_model from voice_changer.RVC.modelMerger.MergeModel import merge_model
from voice_changer.utils.ModelMerger import ModelMerger, ModelMergerRequest from voice_changer.utils.ModelMerger import ModelMerger, ModelMergerRequest
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
class RVCModelMerger(ModelMerger): class RVCModelMerger(ModelMerger):
@classmethod @classmethod
def merge_models(cls, request: ModelMergerRequest, storeSlot: int): def merge_models(cls, params: VoiceChangerParams, request: ModelMergerRequest, storeSlot: int):
print("[Voice Changer] MergeRequest:", request) merged = merge_model(params, request)
merged = merge_model(request)
# いったんは、アップロードフォルダに格納する。(歴史的経緯) # いったんは、アップロードフォルダに格納する。(歴史的経緯)
# 後続のloadmodelを呼び出すことで永続化モデルフォルダに移動させられる。 # 後続のloadmodelを呼び出すことで永続化モデルフォルダに移動させられる。
storeDir = os.path.join(UPLOAD_DIR, f"{storeSlot}") storeDir = os.path.join(UPLOAD_DIR)
print("[Voice Changer] store merged model to:", storeDir) print("[Voice Changer] store merged model to:", storeDir)
os.makedirs(storeDir, exist_ok=True) os.makedirs(storeDir, exist_ok=True)
storeFile = os.path.join(storeDir, "merged.pth") storeFile = os.path.join(storeDir, "merged.pth")

View File

@ -1,12 +1,14 @@
from typing import Dict, Any from typing import Dict, Any
import os
from collections import OrderedDict from collections import OrderedDict
import torch import torch
from voice_changer.ModelSlotManager import ModelSlotManager
from voice_changer.utils.ModelMerger import ModelMergerRequest from voice_changer.utils.ModelMerger import ModelMergerRequest
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
def merge_model(request: ModelMergerRequest): def merge_model(params: VoiceChangerParams, request: ModelMergerRequest):
def extract(ckpt: Dict[str, Any]): def extract(ckpt: Dict[str, Any]):
a = ckpt["model"] a = ckpt["model"]
opt: Dict[str, Any] = OrderedDict() opt: Dict[str, Any] = OrderedDict()
@ -34,11 +36,16 @@ def merge_model(request: ModelMergerRequest):
weights = [] weights = []
alphas = [] alphas = []
slotManager = ModelSlotManager.get_instance(params.model_dir)
for f in files: for f in files:
strength = f.strength strength = f.strength
if strength == 0: if strength == 0:
continue continue
weight, state_dict = load_weight(f.filename) slotInfo = slotManager.get_slot_info(f.slotIndex)
filename = os.path.join(params.model_dir, str(f.slotIndex), os.path.basename(slotInfo.modelFile)) # slotInfo.modelFileはv.1.5.3.11以前はmodel_dirから含まれている。
weight, state_dict = load_weight(filename)
weights.append(weight) weights.append(weight)
alphas.append(f.strength) alphas.append(f.strength)

View File

@ -306,8 +306,8 @@ class VoiceChangerManager(ServerDeviceCallbacks):
req.files = [MergeElement(**f) for f in req.files] req.files = [MergeElement(**f) for f in req.files]
slot = len(self.modelSlotManager.getAllSlotInfo()) - 1 slot = len(self.modelSlotManager.getAllSlotInfo()) - 1
if req.voiceChangerType == "RVC": if req.voiceChangerType == "RVC":
merged = RVCModelMerger.merge_models(req, slot) merged = RVCModelMerger.merge_models(self.params, req, slot)
loadParam = LoadModelParams(voiceChangerType="RVC", slot=slot, isSampleMode=False, sampleId="", files=[LoadModelParamFile(name=os.path.basename(merged), kind="rvcModel", dir=f"{slot}")], params={}) loadParam = LoadModelParams(voiceChangerType="RVC", slot=slot, isSampleMode=False, sampleId="", files=[LoadModelParamFile(name=os.path.basename(merged), kind="rvcModel", dir="")], params={})
self.loadModel(loadParam) self.loadModel(loadParam)
return self.get_info() return self.get_info()

View File

@ -5,6 +5,7 @@ from dataclasses import dataclass
@dataclass @dataclass
class MergeElement: class MergeElement:
slotIndex: int
filename: str filename: str
strength: int strength: int