Merge pull request #261 from nadare881/master

fix infer faiss params
This commit is contained in:
w-okada 2023-05-28 11:30:57 +09:00 committed by GitHub
commit 1cccd90546
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -146,11 +146,16 @@ class Pipeline(object):
# D, I = self.index.search(npy, 1)
# npy = self.feature[I.squeeze()]
score, ix = self.index.search(npy, k=8)
weight = np.square(1 / score)
weight /= weight.sum(axis=1, keepdims=True)
npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
# TODO: kは調整できるようにする
k = 1
if k == 1:
_, ix = self.index.search(npy, 1)
npy = self.big_npy[ix.squeeze()]
else:
score, ix = self.index.search(npy, k=8)
weight = np.square(1 / score)
weight /= weight.sum(axis=1, keepdims=True)
npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
if self.isHalf is True:
npy = npy.astype("float16")