voice-changer/server/voice_changer/RVC/embedder/FairseqHubert.py
2023-05-24 16:56:23 +09:00

47 lines
1.4 KiB
Python

import torch
from torch import device
from const import EnumEmbedderTypes
from voice_changer.RVC.embedder.Embedder import Embedder
from fairseq import checkpoint_utils
class FairseqHubert(Embedder):
def loadModel(self, file: str, dev: device, isHalf: bool = True) -> Embedder:
super().setProps(EnumEmbedderTypes.hubert, file, dev, isHalf)
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
[file],
suffix="",
)
model = models[0]
model.eval()
model = model.to(dev)
if isHalf:
model = model.half()
self.model = model
return self
def extractFeatures(
self, feats: torch.Tensor, embOutputLayer=9, useFinalProj=True
) -> torch.Tensor:
padding_mask = torch.BoolTensor(feats.shape).to(self.dev).fill_(False)
# オリジナル_v1は L9にfinal_projをかけていた。(-> 256)
# オリジナル_v2は L12にfinal_projをかけない。(-> 768)
inputs = {
"source": feats.to(self.dev),
"padding_mask": padding_mask,
"output_layer": embOutputLayer, # 9 or 12
}
with torch.no_grad():
logits = self.model.extract_features(**inputs)
if useFinalProj:
feats = self.model.final_proj(logits[0])
else:
feats = logits[0]
return feats