Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/141788.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
area: Inference
issues: []
pr: 141788
summary: "[Inference API] Handle preconfigured endpoints with embedding task type"
type: enhancement
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.GP_LLM_V2_COMPLETION_ENDPOINT_ID;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.JINA_CLIP_V2_ENDPOINT_ID;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.JINA_EMBED_V3_ENDPOINT_ID;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.RERANK_V1_ENDPOINT_ID;
Expand Down Expand Up @@ -49,7 +50,7 @@ public void testGetDefaultEndpoints() throws IOException {
var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION);
var completionModels = getModels("_all", TaskType.COMPLETION);

assertThat(allModels, hasSize(9));
assertThat(allModels, hasSize(10));
assertThat(chatCompletionModels, hasSize(2));
assertThat(completionModels, hasSize(1));

Expand All @@ -66,6 +67,7 @@ public void testGetDefaultEndpoints() throws IOException {
assertInferenceIdTaskType(allModels, GP_LLM_V2_COMPLETION_ENDPOINT_ID, TaskType.COMPLETION);
assertInferenceIdTaskType(allModels, ELSER_V2_ENDPOINT_ID, TaskType.SPARSE_EMBEDDING);
assertInferenceIdTaskType(allModels, JINA_EMBED_V3_ENDPOINT_ID, TaskType.TEXT_EMBEDDING);
assertInferenceIdTaskType(allModels, JINA_CLIP_V2_ENDPOINT_ID, TaskType.EMBEDDING);
assertInferenceIdTaskType(allModels, RERANK_V1_ENDPOINT_ID, TaskType.RERANK);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,10 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
}

public void testGetServicesWithEmbeddingTaskType() throws IOException {
assertThat(providersFor(TaskType.EMBEDDING), containsInAnyOrder(List.of("text_embedding_test_service", "jinaai").toArray()));
assertThat(
providersFor(TaskType.EMBEDDING),
containsInAnyOrder(List.of("text_embedding_test_service", "jinaai", "elastic").toArray())
);
}

private List<Object> getAllServices() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ private static ElasticInferenceServiceModel createModel(
case CHAT_COMPLETION -> createCompletionModel(authorizedEndpoint, TaskType.CHAT_COMPLETION, components);
case COMPLETION -> createCompletionModel(authorizedEndpoint, TaskType.COMPLETION, components);
case SPARSE_EMBEDDING -> createSparseTextEmbeddingsModel(authorizedEndpoint, components);
case TEXT_EMBEDDING -> createDenseTextEmbeddingsModel(authorizedEndpoint, components);
case TEXT_EMBEDDING, EMBEDDING -> createDenseEmbeddingsModel(authorizedEndpoint, components, taskType);
case RERANK -> createRerankModel(authorizedEndpoint, components);
default -> {
logger.info(UNSUPPORTED_TASK_TYPE_LOG_MESSAGE, authorizedEndpoint.id(), taskType);
Expand Down Expand Up @@ -166,16 +166,17 @@ private static Map<String, Object> getChunkingSettingsMap(
return Objects.requireNonNullElse(configuration.chunkingSettings(), new HashMap<>());
}

private static ElasticInferenceServiceDenseEmbeddingsModel createDenseTextEmbeddingsModel(
private static ElasticInferenceServiceDenseEmbeddingsModel createDenseEmbeddingsModel(
ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint,
ElasticInferenceServiceComponents components
ElasticInferenceServiceComponents components,
TaskType taskType
) {
var config = getConfigurationOrEmpty(authorizedEndpoint);
validateConfigurationForTextEmbedding(config);
validateConfigurationForDenseEmbedding(config, taskType);

return new ElasticInferenceServiceDenseEmbeddingsModel(
authorizedEndpoint.id(),
TaskType.TEXT_EMBEDDING,
taskType,
ElasticInferenceService.NAME,
new ElasticInferenceServiceDenseEmbeddingsServiceSettings(
authorizedEndpoint.modelName(),
Expand All @@ -190,22 +191,13 @@ private static ElasticInferenceServiceDenseEmbeddingsModel createDenseTextEmbedd
);
}

private static void validateConfigurationForTextEmbedding(ElasticInferenceServiceAuthorizationResponseEntity.Configuration config) {
validateFieldPresent(
ElasticInferenceServiceAuthorizationResponseEntity.Configuration.ELEMENT_TYPE,
config.elementType(),
TaskType.TEXT_EMBEDDING
);
validateFieldPresent(
ElasticInferenceServiceAuthorizationResponseEntity.Configuration.DIMENSIONS,
config.dimensions(),
TaskType.TEXT_EMBEDDING
);
validateFieldPresent(
ElasticInferenceServiceAuthorizationResponseEntity.Configuration.SIMILARITY,
config.similarity(),
TaskType.TEXT_EMBEDDING
);
private static void validateConfigurationForDenseEmbedding(
ElasticInferenceServiceAuthorizationResponseEntity.Configuration config,
TaskType taskType
) {
validateFieldPresent(ElasticInferenceServiceAuthorizationResponseEntity.Configuration.ELEMENT_TYPE, config.elementType(), taskType);
validateFieldPresent(ElasticInferenceServiceAuthorizationResponseEntity.Configuration.DIMENSIONS, config.dimensions(), taskType);
validateFieldPresent(ElasticInferenceServiceAuthorizationResponseEntity.Configuration.SIMILARITY, config.similarity(), taskType);

var configElementType = config.elementType().toLowerCase(Locale.ROOT);
var supportedElementTypes = getSupportedElementTypes();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.Set;

import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.EIS_CHAT_PATH;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.EIS_MULTIMODAL_EMBED_PATH;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.EIS_SPARSE_PATH;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.EIS_TEXT_EMBED_PATH;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.createTaskTypeObject;
Expand Down Expand Up @@ -600,12 +601,14 @@ public void testReturnsAuthorizedEndpoints_FiltersUnsupportedElementType() {
public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() {
var idChat = "id_chat";
var idSparse = "id_sparse";
var idDense = "id_dense";
var idDenseMultimodal = "id_dense_multimodal";
var idDenseText = "id_dense_text";
var idRerank = "id_rerank";

var nameChat = "chat_model";
var nameSparse = "sparse_model";
var nameDense = "dense_model";
var nameDenseMultimodal = "dense_multimodal_model";
var nameDenseText = "dense_text_model";
var nameRerank = "rerank_model";

var similarity = SimilarityMeasure.COSINE;
Expand All @@ -614,6 +617,12 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() {

var url = "base_url";

var denseEmbeddingConfiguration = new ElasticInferenceServiceAuthorizationResponseEntity.Configuration(
similarity.toString(),
dimensions,
elementType,
null
);
var response = new ElasticInferenceServiceAuthorizationResponseEntity(
List.of(
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint(
Expand All @@ -637,19 +646,24 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() {
null
),
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint(
idDense,
nameDense,
idDenseMultimodal,
nameDenseMultimodal,
createTaskTypeObject(EIS_MULTIMODAL_EMBED_PATH, TaskType.EMBEDDING.toString()),
"ga",
null,
"",
"",
denseEmbeddingConfiguration
),
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint(
idDenseText,
nameDenseText,
createTaskTypeObject(EIS_TEXT_EMBED_PATH, TaskType.TEXT_EMBEDDING.toString()),
"ga",
null,
"",
"",
new ElasticInferenceServiceAuthorizationResponseEntity.Configuration(
similarity.toString(),
dimensions,
elementType,
null
)
denseEmbeddingConfiguration
),
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint(
idRerank,
Expand All @@ -666,8 +680,8 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() {

var auth = ElasticInferenceServiceAuthorizationModel.of(response, url);

var endpoints = auth.getEndpoints(Set.of(idChat, idSparse, idDense, idRerank));
assertThat(endpoints.size(), is(4));
var endpoints = auth.getEndpoints(Set.of(idChat, idSparse, idDenseMultimodal, idDenseText, idRerank));
assertThat(endpoints.size(), is(5));
assertThat(
endpoints,
containsInAnyOrder(
Expand All @@ -691,10 +705,20 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() {
ChunkingSettingsBuilder.DEFAULT_SETTINGS
),
new ElasticInferenceServiceDenseEmbeddingsModel(
idDense,
idDenseMultimodal,
TaskType.EMBEDDING,
ElasticInferenceService.NAME,
new ElasticInferenceServiceDenseEmbeddingsServiceSettings(nameDenseMultimodal, similarity, dimensions, null),
EmptyTaskSettings.INSTANCE,
EmptySecretSettings.INSTANCE,
new ElasticInferenceServiceComponents(url),
ChunkingSettingsBuilder.DEFAULT_SETTINGS
),
new ElasticInferenceServiceDenseEmbeddingsModel(
idDenseText,
TaskType.TEXT_EMBEDDING,
ElasticInferenceService.NAME,
new ElasticInferenceServiceDenseEmbeddingsServiceSettings(nameDense, similarity, dimensions, null),
new ElasticInferenceServiceDenseEmbeddingsServiceSettings(nameDenseText, similarity, dimensions, null),
EmptyTaskSettings.INSTANCE,
EmptySecretSettings.INSTANCE,
new ElasticInferenceServiceComponents(url),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException {
TaskType.CHAT_COMPLETION,
TaskType.SPARSE_EMBEDDING,
TaskType.TEXT_EMBEDDING,
TaskType.EMBEDDING,
TaskType.RERANK,
TaskType.COMPLETION
)
Expand Down Expand Up @@ -339,6 +340,7 @@ private void assertReturnsValidResponse(
TaskType.CHAT_COMPLETION,
TaskType.SPARSE_EMBEDDING,
TaskType.TEXT_EMBEDDING,
TaskType.EMBEDDING,
TaskType.RERANK,
TaskType.COMPLETION
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ public class ElasticInferenceServiceAuthorizationResponseEntityTests extends Abs
public static final String EIS_TEXT_EMBED_PATH = "embed/text/dense";

// multimodal embedding
public static final String JINA_CLIP_V2_ENDPOINT_ID = ".jina-clip-v2";
public static final String JINA_CLIP_V2_MODEL_NAME = "jina-clip-v2";
public static final String EIS_MULTIMODAL_EMBED_PATH = "embed/dense";

// rerank-v1
Expand Down Expand Up @@ -252,6 +254,31 @@ public record EisAuthorizationResponse(
}
}
},
{
"id": ".jina-clip-v2",
"model_name": "jina-clip-v2",
"task_types": {
"eis": "embed/dense",
"elasticsearch": "embedding"
},
"status": "beta",
"properties": [
"multilingual",
"multimodal",
"open-weights"
],
"release_date": "2026-02-01",
"configuration": {
"similarity": "cosine",
"dimensions": 1024,
"element_type": "float",
"chunking_settings": {
"strategy": "word",
"max_chunk_size": 500,
"overlap": 2
}
}
},
{
"id": ".jina-reranker-v2",
"model_name": "jina-reranker-v2",
Expand Down Expand Up @@ -327,6 +354,7 @@ public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEn
createGpLlmV2CompletionAuthorizedEndpoint(),
createElserAuthorizedEndpoint(),
createJinaTextEmbedAuthorizedEndpoint(),
createJinaMultimodalEmbedAuthorizedEndpoint(),
createRerankV1AuthorizedEndpoint()
);

Expand All @@ -343,6 +371,7 @@ public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEn
createGpLlmV2CompletionExpectedEndpoint(url),
createElserExpectedEndpoint(url),
createJinaExpectedTextEmbeddingEndpoint(url),
createJinaExpectedMultimodalEmbeddingEndpoint(url),
createRerankV1ExpectedEndpoint(url)
),
inferenceIds
Expand Down Expand Up @@ -485,6 +514,37 @@ private static ElasticInferenceServiceModel createJinaExpectedTextEmbeddingEndpo
);
}

private static ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint createJinaMultimodalEmbedAuthorizedEndpoint() {
return new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint(
JINA_CLIP_V2_ENDPOINT_ID,
JINA_CLIP_V2_MODEL_NAME,
createTaskTypeObject(EIS_MULTIMODAL_EMBED_PATH, "embedding"),
"beta",
List.of("multilingual", "multimodal", "open-weights"),
"2026-02-01",
null,
new ElasticInferenceServiceAuthorizationResponseEntity.Configuration(
"cosine",
1024,
"float",
Map.of("strategy", "word", "max_chunk_size", 500, "overlap", 2)
)
);
}

private static ElasticInferenceServiceModel createJinaExpectedMultimodalEmbeddingEndpoint(String url) {
return new ElasticInferenceServiceDenseEmbeddingsModel(
JINA_CLIP_V2_ENDPOINT_ID,
TaskType.EMBEDDING,
ElasticInferenceService.NAME,
new ElasticInferenceServiceDenseEmbeddingsServiceSettings(JINA_CLIP_V2_MODEL_NAME, SimilarityMeasure.COSINE, 1024, null),
EmptyTaskSettings.INSTANCE,
EmptySecretSettings.INSTANCE,
new ElasticInferenceServiceComponents(url),
new WordBoundaryChunkingSettings(500, 2)
);
}

private static ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint createRerankV1AuthorizedEndpoint() {
return new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint(
RERANK_V1_ENDPOINT_ID,
Expand Down Expand Up @@ -632,6 +692,7 @@ public void testParseAllFields() throws IOException {
TaskType.CHAT_COMPLETION,
TaskType.SPARSE_EMBEDDING,
TaskType.TEXT_EMBEDDING,
TaskType.EMBEDDING,
TaskType.RERANK,
TaskType.COMPLETION
)
Expand Down