mirror of
https://github.com/w-okada/voice-changer.git
synced 2025-01-23 21:45:00 +03:00
bugfix: merge model
This commit is contained in:
parent
8b5eb32047
commit
aef9d2c34b
4
client/demo/dist/index.js
vendored
4
client/demo/dist/index.js
vendored
File diff suppressed because one or more lines are too long
31
client/demo/dist/index.js.LICENSE.txt
vendored
31
client/demo/dist/index.js.LICENSE.txt
vendored
@ -1,31 +0,0 @@
|
||||
/*! regenerator-runtime -- Copyright (c) 2014-present, Facebook, Inc. -- license (MIT): https://github.com/facebook/regenerator/blob/main/LICENSE */
|
||||
|
||||
/**
|
||||
* @license React
|
||||
* react-dom.production.min.js
|
||||
*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @license React
|
||||
* react.production.min.js
|
||||
*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @license React
|
||||
* scheduler.production.min.js
|
||||
*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
@ -3,158 +3,182 @@ import { useGuiState } from "./001_GuiStateProvider";
|
||||
import { useAppState } from "../../001_provider/001_AppStateProvider";
|
||||
import { MergeElement, RVCModelSlot, RVCModelType, VoiceChangerType } from "@dannadori/voice-changer-client-js";
|
||||
|
||||
|
||||
export const MergeLabDialog = () => {
|
||||
const guiState = useGuiState()
|
||||
const guiState = useGuiState();
|
||||
|
||||
const { serverSetting } = useAppState()
|
||||
const [currentFilter, setCurrentFilter] = useState<string>("")
|
||||
const [mergeElements, setMergeElements] = useState<MergeElement[]>([])
|
||||
const { serverSetting } = useAppState();
|
||||
const [currentFilter, setCurrentFilter] = useState<string>("");
|
||||
const [mergeElements, setMergeElements] = useState<MergeElement[]>([]);
|
||||
|
||||
// スロットが変更されたときの初期化処理
|
||||
const newSlotChangeKey = useMemo(() => {
|
||||
if (!serverSetting.serverSetting.modelSlots) {
|
||||
return ""
|
||||
return "";
|
||||
}
|
||||
return serverSetting.serverSetting.modelSlots.reduce((prev, cur) => {
|
||||
return prev + "_" + cur.modelFile
|
||||
}, "")
|
||||
}, [serverSetting.serverSetting.modelSlots])
|
||||
return prev + "_" + cur.modelFile;
|
||||
}, "");
|
||||
}, [serverSetting.serverSetting.modelSlots]);
|
||||
|
||||
const filterItems = useMemo(() => {
|
||||
return serverSetting.serverSetting.modelSlots.reduce((prev, cur) => {
|
||||
return serverSetting.serverSetting.modelSlots.reduce(
|
||||
(prev, cur) => {
|
||||
if (cur.voiceChangerType != "RVC") {
|
||||
return prev
|
||||
return prev;
|
||||
}
|
||||
const curRVC = cur as RVCModelSlot
|
||||
const key = `${curRVC.modelType},${cur.samplingRate},${curRVC.embChannels}`
|
||||
const val = { type: curRVC.modelType, samplingRate: cur.samplingRate, embChannels: curRVC.embChannels }
|
||||
const existKeys = Object.keys(prev)
|
||||
const curRVC = cur as RVCModelSlot;
|
||||
const key = `${curRVC.modelType},${cur.samplingRate},${curRVC.embChannels}`;
|
||||
const val = { type: curRVC.modelType, samplingRate: cur.samplingRate, embChannels: curRVC.embChannels };
|
||||
const existKeys = Object.keys(prev);
|
||||
if (!cur.modelFile || cur.modelFile.length == 0) {
|
||||
return prev
|
||||
return prev;
|
||||
}
|
||||
if (curRVC.modelType == "onnxRVC" || curRVC.modelType == "onnxRVCNono") {
|
||||
return prev
|
||||
return prev;
|
||||
}
|
||||
if (!existKeys.includes(key)) {
|
||||
prev[key] = val
|
||||
prev[key] = val;
|
||||
}
|
||||
return prev
|
||||
}, {} as { [key: string]: { type: RVCModelType, samplingRate: number, embChannels: number } })
|
||||
|
||||
}, [newSlotChangeKey])
|
||||
return prev;
|
||||
},
|
||||
{} as { [key: string]: { type: RVCModelType; samplingRate: number; embChannels: number } },
|
||||
);
|
||||
}, [newSlotChangeKey]);
|
||||
|
||||
const models = useMemo(() => {
|
||||
return serverSetting.serverSetting.modelSlots.filter(x => {
|
||||
return serverSetting.serverSetting.modelSlots.filter((x) => {
|
||||
if (x.voiceChangerType != "RVC") {
|
||||
return
|
||||
return;
|
||||
}
|
||||
const xRVC = x as RVCModelSlot
|
||||
const filterVals = filterItems[currentFilter]
|
||||
const xRVC = x as RVCModelSlot;
|
||||
const filterVals = filterItems[currentFilter];
|
||||
if (!filterVals) {
|
||||
return false
|
||||
return false;
|
||||
}
|
||||
if (xRVC.modelType == filterVals.type && xRVC.samplingRate == filterVals.samplingRate && xRVC.embChannels == filterVals.embChannels) {
|
||||
return true
|
||||
return true;
|
||||
} else {
|
||||
return false
|
||||
return false;
|
||||
}
|
||||
})
|
||||
}, [filterItems, currentFilter])
|
||||
});
|
||||
}, [filterItems, currentFilter]);
|
||||
|
||||
useEffect(() => {
|
||||
if (Object.keys(filterItems).length > 0) {
|
||||
setCurrentFilter(Object.keys(filterItems)[0])
|
||||
setCurrentFilter(Object.keys(filterItems)[0]);
|
||||
}
|
||||
}, [filterItems])
|
||||
}, [filterItems]);
|
||||
useEffect(() => {
|
||||
// models はフィルタ後の配列
|
||||
const newMergeElements = models.map((x) => {
|
||||
return { filename: x.modelFile, strength: 0 }
|
||||
})
|
||||
setMergeElements(newMergeElements)
|
||||
}, [models])
|
||||
return { slotIndex: x.slotIndex, filename: x.modelFile, strength: 0 };
|
||||
});
|
||||
setMergeElements(newMergeElements);
|
||||
}, [models]);
|
||||
|
||||
const dialog = useMemo(() => {
|
||||
const closeButtonRow = (
|
||||
<div className="body-row split-3-4-3 left-padding-1">
|
||||
<div className="body-item-text">
|
||||
</div>
|
||||
<div className="body-item-text"></div>
|
||||
<div className="body-button-container body-button-container-space-around">
|
||||
<div className="body-button" onClick={() => { guiState.stateControls.showMergeLabCheckbox.updateState(false) }} >close</div>
|
||||
<div
|
||||
className="body-button"
|
||||
onClick={() => {
|
||||
guiState.stateControls.showMergeLabCheckbox.updateState(false);
|
||||
}}
|
||||
>
|
||||
close
|
||||
</div>
|
||||
</div>
|
||||
<div className="body-item-text"></div>
|
||||
</div>
|
||||
)
|
||||
);
|
||||
|
||||
|
||||
const filterOptions = Object.keys(filterItems).map(x => {
|
||||
return <option key={x} value={x}>{x}</option>
|
||||
}).filter(x => x != null)
|
||||
|
||||
const onMergeElementsChanged = (filename: string, strength: number) => {
|
||||
const newMergeElements = mergeElements.map((x) => {
|
||||
if (x.filename == filename) {
|
||||
return { filename: x.filename, strength: strength }
|
||||
} else {
|
||||
return x
|
||||
}
|
||||
const filterOptions = Object.keys(filterItems)
|
||||
.map((x) => {
|
||||
return (
|
||||
<option key={x} value={x}>
|
||||
{x}
|
||||
</option>
|
||||
);
|
||||
})
|
||||
setMergeElements(newMergeElements)
|
||||
.filter((x) => x != null);
|
||||
|
||||
const onMergeElementsChanged = (slotIndex: number, strength: number) => {
|
||||
const newMergeElements = mergeElements.map((x) => {
|
||||
if (x.slotIndex == slotIndex) {
|
||||
return { slotIndex: x.slotIndex, filename: x.filename, strength: strength };
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
});
|
||||
setMergeElements(newMergeElements);
|
||||
};
|
||||
|
||||
const onMergeClicked = () => {
|
||||
const validMergeElements = mergeElements.filter((x) => {
|
||||
return x.strength > 0;
|
||||
});
|
||||
serverSetting.mergeModel({
|
||||
voiceChangerType: VoiceChangerType.RVC,
|
||||
command: "mix",
|
||||
files: mergeElements
|
||||
})
|
||||
}
|
||||
files: validMergeElements,
|
||||
});
|
||||
};
|
||||
|
||||
const modelList = mergeElements.map((x, index) => {
|
||||
const name = models.find(model => { return model.modelFile == x.filename })?.name || ""
|
||||
const name =
|
||||
models.find((model) => {
|
||||
return model.slotIndex == x.slotIndex;
|
||||
})?.name || "";
|
||||
|
||||
return (
|
||||
<div key={index} className="merge-lab-model-item">
|
||||
<div>{name}</div>
|
||||
<div>
|
||||
{name}
|
||||
</div>
|
||||
<div>
|
||||
<input type="range" className="body-item-input-slider" min="0" max="100" step="1" value={x.strength} onChange={(e) => {
|
||||
onMergeElementsChanged(x.filename, Number(e.target.value))
|
||||
}}></input>
|
||||
<input
|
||||
type="range"
|
||||
className="body-item-input-slider"
|
||||
min="0"
|
||||
max="100"
|
||||
step="1"
|
||||
value={x.strength}
|
||||
onChange={(e) => {
|
||||
onMergeElementsChanged(x.slotIndex, Number(e.target.value));
|
||||
}}
|
||||
></input>
|
||||
<span className="body-item-input-slider-val">{x.strength}</span>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
})
|
||||
|
||||
);
|
||||
});
|
||||
|
||||
const content = (
|
||||
<div className="merge-lab-container">
|
||||
<div className="merge-lab-type-filter">
|
||||
<div>Type:</div>
|
||||
<div>
|
||||
Type:
|
||||
</div>
|
||||
<div>
|
||||
<select value={currentFilter} onChange={(e) => { setCurrentFilter(e.target.value) }}>
|
||||
<select
|
||||
value={currentFilter}
|
||||
onChange={(e) => {
|
||||
setCurrentFilter(e.target.value);
|
||||
}}
|
||||
>
|
||||
{filterOptions}
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
<div className="merge-lab-manipulator">
|
||||
<div className="merge-lab-model-list">
|
||||
{modelList}
|
||||
</div>
|
||||
<div className="merge-lab-model-list">{modelList}</div>
|
||||
<div className="merge-lab-merge-buttons">
|
||||
<div className="merge-lab-merge-buttons-notice">
|
||||
The merged model is stored in the final slot. If you assign this slot, it will be overwritten.
|
||||
</div>
|
||||
<div className="merge-lab-merge-buttons-notice">The merged model is stored in the final slot. If you assign this slot, it will be overwritten.</div>
|
||||
<div className="merge-lab-merge-button" onClick={onMergeClicked}>
|
||||
merge
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
);
|
||||
return (
|
||||
<div className="dialog-frame">
|
||||
<div className="dialog-title">MergeLab</div>
|
||||
@ -166,5 +190,4 @@ export const MergeLabDialog = () => {
|
||||
);
|
||||
}, [newSlotChangeKey, currentFilter, mergeElements, models]);
|
||||
return dialog;
|
||||
|
||||
};
|
||||
|
@ -38,7 +38,7 @@ export const ModelSlotArea = (_props: ModelSlotAreaProps) => {
|
||||
if (!x.modelFile || x.modelFile.length == 0) {
|
||||
return null;
|
||||
}
|
||||
const tileContainerClass = x.id == serverSetting.serverSetting.modelSlotIndex ? "model-slot-tile-container-selected" : "model-slot-tile-container";
|
||||
const tileContainerClass = x.slotIndex == serverSetting.serverSetting.modelSlotIndex ? "model-slot-tile-container-selected" : "model-slot-tile-container";
|
||||
const name = x.name.length > 8 ? x.name.substring(0, 7) + "..." : x.name;
|
||||
const iconElem =
|
||||
x.iconFile.length > 0 ? (
|
||||
@ -54,7 +54,7 @@ export const ModelSlotArea = (_props: ModelSlotAreaProps) => {
|
||||
);
|
||||
|
||||
const clickAction = async () => {
|
||||
const dummyModelSlotIndex = Math.floor(Date.now() / 1000) * 1000 + x.id;
|
||||
const dummyModelSlotIndex = Math.floor(Date.now() / 1000) * 1000 + x.slotIndex;
|
||||
await serverSetting.updateServerSettings({ ...serverSetting.serverSetting, modelSlotIndex: dummyModelSlotIndex });
|
||||
setTimeout(() => {
|
||||
// quick hack
|
||||
|
@ -193,7 +193,7 @@ export type VoiceChangerServerSetting = {
|
||||
}
|
||||
|
||||
type ModelSlot = {
|
||||
id: number
|
||||
slotIndex: number
|
||||
voiceChangerType: VoiceChangerType
|
||||
name: string,
|
||||
description: string,
|
||||
@ -539,7 +539,8 @@ export type OnnxExporterInfo = {
|
||||
|
||||
// Merge
|
||||
export type MergeElement = {
|
||||
filename: string
|
||||
slotIndex: number
|
||||
filename: string // 一意性は保障されない場合がある(フォルダコピーされたときとか)
|
||||
strength: number
|
||||
}
|
||||
export type MergeModelRequest = {
|
||||
|
@ -9,7 +9,7 @@ import json
|
||||
|
||||
@dataclass
|
||||
class ModelSlot:
|
||||
id: int = -1
|
||||
slotIndex: int = -1
|
||||
voiceChangerType: VoiceChangerType | None = None
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
@ -133,19 +133,26 @@ def loadSlotInfo(model_dir: str, slotIndex: int) -> ModelSlots:
|
||||
if not os.path.exists(jsonFile):
|
||||
return ModelSlot()
|
||||
jsonDict = json.load(open(os.path.join(slotDir, "params.json")))
|
||||
slotInfo = ModelSlot(**{k: v for k, v in jsonDict.items() if k in ModelSlot.__annotations__})
|
||||
slotInfoKey = list(ModelSlot.__annotations__.keys())
|
||||
slotInfo = ModelSlot(**{k: v for k, v in jsonDict.items() if k in slotInfoKey})
|
||||
if slotInfo.voiceChangerType == "RVC":
|
||||
return RVCModelSlot(**jsonDict)
|
||||
slotInfoKey.extend(list(RVCModelSlot.__annotations__.keys()))
|
||||
return RVCModelSlot(**{k: v for k, v in jsonDict.items() if k in slotInfoKey})
|
||||
elif slotInfo.voiceChangerType == "MMVCv13":
|
||||
return MMVCv13ModelSlot(**jsonDict)
|
||||
slotInfoKey.extend(list(MMVCv13ModelSlot.__annotations__.keys()))
|
||||
return MMVCv13ModelSlot(**{k: v for k, v in jsonDict.items() if k in slotInfoKey})
|
||||
elif slotInfo.voiceChangerType == "MMVCv15":
|
||||
return MMVCv15ModelSlot(**jsonDict)
|
||||
slotInfoKey.extend(list(MMVCv15ModelSlot.__annotations__.keys()))
|
||||
return MMVCv15ModelSlot(**{k: v for k, v in jsonDict.items() if k in slotInfoKey})
|
||||
elif slotInfo.voiceChangerType == "so-vits-svc-40":
|
||||
return SoVitsSvc40ModelSlot(**jsonDict)
|
||||
slotInfoKey.extend(list(SoVitsSvc40ModelSlot.__annotations__.keys()))
|
||||
return SoVitsSvc40ModelSlot(**{k: v for k, v in jsonDict.items() if k in slotInfoKey})
|
||||
elif slotInfo.voiceChangerType == "DDSP-SVC":
|
||||
return DDSPSVCModelSlot(**jsonDict)
|
||||
slotInfoKey.extend(list(DDSPSVCModelSlot.__annotations__.keys()))
|
||||
return DDSPSVCModelSlot(**{k: v for k, v in jsonDict.items() if k in slotInfoKey})
|
||||
elif slotInfo.voiceChangerType == "Diffusion-SVC":
|
||||
return DiffusionSVCModelSlot(**jsonDict)
|
||||
slotInfoKey.extend(list(DiffusionSVCModelSlot.__annotations__.keys()))
|
||||
return DiffusionSVCModelSlot(**{k: v for k, v in jsonDict.items() if k in slotInfoKey})
|
||||
else:
|
||||
return ModelSlot()
|
||||
|
||||
@ -154,11 +161,13 @@ def loadAllSlotInfo(model_dir: str):
|
||||
slotInfos: list[ModelSlots] = []
|
||||
for slotIndex in range(MAX_SLOT_NUM):
|
||||
slotInfo = loadSlotInfo(model_dir, slotIndex)
|
||||
slotInfo.id = slotIndex
|
||||
slotInfo.slotIndex = slotIndex # スロットインデックスは動的に注入
|
||||
slotInfos.append(slotInfo)
|
||||
return slotInfos
|
||||
|
||||
|
||||
def saveSlotInfo(model_dir: str, slotIndex: int, slotInfo: ModelSlots):
|
||||
slotDir = os.path.join(model_dir, str(slotIndex))
|
||||
json.dump(asdict(slotInfo), open(os.path.join(slotDir, "params.json"), "w"))
|
||||
slotInfoDict = asdict(slotInfo)
|
||||
slotInfo.slotIndex = -1 # スロットインデックスは動的に注入
|
||||
json.dump(slotInfoDict, open(os.path.join(slotDir, "params.json"), "w"))
|
||||
|
@ -113,6 +113,8 @@ class MMVC_Rest_Fileuploader:
|
||||
return JSONResponse(content=json_compatible_item_data)
|
||||
except Exception as e:
|
||||
print("[Voice Changer] post_merge_models ex:", e)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def post_update_model_default(self):
|
||||
try:
|
||||
|
@ -4,17 +4,17 @@ import torch
|
||||
from const import UPLOAD_DIR
|
||||
from voice_changer.RVC.modelMerger.MergeModel import merge_model
|
||||
from voice_changer.utils.ModelMerger import ModelMerger, ModelMergerRequest
|
||||
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
|
||||
|
||||
|
||||
class RVCModelMerger(ModelMerger):
|
||||
@classmethod
|
||||
def merge_models(cls, request: ModelMergerRequest, storeSlot: int):
|
||||
print("[Voice Changer] MergeRequest:", request)
|
||||
merged = merge_model(request)
|
||||
def merge_models(cls, params: VoiceChangerParams, request: ModelMergerRequest, storeSlot: int):
|
||||
merged = merge_model(params, request)
|
||||
|
||||
# いったんは、アップロードフォルダに格納する。(歴史的経緯)
|
||||
# 後続のloadmodelを呼び出すことで永続化モデルフォルダに移動させられる。
|
||||
storeDir = os.path.join(UPLOAD_DIR, f"{storeSlot}")
|
||||
storeDir = os.path.join(UPLOAD_DIR)
|
||||
print("[Voice Changer] store merged model to:", storeDir)
|
||||
os.makedirs(storeDir, exist_ok=True)
|
||||
storeFile = os.path.join(storeDir, "merged.pth")
|
||||
|
@ -1,12 +1,14 @@
|
||||
from typing import Dict, Any
|
||||
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
from voice_changer.ModelSlotManager import ModelSlotManager
|
||||
|
||||
from voice_changer.utils.ModelMerger import ModelMergerRequest
|
||||
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
|
||||
|
||||
|
||||
def merge_model(request: ModelMergerRequest):
|
||||
def merge_model(params: VoiceChangerParams, request: ModelMergerRequest):
|
||||
def extract(ckpt: Dict[str, Any]):
|
||||
a = ckpt["model"]
|
||||
opt: Dict[str, Any] = OrderedDict()
|
||||
@ -34,11 +36,16 @@ def merge_model(request: ModelMergerRequest):
|
||||
|
||||
weights = []
|
||||
alphas = []
|
||||
slotManager = ModelSlotManager.get_instance(params.model_dir)
|
||||
for f in files:
|
||||
strength = f.strength
|
||||
if strength == 0:
|
||||
continue
|
||||
weight, state_dict = load_weight(f.filename)
|
||||
slotInfo = slotManager.get_slot_info(f.slotIndex)
|
||||
|
||||
filename = os.path.join(params.model_dir, str(f.slotIndex), os.path.basename(slotInfo.modelFile)) # slotInfo.modelFileはv.1.5.3.11以前はmodel_dirから含まれている。
|
||||
|
||||
weight, state_dict = load_weight(filename)
|
||||
weights.append(weight)
|
||||
alphas.append(f.strength)
|
||||
|
||||
|
@ -306,8 +306,8 @@ class VoiceChangerManager(ServerDeviceCallbacks):
|
||||
req.files = [MergeElement(**f) for f in req.files]
|
||||
slot = len(self.modelSlotManager.getAllSlotInfo()) - 1
|
||||
if req.voiceChangerType == "RVC":
|
||||
merged = RVCModelMerger.merge_models(req, slot)
|
||||
loadParam = LoadModelParams(voiceChangerType="RVC", slot=slot, isSampleMode=False, sampleId="", files=[LoadModelParamFile(name=os.path.basename(merged), kind="rvcModel", dir=f"{slot}")], params={})
|
||||
merged = RVCModelMerger.merge_models(self.params, req, slot)
|
||||
loadParam = LoadModelParams(voiceChangerType="RVC", slot=slot, isSampleMode=False, sampleId="", files=[LoadModelParamFile(name=os.path.basename(merged), kind="rvcModel", dir="")], params={})
|
||||
self.loadModel(loadParam)
|
||||
return self.get_info()
|
||||
|
||||
|
@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class MergeElement:
|
||||
slotIndex: int
|
||||
filename: str
|
||||
strength: int
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user