11# Copyright (c) Microsoft Corporation.
22# Licensed under the MIT License.
33
4+ from collections .abc import Iterable
45from dataclasses import dataclass
5- from typing import NamedTuple
6+ from typing import Callable
67
78import numpy as np
89
910from .embeddings import AsyncEmbeddingModel , NormalizedEmbedding , NormalizedEmbeddings
1011
1112
13+ @dataclass
14+ class Scored :
15+ item : int
16+ score : float
17+
18+
1219@dataclass
1320class TextEmbeddingIndexSettings :
1421 embedding_model : AsyncEmbeddingModel
@@ -35,11 +42,6 @@ def __init__(
3542 self .max_matches = max_matches
3643
3744
38- class ScoredOrdinal (NamedTuple ):
39- ordinal : int
40- score : float
41-
42-
4345class VectorBase :
4446 _vectors : NormalizedEmbeddings
4547
@@ -79,6 +81,11 @@ def add_embedding(self, key: str | None, embedding: NormalizedEmbedding) -> None
7981 if key is not None :
8082 self ._model .add_embedding (key , embedding )
8183
84+ def add_embeddings (self , embeddings : NormalizedEmbeddings ) -> None :
85+ assert embeddings .ndim == 2
86+ assert embeddings .shape [1 ] == self ._embedding_size
87+ self ._vectors = np .concatenate ((self ._vectors , embeddings ), axis = 0 )
88+
8289 async def add_key (self , key : str , cache : bool = True ) -> None :
8390 embeddings = (await self .get_embedding (key , cache = cache )).reshape (
8491 1 , - 1
@@ -89,29 +96,64 @@ async def add_keys(self, keys: list[str], cache: bool = True) -> None:
8996 embeddings = await self .get_embeddings (keys , cache = cache )
9097 self ._vectors = np .concatenate ((self ._vectors , embeddings ), axis = 0 )
9198
92- async def fuzzy_lookup (
93- self , key : str , max_hits : int | None = None , min_score : float | None = None
94- ) -> list [ScoredOrdinal ]:
99+ def fuzzy_lookup_embedding (
100+ self ,
101+ embedding : NormalizedEmbedding ,
102+ max_hits : int | None = None ,
103+ min_score : float | None = None ,
104+ predicate : Callable [[int ], bool ] | None = None ,
105+ ) -> list [Scored ]:
95106 if max_hits is None :
96107 max_hits = 10
97108 if min_score is None :
98109 min_score = 0.0
99- embedding = await self . get_embedding ( key )
100- scores = np .dot (self ._vectors , embedding ) # This does most of the work
110+ # This line does most of the work:
111+ scores : Iterable [ float ] = np .dot (self ._vectors , embedding )
101112 scored_ordinals = [
102- ScoredOrdinal (i , float ( score ) )
113+ Scored (i , score )
103114 for i , score in enumerate (scores )
104- if score >= min_score
115+ if score >= min_score and ( predicate is None or predicate ( i ))
105116 ]
106117 scored_ordinals .sort (key = lambda x : x .score , reverse = True )
107118 return scored_ordinals [:max_hits ]
108119
120+ # TODO: Make this and fizzy_lookup_embedding() more similar.
121+ def fuzzy_lookup_embedding_in_subset (
122+ self ,
123+ embedding : NormalizedEmbedding ,
124+ ordinals_of_subset : list [int ],
125+ max_hits : int | None = None ,
126+ min_score : float | None = None ,
127+ ) -> list [Scored ]:
128+ return self .fuzzy_lookup_embedding (
129+ embedding , max_hits , min_score , lambda i : i in ordinals_of_subset
130+ )
131+
132+ async def fuzzy_lookup (
133+ self ,
134+ key : str ,
135+ max_hits : int | None = None ,
136+ min_score : float | None = None ,
137+ predicate : Callable [[int ], bool ] | None = None ,
138+ ) -> list [Scored ]:
139+ embedding = await self .get_embedding (key )
140+ return self .fuzzy_lookup_embedding (
141+ embedding , max_hits = max_hits , min_score = min_score , predicate = predicate
142+ )
143+
109144 def clear (self ) -> None :
110145 self ._vectors = np .array ([], dtype = np .float32 )
111146 self ._vectors .shape = (0 , self ._embedding_size )
112147
113- def serialize_embedding_at (self , ordinal : int ) -> NormalizedEmbedding | None :
114- return self ._vectors [ordinal ] if 0 <= ordinal < len (self ._vectors ) else None
148+ def get_embedding_at (self , pos : int ) -> NormalizedEmbedding :
149+ if 0 <= pos < len (self ._vectors ):
150+ return self ._vectors [pos ]
151+ raise IndexError (
152+ f"Index { pos } out of bounds for embedding index of size { len (self )} "
153+ )
154+
155+ def serialize_embedding_at (self , pos : int ) -> NormalizedEmbedding | None :
156+ return self ._vectors [pos ] if 0 <= pos < len (self ._vectors ) else None
115157
116158 def serialize (self ) -> NormalizedEmbeddings :
117159 assert self ._vectors .shape == (len (self ._vectors ), self ._embedding_size )
@@ -181,7 +223,7 @@ def debugv(heading: str):
181223 log ("\n Fuzzy lookups:" )
182224 for word in words + ["pancakes" , "hello world" , "book" , "author" ]:
183225 neighbors = await v .fuzzy_lookup (word , max_hits = 3 )
184- log (f"{ word } :" , [(nb .ordinal , nb .score ) for nb in neighbors ])
226+ log (f"{ word } :" , [(nb .item , nb .score ) for nb in neighbors ])
185227
186228
187229if __name__ == "__main__" :
0 commit comments