Skip to content

Commit 63d0ece

Browse files
committed
Tag scope
1 parent 8ab1854 commit 63d0ece

File tree

5 files changed

+76
-29
lines changed

5 files changed

+76
-29
lines changed

ts/packages/knowPro/src/accumulators.ts

+34-13
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,11 @@ export class MatchAccumulator<T = any> {
117117
}
118118
}
119119

120-
public getMatches(): IterableIterator<Match<T>> {
121-
return this.matches.values();
122-
}
123-
124-
public *getMatchesWhere(
125-
predicate: (match: Match<T>) => boolean,
120+
public *getMatches(
121+
predicate?: (match: Match<T>) => boolean,
126122
): IterableIterator<Match<T>> {
127123
for (const match of this.matches.values()) {
128-
if (predicate(match)) {
124+
if (predicate === undefined || predicate(match)) {
129125
yield match;
130126
}
131127
}
@@ -165,7 +161,7 @@ export class MatchAccumulator<T = any> {
165161
minHitCount: number | undefined,
166162
): IterableIterator<Match<T>> {
167163
return minHitCount !== undefined && minHitCount > 0
168-
? this.getMatchesWhere((m) => m.hitCount >= minHitCount)
164+
? this.getMatches((m) => m.hitCount >= minHitCount)
169165
: this.matches.values();
170166
}
171167
}
@@ -235,6 +231,17 @@ export class SemanticRefAccumulator extends MatchAccumulator<SemanticRefIndex> {
235231
);
236232
}
237233

234+
public *getSemanticRefs(
235+
semanticRefs: SemanticRef[],
236+
predicate?: (semanticRef: SemanticRef) => boolean,
237+
) {
238+
for (const match of this.getMatches()) {
239+
const semanticRef = semanticRefs[match.value];
240+
if (predicate === undefined || predicate(semanticRef))
241+
yield semanticRef;
242+
}
243+
}
244+
238245
public groupMatchesByKnowledgeType(
239246
semanticRefs: SemanticRef[],
240247
): Map<KnowledgeType, SemanticRefAccumulator> {
@@ -252,6 +259,19 @@ export class SemanticRefAccumulator extends MatchAccumulator<SemanticRefIndex> {
252259
return groups;
253260
}
254261

262+
public selectInScope(
263+
semanticRefs: SemanticRef[],
264+
scope: TextRangeAccumulator,
265+
) {
266+
const accumulator = new SemanticRefAccumulator(this.queryTermMatches);
267+
for (const match of this.getMatches()) {
268+
if (scope.isInRange(semanticRefs[match.value].range)) {
269+
accumulator.setMatch(match);
270+
}
271+
}
272+
return accumulator;
273+
}
274+
255275
public toScoredSemanticRefs(): ScoredSemanticRef[] {
256276
return this.getSortedByScore(0).map((m) => {
257277
return {
@@ -348,13 +368,14 @@ export class QueryTermAccumulator {
348368

349369
export class TextRangeAccumulator {
350370
constructor(
351-
public rangesForMessage: Map<MessageIndex, TextRange[]> = new Map<
352-
MessageIndex,
353-
TextRange[]
354-
>(),
371+
private rangesForMessage = new Map<MessageIndex, TextRange[]>(),
355372
) {}
356373

357-
public addTextRange(textRange: TextRange) {
374+
public get size() {
375+
return this.rangesForMessage.size;
376+
}
377+
378+
public addRange(textRange: TextRange) {
358379
const messageIndex = textRange.start.messageIndex;
359380
let textRanges = this.rangesForMessage.get(messageIndex);
360381
if (textRanges === undefined) {

ts/packages/knowPro/src/query.ts

+22-7
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import {
1919
MatchAccumulator,
2020
QueryTermAccumulator,
2121
SemanticRefAccumulator,
22+
TextRangeAccumulator,
2223
} from "./accumulators.js";
2324
import { collections, dateTime } from "typeagent";
2425

@@ -312,7 +313,7 @@ export class WhereSemanticRefExpr
312313
accumulator.queryTermMatches,
313314
);
314315
filtered.setMatches(
315-
accumulator.getMatchesWhere((match) =>
316+
accumulator.getMatches((match) =>
316317
this.matchPredicatesOr(
317318
context,
318319
accumulator.queryTermMatches,
@@ -430,16 +431,30 @@ export class ActionPredicate implements IQuerySemanticRefPredicate {
430431
}
431432
}
432433

433-
export class ScopeExpr implements IQueryOpExpr<void> {
434-
constructor(public predicates: IQueryScopePredicate[]) {}
434+
export class ApplyTagScopeExpr implements IQueryOpExpr<SemanticRefAccumulator> {
435+
constructor(public sourceExpr: IQueryOpExpr<SemanticRefAccumulator>) {}
435436

436-
public eval(context: QueryEvalContext): Promise<void> {
437-
return Promise.resolve();
437+
public async eval(
438+
context: QueryEvalContext,
439+
): Promise<SemanticRefAccumulator> {
440+
let accumulator = await this.sourceExpr.eval(context);
441+
const tagScope = new TextRangeAccumulator();
442+
for (const semanticRef of accumulator.getSemanticRefs(
443+
context.semanticRefs,
444+
(sr) => sr.knowledgeType === "tag",
445+
)) {
446+
tagScope.addRange(semanticRef.range);
447+
}
448+
if (tagScope.size > 0) {
449+
accumulator = accumulator.selectInScope(
450+
context.semanticRefs,
451+
tagScope,
452+
);
453+
}
454+
return Promise.resolve(accumulator);
438455
}
439456
}
440457

441-
export interface IQueryScopePredicate {}
442-
443458
function isPropertyMatch(
444459
termMatches: QueryTermAccumulator,
445460
testText: string | string[] | undefined,

ts/packages/knowPro/src/search.ts

+17-6
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,21 @@ function createTermSearchQuery(
7272
wherePredicates?: q.IQuerySemanticRefPredicate[] | undefined,
7373
maxMatches?: number,
7474
minHitCount?: number,
75+
) {
76+
const query = new q.SelectTopNKnowledgeGroupExpr(
77+
new q.GroupByKnowledgeTypeExpr(
78+
createTermsMatch(conversation, terms, wherePredicates),
79+
),
80+
maxMatches,
81+
minHitCount,
82+
);
83+
return query;
84+
}
85+
86+
function createTermsMatch(
87+
conversation: IConversation,
88+
terms: QueryTerm[],
89+
wherePredicates?: q.IQuerySemanticRefPredicate[] | undefined,
7590
) {
7691
const queryTerms = new q.QueryTermsExpr(terms);
7792
let termsMatchExpr: q.IQueryOpExpr<SemanticRefAccumulator> =
@@ -80,18 +95,14 @@ function createTermSearchQuery(
8095
? new q.ResolveRelatedTermsExpr(queryTerms)
8196
: queryTerms,
8297
);
98+
termsMatchExpr = new q.ApplyTagScopeExpr(termsMatchExpr);
8399
if (wherePredicates !== undefined && wherePredicates.length > 0) {
84100
termsMatchExpr = new q.WhereSemanticRefExpr(
85101
termsMatchExpr,
86102
wherePredicates,
87103
);
88104
}
89-
const query = new q.SelectTopNKnowledgeGroupExpr(
90-
new q.GroupByKnowledgeTypeExpr(termsMatchExpr),
91-
maxMatches,
92-
minHitCount,
93-
);
94-
return query;
105+
return termsMatchExpr;
95106
}
96107

97108
function toGroupedSearchResults(

ts/packages/knowPro/src/termIndex.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import {
1818
ITextEmbeddingDataItem,
1919
ITextEmbeddingData,
2020
} from "./dataFormat.js";
21+
import { createEmbeddingCache } from "knowledge-processor";
2122

2223
export async function buildTermSemanticIndex(
2324
settings: SemanticIndexSettings,
@@ -166,7 +167,7 @@ export type SemanticIndexSettings = {
166167

167168
export function createSemanticIndexSettings(): SemanticIndexSettings {
168169
return {
169-
embeddingModel: openai.createEmbeddingModel(),
170+
embeddingModel: createEmbeddingCache(openai.createEmbeddingModel(), 64),
170171
minScore: 0.8,
171172
retryMaxAttempts: 2,
172173
retryPauseMs: 2000,

ts/packages/knowledgeProcessor/src/modelCache.ts

+1-2
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,12 @@ export function createEmbeddingCache(
1919
model: TextEmbeddingModel,
2020
cacheSize: number,
2121
): TextEmbeddingModelWithCache {
22-
const maxBatchSize = 1;
2322
const cache: collections.Cache<string, number[]> =
2423
collections.createLRUCache(cacheSize);
2524
const modelWithCache: TextEmbeddingModelWithCache = {
2625
cache,
2726
generateEmbedding,
28-
maxBatchSize,
27+
maxBatchSize: model.maxBatchSize,
2928
};
3029
if (model.generateEmbeddingBatch) {
3130
modelWithCache.generateEmbeddingBatch = generateEmbeddingBatch;

0 commit comments

Comments
 (0)