2023-05-02 14:57:12 +03:00
|
|
|
import torch
|
2023-05-31 08:30:35 +03:00
|
|
|
from const import EnumInferenceTypes
|
2023-05-29 11:34:35 +03:00
|
|
|
from voice_changer.RVC.deviceManager.DeviceManager import DeviceManager
|
2023-05-02 14:57:12 +03:00
|
|
|
|
|
|
|
from voice_changer.RVC.inferencer.Inferencer import Inferencer
|
|
|
|
from .models import SynthesizerTrnMsNSFsid
|
|
|
|
|
|
|
|
|
|
|
|
class WebUIInferencer(Inferencer):
|
2023-05-29 11:34:35 +03:00
|
|
|
def loadModel(self, file: str, gpu: int):
|
2023-05-31 08:30:35 +03:00
|
|
|
self.setProps(EnumInferenceTypes.pyTorchWebUI, file, True, gpu)
|
|
|
|
|
2023-05-29 11:34:35 +03:00
|
|
|
dev = DeviceManager.get_instance().getDevice(gpu)
|
|
|
|
isHalf = DeviceManager.get_instance().halfPrecisionAvailable(gpu)
|
|
|
|
|
2023-05-02 14:57:12 +03:00
|
|
|
cpt = torch.load(file, map_location="cpu")
|
|
|
|
model = SynthesizerTrnMsNSFsid(**cpt["params"], is_half=isHalf)
|
|
|
|
|
|
|
|
model.eval()
|
|
|
|
model.load_state_dict(cpt["weight"], strict=False)
|
2023-05-03 07:14:00 +03:00
|
|
|
|
|
|
|
model = model.to(dev)
|
2023-05-02 14:57:12 +03:00
|
|
|
if isHalf:
|
|
|
|
model = model.half()
|
|
|
|
|
|
|
|
self.model = model
|
|
|
|
return self
|
|
|
|
|
|
|
|
def infer(
|
|
|
|
self,
|
|
|
|
feats: torch.Tensor,
|
|
|
|
pitch_length: torch.Tensor,
|
2023-05-03 07:14:00 +03:00
|
|
|
pitch: torch.Tensor,
|
|
|
|
pitchf: torch.Tensor,
|
2023-05-02 14:57:12 +03:00
|
|
|
sid: torch.Tensor,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
return self.model.infer(feats, pitch_length, pitch, pitchf, sid)
|