mirror of
https://github.com/w-okada/voice-changer.git
synced 2025-01-27 21:41:41 +03:00
48 lines
1.4 KiB
Python
48 lines
1.4 KiB
Python
|
import torch
|
||
|
from torch import device
|
||
|
from const import EnumEmbedderTypes
|
||
|
from voice_changer.RVC.embedder.Embedder import Embedder
|
||
|
from fairseq import checkpoint_utils
|
||
|
|
||
|
|
||
|
class FairseqHubert(Embedder):
|
||
|
def loadModel(self, file: str, dev: device, isHalf: bool = True) -> Embedder:
|
||
|
super().loadModel(file, dev, isHalf)
|
||
|
|
||
|
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
||
|
[file],
|
||
|
suffix="",
|
||
|
)
|
||
|
model = models[0]
|
||
|
model.eval()
|
||
|
|
||
|
model = model.to(dev)
|
||
|
if isHalf:
|
||
|
model = model.half()
|
||
|
|
||
|
self.model = model
|
||
|
self.embedderType = EnumEmbedderTypes.hubert
|
||
|
return self
|
||
|
|
||
|
def extractFeatures(self, feats: torch.Tensor, embChannels=256) -> torch.Tensor:
|
||
|
padding_mask = torch.BoolTensor(feats.shape).to(self.dev).fill_(False)
|
||
|
if embChannels == 256:
|
||
|
inputs = {
|
||
|
"source": feats.to(self.dev),
|
||
|
"padding_mask": padding_mask,
|
||
|
"output_layer": 9, # layer 9
|
||
|
}
|
||
|
else:
|
||
|
inputs = {
|
||
|
"source": feats.to(self.dev),
|
||
|
"padding_mask": padding_mask,
|
||
|
}
|
||
|
|
||
|
with torch.no_grad():
|
||
|
logits = self.model.extract_features(**inputs)
|
||
|
if embChannels == 256:
|
||
|
feats = self.model.final_proj(logits[0])
|
||
|
else:
|
||
|
feats = logits[0]
|
||
|
return feats
|