diff --git a/ts/examples/chat/src/memory/knowproMemory.ts b/ts/examples/chat/src/memory/knowproMemory.ts index af7a7c09d..24dc85b2b 100644 --- a/ts/examples/chat/src/memory/knowproMemory.ts +++ b/ts/examples/chat/src/memory/knowproMemory.ts @@ -165,6 +165,7 @@ export async function createKnowproCommands( description: "Search current knowPro conversation by terms", options: { maxToDisplay: argNum("Maximum matches to display", 25), + type: arg("Knowledge type"), }, }; } @@ -188,9 +189,10 @@ export async function createKnowproCommands( `Searching ${conversation.nameTag}...`, ); - const matches = await kp.searchTermsInConversation( + const matches = await kp.searchConversation( conversation, terms, + namedArgs.type, ); if (matches === undefined || matches.size === 0) { context.printer.writeLine("No matches"); diff --git a/ts/packages/knowPro/src/accumulators.ts b/ts/packages/knowPro/src/accumulators.ts new file mode 100644 index 000000000..f01411584 --- /dev/null +++ b/ts/packages/knowPro/src/accumulators.ts @@ -0,0 +1,379 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { collections, createTopNList } from "typeagent"; +import { + IMessage, + KnowledgeType, + MessageIndex, + ScoredSemanticRef, + SemanticRef, + SemanticRefIndex, + Term, + TextRange, +} from "./dataFormat.js"; +import { isInTextRange } from "./query.js"; + +export interface Match { + value: T; + score: number; + hitCount: number; +} + +/** + * Sort in place + * @param matches + */ +export function sortMatchesByRelevance(matches: Match[]) { + matches.sort((x, y) => y.score - x.score); +} + +export class MatchAccumulator { + private matches: Map>; + private maxHitCount: number; + + constructor() { + this.matches = new Map>(); + this.maxHitCount = 0; + } + + public get numMatches(): number { + return this.matches.size; + } + + public get maxHits(): number { + return this.maxHitCount; + } + + public has(value: T): boolean { + return this.matches.has(value); + } + + public getMatch(value: T): Match | undefined { + return this.matches.get(value); + } + + public setMatch(match: Match): void { + this.matches.set(match.value, match); + if (match.hitCount > this.maxHitCount) { + this.maxHitCount = match.hitCount; + } + } + + public setMatches(matches: Match[] | IterableIterator>): void { + for (const match of matches) { + this.setMatch(match); + } + } + + public add(value: T, score: number): void { + let match = this.matches.get(value); + if (match !== undefined) { + match.hitCount += 1; + match.score += score; + } else { + match = { + value, + score, + hitCount: 1, + }; + this.matches.set(value, match); + } + if (match.hitCount > this.maxHitCount) { + this.maxHitCount = match.hitCount; + } + } + + public getSortedByScore(minHitCount?: number): Match[] { + if (this.matches.size === 0) { + return []; + } + const matches = [...this.matchesWithMinHitCount(minHitCount)]; + matches.sort((x, y) => y.score - x.score); + return matches; + } + + /** + * Return the top N scoring matches + * @param maxMatches + * @returns + */ + public getTopNScoring( + maxMatches?: number, + minHitCount?: number, + ): Match[] { + if (this.matches.size === 0) { + return []; + } + if (maxMatches && maxMatches > 0) { + const topList = createTopNList(maxMatches); + for (const match of this.matchesWithMinHitCount(minHitCount)) { + topList.push(match.value, match.score); + } + const ranked = topList.byRank(); + return ranked.map((m) => this.matches.get(m.item)!); + } else { + return this.getSortedByScore(minHitCount); + } + } + + public getMatches(): IterableIterator> { + return this.matches.values(); + } + + public *getMatchesWhere( + predicate: (match: Match) => boolean, + ): IterableIterator> { + for (const match of this.matches.values()) { + if (predicate(match)) { + yield match; + } + } + } + + public clearMatches(): void { + this.matches.clear(); + this.maxHitCount = 0; + } + + public reduceTopNScoring( + maxMatches?: number, + minHitCount?: number, + ): number { + const topN = this.getTopNScoring(maxMatches, minHitCount); + this.clearMatches(); + if (topN.length > 0) { + this.setMatches(topN); + } + return topN.length; + } + + public union(other: MatchAccumulator): void { + for (const matchFrom of other.matches.values()) { + const matchTo = this.matches.get(matchFrom.value); + if (matchTo !== undefined) { + // Existing + matchTo.hitCount += matchFrom.hitCount; + matchTo.score += matchFrom.score; + } else { + this.matches.set(matchFrom.value, matchFrom); + } + } + } + + private matchesWithMinHitCount( + minHitCount: number | undefined, + ): IterableIterator> { + return minHitCount !== undefined && minHitCount > 0 + ? this.getMatchesWhere((m) => m.hitCount >= minHitCount) + : this.matches.values(); + } +} + +export class SemanticRefAccumulator extends MatchAccumulator { + constructor(public queryTermMatches = new QueryTermAccumulator()) { + super(); + } + + public addTermMatch( + term: Term, + semanticRefs: ScoredSemanticRef[] | undefined, + scoreBoost?: number, + ) { + if (semanticRefs) { + scoreBoost ??= term.score ?? 0; + for (const match of semanticRefs) { + this.add(match.semanticRefIndex, match.score + scoreBoost); + } + this.queryTermMatches.add(term); + } + } + + public addRelatedTermMatch( + primaryTerm: Term, + relatedTerm: Term, + semanticRefs: ScoredSemanticRef[] | undefined, + scoreBoost?: number, + ) { + if (semanticRefs) { + // Related term matches count as matches for the queryTerm... + // BUT are scored with the score of the related term + scoreBoost ??= relatedTerm.score ?? 0; + for (const semanticRef of semanticRefs) { + let score = semanticRef.score + scoreBoost; + let match = this.getMatch(semanticRef.semanticRefIndex); + if (match !== undefined) { + if (match.score < score) { + match.score = score; + } + } else { + match = { + value: semanticRef.semanticRefIndex, + score, + hitCount: 1, + }; + this.setMatch(match); + } + } + this.queryTermMatches.add(primaryTerm, relatedTerm); + } + } + + public override getSortedByScore( + minHitCount?: number, + ): Match[] { + return super.getSortedByScore(this.getMinHitCount(minHitCount)); + } + + public override getTopNScoring( + maxMatches?: number, + minHitCount?: number, + ): Match[] { + return super.getTopNScoring( + maxMatches, + this.getMinHitCount(minHitCount), + ); + } + + public groupMatchesByKnowledgeType( + semanticRefs: SemanticRef[], + ): Map { + const groups = new Map(); + for (const match of this.getMatches()) { + const semanticRef = semanticRefs[match.value]; + let group = groups.get(semanticRef.knowledgeType); + if (group === undefined) { + group = new SemanticRefAccumulator(); + group.queryTermMatches = this.queryTermMatches; + groups.set(semanticRef.knowledgeType, group); + } + group.setMatch(match); + } + return groups; + } + + public toScoredSemanticRefs(): ScoredSemanticRef[] { + return this.getSortedByScore(0).map((m) => { + return { + semanticRefIndex: m.value, + score: m.score, + }; + }, 0); + } + + private getMinHitCount(minHitCount?: number): number { + return minHitCount !== undefined + ? minHitCount + : //: this.queryTermMatches.termMatches.size; + this.maxHits; + } +} + +export class QueryTermAccumulator { + constructor( + public termMatches: Set = new Set(), + public relatedTermToTerms: Map> = new Map< + string, + Set + >(), + ) {} + + public add(term: Term, relatedTerm?: Term) { + this.termMatches.add(term.text); + if (relatedTerm !== undefined) { + let relatedTermToTerms = this.relatedTermToTerms.get( + relatedTerm.text, + ); + if (relatedTermToTerms === undefined) { + relatedTermToTerms = new Set(); + this.relatedTermToTerms.set( + relatedTerm.text, + relatedTermToTerms, + ); + } + relatedTermToTerms.add(term.text); + } + } + + public matched(testText: string | string[], expectedText: string): boolean { + if (Array.isArray(testText)) { + if (testText.length > 0) { + for (const text of testText) { + if (this.matched(text, expectedText)) { + return true; + } + } + } + return false; + } + + if ( + this.termMatches.has(testText) && + collections.stringEquals(testText, expectedText, false) + ) { + return true; + } + + // Maybe the test text matched a related term. + // If so, the matching related term should have matched *on behalf* of + // of expectedTerm + const relatedTermToTerms = this.relatedTermToTerms.get(testText); + return relatedTermToTerms !== undefined + ? relatedTermToTerms.has(expectedText) + : false; + } + + public didValueMatch( + obj: Record, + key: string, + expectedValue: string, + ): boolean { + const value = obj[key]; + if (value === undefined) { + return false; + } + if (Array.isArray(value)) { + for (const item of value) { + if (this.didValueMatch(item, key, expectedValue)) { + return true; + } + } + return false; + } else { + const stringValue = value.toString().toLowerCase(); + return this.matched(stringValue, expectedValue); + } + } +} + +export class TextRangeAccumulator { + constructor( + public rangesForMessage: Map = new Map< + MessageIndex, + TextRange[] + >(), + ) {} + + public addTextRange(textRange: TextRange) { + const messageIndex = textRange.start.messageIndex; + let textRanges = this.rangesForMessage.get(messageIndex); + if (textRanges === undefined) { + textRanges = [textRange]; + } + textRanges.push(textRange); + } + + public isInRange(textRange: TextRange): boolean { + const textRanges = this.rangesForMessage.get( + textRange.start.messageIndex, + ); + if (textRanges === undefined) { + return false; + } + return textRanges.some((outerRange) => + isInTextRange(outerRange, textRange), + ); + } +} + +export class MessageAccumulator extends MatchAccumulator {} diff --git a/ts/packages/knowPro/src/dataFormat.ts b/ts/packages/knowPro/src/dataFormat.ts index 829ca0c23..b36074fe3 100644 --- a/ts/packages/knowPro/src/dataFormat.ts +++ b/ts/packages/knowPro/src/dataFormat.ts @@ -78,9 +78,11 @@ export interface IConversation { relatedTermsIndex?: ITermToRelatedTermsIndex | undefined; } +export type MessageIndex = number; + export interface TextLocation { // the index of the message - messageIndex: number; + messageIndex: MessageIndex; // the index of the chunk chunkIndex?: number; // the index of the character within the chunk diff --git a/ts/packages/knowPro/src/query.ts b/ts/packages/knowPro/src/query.ts index d0998ff4f..fc88b9889 100644 --- a/ts/packages/knowPro/src/query.ts +++ b/ts/packages/knowPro/src/query.ts @@ -1,19 +1,26 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -import { createTopNList } from "typeagent"; import { IConversation, + IMessage, ITermToRelatedTermsIndex, ITermToSemanticRefIndex, KnowledgeType, QueryTerm, - ScoredSemanticRef, SemanticRef, SemanticRefIndex, - Term, + TextLocation, + TextRange, } from "./dataFormat.js"; import * as knowLib from "knowledge-processor"; +import { + Match, + MatchAccumulator, + QueryTermAccumulator, + SemanticRefAccumulator, +} from "./accumulators.js"; +import { collections, dateTime } from "typeagent"; export function isConversationSearchable(conversation: IConversation): boolean { return ( @@ -22,6 +29,110 @@ export function isConversationSearchable(conversation: IConversation): boolean { ); } +export function textRangeForConversation( + conversation: IConversation, +): TextRange { + const messages = conversation.messages; + return { + start: { messageIndex: 0 }, + end: { messageIndex: messages.length - 1 }, + }; +} + +export type TimestampRange = { + start: string; + end?: string | undefined; +}; + +export function timestampRangeForConversation( + conversation: IConversation, +): TimestampRange | undefined { + const messages = conversation.messages; + const start = messages[0].timestamp; + const end = messages[messages.length - 1].timestamp; + if (start !== undefined) { + return { + start, + end, + }; + } + return undefined; +} + +/** + * Assumes messages are in timestamp order. + * @param conversation + */ +export function getMessagesInDateRange( + conversation: IConversation, + dateRange: DateRange, +): IMessage[] { + return collections.getInRange( + conversation.messages, + dateTime.timestampString(dateRange.start), + dateRange.end ? dateTime.timestampString(dateRange.end) : undefined, + (x, y) => x.localeCompare(y), + ); +} +/** + * Returns: + * 0 if locations are equal + * < 0 if x is less than y + * > 0 if x is greater than y + * @param x + * @param y + * @returns + */ +export function compareTextLocation(x: TextLocation, y: TextLocation): number { + let cmp = x.messageIndex - y.messageIndex; + if (cmp !== 0) { + return cmp; + } + cmp = (x.chunkIndex ?? 0) - (y.chunkIndex ?? 0); + if (cmp !== 0) { + return cmp; + } + return (x.charIndex ?? 0) - (y.charIndex ?? 0); +} + +const MaxTextLocation: TextLocation = { + messageIndex: Number.MAX_SAFE_INTEGER, + chunkIndex: Number.MAX_SAFE_INTEGER, + charIndex: Number.MAX_SAFE_INTEGER, +}; + +export function isInTextRange( + outerRange: TextRange, + innerRange: TextRange, +): boolean { + // outer start must be <= inner start + // inner end must be <= outerEnd + let cmpStart = compareTextLocation(outerRange.start, innerRange.start); + let cmpEnd = compareTextLocation( + innerRange.end ?? MaxTextLocation, + outerRange.end ?? MaxTextLocation, + ); + return cmpStart <= 0 && cmpEnd <= 0; +} + +export type DateRange = { + start: Date; + end?: Date | undefined; +}; + +export function compareDates(x: Date, y: Date): number { + return x.getTime() - y.getTime(); +} + +export function isDateInRange(outerRange: DateRange, date: Date): boolean { + // outer start must be <= date + // date must be <= outer end + let cmpStart = compareDates(outerRange.start, date); + let cmpEnd = + outerRange.end !== undefined ? compareDates(date, outerRange.end) : -1; + return cmpStart <= 0 && cmpEnd <= 0; +} + // Query eval expressions export interface IQueryOpExpr { @@ -36,14 +147,26 @@ export class QueryEvalContext { } public get semanticRefIndex(): ITermToSemanticRefIndex { + this.conversation.messages; return this.conversation.semanticRefIndex!; } + public get semanticRefs(): SemanticRef[] { return this.conversation.semanticRefs!; } + public get relatedTermIndex(): ITermToRelatedTermsIndex | undefined { return this.conversation.relatedTermsIndex; } + + public getSemanticRef(semanticRefIndex: SemanticRefIndex): SemanticRef { + return this.semanticRefs[semanticRefIndex]; + } + + public getMessageForRef(semanticRef: SemanticRef): IMessage { + const messageIndex = semanticRef.range.start.messageIndex; + return this.conversation.messages[messageIndex]; + } } export class SelectTopNExpr @@ -178,7 +301,7 @@ export class WhereSemanticRefExpr { constructor( public sourceExpr: IQueryOpExpr, - public predicates: IQueryOpPredicate[], + public predicates: IQuerySemanticRefPredicate[], ) {} public async eval( @@ -188,35 +311,55 @@ export class WhereSemanticRefExpr const filtered = new SemanticRefAccumulator( accumulator.queryTermMatches, ); - const semanticRefs = context.semanticRefs; filtered.setMatches( accumulator.getMatchesWhere((match) => - this.testOr(semanticRefs, accumulator.queryTermMatches, match), + this.matchPredicatesOr( + context, + accumulator.queryTermMatches, + match, + ), ), ); return filtered; } - private testOr( - semanticRefs: SemanticRef[], + private matchPredicatesOr( + context: QueryEvalContext, queryTermMatches: QueryTermAccumulator, match: Match, ) { for (let i = 0; i < this.predicates.length; ++i) { - const semanticRef = semanticRefs[match.value]; - if (this.predicates[i].eval(queryTermMatches, semanticRef)) { + const semanticRef = context.getSemanticRef(match.value); + if ( + this.predicates[i].eval(context, queryTermMatches, semanticRef) + ) { return true; } } return false; } } +export interface IQuerySemanticRefPredicate { + eval( + context: QueryEvalContext, + termMatches: QueryTermAccumulator, + semanticRef: SemanticRef, + ): boolean; +} + +export class KnowledgeTypePredicate implements IQuerySemanticRefPredicate { + constructor(public type: KnowledgeType) {} -export interface IQueryOpPredicate { - eval(termMatches: QueryTermAccumulator, semanticRef: SemanticRef): boolean; + public eval( + context: QueryEvalContext, + termMatches: QueryTermAccumulator, + semanticRef: SemanticRef, + ): boolean { + return semanticRef.knowledgeType === this.type; + } } -export class EntityPredicate implements IQueryOpPredicate { +export class EntityPredicate implements IQuerySemanticRefPredicate { constructor( public type: string | undefined, public name: string | undefined, @@ -224,6 +367,7 @@ export class EntityPredicate implements IQueryOpPredicate { ) {} public eval( + context: QueryEvalContext, termMatches: QueryTermAccumulator, semanticRef: SemanticRef, ): boolean { @@ -233,8 +377,8 @@ export class EntityPredicate implements IQueryOpPredicate { const entity = semanticRef.knowledge as knowLib.conversation.ConcreteEntity; return ( - termMatches.matched(entity.type, this.type) && - termMatches.matched(entity.name, this.name) && + isPropertyMatch(termMatches, entity.type, this.type) && + isPropertyMatch(termMatches, entity.name, this.name) && this.matchFacet(termMatches, entity, this.facetName) ); } @@ -248,7 +392,7 @@ export class EntityPredicate implements IQueryOpPredicate { return false; } for (const facet of entity.facets) { - if (termMatches.matched(facet.name, facetName)) { + if (isPropertyMatch(termMatches, facet.name, facetName)) { return true; } } @@ -256,13 +400,14 @@ export class EntityPredicate implements IQueryOpPredicate { } } -export class ActionPredicate implements IQueryOpPredicate { +export class ActionPredicate implements IQuerySemanticRefPredicate { constructor( public subjectEntityName?: string | undefined, public objectEntityName?: string | undefined, ) {} public eval( + context: QueryEvalContext, termMatches: QueryTermAccumulator, semanticRef: SemanticRef, ): boolean { @@ -271,319 +416,37 @@ export class ActionPredicate implements IQueryOpPredicate { } const action = semanticRef.knowledge as knowLib.conversation.Action; return ( - termMatches.matched( + isPropertyMatch( + termMatches, action.subjectEntityName, this.subjectEntityName, ) && - termMatches.matched(action.objectEntityName, this.objectEntityName) + isPropertyMatch( + termMatches, + action.objectEntityName, + this.objectEntityName, + ) ); } } -export interface Match { - value: T; - score: number; - hitCount: number; -} - -/** - * Sort in place - * @param matches - */ -export function sortMatchesByRelevance(matches: Match[]) { - matches.sort((x, y) => y.score - x.score); -} - -export class MatchAccumulator { - private matches: Map>; - - constructor() { - this.matches = new Map>(); - } - - public get numMatches(): number { - return this.matches.size; - } +export class ScopeExpr implements IQueryOpExpr { + constructor(public predicates: IQueryScopePredicate[]) {} - public getMatch(value: T): Match | undefined { - return this.matches.get(value); - } - - public setMatch(match: Match): void { - this.matches.set(match.value, match); - } - - public setMatches(matches: Match[] | IterableIterator>): void { - for (const match of matches) { - this.matches.set(match.value, match); - } - } - - public add(value: T, score: number): void { - let match = this.matches.get(value); - if (match !== undefined) { - match.hitCount += 1; - match.score += score; - } else { - match = { - value, - score, - hitCount: 1, - }; - this.matches.set(value, match); - } - } - - public getSortedByScore(minHitCount?: number): Match[] { - if (this.matches.size === 0) { - return []; - } - const matches = [...this.matchesWithMinHitCount(minHitCount)]; - matches.sort((x, y) => y.score - x.score); - return matches; - } - - /** - * Return the top N scoring matches - * @param maxMatches - * @returns - */ - public getTopNScoring( - maxMatches?: number, - minHitCount?: number, - ): Match[] { - if (this.matches.size === 0) { - return []; - } - if (maxMatches && maxMatches > 0) { - const topList = createTopNList(maxMatches); - for (const match of this.matchesWithMinHitCount(minHitCount)) { - topList.push(match.value, match.score); - } - const ranked = topList.byRank(); - return ranked.map((m) => this.matches.get(m.item)!); - } else { - return this.getSortedByScore(minHitCount); - } - } - - public getMatches(): IterableIterator> { - return this.matches.values(); - } - - public *getMatchesWhere( - predicate: (match: Match) => boolean, - ): IterableIterator> { - for (const match of this.matches.values()) { - if (predicate(match)) { - yield match; - } - } - } - - public removeMatchesWhere(predicate: (match: Match) => boolean): void { - const valuesToRemove: T[] = []; - for (const match of this.getMatchesWhere(predicate)) { - valuesToRemove.push(match.value); - } - this.removeMatches(valuesToRemove); - } - - public removeMatches(valuesToRemove: T[]): void { - if (valuesToRemove.length > 0) { - for (const item of valuesToRemove) { - this.matches.delete(item); - } - } - } - - public clearMatches(): void { - this.matches.clear(); - } - - public mapMatches(map: (m: Match) => M): M[] { - const items: M[] = []; - for (const match of this.matches.values()) { - items.push(map(match)); - } - return items; - } - - public reduceTopNScoring( - maxMatches?: number, - minHitCount?: number, - ): number { - const topN = this.getTopNScoring(maxMatches, minHitCount); - this.clearMatches(); - if (topN.length > 0) { - this.setMatches(topN); - } - return topN.length; - } - - private matchesWithMinHitCount( - minHitCount: number | undefined, - ): IterableIterator> { - return minHitCount !== undefined && minHitCount > 0 - ? this.getMatchesWhere((m) => m.hitCount >= minHitCount) - : this.matches.values(); + public eval(context: QueryEvalContext): Promise { + return Promise.resolve(); } } -export class SemanticRefAccumulator extends MatchAccumulator { - constructor(public queryTermMatches = new QueryTermAccumulator()) { - super(); - } - - public addTermMatch( - term: Term, - semanticRefs: ScoredSemanticRef[] | undefined, - scoreBoost?: number, - ) { - if (semanticRefs) { - scoreBoost ??= term.score ?? 0; - for (const match of semanticRefs) { - this.add(match.semanticRefIndex, match.score + scoreBoost); - } - this.queryTermMatches.add(term); - } - } - - public addRelatedTermMatch( - primaryTerm: Term, - relatedTerm: Term, - semanticRefs: ScoredSemanticRef[] | undefined, - scoreBoost?: number, - ) { - if (semanticRefs) { - // Related term matches count as matches for the queryTerm... - // BUT are scored with the score of the related term - scoreBoost ??= relatedTerm.score ?? 0; - for (const semanticRef of semanticRefs) { - let score = semanticRef.score + scoreBoost; - let match = this.getMatch(semanticRef.semanticRefIndex); - if (match !== undefined) { - if (match.score < score) { - match.score = score; - } - } else { - match = { - value: semanticRef.semanticRefIndex, - score, - hitCount: 1, - }; - this.setMatch(match); - } - } - this.queryTermMatches.add(primaryTerm, relatedTerm); - } - } - - public override getSortedByScore( - minHitCount?: number, - ): Match[] { - return super.getSortedByScore(this.getMinHitCount(minHitCount)); - } - - public override getTopNScoring( - maxMatches?: number, - minHitCount?: number, - ): Match[] { - return super.getTopNScoring( - maxMatches, - this.getMinHitCount(minHitCount), - ); - } - - public groupMatchesByKnowledgeType( - semanticRefs: SemanticRef[], - ): Map { - const groups = new Map(); - for (const match of this.getMatches()) { - const semanticRef = semanticRefs[match.value]; - let group = groups.get(semanticRef.knowledgeType); - if (group === undefined) { - group = new SemanticRefAccumulator(); - group.queryTermMatches = this.queryTermMatches; - groups.set(semanticRef.knowledgeType, group); - } - group.setMatch(match); - } - return groups; - } - - public toScoredSemanticRefs(): ScoredSemanticRef[] { - return this.getSortedByScore(0).map((m) => { - return { - semanticRefIndex: m.value, - score: m.score, - }; - }, 0); - } - - private getMinHitCount(minHitCount?: number): number { - return minHitCount !== undefined - ? minHitCount - : this.queryTermMatches.termMatches.size; - } -} - -export class QueryTermAccumulator { - constructor( - public termMatches: Set = new Set(), - public relatedTermToTerms: Map> = new Map< - string, - Set - >(), - ) {} - - public add(term: Term, relatedTerm?: Term) { - this.termMatches.add(term.text); - if (relatedTerm !== undefined) { - let relatedTermToTerms = this.relatedTermToTerms.get( - relatedTerm.text, - ); - if (relatedTermToTerms === undefined) { - relatedTermToTerms = new Set(); - this.relatedTermToTerms.set( - relatedTerm.text, - relatedTermToTerms, - ); - } - relatedTermToTerms.add(term.text); - } - } - - public matched( - testText: string | string[] | undefined, - expectedText: string | undefined, - ): boolean { - if (expectedText === undefined) { - return true; - } - if (testText === undefined) { - return false; - } - - if (Array.isArray(testText)) { - for (const text of testText) { - if (this.matched(text, expectedText)) { - return true; - } - } - return false; - } - - if (testText === expectedText) { - return true; - } +export interface IQueryScopePredicate {} - // Maybe the test text matched a related term. - // If so, the matching related term should have matched on behalf of - // of a term === expectedTerm - const relatedTermToTerms = this.relatedTermToTerms.get(testText); - return relatedTermToTerms !== undefined - ? relatedTermToTerms.has(expectedText) - : false; +function isPropertyMatch( + termMatches: QueryTermAccumulator, + testText: string | string[] | undefined, + expectedText: string | undefined, +) { + if (testText !== undefined && expectedText !== undefined) { + return termMatches.matched(testText, expectedText); } + return testText === undefined && expectedText === undefined; } diff --git a/ts/packages/knowPro/src/search.ts b/ts/packages/knowPro/src/search.ts index 3bb026626..22ed3c23f 100644 --- a/ts/packages/knowPro/src/search.ts +++ b/ts/packages/knowPro/src/search.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +import { SemanticRefAccumulator } from "./accumulators.js"; import { IConversation, KnowledgeType, @@ -22,43 +23,27 @@ export type SearchResult = { * @param minHitCount * @returns */ -export async function searchTermsInConversation( +export async function searchConversation( conversation: IConversation, terms: QueryTerm[], + type?: KnowledgeType, maxMatches?: number, - minHitCount?: number, ): Promise | undefined> { if (!q.isConversationSearchable(conversation)) { return undefined; } const context = new q.QueryEvalContext(conversation); - const queryTerms = new q.QueryTermsExpr(terms); - const query = new q.SelectTopNKnowledgeGroupExpr( - new q.GroupByKnowledgeTypeExpr( - new q.TermsMatchExpr( - conversation.relatedTermsIndex !== undefined - ? new q.ResolveRelatedTermsExpr(queryTerms) - : queryTerms, - ), - ), + const query = createTermSearchQuery( + conversation, + terms, + type ? [new q.KnowledgeTypePredicate(type)] : undefined, maxMatches, - minHitCount, ); - const evalResults = await query.eval(context); - const semanticRefMatches = new Map(); - for (const [type, accumulator] of evalResults) { - if (accumulator.numMatches > 0) { - semanticRefMatches.set(type, { - termMatches: accumulator.queryTermMatches.termMatches, - semanticRefMatches: accumulator.toScoredSemanticRefs(), - }); - } - } - return semanticRefMatches; + return toGroupedSearchResults(await query.eval(context)); } -export async function searchTermsInConversationExact( +export async function searchConversationExact( conversation: IConversation, terms: QueryTerm[], maxMatches?: number, @@ -80,3 +65,46 @@ export async function searchTermsInConversationExact( semanticRefMatches: evalResults.toScoredSemanticRefs(), }; } + +function createTermSearchQuery( + conversation: IConversation, + terms: QueryTerm[], + wherePredicates?: q.IQuerySemanticRefPredicate[] | undefined, + maxMatches?: number, + minHitCount?: number, +) { + const queryTerms = new q.QueryTermsExpr(terms); + let termsMatchExpr: q.IQueryOpExpr = + new q.TermsMatchExpr( + conversation.relatedTermsIndex !== undefined + ? new q.ResolveRelatedTermsExpr(queryTerms) + : queryTerms, + ); + if (wherePredicates !== undefined && wherePredicates.length > 0) { + termsMatchExpr = new q.WhereSemanticRefExpr( + termsMatchExpr, + wherePredicates, + ); + } + const query = new q.SelectTopNKnowledgeGroupExpr( + new q.GroupByKnowledgeTypeExpr(termsMatchExpr), + maxMatches, + minHitCount, + ); + return query; +} + +function toGroupedSearchResults( + evalResults: Map, +) { + const semanticRefMatches = new Map(); + for (const [type, accumulator] of evalResults) { + if (accumulator.numMatches > 0) { + semanticRefMatches.set(type, { + termMatches: accumulator.queryTermMatches.termMatches, + semanticRefMatches: accumulator.toScoredSemanticRefs(), + }); + } + } + return semanticRefMatches; +}