update contentvec layer9

This commit is contained in:
wataru 2023-04-03 17:33:28 +09:00
parent 5ec945faac
commit 7f40fcfd47

View File

@ -226,9 +226,14 @@ class SoVitsSvc40:
c = torch.from_numpy(np.array(c)).squeeze(0).transpose(1, 2) c = torch.from_numpy(np.array(c)).squeeze(0).transpose(1, 2)
# print("onnx hubert:", self.hubert_onnx.get_providers()) # print("onnx hubert:", self.hubert_onnx.get_providers())
else: else:
self.hubert_model = self.hubert_model.to(dev) if self.hps.model.ssl_dim == 768:
wav16k_tensor = wav16k_tensor.to(dev) self.hubert_model = self.hubert_model.to(dev)
c = utils.get_hubert_content(self.hubert_model, wav_16k_tensor=wav16k_tensor) wav16k_tensor = wav16k_tensor.to(dev)
c = get_hubert_content_layer9(self.hubert_model, wav_16k_tensor=wav16k_tensor)
else:
self.hubert_model = self.hubert_model.to(dev)
wav16k_tensor = wav16k_tensor.to(dev)
c = utils.get_hubert_content(self.hubert_model, wav_16k_tensor=wav16k_tensor)
uv = uv.to(dev) uv = uv.to(dev)
f0 = f0.to(dev) f0 = f0.to(dev)
@ -384,3 +389,21 @@ def compute_f0_harvest(wav_numpy, p_len=None, sampling_rate=44100, hop_length=51
for index, pitch in enumerate(f0): for index, pitch in enumerate(f0):
f0[index] = round(pitch, 1) f0[index] = round(pitch, 1)
return resize_f0(f0, p_len) return resize_f0(f0, p_len)
def get_hubert_content_layer9(hmodel, wav_16k_tensor):
feats = wav_16k_tensor
if feats.dim() == 2: # double channels
feats = feats.mean(-1)
assert feats.dim() == 1, feats.dim()
feats = feats.view(1, -1)
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
inputs = {
"source": feats.to(wav_16k_tensor.device),
"padding_mask": padding_mask.to(wav_16k_tensor.device),
"output_layer": 9, # layer 9
}
with torch.no_grad():
logits = hmodel.extract_features(**inputs)
return logits[0].transpose(1, 2)