voice-changer/server/voice_changer/RVC/inferencer/RVCInferencer.py

37 lines
1.0 KiB
Python
Raw Normal View History

2023-05-02 14:57:12 +03:00
import torch
from torch import device
from const import EnumInferenceTypes
from voice_changer.RVC.inferencer.Inferencer import Inferencer
from infer_pack.models import ( # type:ignore
SynthesizerTrnMs256NSFsid,
)
class RVCInferencer(Inferencer):
def loadModel(self, file: str, dev: device, isHalf: bool = True):
super().setProps(EnumInferenceTypes.pyTorchRVC, file, dev, isHalf)
2023-05-03 11:12:40 +03:00
print("load inf", file)
2023-05-02 14:57:12 +03:00
cpt = torch.load(file, map_location="cpu")
model = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=isHalf)
model.eval()
model.load_state_dict(cpt["weight"], strict=False)
2023-05-03 07:14:00 +03:00
model = model.to(dev)
2023-05-02 14:57:12 +03:00
if isHalf:
model = model.half()
self.model = model
return self
def infer(
self,
feats: torch.Tensor,
pitch_length: torch.Tensor,
2023-05-03 07:14:00 +03:00
pitch: torch.Tensor,
pitchf: torch.Tensor,
2023-05-02 14:57:12 +03:00
sid: torch.Tensor,
) -> torch.Tensor:
return self.model.infer(feats, pitch_length, pitch, pitchf, sid)