update contentvec layer9

This commit is contained in:
wataru 2023-04-03 17:33:28 +09:00
parent 5ec945faac
commit 7f40fcfd47

View File

@ -225,6 +225,11 @@ class SoVitsSvc40:
})
c = torch.from_numpy(np.array(c)).squeeze(0).transpose(1, 2)
# print("onnx hubert:", self.hubert_onnx.get_providers())
else:
if self.hps.model.ssl_dim == 768:
self.hubert_model = self.hubert_model.to(dev)
wav16k_tensor = wav16k_tensor.to(dev)
c = get_hubert_content_layer9(self.hubert_model, wav_16k_tensor=wav16k_tensor)
else:
self.hubert_model = self.hubert_model.to(dev)
wav16k_tensor = wav16k_tensor.to(dev)
@ -384,3 +389,21 @@ def compute_f0_harvest(wav_numpy, p_len=None, sampling_rate=44100, hop_length=51
for index, pitch in enumerate(f0):
f0[index] = round(pitch, 1)
return resize_f0(f0, p_len)
def get_hubert_content_layer9(hmodel, wav_16k_tensor):
feats = wav_16k_tensor
if feats.dim() == 2: # double channels
feats = feats.mean(-1)
assert feats.dim() == 1, feats.dim()
feats = feats.view(1, -1)
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
inputs = {
"source": feats.to(wav_16k_tensor.device),
"padding_mask": padding_mask.to(wav_16k_tensor.device),
"output_layer": 9, # layer 9
}
with torch.no_grad():
logits = hmodel.extract_features(**inputs)
return logits[0].transpose(1, 2)