diff --git a/server/voice_changer/RVC/pipeline/Pipeline.py b/server/voice_changer/RVC/pipeline/Pipeline.py index 8c4364e5..14584699 100644 --- a/server/voice_changer/RVC/pipeline/Pipeline.py +++ b/server/voice_changer/RVC/pipeline/Pipeline.py @@ -146,11 +146,16 @@ class Pipeline(object): # D, I = self.index.search(npy, 1) # npy = self.feature[I.squeeze()] - score, ix = self.index.search(npy, k=8) - weight = np.square(1 / score) - weight /= weight.sum(axis=1, keepdims=True) - - npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1) + # TODO: kは調整できるようにする + k = 1 + if k == 1: + _, ix = self.index.search(npy, 1) + npy = self.big_npy[ix.squeeze()] + else: + score, ix = self.index.search(npy, k=8) + weight = np.square(1 / score) + weight /= weight.sum(axis=1, keepdims=True) + npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1) if self.isHalf is True: npy = npy.astype("float16")