2023-05-02 14:57:12 +03:00
|
|
|
from typing import Any, Protocol
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torch import device
|
|
|
|
|
|
|
|
from const import EnumInferenceTypes
|
2023-05-03 07:14:00 +03:00
|
|
|
import onnxruntime
|
2023-05-02 14:57:12 +03:00
|
|
|
|
|
|
|
|
|
|
|
class Inferencer(Protocol):
|
|
|
|
inferencerType: EnumInferenceTypes = EnumInferenceTypes.pyTorchRVC
|
|
|
|
file: str
|
|
|
|
isHalf: bool = True
|
|
|
|
dev: device
|
|
|
|
|
2023-05-03 07:14:00 +03:00
|
|
|
model: onnxruntime.InferenceSession | Any | None = None
|
2023-05-02 14:57:12 +03:00
|
|
|
|
|
|
|
def loadModel(self, file: str, dev: device, isHalf: bool = True):
|
|
|
|
...
|
|
|
|
|
|
|
|
def infer(
|
|
|
|
self,
|
|
|
|
feats: torch.Tensor,
|
|
|
|
pitch_length: torch.Tensor,
|
|
|
|
pitch: torch.Tensor | None,
|
|
|
|
pitchf: torch.Tensor | None,
|
|
|
|
sid: torch.Tensor,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
...
|
|
|
|
|
|
|
|
def setProps(
|
|
|
|
self,
|
|
|
|
inferencerType: EnumInferenceTypes,
|
|
|
|
file: str,
|
|
|
|
dev: device,
|
|
|
|
isHalf: bool = True,
|
|
|
|
):
|
|
|
|
self.inferencerType = inferencerType
|
|
|
|
self.file = file
|
|
|
|
self.isHalf = isHalf
|
|
|
|
self.dev = dev
|
|
|
|
|
|
|
|
def setHalf(self, isHalf: bool):
|
|
|
|
self.isHalf = isHalf
|
|
|
|
if self.model is not None and isHalf:
|
|
|
|
self.model = self.model.half()
|
2023-05-03 07:14:00 +03:00
|
|
|
elif self.model is not None and isHalf is False:
|
|
|
|
self.model = self.model.float()
|
2023-05-02 14:57:12 +03:00
|
|
|
|
|
|
|
def setDevice(self, dev: device):
|
|
|
|
self.dev = dev
|
|
|
|
if self.model is not None:
|
|
|
|
self.model = self.model.to(self.dev)
|
|
|
|
return self
|