diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/AbstractArrowResultChunk.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/AbstractArrowResultChunk.java index 9254ee9ec..44ee08ed5 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/AbstractArrowResultChunk.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/AbstractArrowResultChunk.java @@ -152,6 +152,15 @@ public void setChunkLink(ExternalLink chunk) { setStatus(ChunkStatus.URL_FETCHED); } + /** + * Returns the external link for this chunk. + * + * @return the external link, or null if not set + */ + protected ExternalLink getChunkLink() { + return chunkLink; + } + /** * Returns the current status of the chunk. * diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/ChunkLinkDownloadService.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/ChunkLinkDownloadService.java index 7d59ea13c..dfbbc78bf 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/ChunkLinkDownloadService.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/ChunkLinkDownloadService.java @@ -10,6 +10,7 @@ import com.databricks.jdbc.log.JdbcLogger; import com.databricks.jdbc.log.JdbcLoggerFactory; import com.databricks.jdbc.model.core.ExternalLink; +import com.google.common.annotations.VisibleForTesting; import java.time.Instant; import java.util.Collection; import java.util.Map; @@ -114,6 +115,24 @@ public ChunkLinkDownloadService( this.chunkIndexToChunksMap = chunkIndexToChunksMap; + // Complete futures for chunks that already have their links (upfront-fetched) + if (nextBatchStartIndex > 0) { + LOGGER.info("Completing futures for {} upfront-fetched links", nextBatchStartIndex); + int completedCount = 0; + for (long i = 0; i < Math.min(nextBatchStartIndex, totalChunks); i++) { + T chunk = chunkIndexToChunksMap.get(i); + if (chunk != null) { + ExternalLink link = chunk.getChunkLink(); + if (link != null) { + LOGGER.debug("Completing link future for chunk {} in constructor", i); + chunkIndexToLinkFuture.get(i).complete(link); + completedCount++; + } + } + } + LOGGER.info("Completed {} futures for upfront-fetched links", completedCount); + } + if (session.getConnectionContext().getClientType() == DatabricksClientType.SEA && isDownloadChainStarted.compareAndSet(false, true)) { // SEA doesn't give all chunk links, so better to trigger download chain as soon as possible @@ -330,7 +349,7 @@ private void handleExpiredLinksAndReset(long chunkIndex) LOGGER.info( "Detected expired link for chunk {}, re-triggering batch download from the smallest index with the expired link", chunkIndex); - for (long i = 1; i < totalChunks; i++) { + for (long i = 0; i < totalChunks; i++) { if (isChunkLinkExpiredForPendingDownload(i)) { LOGGER.info("Found the smallest index {} with the expired link, initiating reset", i); cancelCurrentDownloadTask(); @@ -423,4 +442,15 @@ private boolean isChunkLinkExpired(ExternalLink link) { return expirationWithBuffer.isBefore(Instant.now()); } + + /** + * Returns the CompletableFuture for a specific chunk index for testing purposes. + * + * @param chunkIndex The index of the chunk + * @return The CompletableFuture associated with the chunk index, or null if not found + */ + @VisibleForTesting + CompletableFuture getLinkFutureForTest(long chunkIndex) { + return chunkIndexToLinkFuture.get(chunkIndex); + } } diff --git a/src/main/java/com/databricks/jdbc/common/util/DatabricksThriftUtil.java b/src/main/java/com/databricks/jdbc/common/util/DatabricksThriftUtil.java index 71f5aac2d..1566bd777 100644 --- a/src/main/java/com/databricks/jdbc/common/util/DatabricksThriftUtil.java +++ b/src/main/java/com/databricks/jdbc/common/util/DatabricksThriftUtil.java @@ -21,6 +21,7 @@ import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; import com.databricks.sdk.service.sql.StatementState; import java.nio.ByteBuffer; +import java.time.Instant; import java.util.*; public class DatabricksThriftUtil { @@ -73,7 +74,7 @@ public static ExternalLink createExternalLink(TSparkArrowResultLink chunkInfo, l return new ExternalLink() .setExternalLink(chunkInfo.getFileLink()) .setChunkIndex(chunkIndex) - .setExpiration(Long.toString(chunkInfo.getExpiryTime())); + .setExpiration(Instant.ofEpochMilli(chunkInfo.getExpiryTime()).toString()); } public static void verifySuccessStatus(TStatus status, String errorContext) diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/ChunkLinkDownloadServiceTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/ChunkLinkDownloadServiceTest.java index 52981ced0..8ad09b26b 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/arrow/ChunkLinkDownloadServiceTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/ChunkLinkDownloadServiceTest.java @@ -47,6 +47,7 @@ class ChunkLinkDownloadServiceTest { @BeforeEach void setUp() { when(mockSession.getConnectionContext()).thenReturn(mock(IDatabricksConnectionContext.class)); + lenient().when(mockChunkMap.get(anyLong())).thenReturn(null); } @Test @@ -267,6 +268,66 @@ void testBatchDownloadChaining() verify(mockClient, times(1)).getResultChunks(mockStatementId, 5L); } + @Test + void testUpfrontFetchedLinks_FuturesCompletedInConstructor() + throws ExecutionException, InterruptedException, TimeoutException { + when(mockSession.getConnectionContext().getClientType()) + .thenReturn(DatabricksClientType.THRIFT); + + // Create links for upfront-fetched chunks + ExternalLink link0 = + createExternalLink("url-0", 0L, Collections.emptyMap(), "2025-02-16T00:00:00Z"); + ExternalLink link1 = + createExternalLink("url-1", 1L, Collections.emptyMap(), "2025-02-16T00:00:00Z"); + ExternalLink link2 = + createExternalLink("url-2", 2L, Collections.emptyMap(), "2025-02-16T00:00:00Z"); + + // Create mock chunks with links already set + ArrowResultChunk mockChunk0 = mock(ArrowResultChunk.class); + ArrowResultChunk mockChunk1 = mock(ArrowResultChunk.class); + ArrowResultChunk mockChunk2 = mock(ArrowResultChunk.class); + + ArrowResultChunk mockChunk3 = mock(ArrowResultChunk.class); + ArrowResultChunk mockChunk4 = mock(ArrowResultChunk.class); + + when(mockChunk0.getChunkLink()).thenReturn(link0); + when(mockChunk1.getChunkLink()).thenReturn(link1); + when(mockChunk2.getChunkLink()).thenReturn(link2); + + when(mockChunkMap.get(0L)).thenReturn(mockChunk0); + when(mockChunkMap.get(1L)).thenReturn(mockChunk1); + when(mockChunkMap.get(2L)).thenReturn(mockChunk2); + lenient().when(mockChunkMap.get(3L)).thenReturn(mockChunk3); + lenient().when(mockChunkMap.get(4L)).thenReturn(mockChunk4); + + // Create service with nextBatchStartIndex = 3 (meaning chunks 0, 1, 2 were upfront-fetched) + long nextBatchStartIndex = 3L; + ChunkLinkDownloadService service = + new ChunkLinkDownloadService<>( + mockSession, mockStatementId, TOTAL_CHUNKS, mockChunkMap, nextBatchStartIndex); + + // Verify that futures for chunks 0, 1, 2 are already completed + CompletableFuture future0 = service.getLinkFutureForTest(0L); + CompletableFuture future1 = service.getLinkFutureForTest(1L); + CompletableFuture future2 = service.getLinkFutureForTest(2L); + + assertTrue(future0.isDone(), "Future for chunk 0 should be completed"); + assertTrue(future1.isDone(), "Future for chunk 1 should be completed"); + assertTrue(future2.isDone(), "Future for chunk 2 should be completed"); + + // Verify the futures contain the correct links + assertEquals(link0, future0.get(100, TimeUnit.MILLISECONDS)); + assertEquals(link1, future1.get(100, TimeUnit.MILLISECONDS)); + assertEquals(link2, future2.get(100, TimeUnit.MILLISECONDS)); + + // Verify that futures for chunks 3, 4 are not completed + CompletableFuture future3 = service.getLinkFutureForTest(3L); + CompletableFuture future4 = service.getLinkFutureForTest(4L); + + assertFalse(future3.isDone(), "Future for chunk 3 should not be completed"); + assertFalse(future4.isDone(), "Future for chunk 4 should not be completed"); + } + private ExternalLink createExternalLink( String url, long chunkIndex, Map headers, String expiration) { ExternalLink link = new ExternalLink();