bugfix: merge model

This commit is contained in:
w-okada 2023-08-05 03:02:43 +09:00
parent 8b5eb32047
commit aef9d2c34b
12 changed files with 157 additions and 145 deletions

File diff suppressed because one or more lines are too long

View File

@ -1,31 +0,0 @@
/*! regenerator-runtime -- Copyright (c) 2014-present, Facebook, Inc. -- license (MIT): https://github.com/facebook/regenerator/blob/main/LICENSE */
/**
* @license React
* react-dom.production.min.js
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
/**
* @license React
* react.production.min.js
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
/**
* @license React
* scheduler.production.min.js
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

View File

@ -3,158 +3,182 @@ import { useGuiState } from "./001_GuiStateProvider";
import { useAppState } from "../../001_provider/001_AppStateProvider";
import { MergeElement, RVCModelSlot, RVCModelType, VoiceChangerType } from "@dannadori/voice-changer-client-js";
export const MergeLabDialog = () => {
const guiState = useGuiState()
const guiState = useGuiState();
const { serverSetting } = useAppState()
const [currentFilter, setCurrentFilter] = useState<string>("")
const [mergeElements, setMergeElements] = useState<MergeElement[]>([])
const { serverSetting } = useAppState();
const [currentFilter, setCurrentFilter] = useState<string>("");
const [mergeElements, setMergeElements] = useState<MergeElement[]>([]);
// スロットが変更されたときの初期化処理
const newSlotChangeKey = useMemo(() => {
if (!serverSetting.serverSetting.modelSlots) {
return ""
return "";
}
return serverSetting.serverSetting.modelSlots.reduce((prev, cur) => {
return prev + "_" + cur.modelFile
}, "")
}, [serverSetting.serverSetting.modelSlots])
return prev + "_" + cur.modelFile;
}, "");
}, [serverSetting.serverSetting.modelSlots]);
const filterItems = useMemo(() => {
return serverSetting.serverSetting.modelSlots.reduce((prev, cur) => {
return serverSetting.serverSetting.modelSlots.reduce(
(prev, cur) => {
if (cur.voiceChangerType != "RVC") {
return prev
return prev;
}
const curRVC = cur as RVCModelSlot
const key = `${curRVC.modelType},${cur.samplingRate},${curRVC.embChannels}`
const val = { type: curRVC.modelType, samplingRate: cur.samplingRate, embChannels: curRVC.embChannels }
const existKeys = Object.keys(prev)
const curRVC = cur as RVCModelSlot;
const key = `${curRVC.modelType},${cur.samplingRate},${curRVC.embChannels}`;
const val = { type: curRVC.modelType, samplingRate: cur.samplingRate, embChannels: curRVC.embChannels };
const existKeys = Object.keys(prev);
if (!cur.modelFile || cur.modelFile.length == 0) {
return prev
return prev;
}
if (curRVC.modelType == "onnxRVC" || curRVC.modelType == "onnxRVCNono") {
return prev
return prev;
}
if (!existKeys.includes(key)) {
prev[key] = val
prev[key] = val;
}
return prev
}, {} as { [key: string]: { type: RVCModelType, samplingRate: number, embChannels: number } })
}, [newSlotChangeKey])
return prev;
},
{} as { [key: string]: { type: RVCModelType; samplingRate: number; embChannels: number } },
);
}, [newSlotChangeKey]);
const models = useMemo(() => {
return serverSetting.serverSetting.modelSlots.filter(x => {
return serverSetting.serverSetting.modelSlots.filter((x) => {
if (x.voiceChangerType != "RVC") {
return
return;
}
const xRVC = x as RVCModelSlot
const filterVals = filterItems[currentFilter]
const xRVC = x as RVCModelSlot;
const filterVals = filterItems[currentFilter];
if (!filterVals) {
return false
return false;
}
if (xRVC.modelType == filterVals.type && xRVC.samplingRate == filterVals.samplingRate && xRVC.embChannels == filterVals.embChannels) {
return true
return true;
} else {
return false
return false;
}
})
}, [filterItems, currentFilter])
});
}, [filterItems, currentFilter]);
useEffect(() => {
if (Object.keys(filterItems).length > 0) {
setCurrentFilter(Object.keys(filterItems)[0])
setCurrentFilter(Object.keys(filterItems)[0]);
}
}, [filterItems])
}, [filterItems]);
useEffect(() => {
// models はフィルタ後の配列
const newMergeElements = models.map((x) => {
return { filename: x.modelFile, strength: 0 }
})
setMergeElements(newMergeElements)
}, [models])
return { slotIndex: x.slotIndex, filename: x.modelFile, strength: 0 };
});
setMergeElements(newMergeElements);
}, [models]);
const dialog = useMemo(() => {
const closeButtonRow = (
<div className="body-row split-3-4-3 left-padding-1">
<div className="body-item-text">
</div>
<div className="body-item-text"></div>
<div className="body-button-container body-button-container-space-around">
<div className="body-button" onClick={() => { guiState.stateControls.showMergeLabCheckbox.updateState(false) }} >close</div>
<div
className="body-button"
onClick={() => {
guiState.stateControls.showMergeLabCheckbox.updateState(false);
}}
>
close
</div>
</div>
<div className="body-item-text"></div>
</div>
)
);
const filterOptions = Object.keys(filterItems).map(x => {
return <option key={x} value={x}>{x}</option>
}).filter(x => x != null)
const onMergeElementsChanged = (filename: string, strength: number) => {
const newMergeElements = mergeElements.map((x) => {
if (x.filename == filename) {
return { filename: x.filename, strength: strength }
} else {
return x
}
const filterOptions = Object.keys(filterItems)
.map((x) => {
return (
<option key={x} value={x}>
{x}
</option>
);
})
setMergeElements(newMergeElements)
.filter((x) => x != null);
const onMergeElementsChanged = (slotIndex: number, strength: number) => {
const newMergeElements = mergeElements.map((x) => {
if (x.slotIndex == slotIndex) {
return { slotIndex: x.slotIndex, filename: x.filename, strength: strength };
} else {
return x;
}
});
setMergeElements(newMergeElements);
};
const onMergeClicked = () => {
const validMergeElements = mergeElements.filter((x) => {
return x.strength > 0;
});
serverSetting.mergeModel({
voiceChangerType: VoiceChangerType.RVC,
command: "mix",
files: mergeElements
})
}
files: validMergeElements,
});
};
const modelList = mergeElements.map((x, index) => {
const name = models.find(model => { return model.modelFile == x.filename })?.name || ""
const name =
models.find((model) => {
return model.slotIndex == x.slotIndex;
})?.name || "";
return (
<div key={index} className="merge-lab-model-item">
<div>{name}</div>
<div>
{name}
</div>
<div>
<input type="range" className="body-item-input-slider" min="0" max="100" step="1" value={x.strength} onChange={(e) => {
onMergeElementsChanged(x.filename, Number(e.target.value))
}}></input>
<input
type="range"
className="body-item-input-slider"
min="0"
max="100"
step="1"
value={x.strength}
onChange={(e) => {
onMergeElementsChanged(x.slotIndex, Number(e.target.value));
}}
></input>
<span className="body-item-input-slider-val">{x.strength}</span>
</div>
</div>
)
})
);
});
const content = (
<div className="merge-lab-container">
<div className="merge-lab-type-filter">
<div>Type:</div>
<div>
Type:
</div>
<div>
<select value={currentFilter} onChange={(e) => { setCurrentFilter(e.target.value) }}>
<select
value={currentFilter}
onChange={(e) => {
setCurrentFilter(e.target.value);
}}
>
{filterOptions}
</select>
</div>
</div>
<div className="merge-lab-manipulator">
<div className="merge-lab-model-list">
{modelList}
</div>
<div className="merge-lab-model-list">{modelList}</div>
<div className="merge-lab-merge-buttons">
<div className="merge-lab-merge-buttons-notice">
The merged model is stored in the final slot. If you assign this slot, it will be overwritten.
</div>
<div className="merge-lab-merge-buttons-notice">The merged model is stored in the final slot. If you assign this slot, it will be overwritten.</div>
<div className="merge-lab-merge-button" onClick={onMergeClicked}>
merge
</div>
</div>
</div>
</div>
)
);
return (
<div className="dialog-frame">
<div className="dialog-title">MergeLab</div>
@ -166,5 +190,4 @@ export const MergeLabDialog = () => {
);
}, [newSlotChangeKey, currentFilter, mergeElements, models]);
return dialog;
};

View File

@ -38,7 +38,7 @@ export const ModelSlotArea = (_props: ModelSlotAreaProps) => {
if (!x.modelFile || x.modelFile.length == 0) {
return null;
}
const tileContainerClass = x.id == serverSetting.serverSetting.modelSlotIndex ? "model-slot-tile-container-selected" : "model-slot-tile-container";
const tileContainerClass = x.slotIndex == serverSetting.serverSetting.modelSlotIndex ? "model-slot-tile-container-selected" : "model-slot-tile-container";
const name = x.name.length > 8 ? x.name.substring(0, 7) + "..." : x.name;
const iconElem =
x.iconFile.length > 0 ? (
@ -54,7 +54,7 @@ export const ModelSlotArea = (_props: ModelSlotAreaProps) => {
);
const clickAction = async () => {
const dummyModelSlotIndex = Math.floor(Date.now() / 1000) * 1000 + x.id;
const dummyModelSlotIndex = Math.floor(Date.now() / 1000) * 1000 + x.slotIndex;
await serverSetting.updateServerSettings({ ...serverSetting.serverSetting, modelSlotIndex: dummyModelSlotIndex });
setTimeout(() => {
// quick hack

View File

@ -193,7 +193,7 @@ export type VoiceChangerServerSetting = {
}
type ModelSlot = {
id: number
slotIndex: number
voiceChangerType: VoiceChangerType
name: string,
description: string,
@ -539,7 +539,8 @@ export type OnnxExporterInfo = {
// Merge
export type MergeElement = {
filename: string
slotIndex: number
filename: string // 一意性は保障されない場合がある(フォルダコピーされたときとか)
strength: number
}
export type MergeModelRequest = {

View File

@ -9,7 +9,7 @@ import json
@dataclass
class ModelSlot:
id: int = -1
slotIndex: int = -1
voiceChangerType: VoiceChangerType | None = None
name: str = ""
description: str = ""
@ -133,19 +133,26 @@ def loadSlotInfo(model_dir: str, slotIndex: int) -> ModelSlots:
if not os.path.exists(jsonFile):
return ModelSlot()
jsonDict = json.load(open(os.path.join(slotDir, "params.json")))
slotInfo = ModelSlot(**{k: v for k, v in jsonDict.items() if k in ModelSlot.__annotations__})
slotInfoKey = list(ModelSlot.__annotations__.keys())
slotInfo = ModelSlot(**{k: v for k, v in jsonDict.items() if k in slotInfoKey})
if slotInfo.voiceChangerType == "RVC":
return RVCModelSlot(**jsonDict)
slotInfoKey.extend(list(RVCModelSlot.__annotations__.keys()))
return RVCModelSlot(**{k: v for k, v in jsonDict.items() if k in slotInfoKey})
elif slotInfo.voiceChangerType == "MMVCv13":
return MMVCv13ModelSlot(**jsonDict)
slotInfoKey.extend(list(MMVCv13ModelSlot.__annotations__.keys()))
return MMVCv13ModelSlot(**{k: v for k, v in jsonDict.items() if k in slotInfoKey})
elif slotInfo.voiceChangerType == "MMVCv15":
return MMVCv15ModelSlot(**jsonDict)
slotInfoKey.extend(list(MMVCv15ModelSlot.__annotations__.keys()))
return MMVCv15ModelSlot(**{k: v for k, v in jsonDict.items() if k in slotInfoKey})
elif slotInfo.voiceChangerType == "so-vits-svc-40":
return SoVitsSvc40ModelSlot(**jsonDict)
slotInfoKey.extend(list(SoVitsSvc40ModelSlot.__annotations__.keys()))
return SoVitsSvc40ModelSlot(**{k: v for k, v in jsonDict.items() if k in slotInfoKey})
elif slotInfo.voiceChangerType == "DDSP-SVC":
return DDSPSVCModelSlot(**jsonDict)
slotInfoKey.extend(list(DDSPSVCModelSlot.__annotations__.keys()))
return DDSPSVCModelSlot(**{k: v for k, v in jsonDict.items() if k in slotInfoKey})
elif slotInfo.voiceChangerType == "Diffusion-SVC":
return DiffusionSVCModelSlot(**jsonDict)
slotInfoKey.extend(list(DiffusionSVCModelSlot.__annotations__.keys()))
return DiffusionSVCModelSlot(**{k: v for k, v in jsonDict.items() if k in slotInfoKey})
else:
return ModelSlot()
@ -154,11 +161,13 @@ def loadAllSlotInfo(model_dir: str):
slotInfos: list[ModelSlots] = []
for slotIndex in range(MAX_SLOT_NUM):
slotInfo = loadSlotInfo(model_dir, slotIndex)
slotInfo.id = slotIndex
slotInfo.slotIndex = slotIndex # スロットインデックスは動的に注入
slotInfos.append(slotInfo)
return slotInfos
def saveSlotInfo(model_dir: str, slotIndex: int, slotInfo: ModelSlots):
slotDir = os.path.join(model_dir, str(slotIndex))
json.dump(asdict(slotInfo), open(os.path.join(slotDir, "params.json"), "w"))
slotInfoDict = asdict(slotInfo)
slotInfo.slotIndex = -1 # スロットインデックスは動的に注入
json.dump(slotInfoDict, open(os.path.join(slotDir, "params.json"), "w"))

View File

@ -113,6 +113,8 @@ class MMVC_Rest_Fileuploader:
return JSONResponse(content=json_compatible_item_data)
except Exception as e:
print("[Voice Changer] post_merge_models ex:", e)
import traceback
traceback.print_exc()
def post_update_model_default(self):
try:

View File

@ -4,17 +4,17 @@ import torch
from const import UPLOAD_DIR
from voice_changer.RVC.modelMerger.MergeModel import merge_model
from voice_changer.utils.ModelMerger import ModelMerger, ModelMergerRequest
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
class RVCModelMerger(ModelMerger):
@classmethod
def merge_models(cls, request: ModelMergerRequest, storeSlot: int):
print("[Voice Changer] MergeRequest:", request)
merged = merge_model(request)
def merge_models(cls, params: VoiceChangerParams, request: ModelMergerRequest, storeSlot: int):
merged = merge_model(params, request)
# いったんは、アップロードフォルダに格納する。(歴史的経緯)
# 後続のloadmodelを呼び出すことで永続化モデルフォルダに移動させられる。
storeDir = os.path.join(UPLOAD_DIR, f"{storeSlot}")
storeDir = os.path.join(UPLOAD_DIR)
print("[Voice Changer] store merged model to:", storeDir)
os.makedirs(storeDir, exist_ok=True)
storeFile = os.path.join(storeDir, "merged.pth")

View File

@ -1,12 +1,14 @@
from typing import Dict, Any
import os
from collections import OrderedDict
import torch
from voice_changer.ModelSlotManager import ModelSlotManager
from voice_changer.utils.ModelMerger import ModelMergerRequest
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
def merge_model(request: ModelMergerRequest):
def merge_model(params: VoiceChangerParams, request: ModelMergerRequest):
def extract(ckpt: Dict[str, Any]):
a = ckpt["model"]
opt: Dict[str, Any] = OrderedDict()
@ -34,11 +36,16 @@ def merge_model(request: ModelMergerRequest):
weights = []
alphas = []
slotManager = ModelSlotManager.get_instance(params.model_dir)
for f in files:
strength = f.strength
if strength == 0:
continue
weight, state_dict = load_weight(f.filename)
slotInfo = slotManager.get_slot_info(f.slotIndex)
filename = os.path.join(params.model_dir, str(f.slotIndex), os.path.basename(slotInfo.modelFile)) # slotInfo.modelFileはv.1.5.3.11以前はmodel_dirから含まれている。
weight, state_dict = load_weight(filename)
weights.append(weight)
alphas.append(f.strength)

View File

@ -306,8 +306,8 @@ class VoiceChangerManager(ServerDeviceCallbacks):
req.files = [MergeElement(**f) for f in req.files]
slot = len(self.modelSlotManager.getAllSlotInfo()) - 1
if req.voiceChangerType == "RVC":
merged = RVCModelMerger.merge_models(req, slot)
loadParam = LoadModelParams(voiceChangerType="RVC", slot=slot, isSampleMode=False, sampleId="", files=[LoadModelParamFile(name=os.path.basename(merged), kind="rvcModel", dir=f"{slot}")], params={})
merged = RVCModelMerger.merge_models(self.params, req, slot)
loadParam = LoadModelParams(voiceChangerType="RVC", slot=slot, isSampleMode=False, sampleId="", files=[LoadModelParamFile(name=os.path.basename(merged), kind="rvcModel", dir="")], params={})
self.loadModel(loadParam)
return self.get_info()

View File

@ -5,6 +5,7 @@ from dataclasses import dataclass
@dataclass
class MergeElement:
slotIndex: int
filename: str
strength: int