fix: variable name has been changed

This commit is contained in:
kremnik 2024-09-30 19:55:13 +03:00
parent 3bb317ac63
commit 8adca77ca6

View File

@ -505,7 +505,7 @@ def find_batched(
""" """
embeddings_list = [] embeddings_list = []
valid_mask = [] valid_mask = []
other_keys = set() metadata = set()
for item in representations: for item in representations:
emb = item.get('embedding') emb = item.get('embedding')
@ -516,18 +516,18 @@ def find_batched(
embeddings_list.append(np.zeros_like(representations[0]['embedding'])) embeddings_list.append(np.zeros_like(representations[0]['embedding']))
valid_mask.append(False) valid_mask.append(False)
other_keys.update(item.keys()) metadata.update(item.keys())
# remove embedding key from other keys # remove embedding key from other keys
other_keys.discard('embedding') metadata.discard('embedding')
other_keys = list(other_keys) metadata = list(metadata)
embeddings = np.array(embeddings_list) # (N, D) embeddings = np.array(embeddings_list) # (N, D)
valid_mask = np.array(valid_mask) # (N,) valid_mask = np.array(valid_mask) # (N,)
data = { data = {
key: np.array([item.get(key, None) for item in representations]) key: np.array([item.get(key, None) for item in representations])
for key in other_keys for key in metadata
} }
target_embeddings = [] target_embeddings = []