forked from redis-performance/vector-db-benchmark
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathann_h5_reader.py
More file actions
58 lines (45 loc) · 1.88 KB
/
ann_h5_reader.py
File metadata and controls
58 lines (45 loc) · 1.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import itertools
from typing import Iterator
import h5py
import numpy as np
from benchmark import DATASETS_DIR
from dataset_reader.base_reader import BaseReader, Query, Record
class AnnH5Reader(BaseReader):
def __init__(self, path, normalize=False):
self.path = path
self.normalize = normalize
def read_queries(self) -> Iterator[Query]:
data = h5py.File(self.path)
distances = data["distances"] if "distances" in data else itertools.repeat(None)
for vector, expected_result, expected_scores in zip(
data["test"], data["neighbors"], distances
):
if self.normalize:
vector /= np.linalg.norm(vector)
yield Query(
vector=vector.tolist(),
meta_conditions=None,
expected_result=expected_result.tolist(),
expected_scores=expected_scores.tolist() if expected_scores is not None else None,
)
def read_data(self, *args, **kwargs) -> Iterator[Record]:
data = h5py.File(self.path)
for idx, vector in enumerate(data["train"]):
if self.normalize:
vector /= np.linalg.norm(vector)
yield Record(id=idx, vector=vector.tolist(), metadata=None)
if __name__ == "__main__":
import os
# h5py file 4 keys:
# `train` - float vectors (num vectors 1183514)
# `test` - float vectors (num vectors 10000)
# `neighbors` - int - indices of nearest neighbors for test (num items 10k, each item
# contains info about 100 nearest neighbors)
# `distances` - float - distances for nearest neighbors for test vectors
test_path = os.path.join(
DATASETS_DIR, "glove-100-angular", "glove-100-angular.hdf5"
)
record = next(AnnH5Reader(test_path).read_data())
print(record, end="\n\n")
query = next(AnnH5Reader(test_path).read_queries())
print(query)