mps not support float64

This commit is contained in:
wataru 2023-06-04 17:19:59 +09:00
parent 26a59857b3
commit 3f4e68294a

View File

@ -166,7 +166,7 @@ class Pipeline(object):
npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1) npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
# recover silient font # recover silient font
npy = np.concatenate([np.zeros([npyOffset, npy.shape[1]]), npy]) npy = np.concatenate([np.zeros([npyOffset, npy.shape[1]]).astype("float32"), npy])
if self.isHalf is True: if self.isHalf is True:
npy = npy.astype("float16") npy = npy.astype("float16")