voice-changer/server/voice_changer/RVC/modelMerger/MergeModel.py

68 lines
2.1 KiB
Python
Raw Normal View History

2023-04-30 20:34:01 +03:00
from typing import Dict, Any
2023-05-04 07:09:13 +03:00
from voice_changer.RVC.modelMerger.MergeModelRequest import MergeModelRequest
2023-04-30 20:34:01 +03:00
from collections import OrderedDict
import torch
def merge_model(request: MergeModelRequest):
def extract(ckpt: Dict[str, Any]):
a = ckpt["model"]
opt: Dict[str, Any] = OrderedDict()
opt["weight"] = {}
for key in a.keys():
if "enc_q" in key:
continue
opt["weight"][key] = a[key]
return opt
def load_weight(path: str):
print(f"Loading {path}...")
state_dict = torch.load(path, map_location="cpu")
if "model" in state_dict:
weight = extract(state_dict)
else:
weight = state_dict["weight"]
return weight, state_dict
files = request.files
if len(files) == 0:
print("no merge file..............")
raise RuntimeError("no merge file..............")
weights = []
alphas = []
for f in files:
strength = f.strength
if strength == 0:
continue
weight, state_dict = load_weight(f.filename)
weights.append(weight)
alphas.append(f.strength)
alphas = [x / sum(alphas) for x in alphas]
for weight in weights:
if sorted(list(weight.keys())) != sorted(list(weights[0].keys())):
raise RuntimeError("Failed to merge models.")
merged: Dict[str, Any] = OrderedDict()
merged["weight"] = {}
2023-05-11 17:20:08 +03:00
print("merge start.")
for key in weights[0].keys():
2023-04-30 20:34:01 +03:00
merged["weight"][key] = 0
for i, weight in enumerate(weights):
merged["weight"][key] += weight[key] * alphas[i]
2023-05-11 17:20:08 +03:00
print("merge done. write metadata.")
2023-04-30 20:34:01 +03:00
merged["config"] = state_dict["config"]
merged["params"] = state_dict["params"] if "params" in state_dict else None
merged["sr"] = state_dict["sr"]
merged["f0"] = state_dict["f0"]
merged["info"] = state_dict["info"]
2023-05-01 10:26:56 +03:00
merged["embedder_name"] = (
state_dict["embedder_name"] if "embedder_name" in state_dict else None
)
2023-05-11 17:20:08 +03:00
print("write metadata done.")
2023-04-30 20:34:01 +03:00
return merged