Skip to content

Commit 1e780d5

Browse files
[python/knowpro] Follow-up to last night (#1207)
Streamline fuzzyindex.py and textlocindex.py. Improve query pretty-printing.
1 parent ff4aafa commit 1e780d5

File tree

8 files changed

+161
-264
lines changed

8 files changed

+161
-264
lines changed

python/ta/test/test_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ async def main(filename: str):
117117
assert isinstance(ser2, dict), f"ser2 is not dict but {type(ser2)!r}"
118118
assert len(ser2) > 0, f"ser2 is empty {ser2!r}"
119119
assert "semanticRefs" in ser2, f"'semantic_refs' is not a key in {ser2.keys()!r}"
120-
assert ser1 == ser2, f"ser1 != ser2"
120+
assert str(ser1) == str(ser2), f"ser1 != ser2"
121121

122122

123123
if __name__ == "__main__":

python/ta/test/test_vectorbase.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typeagent.aitools.vectorbase import (
88
VectorBase,
99
TextEmbeddingIndexSettings,
10-
ScoredOrdinal,
1110
)
1211
from typeagent.aitools.embeddings import AsyncEmbeddingModel, NormalizedEmbedding
1312

@@ -106,7 +105,7 @@ async def test_fuzzy_lookup(
106105

107106
results = await vector_base.fuzzy_lookup("word1", max_hits=2)
108107
assert len(results) == 2
109-
assert results[0].ordinal == 0
108+
assert results[0].item == 0
110109
assert results[0].score > 0.9 # High similarity score for the same word
111110

112111

python/ta/typeagent/aitools/vectorbase.py

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33

4+
from collections.abc import Iterable
45
from dataclasses import dataclass
5-
from typing import NamedTuple
6+
from typing import Callable
67

78
import numpy as np
89

910
from .embeddings import AsyncEmbeddingModel, NormalizedEmbedding, NormalizedEmbeddings
1011

1112

13+
@dataclass
14+
class Scored:
15+
item: int
16+
score: float
17+
18+
1219
@dataclass
1320
class TextEmbeddingIndexSettings:
1421
embedding_model: AsyncEmbeddingModel
@@ -35,11 +42,6 @@ def __init__(
3542
self.max_matches = max_matches
3643

3744

38-
class ScoredOrdinal(NamedTuple):
39-
ordinal: int
40-
score: float
41-
42-
4345
class VectorBase:
4446
_vectors: NormalizedEmbeddings
4547

@@ -79,6 +81,11 @@ def add_embedding(self, key: str | None, embedding: NormalizedEmbedding) -> None
7981
if key is not None:
8082
self._model.add_embedding(key, embedding)
8183

84+
def add_embeddings(self, embeddings: NormalizedEmbeddings) -> None:
85+
assert embeddings.ndim == 2
86+
assert embeddings.shape[1] == self._embedding_size
87+
self._vectors = np.concatenate((self._vectors, embeddings), axis=0)
88+
8289
async def add_key(self, key: str, cache: bool = True) -> None:
8390
embeddings = (await self.get_embedding(key, cache=cache)).reshape(
8491
1, -1
@@ -89,29 +96,64 @@ async def add_keys(self, keys: list[str], cache: bool = True) -> None:
8996
embeddings = await self.get_embeddings(keys, cache=cache)
9097
self._vectors = np.concatenate((self._vectors, embeddings), axis=0)
9198

92-
async def fuzzy_lookup(
93-
self, key: str, max_hits: int | None = None, min_score: float | None = None
94-
) -> list[ScoredOrdinal]:
99+
def fuzzy_lookup_embedding(
100+
self,
101+
embedding: NormalizedEmbedding,
102+
max_hits: int | None = None,
103+
min_score: float | None = None,
104+
predicate: Callable[[int], bool] | None = None,
105+
) -> list[Scored]:
95106
if max_hits is None:
96107
max_hits = 10
97108
if min_score is None:
98109
min_score = 0.0
99-
embedding = await self.get_embedding(key)
100-
scores = np.dot(self._vectors, embedding) # This does most of the work
110+
# This line does most of the work:
111+
scores: Iterable[float] = np.dot(self._vectors, embedding)
101112
scored_ordinals = [
102-
ScoredOrdinal(i, float(score))
113+
Scored(i, score)
103114
for i, score in enumerate(scores)
104-
if score >= min_score
115+
if score >= min_score and (predicate is None or predicate(i))
105116
]
106117
scored_ordinals.sort(key=lambda x: x.score, reverse=True)
107118
return scored_ordinals[:max_hits]
108119

120+
# TODO: Make this and fizzy_lookup_embedding() more similar.
121+
def fuzzy_lookup_embedding_in_subset(
122+
self,
123+
embedding: NormalizedEmbedding,
124+
ordinals_of_subset: list[int],
125+
max_hits: int | None = None,
126+
min_score: float | None = None,
127+
) -> list[Scored]:
128+
return self.fuzzy_lookup_embedding(
129+
embedding, max_hits, min_score, lambda i: i in ordinals_of_subset
130+
)
131+
132+
async def fuzzy_lookup(
133+
self,
134+
key: str,
135+
max_hits: int | None = None,
136+
min_score: float | None = None,
137+
predicate: Callable[[int], bool] | None = None,
138+
) -> list[Scored]:
139+
embedding = await self.get_embedding(key)
140+
return self.fuzzy_lookup_embedding(
141+
embedding, max_hits=max_hits, min_score=min_score, predicate=predicate
142+
)
143+
109144
def clear(self) -> None:
110145
self._vectors = np.array([], dtype=np.float32)
111146
self._vectors.shape = (0, self._embedding_size)
112147

113-
def serialize_embedding_at(self, ordinal: int) -> NormalizedEmbedding | None:
114-
return self._vectors[ordinal] if 0 <= ordinal < len(self._vectors) else None
148+
def get_embedding_at(self, pos: int) -> NormalizedEmbedding:
149+
if 0 <= pos < len(self._vectors):
150+
return self._vectors[pos]
151+
raise IndexError(
152+
f"Index {pos} out of bounds for embedding index of size {len(self)}"
153+
)
154+
155+
def serialize_embedding_at(self, pos: int) -> NormalizedEmbedding | None:
156+
return self._vectors[pos] if 0 <= pos < len(self._vectors) else None
115157

116158
def serialize(self) -> NormalizedEmbeddings:
117159
assert self._vectors.shape == (len(self._vectors), self._embedding_size)
@@ -181,7 +223,7 @@ def debugv(heading: str):
181223
log("\nFuzzy lookups:")
182224
for word in words + ["pancakes", "hello world", "book", "author"]:
183225
neighbors = await v.fuzzy_lookup(word, max_hits=3)
184-
log(f"{word}:", [(nb.ordinal, nb.score) for nb in neighbors])
226+
log(f"{word}:", [(nb.item, nb.score) for nb in neighbors])
185227

186228

187229
if __name__ == "__main__":

python/ta/typeagent/demo/ui.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33

44
import asyncio
55
import io
6-
from pprint import pprint
76
import readline
87
import shutil
98
import sys
109
import traceback
1110
from typing import Any
1211

12+
from black import format_str, FileMode
1313
import typechat
1414

1515
from ..aitools.auth import load_dotenv
@@ -35,6 +35,15 @@
3535
cap = min # More readable name for capping a value at some limit
3636

3737

38+
def pretty_print(obj: object) -> None:
39+
"""Pretty-print an object using black.
40+
41+
Only works if the repr() is a valid Python expression.
42+
"""
43+
line_width = cap(200, shutil.get_terminal_size().columns)
44+
print(format_str(repr(obj), mode=FileMode(line_length=line_width)))
45+
46+
3847
def main() -> None:
3948
load_dotenv()
4049
translator = create_translator()
@@ -124,8 +133,6 @@ async def process_query(
124133
conversation: IConversation[IMessage, Any],
125134
translator: typechat.TypeChatJsonTranslator[SearchQuery],
126135
):
127-
line_width = cap(200, shutil.get_terminal_size().columns)
128-
129136
# Gradually turn the query text into something we can use to search.
130137

131138
# TODO: # 0. Recognize @-commands like "@search" and handle them specially.
@@ -138,7 +145,7 @@ async def process_query(
138145
if search_query is None:
139146
print("Failed to translate command to search terms.")
140147
return
141-
pprint(search_query, width=line_width)
148+
pretty_print(search_query)
142149
print()
143150

144151
# 2. Translate the search query into something directly usable as a query.
@@ -149,7 +156,7 @@ async def process_query(
149156
return
150157
for i, query_expr in enumerate(query_exprs):
151158
print(f"---------- {i} ----------")
152-
pprint(query_expr, width=line_width)
159+
pretty_print(query_expr)
153160
print()
154161

155162
# 3. Search!

python/ta/typeagent/knowpro/convthreads.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ async def lookup_thread(
3737
)
3838
return [
3939
ScoredThreadOrdinal(
40-
match.ordinal,
40+
match.item,
4141
match.score,
4242
)
4343
for match in matches

0 commit comments

Comments
 (0)