voice-changer/server/voice_changer/RVC/modelMerger/MergeModel.py

77 lines
2.7 KiB
Python
Raw Normal View History

2023-04-30 20:34:01 +03:00
from typing import Dict, Any
2023-08-04 21:02:43 +03:00
import os
2023-04-30 20:34:01 +03:00
from collections import OrderedDict
import torch
2023-08-04 21:02:43 +03:00
from voice_changer.ModelSlotManager import ModelSlotManager
2023-04-30 20:34:01 +03:00
2023-06-23 08:54:39 +03:00
from voice_changer.utils.ModelMerger import ModelMergerRequest
2023-08-04 21:02:43 +03:00
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
2023-06-23 08:54:39 +03:00
2023-04-30 20:34:01 +03:00
2023-08-04 21:02:43 +03:00
def merge_model(params: VoiceChangerParams, request: ModelMergerRequest):
2023-04-30 20:34:01 +03:00
def extract(ckpt: Dict[str, Any]):
a = ckpt["model"]
opt: Dict[str, Any] = OrderedDict()
opt["weight"] = {}
for key in a.keys():
if "enc_q" in key:
continue
opt["weight"][key] = a[key]
return opt
def load_weight(path: str):
print(f"Loading {path}...")
state_dict = torch.load(path, map_location="cpu")
if "model" in state_dict:
weight = extract(state_dict)
else:
weight = state_dict["weight"]
return weight, state_dict
files = request.files
if len(files) == 0:
print("no merge file..............")
raise RuntimeError("no merge file..............")
weights = []
alphas = []
2023-08-04 21:02:43 +03:00
slotManager = ModelSlotManager.get_instance(params.model_dir)
2023-04-30 20:34:01 +03:00
for f in files:
strength = f.strength
if strength == 0:
continue
2023-08-04 21:02:43 +03:00
slotInfo = slotManager.get_slot_info(f.slotIndex)
filename = os.path.join(params.model_dir, str(f.slotIndex), os.path.basename(slotInfo.modelFile)) # slotInfo.modelFileはv.1.5.3.11以前はmodel_dirから含まれている。
weight, state_dict = load_weight(filename)
2023-04-30 20:34:01 +03:00
weights.append(weight)
alphas.append(f.strength)
alphas = [x / sum(alphas) for x in alphas]
for weight in weights:
if sorted(list(weight.keys())) != sorted(list(weights[0].keys())):
raise RuntimeError("Failed to merge models.")
merged: Dict[str, Any] = OrderedDict()
merged["weight"] = {}
2023-05-11 17:20:08 +03:00
print("merge start.")
for key in weights[0].keys():
2023-04-30 20:34:01 +03:00
merged["weight"][key] = 0
for i, weight in enumerate(weights):
merged["weight"][key] += weight[key] * alphas[i]
2023-05-11 17:20:08 +03:00
print("merge done. write metadata.")
2023-04-30 20:34:01 +03:00
merged["config"] = state_dict["config"]
merged["params"] = state_dict["params"] if "params" in state_dict else None
2023-05-21 05:50:28 +03:00
merged["version"] = state_dict["version"] if "version" in state_dict else None
2023-04-30 20:34:01 +03:00
merged["sr"] = state_dict["sr"]
merged["f0"] = state_dict["f0"]
merged["info"] = state_dict["info"]
2023-06-23 08:54:39 +03:00
merged["embedder_name"] = state_dict["embedder_name"] if "embedder_name" in state_dict else None
merged["embedder_output_layer"] = state_dict["embedder_output_layer"] if "embedder_output_layer" in state_dict else None
2023-05-11 17:20:08 +03:00
print("write metadata done.")
2023-04-30 20:34:01 +03:00
return merged