mirror of
https://github.com/w-okada/voice-changer.git
synced 2025-01-24 22:15:02 +03:00
47 lines
1.4 KiB
Python
47 lines
1.4 KiB
Python
import torch
|
|
from torch import device
|
|
from const import EnumEmbedderTypes
|
|
from voice_changer.RVC.embedder.Embedder import Embedder
|
|
from fairseq import checkpoint_utils
|
|
|
|
|
|
class FairseqHubert(Embedder):
|
|
def loadModel(self, file: str, dev: device, isHalf: bool = True) -> Embedder:
|
|
super().setProps(EnumEmbedderTypes.hubert, file, dev, isHalf)
|
|
|
|
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
|
[file],
|
|
suffix="",
|
|
)
|
|
model = models[0]
|
|
model.eval()
|
|
|
|
model = model.to(dev)
|
|
if isHalf:
|
|
model = model.half()
|
|
|
|
self.model = model
|
|
return self
|
|
|
|
def extractFeatures(self, feats: torch.Tensor, embChannels=256) -> torch.Tensor:
|
|
padding_mask = torch.BoolTensor(feats.shape).to(self.dev).fill_(False)
|
|
if embChannels == 256:
|
|
inputs = {
|
|
"source": feats.to(self.dev),
|
|
"padding_mask": padding_mask,
|
|
"output_layer": 9, # layer 9
|
|
}
|
|
else:
|
|
inputs = {
|
|
"source": feats.to(self.dev),
|
|
"padding_mask": padding_mask,
|
|
}
|
|
|
|
with torch.no_grad():
|
|
logits = self.model.extract_features(**inputs)
|
|
if embChannels == 256:
|
|
feats = self.model.final_proj(logits[0])
|
|
else:
|
|
feats = logits[0]
|
|
return feats
|