voice-changer/server/voice_changer/RVC/embedder/FairseqHubert.py

47 lines
1.4 KiB
Python
Raw Normal View History

2023-05-02 06:11:00 +03:00
import torch
from torch import device
from const import EnumEmbedderTypes
from voice_changer.RVC.embedder.Embedder import Embedder
from fairseq import checkpoint_utils
class FairseqHubert(Embedder):
def loadModel(self, file: str, dev: device, isHalf: bool = True) -> Embedder:
2023-05-02 14:57:12 +03:00
super().setProps(EnumEmbedderTypes.hubert, file, dev, isHalf)
2023-05-02 06:11:00 +03:00
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
[file],
suffix="",
)
model = models[0]
model.eval()
model = model.to(dev)
if isHalf:
model = model.half()
self.model = model
return self
2023-05-24 10:56:23 +03:00
def extractFeatures(
self, feats: torch.Tensor, embOutputLayer=9, useFinalProj=True
) -> torch.Tensor:
2023-05-02 06:11:00 +03:00
padding_mask = torch.BoolTensor(feats.shape).to(self.dev).fill_(False)
2023-05-24 10:56:23 +03:00
# オリジナル_v1は L9にfinal_projをかけていた。(-> 256)
# オリジナル_v2は L12にfinal_projをかけない。(-> 768)
inputs = {
"source": feats.to(self.dev),
"padding_mask": padding_mask,
"output_layer": embOutputLayer, # 9 or 12
}
2023-05-02 06:11:00 +03:00
with torch.no_grad():
logits = self.model.extract_features(**inputs)
2023-05-24 10:56:23 +03:00
if useFinalProj:
2023-05-02 06:11:00 +03:00
feats = self.model.final_proj(logits[0])
else:
feats = logits[0]
return feats