Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -759,4 +759,4 @@ teardown:
- match: { status: 400 }
- match: { error.type: status_exception }
- contains: { error.reason: "[diversify] search failed" }
- match: { error.suppressed.0.type: illegal_argument_exception }
- match: { error.suppressed.0.type: status_exception }
1 change: 1 addition & 0 deletions server/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -514,4 +514,5 @@
exports org.elasticsearch.index.mapper.blockloader.docvalues;
exports org.elasticsearch.index.mapper.blockloader.docvalues.fn;
exports org.elasticsearch.readiness to org.elasticsearch.internal.sigterm;
exports org.elasticsearch.search.diversification;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
import org.elasticsearch.index.query.InterceptedQueryBuilderWrapper;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.diversification.mmr.MMRResultDiversificationContext;
import org.elasticsearch.search.fetch.StoredFieldsContext;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
Expand All @@ -36,6 +40,7 @@
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -65,16 +70,16 @@ public final class DiversifyRetrieverBuilder extends CompoundRetrieverBuilder<Di
public static final ParseField LAMBDA_FIELD = new ParseField("lambda");
public static final ParseField SIZE_FIELD = new ParseField("size");

public static class RankDocWithSearchHit extends RankDoc {
private final SearchHit hit;
public static class RankDocWithDenseVector extends RankDoc {
private final VectorData vector;

public RankDocWithSearchHit(int doc, float score, int shardIndex, SearchHit hit) {
public RankDocWithDenseVector(int doc, float score, int shardIndex, @Nullable VectorData vector) {
super(doc, score, shardIndex);
this.hit = hit;
this.vector = vector;
}

public SearchHit hit() {
return hit;
public VectorData vector() {
return vector;
}
}

Expand Down Expand Up @@ -302,8 +307,21 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {

@Override
protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) {
SearchSourceBuilder builder = sourceBuilder.from(0);
return super.finalizeSourceBuilder(builder).docValueField(diversificationField);
StoredFieldsContext sfCtx = StoredFieldsContext.fromList(List.of(InferenceMetadataFieldsMapper.NAME, diversificationField));
FetchSourceContext fsCtx = FetchSourceContext.of(
false,
false,
new String[] { InferenceMetadataFieldsMapper.NAME, diversificationField },
null
);

SearchSourceBuilder builder = sourceBuilder.from(0)
.excludeVectors(false)
.storedFields(sfCtx)
.fetchSource(fsCtx)
.fetchField(InferenceMetadataFieldsMapper.NAME)
.fetchField(diversificationField);
return super.finalizeSourceBuilder(builder);
}

@Override
Expand Down Expand Up @@ -351,15 +369,10 @@ protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, b
RankDoc[] results = new RankDoc[scoreDocs.length];
Map<Integer, VectorData> fieldVectors = new HashMap<>();
for (int i = 0; i < scoreDocs.length; i++) {
RankDocWithSearchHit asRankDoc = (RankDocWithSearchHit) scoreDocs[i];
RankDocWithDenseVector asRankDoc = (RankDocWithDenseVector) scoreDocs[i];
results[i] = asRankDoc;

var field = asRankDoc.hit().getFields().getOrDefault(diversificationField, null);
if (field != null) {
var fieldValue = field.getValue();
if (fieldValue != null) {
extractFieldVectorData(asRankDoc.rank, fieldValue, fieldVectors);
}
if (asRankDoc.vector() != null) {
fieldVectors.put(asRankDoc.rank, asRankDoc.vector());
}
}

Expand Down Expand Up @@ -397,72 +410,6 @@ private ResultDiversificationContext getResultDiversificationContext() {
throw new IllegalArgumentException("Unknown diversification type [" + diversificationType + "]");
}

private void extractFieldVectorData(int docId, Object fieldValue, Map<Integer, VectorData> fieldVectors) {
switch (fieldValue) {
case float[] floatArray -> {
fieldVectors.put(docId, new VectorData(floatArray));
return;
}
case byte[] byteArray -> {
fieldVectors.put(docId, new VectorData(byteArray));
return;
}
case Float[] boxedFloatArray -> {
fieldVectors.put(docId, new VectorData(unboxedFloatArray(boxedFloatArray)));
return;
}
case Byte[] boxedByteArray -> {
fieldVectors.put(docId, new VectorData(unboxedByteArray(boxedByteArray)));
return;
}
default -> {
}
}

// CCS search returns a generic Object[] array, so we must
// examine the individual element type here.
if (fieldValue instanceof Object[] objectArray) {
if (objectArray.length == 0) {
return;
}

if (objectArray[0] instanceof Byte) {
Byte[] asByteArray = Arrays.stream(objectArray).map(x -> (Byte) x).toArray(Byte[]::new);
fieldVectors.put(docId, new VectorData(unboxedByteArray(asByteArray)));
return;
}

if (objectArray[0] instanceof Float) {
Float[] asFloatArray = Arrays.stream(objectArray).map(x -> (Float) x).toArray(Float[]::new);
fieldVectors.put(docId, new VectorData(unboxedFloatArray(asFloatArray)));
return;
}
}

throw new ElasticsearchStatusException(
String.format(Locale.ROOT, "Failed to retrieve vectors for field [%s]. Is it a [dense_vector] field?", diversificationField),
RestStatus.BAD_REQUEST
);
}

private static float[] unboxedFloatArray(Float[] array) {
float[] unboxedArray = new float[array.length];
int bIndex = 0;
for (Float b : array) {
unboxedArray[bIndex++] = b;
}
return unboxedArray;
}

private static byte[] unboxedByteArray(Byte[] array) {
byte[] unboxedArray = new byte[array.length];
int bIndex = 0;
for (Byte b : array) {
unboxedArray[bIndex++] = b;
}
return unboxedArray;
}

@Override
public String getName() {
return NAME;
Expand Down Expand Up @@ -498,7 +445,52 @@ protected void doToXContent(XContentBuilder builder, Params params) throws IOExc

@Override
protected RankDoc createRankDocFromHit(int docId, SearchHit hit, int shardRequestIndex) {
return new RankDocWithSearchHit(docId, hit.getScore(), shardRequestIndex, hit);
// first try and see if it's an inference field
VectorData vector = tryGetVectorFromInferenceField(hit);

if (vector == null) {
var field = hit.getFields().getOrDefault(diversificationField, null);
if (field != null) {
var fieldValue = field.getValues();
vector = extractFieldVectorData(fieldValue);
}
}

return new RankDocWithDenseVector(docId, hit.getScore(), shardRequestIndex, vector);
}

private VectorData tryGetVectorFromInferenceField(SearchHit hit) {
var inferenceFields = hit.getFields().getOrDefault(InferenceMetadataFieldsMapper.NAME, null);
if (inferenceFields == null) {
return null;
}

var fieldValues = inferenceFields.getValues();
if (fieldValues == null || fieldValues.isEmpty()) {
return null;
}

if (fieldValues.getFirst() instanceof Map<?, ?> mappedValues) {
var fieldValue = mappedValues.getOrDefault(diversificationField, null);
if (fieldValue instanceof ResultDiversificationDenseVectorSupplier vectorSupplier) {

InterceptedQueryBuilderWrapper interceptedWrapper = null;
var innerRetriever = this.innerRetrievers.getFirst();
var retrieverSource = innerRetriever.source();
if (retrieverSource != null && retrieverSource.subSearches().isEmpty() == false) {
for (var entry : retrieverSource.subSearches()) {
if (entry.getQueryBuilder() instanceof InterceptedQueryBuilderWrapper iqbw) {
interceptedWrapper = iqbw;
break;
}
}
}

return vectorSupplier.getDocumentVectorForSearchHit(diversificationField, hit, interceptedWrapper);
}
}

return null;
}

@Override
Expand All @@ -512,4 +504,78 @@ public boolean doEquals(Object o) {
|| (queryVector != null && other.queryVector != null && Objects.equals(queryVector.get(), other.queryVector.get())))
&& Objects.equals(this.queryVectorBuilder, other.queryVectorBuilder);
}

private VectorData extractFieldVectorData(Object fieldValue) {
if (fieldValue == null) {
return null;
}

switch (fieldValue) {
case float[] floatArray -> {
return new VectorData(floatArray);
}
case byte[] byteArray -> {
return new VectorData(byteArray);
}
case Float[] boxedFloatArray -> {
return new VectorData(unboxedFloatArray(boxedFloatArray));
}
case Byte[] boxedByteArray -> {
return new VectorData(unboxedByteArray(boxedByteArray));
}
default -> {
}
}

// could be an array list... let's make sure
if (fieldValue instanceof ArrayList<?> asArrayList && asArrayList.isEmpty() == false) {
Object firstElement = asArrayList.getFirst();
if (firstElement instanceof Float asFloat) {
var asFloatArray = asArrayList.toArray(new Float[0]);
return new VectorData(unboxedFloatArray(asFloatArray));
}
if (firstElement instanceof Byte) {
var asByteArray = asArrayList.toArray(new Byte[0]);
return new VectorData(unboxedByteArray(asByteArray));
}
}

// CCS search returns a generic Object[] array, so we must
// examine the individual element type here.
if (fieldValue instanceof Object[] objectArray) {
if (objectArray.length == 0) {
return null;
}

if (objectArray[0] instanceof Byte) {
Byte[] asByteArray = Arrays.stream(objectArray).map(x -> (Byte) x).toArray(Byte[]::new);
return new VectorData(unboxedByteArray(asByteArray));
}

if (objectArray[0] instanceof Float) {
Float[] asFloatArray = Arrays.stream(objectArray).map(x -> (Float) x).toArray(Float[]::new);
return new VectorData(unboxedFloatArray(asFloatArray));
}
}

return null;
}

private static float[] unboxedFloatArray(Float[] array) {
float[] unboxedArray = new float[array.length];
int bIndex = 0;
for (Float b : array) {
unboxedArray[bIndex++] = b;
}
return unboxedArray;
}

private static byte[] unboxedByteArray(Byte[] array) {
byte[] unboxedArray = new byte[array.length];
int bIndex = 0;
for (Byte b : array) {
unboxedArray[bIndex++] = b;
}
return unboxedArray;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.diversification;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.query.InterceptedQueryBuilderWrapper;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.vectors.VectorData;

public interface ResultDiversificationDenseVectorSupplier {
VectorData getDocumentVectorForSearchHit(
String diversificationField,
SearchHit hit,
@Nullable InterceptedQueryBuilderWrapper queryWrapper
);
}
Loading