From 04847306afa42c9f521f286173f70cfa9018f6fe Mon Sep 17 00:00:00 2001 From: nadare <1na2da0re3@gmail.com> Date: Sun, 28 May 2023 01:13:33 +0900 Subject: [PATCH] fix infer faiss params --- server/voice_changer/RVC/pipeline/Pipeline.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/server/voice_changer/RVC/pipeline/Pipeline.py b/server/voice_changer/RVC/pipeline/Pipeline.py index 8c4364e5..14584699 100644 --- a/server/voice_changer/RVC/pipeline/Pipeline.py +++ b/server/voice_changer/RVC/pipeline/Pipeline.py @@ -146,11 +146,16 @@ class Pipeline(object): # D, I = self.index.search(npy, 1) # npy = self.feature[I.squeeze()] - score, ix = self.index.search(npy, k=8) - weight = np.square(1 / score) - weight /= weight.sum(axis=1, keepdims=True) - - npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1) + # TODO: kは調整できるようにする + k = 1 + if k == 1: + _, ix = self.index.search(npy, 1) + npy = self.big_npy[ix.squeeze()] + else: + score, ix = self.index.search(npy, k=8) + weight = np.square(1 / score) + weight /= weight.sum(axis=1, keepdims=True) + npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1) if self.isHalf is True: npy = npy.astype("float16")