voice-changer/server/voice_changer/VoiceChangerManager.py

166 lines
5.5 KiB
Python
Raw Normal View History

2022-12-31 16:02:53 +09:00
import numpy as np
2023-06-17 15:35:43 +09:00
from downloader.SampleDownloader import downloadSample, getSampleInfos
2023-06-16 15:06:35 +09:00
from voice_changer.Local.ServerDevice import ServerDevice, ServerDeviceCallbacks
2023-06-17 14:16:29 +09:00
from voice_changer.ModelSlotManager import ModelSlotManager
2022-12-31 16:08:14 +09:00
from voice_changer.VoiceChanger import VoiceChanger
2023-04-11 00:21:17 +09:00
from const import ModelType
2023-04-28 06:39:51 +09:00
from voice_changer.utils.LoadModelParams import LoadModelParams
from voice_changer.utils.VoiceChangerModel import AudioInOut
2023-04-27 23:38:25 +09:00
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
from dataclasses import dataclass, asdict
import torch
2023-06-16 15:06:35 +09:00
import threading
2023-06-16 16:10:46 +09:00
from typing import Callable
from typing import Any
@dataclass()
class GPUInfo:
id: int
name: str
memory: int
@dataclass()
class VoiceChangerManagerSettings:
dummy: int
# intData: list[str] = field(default_factory=lambda: ["slotIndex"])
2022-12-31 16:02:53 +09:00
2023-01-29 09:42:45 +09:00
2023-06-16 15:06:35 +09:00
class VoiceChangerManager(ServerDeviceCallbacks):
2023-04-27 23:38:25 +09:00
_instance = None
2023-06-16 15:06:35 +09:00
############################
# ServerDeviceCallbacks
############################
def on_request(self, unpackedData: AudioInOut):
return self.changeVoice(unpackedData)
def emitTo(self, performance: list[float]):
2023-06-16 16:10:46 +09:00
self.emitToFunc(performance)
2023-06-16 15:06:35 +09:00
def get_processing_sampling_rate(self):
return self.voiceChanger.get_processing_sampling_rate()
def setSamplingRate(self, sr: int):
self.voiceChanger.settings.inputSampleRate = sr
############################
# VoiceChangerManager
############################
def __init__(self, params: VoiceChangerParams):
2023-06-16 17:51:46 +09:00
self.params = params
self.voiceChanger: VoiceChanger = None
self.settings: VoiceChangerManagerSettings = VoiceChangerManagerSettings(dummy=0)
2023-06-17 14:16:29 +09:00
self.modelSlotManager = ModelSlotManager.get_instance(self.params.model_dir)
# スタティックな情報を収集
self.gpus: list[GPUInfo] = self._get_gpuInfos()
2023-06-16 15:06:35 +09:00
self.serverDevice = ServerDevice(self)
thread = threading.Thread(target=self.serverDevice.start, args=())
thread.start()
def _get_gpuInfos(self):
devCount = torch.cuda.device_count()
gpus = []
for id in range(devCount):
name = torch.cuda.get_device_name(id)
memory = torch.cuda.get_device_properties(id).total_memory
gpu = {"id": id, "name": name, "memory": memory}
gpus.append(gpu)
return gpus
2023-04-27 23:38:25 +09:00
2022-12-31 16:02:53 +09:00
@classmethod
2023-04-27 23:38:25 +09:00
def get_instance(cls, params: VoiceChangerParams):
if cls._instance is None:
cls._instance = cls(params)
2023-03-16 08:11:38 +09:00
cls._instance.voiceChanger = VoiceChanger(params)
2022-12-31 16:02:53 +09:00
return cls._instance
2023-04-28 06:39:51 +09:00
def loadModel(self, props: LoadModelParams):
2023-06-16 17:51:46 +09:00
paramDict = props.params
if "sampleId" in paramDict and len(paramDict["sampleId"]) > 0:
downloadSample(self.params.sample_mode, paramDict["sampleId"], self.params.model_dir, props.slot, {"useIndex": paramDict["rvcIndexDownload"]})
2023-06-17 14:16:29 +09:00
self.modelSlotManager.getAllSlotInfo(reload=True)
2023-06-16 17:51:46 +09:00
info = {"status": "OK"}
2023-04-14 11:03:52 +09:00
return info
else:
2023-06-16 17:51:46 +09:00
print("[Voice Canger]: upload models........")
info = self.voiceChanger.loadModel(props)
if hasattr(info, "status") and info["status"] == "NG":
return info
else:
info["status"] = "OK"
return info
2022-12-31 16:02:53 +09:00
2023-01-08 00:25:21 +09:00
def get_info(self):
data = asdict(self.settings)
data["gpus"] = self.gpus
2023-06-17 14:16:29 +09:00
data["modelSlots"] = self.modelSlotManager.getAllSlotInfo(reload=True)
data["sampleModels"] = getSampleInfos(self.params.sample_mode)
data["status"] = "OK"
2023-06-16 15:06:35 +09:00
info = self.serverDevice.get_info()
data.update(info)
2023-04-27 23:38:25 +09:00
if hasattr(self, "voiceChanger"):
2023-01-10 22:49:16 +09:00
info = self.voiceChanger.get_info()
data.update(info)
return data
2023-01-08 00:25:21 +09:00
else:
2023-01-29 09:42:45 +09:00
return {"status": "ERROR", "msg": "no model loaded"}
2023-05-07 04:18:18 +09:00
def get_performance(self):
if hasattr(self, "voiceChanger"):
info = self.voiceChanger.get_performance()
return info
else:
return {"status": "ERROR", "msg": "no model loaded"}
2023-01-08 00:25:21 +09:00
2023-04-28 06:39:51 +09:00
def update_settings(self, key: str, val: str | int | float):
2023-06-16 15:06:35 +09:00
self.serverDevice.update_settings(key, val)
2023-04-27 23:38:25 +09:00
if hasattr(self, "voiceChanger"):
self.voiceChanger.update_settings(key, val)
2023-01-08 00:25:21 +09:00
else:
2023-01-29 09:42:45 +09:00
return {"status": "ERROR", "msg": "no model loaded"}
return self.get_info()
2023-01-08 16:18:20 +09:00
2023-04-28 06:39:51 +09:00
def changeVoice(self, receivedData: AudioInOut):
2023-04-27 23:38:25 +09:00
if hasattr(self, "voiceChanger") is True:
return self.voiceChanger.on_request(receivedData)
2023-01-05 02:28:36 +09:00
else:
print("Voice Change is not loaded. Did you load a correct model?")
2023-02-21 04:07:43 +09:00
return np.zeros(1).astype(np.int16), []
2023-04-11 00:21:17 +09:00
def switchModelType(self, modelType: ModelType):
return self.voiceChanger.switchModelType(modelType)
def getModelType(self):
return self.voiceChanger.getModelType()
2023-04-13 08:00:28 +09:00
def export2onnx(self):
return self.voiceChanger.export2onnx()
2023-05-01 02:34:01 +09:00
def merge_models(self, request: str):
2023-06-16 21:58:46 +09:00
self.voiceChanger.merge_models(request)
return self.get_info()
2023-05-21 04:21:54 +09:00
2023-06-17 10:08:32 +09:00
def setEmitTo(self, emitTo: Callable[[Any], None]):
self.emitToFunc = emitTo
2023-05-21 04:21:54 +09:00
def update_model_default(self):
2023-06-16 21:58:46 +09:00
self.voiceChanger.update_model_default()
return self.get_info()
2023-06-08 03:08:59 +09:00
def update_model_info(self, newData: str):
2023-06-16 21:58:46 +09:00
self.voiceChanger.update_model_info(newData)
return self.get_info()
2023-06-08 03:08:59 +09:00
def upload_model_assets(self, params: str):
2023-06-16 21:58:46 +09:00
self.voiceChanger.upload_model_assets(params)
return self.get_info()