voice-changer/server/voice_changer/RVC/embedder/FairseqHubert.py

47 lines
1.4 KiB
Python
Raw Normal View History

2023-05-02 06:11:00 +03:00
import torch
from torch import device
from const import EnumEmbedderTypes
from voice_changer.RVC.embedder.Embedder import Embedder
from fairseq import checkpoint_utils
class FairseqHubert(Embedder):
def loadModel(self, file: str, dev: device, isHalf: bool = True) -> Embedder:
2023-05-02 14:57:12 +03:00
super().setProps(EnumEmbedderTypes.hubert, file, dev, isHalf)
2023-05-02 06:11:00 +03:00
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
[file],
suffix="",
)
model = models[0]
model.eval()
model = model.to(dev)
if isHalf:
model = model.half()
self.model = model
return self
def extractFeatures(self, feats: torch.Tensor, embChannels=256) -> torch.Tensor:
padding_mask = torch.BoolTensor(feats.shape).to(self.dev).fill_(False)
if embChannels == 256:
inputs = {
"source": feats.to(self.dev),
"padding_mask": padding_mask,
"output_layer": 9, # layer 9
}
else:
inputs = {
"source": feats.to(self.dev),
"padding_mask": padding_mask,
}
with torch.no_grad():
logits = self.model.extract_features(**inputs)
if embChannels == 256:
feats = self.model.final_proj(logits[0])
else:
feats = logits[0]
return feats