fix: variable name has been changed

This commit is contained in:
kremnik 2024-09-30 19:55:13 +03:00
parent 3bb317ac63
commit 8adca77ca6

View File

@ -505,7 +505,7 @@ def find_batched(
"""
embeddings_list = []
valid_mask = []
other_keys = set()
metadata = set()
for item in representations:
emb = item.get('embedding')
@ -516,18 +516,18 @@ def find_batched(
embeddings_list.append(np.zeros_like(representations[0]['embedding']))
valid_mask.append(False)
other_keys.update(item.keys())
metadata.update(item.keys())
# remove embedding key from other keys
other_keys.discard('embedding')
other_keys = list(other_keys)
metadata.discard('embedding')
metadata = list(metadata)
embeddings = np.array(embeddings_list) # (N, D)
valid_mask = np.array(valid_mask) # (N,)
data = {
key: np.array([item.get(key, None) for item in representations])
for key in other_keys
for key in metadata
}
target_embeddings = []