bugfix: halfprecision

This commit is contained in:
wataru 2023-06-03 22:17:15 +09:00
parent dc7f9a0e15
commit 26a59857b3

View File

@ -165,11 +165,10 @@ class Pipeline(object):
weight /= weight.sum(axis=1, keepdims=True)
npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
if self.isHalf is True:
npy = npy.astype("float16")
# recover silient font
npy = np.concatenate([np.zeros([npyOffset, npy.shape[1]]), npy])
if self.isHalf is True:
npy = npy.astype("float16")
feats = (
torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate
@ -225,6 +224,7 @@ class Pipeline(object):
).data.to(dtype=torch.int16)
except RuntimeError as e:
if "HALF" in e.__str__().upper():
print("11", e)
raise HalfPrecisionChangingException()
else:
raise e