mirror of
https://github.com/w-okada/voice-changer.git
synced 2025-02-09 03:37:51 +03:00
commit
1cccd90546
@ -146,11 +146,16 @@ class Pipeline(object):
|
|||||||
# D, I = self.index.search(npy, 1)
|
# D, I = self.index.search(npy, 1)
|
||||||
# npy = self.feature[I.squeeze()]
|
# npy = self.feature[I.squeeze()]
|
||||||
|
|
||||||
score, ix = self.index.search(npy, k=8)
|
# TODO: kは調整できるようにする
|
||||||
weight = np.square(1 / score)
|
k = 1
|
||||||
weight /= weight.sum(axis=1, keepdims=True)
|
if k == 1:
|
||||||
|
_, ix = self.index.search(npy, 1)
|
||||||
npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
|
npy = self.big_npy[ix.squeeze()]
|
||||||
|
else:
|
||||||
|
score, ix = self.index.search(npy, k=8)
|
||||||
|
weight = np.square(1 / score)
|
||||||
|
weight /= weight.sum(axis=1, keepdims=True)
|
||||||
|
npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
|
||||||
|
|
||||||
if self.isHalf is True:
|
if self.isHalf is True:
|
||||||
npy = npy.astype("float16")
|
npy = npy.astype("float16")
|
||||||
|
Loading…
Reference in New Issue
Block a user