fix infer faiss params

This commit is contained in:
nadare 2023-05-28 01:13:33 +09:00
parent 4d6d5a27cb
commit 04847306af

View File

@ -146,11 +146,16 @@ class Pipeline(object):
# D, I = self.index.search(npy, 1)
# npy = self.feature[I.squeeze()]
score, ix = self.index.search(npy, k=8)
weight = np.square(1 / score)
weight /= weight.sum(axis=1, keepdims=True)
npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
# TODO: kは調整できるようにする
k = 1
if k == 1:
_, ix = self.index.search(npy, 1)
npy = self.big_npy[ix.squeeze()]
else:
score, ix = self.index.search(npy, k=8)
weight = np.square(1 / score)
weight /= weight.sum(axis=1, keepdims=True)
npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
if self.isHalf is True:
npy = npy.astype("float16")