2023-04-30 20:34:01 +03:00
|
|
|
from typing import Dict, Any
|
2023-05-04 07:09:13 +03:00
|
|
|
from voice_changer.RVC.modelMerger.MergeModelRequest import MergeModelRequest
|
2023-04-30 20:34:01 +03:00
|
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
def merge_model(request: MergeModelRequest):
|
|
|
|
def extract(ckpt: Dict[str, Any]):
|
|
|
|
a = ckpt["model"]
|
|
|
|
opt: Dict[str, Any] = OrderedDict()
|
|
|
|
|
|
|
|
opt["weight"] = {}
|
|
|
|
for key in a.keys():
|
|
|
|
if "enc_q" in key:
|
|
|
|
continue
|
|
|
|
opt["weight"][key] = a[key]
|
|
|
|
return opt
|
|
|
|
|
|
|
|
def load_weight(path: str):
|
|
|
|
print(f"Loading {path}...")
|
|
|
|
state_dict = torch.load(path, map_location="cpu")
|
|
|
|
if "model" in state_dict:
|
|
|
|
weight = extract(state_dict)
|
|
|
|
else:
|
|
|
|
weight = state_dict["weight"]
|
|
|
|
return weight, state_dict
|
|
|
|
|
|
|
|
files = request.files
|
|
|
|
if len(files) == 0:
|
|
|
|
print("no merge file..............")
|
|
|
|
raise RuntimeError("no merge file..............")
|
|
|
|
|
|
|
|
weights = []
|
|
|
|
alphas = []
|
|
|
|
for f in files:
|
|
|
|
strength = f.strength
|
|
|
|
if strength == 0:
|
|
|
|
continue
|
|
|
|
weight, state_dict = load_weight(f.filename)
|
|
|
|
weights.append(weight)
|
|
|
|
alphas.append(f.strength)
|
|
|
|
|
|
|
|
alphas = [x / sum(alphas) for x in alphas]
|
|
|
|
|
|
|
|
for weight in weights:
|
|
|
|
if sorted(list(weight.keys())) != sorted(list(weights[0].keys())):
|
|
|
|
raise RuntimeError("Failed to merge models.")
|
|
|
|
|
|
|
|
merged: Dict[str, Any] = OrderedDict()
|
|
|
|
merged["weight"] = {}
|
2023-05-11 17:20:08 +03:00
|
|
|
print("merge start.")
|
2023-05-12 01:47:57 +03:00
|
|
|
for key in weights[0].keys():
|
2023-04-30 20:34:01 +03:00
|
|
|
merged["weight"][key] = 0
|
|
|
|
for i, weight in enumerate(weights):
|
|
|
|
merged["weight"][key] += weight[key] * alphas[i]
|
2023-05-11 17:20:08 +03:00
|
|
|
print("merge done. write metadata.")
|
2023-04-30 20:34:01 +03:00
|
|
|
|
|
|
|
merged["config"] = state_dict["config"]
|
|
|
|
merged["params"] = state_dict["params"] if "params" in state_dict else None
|
2023-05-21 05:50:28 +03:00
|
|
|
merged["version"] = state_dict["version"] if "version" in state_dict else None
|
2023-04-30 20:34:01 +03:00
|
|
|
merged["sr"] = state_dict["sr"]
|
|
|
|
merged["f0"] = state_dict["f0"]
|
|
|
|
merged["info"] = state_dict["info"]
|
2023-05-01 10:26:56 +03:00
|
|
|
merged["embedder_name"] = (
|
|
|
|
state_dict["embedder_name"] if "embedder_name" in state_dict else None
|
|
|
|
)
|
2023-05-11 17:20:08 +03:00
|
|
|
print("write metadata done.")
|
2023-04-30 20:34:01 +03:00
|
|
|
return merged
|