From 3386c6310aeb426c1936ffe5ff8107d84156f6c2 Mon Sep 17 00:00:00 2001 From: Yanming Zhou Date: Thu, 3 Apr 2025 10:30:46 +0800 Subject: [PATCH] Discard further rows once maxRows has been reached See https://github.com/spring-projects/spring-framework/issues/34666#issuecomment-2773151317 Signed-off-by: Yanming Zhou --- .../jdbc/core/JdbcTemplate.java | 40 +++++++++------- .../core/RowMapperResultSetExtractor.java | 17 ++++++- .../jdbc/core/JdbcTemplateTests.java | 47 +++++++++++++++++++ 3 files changed, 87 insertions(+), 17 deletions(-) diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java index 1b3e14d3686..11833ec090f 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java @@ -102,6 +102,7 @@ * @author Rod Johnson * @author Juergen Hoeller * @author Thomas Risberg + * @author Yanming Zhou * @since May 3, 2001 * @see JdbcOperations * @see PreparedStatementCreator @@ -493,12 +494,12 @@ public String getSql() { @Override public void query(String sql, RowCallbackHandler rch) throws DataAccessException { - query(sql, new RowCallbackHandlerResultSetExtractor(rch)); + query(sql, new RowCallbackHandlerResultSetExtractor(rch, this.maxRows)); } @Override public List query(String sql, RowMapper rowMapper) throws DataAccessException { - return result(query(sql, new RowMapperResultSetExtractor<>(rowMapper))); + return result(query(sql, new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows))); } @Override @@ -508,7 +509,7 @@ class StreamStatementCallback implements StatementCallback>, SqlProvid public Stream doInStatement(Statement stmt) throws SQLException { ResultSet rs = stmt.executeQuery(sql); Connection con = stmt.getConnection(); - return new ResultSetSpliterator<>(rs, rowMapper).stream().onClose(() -> { + return new ResultSetSpliterator<>(rs, rowMapper, JdbcTemplate.this.maxRows).stream().onClose(() -> { JdbcUtils.closeResultSet(rs); JdbcUtils.closeStatement(stmt); DataSourceUtils.releaseConnection(con, getDataSource()); @@ -773,12 +774,12 @@ private String appendSql(@Nullable String sql, String statement) { @Override public void query(PreparedStatementCreator psc, RowCallbackHandler rch) throws DataAccessException { - query(psc, new RowCallbackHandlerResultSetExtractor(rch)); + query(psc, new RowCallbackHandlerResultSetExtractor(rch, this.maxRows)); } @Override public void query(String sql, @Nullable PreparedStatementSetter pss, RowCallbackHandler rch) throws DataAccessException { - query(sql, pss, new RowCallbackHandlerResultSetExtractor(rch)); + query(sql, pss, new RowCallbackHandlerResultSetExtractor(rch, this.maxRows)); } @Override @@ -799,28 +800,28 @@ public void query(String sql, RowCallbackHandler rch, @Nullable Object @Nullable @Override public List query(PreparedStatementCreator psc, RowMapper rowMapper) throws DataAccessException { - return result(query(psc, new RowMapperResultSetExtractor<>(rowMapper))); + return result(query(psc, new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows))); } @Override public List query(String sql, @Nullable PreparedStatementSetter pss, RowMapper rowMapper) throws DataAccessException { - return result(query(sql, pss, new RowMapperResultSetExtractor<>(rowMapper))); + return result(query(sql, pss, new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows))); } @Override public List query(String sql, @Nullable Object @Nullable [] args, int[] argTypes, RowMapper rowMapper) throws DataAccessException { - return result(query(sql, args, argTypes, new RowMapperResultSetExtractor<>(rowMapper))); + return result(query(sql, args, argTypes, new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows))); } @Deprecated(since = "5.3") @Override public List query(String sql, @Nullable Object @Nullable [] args, RowMapper rowMapper) throws DataAccessException { - return result(query(sql, newArgPreparedStatementSetter(args), new RowMapperResultSetExtractor<>(rowMapper))); + return result(query(sql, newArgPreparedStatementSetter(args), new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows))); } @Override public List query(String sql, RowMapper rowMapper, @Nullable Object @Nullable ... args) throws DataAccessException { - return result(query(sql, newArgPreparedStatementSetter(args), new RowMapperResultSetExtractor<>(rowMapper))); + return result(query(sql, newArgPreparedStatementSetter(args), new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows))); } /** @@ -845,7 +846,7 @@ public Stream queryForStream(PreparedStatementCreator psc, @Nullable Prep } ResultSet rs = ps.executeQuery(); Connection con = ps.getConnection(); - return new ResultSetSpliterator<>(rs, rowMapper).stream().onClose(() -> { + return new ResultSetSpliterator<>(rs, rowMapper, this.maxRows).stream().onClose(() -> { JdbcUtils.closeResultSet(rs); if (pss instanceof ParameterDisposer parameterDisposer) { parameterDisposer.cleanupParameters(); @@ -1364,7 +1365,7 @@ protected Map processResultSet( } else if (param.getRowCallbackHandler() != null) { RowCallbackHandler rch = param.getRowCallbackHandler(); - (new RowCallbackHandlerResultSetExtractor(rch)).extractData(rs); + (new RowCallbackHandlerResultSetExtractor(rch, -1)).extractData(rs); return Collections.singletonMap(param.getName(), "ResultSet returned from stored procedure was processed"); } @@ -1747,13 +1748,17 @@ private static class RowCallbackHandlerResultSetExtractor implements ResultSetEx private final RowCallbackHandler rch; - public RowCallbackHandlerResultSetExtractor(RowCallbackHandler rch) { + private final int maxRows; + + public RowCallbackHandlerResultSetExtractor(RowCallbackHandler rch, int maxRows) { this.rch = rch; + this.maxRows = maxRows; } @Override public @Nullable Object extractData(ResultSet rs) throws SQLException { - while (rs.next()) { + int processed = 0; + while (rs.next() && (this.maxRows == -1 || (processed++) < this.maxRows)) { this.rch.processRow(rs); } return null; @@ -1771,17 +1776,20 @@ private static class ResultSetSpliterator implements Spliterator { private final RowMapper rowMapper; + private final int maxRows; + private int rowNum = 0; - public ResultSetSpliterator(ResultSet rs, RowMapper rowMapper) { + public ResultSetSpliterator(ResultSet rs, RowMapper rowMapper, int maxRows) { this.rs = rs; this.rowMapper = rowMapper; + this.maxRows = maxRows; } @Override public boolean tryAdvance(Consumer action) { try { - if (this.rs.next()) { + if (this.rs.next() && (this.maxRows == -1 || this.rowNum < this.maxRows)) { action.accept(this.rowMapper.mapRow(this.rs, this.rowNum++)); return true; } diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/RowMapperResultSetExtractor.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/RowMapperResultSetExtractor.java index 66311a18c93..e353c850b09 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/RowMapperResultSetExtractor.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/RowMapperResultSetExtractor.java @@ -52,6 +52,7 @@ * you can have executable query objects (containing row-mapping logic) there. * * @author Juergen Hoeller + * @author Yanming Zhou * @since 1.0.2 * @param the result element type * @see RowMapper @@ -64,6 +65,8 @@ public class RowMapperResultSetExtractor implements ResultSetExtractor rowMapper) { * (just used for optimized collection handling) */ public RowMapperResultSetExtractor(RowMapper rowMapper, int rowsExpected) { + this(rowMapper, rowsExpected, -1); + } + + /** + * Create a new RowMapperResultSetExtractor. + * @param rowMapper the RowMapper which creates an object for each row + * @param rowsExpected the number of expected rows + * (just used for optimized collection handling) + * @param maxRows the number of max rows + */ + public RowMapperResultSetExtractor(RowMapper rowMapper, int rowsExpected, int maxRows) { Assert.notNull(rowMapper, "RowMapper must not be null"); this.rowMapper = rowMapper; this.rowsExpected = rowsExpected; + this.maxRows = maxRows; } @@ -90,7 +105,7 @@ public RowMapperResultSetExtractor(RowMapper rowMapper, int rowsExpected) { public List extractData(ResultSet rs) throws SQLException { List results = (this.rowsExpected > 0 ? new ArrayList<>(this.rowsExpected) : new ArrayList<>()); int rowNum = 0; - while (rs.next()) { + while (rs.next() && (this.maxRows == -1 || rowNum < this.maxRows)) { results.add(this.rowMapper.mapRow(rs, rowNum++)); } return results; diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java index 6389af71735..a019470cbdf 100644 --- a/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java +++ b/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java @@ -32,7 +32,9 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.function.BiFunction; import java.util.function.Consumer; +import java.util.stream.Stream; import javax.sql.DataSource; @@ -77,6 +79,7 @@ * @author Thomas Risberg * @author Juergen Hoeller * @author Phillip Webb + * @author Yanming Zhou */ class JdbcTemplateTests { @@ -1236,6 +1239,50 @@ public int getBatchSize() { Collections.singletonMap("someId", 456)); } + @Test + void testSkipFurtherRowsOnceMaxRowsHasBeenReachedForRowMapper() throws Exception { + testDiscardFurtherRowsOnceMaxRowsHasBeenReached((template, sql) -> + template.query(sql, (rs, rowNum) -> rs.getString(1))); + } + + @Test + void testDiscardFurtherRowsOnceMaxRowsHasBeenReachedForRowCallbackHandler() throws Exception { + testDiscardFurtherRowsOnceMaxRowsHasBeenReached((template, sql) -> { + List list = new ArrayList<>(); + template.query(sql, (RowCallbackHandler) rs -> list.add(rs.getString(1))); + return list; + }); + } + + @Test + void testDiscardFurtherRowsOnceMaxRowsHasBeenReachedForStream() throws Exception { + testDiscardFurtherRowsOnceMaxRowsHasBeenReached((template, sql) -> { + try (Stream stream = template.queryForStream(sql, (rs, rowNum) -> rs.getString(1))) { + return stream.toList(); + } + }); + } + + private void testDiscardFurtherRowsOnceMaxRowsHasBeenReached(BiFunction> function) throws Exception { + String sql = "SELECT FORENAME FROM CUSTMR"; + String[] results = {"rod", "gary", " portia"}; + int maxRows = 2; + + given(this.resultSet.next()).willReturn(true, true, true, false); + given(this.resultSet.getString(1)).willReturn(results[0], results[1], results[2]); + given(this.connection.createStatement()).willReturn(this.preparedStatement); + + JdbcTemplate template = new JdbcTemplate(); + template.setDataSource(this.dataSource); + template.setMaxRows(maxRows); + + assertThat(function.apply(template, sql)).as("same length").hasSize(maxRows); + + verify(this.resultSet).close(); + verify(this.preparedStatement).close(); + verify(this.connection).close(); + } + private void mockDatabaseMetaData(boolean supportsBatchUpdates) throws SQLException { DatabaseMetaData databaseMetaData = mock(); given(databaseMetaData.getDatabaseProductName()).willReturn("MySQL");