diff --git a/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java b/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java index 0aef7a3cd..83112d648 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java +++ b/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java @@ -1185,4 +1185,14 @@ public boolean getDisableOauthRefreshToken() { public boolean isTokenFederationEnabled() { return getParameter(DatabricksJdbcUrlParams.ENABLE_TOKEN_FEDERATION, "1").equals("1"); } + + @Override + public boolean isStreamingChunkProviderEnabled() { + return getParameter(DatabricksJdbcUrlParams.ENABLE_STREAMING_CHUNK_PROVIDER).equals("1"); + } + + @Override + public int getLinkPrefetchWindow() { + return Integer.parseInt(getParameter(DatabricksJdbcUrlParams.LINK_PREFETCH_WINDOW)); + } } diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/AbstractArrowResultChunk.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/AbstractArrowResultChunk.java index 6b364c676..93f366549 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/AbstractArrowResultChunk.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/AbstractArrowResultChunk.java @@ -116,6 +116,15 @@ public Long getChunkIndex() { return chunkIndex; } + /** + * Returns the starting row offset for this chunk. + * + * @return the row offset + */ + public long getRowOffset() { + return rowOffset; + } + /** * Checks if the chunk link is invalid or expired. * diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowResultChunk.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowResultChunk.java index 9852dc8af..9aba1304d 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowResultChunk.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowResultChunk.java @@ -158,7 +158,7 @@ private void logDownloadMetrics( double speedMBps = (contentLength / 1024.0 / 1024.0) / (downloadTimeMs / 1000.0); String baseUrl = url.split("\\?")[0]; - LOGGER.info( + LOGGER.debug( String.format( "CloudFetch download: %.4f MB/s, %d bytes in %dms from %s", speedMBps, contentLength, downloadTimeMs, baseUrl)); @@ -197,6 +197,23 @@ public Builder withChunkInfo(BaseChunkInfo baseChunkInfo) { return this; } + /** + * Sets chunk metadata directly without requiring a BaseChunkInfo object. Useful for streaming + * chunk creation where metadata comes from ExternalLink. + * + * @param chunkIndex The index of this chunk + * @param rowCount The number of rows in this chunk + * @param rowOffset The starting row offset for this chunk + * @return this builder + */ + public Builder withChunkMetadata(long chunkIndex, long rowCount, long rowOffset) { + this.chunkIndex = chunkIndex; + this.numRows = rowCount; + this.rowOffset = rowOffset; + this.status = status == null ? ChunkStatus.PENDING : status; + return this; + } + public Builder withInputStream(InputStream stream, long rowCount) { this.numRows = rowCount; this.inputStream = stream; diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResult.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResult.java index c86c27447..98a021554 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResult.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResult.java @@ -1,9 +1,11 @@ package com.databricks.jdbc.api.impl.arrow; +import static com.databricks.jdbc.common.util.DatabricksThriftUtil.createExternalLink; import static com.databricks.jdbc.common.util.DatabricksThriftUtil.getColumnInfoFromTColumnDesc; import com.databricks.jdbc.api.impl.ComplexDataTypeParser; import com.databricks.jdbc.api.impl.IExecutionResult; +import com.databricks.jdbc.api.internal.IDatabricksConnectionContext; import com.databricks.jdbc.api.internal.IDatabricksSession; import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; import com.databricks.jdbc.common.CompressionCodec; @@ -16,12 +18,16 @@ import com.databricks.jdbc.model.client.thrift.generated.TColumnDesc; import com.databricks.jdbc.model.client.thrift.generated.TFetchResultsResp; import com.databricks.jdbc.model.client.thrift.generated.TGetResultSetMetadataResp; +import com.databricks.jdbc.model.client.thrift.generated.TSparkArrowResultLink; +import com.databricks.jdbc.model.core.ChunkLinkFetchResult; import com.databricks.jdbc.model.core.ColumnInfo; import com.databricks.jdbc.model.core.ColumnInfoTypeName; +import com.databricks.jdbc.model.core.ExternalLink; import com.databricks.jdbc.model.core.ResultData; import com.databricks.jdbc.model.core.ResultManifest; import com.google.common.annotations.VisibleForTesting; import java.util.ArrayList; +import java.util.Collection; import java.util.List; /** Result container for Arrow-based query results. */ @@ -69,13 +75,7 @@ public ArrowStreamResult( "Creating ArrowStreamResult with remote links for statementId: {}", statementId.toSQLExecStatementId()); this.chunkProvider = - new RemoteChunkProvider( - statementId, - resultManifest, - resultData, - session, - httpClient, - session.getConnectionContext().getCloudFetchThreadPoolSize()); + createRemoteChunkProvider(statementId, resultManifest, resultData, session, httpClient); } this.columnInfos = resultManifest.getSchema().getColumnCount() == 0 @@ -83,6 +83,63 @@ public ArrowStreamResult( : new ArrayList<>(resultManifest.getSchema().getColumns()); } + /** + * Creates the appropriate remote chunk provider based on configuration. + * + * @param statementId The statement ID + * @param resultManifest The result manifest containing chunk metadata + * @param resultData The result data containing initial external links + * @param session The session for fetching additional chunks + * @param httpClient The HTTP client for downloading chunk data + * @return A ChunkProvider instance + */ + private static ChunkProvider createRemoteChunkProvider( + StatementId statementId, + ResultManifest resultManifest, + ResultData resultData, + IDatabricksSession session, + IDatabricksHttpClient httpClient) + throws DatabricksSQLException { + + IDatabricksConnectionContext connectionContext = session.getConnectionContext(); + + if (connectionContext.isStreamingChunkProviderEnabled()) { + LOGGER.info( + "Using StreamingChunkProvider for statementId: {}", statementId.toSQLExecStatementId()); + + ChunkLinkFetcher linkFetcher = new SeaChunkLinkFetcher(session, statementId); + CompressionCodec compressionCodec = resultManifest.getResultCompression(); + int maxChunksInMemory = connectionContext.getCloudFetchThreadPoolSize(); + int linkPrefetchWindow = connectionContext.getLinkPrefetchWindow(); + int chunkReadyTimeoutSeconds = connectionContext.getChunkReadyTimeoutSeconds(); + double cloudFetchSpeedThreshold = connectionContext.getCloudFetchSpeedThreshold(); + + // Convert ExternalLinks to ChunkLinkFetchResult for the provider + ChunkLinkFetchResult initialLinks = + convertToChunkLinkFetchResult(resultData.getExternalLinks()); + + return new StreamingChunkProvider( + linkFetcher, + httpClient, + compressionCodec, + statementId, + maxChunksInMemory, + linkPrefetchWindow, + chunkReadyTimeoutSeconds, + cloudFetchSpeedThreshold, + initialLinks); + } else { + // Use the original RemoteChunkProvider + return new RemoteChunkProvider( + statementId, + resultManifest, + resultData, + session, + httpClient, + connectionContext.getCloudFetchThreadPoolSize()); + } + } + public ArrowStreamResult( TFetchResultsResp resultsResp, boolean isInlineArrow, @@ -110,16 +167,63 @@ public ArrowStreamResult( if (isInlineArrow) { this.chunkProvider = new InlineChunkProvider(resultsResp, parentStatement, session); } else { - CompressionCodec compressionCodec = - CompressionCodec.getCompressionMapping(resultsResp.getResultSetMetadata()); this.chunkProvider = - new RemoteChunkProvider( - parentStatement, - resultsResp, - session, - httpClient, - session.getConnectionContext().getCloudFetchThreadPoolSize(), - compressionCodec); + createThriftRemoteChunkProvider(resultsResp, parentStatement, session, httpClient); + } + } + + /** + * Creates the appropriate remote chunk provider for Thrift based on configuration. + * + * @param resultsResp The Thrift fetch results response + * @param parentStatement The parent statement for fetching additional chunks + * @param session The session for fetching additional chunks + * @param httpClient The HTTP client for downloading chunk data + * @return A ChunkProvider instance + */ + private static ChunkProvider createThriftRemoteChunkProvider( + TFetchResultsResp resultsResp, + IDatabricksStatementInternal parentStatement, + IDatabricksSession session, + IDatabricksHttpClient httpClient) + throws DatabricksSQLException { + + IDatabricksConnectionContext connectionContext = session.getConnectionContext(); + CompressionCodec compressionCodec = + CompressionCodec.getCompressionMapping(resultsResp.getResultSetMetadata()); + + if (connectionContext.isStreamingChunkProviderEnabled()) { + StatementId statementId = parentStatement.getStatementId(); + LOGGER.info("Using StreamingChunkProvider for Thrift statementId: {}", statementId); + + ChunkLinkFetcher linkFetcher = new ThriftChunkLinkFetcher(session, statementId); + int maxChunksInMemory = connectionContext.getCloudFetchThreadPoolSize(); + int linkPrefetchWindow = connectionContext.getLinkPrefetchWindow(); + int chunkReadyTimeoutSeconds = connectionContext.getChunkReadyTimeoutSeconds(); + double cloudFetchSpeedThreshold = connectionContext.getCloudFetchSpeedThreshold(); + + // Convert initial Thrift links to ChunkLinkFetchResult + ChunkLinkFetchResult initialLinks = convertThriftLinksToChunkLinkFetchResult(resultsResp); + + return new StreamingChunkProvider( + linkFetcher, + httpClient, + compressionCodec, + statementId, + maxChunksInMemory, + linkPrefetchWindow, + chunkReadyTimeoutSeconds, + cloudFetchSpeedThreshold, + initialLinks); + } else { + // Use the original RemoteChunkProvider + return new RemoteChunkProvider( + parentStatement, + resultsResp, + session, + httpClient, + connectionContext.getCloudFetchThreadPoolSize(), + compressionCodec); } } @@ -268,4 +372,79 @@ private void setColumnInfo(TGetResultSetMetadataResp resultManifest) { columnInfos.add(getColumnInfoFromTColumnDesc(tColumnDesc)); } } + + /** + * Converts a collection of ExternalLinks to a ChunkLinkFetchResult. + * + * @param externalLinks The external links to convert, may be null + * @return A ChunkLinkFetchResult, or null if input is null or empty + */ + private static ChunkLinkFetchResult convertToChunkLinkFetchResult( + Collection externalLinks) { + if (externalLinks == null || externalLinks.isEmpty()) { + return null; + } + + List linkList = + externalLinks instanceof List + ? (List) externalLinks + : new ArrayList<>(externalLinks); + + // Derive hasMore and nextRowOffset from last link (SEA style) + ExternalLink lastLink = linkList.get(linkList.size() - 1); + boolean hasMore = lastLink.getNextChunkIndex() != null; + long nextFetchIndex = hasMore ? lastLink.getNextChunkIndex() : -1; + long nextRowOffset = lastLink.getRowOffset() + lastLink.getRowCount(); + + return ChunkLinkFetchResult.of(linkList, hasMore, nextFetchIndex, nextRowOffset); + } + + /** + * Converts Thrift result links to a ChunkLinkFetchResult. + * + *

This method converts TSparkArrowResultLink from the Thrift response to the unified + * ChunkLinkFetchResult format used by StreamingChunkProvider. + * + * @param resultsResp The Thrift fetch results response containing initial links + * @return A ChunkLinkFetchResult, or null if no links + */ + private static ChunkLinkFetchResult convertThriftLinksToChunkLinkFetchResult( + TFetchResultsResp resultsResp) { + List resultLinks = resultsResp.getResults().getResultLinks(); + if (resultLinks == null || resultLinks.isEmpty()) { + return null; + } + + List chunkLinks = new ArrayList<>(); + int lastIndex = resultLinks.size() - 1; + boolean hasMoreRows = resultsResp.hasMoreRows; + + for (int i = 0; i < resultLinks.size(); i++) { + TSparkArrowResultLink thriftLink = resultLinks.get(i); + + // Convert Thrift link to ExternalLink (sets chunkIndex, rowOffset, rowCount, etc.) + ExternalLink externalLink = createExternalLink(thriftLink, i); + + // For the last link, set nextChunkIndex based on hasMoreRows + if (i == lastIndex) { + if (hasMoreRows) { + // More chunks available - next fetch should start from lastIndex + 1 + externalLink.setNextChunkIndex((long) i + 1); + } + // If hasMoreRows is false, nextChunkIndex remains null (end of stream) + } else { + // Not the last link - next chunk follows immediately + externalLink.setNextChunkIndex((long) i + 1); + } + + chunkLinks.add(externalLink); + } + + // Calculate next fetch positions from last link + TSparkArrowResultLink lastThriftLink = resultLinks.get(lastIndex); + long nextFetchIndex = hasMoreRows ? lastIndex + 1 : -1; + long nextRowOffset = lastThriftLink.getStartRowOffset() + lastThriftLink.getRowCount(); + + return ChunkLinkFetchResult.of(chunkLinks, hasMoreRows, nextFetchIndex, nextRowOffset); + } } diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/ChunkLinkDownloadService.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/ChunkLinkDownloadService.java index 7d59ea13c..b04fdfd10 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/ChunkLinkDownloadService.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/ChunkLinkDownloadService.java @@ -9,9 +9,9 @@ import com.databricks.jdbc.exception.DatabricksValidationException; import com.databricks.jdbc.log.JdbcLogger; import com.databricks.jdbc.log.JdbcLoggerFactory; +import com.databricks.jdbc.model.core.ChunkLinkFetchResult; import com.databricks.jdbc.model.core.ExternalLink; import java.time.Instant; -import java.util.Collection; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; @@ -219,16 +219,18 @@ private void triggerNextBatchDownload() { CompletableFuture.runAsync( () -> { try { - Collection links = - session.getDatabricksClient().getResultChunks(statementId, batchStartIndex); + // rowOffset is 0 here as this service is used by RemoteChunkProvider (SEA-only) + // which fetches by chunkIndex, not rowOffset + ChunkLinkFetchResult result = + session.getDatabricksClient().getResultChunks(statementId, batchStartIndex, 0); LOGGER.info( "Retrieved {} links for batch starting at {} for statement id {}", - links.size(), + result.getChunkLinks().size(), batchStartIndex, statementId); // Complete futures for all chunks in this batch - for (ExternalLink link : links) { + for (ExternalLink link : result.getChunkLinks()) { CompletableFuture future = chunkIndexToLinkFuture.get(link.getChunkIndex()); if (future != null) { @@ -241,9 +243,12 @@ private void triggerNextBatchDownload() { } // Update next batch start index and trigger next batch - if (!links.isEmpty()) { + if (!result.getChunkLinks().isEmpty()) { long maxChunkIndex = - links.stream().mapToLong(ExternalLink::getChunkIndex).max().getAsLong(); + result.getChunkLinks().stream() + .mapToLong(ExternalLink::getChunkIndex) + .max() + .getAsLong(); nextBatchStartIndex.set(maxChunkIndex + 1); LOGGER.debug("Updated next batch start index to {}", maxChunkIndex + 1); diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/ChunkLinkFetcher.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/ChunkLinkFetcher.java new file mode 100644 index 000000000..6d6a88e56 --- /dev/null +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/ChunkLinkFetcher.java @@ -0,0 +1,51 @@ +package com.databricks.jdbc.api.impl.arrow; + +import com.databricks.jdbc.exception.DatabricksSQLException; +import com.databricks.jdbc.model.core.ChunkLinkFetchResult; +import com.databricks.jdbc.model.core.ExternalLink; + +/** + * Abstraction for fetching chunk links from either SEA or Thrift backend. Implementations handle + * the protocol-specific details of how links are retrieved. + * + *

This interface enables a unified streaming approach for chunk downloads regardless of the + * underlying client type (SEA or Thrift). + */ +public interface ChunkLinkFetcher { + + /** + * Fetches the next batch of chunk links starting from the given position. + * + *

The implementation may return one or more links in a single call. The returned {@link + * ChunkLinkFetchResult} indicates whether more chunks are available. + * + *

SEA implementations use startChunkIndex while Thrift implementations use startRowOffset. + * Each implementation uses the parameter relevant to its protocol and ignores the other. + * + * @param startChunkIndex The chunk index to start fetching from (used by SEA) + * @param startRowOffset The row offset to start fetching from (used by Thrift with + * FETCH_ABSOLUTE) + * @return ChunkLinkFetchResult containing the fetched links and continuation information + * @throws DatabricksSQLException if the fetch operation fails + */ + ChunkLinkFetchResult fetchLinks(long startChunkIndex, long startRowOffset) + throws DatabricksSQLException; + + /** + * Refetches a specific chunk link that may have expired. + * + *

This is used when a previously fetched link has expired before the chunk could be + * downloaded. Both SEA and Thrift clients support this via the getResultChunks API. + * + *

SEA uses chunkIndex while Thrift uses rowOffset to identify the chunk to refetch. + * + * @param chunkIndex The specific chunk index to refetch (used by SEA) + * @param rowOffset The row offset of the chunk to refetch (used by Thrift) + * @return The refreshed ExternalLink with a new expiration time + * @throws DatabricksSQLException if the refetch operation fails + */ + ExternalLink refetchLink(long chunkIndex, long rowOffset) throws DatabricksSQLException; + + /** Closes any resources held by the fetcher. */ + void close(); +} diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/SeaChunkLinkFetcher.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/SeaChunkLinkFetcher.java new file mode 100644 index 000000000..0f6930631 --- /dev/null +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/SeaChunkLinkFetcher.java @@ -0,0 +1,81 @@ +package com.databricks.jdbc.api.impl.arrow; + +import com.databricks.jdbc.api.internal.IDatabricksSession; +import com.databricks.jdbc.dbclient.impl.common.StatementId; +import com.databricks.jdbc.exception.DatabricksSQLException; +import com.databricks.jdbc.log.JdbcLogger; +import com.databricks.jdbc.log.JdbcLoggerFactory; +import com.databricks.jdbc.model.core.ChunkLinkFetchResult; +import com.databricks.jdbc.model.core.ExternalLink; +import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; + +/** + * ChunkLinkFetcher implementation for the SQL Execution API (SEA) client. + * + *

SEA provides chunk links via the getResultChunks API, which returns links with nextChunkIndex + * to indicate continuation. When nextChunkIndex is null, it indicates no more chunks. + */ +public class SeaChunkLinkFetcher implements ChunkLinkFetcher { + + private static final JdbcLogger LOGGER = JdbcLoggerFactory.getLogger(SeaChunkLinkFetcher.class); + + private final IDatabricksSession session; + private final StatementId statementId; + + public SeaChunkLinkFetcher(IDatabricksSession session, StatementId statementId) { + this.session = session; + this.statementId = statementId; + LOGGER.debug("Created SeaChunkLinkFetcher for statement {}", statementId); + } + + @Override + public ChunkLinkFetchResult fetchLinks(long startChunkIndex, long startRowOffset) + throws DatabricksSQLException { + // SEA uses startChunkIndex; startRowOffset is ignored + LOGGER.debug( + "Fetching links starting from chunk index {} for statement {}", + startChunkIndex, + statementId); + + return session + .getDatabricksClient() + .getResultChunks(statementId, startChunkIndex, startRowOffset); + } + + @Override + public ExternalLink refetchLink(long chunkIndex, long rowOffset) throws DatabricksSQLException { + // SEA uses chunkIndex; rowOffset is ignored + LOGGER.info("Refetching expired link for chunk {} of statement {}", chunkIndex, statementId); + + ChunkLinkFetchResult result = + session.getDatabricksClient().getResultChunks(statementId, chunkIndex, rowOffset); + + if (result.isEndOfStream()) { + throw new DatabricksSQLException( + String.format("Failed to refetch link for chunk %d: no links returned", chunkIndex), + DatabricksDriverErrorCode.CHUNK_READY_ERROR); + } + + // Find the link for the requested chunk index + for (ExternalLink link : result.getChunkLinks()) { + if (link.getChunkIndex() == chunkIndex) { + LOGGER.debug( + "Successfully refetched link for chunk {} of statement {}", chunkIndex, statementId); + return link; + } + } + + // Exact match not found - this indicates a server bug + throw new DatabricksSQLException( + String.format( + "Failed to refetch link for chunk %d: server returned links but none matched requested index", + chunkIndex), + DatabricksDriverErrorCode.CHUNK_READY_ERROR); + } + + @Override + public void close() { + LOGGER.debug("Closing SeaChunkLinkFetcher for statement {}", statementId); + // No resources to clean up for SEA fetcher + } +} diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/StreamingChunkDownloadTask.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/StreamingChunkDownloadTask.java new file mode 100644 index 000000000..2bd8ed2eb --- /dev/null +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/StreamingChunkDownloadTask.java @@ -0,0 +1,113 @@ +package com.databricks.jdbc.api.impl.arrow; + +import com.databricks.jdbc.common.CompressionCodec; +import com.databricks.jdbc.dbclient.IDatabricksHttpClient; +import com.databricks.jdbc.exception.DatabricksSQLException; +import com.databricks.jdbc.log.JdbcLogger; +import com.databricks.jdbc.log.JdbcLoggerFactory; +import com.databricks.jdbc.model.core.ExternalLink; +import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; +import java.io.IOException; +import java.util.concurrent.Callable; + +/** + * A download task for streaming chunk provider. Simpler than ChunkDownloadTask - uses + * ChunkLinkFetcher directly for link refresh instead of ChunkLinkDownloadService. + */ +public class StreamingChunkDownloadTask implements Callable { + + private static final JdbcLogger LOGGER = + JdbcLoggerFactory.getLogger(StreamingChunkDownloadTask.class); + + private static final int MAX_RETRIES = 5; + private static final long RETRY_DELAY_MS = 1500; + + private final ArrowResultChunk chunk; + private final IDatabricksHttpClient httpClient; + private final CompressionCodec compressionCodec; + private final ChunkLinkFetcher linkFetcher; + private final double cloudFetchSpeedThreshold; + + public StreamingChunkDownloadTask( + ArrowResultChunk chunk, + IDatabricksHttpClient httpClient, + CompressionCodec compressionCodec, + ChunkLinkFetcher linkFetcher, + double cloudFetchSpeedThreshold) { + this.chunk = chunk; + this.httpClient = httpClient; + this.compressionCodec = compressionCodec; + this.linkFetcher = linkFetcher; + this.cloudFetchSpeedThreshold = cloudFetchSpeedThreshold; + } + + @Override + public Void call() throws DatabricksSQLException { + int retries = 0; + boolean downloadSuccessful = false; + + try { + while (!downloadSuccessful) { + try { + // Check if link is expired and refresh if needed + if (chunk.isChunkLinkInvalid()) { + LOGGER.debug("Link invalid for chunk {}, refetching", chunk.getChunkIndex()); + ExternalLink freshLink = + linkFetcher.refetchLink(chunk.getChunkIndex(), chunk.getRowOffset()); + chunk.setChunkLink(freshLink); + } + + // Perform the download + chunk.downloadData(httpClient, compressionCodec, cloudFetchSpeedThreshold); + downloadSuccessful = true; + + LOGGER.debug("Successfully downloaded chunk {}", chunk.getChunkIndex()); + + } catch (IOException | DatabricksSQLException e) { + retries++; + if (retries >= MAX_RETRIES) { + LOGGER.error( + "Failed to download chunk {} after {} attempts: {}", + chunk.getChunkIndex(), + MAX_RETRIES, + e.getMessage()); + // Status will be set to DOWNLOAD_FAILED in the finally block + throw new DatabricksSQLException( + String.format( + "Failed to download chunk %d after %d attempts", + chunk.getChunkIndex(), MAX_RETRIES), + e, + DatabricksDriverErrorCode.CHUNK_DOWNLOAD_ERROR); + } else { + LOGGER.warn( + "Retry {} for chunk {}: {}", retries, chunk.getChunkIndex(), e.getMessage()); + chunk.setStatus(ChunkStatus.DOWNLOAD_RETRY); + try { + Thread.sleep(RETRY_DELAY_MS); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new DatabricksSQLException( + "Chunk download interrupted", + ie, + DatabricksDriverErrorCode.THREAD_INTERRUPTED_ERROR); + } + } + } + } + } finally { + if (downloadSuccessful) { + chunk.getChunkReadyFuture().complete(null); + } else { + chunk.setStatus(ChunkStatus.DOWNLOAD_FAILED); + chunk + .getChunkReadyFuture() + .completeExceptionally( + new DatabricksSQLException( + "Download failed for chunk " + chunk.getChunkIndex(), + DatabricksDriverErrorCode.CHUNK_DOWNLOAD_ERROR)); + } + } + + return null; + } +} diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/StreamingChunkProvider.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/StreamingChunkProvider.java new file mode 100644 index 000000000..c50081eb5 --- /dev/null +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/StreamingChunkProvider.java @@ -0,0 +1,604 @@ +package com.databricks.jdbc.api.impl.arrow; + +import com.databricks.jdbc.common.CompressionCodec; +import com.databricks.jdbc.dbclient.IDatabricksHttpClient; +import com.databricks.jdbc.dbclient.impl.common.StatementId; +import com.databricks.jdbc.exception.DatabricksParsingException; +import com.databricks.jdbc.exception.DatabricksSQLException; +import com.databricks.jdbc.log.JdbcLogger; +import com.databricks.jdbc.log.JdbcLoggerFactory; +import com.databricks.jdbc.model.core.ChunkLinkFetchResult; +import com.databricks.jdbc.model.core.ExternalLink; +import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; +import javax.annotation.Nonnull; + +/** + * A streaming chunk provider that fetches chunk links proactively and downloads chunks in parallel. + * + *

Key features: + * + *

    + *
  • No dependency on total chunk count - streams until end of data + *
  • Proactive link prefetching with configurable window + *
  • Memory-bounded parallel downloads + *
  • Automatic link refresh on expiration + *
+ * + *

This provider uses two key windows: + * + *

    + *
  • Link prefetch window: How many links to fetch ahead of consumption + *
  • Download window: How many chunks to keep in memory (downloading or ready) + *
+ */ +public class StreamingChunkProvider implements ChunkProvider { + + private static final JdbcLogger LOGGER = + JdbcLoggerFactory.getLogger(StreamingChunkProvider.class); + private static final String DOWNLOAD_THREAD_PREFIX = "databricks-jdbc-streaming-downloader-"; + private static final String PREFETCH_THREAD_NAME = "databricks-jdbc-link-prefetcher"; + + // Configuration + private final int linkPrefetchWindow; + private final int maxChunksInMemory; + private final int chunkReadyTimeoutSeconds; + + // Dependencies + private final ChunkLinkFetcher linkFetcher; + private final IDatabricksHttpClient httpClient; + private final CompressionCodec compressionCodec; + private final StatementId statementId; + private final double cloudFetchSpeedThreshold; + + // Chunk storage + private final ConcurrentMap chunks = new ConcurrentHashMap<>(); + + // Position tracking + // Using AtomicLong for single-writer variables to make thread-safety explicit: + // - currentChunkIndex: written only by consumer thread + // - highestKnownChunkIndex: written only by prefetch thread (after construction) + // - nextDownloadIndex: written only under downloadLock, but AtomicLong for consistency + private final AtomicLong currentChunkIndex = new AtomicLong(-1); + private final AtomicLong highestKnownChunkIndex = new AtomicLong(-1); + private volatile long nextLinkFetchIndex = 0; + private volatile long nextRowOffsetToFetch = 0; + private final AtomicLong nextDownloadIndex = new AtomicLong(0); + + // State flags + private volatile boolean endOfStreamReached = false; + private volatile boolean closed = false; + private volatile DatabricksSQLException prefetchError = null; + + // Row tracking + private final AtomicLong totalRowCount = new AtomicLong(0); + + // Synchronization for prefetch thread + private final ReentrantLock prefetchLock = new ReentrantLock(); + private final Condition consumerAdvanced = prefetchLock.newCondition(); + private final Condition chunkCreated = prefetchLock.newCondition(); + + // Synchronization for download coordination. + // This lock is needed because triggerDownloads() is called from both the prefetch thread + // (via fetchNextLinkBatch) and the consumer thread (via releaseChunk), and the download + // logic reads multiple shared variables (chunksInMemory, nextDownloadIndex, + // highestKnownChunkIndex) + // that must be consistent within the loop. + private final ReentrantLock downloadLock = new ReentrantLock(); + + // Executors + private final ExecutorService downloadExecutor; + private final Thread linkPrefetchThread; + + // Track chunks currently in memory (for sliding window) + private final AtomicInteger chunksInMemory = new AtomicInteger(0); + + /** + * Creates a new StreamingChunkProvider. + * + * @param linkFetcher Fetcher for chunk links + * @param httpClient HTTP client for downloads + * @param compressionCodec Codec for decompressing chunk data + * @param statementId Statement ID for logging and chunk creation + * @param maxChunksInMemory Maximum chunks to keep in memory (download window) + * @param linkPrefetchWindow How many links to fetch ahead + * @param chunkReadyTimeoutSeconds Timeout waiting for chunk to be ready + * @param cloudFetchSpeedThreshold Speed threshold for logging warnings + * @param initialLinks Initial links provided with result data (avoids extra fetch), may be null + */ + public StreamingChunkProvider( + ChunkLinkFetcher linkFetcher, + IDatabricksHttpClient httpClient, + CompressionCodec compressionCodec, + StatementId statementId, + int maxChunksInMemory, + int linkPrefetchWindow, + int chunkReadyTimeoutSeconds, + double cloudFetchSpeedThreshold, + ChunkLinkFetchResult initialLinks) + throws DatabricksParsingException { + + this.linkFetcher = linkFetcher; + this.httpClient = httpClient; + this.compressionCodec = compressionCodec; + this.statementId = statementId; + this.maxChunksInMemory = maxChunksInMemory; + this.linkPrefetchWindow = linkPrefetchWindow; + this.chunkReadyTimeoutSeconds = chunkReadyTimeoutSeconds; + this.cloudFetchSpeedThreshold = cloudFetchSpeedThreshold; + + LOGGER.info( + "Creating StreamingChunkProvider for statement {}: maxChunksInMemory={}, linkPrefetchWindow={}", + statementId, + maxChunksInMemory, + linkPrefetchWindow); + + // Process initial links if provided + processInitialLinks(initialLinks); + + // Create download executor + this.downloadExecutor = createDownloadExecutor(maxChunksInMemory); + + // Start link prefetch thread + this.linkPrefetchThread = new Thread(this::linkPrefetchLoop, PREFETCH_THREAD_NAME); + this.linkPrefetchThread.setDaemon(true); + this.linkPrefetchThread.start(); + + // Trigger initial downloads and prefetch + triggerDownloads(); + notifyConsumerAdvanced(); + } + + // ==================== ChunkProvider Interface ==================== + + @Override + public boolean hasNextChunk() { + if (closed) { + return false; + } + + // If we haven't reached end of stream, there might be more + if (!endOfStreamReached) { + return true; + } + + // We've reached end of stream - check if there are unconsumed chunks + return currentChunkIndex.get() < highestKnownChunkIndex.get(); + } + + @Override + public boolean next() throws DatabricksSQLException { + if (closed) { + return false; + } + + // Release previous chunk if any + long prevIndex = currentChunkIndex.get(); + if (prevIndex >= 0) { + releaseChunk(prevIndex); + } + + if (!hasNextChunk()) { + return false; + } + + currentChunkIndex.incrementAndGet(); + + // Notify prefetch thread that consumer advanced + notifyConsumerAdvanced(); + + return true; + } + + @Override + public AbstractArrowResultChunk getChunk() throws DatabricksSQLException { + long chunkIdx = currentChunkIndex.get(); + if (chunkIdx < 0) { + return null; + } + + ArrowResultChunk chunk = chunks.get(chunkIdx); + + if (chunk == null) { + // Chunk not yet created - wait for it + LOGGER.debug("Chunk {} not yet available, waiting for prefetch", chunkIdx); + waitForChunkCreation(chunkIdx); + chunk = chunks.get(chunkIdx); + } + + if (chunk == null) { + throw new DatabricksSQLException( + "Chunk " + chunkIdx + " not found after waiting", + DatabricksDriverErrorCode.CHUNK_READY_ERROR); + } + + // Wait for chunk to be ready (downloaded and processed) + try { + chunk.waitForChunkReady(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new DatabricksSQLException( + "Interrupted waiting for chunk " + chunkIdx, + e, + DatabricksDriverErrorCode.THREAD_INTERRUPTED_ERROR); + } catch (ExecutionException e) { + throw new DatabricksSQLException( + "Failed to prepare chunk " + chunkIdx, + e.getCause(), + DatabricksDriverErrorCode.CHUNK_READY_ERROR); + } catch (TimeoutException e) { + throw new DatabricksSQLException( + "Timeout waiting for chunk " + chunkIdx + " (timeout: " + chunkReadyTimeoutSeconds + "s)", + DatabricksDriverErrorCode.CHUNK_READY_ERROR); + } + + return chunk; + } + + @Override + public void close() { + if (closed) { + return; + } + + LOGGER.info("Closing StreamingChunkProvider for statement {}", statementId); + closed = true; + + // Wake up any waiting threads so they can exit + notifyConsumerAdvanced(); + notifyChunkCreated(); + + // Interrupt prefetch thread + if (linkPrefetchThread != null) { + linkPrefetchThread.interrupt(); + } + + // Shutdown download executor + if (downloadExecutor != null) { + downloadExecutor.shutdownNow(); + } + + // Release all chunks + for (ArrowResultChunk chunk : chunks.values()) { + try { + chunk.releaseChunk(); + } catch (Exception e) { + LOGGER.warn("Error releasing chunk: {}", e.getMessage()); + } + } + chunks.clear(); + + // Close link fetcher + if (linkFetcher != null) { + linkFetcher.close(); + } + } + + @Override + public long getRowCount() { + return totalRowCount.get(); + } + + @Override + public long getChunkCount() { + // In streaming mode, we don't know total chunks until end of stream + if (endOfStreamReached) { + return highestKnownChunkIndex.get() + 1; + } + return -1; // Unknown + } + + @Override + public boolean isClosed() { + return closed; + } + + // ==================== Link Prefetch Logic ==================== + + private void linkPrefetchLoop() { + LOGGER.debug("Link prefetch thread started for statement {}", statementId); + + while (!closed && !Thread.currentThread().isInterrupted()) { + try { + prefetchLock.lock(); + try { + long targetIndex = currentChunkIndex.get() + linkPrefetchWindow; + + // Wait if we're caught up + while (!endOfStreamReached && nextLinkFetchIndex > targetIndex) { + if (closed) break; + LOGGER.debug( + "Prefetch caught up, waiting for consumer. next={}, target={}", + nextLinkFetchIndex, + targetIndex); + consumerAdvanced.await(); + targetIndex = currentChunkIndex.get() + linkPrefetchWindow; + } + } finally { + prefetchLock.unlock(); + } + + if (closed || endOfStreamReached) { + break; + } + + // Fetch next batch of links + fetchNextLinkBatch(); + + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOGGER.debug("Link prefetch thread interrupted"); + break; + } catch (DatabricksSQLException e) { + LOGGER.error("Error fetching links: {}", e.getMessage()); + prefetchError = e; + notifyChunkCreated(); // Wake up any waiting consumer to check the error + break; + } + } + + LOGGER.debug("Link prefetch thread exiting for statement {}", statementId); + } + + private void fetchNextLinkBatch() throws DatabricksSQLException { + if (endOfStreamReached || closed) { + return; + } + + LOGGER.debug( + "Fetching links starting from index {}, row offset {} for statement {}", + nextLinkFetchIndex, + nextRowOffsetToFetch, + statementId); + + ChunkLinkFetchResult result = linkFetcher.fetchLinks(nextLinkFetchIndex, nextRowOffsetToFetch); + + if (result.isEndOfStream()) { + LOGGER.info("End of stream reached for statement {}", statementId); + endOfStreamReached = true; + return; + } + + // Process received links - create chunks + for (ExternalLink link : result.getChunkLinks()) { + createChunkFromLink(link); + } + + // Update next fetch positions + if (result.hasMore()) { + nextLinkFetchIndex = result.getNextFetchIndex(); + nextRowOffsetToFetch = result.getNextRowOffset(); + } else { + endOfStreamReached = true; + LOGGER.info("End of stream reached for statement {} (hasMore=false)", statementId); + } + + // Trigger downloads for new chunks + triggerDownloads(); + } + + /** + * Processes initial links provided with the result data. This avoids an extra fetch call for + * links the server already provided. + * + * @param initialLinks The initial links from ResultData, may be null + */ + private void processInitialLinks(ChunkLinkFetchResult initialLinks) + throws DatabricksParsingException { + if (initialLinks == null) { + LOGGER.debug("No initial links provided for statement {}", statementId); + return; + } + + LOGGER.info( + "Processing {} initial links for statement {}", + initialLinks.getChunkLinks().size(), + statementId); + + for (ExternalLink link : initialLinks.getChunkLinks()) { + createChunkFromLink(link); + } + + // Set next fetch positions using unified API + if (initialLinks.hasMore()) { + nextLinkFetchIndex = initialLinks.getNextFetchIndex(); + nextRowOffsetToFetch = initialLinks.getNextRowOffset(); + LOGGER.debug( + "Next fetch position set to chunk index {}, row offset {} from initial links", + nextLinkFetchIndex, + nextRowOffsetToFetch); + } else { + endOfStreamReached = true; + LOGGER.info("End of stream reached from initial links for statement {}", statementId); + } + } + + /** + * Creates a chunk from an external link and registers it for download. + * + * @param link The external link containing chunkIndex, rowCount, rowOffset, and download URL + */ + private void createChunkFromLink(ExternalLink link) throws DatabricksParsingException { + long chunkIndex = link.getChunkIndex(); + if (chunks.containsKey(chunkIndex)) { + LOGGER.debug("Chunk {} already exists, skipping creation", chunkIndex); + return; + } + + long rowCount = link.getRowCount(); + long rowOffset = link.getRowOffset(); + + ArrowResultChunk chunk = + ArrowResultChunk.builder() + .withStatementId(statementId) + .withChunkMetadata(chunkIndex, rowCount, rowOffset) + .withChunkReadyTimeoutSeconds(chunkReadyTimeoutSeconds) + .build(); + + chunk.setChunkLink(link); + chunks.put(chunkIndex, chunk); + highestKnownChunkIndex.updateAndGet(current -> Math.max(current, chunkIndex)); + totalRowCount.addAndGet(rowCount); + + // Notify any waiting consumers that a chunk is available + notifyChunkCreated(); + + LOGGER.debug( + "Created chunk {} with {} rows for statement {}", chunkIndex, rowCount, statementId); + } + + // ==================== Download Coordination ==================== + + private void triggerDownloads() { + downloadLock.lock(); + try { + long downloadIdx = nextDownloadIndex.get(); + while (!closed + && chunksInMemory.get() < maxChunksInMemory + && downloadIdx <= highestKnownChunkIndex.get()) { + ArrowResultChunk chunk = chunks.get(downloadIdx); + + if (chunk == null) { + // Chunk not yet created, wait for prefetch + break; + } + + // Only submit if not already downloading/downloaded + ChunkStatus status = chunk.getStatus(); + if (status == ChunkStatus.PENDING || status == ChunkStatus.URL_FETCHED) { + submitDownloadTask(chunk); + chunksInMemory.incrementAndGet(); + } + + downloadIdx = nextDownloadIndex.incrementAndGet(); + } + } finally { + downloadLock.unlock(); + } + } + + private void submitDownloadTask(ArrowResultChunk chunk) { + LOGGER.debug("Submitting download task for chunk {}", chunk.getChunkIndex()); + + StreamingChunkDownloadTask task = + new StreamingChunkDownloadTask( + chunk, httpClient, compressionCodec, linkFetcher, cloudFetchSpeedThreshold); + + downloadExecutor.submit(task); + } + + // ==================== Resource Management ==================== + + private void releaseChunk(long chunkIndex) { + ArrowResultChunk chunk = chunks.get(chunkIndex); + if (chunk != null && chunk.releaseChunk()) { + chunks.remove(chunkIndex); + chunksInMemory.decrementAndGet(); + + LOGGER.debug("Released chunk {}, chunksInMemory={}", chunkIndex, chunksInMemory.get()); + + // Trigger more downloads to fill the freed slot + triggerDownloads(); + } + } + + /** + * Waits for a chunk to be created by the prefetch thread. + * + *

This method waits indefinitely for the chunk to be created, relying on the following exit + * conditions: + * + *

    + *
  • Chunk is created (success) + *
  • Provider is closed + *
  • Prefetch thread encountered an error + *
  • End of stream reached and chunk doesn't exist + *
  • Thread is interrupted + *
+ * + *

The overall timeout for chunk retrieval is enforced by {@link + * ArrowResultChunk#waitForChunkReady()} which has a configurable timeout. + */ + private void waitForChunkCreation(long chunkIndex) throws DatabricksSQLException { + prefetchLock.lock(); + try { + while (!closed && !chunks.containsKey(chunkIndex)) { + // Check if prefetch thread encountered an error + if (prefetchError != null) { + throw new DatabricksSQLException( + "Link prefetch failed: " + prefetchError.getMessage(), + prefetchError, + DatabricksDriverErrorCode.CHUNK_READY_ERROR); + } + + long highestKnown = highestKnownChunkIndex.get(); + if (endOfStreamReached && chunkIndex > highestKnown) { + throw new DatabricksSQLException( + "Chunk " + chunkIndex + " does not exist (highest known: " + highestKnown + ")", + DatabricksDriverErrorCode.CHUNK_READY_ERROR); + } + + try { + chunkCreated.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new DatabricksSQLException( + "Interrupted waiting for chunk creation", + e, + DatabricksDriverErrorCode.THREAD_INTERRUPTED_ERROR); + } + } + } finally { + prefetchLock.unlock(); + } + } + + // ==================== Synchronization Helpers ==================== + + private void notifyConsumerAdvanced() { + prefetchLock.lock(); + try { + consumerAdvanced.signalAll(); + } finally { + prefetchLock.unlock(); + } + } + + private void notifyChunkCreated() { + prefetchLock.lock(); + try { + chunkCreated.signalAll(); + } finally { + prefetchLock.unlock(); + } + } + + // ==================== Executor Creation ==================== + + private ExecutorService createDownloadExecutor(int poolSize) { + ThreadFactory threadFactory = + new ThreadFactory() { + private final AtomicInteger threadCount = new AtomicInteger(1); + + @Override + public Thread newThread(@Nonnull Runnable r) { + Thread thread = new Thread(r); + thread.setName(DOWNLOAD_THREAD_PREFIX + threadCount.getAndIncrement()); + thread.setDaemon(true); + return thread; + } + }; + + return Executors.newFixedThreadPool(poolSize, threadFactory); + } +} diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/ThriftChunkLinkFetcher.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/ThriftChunkLinkFetcher.java new file mode 100644 index 000000000..d38653271 --- /dev/null +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/ThriftChunkLinkFetcher.java @@ -0,0 +1,116 @@ +package com.databricks.jdbc.api.impl.arrow; + +import com.databricks.jdbc.api.internal.IDatabricksSession; +import com.databricks.jdbc.dbclient.impl.common.StatementId; +import com.databricks.jdbc.exception.DatabricksSQLException; +import com.databricks.jdbc.log.JdbcLogger; +import com.databricks.jdbc.log.JdbcLoggerFactory; +import com.databricks.jdbc.model.core.ChunkLinkFetchResult; +import com.databricks.jdbc.model.core.ExternalLink; +import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; + +/** + * ChunkLinkFetcher implementation for the Thrift client. + * + *

Thrift provides chunk links via the getResultChunks API, which returns links with + * nextChunkIndex to indicate continuation. When nextChunkIndex is null, it indicates no more + * chunks. + */ +public class ThriftChunkLinkFetcher implements ChunkLinkFetcher { + + private static final JdbcLogger LOGGER = + JdbcLoggerFactory.getLogger(ThriftChunkLinkFetcher.class); + + private final IDatabricksSession session; + private final StatementId statementId; + + public ThriftChunkLinkFetcher(IDatabricksSession session, StatementId statementId) { + this.session = session; + this.statementId = statementId; + LOGGER.debug("Created ThriftChunkLinkFetcher for statement {}", statementId); + } + + @Override + public ChunkLinkFetchResult fetchLinks(long startChunkIndex, long startRowOffset) + throws DatabricksSQLException { + // Thrift uses startRowOffset with FETCH_ABSOLUTE; startChunkIndex is used for metadata + LOGGER.debug( + "Fetching links starting from chunk index {}, row offset {} for statement {}", + startChunkIndex, + startRowOffset, + statementId); + + return session + .getDatabricksClient() + .getResultChunks(statementId, startChunkIndex, startRowOffset); + } + + @Override + public ExternalLink refetchLink(long chunkIndex, long rowOffset) throws DatabricksSQLException { + // Thrift uses rowOffset with FETCH_ABSOLUTE + LOGGER.info( + "Refetching expired link for chunk {}, row offset {} of statement {}", + chunkIndex, + rowOffset, + statementId); + + // For Thrift, we may need to retry if hasMore=true but no links returned yet + int maxRetries = 100; // Reasonable limit to prevent infinite loops + int retryCount = 0; + + while (retryCount < maxRetries) { + ChunkLinkFetchResult result = + session.getDatabricksClient().getResultChunks(statementId, chunkIndex, rowOffset); + + if (!result.getChunkLinks().isEmpty()) { + // Find the link for the requested chunk index + for (ExternalLink link : result.getChunkLinks()) { + if (link.getChunkIndex() == chunkIndex) { + LOGGER.debug( + "Successfully refetched link for chunk {} of statement {}", + chunkIndex, + statementId); + return link; + } + } + + // Exact match not found - this indicates a server bug + throw new DatabricksSQLException( + String.format( + "Failed to refetch link for chunk %d: server returned links but none matched requested index", + chunkIndex), + DatabricksDriverErrorCode.CHUNK_READY_ERROR); + } + + // No links returned - check if we should retry + if (!result.hasMore()) { + // No more data and no links - this is unexpected for a refetch + throw new DatabricksSQLException( + String.format( + "Failed to refetch link for chunk %d: no links returned and hasMore=false", + chunkIndex), + DatabricksDriverErrorCode.CHUNK_READY_ERROR); + } + + // hasMore=true but no links yet - retry + retryCount++; + LOGGER.debug( + "No links returned for chunk {} but hasMore=true, retrying ({}/{})", + chunkIndex, + retryCount, + maxRetries); + } + + throw new DatabricksSQLException( + String.format( + "Failed to refetch link for chunk %d: max retries (%d) exceeded", + chunkIndex, maxRetries), + DatabricksDriverErrorCode.CHUNK_READY_ERROR); + } + + @Override + public void close() { + LOGGER.debug("Closing ThriftChunkLinkFetcher for statement {}", statementId); + // No resources to clean up for Thrift fetcher + } +} diff --git a/src/main/java/com/databricks/jdbc/api/internal/IDatabricksConnectionContext.java b/src/main/java/com/databricks/jdbc/api/internal/IDatabricksConnectionContext.java index 202358564..f3a8c0911 100644 --- a/src/main/java/com/databricks/jdbc/api/internal/IDatabricksConnectionContext.java +++ b/src/main/java/com/databricks/jdbc/api/internal/IDatabricksConnectionContext.java @@ -422,4 +422,19 @@ public interface IDatabricksConnectionContext { /** Returns whether token federation is enabled for authentication. */ boolean isTokenFederationEnabled(); + + /** Returns whether streaming chunk provider is enabled for result fetching. */ + boolean isStreamingChunkProviderEnabled(); + + /** + * Returns the number of chunk links to prefetch ahead of consumption. + * + *

This controls how far ahead the streaming chunk provider fetches links before they are + * needed. Higher values reduce latency by ensuring links are ready when needed. Lower values + * reduce the risk of link expiry for workloads that process data slowly (e.g., heavy computation + * per row), since prefetched links may expire before being used. + * + * @return the link prefetch window size (default: 128) + */ + int getLinkPrefetchWindow(); } diff --git a/src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java b/src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java index df5057885..ddbec923f 100644 --- a/src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java +++ b/src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java @@ -188,6 +188,16 @@ public enum DatabricksJdbcUrlParams { "1"), ENABLE_TOKEN_FEDERATION( "EnableTokenFederation", "Enable token federation for authentication", "1"), + ENABLE_STREAMING_CHUNK_PROVIDER( + "EnableStreamingChunkProvider", + "Enable streaming chunk provider for result fetching (experimental)", + "0"), + LINK_PREFETCH_WINDOW( + "LinkPrefetchWindow", + "Number of chunk links to prefetch ahead of consumption. " + + "Higher values reduce latency by having links ready sooner. " + + "Lower values reduce risk of link expiry for slow processing workloads", + "128"), API_RETRIABLE_HTTP_CODES( "ApiRetriableHttpCodes", "Comma-separated list of HTTP status codes that should be retried irrespective of Retry-After header.", diff --git a/src/main/java/com/databricks/jdbc/common/util/DatabricksThriftUtil.java b/src/main/java/com/databricks/jdbc/common/util/DatabricksThriftUtil.java index b5df82129..308dfafe6 100644 --- a/src/main/java/com/databricks/jdbc/common/util/DatabricksThriftUtil.java +++ b/src/main/java/com/databricks/jdbc/common/util/DatabricksThriftUtil.java @@ -21,6 +21,7 @@ import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; import com.databricks.sdk.service.sql.StatementState; import java.nio.ByteBuffer; +import java.time.Instant; import java.util.*; public class DatabricksThriftUtil { @@ -73,7 +74,10 @@ public static ExternalLink createExternalLink(TSparkArrowResultLink chunkInfo, l return new ExternalLink() .setExternalLink(chunkInfo.getFileLink()) .setChunkIndex(chunkIndex) - .setExpiration(Long.toString(chunkInfo.getExpiryTime())); + .setExpiration(Instant.ofEpochMilli(chunkInfo.getExpiryTime()).toString()) + .setRowOffset(chunkInfo.getStartRowOffset()) + .setByteCount(chunkInfo.getBytesNum()) + .setRowCount(chunkInfo.getRowCount()); } public static void verifySuccessStatus(TStatus status, String errorContext) diff --git a/src/main/java/com/databricks/jdbc/dbclient/IDatabricksClient.java b/src/main/java/com/databricks/jdbc/dbclient/IDatabricksClient.java index e1c6398f1..612b10475 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/IDatabricksClient.java +++ b/src/main/java/com/databricks/jdbc/dbclient/IDatabricksClient.java @@ -9,12 +9,11 @@ import com.databricks.jdbc.dbclient.impl.common.StatementId; import com.databricks.jdbc.exception.DatabricksSQLException; import com.databricks.jdbc.model.client.thrift.generated.TFetchResultsResp; -import com.databricks.jdbc.model.core.ExternalLink; +import com.databricks.jdbc.model.core.ChunkLinkFetchResult; import com.databricks.jdbc.model.core.ResultData; import com.databricks.jdbc.telemetry.latency.DatabricksMetricsTimed; import com.databricks.sdk.core.DatabricksConfig; import java.sql.SQLException; -import java.util.Collection; import java.util.Map; /** Interface for Databricks client which abstracts the integration with Databricks server. */ @@ -115,12 +114,25 @@ DatabricksResultSet getStatementResult( throws SQLException; /** - * Fetches the chunk details for given chunk index and statement-Id. + * Fetches the chunk links for given chunk index and statement-Id. + * + *

For SEA clients, the chunkIndex is used to identify which chunk to fetch. For Thrift + * clients, the rowOffset is used with FETCH_ABSOLUTE orientation to seek to the correct position. + * + *

The returned {@link ChunkLinkFetchResult} contains the chunk links and continuation + * information: + * + *

    + *
  • SEA: hasMore derived from last link's nextChunkIndex + *
  • Thrift: hasMore from server's hasMoreRows flag, nextRowOffset for continuation + *
* * @param statementId statement-Id for which chunk should be fetched - * @param chunkIndex chunkIndex for which chunk should be fetched + * @param chunkIndex chunkIndex for which chunk should be fetched (used by SEA) + * @param rowOffset row offset for fetching results (used by Thrift with FETCH_ABSOLUTE) + * @return ChunkLinkFetchResult containing links and continuation information */ - Collection getResultChunks(StatementId statementId, long chunkIndex) + ChunkLinkFetchResult getResultChunks(StatementId statementId, long chunkIndex, long rowOffset) throws DatabricksSQLException; /** diff --git a/src/main/java/com/databricks/jdbc/dbclient/impl/sqlexec/DatabricksSdkClient.java b/src/main/java/com/databricks/jdbc/dbclient/impl/sqlexec/DatabricksSdkClient.java index b8dfed049..c0bfcaabc 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/impl/sqlexec/DatabricksSdkClient.java +++ b/src/main/java/com/databricks/jdbc/dbclient/impl/sqlexec/DatabricksSdkClient.java @@ -29,6 +29,7 @@ import com.databricks.jdbc.model.client.sqlexec.ExecuteStatementResponse; import com.databricks.jdbc.model.client.sqlexec.GetStatementResponse; import com.databricks.jdbc.model.client.thrift.generated.TFetchResultsResp; +import com.databricks.jdbc.model.core.ChunkLinkFetchResult; import com.databricks.jdbc.model.core.Disposition; import com.databricks.jdbc.model.core.ExternalLink; import com.databricks.jdbc.model.core.ResultData; @@ -409,14 +410,12 @@ public void cancelStatement(StatementId typedStatementId) throws DatabricksSQLEx } @Override - public Collection getResultChunks(StatementId typedStatementId, long chunkIndex) - throws DatabricksSQLException { - DatabricksThreadContextHolder.setStatementId(typedStatementId); + public ChunkLinkFetchResult getResultChunks( + StatementId typedStatementId, long chunkIndex, long rowOffset) throws DatabricksSQLException { + // SEA uses chunkIndex; rowOffset is ignored String statementId = typedStatementId.toSQLExecStatementId(); LOGGER.debug( - "public Optional getResultChunk(String statementId = {}, long chunkIndex = {})", - statementId, - chunkIndex); + "getResultChunks(statementId={}, chunkIndex={}) using SEA client", statementId, chunkIndex); GetStatementResultChunkNRequest request = new GetStatementResultChunkNRequest().setStatementId(statementId).setChunkIndex(chunkIndex); String path = String.format(RESULT_CHUNK_PATH, statementId, chunkIndex); @@ -424,7 +423,7 @@ public Collection getResultChunks(StatementId typedStatementId, lo Request req = new Request(Request.GET, path, apiClient.serialize(request)); req.withHeaders(getHeaders("getStatementResultN")); ResultData resultData = apiClient.execute(req, ResultData.class); - return resultData.getExternalLinks(); + return buildChunkLinkFetchResult(resultData.getExternalLinks()); } catch (IOException e) { String errorMessage = "Error while processing the get result chunk request"; LOGGER.error(errorMessage, e); @@ -432,6 +431,36 @@ public Collection getResultChunks(StatementId typedStatementId, lo } } + /** + * Builds a ChunkLinkFetchResult from SEA external links. + * + * @param links The external links from the SEA response + * @return ChunkLinkFetchResult with links and continuation info + */ + private ChunkLinkFetchResult buildChunkLinkFetchResult(Collection links) { + if (links == null || links.isEmpty()) { + return ChunkLinkFetchResult.endOfStream(); + } + + List linkList = + links instanceof List ? (List) links : new ArrayList<>(links); + + // Derive continuation info from last link + ExternalLink lastLink = linkList.get(linkList.size() - 1); + boolean hasMore = lastLink.getNextChunkIndex() != null; + long nextFetchIndex = hasMore ? lastLink.getNextChunkIndex() : -1; + long nextRowOffset = lastLink.getRowOffset() + lastLink.getRowCount(); + + LOGGER.debug( + "Built ChunkLinkFetchResult with {} links, hasMore={}, nextFetchIndex={}, nextRowOffset={}", + linkList.size(), + hasMore, + nextFetchIndex, + nextRowOffset); + + return ChunkLinkFetchResult.of(linkList, hasMore, nextFetchIndex, nextRowOffset); + } + @Override public ResultData getResultChunksData(StatementId typedStatementId, long chunkIndex) throws DatabricksSQLException { diff --git a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java index 009ace07e..68e8888e8 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java +++ b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java @@ -521,6 +521,58 @@ TFetchResultsResp getResultSetResp( return response; } + /** + * Fetches results using FETCH_ABSOLUTE orientation starting from the given row offset. + * + *

This method is used by the streaming chunk provider to seek to a specific row position and + * fetch a batch of results. + * + * @param operationHandle The operation handle for the statement + * @param startRowOffset The row offset to start fetching from (0-indexed) + * @param context Context string for logging + * @return The fetch results response + * @throws DatabricksHttpException if the fetch fails + */ + TFetchResultsResp fetchResultsWithAbsoluteOffset( + TOperationHandle operationHandle, long startRowOffset, String context) + throws DatabricksHttpException { + String statementId = StatementId.loggableStatementId(operationHandle); + LOGGER.debug( + "Fetching results with FETCH_ABSOLUTE at offset {} for statement {}", + startRowOffset, + statementId); + + TFetchResultsReq request = + new TFetchResultsReq() + .setOperationHandle(operationHandle) + .setStartRowOffset(startRowOffset) + .setFetchType((short) 0) // 0 represents Query output + .setMaxRows(maxRowsPerBlock) + .setMaxBytes(DEFAULT_BYTE_LIMIT); + + TFetchResultsResp response; + try { + response = getThriftClient().FetchResults(request); + } catch (TException e) { + String errorMessage = + String.format( + "Error while fetching results from Thrift server with FETCH_ABSOLUTE. " + + "startRowOffset=%d, maxRows=%d, Error {%s}", + startRowOffset, request.getMaxRows(), e.getMessage()); + LOGGER.error(e, errorMessage); + throw new DatabricksHttpException(errorMessage, e, DatabricksDriverErrorCode.INVALID_STATE); + } + + verifySuccessStatus( + response.getStatus(), + String.format( + "Error while fetching results with FETCH_ABSOLUTE. startRowOffset=%d, hasMoreRows=%s", + startRowOffset, response.hasMoreRows), + statementId); + + return response; + } + private TFetchResultsResp listFunctions(TGetFunctionsReq request) throws TException, DatabricksSQLException { if (enableDirectResults) request.setGetDirectResults(DEFAULT_DIRECT_RESULTS); diff --git a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClient.java b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClient.java index 4cea9523a..e7d2ee647 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClient.java +++ b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClient.java @@ -29,6 +29,7 @@ import com.databricks.jdbc.log.JdbcLogger; import com.databricks.jdbc.log.JdbcLoggerFactory; import com.databricks.jdbc.model.client.thrift.generated.*; +import com.databricks.jdbc.model.core.ChunkLinkFetchResult; import com.databricks.jdbc.model.core.ExternalLink; import com.databricks.jdbc.model.core.ResultData; import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; @@ -37,7 +38,6 @@ import java.math.BigDecimal; import java.sql.SQLException; import java.util.*; -import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; public class DatabricksThriftServiceClient implements IDatabricksClient, IDatabricksMetadataClient { @@ -301,32 +301,67 @@ public DatabricksResultSet getStatementResult( } @Override - public Collection getResultChunks(StatementId statementId, long chunkIndex) - throws DatabricksSQLException { - String context = - String.format( - "public Optional getResultChunk(String statementId = {%s}, long chunkIndex = {%s}) using Thrift client", - statementId, chunkIndex); - LOGGER.debug(context); - DatabricksThreadContextHolder.setStatementId(statementId); - TFetchResultsResp fetchResultsResp; - List externalLinks = new ArrayList<>(); - AtomicInteger index = new AtomicInteger(0); - do { - fetchResultsResp = thriftAccessor.getResultSetResp(getOperationHandle(statementId), context); - fetchResultsResp - .getResults() - .getResultLinks() - .forEach( - resultLink -> - externalLinks.add(createExternalLink(resultLink, index.getAndIncrement()))); - } while (fetchResultsResp.hasMoreRows); - if (chunkIndex < 0 || externalLinks.size() <= chunkIndex) { - String error = String.format("Out of bounds error for chunkIndex. Context: %s", context); - LOGGER.error(error); - throw new DatabricksSQLException(error, DatabricksDriverErrorCode.INVALID_STATE); + public ChunkLinkFetchResult getResultChunks( + StatementId statementId, long chunkIndex, long rowOffset) throws DatabricksSQLException { + // Thrift uses rowOffset with FETCH_ABSOLUTE; chunkIndex is used for link metadata + LOGGER.debug( + "getResultChunks(statementId={}, chunkIndex={}, rowOffset={}) using Thrift client", + statementId, + chunkIndex, + rowOffset); + + TFetchResultsResp fetchResultsResp = + thriftAccessor.fetchResultsWithAbsoluteOffset( + getOperationHandle(statementId), rowOffset, "getResultChunks"); + + boolean hasMoreRows = fetchResultsResp.hasMoreRows; + List resultLinks = fetchResultsResp.getResults().getResultLinks(); + + if (resultLinks == null || resultLinks.isEmpty()) { + LOGGER.debug( + "No result links returned for statement {}, hasMoreRows={}", statementId, hasMoreRows); + // For Thrift, hasMoreRows is the source of truth. Even with no links, + // if hasMoreRows is true, we should indicate continuation with the same offset. + return ChunkLinkFetchResult.of(new ArrayList<>(), hasMoreRows, chunkIndex, rowOffset); + } + + List chunkLinks = new ArrayList<>(); + int lastIndex = resultLinks.size() - 1; + long nextRowOffset = rowOffset; + long nextFetchIndex = chunkIndex; + + for (int i = 0; i < resultLinks.size(); i++) { + TSparkArrowResultLink thriftLink = resultLinks.get(i); + long linkChunkIndex = chunkIndex + i; + + // createExternalLink sets chunkIndex, rowOffset, rowCount, byteCount, expiration, + // externalLink + ExternalLink externalLink = createExternalLink(thriftLink, linkChunkIndex); + + // Set nextChunkIndex based on position and hasMoreRows + if (i == lastIndex) { + if (hasMoreRows) { + externalLink.setNextChunkIndex(linkChunkIndex + 1); + nextFetchIndex = linkChunkIndex + 1; + } + nextRowOffset = thriftLink.getStartRowOffset() + thriftLink.getRowCount(); + } else { + externalLink.setNextChunkIndex(linkChunkIndex + 1); + } + + chunkLinks.add(externalLink); } - return externalLinks; + + LOGGER.debug( + "Built ChunkLinkFetchResult with {} links for statement {}, hasMore={}, nextFetchIndex={}, nextRowOffset={}", + chunkLinks.size(), + statementId, + hasMoreRows, + nextFetchIndex, + nextRowOffset); + + return ChunkLinkFetchResult.of( + chunkLinks, hasMoreRows, hasMoreRows ? nextFetchIndex : -1, nextRowOffset); } @Override diff --git a/src/main/java/com/databricks/jdbc/model/core/ChunkLinkFetchResult.java b/src/main/java/com/databricks/jdbc/model/core/ChunkLinkFetchResult.java new file mode 100644 index 000000000..51c2ad071 --- /dev/null +++ b/src/main/java/com/databricks/jdbc/model/core/ChunkLinkFetchResult.java @@ -0,0 +1,101 @@ +package com.databricks.jdbc.model.core; + +import java.util.Collections; +import java.util.List; + +/** + * Result of fetching chunk links from the server. + * + *

Contains the fetched chunk links and continuation information for both SEA and Thrift + * protocols: + * + *

    + *
  • SEA: Uses chunkIndex for continuation, hasMore derived from nextChunkIndex on last link + *
  • Thrift: Uses rowOffset for continuation, hasMore from server's hasMoreRows flag + *
+ * + *

Each {@link ExternalLink} contains chunkIndex, rowCount, rowOffset, and the download URL. + */ +public class ChunkLinkFetchResult { + + private final List chunkLinks; + private final boolean hasMore; + private final long nextFetchIndex; + private final long nextRowOffset; + + private ChunkLinkFetchResult( + List chunkLinks, boolean hasMore, long nextFetchIndex, long nextRowOffset) { + this.chunkLinks = chunkLinks; + this.hasMore = hasMore; + this.nextFetchIndex = nextFetchIndex; + this.nextRowOffset = nextRowOffset; + } + + /** + * Creates a result with the given links and continuation info. + * + * @param links The fetched external links (each contains chunkIndex, rowCount, rowOffset, URL) + * @param hasMore Whether more chunks are available + * @param nextFetchIndex The next chunk index to fetch from, or -1 if no more + * @param nextRowOffset The next row offset for Thrift FETCH_ABSOLUTE + * @return A new ChunkLinkFetchResult + */ + public static ChunkLinkFetchResult of( + List links, boolean hasMore, long nextFetchIndex, long nextRowOffset) { + return new ChunkLinkFetchResult(links, hasMore, nextFetchIndex, nextRowOffset); + } + + /** + * Creates a result indicating the end of the stream (no more chunks). + * + * @return A ChunkLinkFetchResult representing end of stream + */ + public static ChunkLinkFetchResult endOfStream() { + return new ChunkLinkFetchResult(Collections.emptyList(), false, -1, 0); + } + + /** + * Returns the list of external links fetched in this batch. + * + * @return List of ExternalLink, may be empty + */ + public List getChunkLinks() { + return chunkLinks; + } + + /** + * Returns whether more chunks are available after this batch. + * + * @return true if more chunks can be fetched, false otherwise + */ + public boolean hasMore() { + return hasMore; + } + + /** + * Returns the next chunk index to fetch from. + * + * @return The next fetch index, or -1 if no more chunks + */ + public long getNextFetchIndex() { + return nextFetchIndex; + } + + /** + * Returns the next row offset for Thrift FETCH_ABSOLUTE continuation. + * + * @return The next row offset, or 0 if not applicable + */ + public long getNextRowOffset() { + return nextRowOffset; + } + + /** + * Checks if this result represents the end of the chunk stream. + * + * @return true if no more chunks are available + */ + public boolean isEndOfStream() { + return !hasMore && chunkLinks.isEmpty(); + } +} diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResultTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResultTest.java index 7bd55237c..989e4cb4d 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResultTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResultTest.java @@ -3,8 +3,7 @@ import static com.databricks.jdbc.TestConstants.*; import static java.lang.Math.min; import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.when; import com.databricks.jdbc.api.impl.DatabricksConnectionContextFactory; @@ -21,6 +20,7 @@ import com.databricks.jdbc.model.client.thrift.generated.TGetResultSetMetadataResp; import com.databricks.jdbc.model.client.thrift.generated.TRowSet; import com.databricks.jdbc.model.client.thrift.generated.TSparkArrowResultLink; +import com.databricks.jdbc.model.core.ChunkLinkFetchResult; import com.databricks.jdbc.model.core.ColumnInfo; import com.databricks.jdbc.model.core.ColumnInfoTypeName; import com.databricks.jdbc.model.core.ExternalLink; @@ -248,7 +248,9 @@ private List getChunkLinks(long chunkIndex, boolean isLast) { new ExternalLink() .setChunkIndex(chunkIndex) .setExternalLink(CHUNK_URL_PREFIX + chunkIndex) - .setExpiration(Instant.now().plusSeconds(3600L).toString()); + .setExpiration(Instant.now().plusSeconds(3600L).toString()) + .setRowOffset(chunkIndex * this.rowsInChunk) + .setRowCount(this.rowsInChunk); if (!isLast) { chunkLink.setNextChunkIndex(chunkIndex + 1); } @@ -283,11 +285,24 @@ private void setupMockResponse() throws Exception { private void setupResultChunkMocks() throws DatabricksSQLException { for (int chunkIndex = 1; chunkIndex < numberOfChunks; chunkIndex++) { boolean isLastChunk = (chunkIndex == (numberOfChunks - 1)); - when(mockedSdkClient.getResultChunks(STATEMENT_ID, chunkIndex)) - .thenReturn(getChunkLinks(chunkIndex, isLastChunk)); + when(mockedSdkClient.getResultChunks(eq(STATEMENT_ID), eq((long) chunkIndex), anyLong())) + .thenReturn(buildChunkLinkFetchResult(getChunkLinks(chunkIndex, isLastChunk))); } } + private ChunkLinkFetchResult buildChunkLinkFetchResult(List links) { + if (links == null || links.isEmpty()) { + return ChunkLinkFetchResult.endOfStream(); + } + + ExternalLink lastLink = links.get(links.size() - 1); + boolean hasMore = lastLink.getNextChunkIndex() != null; + long nextFetchIndex = hasMore ? lastLink.getNextChunkIndex() : -1; + long nextRowOffset = lastLink.getRowOffset() + lastLink.getRowCount(); + + return ChunkLinkFetchResult.of(links, hasMore, nextFetchIndex, nextRowOffset); + } + private File createTestArrowFile( String fileName, Schema schema, Object[][] testData, RootAllocator allocator) throws IOException { diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/ChunkLinkDownloadServiceTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/ChunkLinkDownloadServiceTest.java index 52981ced0..13207280a 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/arrow/ChunkLinkDownloadServiceTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/ChunkLinkDownloadServiceTest.java @@ -11,6 +11,7 @@ import com.databricks.jdbc.dbclient.impl.common.StatementId; import com.databricks.jdbc.exception.DatabricksSQLException; import com.databricks.jdbc.exception.DatabricksValidationException; +import com.databricks.jdbc.model.core.ChunkLinkFetchResult; import com.databricks.jdbc.model.core.ExternalLink; import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; import java.time.Instant; @@ -55,14 +56,14 @@ void testGetLinkForChunk_Success() when(mockSession.getDatabricksClient()).thenReturn(mockClient); // Mock the response to link requests - when(mockClient.getResultChunks(eq(mockStatementId), eq(1L))) - .thenReturn(Collections.singletonList(linkForChunkIndex_1)); - when(mockClient.getResultChunks(eq(mockStatementId), eq(2L))) - .thenReturn(Collections.singletonList(linkForChunkIndex_2)); - when(mockClient.getResultChunks(eq(mockStatementId), eq(3L))) - .thenReturn(Collections.singletonList(linkForChunkIndex_3)); - when(mockClient.getResultChunks(eq(mockStatementId), eq(4L))) - .thenReturn(Collections.singletonList(linkForChunkIndex_4)); + when(mockClient.getResultChunks(eq(mockStatementId), eq(1L), eq(0L))) + .thenReturn(buildChunkLinkFetchResult(Collections.singletonList(linkForChunkIndex_1))); + when(mockClient.getResultChunks(eq(mockStatementId), eq(2L), eq(0L))) + .thenReturn(buildChunkLinkFetchResult(Collections.singletonList(linkForChunkIndex_2))); + when(mockClient.getResultChunks(eq(mockStatementId), eq(3L), eq(0L))) + .thenReturn(buildChunkLinkFetchResult(Collections.singletonList(linkForChunkIndex_3))); + when(mockClient.getResultChunks(eq(mockStatementId), eq(4L), eq(0L))) + .thenReturn(buildChunkLinkFetchResult(Collections.singletonList(linkForChunkIndex_4))); long chunkIndex = 1L; when(mockChunkMap.get(chunkIndex)).thenReturn(mock(ArrowResultChunk.class)); @@ -78,7 +79,7 @@ void testGetLinkForChunk_Success() TimeUnit.MILLISECONDS.sleep(500); assertEquals(linkForChunkIndex_1, result); - verify(mockClient).getResultChunks(mockStatementId, NEXT_BATCH_START_INDEX); + verify(mockClient).getResultChunks(mockStatementId, NEXT_BATCH_START_INDEX, 0L); } @Test @@ -115,7 +116,8 @@ void testGetLinkForChunk_ClientError() new DatabricksSQLException("Test error", DatabricksDriverErrorCode.INVALID_STATE); when(mockSession.getDatabricksClient()).thenReturn(mockClient); // Mock an error in response to the link request - when(mockClient.getResultChunks(eq(mockStatementId), anyLong())).thenThrow(expectedError); + when(mockClient.getResultChunks(eq(mockStatementId), anyLong(), anyLong())) + .thenThrow(expectedError); when(mockChunkMap.get(chunkIndex)).thenReturn(mock(ArrowResultChunk.class)); ChunkLinkDownloadService service = @@ -133,14 +135,14 @@ void testGetLinkForChunk_ClientError() void testAutoTriggerForSEAClient() throws DatabricksSQLException, InterruptedException { when(mockSession.getDatabricksClient()).thenReturn(mockClient); // Mock the response to link requests - when(mockClient.getResultChunks(eq(mockStatementId), eq(1L))) - .thenReturn(Collections.singletonList(linkForChunkIndex_1)); - when(mockClient.getResultChunks(eq(mockStatementId), eq(2L))) - .thenReturn(Collections.singletonList(linkForChunkIndex_2)); - when(mockClient.getResultChunks(eq(mockStatementId), eq(3L))) - .thenReturn(Collections.singletonList(linkForChunkIndex_3)); - when(mockClient.getResultChunks(eq(mockStatementId), eq(4L))) - .thenReturn(Collections.singletonList(linkForChunkIndex_4)); + when(mockClient.getResultChunks(eq(mockStatementId), eq(1L), eq(0L))) + .thenReturn(buildChunkLinkFetchResult(Collections.singletonList(linkForChunkIndex_1))); + when(mockClient.getResultChunks(eq(mockStatementId), eq(2L), eq(0L))) + .thenReturn(buildChunkLinkFetchResult(Collections.singletonList(linkForChunkIndex_2))); + when(mockClient.getResultChunks(eq(mockStatementId), eq(3L), eq(0L))) + .thenReturn(buildChunkLinkFetchResult(Collections.singletonList(linkForChunkIndex_3))); + when(mockClient.getResultChunks(eq(mockStatementId), eq(4L), eq(0L))) + .thenReturn(buildChunkLinkFetchResult(Collections.singletonList(linkForChunkIndex_4))); // Download chain will be triggered immediately in the constructor when(mockSession.getConnectionContext().getClientType()).thenReturn(DatabricksClientType.SEA); @@ -150,7 +152,7 @@ void testAutoTriggerForSEAClient() throws DatabricksSQLException, InterruptedExc // Sleep to allow the service to complete the download pipeline TimeUnit.MILLISECONDS.sleep(500); - verify(mockClient).getResultChunks(mockStatementId, NEXT_BATCH_START_INDEX); + verify(mockClient).getResultChunks(mockStatementId, NEXT_BATCH_START_INDEX, 0L); } @Test @@ -163,14 +165,15 @@ void testHandleExpiredLinks() when(mockSession.getDatabricksClient()).thenReturn(mockClient); // Mock the response to link requests. Return the expired link for chunk index 1 - when(mockClient.getResultChunks(eq(mockStatementId), eq(1L))) - .thenReturn(Collections.singletonList(expiredLinkForChunkIndex_1)); - when(mockClient.getResultChunks(eq(mockStatementId), eq(2L))) - .thenReturn(Collections.singletonList(linkForChunkIndex_2)); - when(mockClient.getResultChunks(eq(mockStatementId), eq(3L))) - .thenReturn(Collections.singletonList(linkForChunkIndex_3)); - when(mockClient.getResultChunks(eq(mockStatementId), eq(4L))) - .thenReturn(Collections.singletonList(linkForChunkIndex_4)); + when(mockClient.getResultChunks(eq(mockStatementId), eq(1L), anyLong())) + .thenReturn( + buildChunkLinkFetchResult(Collections.singletonList(expiredLinkForChunkIndex_1))); + when(mockClient.getResultChunks(eq(mockStatementId), eq(2L), anyLong())) + .thenReturn(buildChunkLinkFetchResult(Collections.singletonList(linkForChunkIndex_2))); + when(mockClient.getResultChunks(eq(mockStatementId), eq(3L), anyLong())) + .thenReturn(buildChunkLinkFetchResult(Collections.singletonList(linkForChunkIndex_3))); + when(mockClient.getResultChunks(eq(mockStatementId), eq(4L), anyLong())) + .thenReturn(buildChunkLinkFetchResult(Collections.singletonList(linkForChunkIndex_4))); long chunkIndex = 1L; ArrowResultChunk mockChunk = mock(ArrowResultChunk.class); @@ -185,8 +188,8 @@ void testHandleExpiredLinks() TimeUnit.MILLISECONDS.sleep(500); // Mock a new valid link for chunk index 1 - when(mockClient.getResultChunks(eq(mockStatementId), eq(1L))) - .thenReturn(Collections.singletonList(linkForChunkIndex_1)); + when(mockClient.getResultChunks(eq(mockStatementId), eq(1L), eq(0L))) + .thenReturn(buildChunkLinkFetchResult(Collections.singletonList(linkForChunkIndex_1))); // Try to get the link for chunk index 1. Download chain will be re-triggered because the link // is expired CompletableFuture future = service.getLinkForChunk(chunkIndex); @@ -195,7 +198,7 @@ void testHandleExpiredLinks() TimeUnit.MILLISECONDS.sleep(500); assertEquals(linkForChunkIndex_1, result); - verify(mockClient, times(2)).getResultChunks(mockStatementId, chunkIndex); + verify(mockClient, times(2)).getResultChunks(mockStatementId, chunkIndex, 0L); } @Test @@ -222,14 +225,17 @@ void testBatchDownloadChaining() when(mockSession.getDatabricksClient()).thenReturn(mockClient); // Mock the links for the first batch. The link futures for both chunks will be completed at the // same time - when(mockClient.getResultChunks(eq(mockStatementId), eq(1L))) - .thenReturn(Arrays.asList(linkForChunkIndex_1, linkForChunkIndex_2)); + when(mockClient.getResultChunks(eq(mockStatementId), eq(1L), eq(0L))) + .thenReturn( + buildChunkLinkFetchResult(Arrays.asList(linkForChunkIndex_1, linkForChunkIndex_2))); // Mock the links for the second batch. - when(mockClient.getResultChunks(eq(mockStatementId), eq(3L))) - .thenReturn(Arrays.asList(linkForChunkIndex_3, linkForChunkIndex_4)); + when(mockClient.getResultChunks(eq(mockStatementId), eq(3L), eq(0L))) + .thenReturn( + buildChunkLinkFetchResult(Arrays.asList(linkForChunkIndex_3, linkForChunkIndex_4))); // Mock the links for the third batch. - when(mockClient.getResultChunks(eq(mockStatementId), eq(5L))) - .thenReturn(Arrays.asList(linkForChunkIndex_5, linkForChunkIndex_6)); + when(mockClient.getResultChunks(eq(mockStatementId), eq(5L), eq(0L))) + .thenReturn( + buildChunkLinkFetchResult(Arrays.asList(linkForChunkIndex_5, linkForChunkIndex_6))); ChunkLinkDownloadService service = new ChunkLinkDownloadService<>( @@ -260,11 +266,11 @@ void testBatchDownloadChaining() assertEquals(linkForChunkIndex_5, result5); assertEquals(linkForChunkIndex_6, result6); // Verify the request for first batch - verify(mockClient, times(1)).getResultChunks(mockStatementId, 1L); + verify(mockClient, times(1)).getResultChunks(mockStatementId, 1L, 0L); // Verify the request for second batch - verify(mockClient, times(1)).getResultChunks(mockStatementId, 3L); + verify(mockClient, times(1)).getResultChunks(mockStatementId, 3L, 0L); // Verify the request for third batch - verify(mockClient, times(1)).getResultChunks(mockStatementId, 5L); + verify(mockClient, times(1)).getResultChunks(mockStatementId, 5L, 0L); } private ExternalLink createExternalLink( @@ -274,7 +280,26 @@ private ExternalLink createExternalLink( link.setChunkIndex(chunkIndex); link.setHttpHeaders(headers); link.setExpiration(expiration); + link.setRowOffset(chunkIndex * 100L); + link.setRowCount(100L); return link; } + + /** + * Helper method to build ChunkLinkFetchResult from a list of ExternalLinks. This mimics the + * behavior of the SEA client's buildChunkLinkFetchResult method. + */ + private ChunkLinkFetchResult buildChunkLinkFetchResult(List links) { + if (links == null || links.isEmpty()) { + return ChunkLinkFetchResult.endOfStream(); + } + + ExternalLink lastLink = links.get(links.size() - 1); + boolean hasMore = lastLink.getNextChunkIndex() != null; + long nextFetchIndex = hasMore ? lastLink.getNextChunkIndex() : -1; + long nextRowOffset = lastLink.getRowOffset() + lastLink.getRowCount(); + + return ChunkLinkFetchResult.of(links, hasMore, nextFetchIndex, nextRowOffset); + } } diff --git a/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClientTest.java b/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClientTest.java index 27f4c04e5..98ff3d274 100644 --- a/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClientTest.java +++ b/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClientTest.java @@ -10,6 +10,7 @@ import static com.databricks.jdbc.model.core.ColumnInfoTypeName.*; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -23,6 +24,7 @@ import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; import com.databricks.jdbc.model.client.thrift.generated.*; +import com.databricks.jdbc.model.core.ChunkLinkFetchResult; import com.databricks.jdbc.model.core.ExternalLink; import com.databricks.jdbc.model.core.ResultColumn; import com.databricks.sdk.core.DatabricksConfig; @@ -369,17 +371,20 @@ void testGetResultChunks() throws SQLException { .setStatus(new TStatus().setStatusCode(TStatusCode.SUCCESS_STATUS)) .setResults(resultData) .setResultSetMetadata(resultMetadataData); - when(thriftAccessor.getResultSetResp(any(), any())).thenReturn(response); + when(thriftAccessor.fetchResultsWithAbsoluteOffset(any(), anyLong(), any())) + .thenReturn(response); when(resultData.getResultLinks()) .thenReturn( Collections.singletonList(new TSparkArrowResultLink().setFileLink(TEST_STRING))); - Collection resultChunks = client.getResultChunks(TEST_STMT_ID, 0); - assertEquals(resultChunks.size(), 1); - assertEquals(resultChunks.stream().findFirst().get().getExternalLink(), TEST_STRING); + // Pass chunkIndex=0 and rowOffset=0 for the first chunk + ChunkLinkFetchResult result = client.getResultChunks(TEST_STMT_ID, 0, 0); + List chunkLinks = result.getChunkLinks(); + assertEquals(1, chunkLinks.size()); + assertEquals(TEST_STRING, chunkLinks.get(0).getExternalLink()); } @Test - void testGetResultChunksThrowsError() throws SQLException { + void testGetResultChunksReturnsEmptyWhenNoLinks() throws SQLException { DatabricksThriftServiceClient client = new DatabricksThriftServiceClient(thriftAccessor, connectionContext); TFetchResultsResp response = @@ -387,10 +392,12 @@ void testGetResultChunksThrowsError() throws SQLException { .setStatus(new TStatus().setStatusCode(TStatusCode.SUCCESS_STATUS)) .setResults(resultData) .setResultSetMetadata(resultMetadataData); - when(thriftAccessor.getResultSetResp(any(), any())).thenReturn(response); - assertThrows(DatabricksSQLException.class, () -> client.getResultChunks(TEST_STMT_ID, -1)); - assertThrows(DatabricksSQLException.class, () -> client.getResultChunks(TEST_STMT_ID, 2)); - assertThrows(DatabricksSQLException.class, () -> client.getResultChunks(TEST_STMT_ID, 1)); + when(thriftAccessor.fetchResultsWithAbsoluteOffset(any(), anyLong(), any())) + .thenReturn(response); + when(resultData.getResultLinks()).thenReturn(null); + ChunkLinkFetchResult result = client.getResultChunks(TEST_STMT_ID, 0, 0); + assertTrue(result.isEndOfStream()); + assertEquals(0, result.getChunkLinks().size()); } @Test