voice-changer/server/voice_changer/RVC/inferencer/WebUIInferencer.py

43 lines
1.3 KiB
Python
Raw Normal View History

2023-05-02 14:57:12 +03:00
import torch
2023-05-31 08:30:35 +03:00
from const import EnumInferenceTypes
2023-05-29 11:34:35 +03:00
from voice_changer.RVC.deviceManager.DeviceManager import DeviceManager
2023-05-02 14:57:12 +03:00
from voice_changer.RVC.inferencer.Inferencer import Inferencer
from .models import SynthesizerTrnMsNSFsid
class WebUIInferencer(Inferencer):
2023-05-29 11:34:35 +03:00
def loadModel(self, file: str, gpu: int):
2023-05-31 08:30:35 +03:00
self.setProps(EnumInferenceTypes.pyTorchWebUI, file, True, gpu)
2023-05-29 11:34:35 +03:00
dev = DeviceManager.get_instance().getDevice(gpu)
isHalf = DeviceManager.get_instance().halfPrecisionAvailable(gpu)
2023-05-02 14:57:12 +03:00
cpt = torch.load(file, map_location="cpu")
model = SynthesizerTrnMsNSFsid(**cpt["params"], is_half=isHalf)
model.eval()
model.load_state_dict(cpt["weight"], strict=False)
2023-05-03 07:14:00 +03:00
model = model.to(dev)
2023-05-02 14:57:12 +03:00
if isHalf:
model = model.half()
self.model = model
return self
def infer(
self,
feats: torch.Tensor,
pitch_length: torch.Tensor,
2023-05-03 07:14:00 +03:00
pitch: torch.Tensor,
pitchf: torch.Tensor,
2023-05-02 14:57:12 +03:00
sid: torch.Tensor,
2023-07-01 10:45:25 +03:00
convert_length: int | None,
2023-05-02 14:57:12 +03:00
) -> torch.Tensor:
2023-09-06 02:04:39 +03:00
res = self.model.infer(feats, pitch_length, pitch, pitchf, sid, convert_length=convert_length)
res = res[0][0, 0].to(dtype=torch.float32)
res = torch.clip(res, -1.0, 1.0)
return res