From 4b65eb2c2ea90a6887a4cc403ff0e7700ae25a03 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Wed, 10 Sep 2025 22:39:47 -0700 Subject: [PATCH 01/37] Introduce batch support. --- ...ongoPreparedStatementIntegrationTests.java | 281 +++++++++++++++++- .../query/AbstractQueryIntegrationTests.java | 12 +- .../mutation/BatchUpdateIntegrationTests.java | 239 +++++++++++++++ .../jdbc/MongoPreparedStatement.java | 49 ++- .../hibernate/jdbc/MongoStatement.java | 139 +++++++-- .../jdbc/MongoPreparedStatementTests.java | 184 ++++++++++-- 6 files changed, 841 insertions(+), 63 deletions(-) create mode 100644 src/integrationTest/java/com/mongodb/hibernate/query/mutation/BatchUpdateIntegrationTests.java diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java index 84148316..7ac6905e 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java @@ -20,12 +20,15 @@ import static com.mongodb.hibernate.jdbc.MongoStatementIntegrationTests.doAndTerminateTransaction; import static com.mongodb.hibernate.jdbc.MongoStatementIntegrationTests.doWithSpecifiedAutoCommit; import static com.mongodb.hibernate.jdbc.MongoStatementIntegrationTests.insertTestData; +import static java.lang.String.format; +import static java.sql.Statement.SUCCESS_NO_INFO; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import com.mongodb.client.MongoCollection; @@ -35,7 +38,10 @@ import com.mongodb.hibernate.junit.MongoExtension; import java.math.BigDecimal; import java.sql.Connection; +import java.sql.PreparedStatement; import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import java.util.ArrayList; import java.util.List; import java.util.Random; import java.util.function.Function; @@ -47,9 +53,12 @@ import org.junit.jupiter.api.AutoClose; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; @ExtendWith(MongoExtension.class) class MongoPreparedStatementIntegrationTests { @@ -124,8 +133,19 @@ void testExecuteQuery() { } @Test - void testPreparedStatementAndResultSetRoundTrip() { + void testPreparedStatementExecuteUpdateAndResultSetRoundTrip() { + assertRoundTrip(PreparedStatement::executeUpdate); + } + + @Test + void testPreparedStatementExecuteBatchAndResultSetRoundTrip() { + assertRoundTrip(preparedStatement -> { + preparedStatement.addBatch(); + preparedStatement.executeBatch(); + }); + } + private void assertRoundTrip(SqlConsumer executor) { var random = new Random(); boolean booleanValue = random.nextBoolean(); @@ -165,8 +185,7 @@ void testPreparedStatementAndResultSetRoundTrip() { pstmt.setString(5, stringValue); pstmt.setBigDecimal(6, bigDecimalValue); pstmt.setBytes(7, bytes); - - pstmt.executeUpdate(); + executor.accept(pstmt); } }); @@ -210,6 +229,241 @@ void testPreparedStatementAndResultSetRoundTrip() { }); } + @Nested + class ExecuteBatchTests { + private static final String INSERT_MQL = + """ + { + insert: "books", + documents: [ + { + _id: 1, + title: "War and Peace" + }, + { + _id: 2, + title: "Anna Karenina" + }, + { + _id: 3, + title: "Crime and Punishment" + } + ] + }"""; + + @Test + void testEmptyBatch() { + doWorkAwareOfAutoCommit(connection -> { + try { + var pstmt = (MongoPreparedStatement) + connection.prepareStatement( + """ + { + insert: "books", + documents: [ + { + _id: 1 + } + ] + }"""); + int[] updateCounts = pstmt.executeBatch(); + assertEquals(0, updateCounts.length); + } catch (SQLException e) { + throw new RuntimeException(e); + } + }); + + assertThat(mongoCollection.find()).isEmpty(); + } + + @Test + @DisplayName("Test statement’s batch queue is reset once executeBatch returns") + void testBatchQueueIsResetAfterExecute() { + doWorkAwareOfAutoCommit(connection -> { + var pstmt = (MongoPreparedStatement) + connection.prepareStatement( + """ + { + insert: "books", + documents: [ + { + _id: {$undefined: true}, + title: {$undefined: true} + } + ] + }"""); + + pstmt.setInt(1, 1); + pstmt.setString(2, "War and Peace"); + pstmt.addBatch(); + assertExecuteBatch(pstmt, 1); + + assertExecuteBatch(pstmt, 0); + }); + + assertThat(mongoCollection.find()) + .containsExactly( + BsonDocument.parse( + """ + { + _id: 1, + title: "War and Peace" + }""")); + } + + @Test + @DisplayName("Test values set for the parameter markers of PreparedStatement are not reset when it is executed") + void testBatchParametersReuse() { + doWorkAwareOfAutoCommit(connection -> { + var pstmt = (MongoPreparedStatement) + connection.prepareStatement( + """ + { + insert: "books", + documents: [ + { + _id: {$undefined: true}, + title: {$undefined: true} + } + ] + }"""); + + pstmt.setInt(1, 1); + pstmt.setString(2, "War and Peace"); + pstmt.addBatch(); + assertExecuteBatch(pstmt, 1); + + pstmt.setInt(1, 2); + // No need to set title again, it should be reused from the previous execution + pstmt.addBatch(); + assertExecuteBatch(pstmt, 1); + }); + + assertThat(mongoCollection.find()) + .containsExactly( + BsonDocument.parse( + """ + { + _id: 1, + title: "War and Peace" + }"""), + BsonDocument.parse( + """ + { + _id: 2, + title: "War and Peace" + }""")); + } + + @Test + void testBatchInsert() { + int batchCount = 3; + doWorkAwareOfAutoCommit(connection -> { + var pstmt = (MongoPreparedStatement) + connection.prepareStatement( + """ + { + insert: "books", + documents: [{ + _id: {$undefined: true}, + title: {$undefined: true} + }] + }"""); + + for (int i = 1; i <= batchCount; i++) { + pstmt.setInt(1, i); + pstmt.setString(2, "Book " + i); + pstmt.addBatch(); + } + assertExecuteBatch(pstmt, batchCount); + }); + + var expectedDocs = new ArrayList(); + for (int i = 0; i < batchCount; i++) { + expectedDocs.add(BsonDocument.parse(format( + """ + { + "_id": %d, + "title": "Book %d" + }""", + i + 1, i + 1))); + } + assertThat(mongoCollection.find()).containsExactlyElementsOf(expectedDocs); + } + + @Test + void testBatchUpdate() { + insertTestData(session, INSERT_MQL); + + int batchCount = 3; + doWorkAwareOfAutoCommit(connection -> { + var pstmt = (MongoPreparedStatement) + connection.prepareStatement( + """ + { + update: "books", + updates: [{ + q: { _id: { $undefined: true } }, + u: { $set: { title: { $undefined: true } } }, + multi: true + }] + }"""); + for (int i = 1; i <= batchCount; i++) { + pstmt.setInt(1, i); + pstmt.setString(2, "Book " + i); + pstmt.addBatch(); + } + assertExecuteBatch(pstmt, batchCount); + }); + + var expectedDocs = new ArrayList(); + for (int i = 0; i < batchCount; i++) { + expectedDocs.add(BsonDocument.parse(format( + """ + { + "_id": %d, + "title": "Book %d" + }""", + i + 1, i + 1))); + } + assertThat(mongoCollection.find()).containsExactlyElementsOf(expectedDocs); + } + + @Test + void testBatchDelete() { + insertTestData(session, INSERT_MQL); + + int batchCount = 3; + doWorkAwareOfAutoCommit(connection -> { + var pstmt = (MongoPreparedStatement) + connection.prepareStatement( + """ + { + delete: "books", + deletes: [{ + q: { _id: { $undefined: true } }, + limit: 0 + }] + }"""); + for (int i = 1; i <= batchCount; i++) { + pstmt.setInt(1, i); + pstmt.addBatch(); + } + assertExecuteBatch(pstmt, batchCount); + }); + + assertThat(mongoCollection.find()).isEmpty(); + } + + private void assertExecuteBatch(MongoPreparedStatement pstmt, int expectedBatchResultSize) throws SQLException { + int[] updateCounts = pstmt.executeBatch(); + assertEquals(expectedBatchResultSize, updateCounts.length); + for (int updateCount : updateCounts) { + assertEquals(SUCCESS_NO_INFO, updateCount); + } + } + } + @Nested class ExecuteUpdateTests { @@ -386,6 +640,23 @@ void testDelete() { }"""))); } + @ParameterizedTest(name = "testNotSupportedCommands. Parameters: {0}") + @ValueSource(strings = {"findAndModify", "aggregate", "bulkWrite"}) + void testNotSupportedCommands(String commandName) { + doWorkAwareOfAutoCommit(connection -> { + try (PreparedStatement findAndModify = connection.prepareStatement(format( + """ + { + findAndModify: "books" + }""", + commandName))) { + SQLFeatureNotSupportedException exception = + assertThrows(SQLFeatureNotSupportedException.class, findAndModify::executeUpdate); + assertThat(exception.getMessage()).contains("findAndModify"); + } + }); + } + private void assertExecuteUpdate( Function pstmtProvider, int expectedUpdatedRowCount, @@ -407,4 +678,8 @@ private void doWorkAwareOfAutoCommit(Work work) { void doAwareOfAutoCommit(Connection connection, SqlExecutable work) throws SQLException { doWithSpecifiedAutoCommit(false, connection, () -> doAndTerminateTransaction(connection, work)); } + + interface SqlConsumer { + void accept(T t) throws SQLException; + } } diff --git a/src/integrationTest/java/com/mongodb/hibernate/query/AbstractQueryIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/query/AbstractQueryIntegrationTests.java index 19ee1e81..5a5007b5 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/query/AbstractQueryIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/query/AbstractQueryIntegrationTests.java @@ -178,13 +178,13 @@ protected void assertSelectQueryFailure( expectedExceptionMessageParameters); } - protected void assertActualCommand(BsonDocument expectedCommand) { + protected void assertActualCommand(BsonDocument... expectedCommands) { var capturedCommands = testCommandListener.getStartedCommands(); - - assertThat(capturedCommands) - .singleElement() - .asInstanceOf(InstanceOfAssertFactories.MAP) - .containsAllEntriesOf(expectedCommand); + assertThat(capturedCommands).hasSize(expectedCommands.length); + for (int i = 0; i < expectedCommands.length; i++) { + BsonDocument actual = capturedCommands.get(i); + assertThat(actual).asInstanceOf(InstanceOfAssertFactories.MAP).containsAllEntriesOf(expectedCommands[i]); + } } protected void assertMutationQuery( diff --git a/src/integrationTest/java/com/mongodb/hibernate/query/mutation/BatchUpdateIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/query/mutation/BatchUpdateIntegrationTests.java new file mode 100644 index 00000000..732b45ad --- /dev/null +++ b/src/integrationTest/java/com/mongodb/hibernate/query/mutation/BatchUpdateIntegrationTests.java @@ -0,0 +1,239 @@ +/* + * Copyright 2025-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.query.mutation; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.bson.RawBsonDocument.parse; + +import com.mongodb.client.MongoCollection; +import com.mongodb.hibernate.junit.InjectMongoCollection; +import com.mongodb.hibernate.query.AbstractQueryIntegrationTests; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Table; +import org.bson.BsonDocument; +import org.hibernate.cfg.AvailableSettings; +import org.hibernate.engine.spi.SessionImplementor; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.ServiceRegistry; +import org.hibernate.testing.orm.junit.Setting; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +@DomainModel(annotatedClasses = BatchUpdateIntegrationTests.Item.class) +@ServiceRegistry(settings = @Setting(name = AvailableSettings.STATEMENT_BATCH_SIZE, value = "3")) +class BatchUpdateIntegrationTests extends AbstractQueryIntegrationTests { + + private static final String COLLECTION_NAME = "items"; + private static final int BATCH_COUNT = 5; + + @InjectMongoCollection(COLLECTION_NAME) + private static MongoCollection collection; + + @BeforeEach + void beforeEach() { + getTestCommandListener().clear(); + } + + @Test + // TODO remove this test.We should forbid native mutation queries with batching + void testNativeInsertMutationQuery() { + getSessionFactoryScope().inTransaction(session -> { + session.createNativeMutationQuery( + """ + { + "insert": "items", + "ordered": true, + "documents": [ + { "_id": 101, "string": "native101"}, + { "_id": 102, "string": "native102"} + ] + } + """) + .executeUpdate(); + + assertActualCommand( + parse( + """ + { + "insert": "items", + "ordered": true, + "documents": [ + { "_id": 101, "string": "native101"}, + { "_id": 102, "string": "native102"} + ] + } + """)); + }); + } + + @Test + void testBatchInsert() { + getSessionFactoryScope().inTransaction(session -> { + for (int i = 1; i <= BATCH_COUNT; i++) { + session.persist(new Item(i, String.valueOf(i))); + } + session.flush(); + assertActualCommand( + parse( + """ + { + "insert": "items", + "ordered": true, + "documents": [ + { "_id": 1, "string": "1"}, + { "_id": 2, "string": "2"}, + { "_id": 3, "string": "3"} + ] + } + """), + parse( + """ + { + "insert": "items", + "ordered": true, + "documents": [ + { "_id": 4, "string": "4"}, + { "_id": 5, "string": "5"} + ] + } + """)); + }); + + assertThat(collection.find()) + .containsExactlyElementsOf(java.util.List.of( + BsonDocument.parse("{ _id: 1, string: '1' }"), + BsonDocument.parse("{ _id: 2, string: '2' }"), + BsonDocument.parse("{ _id: 3, string: '3' }"), + BsonDocument.parse("{ _id: 4, string: '4' }"), + BsonDocument.parse("{ _id: 5, string: '5' }"))); + } + + @Nested + class BatchUpdateTests { + @Test + void testBatchUpdate() { + getSessionFactoryScope().inTransaction(session -> { + insertTestData(session); + for (int i = 1; i <= BATCH_COUNT; i++) { + Item item = session.find(Item.class, i); + item.string = "u" + i; + } + session.flush(); + assertActualCommand( + parse( + """ + { + "update": "items", + "ordered": true, + "updates": [ + { "q": { "_id": { "$eq": 1 } }, "u": { "$set": { "string": "u1" } }, "multi": true }, + { "q": { "_id": { "$eq": 2 } }, "u": { "$set": { "string": "u2" } }, "multi": true }, + { "q": { "_id": { "$eq": 3 } }, "u": { "$set": { "string": "u3" } }, "multi": true } + ] + } + """), + parse( + """ + { + "update": "items", + "ordered": true, + "updates": [ + { "q": { "_id": { "$eq": 4 } }, "u": { "$set": { "string": "u4" } }, "multi": true }, + { "q": { "_id": { "$eq": 5 } }, "u": { "$set": { "string": "u5" } }, "multi": true } + ] + } + """)); + }); + + assertThat(collection.find()) + .containsExactlyElementsOf(java.util.List.of( + BsonDocument.parse("{ _id: 1, string: 'u1' }"), + BsonDocument.parse("{ _id: 2, string: 'u2' }"), + BsonDocument.parse("{ _id: 3, string: 'u3' }"), + BsonDocument.parse("{ _id: 4, string: 'u4' }"), + BsonDocument.parse("{ _id: 5, string: 'u5' }"))); + } + } + + @Nested + class BatchDeleteTests { + + @Test + void testBatchDelete() { + getSessionFactoryScope().inTransaction(session -> { + insertTestData(session); + for (int i = 1; i <= BATCH_COUNT; i++) { + var item = session.find(Item.class, i); + session.remove(item); + } + session.flush(); + assertActualCommand( + parse( + """ + { + "delete": "items", + "ordered": true, + "deletes": [ + {"q": {"_id": {"$eq": 1}}, "limit": 0}, + {"q": {"_id": {"$eq": 2}}, "limit": 0}, + {"q": {"_id": {"$eq": 3}}, "limit": 0} + ] + } + """), + parse( + """ + { + "delete": "items", + "ordered": true, + "deletes": [ + {"q": {"_id": {"$eq": 4}}, "limit": 0} + {"q": {"_id": {"$eq": 5}}, "limit": 0} + ] + } + """)); + }); + + assertThat(collection.find()).isEmpty(); + } + } + + private void insertTestData(final SessionImplementor session) { + for (int i = 1; i <= 5; i++) { + session.persist(new Item(i, String.valueOf(i))); + } + session.flush(); + getTestCommandListener().clear(); + } + + @Entity + @Table(name = COLLECTION_NAME) + static class Item { + @Id + int id; + + String string; + + Item() {} + + Item(int id, String string) { + this.id = id; + this.string = string; + } + } +} diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java index 20c4ec07..ac79a143 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java @@ -21,6 +21,7 @@ import static com.mongodb.hibernate.internal.type.ValueConversions.toBsonValue; import static java.lang.String.format; +import com.mongodb.MongoBulkWriteException; import com.mongodb.client.ClientSession; import com.mongodb.client.MongoDatabase; import com.mongodb.hibernate.internal.FeatureNotSupportedException; @@ -28,6 +29,7 @@ import com.mongodb.hibernate.internal.type.ObjectIdJdbcType; import java.math.BigDecimal; import java.sql.Array; +import java.sql.BatchUpdateException; import java.sql.Date; import java.sql.JDBCType; import java.sql.PreparedStatement; @@ -35,10 +37,12 @@ import java.sql.SQLException; import java.sql.SQLFeatureNotSupportedException; import java.sql.SQLSyntaxErrorException; +import java.sql.Statement; import java.sql.Time; import java.sql.Timestamp; import java.sql.Types; import java.util.ArrayList; +import java.util.Arrays; import java.util.Calendar; import java.util.List; import java.util.Set; @@ -52,15 +56,17 @@ final class MongoPreparedStatement extends MongoStatement implements PreparedStatementAdapter { private final BsonDocument command; - + private final List batchCommands; private final List parameterValueSetters; + private static final int[] EMPTY_BATCH_RESULT = new int[0]; MongoPreparedStatement( MongoDatabase mongoDatabase, ClientSession clientSession, MongoConnection mongoConnection, String mql) throws SQLSyntaxErrorException { super(mongoDatabase, clientSession, mongoConnection); - this.command = MongoStatement.parse(mql); - this.parameterValueSetters = new ArrayList<>(); + command = MongoStatement.parse(mql); + batchCommands = new ArrayList<>(); + parameterValueSetters = new ArrayList<>(); parseParameters(command, parameterValueSetters); } @@ -200,7 +206,34 @@ public void setObject(int parameterIndex, Object x, int targetSqlType) throws SQ @Override public void addBatch() throws SQLException { checkClosed(); - throw new SQLFeatureNotSupportedException("TODO-HIBERNATE-35 https://jira.mongodb.org/browse/HIBERNATE-35"); + checkAllParametersSet(); // TODO first check that all parameters are set for the previous batch. + batchCommands.add(command.clone()); + } + + @Override + public void clearBatch() throws SQLException { + checkClosed(); + batchCommands.clear(); + } + + @Override + public int[] executeBatch() throws SQLException { + checkClosed(); + closeLastOpenResultSet(); + if (batchCommands.isEmpty()) { + return EMPTY_BATCH_RESULT; + } + try { + executeBulkWrite(batchCommands); + var rowCounts = new int[batchCommands.size()]; + // We cannot determine the actual number of rows affected for each command in the batch. + Arrays.fill(rowCounts, Statement.SUCCESS_NO_INFO); + return rowCounts; + } catch (MongoBulkWriteException mongoBulkWriteException) { + throw createBatchUpdateException(mongoBulkWriteException, command.getFirstKey()); + } finally { + batchCommands.clear(); + } } @Override @@ -360,4 +393,12 @@ private static void checkComparatorNotComparingWithNullValues(BsonDocument docum } } } + + static BatchUpdateException createBatchUpdateException( + final MongoBulkWriteException mongoBulkWriteException, final String commandName) { + int updateCount = getUpdateCount(commandName, mongoBulkWriteException.getWriteResult()); + int[] updateCounts = new int[updateCount]; + Arrays.fill(updateCounts, SUCCESS_NO_INFO); + return new BatchUpdateException(mongoBulkWriteException.getMessage(), updateCounts, mongoBulkWriteException); + } } diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 2442a4e8..7f416310 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -16,13 +16,30 @@ package com.mongodb.hibernate.jdbc; +import static com.mongodb.hibernate.internal.MongoAssertions.assertNotNull; +import static com.mongodb.hibernate.internal.MongoAssertions.assertTrue; import static com.mongodb.hibernate.internal.MongoConstants.ID_FIELD_NAME; import static com.mongodb.hibernate.internal.VisibleForTesting.AccessModifier.PRIVATE; import static java.lang.String.format; +import static java.util.Collections.singletonList; import static java.util.stream.Collectors.toCollection; +import com.mongodb.MongoBulkWriteException; +import com.mongodb.MongoExecutionTimeoutException; +import com.mongodb.MongoSocketReadTimeoutException; +import com.mongodb.MongoSocketWriteTimeoutException; +import com.mongodb.MongoTimeoutException; +import com.mongodb.bulk.BulkWriteResult; import com.mongodb.client.ClientSession; +import com.mongodb.client.MongoCollection; import com.mongodb.client.MongoDatabase; +import com.mongodb.client.model.DeleteManyModel; +import com.mongodb.client.model.DeleteOneModel; +import com.mongodb.client.model.DeleteOptions; +import com.mongodb.client.model.InsertOneModel; +import com.mongodb.client.model.UpdateManyModel; +import com.mongodb.client.model.UpdateOneModel; +import com.mongodb.client.model.WriteModel; import com.mongodb.hibernate.internal.FeatureNotSupportedException; import com.mongodb.hibernate.internal.VisibleForTesting; import java.sql.Connection; @@ -30,8 +47,10 @@ import java.sql.SQLException; import java.sql.SQLFeatureNotSupportedException; import java.sql.SQLSyntaxErrorException; +import java.sql.SQLTimeoutException; import java.sql.SQLWarning; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Map; import org.bson.BsonDocument; @@ -127,10 +146,37 @@ public int executeUpdate(String mql) throws SQLException { int executeUpdateCommand(BsonDocument command) throws SQLException { try { - startTransactionIfNeeded(); - return mongoDatabase.runCommand(clientSession, command).getInteger("n"); - } catch (RuntimeException e) { - throw new SQLException("Failed to execute update command", e); + var bulkWriteResult = executeBulkWrite(singletonList(command)); + return getUpdateCount(command.getFirstKey(), bulkWriteResult); + } catch (MongoBulkWriteException mongoBulkWriteException) { + throw new SQLException(mongoBulkWriteException.getMessage(), mongoBulkWriteException); + } + } + + BulkWriteResult executeBulkWrite(List commandBatch) throws SQLException { + startTransactionIfNeeded(); + var firstDocumentInBatch = commandBatch.get(0); + var commandName = assertNotNull(firstDocumentInBatch.getFirstKey()); + var collectionName = + assertNotNull(firstDocumentInBatch.getString(commandName).getValue()); + MongoCollection collection = mongoDatabase.getCollection(collectionName, BsonDocument.class); + + try { + var writeModels = new ArrayList>(commandBatch.size()); + for (var command : commandBatch) { + assertTrue(collectionName.equals(command.getString(commandName).getValue())); + convertToWriteModels(command, commandName, writeModels); + } + return collection.bulkWrite(clientSession, writeModels); + } catch (MongoSocketReadTimeoutException + | MongoSocketWriteTimeoutException + | MongoTimeoutException + | MongoExecutionTimeoutException mongoTimeoutException) { + throw new SQLTimeoutException(mongoTimeoutException.getMessage(), mongoTimeoutException); + } catch (MongoBulkWriteException mongoBulkWriteException) { + throw mongoBulkWriteException; + } catch (RuntimeException runtimeException) { + throw new SQLException("Failed to execute update", runtimeException); } } @@ -186,25 +232,6 @@ public int getUpdateCount() throws SQLException { throw new SQLFeatureNotSupportedException("To be implemented in scope of index and unique constraint creation"); } - @Override - public void addBatch(String mql) throws SQLException { - checkClosed(); - throw new SQLFeatureNotSupportedException("TODO-HIBERNATE-35 https://jira.mongodb.org/browse/HIBERNATE-35"); - } - - @Override - public void clearBatch() throws SQLException { - checkClosed(); - throw new SQLFeatureNotSupportedException("TODO-HIBERNATE-35 https://jira.mongodb.org/browse/HIBERNATE-35"); - } - - @Override - public int[] executeBatch() throws SQLException { - checkClosed(); - closeLastOpenResultSet(); - throw new SQLFeatureNotSupportedException("TODO-HIBERNATE-35 https://jira.mongodb.org/browse/HIBERNATE-35"); - } - @Override public Connection getConnection() throws SQLException { checkClosed(); @@ -240,9 +267,73 @@ static BsonDocument parse(String mql) throws SQLSyntaxErrorException { * Starts transaction for the first {@link java.sql.Statement} executing if * {@linkplain MongoConnection#getAutoCommit() auto-commit} is disabled. */ - private void startTransactionIfNeeded() throws SQLException { + void startTransactionIfNeeded() throws SQLException { if (!mongoConnection.getAutoCommit() && !clientSession.hasActiveTransaction()) { clientSession.startTransaction(); } } + + static int getUpdateCount(final String commandName, final BulkWriteResult bulkWriteResult) { + return switch (commandName) { + case "insert" -> bulkWriteResult.getInsertedCount(); + case "update" -> bulkWriteResult.getModifiedCount(); + case "delete" -> bulkWriteResult.getDeletedCount(); + default -> throw new FeatureNotSupportedException("Unsupported command: " + commandName); + }; + } + + private static void convertToWriteModels( + final BsonDocument command, + final String commandName, + final Collection> writeModels) + throws SQLFeatureNotSupportedException { + switch (commandName) { + case "insert": + var documents = command.getArray("documents"); + for (var insertDocument : documents) { + writeModels.add(createInsertModel(insertDocument.asDocument())); + } + break; + case "update": + var updates = command.getArray("updates").getValues(); + for (var updateDocument : updates) { + writeModels.add(createUpdateModel(updateDocument.asDocument())); + } + break; + case "delete": + var deletes = command.getArray("deletes"); + for (var deleteDocument : deletes) { + writeModels.add(createDeleteModel(deleteDocument.asDocument())); + } + break; + default: + throw new SQLFeatureNotSupportedException("Unsupported command: " + commandName); + } + } + + private static WriteModel createInsertModel(final BsonDocument document) { + return new InsertOneModel<>(document); + } + + private static WriteModel createDeleteModel(final BsonDocument deleteDocument) { + var isSingleDelete = deleteDocument.getNumber("limit").intValue() == 1; + var queryFilter = deleteDocument.getDocument("q"); + + if (isSingleDelete) { + new DeleteOptions(); + return new DeleteOneModel<>(queryFilter); + } + return new DeleteManyModel<>(queryFilter); + } + + private static WriteModel createUpdateModel(final BsonDocument updateDocument) { + var isMulti = updateDocument.getBoolean("multi").getValue(); + var queryFilter = updateDocument.getDocument("q"); + var updatePipeline = updateDocument.getDocument("u"); + + if (isMulti) { + return new UpdateManyModel<>(queryFilter, updatePipeline); + } + return new UpdateOneModel<>(queryFilter, updatePipeline); + } } diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java index be8f84ac..1e68da1b 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java @@ -16,28 +16,38 @@ package com.mongodb.hibernate.jdbc; +import static java.sql.Statement.SUCCESS_NO_INFO; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.verify; +import com.mongodb.MongoBulkWriteException; +import com.mongodb.ServerAddress; +import com.mongodb.bulk.BulkWriteError; +import com.mongodb.bulk.BulkWriteResult; import com.mongodb.client.AggregateIterable; import com.mongodb.client.ClientSession; import com.mongodb.client.MongoCollection; import com.mongodb.client.MongoCursor; import com.mongodb.client.MongoDatabase; +import com.mongodb.client.model.InsertOneModel; +import com.mongodb.client.model.WriteModel; import com.mongodb.hibernate.internal.type.ObjectIdJdbcType; import java.math.BigDecimal; import java.sql.Array; +import java.sql.BatchUpdateException; import java.sql.Date; import java.sql.ResultSet; import java.sql.SQLException; @@ -47,6 +57,7 @@ import java.sql.Types; import java.util.Calendar; import java.util.List; +import java.util.Set; import java.util.function.Consumer; import org.bson.BsonArray; import org.bson.BsonBoolean; @@ -54,7 +65,6 @@ import org.bson.BsonInt32; import org.bson.BsonObjectId; import org.bson.BsonString; -import org.bson.Document; import org.bson.types.ObjectId; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; @@ -74,6 +84,9 @@ class MongoPreparedStatementTests { @Mock private MongoDatabase mongoDatabase; + @Mock + MongoCollection mongoCollection; + @Mock private ClientSession clientSession; @@ -108,15 +121,16 @@ private MongoPreparedStatement createMongoPreparedStatement(String mql) throws S class ParameterValueSettingTests { @Captor - private ArgumentCaptor commandCaptor; + private ArgumentCaptor>> commandCaptor; @Test @DisplayName("Happy path when all parameters are provided values") void testSuccess() throws SQLException { + BulkWriteResult bulkWriteResult = Mockito.mock(BulkWriteResult.class); - doReturn(Document.parse("{ok: 1.0, n: 1}")) - .when(mongoDatabase) - .runCommand(eq(clientSession), any(BsonDocument.class)); + doReturn(mongoCollection).when(mongoDatabase).getCollection(anyString(), eq(BsonDocument.class)); + doReturn(bulkWriteResult).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + doReturn(1).when(bulkWriteResult).getInsertedCount(); try (var preparedStatement = createMongoPreparedStatement(EXAMPLE_MQL)) { @@ -130,24 +144,139 @@ void testSuccess() throws SQLException { preparedStatement.executeUpdate(); - verify(mongoDatabase).runCommand(eq(clientSession), commandCaptor.capture()); - var command = commandCaptor.getValue(); + verify(mongoCollection).bulkWrite(eq(clientSession), commandCaptor.capture()); + var writeModels = commandCaptor.getValue(); + assertEquals(1, writeModels.size()); var expectedDoc = new BsonDocument() - .append("insert", new BsonString("items")) + .append("string1", new BsonString("s1")) + .append("string2", new BsonString("s2")) + .append("int32", new BsonInt32(1)) + .append("boolean", BsonBoolean.TRUE) .append( - "documents", - new BsonArray(List.of(new BsonDocument() - .append("string1", new BsonString("s1")) - .append("string2", new BsonString("s2")) - .append("int32", new BsonInt32(1)) - .append("boolean", BsonBoolean.TRUE) - .append( - "stringAndObjectId", - new BsonArray(List.of( - new BsonString("array element"), - new BsonObjectId(new ObjectId(1, 2))))) - .append("objectId", new BsonObjectId(new ObjectId(2, 0)))))); - assertEquals(expectedDoc, command); + "stringAndObjectId", + new BsonArray( + List.of(new BsonString("array element"), new BsonObjectId(new ObjectId(1, 2))))) + .append("objectId", new BsonObjectId(new ObjectId(2, 0))); + assertInstanceOf(InsertOneModel.class, writeModels.get(0)); + assertEquals(expectedDoc, ((InsertOneModel) writeModels.get(0)).getDocument()); + } + } + } + + @Nested + class ExecuteBatchThrows { + + static final String BULK_WRITE_ERROR_MESSAGE = "Test message"; + + @Mock + BulkWriteResult bulkWriteResult; + + @BeforeEach + void beforeEach() { + doReturn(mongoCollection).when(mongoDatabase).getCollection(anyString(), eq(BsonDocument.class)); + BulkWriteError bulkWriteError = new BulkWriteError(10, BULK_WRITE_ERROR_MESSAGE, new BsonDocument(), 0); + List writeErrors = List.of(bulkWriteError); + + MongoBulkWriteException mongoBulkWriteException = new MongoBulkWriteException( + bulkWriteResult, writeErrors, null, new ServerAddress("localhost"), Set.of("label")); + + doThrow(mongoBulkWriteException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + } + + @Test + void testBatchInsert() throws SQLException { + doReturn(1).when(bulkWriteResult).getInsertedCount(); + + String mql = + """ + { + insert: "items", + documents: [ + { _id: { $undefined: true } } + ] + } + """; + + try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(mql)) { + mongoPreparedStatement.setInt(1, 1); + mongoPreparedStatement.addBatch(); + mongoPreparedStatement.addBatch(); + + BatchUpdateException batchUpdateException = + assertThrows(BatchUpdateException.class, mongoPreparedStatement::executeBatch); + + int[] updateCounts = batchUpdateException.getUpdateCounts(); + assertEquals(1, updateCounts.length); + assertTrue(batchUpdateException.getMessage().contains(BULK_WRITE_ERROR_MESSAGE)); + assertEquals(0, batchUpdateException.getErrorCode()); + assertUpdateCounts(updateCounts); + } + } + + @Test + void testBatchUpdate() throws SQLException { + doReturn(1).when(bulkWriteResult).getModifiedCount(); + + String mql = + """ + { + update: "items", + updates: [ + { q: { _id: { $undefined: true } }, u: { $set: { touched: true } }, multi: false } + ] + } + """; + + try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(mql)) { + mongoPreparedStatement.setInt(1, 1); + mongoPreparedStatement.addBatch(); + mongoPreparedStatement.addBatch(); + + BatchUpdateException batchUpdateException = + assertThrows(BatchUpdateException.class, mongoPreparedStatement::executeBatch); + + int[] updateCounts = batchUpdateException.getUpdateCounts(); + assertEquals(1, updateCounts.length); + assertTrue(batchUpdateException.getMessage().contains(BULK_WRITE_ERROR_MESSAGE)); + assertEquals(0, batchUpdateException.getErrorCode()); + assertUpdateCounts(updateCounts); + } + } + + @Test + void testBatchDelete() throws SQLException { + doReturn(1).when(bulkWriteResult).getDeletedCount(); + + String mql = + """ + { + delete: "items", + deletes: [ + { q: { _id: { $undefined: true } }, limit: 1 } + ] + } + """; + + try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(mql)) { + mongoPreparedStatement.setInt(1, 1); + mongoPreparedStatement.addBatch(); + mongoPreparedStatement.addBatch(); + + BatchUpdateException batchUpdateException = + assertThrows(BatchUpdateException.class, mongoPreparedStatement::executeBatch); + + int[] updateCounts = batchUpdateException.getUpdateCounts(); + assertEquals(1, updateCounts.length); + assertTrue(batchUpdateException.getMessage().contains(BULK_WRITE_ERROR_MESSAGE)); + assertNull(batchUpdateException.getSQLState()); + assertEquals(0, batchUpdateException.getErrorCode()); + assertUpdateCounts(updateCounts); + } + } + + private static void assertUpdateCounts(final int[] updateCounts) { + for (int count : updateCounts) { + assertEquals(SUCCESS_NO_INFO, count); } } } @@ -208,10 +337,13 @@ void testExecuteQuery() throws SQLException { @Test void testExecuteUpdate() throws SQLException { - doReturn(Document.parse("{n: 10}")) - .when(mongoDatabase) - .runCommand(eq(clientSession), any(BsonDocument.class)); - mongoPreparedStatement.executeUpdate(); + assertThrows(SQLException.class, () -> mongoPreparedStatement.executeUpdate()); + assertTrue(lastOpenResultSet.isClosed()); + } + + @Test + void testExecuteBatch() throws SQLException { + mongoPreparedStatement.executeBatch(); assertTrue(lastOpenResultSet.isClosed()); } } From 5b17d582ce116a7317c045ce3759380e4a423830 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Thu, 2 Oct 2025 10:21:25 -0700 Subject: [PATCH 02/37] Add tests and refactor MongoStatements. --- ...ongoPreparedStatementIntegrationTests.java | 55 ++- .../mutation/BatchUpdateIntegrationTests.java | 203 +++++------ .../jdbc/MongoPreparedStatement.java | 20 +- .../hibernate/jdbc/MongoStatement.java | 231 ++++++++++--- .../jdbc/MongoPreparedStatementTests.java | 321 +++++++++++++----- .../hibernate/jdbc/MongoStatementTests.java | 32 +- 6 files changed, 567 insertions(+), 295 deletions(-) diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java index 7ac6905e..3cb88160 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java @@ -644,15 +644,62 @@ void testDelete() { @ValueSource(strings = {"findAndModify", "aggregate", "bulkWrite"}) void testNotSupportedCommands(String commandName) { doWorkAwareOfAutoCommit(connection -> { - try (PreparedStatement findAndModify = connection.prepareStatement(format( + try (PreparedStatement pstm = connection.prepareStatement(format( """ { - findAndModify: "books" + %s: "books" }""", commandName))) { SQLFeatureNotSupportedException exception = - assertThrows(SQLFeatureNotSupportedException.class, findAndModify::executeUpdate); - assertThat(exception.getMessage()).contains("findAndModify"); + assertThrows(SQLFeatureNotSupportedException.class, pstm::executeUpdate); + assertThat(exception.getMessage()).contains(commandName); + } + }); + } + + @Test + void testNotSupportedUpdateElements() { + doWorkAwareOfAutoCommit(connection -> { + try (PreparedStatement pstm = connection.prepareStatement( + format( + """ + { + update: "books", + updates: [ + { + q: { author: { $eq: "Leo Tolstoy" } }, + u: { $set: { outOfStock: true } }, + multi: true, + hint: { _id: 1 } + } + ] + }"""))) { + SQLFeatureNotSupportedException exception = + assertThrows(SQLFeatureNotSupportedException.class, pstm::executeUpdate); + assertThat(exception.getMessage()).isEqualTo("Unsupported elements in update command: [hint]"); + } + }); + } + + @Test + void testNotSupportedDeleteElements() { + doWorkAwareOfAutoCommit(connection -> { + try (PreparedStatement pstm = connection.prepareStatement( + format( + """ + { + delete: "books", + deletes: [ + { + q: { author: { $eq: "Leo Tolstoy" } }, + limit: 0, + hint: { _id: 1 } + } + ] + }"""))) { + SQLFeatureNotSupportedException exception = + assertThrows(SQLFeatureNotSupportedException.class, pstm::executeUpdate); + assertThat(exception.getMessage()).isEqualTo("Unsupported elements in delete command: [hint]"); } }); } diff --git a/src/integrationTest/java/com/mongodb/hibernate/query/mutation/BatchUpdateIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/query/mutation/BatchUpdateIntegrationTests.java index 732b45ad..b037d1e3 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/query/mutation/BatchUpdateIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/query/mutation/BatchUpdateIntegrationTests.java @@ -25,6 +25,7 @@ import jakarta.persistence.Entity; import jakarta.persistence.Id; import jakarta.persistence.Table; +import java.util.List; import org.bson.BsonDocument; import org.hibernate.cfg.AvailableSettings; import org.hibernate.engine.spi.SessionImplementor; @@ -32,7 +33,6 @@ import org.hibernate.testing.orm.junit.ServiceRegistry; import org.hibernate.testing.orm.junit.Setting; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @DomainModel(annotatedClasses = BatchUpdateIntegrationTests.Item.class) @@ -40,7 +40,7 @@ class BatchUpdateIntegrationTests extends AbstractQueryIntegrationTests { private static final String COLLECTION_NAME = "items"; - private static final int BATCH_COUNT = 5; + private static final int ENTITIES_TO_PERSIST_COUNT = 5; @InjectMongoCollection(COLLECTION_NAME) private static MongoCollection collection; @@ -50,42 +50,10 @@ void beforeEach() { getTestCommandListener().clear(); } - @Test - // TODO remove this test.We should forbid native mutation queries with batching - void testNativeInsertMutationQuery() { - getSessionFactoryScope().inTransaction(session -> { - session.createNativeMutationQuery( - """ - { - "insert": "items", - "ordered": true, - "documents": [ - { "_id": 101, "string": "native101"}, - { "_id": 102, "string": "native102"} - ] - } - """) - .executeUpdate(); - - assertActualCommand( - parse( - """ - { - "insert": "items", - "ordered": true, - "documents": [ - { "_id": 101, "string": "native101"}, - { "_id": 102, "string": "native102"} - ] - } - """)); - }); - } - @Test void testBatchInsert() { getSessionFactoryScope().inTransaction(session -> { - for (int i = 1; i <= BATCH_COUNT; i++) { + for (int i = 1; i <= ENTITIES_TO_PERSIST_COUNT; i++) { session.persist(new Item(i, String.valueOf(i))); } session.flush(); @@ -116,7 +84,7 @@ void testBatchInsert() { }); assertThat(collection.find()) - .containsExactlyElementsOf(java.util.List.of( + .containsExactlyElementsOf(List.of( BsonDocument.parse("{ _id: 1, string: '1' }"), BsonDocument.parse("{ _id: 2, string: '2' }"), BsonDocument.parse("{ _id: 3, string: '3' }"), @@ -124,93 +92,86 @@ void testBatchInsert() { BsonDocument.parse("{ _id: 5, string: '5' }"))); } - @Nested - class BatchUpdateTests { - @Test - void testBatchUpdate() { - getSessionFactoryScope().inTransaction(session -> { - insertTestData(session); - for (int i = 1; i <= BATCH_COUNT; i++) { - Item item = session.find(Item.class, i); - item.string = "u" + i; - } - session.flush(); - assertActualCommand( - parse( - """ - { - "update": "items", - "ordered": true, - "updates": [ - { "q": { "_id": { "$eq": 1 } }, "u": { "$set": { "string": "u1" } }, "multi": true }, - { "q": { "_id": { "$eq": 2 } }, "u": { "$set": { "string": "u2" } }, "multi": true }, - { "q": { "_id": { "$eq": 3 } }, "u": { "$set": { "string": "u3" } }, "multi": true } - ] - } - """), - parse( - """ - { - "update": "items", - "ordered": true, - "updates": [ - { "q": { "_id": { "$eq": 4 } }, "u": { "$set": { "string": "u4" } }, "multi": true }, - { "q": { "_id": { "$eq": 5 } }, "u": { "$set": { "string": "u5" } }, "multi": true } - ] - } - """)); - }); - - assertThat(collection.find()) - .containsExactlyElementsOf(java.util.List.of( - BsonDocument.parse("{ _id: 1, string: 'u1' }"), - BsonDocument.parse("{ _id: 2, string: 'u2' }"), - BsonDocument.parse("{ _id: 3, string: 'u3' }"), - BsonDocument.parse("{ _id: 4, string: 'u4' }"), - BsonDocument.parse("{ _id: 5, string: 'u5' }"))); - } + @Test + void testBatchUpdate() { + getSessionFactoryScope().inTransaction(session -> { + insertTestData(session); + for (int i = 1; i <= ENTITIES_TO_PERSIST_COUNT; i++) { + Item item = session.find(Item.class, i); + item.string = "u" + i; + } + session.flush(); + assertActualCommand( + parse( + """ + { + "update": "items", + "ordered": true, + "updates": [ + { "q": { "_id": { "$eq": 1 } }, "u": { "$set": { "string": "u1" } }, "multi": true }, + { "q": { "_id": { "$eq": 2 } }, "u": { "$set": { "string": "u2" } }, "multi": true }, + { "q": { "_id": { "$eq": 3 } }, "u": { "$set": { "string": "u3" } }, "multi": true } + ] + } + """), + parse( + """ + { + "update": "items", + "ordered": true, + "updates": [ + { "q": { "_id": { "$eq": 4 } }, "u": { "$set": { "string": "u4" } }, "multi": true }, + { "q": { "_id": { "$eq": 5 } }, "u": { "$set": { "string": "u5" } }, "multi": true } + ] + } + """)); + }); + + assertThat(collection.find()) + .containsExactlyElementsOf(java.util.List.of( + BsonDocument.parse("{ _id: 1, string: 'u1' }"), + BsonDocument.parse("{ _id: 2, string: 'u2' }"), + BsonDocument.parse("{ _id: 3, string: 'u3' }"), + BsonDocument.parse("{ _id: 4, string: 'u4' }"), + BsonDocument.parse("{ _id: 5, string: 'u5' }"))); } - @Nested - class BatchDeleteTests { - - @Test - void testBatchDelete() { - getSessionFactoryScope().inTransaction(session -> { - insertTestData(session); - for (int i = 1; i <= BATCH_COUNT; i++) { - var item = session.find(Item.class, i); - session.remove(item); - } - session.flush(); - assertActualCommand( - parse( - """ - { - "delete": "items", - "ordered": true, - "deletes": [ - {"q": {"_id": {"$eq": 1}}, "limit": 0}, - {"q": {"_id": {"$eq": 2}}, "limit": 0}, - {"q": {"_id": {"$eq": 3}}, "limit": 0} - ] - } - """), - parse( - """ - { - "delete": "items", - "ordered": true, - "deletes": [ - {"q": {"_id": {"$eq": 4}}, "limit": 0} - {"q": {"_id": {"$eq": 5}}, "limit": 0} - ] - } - """)); - }); - - assertThat(collection.find()).isEmpty(); - } + @Test + void testBatchDelete() { + getSessionFactoryScope().inTransaction(session -> { + insertTestData(session); + for (int i = 1; i <= ENTITIES_TO_PERSIST_COUNT; i++) { + var item = session.find(Item.class, i); + session.remove(item); + } + session.flush(); + assertActualCommand( + parse( + """ + { + "delete": "items", + "ordered": true, + "deletes": [ + {"q": {"_id": {"$eq": 1}}, "limit": 0}, + {"q": {"_id": {"$eq": 2}}, "limit": 0}, + {"q": {"_id": {"$eq": 3}}, "limit": 0} + ] + } + """), + parse( + """ + { + "delete": "items", + "ordered": true, + "deletes": [ + {"q": {"_id": {"$eq": 4}}, "limit": 0}, + {"q": {"_id": {"$eq": 5}}, "limit": 0} + ] + } + """)); + }); + + assertThat(collection.find()).isEmpty(); } private void insertTestData(final SessionImplementor session) { diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java index ac79a143..9c23e338 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java @@ -21,7 +21,6 @@ import static com.mongodb.hibernate.internal.type.ValueConversions.toBsonValue; import static java.lang.String.format; -import com.mongodb.MongoBulkWriteException; import com.mongodb.client.ClientSession; import com.mongodb.client.MongoDatabase; import com.mongodb.hibernate.internal.FeatureNotSupportedException; @@ -29,7 +28,6 @@ import com.mongodb.hibernate.internal.type.ObjectIdJdbcType; import java.math.BigDecimal; import java.sql.Array; -import java.sql.BatchUpdateException; import java.sql.Date; import java.sql.JDBCType; import java.sql.PreparedStatement; @@ -55,10 +53,10 @@ final class MongoPreparedStatement extends MongoStatement implements PreparedStatementAdapter { + private static final int[] EMPTY_BATCH_RESULT = new int[0]; private final BsonDocument command; private final List batchCommands; private final List parameterValueSetters; - private static final int[] EMPTY_BATCH_RESULT = new int[0]; MongoPreparedStatement( MongoDatabase mongoDatabase, ClientSession clientSession, MongoConnection mongoConnection, String mql) @@ -83,6 +81,7 @@ public int executeUpdate() throws SQLException { checkClosed(); closeLastOpenResultSet(); checkAllParametersSet(); + checkUpdateOperation(command); return executeUpdateCommand(command); } @@ -206,7 +205,8 @@ public void setObject(int parameterIndex, Object x, int targetSqlType) throws SQ @Override public void addBatch() throws SQLException { checkClosed(); - checkAllParametersSet(); // TODO first check that all parameters are set for the previous batch. + checkAllParametersSet(); + checkUpdateOperation(command); batchCommands.add(command.clone()); } @@ -224,13 +224,11 @@ public int[] executeBatch() throws SQLException { return EMPTY_BATCH_RESULT; } try { - executeBulkWrite(batchCommands); + executeBulkWrite(batchCommands, ExecutionType.BATCH); var rowCounts = new int[batchCommands.size()]; // We cannot determine the actual number of rows affected for each command in the batch. Arrays.fill(rowCounts, Statement.SUCCESS_NO_INFO); return rowCounts; - } catch (MongoBulkWriteException mongoBulkWriteException) { - throw createBatchUpdateException(mongoBulkWriteException, command.getFirstKey()); } finally { batchCommands.clear(); } @@ -393,12 +391,4 @@ private static void checkComparatorNotComparingWithNullValues(BsonDocument docum } } } - - static BatchUpdateException createBatchUpdateException( - final MongoBulkWriteException mongoBulkWriteException, final String commandName) { - int updateCount = getUpdateCount(commandName, mongoBulkWriteException.getWriteResult()); - int[] updateCounts = new int[updateCount]; - Arrays.fill(updateCounts, SUCCESS_NO_INFO); - return new BatchUpdateException(mongoBulkWriteException.getMessage(), updateCounts, mongoBulkWriteException); - } } diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 7f416310..4aec7e1f 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -16,9 +16,10 @@ package com.mongodb.hibernate.jdbc; +import static com.mongodb.assertions.Assertions.fail; import static com.mongodb.hibernate.internal.MongoAssertions.assertNotNull; -import static com.mongodb.hibernate.internal.MongoAssertions.assertTrue; import static com.mongodb.hibernate.internal.MongoConstants.ID_FIELD_NAME; +import static com.mongodb.hibernate.internal.MongoConstants.MONGO_DBMS_NAME; import static com.mongodb.hibernate.internal.VisibleForTesting.AccessModifier.PRIVATE; import static java.lang.String.format; import static java.util.Collections.singletonList; @@ -29,19 +30,20 @@ import com.mongodb.MongoSocketReadTimeoutException; import com.mongodb.MongoSocketWriteTimeoutException; import com.mongodb.MongoTimeoutException; +import com.mongodb.bulk.BulkWriteError; import com.mongodb.bulk.BulkWriteResult; import com.mongodb.client.ClientSession; import com.mongodb.client.MongoCollection; import com.mongodb.client.MongoDatabase; import com.mongodb.client.model.DeleteManyModel; import com.mongodb.client.model.DeleteOneModel; -import com.mongodb.client.model.DeleteOptions; import com.mongodb.client.model.InsertOneModel; import com.mongodb.client.model.UpdateManyModel; import com.mongodb.client.model.UpdateOneModel; import com.mongodb.client.model.WriteModel; import com.mongodb.hibernate.internal.FeatureNotSupportedException; import com.mongodb.hibernate.internal.VisibleForTesting; +import java.sql.BatchUpdateException; import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; @@ -49,7 +51,9 @@ import java.sql.SQLSyntaxErrorException; import java.sql.SQLTimeoutException; import java.sql.SQLWarning; +import java.sql.Statement; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.List; import java.util.Map; @@ -59,6 +63,9 @@ class MongoStatement implements StatementAdapter { + private static final List SUPPORTED_UPDATE_COMMAND_ELEMENTS = List.of("q", "u", "multi"); + private static final List SUPPORTED_DELETE_COMMAND_ELEMENTS = List.of("q", "limit"); + private static final String EXCEPTION_MESSAGE_FAILED_TO_EXECUTE_OPERATION = "Failed to execute operation"; private final MongoDatabase mongoDatabase; private final MongoConnection mongoConnection; private final ClientSession clientSession; @@ -87,12 +94,11 @@ void closeLastOpenResultSet() throws SQLException { } ResultSet executeQueryCommand(BsonDocument command) throws SQLException { + var mongoCommand = getCommandType(command); try { startTransactionIfNeeded(); - - var collectionName = command.getString("aggregate").getValue(); - var collection = mongoDatabase.getCollection(collectionName, BsonDocument.class); - + checkQueryOperation(command); + var collection = getCollection(mongoCommand, command); var pipeline = command.getArray("pipeline").stream() .map(BsonValue::asDocument) .toList(); @@ -101,8 +107,8 @@ ResultSet executeQueryCommand(BsonDocument command) throws SQLException { return resultSet = new MongoResultSet( collection.aggregate(clientSession, pipeline).cursor(), fieldNames); - } catch (RuntimeException e) { - throw new SQLException("Failed to execute query", e); + } catch (Exception exception) { + throw handleException(exception, mongoCommand, ExecutionType.QUERY); } } @@ -141,42 +147,28 @@ public int executeUpdate(String mql) throws SQLException { checkClosed(); closeLastOpenResultSet(); var command = parse(mql); + checkUpdateOperation(command); return executeUpdateCommand(command); } int executeUpdateCommand(BsonDocument command) throws SQLException { - try { - var bulkWriteResult = executeBulkWrite(singletonList(command)); - return getUpdateCount(command.getFirstKey(), bulkWriteResult); - } catch (MongoBulkWriteException mongoBulkWriteException) { - throw new SQLException(mongoBulkWriteException.getMessage(), mongoBulkWriteException); - } + return executeBulkWrite(singletonList(command), ExecutionType.UPDATE); } - BulkWriteResult executeBulkWrite(List commandBatch) throws SQLException { - startTransactionIfNeeded(); + int executeBulkWrite(List commandBatch, ExecutionType executionType) throws SQLException { var firstDocumentInBatch = commandBatch.get(0); - var commandName = assertNotNull(firstDocumentInBatch.getFirstKey()); - var collectionName = - assertNotNull(firstDocumentInBatch.getString(commandName).getValue()); - MongoCollection collection = mongoDatabase.getCollection(collectionName, BsonDocument.class); - + var commandType = getCommandType(firstDocumentInBatch); + var collection = getCollection(commandType, firstDocumentInBatch); try { + startTransactionIfNeeded(); var writeModels = new ArrayList>(commandBatch.size()); for (var command : commandBatch) { - assertTrue(collectionName.equals(command.getString(commandName).getValue())); - convertToWriteModels(command, commandName, writeModels); + convertToWriteModels(commandType, command, writeModels); } - return collection.bulkWrite(clientSession, writeModels); - } catch (MongoSocketReadTimeoutException - | MongoSocketWriteTimeoutException - | MongoTimeoutException - | MongoExecutionTimeoutException mongoTimeoutException) { - throw new SQLTimeoutException(mongoTimeoutException.getMessage(), mongoTimeoutException); - } catch (MongoBulkWriteException mongoBulkWriteException) { - throw mongoBulkWriteException; - } catch (RuntimeException runtimeException) { - throw new SQLException("Failed to execute update", runtimeException); + var bulkWriteResult = collection.bulkWrite(clientSession, writeModels); + return getUpdateCount(commandType, bulkWriteResult); + } catch (Exception exception) { + throw handleException(exception, commandType, executionType); } } @@ -255,6 +247,24 @@ void checkClosed() throws SQLException { } } + private void checkQueryOperation(BsonDocument command) throws SQLFeatureNotSupportedException { + CommandType commandType = getCommandType(command); + if (commandType != CommandType.AGGREGATE) { + throw new SQLFeatureNotSupportedException( + format("Unsupported command for query operation: %s", commandType.getCommandName())); + } + } + + void checkUpdateOperation(BsonDocument command) throws SQLException { + CommandType commandType = getCommandType(command); + if (commandType != CommandType.INSERT + && commandType != CommandType.UPDATE + && commandType != CommandType.DELETE) { + throw new SQLFeatureNotSupportedException( + format("Unsupported command for batch operation: %s", commandType.getCommandName())); + } + } + static BsonDocument parse(String mql) throws SQLSyntaxErrorException { try { return BsonDocument.parse(mql); @@ -273,60 +283,108 @@ void startTransactionIfNeeded() throws SQLException { } } - static int getUpdateCount(final String commandName, final BulkWriteResult bulkWriteResult) { - return switch (commandName) { - case "insert" -> bulkWriteResult.getInsertedCount(); - case "update" -> bulkWriteResult.getModifiedCount(); - case "delete" -> bulkWriteResult.getDeletedCount(); - default -> throw new FeatureNotSupportedException("Unsupported command: " + commandName); - }; + static CommandType getCommandType(BsonDocument command) throws SQLFeatureNotSupportedException { + // The first key is always the command name, e.g. "insert", "update", "delete". + return CommandType.fromString(assertNotNull(command.getFirstKey())); + } + + private MongoCollection getCollection(CommandType commandType, BsonDocument command) { + var collectionName = + assertNotNull(command.getString(commandType.getCommandName()).getValue()); + return mongoDatabase.getCollection(collectionName, BsonDocument.class); + } + + private static SQLException handleException( + Exception exception, CommandType commandType, ExecutionType executionType) { + + if (exception instanceof SQLException sqlException) { + return sqlException; + } + if (exception instanceof MongoSocketReadTimeoutException + || exception instanceof MongoSocketWriteTimeoutException + || exception instanceof MongoTimeoutException + || exception instanceof MongoExecutionTimeoutException) { + return new SQLTimeoutException( + format("Timeout while waiting for %s operation to complete", MONGO_DBMS_NAME), exception); + } + + if (exception instanceof MongoBulkWriteException mongoBulkWriteException) { + if (executionType == ExecutionType.BATCH) { + return createBatchUpdateException(mongoBulkWriteException, commandType); + } else { + return new SQLException( + EXCEPTION_MESSAGE_FAILED_TO_EXECUTE_OPERATION, + null, + getErrorCode(mongoBulkWriteException), + mongoBulkWriteException); + } + } + + return new SQLException(EXCEPTION_MESSAGE_FAILED_TO_EXECUTE_OPERATION, exception); + } + + static BatchUpdateException createBatchUpdateException( + MongoBulkWriteException mongoBulkWriteException, CommandType commandType) { + int updateCount = getUpdateCount(commandType, mongoBulkWriteException.getWriteResult()); + int[] updateCounts = new int[updateCount]; + Arrays.fill(updateCounts, Statement.SUCCESS_NO_INFO); + int code = getErrorCode(mongoBulkWriteException); + return new BatchUpdateException( + EXCEPTION_MESSAGE_FAILED_TO_EXECUTE_OPERATION, null, code, updateCounts, mongoBulkWriteException); + } + + private static int getErrorCode(final MongoBulkWriteException mongoBulkWriteException) { + List writeErrors = mongoBulkWriteException.getWriteErrors(); + // Since we are executing an ordered bulk write, there will be at most one BulkWriteError. + return writeErrors.isEmpty() ? 0 : writeErrors.get(0).getCode(); } private static void convertToWriteModels( - final BsonDocument command, - final String commandName, - final Collection> writeModels) + CommandType commandType, BsonDocument command, Collection> writeModels) throws SQLFeatureNotSupportedException { - switch (commandName) { - case "insert": + switch (commandType) { + case INSERT: var documents = command.getArray("documents"); for (var insertDocument : documents) { writeModels.add(createInsertModel(insertDocument.asDocument())); } break; - case "update": + case UPDATE: var updates = command.getArray("updates").getValues(); for (var updateDocument : updates) { writeModels.add(createUpdateModel(updateDocument.asDocument())); } break; - case "delete": + case DELETE: var deletes = command.getArray("deletes"); for (var deleteDocument : deletes) { writeModels.add(createDeleteModel(deleteDocument.asDocument())); } break; default: - throw new SQLFeatureNotSupportedException("Unsupported command: " + commandName); + throw fail(); } } - private static WriteModel createInsertModel(final BsonDocument document) { - return new InsertOneModel<>(document); + private static WriteModel createInsertModel(final BsonDocument insertDocument) { + return new InsertOneModel<>(insertDocument); } - private static WriteModel createDeleteModel(final BsonDocument deleteDocument) { + private static WriteModel createDeleteModel(final BsonDocument deleteDocument) + throws SQLFeatureNotSupportedException { + checkDeleteElements(deleteDocument); var isSingleDelete = deleteDocument.getNumber("limit").intValue() == 1; var queryFilter = deleteDocument.getDocument("q"); if (isSingleDelete) { - new DeleteOptions(); return new DeleteOneModel<>(queryFilter); } return new DeleteManyModel<>(queryFilter); } - private static WriteModel createUpdateModel(final BsonDocument updateDocument) { + private static WriteModel createUpdateModel(final BsonDocument updateDocument) + throws SQLFeatureNotSupportedException { + checkUpdateElements(updateDocument); var isMulti = updateDocument.getBoolean("multi").getValue(); var queryFilter = updateDocument.getDocument("q"); var updatePipeline = updateDocument.getDocument("u"); @@ -336,4 +394,71 @@ private static WriteModel createUpdateModel(final BsonDocument upd } return new UpdateOneModel<>(queryFilter, updatePipeline); } + + private static void checkDeleteElements(final BsonDocument deleteDocument) throws SQLFeatureNotSupportedException { + if (deleteDocument.size() > SUPPORTED_DELETE_COMMAND_ELEMENTS.size()) { + List unSupportedElements = + getUnsupportedElements(deleteDocument, SUPPORTED_DELETE_COMMAND_ELEMENTS); + throw new SQLFeatureNotSupportedException( + format("Unsupported elements in delete command: %s", unSupportedElements)); + } + } + + private static void checkUpdateElements(final BsonDocument updateDocument) throws SQLFeatureNotSupportedException { + if (updateDocument.size() > SUPPORTED_UPDATE_COMMAND_ELEMENTS.size()) { + List unSupportedElements = + getUnsupportedElements(updateDocument, SUPPORTED_UPDATE_COMMAND_ELEMENTS); + throw new SQLFeatureNotSupportedException( + format("Unsupported elements in update command: %s", unSupportedElements)); + } + } + + private static List getUnsupportedElements( + final BsonDocument deleteDocument, final List supportedElements) { + return deleteDocument.keySet().stream() + .filter((key) -> !supportedElements.contains(key)) + .toList(); + } + + static int getUpdateCount(CommandType commandType, BulkWriteResult bulkWriteResult) { + return switch (commandType) { + case INSERT -> bulkWriteResult.getInsertedCount(); + case UPDATE -> bulkWriteResult.getModifiedCount(); + case DELETE -> bulkWriteResult.getDeletedCount(); + default -> throw fail(); + }; + } + + enum CommandType { + INSERT("insert"), + UPDATE("update"), + DELETE("delete"), + AGGREGATE("aggregate"); + + private final String commandName; + + CommandType(String commandName) { + this.commandName = commandName; + } + + String getCommandName() { + return commandName; + } + + static CommandType fromString(String commandName) throws SQLFeatureNotSupportedException { + return switch (commandName) { + case "insert" -> INSERT; + case "update" -> UPDATE; + case "delete" -> DELETE; + case "aggregate" -> AGGREGATE; + default -> throw new SQLFeatureNotSupportedException(format("Unsupported command: %s", commandName)); + }; + } + } + + enum ExecutionType { + UPDATE, + BATCH, + QUERY + } } diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java index 1e68da1b..b9751cac 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java @@ -17,7 +17,11 @@ package com.mongodb.hibernate.jdbc; import static java.sql.Statement.SUCCESS_NO_INFO; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptySet; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatException; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -25,6 +29,7 @@ import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Named.named; import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; @@ -34,8 +39,15 @@ import static org.mockito.Mockito.verify; import com.mongodb.MongoBulkWriteException; +import com.mongodb.MongoException; +import com.mongodb.MongoExecutionTimeoutException; +import com.mongodb.MongoOperationTimeoutException; +import com.mongodb.MongoSocketReadTimeoutException; +import com.mongodb.MongoSocketWriteTimeoutException; +import com.mongodb.MongoTimeoutException; import com.mongodb.ServerAddress; import com.mongodb.bulk.BulkWriteError; +import com.mongodb.bulk.BulkWriteInsert; import com.mongodb.bulk.BulkWriteResult; import com.mongodb.client.AggregateIterable; import com.mongodb.client.ClientSession; @@ -52,13 +64,14 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.sql.SQLSyntaxErrorException; +import java.sql.SQLTimeoutException; import java.sql.Time; import java.sql.Timestamp; import java.sql.Types; import java.util.Calendar; import java.util.List; -import java.util.Set; import java.util.function.Consumer; +import java.util.stream.Stream; import org.bson.BsonArray; import org.bson.BsonBoolean; import org.bson.BsonDocument; @@ -72,6 +85,9 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; @@ -164,117 +180,252 @@ void testSuccess() throws SQLException { } @Nested - class ExecuteBatchThrows { - - static final String BULK_WRITE_ERROR_MESSAGE = "Test message"; + class ExecuteMethodThrowsSqlExceptionTests { + private static final String DUMMY_EXCEPTION_MESSAGE = "Test message"; + private static final ServerAddress DUMMY_SERVER_ADDRESS = new ServerAddress("localhost"); + + private static final BulkWriteError BULK_WRITE_ERROR = + new BulkWriteError(10, DUMMY_EXCEPTION_MESSAGE, new BsonDocument(), 0); + private static final BulkWriteResult BULK_WRITE_RESULT = BulkWriteResult.acknowledged( + 1, 0, 2, 3, emptyList(), List.of(new BulkWriteInsert(0, new BsonObjectId(new ObjectId(1, 2))))); + private static final MongoBulkWriteException MONGO_BULK_WRITE_EXCEPTION_WITH_ERRORS = + new MongoBulkWriteException( + BULK_WRITE_RESULT, List.of(BULK_WRITE_ERROR), null, DUMMY_SERVER_ADDRESS, emptySet()); + private static final MongoBulkWriteException MONGO_BULK_WRITE_EXCEPTION_NO_ERRORS = + new MongoBulkWriteException(BULK_WRITE_RESULT, emptyList(), null, DUMMY_SERVER_ADDRESS, emptySet()); + + private static final String MQL_ITEMS_AGGREGATE = + """ + { + aggregate: "items", + pipeline: [ + { $match: { _id: 1 } }, + { $project: { _id: 0 } } + ] + } + """; - @Mock - BulkWriteResult bulkWriteResult; + private static final String MQL_ITEMS_INSERT = + """ + { + insert: "items", + documents: [ + { _id: 1 } + ] + } + """; + private static final String MQL_ITEMS_UPDATE = + """ + { + update: "items", + updates: [ + { q: { _id: 1 }, u: { $set: { touched: true } }, multi: false } + ] + } + """; + private static final String MQL_ITEMS_DELETE = + """ + { + delete: "items", + deletes: [ + { q: { _id: 1 }, limit: 1 } + ] + } + """; @BeforeEach void beforeEach() { doReturn(mongoCollection).when(mongoDatabase).getCollection(anyString(), eq(BsonDocument.class)); - BulkWriteError bulkWriteError = new BulkWriteError(10, BULK_WRITE_ERROR_MESSAGE, new BsonDocument(), 0); - List writeErrors = List.of(bulkWriteError); - - MongoBulkWriteException mongoBulkWriteException = new MongoBulkWriteException( - bulkWriteResult, writeErrors, null, new ServerAddress("localhost"), Set.of("label")); + } - doThrow(mongoBulkWriteException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + private static Stream exceptions() { + return Stream.of( + Arguments.of(new MongoException(DUMMY_EXCEPTION_MESSAGE)), + Arguments.of(new RuntimeException(DUMMY_EXCEPTION_MESSAGE))); } - @Test - void testBatchInsert() throws SQLException { - doReturn(1).when(bulkWriteResult).getInsertedCount(); + @ParameterizedTest(name = "testExecuteQueryThrowsSqlException: exception={0}") + @MethodSource("exceptions") + void testExecuteQueryThrowsSqlExceptionWhenExceptionOccurs(Exception exception) throws SQLException { + doThrow(exception).when(mongoCollection).aggregate(eq(clientSession), anyList()); + assertExecuteThrowsSqlException(MQL_ITEMS_AGGREGATE, MongoPreparedStatement::executeQuery, exception); + } - String mql = - """ - { - insert: "items", - documents: [ - { _id: { $undefined: true } } - ] - } - """; + @ParameterizedTest(name = "testExecuteUpdateThrowsSqlException: exception={0}") + @MethodSource("exceptions") + void testExecuteUpdateThrowsSqlExceptionWhenExceptionOccurs(Exception exception) throws SQLException { + doThrow(exception).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + assertExecuteThrowsSqlException(MQL_ITEMS_INSERT, MongoPreparedStatement::executeUpdate, exception); + } - try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(mql)) { - mongoPreparedStatement.setInt(1, 1); - mongoPreparedStatement.addBatch(); - mongoPreparedStatement.addBatch(); + @ParameterizedTest(name = "testExecuteUpdateThrowsSqlException: exception={0}") + @MethodSource("exceptions") + void testExecuteBatchThrowsSqlExceptionWhenExceptionOccurs(Exception exception) throws SQLException { + doThrow(exception).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + assertExecuteThrowsSqlException( + MQL_ITEMS_INSERT, + mongoPreparedStatement -> { + mongoPreparedStatement.addBatch(); + mongoPreparedStatement.executeBatch(); + }, + exception); + } - BatchUpdateException batchUpdateException = - assertThrows(BatchUpdateException.class, mongoPreparedStatement::executeBatch); + private static Stream argumentsForExecuteUpdate() { + return Stream.of( + Arguments.of(named("insert", MQL_ITEMS_INSERT), MONGO_BULK_WRITE_EXCEPTION_NO_ERRORS), + Arguments.of(named("update", MQL_ITEMS_UPDATE), MONGO_BULK_WRITE_EXCEPTION_NO_ERRORS), + Arguments.of(named("delete", MQL_ITEMS_DELETE), MONGO_BULK_WRITE_EXCEPTION_NO_ERRORS), + Arguments.of(named("insert", MQL_ITEMS_INSERT), MONGO_BULK_WRITE_EXCEPTION_WITH_ERRORS), + Arguments.of(named("update", MQL_ITEMS_UPDATE), MONGO_BULK_WRITE_EXCEPTION_WITH_ERRORS), + Arguments.of(named("delete", MQL_ITEMS_DELETE), MONGO_BULK_WRITE_EXCEPTION_WITH_ERRORS)); + } - int[] updateCounts = batchUpdateException.getUpdateCounts(); - assertEquals(1, updateCounts.length); - assertTrue(batchUpdateException.getMessage().contains(BULK_WRITE_ERROR_MESSAGE)); - assertEquals(0, batchUpdateException.getErrorCode()); - assertUpdateCounts(updateCounts); + @ParameterizedTest(name = "testUpdateThrowsSqlException: commandName={0}, exception={1}") + @MethodSource("argumentsForExecuteUpdate") + void testExecuteUpdateThrowsSqlExceptionWhenMongoBulkWriteExceptionOccurs( + String mql, MongoBulkWriteException mongoBulkWriteException) throws SQLException { + doThrow(mongoBulkWriteException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(mql)) { + assertThatExceptionOfType(SQLException.class) + .isThrownBy(mongoPreparedStatement::executeUpdate) + .withCause(mongoBulkWriteException) + .satisfies(sqlException -> { + Integer vendorCodeError = getVendorCodeError(mongoBulkWriteException); + assertAll( + () -> assertNull(sqlException.getSQLState()), + () -> assertEquals(vendorCodeError, sqlException.getErrorCode())); + }); } } - @Test - void testBatchUpdate() throws SQLException { - doReturn(1).when(bulkWriteResult).getModifiedCount(); - - String mql = - """ - { - update: "items", - updates: [ - { q: { _id: { $undefined: true } }, u: { $set: { touched: true } }, multi: false } - ] - } - """; + private static Stream argumentsForExecuteBatch() { + return Stream.of( + Arguments.of( + named("insert", MQL_ITEMS_INSERT), + MONGO_BULK_WRITE_EXCEPTION_NO_ERRORS, + BULK_WRITE_RESULT.getInsertedCount()), + Arguments.of( + named("update", MQL_ITEMS_UPDATE), + MONGO_BULK_WRITE_EXCEPTION_NO_ERRORS, + BULK_WRITE_RESULT.getModifiedCount()), + Arguments.of( + named("delete", MQL_ITEMS_DELETE), + MONGO_BULK_WRITE_EXCEPTION_NO_ERRORS, + BULK_WRITE_RESULT.getDeletedCount()), + Arguments.of( + named("insert", MQL_ITEMS_INSERT), + MONGO_BULK_WRITE_EXCEPTION_WITH_ERRORS, + BULK_WRITE_RESULT.getInsertedCount()), + Arguments.of( + named("update", MQL_ITEMS_UPDATE), + MONGO_BULK_WRITE_EXCEPTION_WITH_ERRORS, + BULK_WRITE_RESULT.getModifiedCount()), + Arguments.of( + named("delete", MQL_ITEMS_DELETE), + MONGO_BULK_WRITE_EXCEPTION_WITH_ERRORS, + BULK_WRITE_RESULT.getDeletedCount())); + } + @ParameterizedTest(name = "testBatchUpdateException: commandName={0}, exception={1}") + @MethodSource("argumentsForExecuteBatch") + void testExecuteBatchThrowsBatchUpdateExceptionWhenMongoBulkWriteExceptionOccurs( + String mql, MongoBulkWriteException mongoBulkWriteException, int expectedUpdateCountLength) + throws SQLException { + doThrow(mongoBulkWriteException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(mql)) { - mongoPreparedStatement.setInt(1, 1); - mongoPreparedStatement.addBatch(); mongoPreparedStatement.addBatch(); + assertThatExceptionOfType(BatchUpdateException.class) + .isThrownBy(mongoPreparedStatement::executeBatch) + .withCause(mongoBulkWriteException) + .satisfies(batchUpdateException -> { + Integer vendorCodeError = getVendorCodeError(mongoBulkWriteException); + assertAll( + () -> assertUpdateCounts( + batchUpdateException.getUpdateCounts(), expectedUpdateCountLength), + () -> assertNull(batchUpdateException.getSQLState()), + () -> assertEquals(vendorCodeError, batchUpdateException.getErrorCode())); + }); + } + } - BatchUpdateException batchUpdateException = - assertThrows(BatchUpdateException.class, mongoPreparedStatement::executeBatch); + private static Stream timeoutExceptions() { + RuntimeException dummyCause = new RuntimeException(); + return Stream.of( + new MongoExecutionTimeoutException(DUMMY_EXCEPTION_MESSAGE), + new MongoSocketReadTimeoutException(DUMMY_EXCEPTION_MESSAGE, DUMMY_SERVER_ADDRESS, dummyCause), + new MongoSocketWriteTimeoutException(DUMMY_EXCEPTION_MESSAGE, DUMMY_SERVER_ADDRESS, dummyCause), + new MongoTimeoutException(DUMMY_EXCEPTION_MESSAGE), + new MongoOperationTimeoutException(DUMMY_EXCEPTION_MESSAGE)); + } - int[] updateCounts = batchUpdateException.getUpdateCounts(); - assertEquals(1, updateCounts.length); - assertTrue(batchUpdateException.getMessage().contains(BULK_WRITE_ERROR_MESSAGE)); - assertEquals(0, batchUpdateException.getErrorCode()); - assertUpdateCounts(updateCounts); + @ParameterizedTest( + name = + "SQLTimeoutException is thrown when timeout exception occurs in executeQuery. Parameters: timeoutException={0}") + @MethodSource("timeoutExceptions") + void testExecuteQuerySqlTimeoutException(MongoException timeoutException) throws SQLException { + doThrow(timeoutException).when(mongoCollection).aggregate(eq(clientSession), anyList()); + try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(MQL_ITEMS_AGGREGATE)) { + assertThatException() + .isThrownBy(mongoPreparedStatement::executeQuery) + .isInstanceOf(SQLTimeoutException.class) + .withCause(timeoutException); } } - @Test - void testBatchDelete() throws SQLException { - doReturn(1).when(bulkWriteResult).getDeletedCount(); - - String mql = - """ - { - delete: "items", - deletes: [ - { q: { _id: { $undefined: true } }, limit: 1 } - ] - } - """; + @ParameterizedTest( + name = + "SQLTimeoutException is thrown when timeout exception occurs in executeUpdate. Parameters: timeoutException={0}") + @MethodSource("timeoutExceptions") + void testExecuteUpdateSqlTimeoutException(MongoException timeoutException) throws SQLException { + doThrow(timeoutException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(MQL_ITEMS_INSERT)) { + assertThatException() + .isThrownBy(mongoPreparedStatement::executeUpdate) + .isInstanceOf(SQLTimeoutException.class) + .withCause(timeoutException); + } + } - try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(mql)) { - mongoPreparedStatement.setInt(1, 1); + @ParameterizedTest( + name = + "SQLTimeoutException is thrown when timeout exception occurs in executeBatch. Parameters: timeoutException={0}") + @MethodSource("timeoutExceptions") + void testExecuteBatchSqlTimeoutException(MongoException timeoutException) throws SQLException { + doThrow(timeoutException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(MQL_ITEMS_INSERT)) { mongoPreparedStatement.addBatch(); - mongoPreparedStatement.addBatch(); - - BatchUpdateException batchUpdateException = - assertThrows(BatchUpdateException.class, mongoPreparedStatement::executeBatch); + assertThatException() + .isThrownBy(mongoPreparedStatement::executeBatch) + .isInstanceOf(SQLTimeoutException.class) + .withCause(timeoutException); + } + } - int[] updateCounts = batchUpdateException.getUpdateCounts(); - assertEquals(1, updateCounts.length); - assertTrue(batchUpdateException.getMessage().contains(BULK_WRITE_ERROR_MESSAGE)); - assertNull(batchUpdateException.getSQLState()); - assertEquals(0, batchUpdateException.getErrorCode()); - assertUpdateCounts(updateCounts); + private void assertExecuteThrowsSqlException( + String mql, SqlConsumer executeConsumer, Exception expectedCause) + throws SQLException { + try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(mql)) { + assertThatExceptionOfType(SQLException.class) + .isThrownBy(() -> executeConsumer.accept(mongoPreparedStatement)) + .withCause(expectedCause) + .satisfies(sqlException -> { + assertAll( + () -> assertNull(sqlException.getSQLState()), + () -> assertEquals(0, sqlException.getErrorCode())); + }); } } - private static void assertUpdateCounts(final int[] updateCounts) { + private static Integer getVendorCodeError(final MongoBulkWriteException mongoBulkWriteException) { + return mongoBulkWriteException.getWriteErrors().stream() + .map(BulkWriteError::getCode) + .findFirst() + .orElse(0); + } + + private static void assertUpdateCounts(final int[] updateCounts, int expectedUpdateCountsLength) { + assertEquals(expectedUpdateCountsLength, updateCounts.length); for (int count : updateCounts) { assertEquals(SUCCESS_NO_INFO, count); } @@ -417,4 +568,8 @@ private static void assertThrowsClosedException(Executable executable) { var exception = assertThrows(SQLException.class, executable); assertThat(exception.getMessage()).isEqualTo("MongoPreparedStatement has been closed"); } + + interface SqlConsumer { + void accept(T t) throws SQLException; + } } diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java index 7256fd98..c7a2a25e 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java @@ -23,7 +23,6 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; @@ -31,6 +30,7 @@ import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; +import com.mongodb.bulk.BulkWriteResult; import com.mongodb.client.AggregateIterable; import com.mongodb.client.ClientSession; import com.mongodb.client.MongoCollection; @@ -43,7 +43,6 @@ import java.util.List; import java.util.function.BiConsumer; import org.bson.BsonDocument; -import org.bson.Document; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -58,6 +57,9 @@ class MongoStatementTests { @Mock private MongoDatabase mongoDatabase; + @Mock + private MongoCollection mongoCollection; + @Mock private ClientSession clientSession; @@ -132,15 +134,15 @@ class ExecuteMethodClosesLastOpenResultSetTests { ] }"""; - @Mock - MongoCollection mongoCollection; - @Mock AggregateIterable aggregateIterable; @Mock MongoCursor mongoCursor; + @Mock + BulkWriteResult bulkWriteResult; + private ResultSet lastOpenResultSet; @BeforeEach @@ -161,9 +163,10 @@ void testExecuteQuery() throws SQLException { @Test void testExecuteUpdate() throws SQLException { - doReturn(Document.parse("{n: 10}")) - .when(mongoDatabase) - .runCommand(eq(clientSession), any(BsonDocument.class)); + doReturn(mongoCollection).when(mongoDatabase).getCollection(anyString(), eq(BsonDocument.class)); + doReturn(bulkWriteResult).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + doReturn(10).when(bulkWriteResult).getModifiedCount(); + mongoStatement.executeUpdate(exampleUpdateMql); assertTrue(lastOpenResultSet.isClosed()); } @@ -173,12 +176,6 @@ void testExecute() throws SQLException { assertThrows(SQLFeatureNotSupportedException.class, () -> mongoStatement.execute(exampleUpdateMql)); assertTrue(lastOpenResultSet.isClosed()); } - - @Test - void testExecuteBatch() throws SQLException { - assertThrows(SQLFeatureNotSupportedException.class, () -> mongoStatement.executeBatch()); - assertTrue(lastOpenResultSet.isClosed()); - } } @Test @@ -227,9 +224,9 @@ void testSQLExceptionThrownWhenCalledWithInvalidMql() { @Test void testSQLExceptionThrownWhenDBAccessFailed() { - var dbAccessException = new RuntimeException(); - doThrow(dbAccessException).when(mongoDatabase).runCommand(same(clientSession), any(BsonDocument.class)); + doReturn(mongoCollection).when(mongoDatabase).getCollection(anyString(), eq(BsonDocument.class)); + doThrow(dbAccessException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); String mql = """ { @@ -281,9 +278,6 @@ private void checkMethodsWithOpenPrecondition() { () -> assertThrowsClosedException(mongoStatement::getResultSet), () -> assertThrowsClosedException(mongoStatement::getMoreResults), () -> assertThrowsClosedException(mongoStatement::getUpdateCount), - () -> assertThrowsClosedException(() -> mongoStatement.addBatch(exampleUpdateMql)), - () -> assertThrowsClosedException(mongoStatement::clearBatch), - () -> assertThrowsClosedException(mongoStatement::executeBatch), () -> assertThrowsClosedException(mongoStatement::getConnection), () -> assertThrowsClosedException(() -> mongoStatement.isWrapperFor(MongoStatement.class))); } From ded2bb775415ff78e3f69404dc9ce0f3da498414 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Thu, 2 Oct 2025 15:03:59 -0700 Subject: [PATCH 03/37] Rename tests. --- ...ongoPreparedStatementIntegrationTests.java | 2 +- .../jdbc/MongoPreparedStatement.java | 20 +-- .../hibernate/jdbc/MongoStatement.java | 65 +++++---- .../jdbc/MongoPreparedStatementTests.java | 125 +++++++----------- 4 files changed, 97 insertions(+), 115 deletions(-) diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java index 3cb88160..f9ad1fc9 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java @@ -640,7 +640,7 @@ void testDelete() { }"""))); } - @ParameterizedTest(name = "testNotSupportedCommands. Parameters: {0}") + @ParameterizedTest(name = "test not supported commands. Parameters: {0}") @ValueSource(strings = {"findAndModify", "aggregate", "bulkWrite"}) void testNotSupportedCommands(String commandName) { doWorkAwareOfAutoCommit(connection -> { diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java index 9c23e338..06a3c36e 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java @@ -55,7 +55,7 @@ final class MongoPreparedStatement extends MongoStatement implements PreparedSta private static final int[] EMPTY_BATCH_RESULT = new int[0]; private final BsonDocument command; - private final List batchCommands; + private final List commandBatch; private final List parameterValueSetters; MongoPreparedStatement( @@ -63,7 +63,7 @@ final class MongoPreparedStatement extends MongoStatement implements PreparedSta throws SQLSyntaxErrorException { super(mongoDatabase, clientSession, mongoConnection); command = MongoStatement.parse(mql); - batchCommands = new ArrayList<>(); + commandBatch = new ArrayList<>(); parameterValueSetters = new ArrayList<>(); parseParameters(command, parameterValueSetters); } @@ -81,7 +81,7 @@ public int executeUpdate() throws SQLException { checkClosed(); closeLastOpenResultSet(); checkAllParametersSet(); - checkUpdateOperation(command); + checkSupportedUpdateCommand(command); return executeUpdateCommand(command); } @@ -206,31 +206,31 @@ public void setObject(int parameterIndex, Object x, int targetSqlType) throws SQ public void addBatch() throws SQLException { checkClosed(); checkAllParametersSet(); - checkUpdateOperation(command); - batchCommands.add(command.clone()); + commandBatch.add(command.clone()); } @Override public void clearBatch() throws SQLException { checkClosed(); - batchCommands.clear(); + commandBatch.clear(); } @Override public int[] executeBatch() throws SQLException { checkClosed(); closeLastOpenResultSet(); - if (batchCommands.isEmpty()) { + if (commandBatch.isEmpty()) { return EMPTY_BATCH_RESULT; } + checkSupportedBatchCommand(commandBatch.get(0)); try { - executeBulkWrite(batchCommands, ExecutionType.BATCH); - var rowCounts = new int[batchCommands.size()]; + executeBulkWrite(commandBatch, ExecutionType.BATCH); + var rowCounts = new int[commandBatch.size()]; // We cannot determine the actual number of rows affected for each command in the batch. Arrays.fill(rowCounts, Statement.SUCCESS_NO_INFO); return rowCounts; } finally { - batchCommands.clear(); + commandBatch.clear(); } } diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 4aec7e1f..368b9652 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -19,7 +19,6 @@ import static com.mongodb.assertions.Assertions.fail; import static com.mongodb.hibernate.internal.MongoAssertions.assertNotNull; import static com.mongodb.hibernate.internal.MongoConstants.ID_FIELD_NAME; -import static com.mongodb.hibernate.internal.MongoConstants.MONGO_DBMS_NAME; import static com.mongodb.hibernate.internal.VisibleForTesting.AccessModifier.PRIVATE; import static java.lang.String.format; import static java.util.Collections.singletonList; @@ -30,7 +29,6 @@ import com.mongodb.MongoSocketReadTimeoutException; import com.mongodb.MongoSocketWriteTimeoutException; import com.mongodb.MongoTimeoutException; -import com.mongodb.bulk.BulkWriteError; import com.mongodb.bulk.BulkWriteResult; import com.mongodb.client.ClientSession; import com.mongodb.client.MongoCollection; @@ -94,11 +92,11 @@ void closeLastOpenResultSet() throws SQLException { } ResultSet executeQueryCommand(BsonDocument command) throws SQLException { - var mongoCommand = getCommandType(command); + var commandType = getCommandType(command); + checkSupportedQueryCommand(command); try { startTransactionIfNeeded(); - checkQueryOperation(command); - var collection = getCollection(mongoCommand, command); + var collection = getCollection(commandType, command); var pipeline = command.getArray("pipeline").stream() .map(BsonValue::asDocument) .toList(); @@ -107,8 +105,8 @@ ResultSet executeQueryCommand(BsonDocument command) throws SQLException { return resultSet = new MongoResultSet( collection.aggregate(clientSession, pipeline).cursor(), fieldNames); - } catch (Exception exception) { - throw handleException(exception, mongoCommand, ExecutionType.QUERY); + } catch (RuntimeException exception) { + throw handleException(exception, commandType, ExecutionType.QUERY); } } @@ -147,7 +145,7 @@ public int executeUpdate(String mql) throws SQLException { checkClosed(); closeLastOpenResultSet(); var command = parse(mql); - checkUpdateOperation(command); + checkSupportedUpdateCommand(command); return executeUpdateCommand(command); } @@ -167,7 +165,7 @@ int executeBulkWrite(List commandBatch, ExecutionType ex } var bulkWriteResult = collection.bulkWrite(clientSession, writeModels); return getUpdateCount(commandType, bulkWriteResult); - } catch (Exception exception) { + } catch (RuntimeException exception) { throw handleException(exception, commandType, executionType); } } @@ -247,16 +245,19 @@ void checkClosed() throws SQLException { } } - private void checkQueryOperation(BsonDocument command) throws SQLFeatureNotSupportedException { - CommandType commandType = getCommandType(command); + private void checkSupportedQueryCommand(BsonDocument command) throws SQLFeatureNotSupportedException { + var commandType = getCommandType(command); if (commandType != CommandType.AGGREGATE) { throw new SQLFeatureNotSupportedException( format("Unsupported command for query operation: %s", commandType.getCommandName())); } } - void checkUpdateOperation(BsonDocument command) throws SQLException { - CommandType commandType = getCommandType(command); + void checkSupportedUpdateCommand(BsonDocument command) throws SQLException { + checkSupportedUpdateCommand(getCommandType(command)); + } + + private void checkSupportedUpdateCommand(CommandType commandType) throws SQLException { if (commandType != CommandType.INSERT && commandType != CommandType.UPDATE && commandType != CommandType.DELETE) { @@ -265,6 +266,21 @@ void checkUpdateOperation(BsonDocument command) throws SQLException { } } + void checkSupportedBatchCommand(BsonDocument command) throws SQLException { + var commandType = getCommandType(command); + if (commandType == CommandType.AGGREGATE) { + // The method executeBatch throws a BatchUpdateException if any of the commands in the batch attempts to + // return a result set. + throw new BatchUpdateException( + format( + "Commands returning result set are not supported. Received command: %s", + commandType.getCommandName()), + null, + new int[0]); + } + checkSupportedUpdateCommand(commandType); + } + static BsonDocument parse(String mql) throws SQLSyntaxErrorException { try { return BsonDocument.parse(mql); @@ -295,17 +311,12 @@ private MongoCollection getCollection(CommandType commandType, Bso } private static SQLException handleException( - Exception exception, CommandType commandType, ExecutionType executionType) { - - if (exception instanceof SQLException sqlException) { - return sqlException; - } + RuntimeException exception, CommandType commandType, ExecutionType executionType) { if (exception instanceof MongoSocketReadTimeoutException || exception instanceof MongoSocketWriteTimeoutException || exception instanceof MongoTimeoutException || exception instanceof MongoExecutionTimeoutException) { - return new SQLTimeoutException( - format("Timeout while waiting for %s operation to complete", MONGO_DBMS_NAME), exception); + return new SQLTimeoutException("Timeout while waiting for operation to complete", exception); } if (exception instanceof MongoBulkWriteException mongoBulkWriteException) { @@ -325,16 +336,16 @@ private static SQLException handleException( static BatchUpdateException createBatchUpdateException( MongoBulkWriteException mongoBulkWriteException, CommandType commandType) { - int updateCount = getUpdateCount(commandType, mongoBulkWriteException.getWriteResult()); - int[] updateCounts = new int[updateCount]; + var updateCount = getUpdateCount(commandType, mongoBulkWriteException.getWriteResult()); + var updateCounts = new int[updateCount]; Arrays.fill(updateCounts, Statement.SUCCESS_NO_INFO); - int code = getErrorCode(mongoBulkWriteException); + var code = getErrorCode(mongoBulkWriteException); return new BatchUpdateException( EXCEPTION_MESSAGE_FAILED_TO_EXECUTE_OPERATION, null, code, updateCounts, mongoBulkWriteException); } private static int getErrorCode(final MongoBulkWriteException mongoBulkWriteException) { - List writeErrors = mongoBulkWriteException.getWriteErrors(); + var writeErrors = mongoBulkWriteException.getWriteErrors(); // Since we are executing an ordered bulk write, there will be at most one BulkWriteError. return writeErrors.isEmpty() ? 0 : writeErrors.get(0).getCode(); } @@ -397,8 +408,7 @@ private static WriteModel createUpdateModel(final BsonDocument upd private static void checkDeleteElements(final BsonDocument deleteDocument) throws SQLFeatureNotSupportedException { if (deleteDocument.size() > SUPPORTED_DELETE_COMMAND_ELEMENTS.size()) { - List unSupportedElements = - getUnsupportedElements(deleteDocument, SUPPORTED_DELETE_COMMAND_ELEMENTS); + var unSupportedElements = getUnsupportedElements(deleteDocument, SUPPORTED_DELETE_COMMAND_ELEMENTS); throw new SQLFeatureNotSupportedException( format("Unsupported elements in delete command: %s", unSupportedElements)); } @@ -406,8 +416,7 @@ private static void checkDeleteElements(final BsonDocument deleteDocument) throw private static void checkUpdateElements(final BsonDocument updateDocument) throws SQLFeatureNotSupportedException { if (updateDocument.size() > SUPPORTED_UPDATE_COMMAND_ELEMENTS.size()) { - List unSupportedElements = - getUnsupportedElements(updateDocument, SUPPORTED_UPDATE_COMMAND_ELEMENTS); + var unSupportedElements = getUnsupportedElements(updateDocument, SUPPORTED_UPDATE_COMMAND_ELEMENTS); throw new SQLFeatureNotSupportedException( format("Unsupported elements in update command: %s", unSupportedElements)); } diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java index b9751cac..40fec870 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java @@ -20,7 +20,6 @@ import static java.util.Collections.emptyList; import static java.util.Collections.emptySet; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatException; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -239,39 +238,59 @@ void beforeEach() { } private static Stream exceptions() { + var dummyCause = new RuntimeException(); return Stream.of( - Arguments.of(new MongoException(DUMMY_EXCEPTION_MESSAGE)), - Arguments.of(new RuntimeException(DUMMY_EXCEPTION_MESSAGE))); + Arguments.of(new MongoException(DUMMY_EXCEPTION_MESSAGE), SQLException.class), + Arguments.of(new RuntimeException(DUMMY_EXCEPTION_MESSAGE), SQLException.class), + Arguments.of( + new MongoExecutionTimeoutException(DUMMY_EXCEPTION_MESSAGE), SQLTimeoutException.class), + Arguments.of( + new MongoSocketReadTimeoutException( + DUMMY_EXCEPTION_MESSAGE, DUMMY_SERVER_ADDRESS, dummyCause), + SQLTimeoutException.class), + Arguments.of( + new MongoSocketWriteTimeoutException( + DUMMY_EXCEPTION_MESSAGE, DUMMY_SERVER_ADDRESS, dummyCause), + SQLTimeoutException.class), + Arguments.of(new MongoTimeoutException(DUMMY_EXCEPTION_MESSAGE), SQLTimeoutException.class), + Arguments.of( + new MongoOperationTimeoutException(DUMMY_EXCEPTION_MESSAGE), SQLTimeoutException.class)); } - @ParameterizedTest(name = "testExecuteQueryThrowsSqlException: exception={0}") + @ParameterizedTest(name = "test executeQuery throws SQLException. Parameters: exception={0}") @MethodSource("exceptions") - void testExecuteQueryThrowsSqlExceptionWhenExceptionOccurs(Exception exception) throws SQLException { - doThrow(exception).when(mongoCollection).aggregate(eq(clientSession), anyList()); - assertExecuteThrowsSqlException(MQL_ITEMS_AGGREGATE, MongoPreparedStatement::executeQuery, exception); + void testExecuteQueryThrowsSqlException(Exception thrownException, Class expectedType) + throws SQLException { + doThrow(thrownException).when(mongoCollection).aggregate(eq(clientSession), anyList()); + assertExecuteThrowsSqlException( + MQL_ITEMS_AGGREGATE, MongoPreparedStatement::executeQuery, thrownException, expectedType); } - @ParameterizedTest(name = "testExecuteUpdateThrowsSqlException: exception={0}") + @ParameterizedTest(name = "test executeUpdate throws SQLException. Parameters: exception={0}") @MethodSource("exceptions") - void testExecuteUpdateThrowsSqlExceptionWhenExceptionOccurs(Exception exception) throws SQLException { - doThrow(exception).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); - assertExecuteThrowsSqlException(MQL_ITEMS_INSERT, MongoPreparedStatement::executeUpdate, exception); + void testExecuteUpdateThrowsSqlException(Exception thrownException, Class expectedType) + throws SQLException { + doThrow(thrownException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + assertExecuteThrowsSqlException( + MQL_ITEMS_INSERT, MongoPreparedStatement::executeUpdate, thrownException, expectedType); } - @ParameterizedTest(name = "testExecuteUpdateThrowsSqlException: exception={0}") + @ParameterizedTest(name = "test executeUpdate throws SQLException. Parameters: exception={0}") @MethodSource("exceptions") - void testExecuteBatchThrowsSqlExceptionWhenExceptionOccurs(Exception exception) throws SQLException { - doThrow(exception).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + void testExecuteBatchThrowsSqlException(Exception thrownException, Class expectedType) + throws SQLException { + doThrow(thrownException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); assertExecuteThrowsSqlException( MQL_ITEMS_INSERT, mongoPreparedStatement -> { mongoPreparedStatement.addBatch(); mongoPreparedStatement.executeBatch(); }, - exception); + thrownException, + expectedType); } - private static Stream argumentsForExecuteUpdate() { + private static Stream bulkWriteExceptionsForExecuteUpdate() { return Stream.of( Arguments.of(named("insert", MQL_ITEMS_INSERT), MONGO_BULK_WRITE_EXCEPTION_NO_ERRORS), Arguments.of(named("update", MQL_ITEMS_UPDATE), MONGO_BULK_WRITE_EXCEPTION_NO_ERRORS), @@ -281,8 +300,10 @@ private static Stream argumentsForExecuteUpdate() { Arguments.of(named("delete", MQL_ITEMS_DELETE), MONGO_BULK_WRITE_EXCEPTION_WITH_ERRORS)); } - @ParameterizedTest(name = "testUpdateThrowsSqlException: commandName={0}, exception={1}") - @MethodSource("argumentsForExecuteUpdate") + @ParameterizedTest( + name = "test executeUpdate throws SQLException when MongoBulkWriteException occurs." + + " Parameters: commandName={0}, exception={1}") + @MethodSource("bulkWriteExceptionsForExecuteUpdate") void testExecuteUpdateThrowsSqlExceptionWhenMongoBulkWriteExceptionOccurs( String mql, MongoBulkWriteException mongoBulkWriteException) throws SQLException { doThrow(mongoBulkWriteException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); @@ -299,7 +320,7 @@ void testExecuteUpdateThrowsSqlExceptionWhenMongoBulkWriteExceptionOccurs( } } - private static Stream argumentsForExecuteBatch() { + private static Stream bulkWriteExceptionsForExecuteBatch() { return Stream.of( Arguments.of( named("insert", MQL_ITEMS_INSERT), @@ -327,8 +348,10 @@ private static Stream argumentsForExecuteBatch() { BULK_WRITE_RESULT.getDeletedCount())); } - @ParameterizedTest(name = "testBatchUpdateException: commandName={0}, exception={1}") - @MethodSource("argumentsForExecuteBatch") + @ParameterizedTest( + name = "test executeBatch throws BatchUpdateException when MongoBulkWriteException occurs." + + " Parameters: commandName={0}, exception={1}") + @MethodSource("bulkWriteExceptionsForExecuteBatch") void testExecuteBatchThrowsBatchUpdateExceptionWhenMongoBulkWriteExceptionOccurs( String mql, MongoBulkWriteException mongoBulkWriteException, int expectedUpdateCountLength) throws SQLException { @@ -349,64 +372,14 @@ void testExecuteBatchThrowsBatchUpdateExceptionWhenMongoBulkWriteExceptionOccurs } } - private static Stream timeoutExceptions() { - RuntimeException dummyCause = new RuntimeException(); - return Stream.of( - new MongoExecutionTimeoutException(DUMMY_EXCEPTION_MESSAGE), - new MongoSocketReadTimeoutException(DUMMY_EXCEPTION_MESSAGE, DUMMY_SERVER_ADDRESS, dummyCause), - new MongoSocketWriteTimeoutException(DUMMY_EXCEPTION_MESSAGE, DUMMY_SERVER_ADDRESS, dummyCause), - new MongoTimeoutException(DUMMY_EXCEPTION_MESSAGE), - new MongoOperationTimeoutException(DUMMY_EXCEPTION_MESSAGE)); - } - - @ParameterizedTest( - name = - "SQLTimeoutException is thrown when timeout exception occurs in executeQuery. Parameters: timeoutException={0}") - @MethodSource("timeoutExceptions") - void testExecuteQuerySqlTimeoutException(MongoException timeoutException) throws SQLException { - doThrow(timeoutException).when(mongoCollection).aggregate(eq(clientSession), anyList()); - try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(MQL_ITEMS_AGGREGATE)) { - assertThatException() - .isThrownBy(mongoPreparedStatement::executeQuery) - .isInstanceOf(SQLTimeoutException.class) - .withCause(timeoutException); - } - } - - @ParameterizedTest( - name = - "SQLTimeoutException is thrown when timeout exception occurs in executeUpdate. Parameters: timeoutException={0}") - @MethodSource("timeoutExceptions") - void testExecuteUpdateSqlTimeoutException(MongoException timeoutException) throws SQLException { - doThrow(timeoutException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); - try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(MQL_ITEMS_INSERT)) { - assertThatException() - .isThrownBy(mongoPreparedStatement::executeUpdate) - .isInstanceOf(SQLTimeoutException.class) - .withCause(timeoutException); - } - } - - @ParameterizedTest( - name = - "SQLTimeoutException is thrown when timeout exception occurs in executeBatch. Parameters: timeoutException={0}") - @MethodSource("timeoutExceptions") - void testExecuteBatchSqlTimeoutException(MongoException timeoutException) throws SQLException { - doThrow(timeoutException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); - try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(MQL_ITEMS_INSERT)) { - mongoPreparedStatement.addBatch(); - assertThatException() - .isThrownBy(mongoPreparedStatement::executeBatch) - .isInstanceOf(SQLTimeoutException.class) - .withCause(timeoutException); - } - } - private void assertExecuteThrowsSqlException( - String mql, SqlConsumer executeConsumer, Exception expectedCause) + String mql, + SqlConsumer executeConsumer, + Exception expectedCause, + Class expectedExceptionType) throws SQLException { try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(mql)) { - assertThatExceptionOfType(SQLException.class) + assertThatExceptionOfType(expectedExceptionType) .isThrownBy(() -> executeConsumer.accept(mongoPreparedStatement)) .withCause(expectedCause) .satisfies(sqlException -> { From 92d9c41ee79b18d3301f05e6c5c63b99d107ab95 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Thu, 2 Oct 2025 15:24:14 -0700 Subject: [PATCH 04/37] Rename variables. --- .../jdbc/MongoPreparedStatement.java | 6 ++-- .../hibernate/jdbc/MongoStatement.java | 12 ++++--- .../jdbc/MongoPreparedStatementTests.java | 34 +++++++++++-------- .../hibernate/jdbc/MongoStatementTests.java | 12 ++++--- 4 files changed, 38 insertions(+), 26 deletions(-) diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java index 06a3c36e..61f3a3f8 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java @@ -225,10 +225,10 @@ public int[] executeBatch() throws SQLException { checkSupportedBatchCommand(commandBatch.get(0)); try { executeBulkWrite(commandBatch, ExecutionType.BATCH); - var rowCounts = new int[commandBatch.size()]; + var updateCounts = new int[commandBatch.size()]; // We cannot determine the actual number of rows affected for each command in the batch. - Arrays.fill(rowCounts, Statement.SUCCESS_NO_INFO); - return rowCounts; + Arrays.fill(updateCounts, Statement.SUCCESS_NO_INFO); + return updateCounts; } finally { commandBatch.clear(); } diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 368b9652..72afc6e5 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -312,10 +312,7 @@ private MongoCollection getCollection(CommandType commandType, Bso private static SQLException handleException( RuntimeException exception, CommandType commandType, ExecutionType executionType) { - if (exception instanceof MongoSocketReadTimeoutException - || exception instanceof MongoSocketWriteTimeoutException - || exception instanceof MongoTimeoutException - || exception instanceof MongoExecutionTimeoutException) { + if (isTimeoutException(exception)) { return new SQLTimeoutException("Timeout while waiting for operation to complete", exception); } @@ -334,6 +331,13 @@ private static SQLException handleException( return new SQLException(EXCEPTION_MESSAGE_FAILED_TO_EXECUTE_OPERATION, exception); } + private static boolean isTimeoutException(final RuntimeException exception) { + return exception instanceof MongoSocketReadTimeoutException + || exception instanceof MongoSocketWriteTimeoutException + || exception instanceof MongoTimeoutException + || exception instanceof MongoExecutionTimeoutException; + } + static BatchUpdateException createBatchUpdateException( MongoBulkWriteException mongoBulkWriteException, CommandType commandType) { var updateCount = getUpdateCount(commandType, mongoBulkWriteException.getWriteResult()); diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java index 40fec870..49cfc254 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java @@ -257,36 +257,38 @@ private static Stream exceptions() { new MongoOperationTimeoutException(DUMMY_EXCEPTION_MESSAGE), SQLTimeoutException.class)); } - @ParameterizedTest(name = "test executeQuery throws SQLException. Parameters: exception={0}") + @ParameterizedTest(name = "test executeQuery throws SQLException. Parameters: exception={0}, expectedType={1}") @MethodSource("exceptions") - void testExecuteQueryThrowsSqlException(Exception thrownException, Class expectedType) + void testExecuteQueryThrowsSqlException(Exception exceptionToThrow, Class expectedType) throws SQLException { - doThrow(thrownException).when(mongoCollection).aggregate(eq(clientSession), anyList()); - assertExecuteThrowsSqlException( - MQL_ITEMS_AGGREGATE, MongoPreparedStatement::executeQuery, thrownException, expectedType); + doThrow(exceptionToThrow).when(mongoCollection).aggregate(eq(clientSession), anyList()); + + assertExecuteThrowsSqlException(MQL_ITEMS_AGGREGATE, MongoPreparedStatement::executeQuery, exceptionToThrow, expectedType); } - @ParameterizedTest(name = "test executeUpdate throws SQLException. Parameters: exception={0}") + @ParameterizedTest(name = "test executeUpdate throws SQLException. Parameters: exception={0}, expectedType={1}") @MethodSource("exceptions") - void testExecuteUpdateThrowsSqlException(Exception thrownException, Class expectedType) + void testExecuteUpdateThrowsSqlException(Exception exceptionToThrow, Class expectedType) throws SQLException { - doThrow(thrownException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + doThrow(exceptionToThrow).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + assertExecuteThrowsSqlException( - MQL_ITEMS_INSERT, MongoPreparedStatement::executeUpdate, thrownException, expectedType); + MQL_ITEMS_INSERT, MongoPreparedStatement::executeUpdate, exceptionToThrow, expectedType); } - @ParameterizedTest(name = "test executeUpdate throws SQLException. Parameters: exception={0}") + @ParameterizedTest(name = "test executeUpdate throws SQLException. Parameters: exception={0}, expectedType={1}") @MethodSource("exceptions") - void testExecuteBatchThrowsSqlException(Exception thrownException, Class expectedType) + void testExecuteBatchThrowsSqlException(Exception exceptionToThrow, Class expectedType) throws SQLException { - doThrow(thrownException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + doThrow(exceptionToThrow).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + assertExecuteThrowsSqlException( MQL_ITEMS_INSERT, mongoPreparedStatement -> { mongoPreparedStatement.addBatch(); mongoPreparedStatement.executeBatch(); }, - thrownException, + exceptionToThrow, expectedType); } @@ -307,12 +309,13 @@ private static Stream bulkWriteExceptionsForExecuteUpdate() { void testExecuteUpdateThrowsSqlExceptionWhenMongoBulkWriteExceptionOccurs( String mql, MongoBulkWriteException mongoBulkWriteException) throws SQLException { doThrow(mongoBulkWriteException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + Integer vendorCodeError = getVendorCodeError(mongoBulkWriteException); + try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(mql)) { assertThatExceptionOfType(SQLException.class) .isThrownBy(mongoPreparedStatement::executeUpdate) .withCause(mongoBulkWriteException) .satisfies(sqlException -> { - Integer vendorCodeError = getVendorCodeError(mongoBulkWriteException); assertAll( () -> assertNull(sqlException.getSQLState()), () -> assertEquals(vendorCodeError, sqlException.getErrorCode())); @@ -356,13 +359,14 @@ void testExecuteBatchThrowsBatchUpdateExceptionWhenMongoBulkWriteExceptionOccurs String mql, MongoBulkWriteException mongoBulkWriteException, int expectedUpdateCountLength) throws SQLException { doThrow(mongoBulkWriteException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + Integer vendorCodeError = getVendorCodeError(mongoBulkWriteException); + try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(mql)) { mongoPreparedStatement.addBatch(); assertThatExceptionOfType(BatchUpdateException.class) .isThrownBy(mongoPreparedStatement::executeBatch) .withCause(mongoBulkWriteException) .satisfies(batchUpdateException -> { - Integer vendorCodeError = getVendorCodeError(mongoBulkWriteException); assertAll( () -> assertUpdateCounts( batchUpdateException.getUpdateCounts(), expectedUpdateCountLength), diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java index c7a2a25e..a06c20a1 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java @@ -16,6 +16,7 @@ package com.mongodb.hibernate.jdbc; +import static java.util.Collections.emptyList; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; @@ -30,6 +31,7 @@ import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; +import com.mongodb.bulk.BulkWriteInsert; import com.mongodb.bulk.BulkWriteResult; import com.mongodb.client.AggregateIterable; import com.mongodb.client.ClientSession; @@ -43,6 +45,8 @@ import java.util.List; import java.util.function.BiConsumer; import org.bson.BsonDocument; +import org.bson.BsonObjectId; +import org.bson.types.ObjectId; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -140,8 +144,9 @@ class ExecuteMethodClosesLastOpenResultSetTests { @Mock MongoCursor mongoCursor; - @Mock - BulkWriteResult bulkWriteResult; + private static final BulkWriteResult BULK_WRITE_RESULT = BulkWriteResult.acknowledged( + 1, 0, 2, 3, emptyList(), + List.of(new BulkWriteInsert(0, new BsonObjectId(new ObjectId(1, 1))))); private ResultSet lastOpenResultSet; @@ -164,8 +169,7 @@ void testExecuteQuery() throws SQLException { @Test void testExecuteUpdate() throws SQLException { doReturn(mongoCollection).when(mongoDatabase).getCollection(anyString(), eq(BsonDocument.class)); - doReturn(bulkWriteResult).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); - doReturn(10).when(bulkWriteResult).getModifiedCount(); + doReturn(BULK_WRITE_RESULT).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); mongoStatement.executeUpdate(exampleUpdateMql); assertTrue(lastOpenResultSet.isClosed()); From b78cb7796fc4b7fbb7a774d63b34430934a2ddd0 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Thu, 2 Oct 2025 16:56:15 -0700 Subject: [PATCH 05/37] Rename methods, add tests. --- ...ongoPreparedStatementIntegrationTests.java | 317 ++++++++++-------- .../query/AbstractQueryIntegrationTests.java | 6 +- .../mutation/BatchUpdateIntegrationTests.java | 6 +- ...imitOffsetFetchClauseIntegrationTests.java | 2 +- .../hibernate/jdbc/MongoStatement.java | 3 +- .../jdbc/MongoPreparedStatementTests.java | 3 +- .../hibernate/jdbc/MongoStatementTests.java | 3 +- 7 files changed, 184 insertions(+), 156 deletions(-) diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java index f9ad1fc9..1c83ec6f 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java @@ -23,6 +23,7 @@ import static java.lang.String.format; import static java.sql.Statement.SUCCESS_NO_INFO; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -37,6 +38,7 @@ import com.mongodb.hibernate.junit.InjectMongoCollection; import com.mongodb.hibernate.junit.MongoExtension; import java.math.BigDecimal; +import java.sql.BatchUpdateException; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.SQLException; @@ -251,22 +253,39 @@ class ExecuteBatchTests { ] }"""; + @Test + @DisplayName("executeBatch throws a BatchUpdateException for command returning ResultSet") + void testQueriesReturningResult() { + doWorkAwareOfAutoCommit(connection -> { + try (var pstm = connection.prepareStatement( + """ + { + aggregate: "books", + pipeline: [ + { $match: { _id: 1 } } + ] + }""")) { + pstm.addBatch(); + assertThatExceptionOfType(BatchUpdateException.class) + .isThrownBy(pstm::executeBatch) + .satisfies(batchUpdateException -> { + assertNull(batchUpdateException.getUpdateCounts()); + assertNull(batchUpdateException.getSQLState()); + assertEquals(0, batchUpdateException.getErrorCode()); + }); + } catch (SQLException e) { + throw new RuntimeException(e); + } + }); + + assertThat(mongoCollection.find()).isEmpty(); + } + @Test void testEmptyBatch() { doWorkAwareOfAutoCommit(connection -> { - try { - var pstmt = (MongoPreparedStatement) - connection.prepareStatement( - """ - { - insert: "books", - documents: [ - { - _id: 1 - } - ] - }"""); - int[] updateCounts = pstmt.executeBatch(); + try (var pstmt = connection.prepareStatement(INSERT_MQL)) { + var updateCounts = pstmt.executeBatch(); assertEquals(0, updateCounts.length); } catch (SQLException e) { throw new RuntimeException(e); @@ -280,25 +299,26 @@ void testEmptyBatch() { @DisplayName("Test statement’s batch queue is reset once executeBatch returns") void testBatchQueueIsResetAfterExecute() { doWorkAwareOfAutoCommit(connection -> { - var pstmt = (MongoPreparedStatement) - connection.prepareStatement( - """ + try (var pstmt = connection.prepareStatement( + """ + { + insert: "books", + documents: [ { - insert: "books", - documents: [ - { - _id: {$undefined: true}, - title: {$undefined: true} - } - ] - }"""); - - pstmt.setInt(1, 1); - pstmt.setString(2, "War and Peace"); - pstmt.addBatch(); - assertExecuteBatch(pstmt, 1); + _id: {$undefined: true}, + title: {$undefined: true} + } + ] + }""")) { - assertExecuteBatch(pstmt, 0); + pstmt.setInt(1, 1); + pstmt.setString(2, "War and Peace"); + pstmt.addBatch(); + assertExecuteBatch(pstmt, 1); + assertExecuteBatch(pstmt, 0); + } catch (SQLException e) { + throw new RuntimeException(e); + } }); assertThat(mongoCollection.find()) @@ -315,28 +335,30 @@ void testBatchQueueIsResetAfterExecute() { @DisplayName("Test values set for the parameter markers of PreparedStatement are not reset when it is executed") void testBatchParametersReuse() { doWorkAwareOfAutoCommit(connection -> { - var pstmt = (MongoPreparedStatement) - connection.prepareStatement( - """ + try (var pstmt = connection.prepareStatement( + """ + { + insert: "books", + documents: [ { - insert: "books", - documents: [ - { - _id: {$undefined: true}, - title: {$undefined: true} - } - ] - }"""); + _id: {$undefined: true}, + title: {$undefined: true} + } + ] + }""")) { - pstmt.setInt(1, 1); - pstmt.setString(2, "War and Peace"); - pstmt.addBatch(); - assertExecuteBatch(pstmt, 1); - - pstmt.setInt(1, 2); - // No need to set title again, it should be reused from the previous execution - pstmt.addBatch(); - assertExecuteBatch(pstmt, 1); + pstmt.setInt(1, 1); + pstmt.setString(2, "War and Peace"); + pstmt.addBatch(); + assertExecuteBatch(pstmt, 1); + + pstmt.setInt(1, 2); + // No need to set title again, it should be reused from the previous execution + pstmt.addBatch(); + assertExecuteBatch(pstmt, 1); + } catch (SQLException e) { + throw new RuntimeException(e); + } }); assertThat(mongoCollection.find()) @@ -357,25 +379,27 @@ void testBatchParametersReuse() { @Test void testBatchInsert() { - int batchCount = 3; + var batchCount = 3; doWorkAwareOfAutoCommit(connection -> { - var pstmt = (MongoPreparedStatement) - connection.prepareStatement( - """ - { - insert: "books", - documents: [{ - _id: {$undefined: true}, - title: {$undefined: true} - }] - }"""); - - for (int i = 1; i <= batchCount; i++) { - pstmt.setInt(1, i); - pstmt.setString(2, "Book " + i); - pstmt.addBatch(); + try (var pstmt = connection.prepareStatement( + """ + { + insert: "books", + documents: [{ + _id: {$undefined: true}, + title: {$undefined: true} + }] + }""")) { + + for (int i = 1; i <= batchCount; i++) { + pstmt.setInt(1, i); + pstmt.setString(2, "Book " + i); + pstmt.addBatch(); + } + assertExecuteBatch(pstmt, batchCount); + } catch (SQLException e) { + throw new RuntimeException(e); } - assertExecuteBatch(pstmt, batchCount); }); var expectedDocs = new ArrayList(); @@ -394,26 +418,27 @@ void testBatchInsert() { @Test void testBatchUpdate() { insertTestData(session, INSERT_MQL); - - int batchCount = 3; + var batchCount = 3; doWorkAwareOfAutoCommit(connection -> { - var pstmt = (MongoPreparedStatement) - connection.prepareStatement( - """ - { - update: "books", - updates: [{ - q: { _id: { $undefined: true } }, - u: { $set: { title: { $undefined: true } } }, - multi: true - }] - }"""); - for (int i = 1; i <= batchCount; i++) { - pstmt.setInt(1, i); - pstmt.setString(2, "Book " + i); - pstmt.addBatch(); + try (var pstmt = connection.prepareStatement( + """ + { + update: "books", + updates: [{ + q: { _id: { $undefined: true } }, + u: { $set: { title: { $undefined: true } } }, + multi: true + }] + }""")) { + for (int i = 1; i <= batchCount; i++) { + pstmt.setInt(1, i); + pstmt.setString(2, "Book " + i); + pstmt.addBatch(); + } + assertExecuteBatch(pstmt, batchCount); + } catch (SQLException e) { + throw new RuntimeException(e); } - assertExecuteBatch(pstmt, batchCount); }); var expectedDocs = new ArrayList(); @@ -432,32 +457,34 @@ void testBatchUpdate() { @Test void testBatchDelete() { insertTestData(session, INSERT_MQL); - - int batchCount = 3; + var batchCount = 3; doWorkAwareOfAutoCommit(connection -> { - var pstmt = (MongoPreparedStatement) - connection.prepareStatement( - """ - { - delete: "books", - deletes: [{ - q: { _id: { $undefined: true } }, - limit: 0 - }] - }"""); - for (int i = 1; i <= batchCount; i++) { - pstmt.setInt(1, i); - pstmt.addBatch(); + try (var pstmt = connection.prepareStatement( + """ + { + delete: "books", + deletes: [{ + q: { _id: { $undefined: true } }, + limit: 0 + }] + }""")) { + for (int i = 1; i <= batchCount; i++) { + pstmt.setInt(1, i); + pstmt.addBatch(); + } + assertExecuteBatch(pstmt, batchCount); + } catch (SQLException e) { + throw new RuntimeException(e); } - assertExecuteBatch(pstmt, batchCount); }); assertThat(mongoCollection.find()).isEmpty(); } - private void assertExecuteBatch(MongoPreparedStatement pstmt, int expectedBatchResultSize) throws SQLException { + private static void assertExecuteBatch(PreparedStatement pstmt, int expectedUpdateCountsSize) + throws SQLException { int[] updateCounts = pstmt.executeBatch(); - assertEquals(expectedBatchResultSize, updateCounts.length); + assertEquals(expectedUpdateCountsSize, updateCounts.length); for (int updateCount : updateCounts) { assertEquals(SUCCESS_NO_INFO, updateCount); } @@ -498,7 +525,7 @@ class ExecuteUpdateTests { @Test void testInsert() { - Function pstmtProvider = connection -> { + Function pstmtProvider = connection -> { try { var pstmt = (MongoPreparedStatement) connection.prepareStatement( @@ -570,26 +597,25 @@ void testUpdate() { outOfStock: false, tags: [ "classic", "dostoevsky", "literature" ] }""")); - Function pstmtProvider = connection -> { + Function pstmtProvider = connection -> { try { - var pstmt = (MongoPreparedStatement) - connection.prepareStatement( - """ + var pstmt = connection.prepareStatement( + """ + { + update: "books", + updates: [ { - update: "books", - updates: [ - { - q: { author: { $undefined: true } }, - u: { - $set: { - outOfStock: { $undefined: true } - }, - $push: { tags: { $undefined: true } } - }, - multi: true - } - ] - }"""); + q: { author: { $undefined: true } }, + u: { + $set: { + outOfStock: { $undefined: true } + }, + $push: { tags: { $undefined: true } } + }, + multi: true + } + ] + }"""); pstmt.setString(1, "Leo Tolstoy"); pstmt.setBoolean(2, true); pstmt.setString(3, "literature"); @@ -605,20 +631,19 @@ void testUpdate() { void testDelete() { insertTestData(session, INSERT_MQL); - Function pstmtProvider = connection -> { + Function pstmtProvider = connection -> { try { - var pstmt = (MongoPreparedStatement) - connection.prepareStatement( - """ + var pstmt = connection.prepareStatement( + """ + { + delete: "books", + deletes: [ { - delete: "books", - deletes: [ - { - q: { author: { $undefined: true } }, - limit: 0 - } - ] - }"""); + q: { author: { $undefined: true } }, + limit: 0 + } + ] + }"""); pstmt.setString(1, "Leo Tolstoy"); return pstmt; } catch (SQLException e) { @@ -657,8 +682,9 @@ void testNotSupportedCommands(String commandName) { }); } - @Test - void testNotSupportedUpdateElements() { + @ParameterizedTest(name = "test not supported update elements. Parameters: option={0}") + @ValueSource(strings = {"hint", "collation", "arrayFilters", "sort", "upsert", "c"}) + void testNotSupportedUpdateElements(String unsupportedElement) { doWorkAwareOfAutoCommit(connection -> { try (PreparedStatement pstm = connection.prepareStatement( format( @@ -670,19 +696,20 @@ void testNotSupportedUpdateElements() { q: { author: { $eq: "Leo Tolstoy" } }, u: { $set: { outOfStock: true } }, multi: true, - hint: { _id: 1 } + %s: { _id: 1 } } ] - }"""))) { + }""", unsupportedElement))) { SQLFeatureNotSupportedException exception = assertThrows(SQLFeatureNotSupportedException.class, pstm::executeUpdate); - assertThat(exception.getMessage()).isEqualTo("Unsupported elements in update command: [hint]"); + assertThat(exception.getMessage()).isEqualTo(format("Unsupported elements in update command: [%s]", unsupportedElement)); } }); } - @Test - void testNotSupportedDeleteElements() { + @ParameterizedTest(name = "test not supported delete elements. Parameters: option={0}") + @ValueSource(strings = {"hint", "collation"}) + void testNotSupportedDeleteElements(String unsupportedElement) { doWorkAwareOfAutoCommit(connection -> { try (PreparedStatement pstm = connection.prepareStatement( format( @@ -693,19 +720,19 @@ void testNotSupportedDeleteElements() { { q: { author: { $eq: "Leo Tolstoy" } }, limit: 0, - hint: { _id: 1 } + %s: { _id: 1 } } ] - }"""))) { + }""", unsupportedElement))) { SQLFeatureNotSupportedException exception = assertThrows(SQLFeatureNotSupportedException.class, pstm::executeUpdate); - assertThat(exception.getMessage()).isEqualTo("Unsupported elements in delete command: [hint]"); + assertThat(exception.getMessage()).isEqualTo(format("Unsupported elements in delete command: [%s]", unsupportedElement)); } }); } private void assertExecuteUpdate( - Function pstmtProvider, + Function pstmtProvider, int expectedUpdatedRowCount, List expectedDocuments) { doWorkAwareOfAutoCommit(connection -> { diff --git a/src/integrationTest/java/com/mongodb/hibernate/query/AbstractQueryIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/query/AbstractQueryIntegrationTests.java index 5a5007b5..3d08c506 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/query/AbstractQueryIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/query/AbstractQueryIntegrationTests.java @@ -128,7 +128,7 @@ protected void assertSelectionQuery( } var resultList = selectionQuery.getResultList(); - assertActualCommand(BsonDocument.parse(expectedMql)); + assertActualCommandsInOrder(BsonDocument.parse(expectedMql)); resultListVerifier.accept(resultList); @@ -178,7 +178,7 @@ protected void assertSelectQueryFailure( expectedExceptionMessageParameters); } - protected void assertActualCommand(BsonDocument... expectedCommands) { + protected void assertActualCommandsInOrder(BsonDocument... expectedCommands) { var capturedCommands = testCommandListener.getStartedCommands(); assertThat(capturedCommands).hasSize(expectedCommands.length); for (int i = 0; i < expectedCommands.length; i++) { @@ -201,7 +201,7 @@ protected void assertMutationQuery( queryPostProcessor.accept(query); } var mutationCount = query.executeUpdate(); - assertActualCommand(BsonDocument.parse(expectedMql)); + assertActualCommandsInOrder(BsonDocument.parse(expectedMql)); assertThat(mutationCount).isEqualTo(expectedMutationCount); }); assertThat(collection.find()).containsExactlyElementsOf(expectedDocuments); diff --git a/src/integrationTest/java/com/mongodb/hibernate/query/mutation/BatchUpdateIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/query/mutation/BatchUpdateIntegrationTests.java index b037d1e3..544821b1 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/query/mutation/BatchUpdateIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/query/mutation/BatchUpdateIntegrationTests.java @@ -57,7 +57,7 @@ void testBatchInsert() { session.persist(new Item(i, String.valueOf(i))); } session.flush(); - assertActualCommand( + assertActualCommandsInOrder( parse( """ { @@ -101,7 +101,7 @@ void testBatchUpdate() { item.string = "u" + i; } session.flush(); - assertActualCommand( + assertActualCommandsInOrder( parse( """ { @@ -145,7 +145,7 @@ void testBatchDelete() { session.remove(item); } session.flush(); - assertActualCommand( + assertActualCommandsInOrder( parse( """ { diff --git a/src/integrationTest/java/com/mongodb/hibernate/query/select/LimitOffsetFetchClauseIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/query/select/LimitOffsetFetchClauseIntegrationTests.java index c70e9544..9bed452f 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/query/select/LimitOffsetFetchClauseIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/query/select/LimitOffsetFetchClauseIntegrationTests.java @@ -614,7 +614,7 @@ private void setQueryOptionsAndQuery( query.getResultList(); if (expectedMql != null) { var expectedCommand = BsonDocument.parse(expectedMql); - assertActualCommand(expectedCommand); + assertActualCommandsInOrder(expectedCommand); } } } diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 72afc6e5..d34062e3 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -276,7 +276,8 @@ void checkSupportedBatchCommand(BsonDocument command) throws SQLException { "Commands returning result set are not supported. Received command: %s", commandType.getCommandName()), null, - new int[0]); + 0, + null); } checkSupportedUpdateCommand(commandType); } diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java index 49cfc254..0154eaac 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java @@ -263,7 +263,8 @@ void testExecuteQueryThrowsSqlException(Exception exceptionToThrow, Class mongoCursor; private static final BulkWriteResult BULK_WRITE_RESULT = BulkWriteResult.acknowledged( - 1, 0, 2, 3, emptyList(), - List.of(new BulkWriteInsert(0, new BsonObjectId(new ObjectId(1, 1))))); + 1, 0, 2, 3, emptyList(), List.of(new BulkWriteInsert(0, new BsonObjectId(new ObjectId(1, 1))))); private ResultSet lastOpenResultSet; From a7f7792e10a80dba7dcd65a40f9316b80964b60c Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Thu, 2 Oct 2025 16:58:36 -0700 Subject: [PATCH 06/37] Revert to private. --- src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index d34062e3..65bfe658 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -294,7 +294,7 @@ static BsonDocument parse(String mql) throws SQLSyntaxErrorException { * Starts transaction for the first {@link java.sql.Statement} executing if * {@linkplain MongoConnection#getAutoCommit() auto-commit} is disabled. */ - void startTransactionIfNeeded() throws SQLException { + private void startTransactionIfNeeded() throws SQLException { if (!mongoConnection.getAutoCommit() && !clientSession.hasActiveTransaction()) { clientSession.startTransaction(); } From 7372d56fd1f706017a1e7056a04c64c4e680af1c Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Mon, 6 Oct 2025 20:23:57 -0700 Subject: [PATCH 07/37] Add JDBC exception handling. --- .../jdbc/MongoPreparedStatement.java | 1 - .../hibernate/jdbc/MongoStatement.java | 221 +++++--- .../jdbc/MongoPreparedStatementTests.java | 495 ++++++++++++++---- 3 files changed, 560 insertions(+), 157 deletions(-) diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java index 61f3a3f8..55532a0f 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java @@ -53,7 +53,6 @@ final class MongoPreparedStatement extends MongoStatement implements PreparedStatementAdapter { - private static final int[] EMPTY_BATCH_RESULT = new int[0]; private final BsonDocument command; private final List commandBatch; private final List parameterValueSetters; diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 65bfe658..7f08e9ec 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -16,15 +16,9 @@ package com.mongodb.hibernate.jdbc; -import static com.mongodb.assertions.Assertions.fail; -import static com.mongodb.hibernate.internal.MongoAssertions.assertNotNull; -import static com.mongodb.hibernate.internal.MongoConstants.ID_FIELD_NAME; -import static com.mongodb.hibernate.internal.VisibleForTesting.AccessModifier.PRIVATE; -import static java.lang.String.format; -import static java.util.Collections.singletonList; -import static java.util.stream.Collectors.toCollection; - +import com.mongodb.ErrorCategory; import com.mongodb.MongoBulkWriteException; +import com.mongodb.MongoException; import com.mongodb.MongoExecutionTimeoutException; import com.mongodb.MongoSocketReadTimeoutException; import com.mongodb.MongoSocketWriteTimeoutException; @@ -41,13 +35,19 @@ import com.mongodb.client.model.WriteModel; import com.mongodb.hibernate.internal.FeatureNotSupportedException; import com.mongodb.hibernate.internal.VisibleForTesting; +import org.bson.BsonDocument; +import org.bson.BsonValue; +import org.jspecify.annotations.Nullable; + import java.sql.BatchUpdateException; import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.SQLFeatureNotSupportedException; +import java.sql.SQLIntegrityConstraintViolationException; import java.sql.SQLSyntaxErrorException; import java.sql.SQLTimeoutException; +import java.sql.SQLTransientException; import java.sql.SQLWarning; import java.sql.Statement; import java.util.ArrayList; @@ -55,15 +55,25 @@ import java.util.Collection; import java.util.List; import java.util.Map; -import org.bson.BsonDocument; -import org.bson.BsonValue; -import org.jspecify.annotations.Nullable; + +import static com.mongodb.assertions.Assertions.fail; +import static com.mongodb.hibernate.internal.MongoAssertions.assertNotNull; +import static com.mongodb.hibernate.internal.MongoConstants.ID_FIELD_NAME; +import static com.mongodb.hibernate.internal.VisibleForTesting.AccessModifier.PRIVATE; +import static java.lang.Math.max; +import static java.lang.String.format; +import static java.util.Collections.singletonList; +import static java.util.stream.Collectors.toCollection; class MongoStatement implements StatementAdapter { private static final List SUPPORTED_UPDATE_COMMAND_ELEMENTS = List.of("q", "u", "multi"); private static final List SUPPORTED_DELETE_COMMAND_ELEMENTS = List.of("q", "limit"); - private static final String EXCEPTION_MESSAGE_FAILED_TO_EXECUTE_OPERATION = "Failed to execute operation"; + private static final String EXCEPTION_MESSAGE_OPERATION_FAILED = "Failed to execute operation"; + private static final String EXCEPTION_MESSAGE_BATCH_FAILED = "Batch execution failed"; + private static final String EXCEPTION_MESSAGE_TIMEOUT = "Timeout while waiting for operation to complete"; + private static final int DEFAULT_ERROR_CODE = 0; + static final int[] EMPTY_BATCH_RESULT = new int[DEFAULT_ERROR_CODE]; private final MongoDatabase mongoDatabase; private final MongoConnection mongoConnection; private final ClientSession clientSession; @@ -127,7 +137,7 @@ private static boolean isExcludeProjectSpecification(Map.Entry commandBatch, ExecutionType executionType) throws SQLException { - var firstDocumentInBatch = commandBatch.get(0); + var firstDocumentInBatch = commandBatch.get(DEFAULT_ERROR_CODE); var commandType = getCommandType(firstDocumentInBatch); var collection = getCollection(commandType, firstDocumentInBatch); try { @@ -276,7 +286,7 @@ void checkSupportedBatchCommand(BsonDocument command) throws SQLException { "Commands returning result set are not supported. Received command: %s", commandType.getCommandName()), null, - 0, + DEFAULT_ERROR_CODE, null); } checkSupportedUpdateCommand(commandType); @@ -311,50 +321,6 @@ private MongoCollection getCollection(CommandType commandType, Bso return mongoDatabase.getCollection(collectionName, BsonDocument.class); } - private static SQLException handleException( - RuntimeException exception, CommandType commandType, ExecutionType executionType) { - if (isTimeoutException(exception)) { - return new SQLTimeoutException("Timeout while waiting for operation to complete", exception); - } - - if (exception instanceof MongoBulkWriteException mongoBulkWriteException) { - if (executionType == ExecutionType.BATCH) { - return createBatchUpdateException(mongoBulkWriteException, commandType); - } else { - return new SQLException( - EXCEPTION_MESSAGE_FAILED_TO_EXECUTE_OPERATION, - null, - getErrorCode(mongoBulkWriteException), - mongoBulkWriteException); - } - } - - return new SQLException(EXCEPTION_MESSAGE_FAILED_TO_EXECUTE_OPERATION, exception); - } - - private static boolean isTimeoutException(final RuntimeException exception) { - return exception instanceof MongoSocketReadTimeoutException - || exception instanceof MongoSocketWriteTimeoutException - || exception instanceof MongoTimeoutException - || exception instanceof MongoExecutionTimeoutException; - } - - static BatchUpdateException createBatchUpdateException( - MongoBulkWriteException mongoBulkWriteException, CommandType commandType) { - var updateCount = getUpdateCount(commandType, mongoBulkWriteException.getWriteResult()); - var updateCounts = new int[updateCount]; - Arrays.fill(updateCounts, Statement.SUCCESS_NO_INFO); - var code = getErrorCode(mongoBulkWriteException); - return new BatchUpdateException( - EXCEPTION_MESSAGE_FAILED_TO_EXECUTE_OPERATION, null, code, updateCounts, mongoBulkWriteException); - } - - private static int getErrorCode(final MongoBulkWriteException mongoBulkWriteException) { - var writeErrors = mongoBulkWriteException.getWriteErrors(); - // Since we are executing an ordered bulk write, there will be at most one BulkWriteError. - return writeErrors.isEmpty() ? 0 : writeErrors.get(0).getCode(); - } - private static void convertToWriteModels( CommandType commandType, BsonDocument command, Collection> writeModels) throws SQLFeatureNotSupportedException { @@ -443,6 +409,143 @@ static int getUpdateCount(CommandType commandType, BulkWriteResult bulkWriteResu }; } + private static SQLException handleException(RuntimeException exception, + CommandType commandType, + ExecutionType executionType) { + int errorCode = getErrorCode(exception); + return switch (executionType) { + case BATCH -> handleBatchException(exception, commandType, errorCode); + case QUERY, UPDATE -> { + if (exception instanceof MongoException mongoException) { + Exception handledException = handleMongoException(mongoException, errorCode); + yield toSqlException( + errorCode, + handledException); + } + yield toSqlException(DEFAULT_ERROR_CODE, exception); + } + }; + } + + private static SQLException handleBatchException(RuntimeException exception, + CommandType commandType, + int errorCode) { + if (exception instanceof MongoException mongo) { + Exception cause = handleMongoException(mongo, errorCode); + if (exception instanceof MongoBulkWriteException bulkWriteException) { + return createBatchUpdateException( + cause, + bulkWriteException.getWriteResult(), + errorCode, + commandType); + } + return toBatchUpdateException(errorCode, cause); + } + return toBatchUpdateException(DEFAULT_ERROR_CODE, exception); + } + + private static int getErrorCode(final RuntimeException runtimeException) { + if (runtimeException instanceof MongoBulkWriteException bulk) { + return getErrorCode(bulk); + } + if (runtimeException instanceof MongoException mongoException) { + return max(DEFAULT_ERROR_CODE, mongoException.getCode()); + } + return DEFAULT_ERROR_CODE; + } + + private static SQLTransientException toSqlTransientException(final int errorCode, Exception cause) { + return withCause(new SQLTransientException("Transient exception occurred", null, errorCode), cause); + } + + private static SQLException toSqlException(final int errorCode, final Exception exception) { + if (exception instanceof SQLException sqlException) { + return sqlException; + } + return new SQLException(EXCEPTION_MESSAGE_OPERATION_FAILED, null, errorCode, exception); + } + + private static Exception handleMongoException(final MongoException exceptionToHandle, + final int errorCode) { + Exception exception; + if (isTimeoutException(exceptionToHandle)) { + exception = new SQLTimeoutException(EXCEPTION_MESSAGE_TIMEOUT, + null, + errorCode, + exceptionToHandle); + } else { + exception = handleByErrorCode(exceptionToHandle, errorCode); + } + if (exceptionToHandle.hasErrorLabel(MongoException.TRANSIENT_TRANSACTION_ERROR_LABEL)) { + return toSqlTransientException(errorCode, exception); + } + return exception; + } + + private static SQLException toBatchUpdateException(final int errorCode, final Exception exception) { + return withCause(new BatchUpdateException( + EXCEPTION_MESSAGE_BATCH_FAILED, + null, + errorCode, + EMPTY_BATCH_RESULT), + exception); + } + + private static T withCause(T exception, final Exception cause) { + exception.initCause(cause); + if (exception instanceof SQLException sqlException) { + sqlException.setNextException(sqlException); + } + return exception; + } + + private static Exception handleByErrorCode(final MongoException mongoException, + int errorCode) { + ErrorCategory errorCategory = ErrorCategory.fromErrorCode(errorCode); + return switch (errorCategory) { + case DUPLICATE_KEY -> new SQLIntegrityConstraintViolationException( + EXCEPTION_MESSAGE_OPERATION_FAILED, + null, + errorCode, + mongoException); + case EXECUTION_TIMEOUT -> new SQLTimeoutException( + EXCEPTION_MESSAGE_TIMEOUT, + null, + errorCode, + mongoException); + case UNCATEGORIZED -> mongoException; + }; + } + + private static boolean isTimeoutException(final MongoException exception) { + return exception instanceof MongoSocketReadTimeoutException + || exception instanceof MongoSocketWriteTimeoutException + || exception instanceof MongoTimeoutException + || exception instanceof MongoExecutionTimeoutException; + } + + private static BatchUpdateException createBatchUpdateException( + Exception cause, + BulkWriteResult bulkWriteResult, + int errorCode, + CommandType commandType) { + var updateCount = getUpdateCount(commandType, bulkWriteResult); + var updateCounts = new int[updateCount]; + Arrays.fill(updateCounts, Statement.SUCCESS_NO_INFO); + return withCause(new BatchUpdateException( + EXCEPTION_MESSAGE_BATCH_FAILED, + null, + errorCode, + updateCounts), + cause); + } + + private static int getErrorCode(final MongoBulkWriteException mongoBulkWriteException) { + var writeErrors = mongoBulkWriteException.getWriteErrors(); + // Since we are executing an ordered bulk write, there will be at most one BulkWriteError. + return writeErrors.isEmpty() ? DEFAULT_ERROR_CODE : writeErrors.get(DEFAULT_ERROR_CODE).getCode(); + } + enum CommandType { INSERT("insert"), UPDATE("update"), diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java index 0154eaac..1ff3eaf4 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java @@ -16,27 +16,6 @@ package com.mongodb.hibernate.jdbc; -import static java.sql.Statement.SUCCESS_NO_INFO; -import static java.util.Collections.emptyList; -import static java.util.Collections.emptySet; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.junit.jupiter.api.Assertions.assertAll; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Named.named; -import static org.mockito.ArgumentMatchers.anyList; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.same; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.verify; - import com.mongodb.MongoBulkWriteException; import com.mongodb.MongoException; import com.mongodb.MongoExecutionTimeoutException; @@ -56,21 +35,7 @@ import com.mongodb.client.model.InsertOneModel; import com.mongodb.client.model.WriteModel; import com.mongodb.hibernate.internal.type.ObjectIdJdbcType; -import java.math.BigDecimal; -import java.sql.Array; -import java.sql.BatchUpdateException; -import java.sql.Date; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.SQLSyntaxErrorException; -import java.sql.SQLTimeoutException; -import java.sql.Time; -import java.sql.Timestamp; -import java.sql.Types; -import java.util.Calendar; -import java.util.List; -import java.util.function.Consumer; -import java.util.stream.Stream; +import org.assertj.core.api.ThrowingConsumer; import org.bson.BsonArray; import org.bson.BsonBoolean; import org.bson.BsonDocument; @@ -93,6 +58,46 @@ import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; +import java.math.BigDecimal; +import java.sql.Array; +import java.sql.BatchUpdateException; +import java.sql.Date; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.SQLIntegrityConstraintViolationException; +import java.sql.SQLSyntaxErrorException; +import java.sql.SQLTimeoutException; +import java.sql.SQLTransientException; +import java.sql.Time; +import java.sql.Timestamp; +import java.sql.Types; +import java.util.Calendar; +import java.util.List; +import java.util.function.Consumer; +import java.util.stream.Stream; + +import static java.lang.Math.max; +import static java.sql.Statement.SUCCESS_NO_INFO; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptySet; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Named.named; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.verify; + @ExtendWith(MockitoExtension.class) class MongoPreparedStatementTests { @@ -237,60 +242,345 @@ void beforeEach() { doReturn(mongoCollection).when(mongoDatabase).getCollection(anyString(), eq(BsonDocument.class)); } - private static Stream exceptions() { + private static Stream timeoutExceptions() { var dummyCause = new RuntimeException(); return Stream.of( - Arguments.of(new MongoException(DUMMY_EXCEPTION_MESSAGE), SQLException.class), - Arguments.of(new RuntimeException(DUMMY_EXCEPTION_MESSAGE), SQLException.class), - Arguments.of( - new MongoExecutionTimeoutException(DUMMY_EXCEPTION_MESSAGE), SQLTimeoutException.class), - Arguments.of( - new MongoSocketReadTimeoutException( - DUMMY_EXCEPTION_MESSAGE, DUMMY_SERVER_ADDRESS, dummyCause), - SQLTimeoutException.class), - Arguments.of( - new MongoSocketWriteTimeoutException( - DUMMY_EXCEPTION_MESSAGE, DUMMY_SERVER_ADDRESS, dummyCause), - SQLTimeoutException.class), - Arguments.of(new MongoTimeoutException(DUMMY_EXCEPTION_MESSAGE), SQLTimeoutException.class), - Arguments.of( - new MongoOperationTimeoutException(DUMMY_EXCEPTION_MESSAGE), SQLTimeoutException.class)); + new MongoExecutionTimeoutException(1, DUMMY_EXCEPTION_MESSAGE), + new MongoSocketReadTimeoutException(DUMMY_EXCEPTION_MESSAGE, DUMMY_SERVER_ADDRESS, dummyCause), + new MongoSocketWriteTimeoutException( + DUMMY_EXCEPTION_MESSAGE, DUMMY_SERVER_ADDRESS, dummyCause), + new MongoTimeoutException(DUMMY_EXCEPTION_MESSAGE), + new MongoOperationTimeoutException(DUMMY_EXCEPTION_MESSAGE), + new MongoException(50, DUMMY_EXCEPTION_MESSAGE) // 50 is a timeout error code + ); } - @ParameterizedTest(name = "test executeQuery throws SQLException. Parameters: exception={0}, expectedType={1}") - @MethodSource("exceptions") - void testExecuteQueryThrowsSqlException(Exception exceptionToThrow, Class expectedType) - throws SQLException { - doThrow(exceptionToThrow).when(mongoCollection).aggregate(eq(clientSession), anyList()); + private static Stream transientTimeoutExceptions() { + return timeoutExceptions().peek(mongoException -> mongoException.addLabel(MongoException.TRANSIENT_TRANSACTION_ERROR_LABEL)); + } + + private static Stream constraintViolationExceptions() { + return Stream.of( + new MongoException(11000, "Duplicate key error"), + new MongoException(11001, "Duplicate key error"), + new MongoException(12582, "Duplicate key error") + ); + } + + private static Stream genericMongoExceptions() { + return Stream.of( + new MongoException(-3, DUMMY_EXCEPTION_MESSAGE), + new MongoException(5000, DUMMY_EXCEPTION_MESSAGE) + ); + } + + private static Stream genericTransientMongoExceptions() { + return Stream.of( + new MongoException(-3, DUMMY_EXCEPTION_MESSAGE), + new MongoException(5000, DUMMY_EXCEPTION_MESSAGE) + ).peek(mongoException -> mongoException.addLabel(MongoException.TRANSIENT_TRANSACTION_ERROR_LABEL)); + } + + @ParameterizedTest(name = "test executeBatch throws SQLException. Exception: {0}") + @MethodSource("genericMongoExceptions") + void testExecuteBatchMongoException(MongoException mongoException) throws SQLException { + int expectedErrorCode = max(0, mongoException.getCode()); + doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + assertExecuteBatchThrowsSqlException(batchUpdateException -> { + assertAll( + () -> assertEquals(mongoException, batchUpdateException.getCause()), + () -> assertEquals(expectedErrorCode, batchUpdateException.getErrorCode()), + () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0), + () -> assertNull(batchUpdateException.getSQLState()) + ); + }); + } + + @ParameterizedTest(name = "test executeBatch throws SQLException. Exception: {0}") + @MethodSource("genericMongoExceptions") + void testExecuteUpdateMongoException(MongoException mongoException) throws SQLException { + doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + assertExecuteUpdateThrowsSqlException( + sqlException -> assertGenericMongoException(mongoException, sqlException)); + } - assertExecuteThrowsSqlException( - MQL_ITEMS_AGGREGATE, MongoPreparedStatement::executeQuery, exceptionToThrow, expectedType); + @ParameterizedTest(name = "test executeBatch throws SQLException. Exception: {0}") + @MethodSource("genericMongoExceptions") + void testExecuteQueryMongoException(MongoException mongoException) throws SQLException { + doThrow(mongoException).when(mongoCollection).aggregate(eq(clientSession), anyList()); + assertExecuteQueryThrowsSqlException( + sqlException -> assertGenericMongoException(mongoException, sqlException)); } - @ParameterizedTest(name = "test executeUpdate throws SQLException. Parameters: exception={0}, expectedType={1}") - @MethodSource("exceptions") - void testExecuteUpdateThrowsSqlException(Exception exceptionToThrow, Class expectedType) + @ParameterizedTest(name = "test executeBatch throws SQLException. Exception: {0}") + @MethodSource("genericTransientMongoExceptions") + void testExecuteBatchTransientMongoException(MongoException mongoException) throws SQLException { + int expectedErrorCode = max(0, mongoException.getCode()); + doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + assertExecuteBatchThrowsSqlException(batchUpdateException -> { + assertAll( + () -> { + SQLTransientException sqlTransientException = + assertInstanceOf(SQLTransientException.class, batchUpdateException.getCause()); + assertEquals(mongoException, sqlTransientException.getCause()); + }, + () -> assertEquals(expectedErrorCode, batchUpdateException.getErrorCode()), + () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0), + () -> assertNull(batchUpdateException.getSQLState()) + ); + }); + } + + @ParameterizedTest(name = "test executeUpdate throws SQLException. Exception: {0}") + @MethodSource("genericTransientMongoExceptions") + void testExecuteUpdateTransientMongoException(MongoException mongoException) throws SQLException { + doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + assertExecuteUpdateThrowsSqlException(sqlException -> { + assertGenericTransientMongoException(mongoException, sqlException); + }); + } + + @ParameterizedTest(name = "test executeQuery throws SQLException. Exception: {0}") + @MethodSource("genericTransientMongoExceptions") + void testExecuteQueryTransientMongoException(MongoException mongoException) throws SQLException { + doThrow(mongoException).when(mongoCollection).aggregate(eq(clientSession), anyList()); + assertExecuteQueryThrowsSqlException(sqlException -> { + assertGenericTransientMongoException(mongoException, sqlException); + }); + } + + private static void assertGenericTransientMongoException(final MongoException mongoException, final SQLException sqlException) { + int expectedErrorCode = max(0, mongoException.getCode()); + assertAll( + () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), + () -> { + SQLTransientException sqlTransientException = + assertInstanceOf(SQLTransientException.class, sqlException); + assertEquals(expectedErrorCode, sqlTransientException.getErrorCode()); + assertEquals(mongoException, sqlTransientException.getCause()); + }, + () -> assertNull(sqlException.getSQLState())); + } + + private static void assertGenericMongoException(final MongoException mongoException, final SQLException sqlException) { + int expectedErrorCode = max(0, mongoException.getCode()); + assertAll( + () -> assertThat((Throwable) sqlException).isExactlyInstanceOf(SQLException.class), + () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), + () -> assertEquals(mongoException, sqlException.getCause()), + () -> assertNull(sqlException.getSQLState())); + } + + @ParameterizedTest(name = "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") + @MethodSource("timeoutExceptions") + void testExecuteUpdateTimeoutException(MongoException mongoTimeoutException) throws SQLException { - doThrow(exceptionToThrow).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + doThrow(mongoTimeoutException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + assertExecuteUpdateThrowsSqlException( + sqlException -> + assertTimeoutException(mongoTimeoutException, sqlException)); + } + + @ParameterizedTest(name = "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") + @MethodSource("timeoutExceptions") + void testExecuteQueryTimeoutException(MongoException mongoTimeoutException) + throws SQLException { + doThrow(mongoTimeoutException).when(mongoCollection).aggregate(eq(clientSession), anyList()); + assertExecuteQueryThrowsSqlException( + sqlException -> + assertTimeoutException(mongoTimeoutException, sqlException)); + } - assertExecuteThrowsSqlException( - MQL_ITEMS_INSERT, MongoPreparedStatement::executeUpdate, exceptionToThrow, expectedType); + private static void assertTimeoutException(final MongoException mongoTimeoutException, final SQLException sqlException) { + int expectedErrorCode = max(0, mongoTimeoutException.getCode()); + assertAll( + () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), + () -> { + SQLTimeoutException sqlTimeoutException = assertInstanceOf(SQLTimeoutException.class, sqlException); + assertEquals(expectedErrorCode, sqlTimeoutException.getErrorCode()); + assertEquals(mongoTimeoutException, sqlTimeoutException.getCause()); + }, + () -> assertNull(sqlException.getSQLState())); } - @ParameterizedTest(name = "test executeUpdate throws SQLException. Parameters: exception={0}, expectedType={1}") - @MethodSource("exceptions") - void testExecuteBatchThrowsSqlException(Exception exceptionToThrow, Class expectedType) + @ParameterizedTest(name = "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") + @MethodSource("transientTimeoutExceptions") + void testExecuteUpdateTransientTimeoutException(MongoException mongoTransientTimeoutException) throws SQLException { - doThrow(exceptionToThrow).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + doThrow(mongoTransientTimeoutException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + assertExecuteUpdateThrowsSqlException( + sqlException -> + assertTransientTimeoutException(mongoTransientTimeoutException, sqlException)); + } + + @ParameterizedTest(name = "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") + @MethodSource("transientTimeoutExceptions") + void testExecuteQueryTransientTimeoutException(MongoException mongoTransientTimeoutException) + throws SQLException { + doThrow(mongoTransientTimeoutException).when(mongoCollection).aggregate(eq(clientSession), anyList()); + assertExecuteQueryThrowsSqlException( + sqlException -> + assertTransientTimeoutException(mongoTransientTimeoutException, sqlException)); + } - assertExecuteThrowsSqlException( - MQL_ITEMS_INSERT, - mongoPreparedStatement -> { - mongoPreparedStatement.addBatch(); - mongoPreparedStatement.executeBatch(); + private static void assertTransientTimeoutException(final MongoException mongoTransientTimeoutException, + final SQLException sqlException) { + int expectedErrorCode = max(0, mongoTransientTimeoutException.getCode()); + assertAll( + () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), + () -> { + SQLTransientException sqlTransientException = + assertInstanceOf(SQLTransientException.class, sqlException); + assertEquals(expectedErrorCode, sqlTransientException.getErrorCode()); + SQLTimeoutException sqlTimeoutException = + assertInstanceOf(SQLTimeoutException.class, sqlTransientException.getCause()); + assertEquals(expectedErrorCode, sqlTimeoutException.getErrorCode()); + assertEquals(mongoTransientTimeoutException, sqlTimeoutException.getCause()); }, - exceptionToThrow, - expectedType); + () -> assertNull(sqlException.getSQLState())); + } + + @ParameterizedTest(name = "test executeUpdate constraint violation. Exception code={0}") + @MethodSource("constraintViolationExceptions") + void testExecuteUpdateConstraintViolationException(MongoException mongoException) throws SQLException { + int expectedErrorCode = mongoException.getCode(); + doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + assertExecuteUpdateThrowsSqlException( + sqlException -> { + assertConstraintViolationException(mongoException, sqlException, expectedErrorCode); + }); + } + + @ParameterizedTest(name = "test executeUpdate constraint violation. Exception code={0}") + @MethodSource("constraintViolationExceptions") + void testExecuteQueryConstraintViolationException(MongoException mongoException) throws SQLException { + int expectedErrorCode = mongoException.getCode(); + doThrow(mongoException).when(mongoCollection).aggregate(eq(clientSession), anyList()); + assertExecuteQueryThrowsSqlException( + sqlException -> { + assertConstraintViolationException(mongoException, sqlException, expectedErrorCode); + }); + } + + private static void assertConstraintViolationException(final MongoException mongoException, + final SQLException sqlException, + final int expectedErrorCode) { + assertAll( + () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), + () -> { + SQLIntegrityConstraintViolationException sqlIntegrityConstraintViolationException = + assertInstanceOf(SQLIntegrityConstraintViolationException.class, sqlException); + assertEquals(expectedErrorCode, sqlIntegrityConstraintViolationException.getErrorCode()); + assertEquals(mongoException, sqlIntegrityConstraintViolationException.getCause()); + }, + () -> assertNull(sqlException.getSQLState())); + } + + + @ParameterizedTest(name = "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") + @MethodSource("timeoutExceptions") + void testExecuteBatchTimeoutException(MongoException mongoTimeoutException) + throws SQLException { + doThrow(mongoTimeoutException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + assertExecuteBatchThrowsSqlException( + batchUpdateException -> { + int expectedErrorCode = max(0, mongoTimeoutException.getCode()); + assertAll( + () -> assertEquals(expectedErrorCode, batchUpdateException.getErrorCode()), + () -> { + SQLTimeoutException sqlTimeoutException = + assertInstanceOf(SQLTimeoutException.class, batchUpdateException.getCause()); + assertEquals(expectedErrorCode, sqlTimeoutException.getErrorCode()); + assertEquals(mongoTimeoutException, sqlTimeoutException.getCause()); + }, + () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0), + () -> assertNull(batchUpdateException.getSQLState())); + }); + } + + + @ParameterizedTest(name = "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") + @MethodSource("transientTimeoutExceptions") + void testExecuteBatchTransientTimeoutException(MongoException mongoTimeoutException) + throws SQLException { + doThrow(mongoTimeoutException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + assertExecuteBatchThrowsSqlException( + batchUpdateException -> { + int expectedErrorCode = max(0, mongoTimeoutException.getCode()); + assertAll( + () -> assertEquals(expectedErrorCode, batchUpdateException.getErrorCode()), + () -> { + SQLTransientException sqlTransientException = + assertInstanceOf(SQLTransientException.class, batchUpdateException.getCause()); + assertEquals(expectedErrorCode, sqlTransientException.getErrorCode()); + SQLTimeoutException sqlTimeoutException = + assertInstanceOf(SQLTimeoutException.class, sqlTransientException.getCause()); + assertEquals(expectedErrorCode, sqlTimeoutException.getErrorCode()); + assertEquals(mongoTimeoutException, sqlTimeoutException.getCause()); + }, + () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0), + () -> assertNull(batchUpdateException.getSQLState())); + }); + } + + @ParameterizedTest(name = "test executeBatch constraint violation. Exception code={0}") + @MethodSource("constraintViolationExceptions") + void testExecuteBatchConstraintViolationException(MongoException mongoException) throws SQLException { + int expectedErrorCode = mongoException.getCode(); + doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + assertExecuteBatchThrowsSqlException( + batchUpdateException -> { + assertAll( + () -> assertEquals(expectedErrorCode, batchUpdateException.getErrorCode()), + () -> { + SQLIntegrityConstraintViolationException sqlIntegrityConstraintViolationException = + assertInstanceOf(SQLIntegrityConstraintViolationException.class, + batchUpdateException.getCause()); + assertEquals(expectedErrorCode, sqlIntegrityConstraintViolationException.getErrorCode()); + assertEquals(mongoException, sqlIntegrityConstraintViolationException.getCause()); + }, + () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0), + () -> assertNull(batchUpdateException.getSQLState())); + }); + } + + @Test + void testExecuteBatchRuntimeExceptionCause() + throws SQLException { + RuntimeException runtimeException = new RuntimeException(); + doThrow(runtimeException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + assertExecuteBatchThrowsSqlException(batchUpdateException -> { + assertAll( + () -> assertEquals(runtimeException, batchUpdateException.getCause()), + () -> assertEquals(0, batchUpdateException.getErrorCode()), + () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0), + () -> assertNull(batchUpdateException.getSQLState()) + ); + }); + } + + @Test + void testExecuteUpdateRuntimeExceptionCause() + throws SQLException { + RuntimeException runtimeException = new RuntimeException(); + doThrow(runtimeException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + assertExecuteUpdateThrowsSqlException(sqlException -> assertGenericException(sqlException, runtimeException)); + } + + @Test + void testExecuteQueryRuntimeExceptionCause() + throws SQLException { + RuntimeException runtimeException = new RuntimeException(); + doThrow(runtimeException).when(mongoCollection).aggregate(eq(clientSession), anyList()); + assertExecuteQueryThrowsSqlException(sqlException -> assertGenericException(sqlException, runtimeException)); + } + + private static void assertGenericException(final SQLException sqlException, RuntimeException cause) { + assertAll( + () -> assertThat((Throwable) sqlException).isExactlyInstanceOf(SQLException.class), + () -> assertEquals(cause, sqlException.getCause()), + () -> assertEquals(0, sqlException.getErrorCode()), + () -> assertNull(sqlException.getSQLState())); } private static Stream bulkWriteExceptionsForExecuteUpdate() { @@ -307,7 +597,7 @@ private static Stream bulkWriteExceptionsForExecuteUpdate() { name = "test executeUpdate throws SQLException when MongoBulkWriteException occurs." + " Parameters: commandName={0}, exception={1}") @MethodSource("bulkWriteExceptionsForExecuteUpdate") - void testExecuteUpdateThrowsSqlExceptionWhenMongoBulkWriteExceptionOccurs( + void testExecuteUpdateMongoBulkWriteException( String mql, MongoBulkWriteException mongoBulkWriteException) throws SQLException { doThrow(mongoBulkWriteException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); Integer vendorCodeError = getVendorCodeError(mongoBulkWriteException); @@ -316,11 +606,9 @@ void testExecuteUpdateThrowsSqlExceptionWhenMongoBulkWriteExceptionOccurs( assertThatExceptionOfType(SQLException.class) .isThrownBy(mongoPreparedStatement::executeUpdate) .withCause(mongoBulkWriteException) - .satisfies(sqlException -> { - assertAll( - () -> assertNull(sqlException.getSQLState()), - () -> assertEquals(vendorCodeError, sqlException.getErrorCode())); - }); + .satisfies(sqlException -> + assertAll(() -> assertNull(sqlException.getSQLState()), + () -> assertEquals(vendorCodeError, sqlException.getErrorCode()))); } } @@ -356,7 +644,7 @@ private static Stream bulkWriteExceptionsForExecuteBatch() { name = "test executeBatch throws BatchUpdateException when MongoBulkWriteException occurs." + " Parameters: commandName={0}, exception={1}") @MethodSource("bulkWriteExceptionsForExecuteBatch") - void testExecuteBatchThrowsBatchUpdateExceptionWhenMongoBulkWriteExceptionOccurs( + void testExecuteBatchMongoBulkWriteException( String mql, MongoBulkWriteException mongoBulkWriteException, int expectedUpdateCountLength) throws SQLException { doThrow(mongoBulkWriteException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); @@ -377,21 +665,34 @@ void testExecuteBatchThrowsBatchUpdateExceptionWhenMongoBulkWriteExceptionOccurs } } - private void assertExecuteThrowsSqlException( - String mql, - SqlConsumer executeConsumer, - Exception expectedCause, - Class expectedExceptionType) + private void assertExecuteBatchThrowsSqlException( + ThrowingConsumer asserter) throws SQLException { - try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(mql)) { - assertThatExceptionOfType(expectedExceptionType) - .isThrownBy(() -> executeConsumer.accept(mongoPreparedStatement)) - .withCause(expectedCause) - .satisfies(sqlException -> { - assertAll( - () -> assertNull(sqlException.getSQLState()), - () -> assertEquals(0, sqlException.getErrorCode())); - }); + try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(MQL_ITEMS_INSERT)) { + mongoPreparedStatement.addBatch(); + assertThatExceptionOfType(BatchUpdateException.class) + .isThrownBy(mongoPreparedStatement::executeBatch) + .satisfies(asserter); + } + } + + private void assertExecuteUpdateThrowsSqlException( + ThrowingConsumer asserter) + throws SQLException { + try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(MQL_ITEMS_INSERT)) { + assertThatExceptionOfType(SQLException.class) + .isThrownBy(mongoPreparedStatement::executeUpdate) + .satisfies(asserter); + } + } + + private void assertExecuteQueryThrowsSqlException( + ThrowingConsumer asserter) + throws SQLException { + try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(MQL_ITEMS_AGGREGATE)) { + assertThatExceptionOfType(SQLException.class) + .isThrownBy(mongoPreparedStatement::executeQuery) + .satisfies(asserter); } } From 330de0fccbb85fab5505520dd5c4e4dd06d49641 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Mon, 6 Oct 2025 20:26:57 -0700 Subject: [PATCH 08/37] Apply spotless. --- .../hibernate/jdbc/MongoStatement.java | 95 ++---- .../jdbc/MongoPreparedStatementTests.java | 321 +++++++++--------- 2 files changed, 189 insertions(+), 227 deletions(-) diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 7f08e9ec..2c44d409 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -16,6 +16,15 @@ package com.mongodb.hibernate.jdbc; +import static com.mongodb.assertions.Assertions.fail; +import static com.mongodb.hibernate.internal.MongoAssertions.assertNotNull; +import static com.mongodb.hibernate.internal.MongoConstants.ID_FIELD_NAME; +import static com.mongodb.hibernate.internal.VisibleForTesting.AccessModifier.PRIVATE; +import static java.lang.Math.max; +import static java.lang.String.format; +import static java.util.Collections.singletonList; +import static java.util.stream.Collectors.toCollection; + import com.mongodb.ErrorCategory; import com.mongodb.MongoBulkWriteException; import com.mongodb.MongoException; @@ -35,10 +44,6 @@ import com.mongodb.client.model.WriteModel; import com.mongodb.hibernate.internal.FeatureNotSupportedException; import com.mongodb.hibernate.internal.VisibleForTesting; -import org.bson.BsonDocument; -import org.bson.BsonValue; -import org.jspecify.annotations.Nullable; - import java.sql.BatchUpdateException; import java.sql.Connection; import java.sql.ResultSet; @@ -55,15 +60,9 @@ import java.util.Collection; import java.util.List; import java.util.Map; - -import static com.mongodb.assertions.Assertions.fail; -import static com.mongodb.hibernate.internal.MongoAssertions.assertNotNull; -import static com.mongodb.hibernate.internal.MongoConstants.ID_FIELD_NAME; -import static com.mongodb.hibernate.internal.VisibleForTesting.AccessModifier.PRIVATE; -import static java.lang.Math.max; -import static java.lang.String.format; -import static java.util.Collections.singletonList; -import static java.util.stream.Collectors.toCollection; +import org.bson.BsonDocument; +import org.bson.BsonValue; +import org.jspecify.annotations.Nullable; class MongoStatement implements StatementAdapter { @@ -409,35 +408,27 @@ static int getUpdateCount(CommandType commandType, BulkWriteResult bulkWriteResu }; } - private static SQLException handleException(RuntimeException exception, - CommandType commandType, - ExecutionType executionType) { + private static SQLException handleException( + RuntimeException exception, CommandType commandType, ExecutionType executionType) { int errorCode = getErrorCode(exception); return switch (executionType) { case BATCH -> handleBatchException(exception, commandType, errorCode); case QUERY, UPDATE -> { if (exception instanceof MongoException mongoException) { Exception handledException = handleMongoException(mongoException, errorCode); - yield toSqlException( - errorCode, - handledException); + yield toSqlException(errorCode, handledException); } yield toSqlException(DEFAULT_ERROR_CODE, exception); } }; } - private static SQLException handleBatchException(RuntimeException exception, - CommandType commandType, - int errorCode) { + private static SQLException handleBatchException( + RuntimeException exception, CommandType commandType, int errorCode) { if (exception instanceof MongoException mongo) { Exception cause = handleMongoException(mongo, errorCode); if (exception instanceof MongoBulkWriteException bulkWriteException) { - return createBatchUpdateException( - cause, - bulkWriteException.getWriteResult(), - errorCode, - commandType); + return createBatchUpdateException(cause, bulkWriteException.getWriteResult(), errorCode, commandType); } return toBatchUpdateException(errorCode, cause); } @@ -465,14 +456,10 @@ private static SQLException toSqlException(final int errorCode, final Exception return new SQLException(EXCEPTION_MESSAGE_OPERATION_FAILED, null, errorCode, exception); } - private static Exception handleMongoException(final MongoException exceptionToHandle, - final int errorCode) { + private static Exception handleMongoException(final MongoException exceptionToHandle, final int errorCode) { Exception exception; if (isTimeoutException(exceptionToHandle)) { - exception = new SQLTimeoutException(EXCEPTION_MESSAGE_TIMEOUT, - null, - errorCode, - exceptionToHandle); + exception = new SQLTimeoutException(EXCEPTION_MESSAGE_TIMEOUT, null, errorCode, exceptionToHandle); } else { exception = handleByErrorCode(exceptionToHandle, errorCode); } @@ -483,11 +470,8 @@ private static Exception handleMongoException(final MongoException exceptionToHa } private static SQLException toBatchUpdateException(final int errorCode, final Exception exception) { - return withCause(new BatchUpdateException( - EXCEPTION_MESSAGE_BATCH_FAILED, - null, - errorCode, - EMPTY_BATCH_RESULT), + return withCause( + new BatchUpdateException(EXCEPTION_MESSAGE_BATCH_FAILED, null, errorCode, EMPTY_BATCH_RESULT), exception); } @@ -499,20 +483,14 @@ private static T withCause(T exception, final Exception ca return exception; } - private static Exception handleByErrorCode(final MongoException mongoException, - int errorCode) { + private static Exception handleByErrorCode(final MongoException mongoException, int errorCode) { ErrorCategory errorCategory = ErrorCategory.fromErrorCode(errorCode); return switch (errorCategory) { - case DUPLICATE_KEY -> new SQLIntegrityConstraintViolationException( - EXCEPTION_MESSAGE_OPERATION_FAILED, - null, - errorCode, - mongoException); - case EXECUTION_TIMEOUT -> new SQLTimeoutException( - EXCEPTION_MESSAGE_TIMEOUT, - null, - errorCode, - mongoException); + case DUPLICATE_KEY -> + new SQLIntegrityConstraintViolationException( + EXCEPTION_MESSAGE_OPERATION_FAILED, null, errorCode, mongoException); + case EXECUTION_TIMEOUT -> + new SQLTimeoutException(EXCEPTION_MESSAGE_TIMEOUT, null, errorCode, mongoException); case UNCATEGORIZED -> mongoException; }; } @@ -525,25 +503,20 @@ private static boolean isTimeoutException(final MongoException exception) { } private static BatchUpdateException createBatchUpdateException( - Exception cause, - BulkWriteResult bulkWriteResult, - int errorCode, - CommandType commandType) { + Exception cause, BulkWriteResult bulkWriteResult, int errorCode, CommandType commandType) { var updateCount = getUpdateCount(commandType, bulkWriteResult); var updateCounts = new int[updateCount]; Arrays.fill(updateCounts, Statement.SUCCESS_NO_INFO); - return withCause(new BatchUpdateException( - EXCEPTION_MESSAGE_BATCH_FAILED, - null, - errorCode, - updateCounts), - cause); + return withCause( + new BatchUpdateException(EXCEPTION_MESSAGE_BATCH_FAILED, null, errorCode, updateCounts), cause); } private static int getErrorCode(final MongoBulkWriteException mongoBulkWriteException) { var writeErrors = mongoBulkWriteException.getWriteErrors(); // Since we are executing an ordered bulk write, there will be at most one BulkWriteError. - return writeErrors.isEmpty() ? DEFAULT_ERROR_CODE : writeErrors.get(DEFAULT_ERROR_CODE).getCode(); + return writeErrors.isEmpty() + ? DEFAULT_ERROR_CODE + : writeErrors.get(DEFAULT_ERROR_CODE).getCode(); } enum CommandType { diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java index 1ff3eaf4..08a0f39e 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java @@ -16,6 +16,28 @@ package com.mongodb.hibernate.jdbc; +import static java.lang.Math.max; +import static java.sql.Statement.SUCCESS_NO_INFO; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptySet; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Named.named; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.verify; + import com.mongodb.MongoBulkWriteException; import com.mongodb.MongoException; import com.mongodb.MongoExecutionTimeoutException; @@ -35,6 +57,23 @@ import com.mongodb.client.model.InsertOneModel; import com.mongodb.client.model.WriteModel; import com.mongodb.hibernate.internal.type.ObjectIdJdbcType; +import java.math.BigDecimal; +import java.sql.Array; +import java.sql.BatchUpdateException; +import java.sql.Date; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.SQLIntegrityConstraintViolationException; +import java.sql.SQLSyntaxErrorException; +import java.sql.SQLTimeoutException; +import java.sql.SQLTransientException; +import java.sql.Time; +import java.sql.Timestamp; +import java.sql.Types; +import java.util.Calendar; +import java.util.List; +import java.util.function.Consumer; +import java.util.stream.Stream; import org.assertj.core.api.ThrowingConsumer; import org.bson.BsonArray; import org.bson.BsonBoolean; @@ -58,46 +97,6 @@ import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; -import java.math.BigDecimal; -import java.sql.Array; -import java.sql.BatchUpdateException; -import java.sql.Date; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.SQLIntegrityConstraintViolationException; -import java.sql.SQLSyntaxErrorException; -import java.sql.SQLTimeoutException; -import java.sql.SQLTransientException; -import java.sql.Time; -import java.sql.Timestamp; -import java.sql.Types; -import java.util.Calendar; -import java.util.List; -import java.util.function.Consumer; -import java.util.stream.Stream; - -import static java.lang.Math.max; -import static java.sql.Statement.SUCCESS_NO_INFO; -import static java.util.Collections.emptyList; -import static java.util.Collections.emptySet; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.junit.jupiter.api.Assertions.assertAll; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Named.named; -import static org.mockito.ArgumentMatchers.anyList; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.same; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.verify; - @ExtendWith(MockitoExtension.class) class MongoPreparedStatementTests { @@ -247,38 +246,35 @@ private static Stream timeoutExceptions() { return Stream.of( new MongoExecutionTimeoutException(1, DUMMY_EXCEPTION_MESSAGE), new MongoSocketReadTimeoutException(DUMMY_EXCEPTION_MESSAGE, DUMMY_SERVER_ADDRESS, dummyCause), - new MongoSocketWriteTimeoutException( - DUMMY_EXCEPTION_MESSAGE, DUMMY_SERVER_ADDRESS, dummyCause), + new MongoSocketWriteTimeoutException(DUMMY_EXCEPTION_MESSAGE, DUMMY_SERVER_ADDRESS, dummyCause), new MongoTimeoutException(DUMMY_EXCEPTION_MESSAGE), new MongoOperationTimeoutException(DUMMY_EXCEPTION_MESSAGE), new MongoException(50, DUMMY_EXCEPTION_MESSAGE) // 50 is a timeout error code - ); + ); } private static Stream transientTimeoutExceptions() { - return timeoutExceptions().peek(mongoException -> mongoException.addLabel(MongoException.TRANSIENT_TRANSACTION_ERROR_LABEL)); + return timeoutExceptions() + .peek(mongoException -> mongoException.addLabel(MongoException.TRANSIENT_TRANSACTION_ERROR_LABEL)); } private static Stream constraintViolationExceptions() { return Stream.of( new MongoException(11000, "Duplicate key error"), new MongoException(11001, "Duplicate key error"), - new MongoException(12582, "Duplicate key error") - ); + new MongoException(12582, "Duplicate key error")); } private static Stream genericMongoExceptions() { return Stream.of( - new MongoException(-3, DUMMY_EXCEPTION_MESSAGE), - new MongoException(5000, DUMMY_EXCEPTION_MESSAGE) - ); + new MongoException(-3, DUMMY_EXCEPTION_MESSAGE), new MongoException(5000, DUMMY_EXCEPTION_MESSAGE)); } private static Stream genericTransientMongoExceptions() { return Stream.of( - new MongoException(-3, DUMMY_EXCEPTION_MESSAGE), - new MongoException(5000, DUMMY_EXCEPTION_MESSAGE) - ).peek(mongoException -> mongoException.addLabel(MongoException.TRANSIENT_TRANSACTION_ERROR_LABEL)); + new MongoException(-3, DUMMY_EXCEPTION_MESSAGE), + new MongoException(5000, DUMMY_EXCEPTION_MESSAGE)) + .peek(mongoException -> mongoException.addLabel(MongoException.TRANSIENT_TRANSACTION_ERROR_LABEL)); } @ParameterizedTest(name = "test executeBatch throws SQLException. Exception: {0}") @@ -291,8 +287,7 @@ void testExecuteBatchMongoException(MongoException mongoException) throws SQLExc () -> assertEquals(mongoException, batchUpdateException.getCause()), () -> assertEquals(expectedErrorCode, batchUpdateException.getErrorCode()), () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0), - () -> assertNull(batchUpdateException.getSQLState()) - ); + () -> assertNull(batchUpdateException.getSQLState())); }); } @@ -326,8 +321,7 @@ void testExecuteBatchTransientMongoException(MongoException mongoException) thro }, () -> assertEquals(expectedErrorCode, batchUpdateException.getErrorCode()), () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0), - () -> assertNull(batchUpdateException.getSQLState()) - ); + () -> assertNull(batchUpdateException.getSQLState())); }); } @@ -349,7 +343,8 @@ void testExecuteQueryTransientMongoException(MongoException mongoException) thro }); } - private static void assertGenericTransientMongoException(final MongoException mongoException, final SQLException sqlException) { + private static void assertGenericTransientMongoException( + final MongoException mongoException, final SQLException sqlException) { int expectedErrorCode = max(0, mongoException.getCode()); assertAll( () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), @@ -362,7 +357,8 @@ private static void assertGenericTransientMongoException(final MongoException mo () -> assertNull(sqlException.getSQLState())); } - private static void assertGenericMongoException(final MongoException mongoException, final SQLException sqlException) { + private static void assertGenericMongoException( + final MongoException mongoException, final SQLException sqlException) { int expectedErrorCode = max(0, mongoException.getCode()); assertAll( () -> assertThat((Throwable) sqlException).isExactlyInstanceOf(SQLException.class), @@ -371,60 +367,64 @@ private static void assertGenericMongoException(final MongoException mongoExcept () -> assertNull(sqlException.getSQLState())); } - @ParameterizedTest(name = "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") + @ParameterizedTest( + name = + "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") @MethodSource("timeoutExceptions") - void testExecuteUpdateTimeoutException(MongoException mongoTimeoutException) - throws SQLException { + void testExecuteUpdateTimeoutException(MongoException mongoTimeoutException) throws SQLException { doThrow(mongoTimeoutException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); assertExecuteUpdateThrowsSqlException( - sqlException -> - assertTimeoutException(mongoTimeoutException, sqlException)); + sqlException -> assertTimeoutException(mongoTimeoutException, sqlException)); } - @ParameterizedTest(name = "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") + @ParameterizedTest( + name = + "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") @MethodSource("timeoutExceptions") - void testExecuteQueryTimeoutException(MongoException mongoTimeoutException) - throws SQLException { + void testExecuteQueryTimeoutException(MongoException mongoTimeoutException) throws SQLException { doThrow(mongoTimeoutException).when(mongoCollection).aggregate(eq(clientSession), anyList()); assertExecuteQueryThrowsSqlException( - sqlException -> - assertTimeoutException(mongoTimeoutException, sqlException)); + sqlException -> assertTimeoutException(mongoTimeoutException, sqlException)); } - private static void assertTimeoutException(final MongoException mongoTimeoutException, final SQLException sqlException) { + private static void assertTimeoutException( + final MongoException mongoTimeoutException, final SQLException sqlException) { int expectedErrorCode = max(0, mongoTimeoutException.getCode()); assertAll( () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), () -> { - SQLTimeoutException sqlTimeoutException = assertInstanceOf(SQLTimeoutException.class, sqlException); + SQLTimeoutException sqlTimeoutException = + assertInstanceOf(SQLTimeoutException.class, sqlException); assertEquals(expectedErrorCode, sqlTimeoutException.getErrorCode()); assertEquals(mongoTimeoutException, sqlTimeoutException.getCause()); }, () -> assertNull(sqlException.getSQLState())); } - @ParameterizedTest(name = "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") + @ParameterizedTest( + name = + "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") @MethodSource("transientTimeoutExceptions") void testExecuteUpdateTransientTimeoutException(MongoException mongoTransientTimeoutException) throws SQLException { doThrow(mongoTransientTimeoutException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); assertExecuteUpdateThrowsSqlException( - sqlException -> - assertTransientTimeoutException(mongoTransientTimeoutException, sqlException)); + sqlException -> assertTransientTimeoutException(mongoTransientTimeoutException, sqlException)); } - @ParameterizedTest(name = "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") + @ParameterizedTest( + name = + "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") @MethodSource("transientTimeoutExceptions") void testExecuteQueryTransientTimeoutException(MongoException mongoTransientTimeoutException) throws SQLException { doThrow(mongoTransientTimeoutException).when(mongoCollection).aggregate(eq(clientSession), anyList()); assertExecuteQueryThrowsSqlException( - sqlException -> - assertTransientTimeoutException(mongoTransientTimeoutException, sqlException)); + sqlException -> assertTransientTimeoutException(mongoTransientTimeoutException, sqlException)); } - private static void assertTransientTimeoutException(final MongoException mongoTransientTimeoutException, - final SQLException sqlException) { + private static void assertTransientTimeoutException( + final MongoException mongoTransientTimeoutException, final SQLException sqlException) { int expectedErrorCode = max(0, mongoTransientTimeoutException.getCode()); assertAll( () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), @@ -445,10 +445,9 @@ private static void assertTransientTimeoutException(final MongoException mongoTr void testExecuteUpdateConstraintViolationException(MongoException mongoException) throws SQLException { int expectedErrorCode = mongoException.getCode(); doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); - assertExecuteUpdateThrowsSqlException( - sqlException -> { - assertConstraintViolationException(mongoException, sqlException, expectedErrorCode); - }); + assertExecuteUpdateThrowsSqlException(sqlException -> { + assertConstraintViolationException(mongoException, sqlException, expectedErrorCode); + }); } @ParameterizedTest(name = "test executeUpdate constraint violation. Exception code={0}") @@ -456,15 +455,13 @@ void testExecuteUpdateConstraintViolationException(MongoException mongoException void testExecuteQueryConstraintViolationException(MongoException mongoException) throws SQLException { int expectedErrorCode = mongoException.getCode(); doThrow(mongoException).when(mongoCollection).aggregate(eq(clientSession), anyList()); - assertExecuteQueryThrowsSqlException( - sqlException -> { - assertConstraintViolationException(mongoException, sqlException, expectedErrorCode); - }); + assertExecuteQueryThrowsSqlException(sqlException -> { + assertConstraintViolationException(mongoException, sqlException, expectedErrorCode); + }); } - private static void assertConstraintViolationException(final MongoException mongoException, - final SQLException sqlException, - final int expectedErrorCode) { + private static void assertConstraintViolationException( + final MongoException mongoException, final SQLException sqlException, final int expectedErrorCode) { assertAll( () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), () -> { @@ -476,51 +473,49 @@ private static void assertConstraintViolationException(final MongoException mong () -> assertNull(sqlException.getSQLState())); } - - @ParameterizedTest(name = "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") + @ParameterizedTest( + name = + "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") @MethodSource("timeoutExceptions") - void testExecuteBatchTimeoutException(MongoException mongoTimeoutException) - throws SQLException { + void testExecuteBatchTimeoutException(MongoException mongoTimeoutException) throws SQLException { doThrow(mongoTimeoutException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); - assertExecuteBatchThrowsSqlException( - batchUpdateException -> { - int expectedErrorCode = max(0, mongoTimeoutException.getCode()); - assertAll( - () -> assertEquals(expectedErrorCode, batchUpdateException.getErrorCode()), - () -> { - SQLTimeoutException sqlTimeoutException = - assertInstanceOf(SQLTimeoutException.class, batchUpdateException.getCause()); - assertEquals(expectedErrorCode, sqlTimeoutException.getErrorCode()); - assertEquals(mongoTimeoutException, sqlTimeoutException.getCause()); - }, - () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0), - () -> assertNull(batchUpdateException.getSQLState())); - }); - } - - - @ParameterizedTest(name = "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") + assertExecuteBatchThrowsSqlException(batchUpdateException -> { + int expectedErrorCode = max(0, mongoTimeoutException.getCode()); + assertAll( + () -> assertEquals(expectedErrorCode, batchUpdateException.getErrorCode()), + () -> { + SQLTimeoutException sqlTimeoutException = + assertInstanceOf(SQLTimeoutException.class, batchUpdateException.getCause()); + assertEquals(expectedErrorCode, sqlTimeoutException.getErrorCode()); + assertEquals(mongoTimeoutException, sqlTimeoutException.getCause()); + }, + () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0), + () -> assertNull(batchUpdateException.getSQLState())); + }); + } + + @ParameterizedTest( + name = + "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") @MethodSource("transientTimeoutExceptions") - void testExecuteBatchTransientTimeoutException(MongoException mongoTimeoutException) - throws SQLException { + void testExecuteBatchTransientTimeoutException(MongoException mongoTimeoutException) throws SQLException { doThrow(mongoTimeoutException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); - assertExecuteBatchThrowsSqlException( - batchUpdateException -> { - int expectedErrorCode = max(0, mongoTimeoutException.getCode()); - assertAll( - () -> assertEquals(expectedErrorCode, batchUpdateException.getErrorCode()), - () -> { - SQLTransientException sqlTransientException = - assertInstanceOf(SQLTransientException.class, batchUpdateException.getCause()); - assertEquals(expectedErrorCode, sqlTransientException.getErrorCode()); - SQLTimeoutException sqlTimeoutException = - assertInstanceOf(SQLTimeoutException.class, sqlTransientException.getCause()); - assertEquals(expectedErrorCode, sqlTimeoutException.getErrorCode()); - assertEquals(mongoTimeoutException, sqlTimeoutException.getCause()); - }, - () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0), - () -> assertNull(batchUpdateException.getSQLState())); - }); + assertExecuteBatchThrowsSqlException(batchUpdateException -> { + int expectedErrorCode = max(0, mongoTimeoutException.getCode()); + assertAll( + () -> assertEquals(expectedErrorCode, batchUpdateException.getErrorCode()), + () -> { + SQLTransientException sqlTransientException = + assertInstanceOf(SQLTransientException.class, batchUpdateException.getCause()); + assertEquals(expectedErrorCode, sqlTransientException.getErrorCode()); + SQLTimeoutException sqlTimeoutException = + assertInstanceOf(SQLTimeoutException.class, sqlTransientException.getCause()); + assertEquals(expectedErrorCode, sqlTimeoutException.getErrorCode()); + assertEquals(mongoTimeoutException, sqlTimeoutException.getCause()); + }, + () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0), + () -> assertNull(batchUpdateException.getSQLState())); + }); } @ParameterizedTest(name = "test executeBatch constraint violation. Exception code={0}") @@ -528,25 +523,24 @@ void testExecuteBatchTransientTimeoutException(MongoException mongoTimeoutExcept void testExecuteBatchConstraintViolationException(MongoException mongoException) throws SQLException { int expectedErrorCode = mongoException.getCode(); doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); - assertExecuteBatchThrowsSqlException( - batchUpdateException -> { - assertAll( - () -> assertEquals(expectedErrorCode, batchUpdateException.getErrorCode()), - () -> { - SQLIntegrityConstraintViolationException sqlIntegrityConstraintViolationException = - assertInstanceOf(SQLIntegrityConstraintViolationException.class, - batchUpdateException.getCause()); - assertEquals(expectedErrorCode, sqlIntegrityConstraintViolationException.getErrorCode()); - assertEquals(mongoException, sqlIntegrityConstraintViolationException.getCause()); - }, - () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0), - () -> assertNull(batchUpdateException.getSQLState())); - }); + assertExecuteBatchThrowsSqlException(batchUpdateException -> { + assertAll( + () -> assertEquals(expectedErrorCode, batchUpdateException.getErrorCode()), + () -> { + SQLIntegrityConstraintViolationException sqlIntegrityConstraintViolationException = + assertInstanceOf( + SQLIntegrityConstraintViolationException.class, + batchUpdateException.getCause()); + assertEquals(expectedErrorCode, sqlIntegrityConstraintViolationException.getErrorCode()); + assertEquals(mongoException, sqlIntegrityConstraintViolationException.getCause()); + }, + () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0), + () -> assertNull(batchUpdateException.getSQLState())); + }); } @Test - void testExecuteBatchRuntimeExceptionCause() - throws SQLException { + void testExecuteBatchRuntimeExceptionCause() throws SQLException { RuntimeException runtimeException = new RuntimeException(); doThrow(runtimeException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); assertExecuteBatchThrowsSqlException(batchUpdateException -> { @@ -554,25 +548,24 @@ void testExecuteBatchRuntimeExceptionCause() () -> assertEquals(runtimeException, batchUpdateException.getCause()), () -> assertEquals(0, batchUpdateException.getErrorCode()), () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0), - () -> assertNull(batchUpdateException.getSQLState()) - ); + () -> assertNull(batchUpdateException.getSQLState())); }); } @Test - void testExecuteUpdateRuntimeExceptionCause() - throws SQLException { + void testExecuteUpdateRuntimeExceptionCause() throws SQLException { RuntimeException runtimeException = new RuntimeException(); doThrow(runtimeException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); - assertExecuteUpdateThrowsSqlException(sqlException -> assertGenericException(sqlException, runtimeException)); + assertExecuteUpdateThrowsSqlException( + sqlException -> assertGenericException(sqlException, runtimeException)); } @Test - void testExecuteQueryRuntimeExceptionCause() - throws SQLException { + void testExecuteQueryRuntimeExceptionCause() throws SQLException { RuntimeException runtimeException = new RuntimeException(); doThrow(runtimeException).when(mongoCollection).aggregate(eq(clientSession), anyList()); - assertExecuteQueryThrowsSqlException(sqlException -> assertGenericException(sqlException, runtimeException)); + assertExecuteQueryThrowsSqlException( + sqlException -> assertGenericException(sqlException, runtimeException)); } private static void assertGenericException(final SQLException sqlException, RuntimeException cause) { @@ -597,8 +590,8 @@ private static Stream bulkWriteExceptionsForExecuteUpdate() { name = "test executeUpdate throws SQLException when MongoBulkWriteException occurs." + " Parameters: commandName={0}, exception={1}") @MethodSource("bulkWriteExceptionsForExecuteUpdate") - void testExecuteUpdateMongoBulkWriteException( - String mql, MongoBulkWriteException mongoBulkWriteException) throws SQLException { + void testExecuteUpdateMongoBulkWriteException(String mql, MongoBulkWriteException mongoBulkWriteException) + throws SQLException { doThrow(mongoBulkWriteException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); Integer vendorCodeError = getVendorCodeError(mongoBulkWriteException); @@ -606,9 +599,9 @@ void testExecuteUpdateMongoBulkWriteException( assertThatExceptionOfType(SQLException.class) .isThrownBy(mongoPreparedStatement::executeUpdate) .withCause(mongoBulkWriteException) - .satisfies(sqlException -> - assertAll(() -> assertNull(sqlException.getSQLState()), - () -> assertEquals(vendorCodeError, sqlException.getErrorCode()))); + .satisfies(sqlException -> assertAll( + () -> assertNull(sqlException.getSQLState()), + () -> assertEquals(vendorCodeError, sqlException.getErrorCode()))); } } @@ -665,8 +658,7 @@ void testExecuteBatchMongoBulkWriteException( } } - private void assertExecuteBatchThrowsSqlException( - ThrowingConsumer asserter) + private void assertExecuteBatchThrowsSqlException(ThrowingConsumer asserter) throws SQLException { try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(MQL_ITEMS_INSERT)) { mongoPreparedStatement.addBatch(); @@ -676,8 +668,7 @@ private void assertExecuteBatchThrowsSqlException( } } - private void assertExecuteUpdateThrowsSqlException( - ThrowingConsumer asserter) + private void assertExecuteUpdateThrowsSqlException(ThrowingConsumer asserter) throws SQLException { try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(MQL_ITEMS_INSERT)) { assertThatExceptionOfType(SQLException.class) @@ -686,9 +677,7 @@ private void assertExecuteUpdateThrowsSqlException( } } - private void assertExecuteQueryThrowsSqlException( - ThrowingConsumer asserter) - throws SQLException { + private void assertExecuteQueryThrowsSqlException(ThrowingConsumer asserter) throws SQLException { try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(MQL_ITEMS_AGGREGATE)) { assertThatExceptionOfType(SQLException.class) .isThrownBy(mongoPreparedStatement::executeQuery) From 5c81496567a87400427e54d8f1edce897040e7e3 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Mon, 6 Oct 2025 21:44:32 -0700 Subject: [PATCH 09/37] Apply spotless. --- ...ongoPreparedStatementIntegrationTests.java | 64 +++--- .../hibernate/jdbc/MongoStatement.java | 19 +- .../jdbc/MongoPreparedStatementTests.java | 202 +++++++++--------- 3 files changed, 141 insertions(+), 144 deletions(-) diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java index 1c83ec6f..77faac90 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java @@ -682,51 +682,53 @@ void testNotSupportedCommands(String commandName) { }); } - @ParameterizedTest(name = "test not supported update elements. Parameters: option={0}") + @ParameterizedTest(name = "test not supported update elements. Parameters: option={0}") @ValueSource(strings = {"hint", "collation", "arrayFilters", "sort", "upsert", "c"}) void testNotSupportedUpdateElements(String unsupportedElement) { doWorkAwareOfAutoCommit(connection -> { - try (PreparedStatement pstm = connection.prepareStatement( - format( - """ - { - update: "books", - updates: [ - { - q: { author: { $eq: "Leo Tolstoy" } }, - u: { $set: { outOfStock: true } }, - multi: true, - %s: { _id: 1 } - } - ] - }""", unsupportedElement))) { + try (PreparedStatement pstm = connection.prepareStatement(format( + """ + { + update: "books", + updates: [ + { + q: { author: { $eq: "Leo Tolstoy" } }, + u: { $set: { outOfStock: true } }, + multi: true, + %s: { _id: 1 } + } + ] + }""", + unsupportedElement))) { SQLFeatureNotSupportedException exception = assertThrows(SQLFeatureNotSupportedException.class, pstm::executeUpdate); - assertThat(exception.getMessage()).isEqualTo(format("Unsupported elements in update command: [%s]", unsupportedElement)); + assertThat(exception.getMessage()) + .isEqualTo(format("Unsupported elements in update command: [%s]", unsupportedElement)); } }); } - @ParameterizedTest(name = "test not supported delete elements. Parameters: option={0}") + @ParameterizedTest(name = "test not supported delete elements. Parameters: option={0}") @ValueSource(strings = {"hint", "collation"}) void testNotSupportedDeleteElements(String unsupportedElement) { doWorkAwareOfAutoCommit(connection -> { - try (PreparedStatement pstm = connection.prepareStatement( - format( - """ - { - delete: "books", - deletes: [ - { - q: { author: { $eq: "Leo Tolstoy" } }, - limit: 0, - %s: { _id: 1 } - } - ] - }""", unsupportedElement))) { + try (PreparedStatement pstm = connection.prepareStatement(format( + """ + { + delete: "books", + deletes: [ + { + q: { author: { $eq: "Leo Tolstoy" } }, + limit: 0, + %s: { _id: 1 } + } + ] + }""", + unsupportedElement))) { SQLFeatureNotSupportedException exception = assertThrows(SQLFeatureNotSupportedException.class, pstm::executeUpdate); - assertThat(exception.getMessage()).isEqualTo(format("Unsupported elements in delete command: [%s]", unsupportedElement)); + assertThat(exception.getMessage()) + .isEqualTo(format("Unsupported elements in delete command: [%s]", unsupportedElement)); } }); } diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 2c44d409..cdd3e19f 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -425,8 +425,8 @@ private static SQLException handleException( private static SQLException handleBatchException( RuntimeException exception, CommandType commandType, int errorCode) { - if (exception instanceof MongoException mongo) { - Exception cause = handleMongoException(mongo, errorCode); + if (exception instanceof MongoException mongoException) { + Exception cause = handleMongoException(mongoException, errorCode); if (exception instanceof MongoBulkWriteException bulkWriteException) { return createBatchUpdateException(cause, bulkWriteException.getWriteResult(), errorCode, commandType); } @@ -436,8 +436,8 @@ private static SQLException handleBatchException( } private static int getErrorCode(final RuntimeException runtimeException) { - if (runtimeException instanceof MongoBulkWriteException bulk) { - return getErrorCode(bulk); + if (runtimeException instanceof MongoBulkWriteException mongoBulkWriteException) { + return getErrorCode(mongoBulkWriteException); } if (runtimeException instanceof MongoException mongoException) { return max(DEFAULT_ERROR_CODE, mongoException.getCode()); @@ -461,7 +461,7 @@ private static Exception handleMongoException(final MongoException exceptionToHa if (isTimeoutException(exceptionToHandle)) { exception = new SQLTimeoutException(EXCEPTION_MESSAGE_TIMEOUT, null, errorCode, exceptionToHandle); } else { - exception = handleByErrorCode(exceptionToHandle, errorCode); + exception = handleByErrorCode(errorCode, exceptionToHandle); } if (exceptionToHandle.hasErrorLabel(MongoException.TRANSIENT_TRANSACTION_ERROR_LABEL)) { return toSqlTransientException(errorCode, exception); @@ -483,15 +483,14 @@ private static T withCause(T exception, final Exception ca return exception; } - private static Exception handleByErrorCode(final MongoException mongoException, int errorCode) { + private static Exception handleByErrorCode(int errorCode, final MongoException cause) { ErrorCategory errorCategory = ErrorCategory.fromErrorCode(errorCode); return switch (errorCategory) { case DUPLICATE_KEY -> new SQLIntegrityConstraintViolationException( - EXCEPTION_MESSAGE_OPERATION_FAILED, null, errorCode, mongoException); - case EXECUTION_TIMEOUT -> - new SQLTimeoutException(EXCEPTION_MESSAGE_TIMEOUT, null, errorCode, mongoException); - case UNCATEGORIZED -> mongoException; + EXCEPTION_MESSAGE_OPERATION_FAILED, null, errorCode, cause); + case EXECUTION_TIMEOUT -> new SQLTimeoutException(EXCEPTION_MESSAGE_TIMEOUT, null, errorCode, cause); + case UNCATEGORIZED -> cause; }; } diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java index 08a0f39e..1155c298 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java @@ -260,9 +260,9 @@ private static Stream transientTimeoutExceptions() { private static Stream constraintViolationExceptions() { return Stream.of( - new MongoException(11000, "Duplicate key error"), - new MongoException(11001, "Duplicate key error"), - new MongoException(12582, "Duplicate key error")); + new MongoException(11000, DUMMY_EXCEPTION_MESSAGE), + new MongoException(11001, DUMMY_EXCEPTION_MESSAGE), + new MongoException(12582, DUMMY_EXCEPTION_MESSAGE)); } private static Stream genericMongoExceptions() { @@ -277,21 +277,22 @@ private static Stream genericTransientMongoExceptions() { .peek(mongoException -> mongoException.addLabel(MongoException.TRANSIENT_TRANSACTION_ERROR_LABEL)); } - @ParameterizedTest(name = "test executeBatch throws SQLException. Exception: {0}") + @ParameterizedTest(name = "test executeBatch MongoException. Parameters: Parameters: exception: {0}") @MethodSource("genericMongoExceptions") void testExecuteBatchMongoException(MongoException mongoException) throws SQLException { int expectedErrorCode = max(0, mongoException.getCode()); doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + assertExecuteBatchThrowsSqlException(batchUpdateException -> { assertAll( - () -> assertEquals(mongoException, batchUpdateException.getCause()), () -> assertEquals(expectedErrorCode, batchUpdateException.getErrorCode()), - () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0), - () -> assertNull(batchUpdateException.getSQLState())); + () -> assertNull(batchUpdateException.getSQLState()), + () -> assertEquals(mongoException, batchUpdateException.getCause()), + () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0)); }); } - @ParameterizedTest(name = "test executeBatch throws SQLException. Exception: {0}") + @ParameterizedTest(name = "test executeUpdate MongoException. Parameters: Parameters: exception: {0}") @MethodSource("genericMongoExceptions") void testExecuteUpdateMongoException(MongoException mongoException) throws SQLException { doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); @@ -299,7 +300,7 @@ void testExecuteUpdateMongoException(MongoException mongoException) throws SQLEx sqlException -> assertGenericMongoException(mongoException, sqlException)); } - @ParameterizedTest(name = "test executeBatch throws SQLException. Exception: {0}") + @ParameterizedTest(name = "test executeUQuery MongoException. Parameters: Parameters: exception: {0}") @MethodSource("genericMongoExceptions") void testExecuteQueryMongoException(MongoException mongoException) throws SQLException { doThrow(mongoException).when(mongoCollection).aggregate(eq(clientSession), anyList()); @@ -307,11 +308,12 @@ void testExecuteQueryMongoException(MongoException mongoException) throws SQLExc sqlException -> assertGenericMongoException(mongoException, sqlException)); } - @ParameterizedTest(name = "test executeBatch throws SQLException. Exception: {0}") + @ParameterizedTest(name = "test executeBatch transient MongoException. Parameters: Parameters: exception: {0}") @MethodSource("genericTransientMongoExceptions") void testExecuteBatchTransientMongoException(MongoException mongoException) throws SQLException { int expectedErrorCode = max(0, mongoException.getCode()); doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + assertExecuteBatchThrowsSqlException(batchUpdateException -> { assertAll( () -> { @@ -325,7 +327,7 @@ void testExecuteBatchTransientMongoException(MongoException mongoException) thro }); } - @ParameterizedTest(name = "test executeUpdate throws SQLException. Exception: {0}") + @ParameterizedTest(name = "test executeUpdate transient MongoException. Parameters: Parameters: exception: {0}") @MethodSource("genericTransientMongoExceptions") void testExecuteUpdateTransientMongoException(MongoException mongoException) throws SQLException { doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); @@ -334,7 +336,7 @@ void testExecuteUpdateTransientMongoException(MongoException mongoException) thr }); } - @ParameterizedTest(name = "test executeQuery throws SQLException. Exception: {0}") + @ParameterizedTest(name = "test executeQuery transient MongoException. Parameters: Parameters: exception: {0}") @MethodSource("genericTransientMongoExceptions") void testExecuteQueryTransientMongoException(MongoException mongoException) throws SQLException { doThrow(mongoException).when(mongoCollection).aggregate(eq(clientSession), anyList()); @@ -343,33 +345,7 @@ void testExecuteQueryTransientMongoException(MongoException mongoException) thro }); } - private static void assertGenericTransientMongoException( - final MongoException mongoException, final SQLException sqlException) { - int expectedErrorCode = max(0, mongoException.getCode()); - assertAll( - () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), - () -> { - SQLTransientException sqlTransientException = - assertInstanceOf(SQLTransientException.class, sqlException); - assertEquals(expectedErrorCode, sqlTransientException.getErrorCode()); - assertEquals(mongoException, sqlTransientException.getCause()); - }, - () -> assertNull(sqlException.getSQLState())); - } - - private static void assertGenericMongoException( - final MongoException mongoException, final SQLException sqlException) { - int expectedErrorCode = max(0, mongoException.getCode()); - assertAll( - () -> assertThat((Throwable) sqlException).isExactlyInstanceOf(SQLException.class), - () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), - () -> assertEquals(mongoException, sqlException.getCause()), - () -> assertNull(sqlException.getSQLState())); - } - - @ParameterizedTest( - name = - "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") + @ParameterizedTest(name = "test executeUpdate timeout exception. Parameters: Parameters: exception: {0}") @MethodSource("timeoutExceptions") void testExecuteUpdateTimeoutException(MongoException mongoTimeoutException) throws SQLException { doThrow(mongoTimeoutException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); @@ -377,9 +353,7 @@ void testExecuteUpdateTimeoutException(MongoException mongoTimeoutException) thr sqlException -> assertTimeoutException(mongoTimeoutException, sqlException)); } - @ParameterizedTest( - name = - "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") + @ParameterizedTest(name = "test executeQuery timeout exception. Parameters: exception: {0}") @MethodSource("timeoutExceptions") void testExecuteQueryTimeoutException(MongoException mongoTimeoutException) throws SQLException { doThrow(mongoTimeoutException).when(mongoCollection).aggregate(eq(clientSession), anyList()); @@ -387,23 +361,7 @@ void testExecuteQueryTimeoutException(MongoException mongoTimeoutException) thro sqlException -> assertTimeoutException(mongoTimeoutException, sqlException)); } - private static void assertTimeoutException( - final MongoException mongoTimeoutException, final SQLException sqlException) { - int expectedErrorCode = max(0, mongoTimeoutException.getCode()); - assertAll( - () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), - () -> { - SQLTimeoutException sqlTimeoutException = - assertInstanceOf(SQLTimeoutException.class, sqlException); - assertEquals(expectedErrorCode, sqlTimeoutException.getErrorCode()); - assertEquals(mongoTimeoutException, sqlTimeoutException.getCause()); - }, - () -> assertNull(sqlException.getSQLState())); - } - - @ParameterizedTest( - name = - "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") + @ParameterizedTest(name = "test executeUpdate transient timeout exception. Parameters: exception: {0}") @MethodSource("transientTimeoutExceptions") void testExecuteUpdateTransientTimeoutException(MongoException mongoTransientTimeoutException) throws SQLException { @@ -412,9 +370,7 @@ void testExecuteUpdateTransientTimeoutException(MongoException mongoTransientTim sqlException -> assertTransientTimeoutException(mongoTransientTimeoutException, sqlException)); } - @ParameterizedTest( - name = - "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") + @ParameterizedTest(name = "test executeQuery transient timeout exception. Parameters: exception: {0}") @MethodSource("transientTimeoutExceptions") void testExecuteQueryTransientTimeoutException(MongoException mongoTransientTimeoutException) throws SQLException { @@ -423,38 +379,23 @@ void testExecuteQueryTransientTimeoutException(MongoException mongoTransientTime sqlException -> assertTransientTimeoutException(mongoTransientTimeoutException, sqlException)); } - private static void assertTransientTimeoutException( - final MongoException mongoTransientTimeoutException, final SQLException sqlException) { - int expectedErrorCode = max(0, mongoTransientTimeoutException.getCode()); - assertAll( - () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), - () -> { - SQLTransientException sqlTransientException = - assertInstanceOf(SQLTransientException.class, sqlException); - assertEquals(expectedErrorCode, sqlTransientException.getErrorCode()); - SQLTimeoutException sqlTimeoutException = - assertInstanceOf(SQLTimeoutException.class, sqlTransientException.getCause()); - assertEquals(expectedErrorCode, sqlTimeoutException.getErrorCode()); - assertEquals(mongoTransientTimeoutException, sqlTimeoutException.getCause()); - }, - () -> assertNull(sqlException.getSQLState())); - } - - @ParameterizedTest(name = "test executeUpdate constraint violation. Exception code={0}") + @ParameterizedTest(name = "test executeUpdate constraint violation. Parameters: exception: {0}") @MethodSource("constraintViolationExceptions") void testExecuteUpdateConstraintViolationException(MongoException mongoException) throws SQLException { int expectedErrorCode = mongoException.getCode(); doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + assertExecuteUpdateThrowsSqlException(sqlException -> { assertConstraintViolationException(mongoException, sqlException, expectedErrorCode); }); } - @ParameterizedTest(name = "test executeUpdate constraint violation. Exception code={0}") + @ParameterizedTest(name = "test executeQuery constraint violation. Parameters: exception: {0}") @MethodSource("constraintViolationExceptions") void testExecuteQueryConstraintViolationException(MongoException mongoException) throws SQLException { int expectedErrorCode = mongoException.getCode(); doThrow(mongoException).when(mongoCollection).aggregate(eq(clientSession), anyList()); + assertExecuteQueryThrowsSqlException(sqlException -> { assertConstraintViolationException(mongoException, sqlException, expectedErrorCode); }); @@ -473,9 +414,7 @@ private static void assertConstraintViolationException( () -> assertNull(sqlException.getSQLState())); } - @ParameterizedTest( - name = - "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") + @ParameterizedTest(name = "test executeBatch timeout exception. Parameters: exception: {0}") @MethodSource("timeoutExceptions") void testExecuteBatchTimeoutException(MongoException mongoTimeoutException) throws SQLException { doThrow(mongoTimeoutException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); @@ -494,9 +433,7 @@ void testExecuteBatchTimeoutException(MongoException mongoTimeoutException) thro }); } - @ParameterizedTest( - name = - "test executeUpdate throws SQLException. Parameters: exception={0}, expectedSqlExceptionTypeCause={1}") + @ParameterizedTest(name = "test executeBatch transient timeout exception. Parameters: exception: {0}") @MethodSource("transientTimeoutExceptions") void testExecuteBatchTransientTimeoutException(MongoException mongoTimeoutException) throws SQLException { doThrow(mongoTimeoutException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); @@ -518,11 +455,12 @@ void testExecuteBatchTransientTimeoutException(MongoException mongoTimeoutExcept }); } - @ParameterizedTest(name = "test executeBatch constraint violation. Exception code={0}") + @ParameterizedTest(name = "test executeBatch constraint violation. Parameters: exception: {0}") @MethodSource("constraintViolationExceptions") void testExecuteBatchConstraintViolationException(MongoException mongoException) throws SQLException { int expectedErrorCode = mongoException.getCode(); doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + assertExecuteBatchThrowsSqlException(batchUpdateException -> { assertAll( () -> assertEquals(expectedErrorCode, batchUpdateException.getErrorCode()), @@ -568,14 +506,6 @@ void testExecuteQueryRuntimeExceptionCause() throws SQLException { sqlException -> assertGenericException(sqlException, runtimeException)); } - private static void assertGenericException(final SQLException sqlException, RuntimeException cause) { - assertAll( - () -> assertThat((Throwable) sqlException).isExactlyInstanceOf(SQLException.class), - () -> assertEquals(cause, sqlException.getCause()), - () -> assertEquals(0, sqlException.getErrorCode()), - () -> assertNull(sqlException.getSQLState())); - } - private static Stream bulkWriteExceptionsForExecuteUpdate() { return Stream.of( Arguments.of(named("insert", MQL_ITEMS_INSERT), MONGO_BULK_WRITE_EXCEPTION_NO_ERRORS), @@ -587,8 +517,7 @@ private static Stream bulkWriteExceptionsForExecuteUpdate() { } @ParameterizedTest( - name = "test executeUpdate throws SQLException when MongoBulkWriteException occurs." - + " Parameters: commandName={0}, exception={1}") + name = "test executeUpdate MongoBulkWriteException. Parameters: commandName={0}, exception={1}") @MethodSource("bulkWriteExceptionsForExecuteUpdate") void testExecuteUpdateMongoBulkWriteException(String mql, MongoBulkWriteException mongoBulkWriteException) throws SQLException { @@ -600,8 +529,8 @@ void testExecuteUpdateMongoBulkWriteException(String mql, MongoBulkWriteExceptio .isThrownBy(mongoPreparedStatement::executeUpdate) .withCause(mongoBulkWriteException) .satisfies(sqlException -> assertAll( - () -> assertNull(sqlException.getSQLState()), - () -> assertEquals(vendorCodeError, sqlException.getErrorCode()))); + () -> assertEquals(vendorCodeError, sqlException.getErrorCode()), + () -> assertNull(sqlException.getSQLState()))); } } @@ -634,8 +563,7 @@ private static Stream bulkWriteExceptionsForExecuteBatch() { } @ParameterizedTest( - name = "test executeBatch throws BatchUpdateException when MongoBulkWriteException occurs." - + " Parameters: commandName={0}, exception={1}") + name = "test executeBatch MongoBulkWriteException. Parameters: commandName={0}, exception={1}") @MethodSource("bulkWriteExceptionsForExecuteBatch") void testExecuteBatchMongoBulkWriteException( String mql, MongoBulkWriteException mongoBulkWriteException, int expectedUpdateCountLength) @@ -650,14 +578,82 @@ void testExecuteBatchMongoBulkWriteException( .withCause(mongoBulkWriteException) .satisfies(batchUpdateException -> { assertAll( + () -> assertEquals(vendorCodeError, batchUpdateException.getErrorCode()), + () -> assertNull(batchUpdateException.getSQLState()), () -> assertUpdateCounts( batchUpdateException.getUpdateCounts(), expectedUpdateCountLength), - () -> assertNull(batchUpdateException.getSQLState()), () -> assertEquals(vendorCodeError, batchUpdateException.getErrorCode())); }); } } + private static void assertGenericException(final SQLException sqlException, RuntimeException cause) { + assertAll( + () -> assertThat((Throwable) sqlException).isExactlyInstanceOf(SQLException.class), + () -> assertEquals(cause, sqlException.getCause()), + () -> assertEquals(0, sqlException.getErrorCode()), + () -> assertNull(sqlException.getSQLState())); + } + + private static void assertTransientTimeoutException( + final MongoException mongoTransientTimeoutException, final SQLException sqlException) { + int expectedErrorCode = max(0, mongoTransientTimeoutException.getCode()); + assertAll( + () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), + () -> { + SQLTransientException sqlTransientException = + assertInstanceOf(SQLTransientException.class, sqlException); + assertEquals(expectedErrorCode, sqlTransientException.getErrorCode()); + assertNull(sqlTransientException.getSQLState()); + SQLTimeoutException sqlTimeoutException = + assertInstanceOf(SQLTimeoutException.class, sqlTransientException.getCause()); + assertEquals(expectedErrorCode, sqlTimeoutException.getErrorCode()); + assertNull(sqlTimeoutException.getSQLState()); + assertEquals(mongoTransientTimeoutException, sqlTimeoutException.getCause()); + }, + () -> assertNull(sqlException.getSQLState())); + } + + private static void assertGenericTransientMongoException( + final MongoException mongoException, final SQLException sqlException) { + int expectedErrorCode = max(0, mongoException.getCode()); + assertAll( + () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), + () -> { + SQLTransientException sqlTransientException = + assertInstanceOf(SQLTransientException.class, sqlException); + assertEquals(expectedErrorCode, sqlTransientException.getErrorCode()); + assertNull(sqlTransientException.getSQLState()); + assertEquals(mongoException, sqlTransientException.getCause()); + }, + () -> assertNull(sqlException.getSQLState())); + } + + private static void assertGenericMongoException( + final MongoException mongoException, final SQLException sqlException) { + int expectedErrorCode = max(0, mongoException.getCode()); + assertAll( + () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), + () -> assertNull(sqlException.getSQLState()), + () -> assertEquals(mongoException, sqlException.getCause()), + () -> assertThat((Throwable) sqlException).isExactlyInstanceOf(SQLException.class)); + } + + private static void assertTimeoutException( + final MongoException mongoTimeoutException, final SQLException sqlException) { + int expectedErrorCode = max(0, mongoTimeoutException.getCode()); + assertAll( + () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), + () -> assertNull(sqlException.getSQLState()), + () -> { + SQLTimeoutException sqlTimeoutException = + assertInstanceOf(SQLTimeoutException.class, sqlException); + assertEquals(expectedErrorCode, sqlTimeoutException.getErrorCode()); + assertNull(sqlTimeoutException.getSQLState()); + assertEquals(mongoTimeoutException, sqlTimeoutException.getCause()); + }); + } + private void assertExecuteBatchThrowsSqlException(ThrowingConsumer asserter) throws SQLException { try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(MQL_ITEMS_INSERT)) { From e4d7afdd3e376e51f5beeab53dc2cfc3bef0e2d3 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Wed, 8 Oct 2025 16:07:23 -0700 Subject: [PATCH 10/37] Revert SqlTransientException handling. --- .../hibernate/jdbc/MongoStatement.java | 25 +--- .../jdbc/MongoPreparedStatementTests.java | 124 ------------------ 2 files changed, 7 insertions(+), 142 deletions(-) diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index cdd3e19f..40ed7307 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -52,7 +52,6 @@ import java.sql.SQLIntegrityConstraintViolationException; import java.sql.SQLSyntaxErrorException; import java.sql.SQLTimeoutException; -import java.sql.SQLTransientException; import java.sql.SQLWarning; import java.sql.Statement; import java.util.ArrayList; @@ -445,10 +444,6 @@ private static int getErrorCode(final RuntimeException runtimeException) { return DEFAULT_ERROR_CODE; } - private static SQLTransientException toSqlTransientException(final int errorCode, Exception cause) { - return withCause(new SQLTransientException("Transient exception occurred", null, errorCode), cause); - } - private static SQLException toSqlException(final int errorCode, final Exception exception) { if (exception instanceof SQLException sqlException) { return sqlException; @@ -457,16 +452,10 @@ private static SQLException toSqlException(final int errorCode, final Exception } private static Exception handleMongoException(final MongoException exceptionToHandle, final int errorCode) { - Exception exception; if (isTimeoutException(exceptionToHandle)) { - exception = new SQLTimeoutException(EXCEPTION_MESSAGE_TIMEOUT, null, errorCode, exceptionToHandle); - } else { - exception = handleByErrorCode(errorCode, exceptionToHandle); + return new SQLTimeoutException(EXCEPTION_MESSAGE_TIMEOUT, null, errorCode, exceptionToHandle); } - if (exceptionToHandle.hasErrorLabel(MongoException.TRANSIENT_TRANSACTION_ERROR_LABEL)) { - return toSqlTransientException(errorCode, exception); - } - return exception; + return handleByErrorCode(errorCode, exceptionToHandle); } private static SQLException toBatchUpdateException(final int errorCode, final Exception exception) { @@ -475,12 +464,12 @@ private static SQLException toBatchUpdateException(final int errorCode, final Ex exception); } - private static T withCause(T exception, final Exception cause) { - exception.initCause(cause); - if (exception instanceof SQLException sqlException) { - sqlException.setNextException(sqlException); + private static T withCause(T sqlException, final Exception cause) { + sqlException.initCause(cause); + if (cause instanceof SQLException sqlExceptionCause) { + sqlException.setNextException(sqlExceptionCause); } - return exception; + return sqlException; } private static Exception handleByErrorCode(int errorCode, final MongoException cause) { diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java index 1155c298..5468ecee 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java @@ -66,7 +66,6 @@ import java.sql.SQLIntegrityConstraintViolationException; import java.sql.SQLSyntaxErrorException; import java.sql.SQLTimeoutException; -import java.sql.SQLTransientException; import java.sql.Time; import java.sql.Timestamp; import java.sql.Types; @@ -253,11 +252,6 @@ private static Stream timeoutExceptions() { ); } - private static Stream transientTimeoutExceptions() { - return timeoutExceptions() - .peek(mongoException -> mongoException.addLabel(MongoException.TRANSIENT_TRANSACTION_ERROR_LABEL)); - } - private static Stream constraintViolationExceptions() { return Stream.of( new MongoException(11000, DUMMY_EXCEPTION_MESSAGE), @@ -270,13 +264,6 @@ private static Stream genericMongoExceptions() { new MongoException(-3, DUMMY_EXCEPTION_MESSAGE), new MongoException(5000, DUMMY_EXCEPTION_MESSAGE)); } - private static Stream genericTransientMongoExceptions() { - return Stream.of( - new MongoException(-3, DUMMY_EXCEPTION_MESSAGE), - new MongoException(5000, DUMMY_EXCEPTION_MESSAGE)) - .peek(mongoException -> mongoException.addLabel(MongoException.TRANSIENT_TRANSACTION_ERROR_LABEL)); - } - @ParameterizedTest(name = "test executeBatch MongoException. Parameters: Parameters: exception: {0}") @MethodSource("genericMongoExceptions") void testExecuteBatchMongoException(MongoException mongoException) throws SQLException { @@ -308,43 +295,6 @@ void testExecuteQueryMongoException(MongoException mongoException) throws SQLExc sqlException -> assertGenericMongoException(mongoException, sqlException)); } - @ParameterizedTest(name = "test executeBatch transient MongoException. Parameters: Parameters: exception: {0}") - @MethodSource("genericTransientMongoExceptions") - void testExecuteBatchTransientMongoException(MongoException mongoException) throws SQLException { - int expectedErrorCode = max(0, mongoException.getCode()); - doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); - - assertExecuteBatchThrowsSqlException(batchUpdateException -> { - assertAll( - () -> { - SQLTransientException sqlTransientException = - assertInstanceOf(SQLTransientException.class, batchUpdateException.getCause()); - assertEquals(mongoException, sqlTransientException.getCause()); - }, - () -> assertEquals(expectedErrorCode, batchUpdateException.getErrorCode()), - () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0), - () -> assertNull(batchUpdateException.getSQLState())); - }); - } - - @ParameterizedTest(name = "test executeUpdate transient MongoException. Parameters: Parameters: exception: {0}") - @MethodSource("genericTransientMongoExceptions") - void testExecuteUpdateTransientMongoException(MongoException mongoException) throws SQLException { - doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); - assertExecuteUpdateThrowsSqlException(sqlException -> { - assertGenericTransientMongoException(mongoException, sqlException); - }); - } - - @ParameterizedTest(name = "test executeQuery transient MongoException. Parameters: Parameters: exception: {0}") - @MethodSource("genericTransientMongoExceptions") - void testExecuteQueryTransientMongoException(MongoException mongoException) throws SQLException { - doThrow(mongoException).when(mongoCollection).aggregate(eq(clientSession), anyList()); - assertExecuteQueryThrowsSqlException(sqlException -> { - assertGenericTransientMongoException(mongoException, sqlException); - }); - } - @ParameterizedTest(name = "test executeUpdate timeout exception. Parameters: Parameters: exception: {0}") @MethodSource("timeoutExceptions") void testExecuteUpdateTimeoutException(MongoException mongoTimeoutException) throws SQLException { @@ -361,24 +311,6 @@ void testExecuteQueryTimeoutException(MongoException mongoTimeoutException) thro sqlException -> assertTimeoutException(mongoTimeoutException, sqlException)); } - @ParameterizedTest(name = "test executeUpdate transient timeout exception. Parameters: exception: {0}") - @MethodSource("transientTimeoutExceptions") - void testExecuteUpdateTransientTimeoutException(MongoException mongoTransientTimeoutException) - throws SQLException { - doThrow(mongoTransientTimeoutException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); - assertExecuteUpdateThrowsSqlException( - sqlException -> assertTransientTimeoutException(mongoTransientTimeoutException, sqlException)); - } - - @ParameterizedTest(name = "test executeQuery transient timeout exception. Parameters: exception: {0}") - @MethodSource("transientTimeoutExceptions") - void testExecuteQueryTransientTimeoutException(MongoException mongoTransientTimeoutException) - throws SQLException { - doThrow(mongoTransientTimeoutException).when(mongoCollection).aggregate(eq(clientSession), anyList()); - assertExecuteQueryThrowsSqlException( - sqlException -> assertTransientTimeoutException(mongoTransientTimeoutException, sqlException)); - } - @ParameterizedTest(name = "test executeUpdate constraint violation. Parameters: exception: {0}") @MethodSource("constraintViolationExceptions") void testExecuteUpdateConstraintViolationException(MongoException mongoException) throws SQLException { @@ -433,28 +365,6 @@ void testExecuteBatchTimeoutException(MongoException mongoTimeoutException) thro }); } - @ParameterizedTest(name = "test executeBatch transient timeout exception. Parameters: exception: {0}") - @MethodSource("transientTimeoutExceptions") - void testExecuteBatchTransientTimeoutException(MongoException mongoTimeoutException) throws SQLException { - doThrow(mongoTimeoutException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); - assertExecuteBatchThrowsSqlException(batchUpdateException -> { - int expectedErrorCode = max(0, mongoTimeoutException.getCode()); - assertAll( - () -> assertEquals(expectedErrorCode, batchUpdateException.getErrorCode()), - () -> { - SQLTransientException sqlTransientException = - assertInstanceOf(SQLTransientException.class, batchUpdateException.getCause()); - assertEquals(expectedErrorCode, sqlTransientException.getErrorCode()); - SQLTimeoutException sqlTimeoutException = - assertInstanceOf(SQLTimeoutException.class, sqlTransientException.getCause()); - assertEquals(expectedErrorCode, sqlTimeoutException.getErrorCode()); - assertEquals(mongoTimeoutException, sqlTimeoutException.getCause()); - }, - () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0), - () -> assertNull(batchUpdateException.getSQLState())); - }); - } - @ParameterizedTest(name = "test executeBatch constraint violation. Parameters: exception: {0}") @MethodSource("constraintViolationExceptions") void testExecuteBatchConstraintViolationException(MongoException mongoException) throws SQLException { @@ -595,40 +505,6 @@ private static void assertGenericException(final SQLException sqlException, Runt () -> assertNull(sqlException.getSQLState())); } - private static void assertTransientTimeoutException( - final MongoException mongoTransientTimeoutException, final SQLException sqlException) { - int expectedErrorCode = max(0, mongoTransientTimeoutException.getCode()); - assertAll( - () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), - () -> { - SQLTransientException sqlTransientException = - assertInstanceOf(SQLTransientException.class, sqlException); - assertEquals(expectedErrorCode, sqlTransientException.getErrorCode()); - assertNull(sqlTransientException.getSQLState()); - SQLTimeoutException sqlTimeoutException = - assertInstanceOf(SQLTimeoutException.class, sqlTransientException.getCause()); - assertEquals(expectedErrorCode, sqlTimeoutException.getErrorCode()); - assertNull(sqlTimeoutException.getSQLState()); - assertEquals(mongoTransientTimeoutException, sqlTimeoutException.getCause()); - }, - () -> assertNull(sqlException.getSQLState())); - } - - private static void assertGenericTransientMongoException( - final MongoException mongoException, final SQLException sqlException) { - int expectedErrorCode = max(0, mongoException.getCode()); - assertAll( - () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), - () -> { - SQLTransientException sqlTransientException = - assertInstanceOf(SQLTransientException.class, sqlException); - assertEquals(expectedErrorCode, sqlTransientException.getErrorCode()); - assertNull(sqlTransientException.getSQLState()); - assertEquals(mongoException, sqlTransientException.getCause()); - }, - () -> assertNull(sqlException.getSQLState())); - } - private static void assertGenericMongoException( final MongoException mongoException, final SQLException sqlException) { int expectedErrorCode = max(0, mongoException.getCode()); From 183aafcb47f6971032b7f42475cd990478ab7123 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Wed, 8 Oct 2025 16:42:11 -0700 Subject: [PATCH 11/37] Remove SQlConsumer. --- .../mongodb/hibernate/jdbc/MongoPreparedStatementTests.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java index 5468ecee..4abcb996 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java @@ -708,8 +708,4 @@ private static void assertThrowsClosedException(Executable executable) { var exception = assertThrows(SQLException.class, executable); assertThat(exception.getMessage()).isEqualTo("MongoPreparedStatement has been closed"); } - - interface SqlConsumer { - void accept(T t) throws SQLException; - } } From 9577603cfd3aec7f8d4076c7f71870dfe9a3dd46 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Tue, 21 Oct 2025 16:42:57 -0700 Subject: [PATCH 12/37] Fix issues. --- ...ongoPreparedStatementIntegrationTests.java | 8 +- .../jdbc/MongoPreparedStatement.java | 17 + .../hibernate/jdbc/MongoStatement.java | 299 +++++++++--------- .../jdbc/MongoPreparedStatementTests.java | 166 +++++----- 4 files changed, 240 insertions(+), 250 deletions(-) diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java index 77faac90..df4041d3 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java @@ -268,11 +268,9 @@ void testQueriesReturningResult() { pstm.addBatch(); assertThatExceptionOfType(BatchUpdateException.class) .isThrownBy(pstm::executeBatch) - .satisfies(batchUpdateException -> { - assertNull(batchUpdateException.getUpdateCounts()); - assertNull(batchUpdateException.getSQLState()); - assertEquals(0, batchUpdateException.getErrorCode()); - }); + .returns(null, BatchUpdateException::getUpdateCounts) + .returns(null, BatchUpdateException::getSQLState) + .returns(0, BatchUpdateException::getErrorCode); } catch (SQLException e) { throw new RuntimeException(e); } diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java index 55532a0f..1b569f6e 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java @@ -28,6 +28,7 @@ import com.mongodb.hibernate.internal.type.ObjectIdJdbcType; import java.math.BigDecimal; import java.sql.Array; +import java.sql.BatchUpdateException; import java.sql.Date; import java.sql.JDBCType; import java.sql.PreparedStatement; @@ -233,6 +234,22 @@ public int[] executeBatch() throws SQLException { } } + private void checkSupportedBatchCommand(BsonDocument command) throws SQLException { + var commandType = getCommandType(command); + if (commandType == CommandType.AGGREGATE) { + // The method executeBatch throws a BatchUpdateException if any of the commands in the batch attempts to + // return a result set. + throw new BatchUpdateException( + format( + "Commands returning result set are not supported. Received command: %s", + commandType.getCommandName()), + null, + NO_ERROR_CODE, + null); + } + checkSupportedUpdateCommand(commandType); + } + @Override public void setArray(int parameterIndex, Array x) throws SQLException { checkClosed(); diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 40ed7307..db4cfb02 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -64,14 +64,11 @@ import org.jspecify.annotations.Nullable; class MongoStatement implements StatementAdapter { - - private static final List SUPPORTED_UPDATE_COMMAND_ELEMENTS = List.of("q", "u", "multi"); - private static final List SUPPORTED_DELETE_COMMAND_ELEMENTS = List.of("q", "limit"); private static final String EXCEPTION_MESSAGE_OPERATION_FAILED = "Failed to execute operation"; private static final String EXCEPTION_MESSAGE_BATCH_FAILED = "Batch execution failed"; private static final String EXCEPTION_MESSAGE_TIMEOUT = "Timeout while waiting for operation to complete"; - private static final int DEFAULT_ERROR_CODE = 0; - static final int[] EMPTY_BATCH_RESULT = new int[DEFAULT_ERROR_CODE]; + static final int NO_ERROR_CODE = 0; + static final int[] EMPTY_BATCH_RESULT = new int[0]; private final MongoDatabase mongoDatabase; private final MongoConnection mongoConnection; private final ClientSession clientSession; @@ -135,7 +132,7 @@ private static boolean isExcludeProjectSpecification(Map.Entry commandBatch, ExecutionType executionType) throws SQLException { - var firstDocumentInBatch = commandBatch.get(DEFAULT_ERROR_CODE); + var firstDocumentInBatch = commandBatch.get(0); var commandType = getCommandType(firstDocumentInBatch); var collection = getCollection(commandType, firstDocumentInBatch); try { startTransactionIfNeeded(); var writeModels = new ArrayList>(commandBatch.size()); for (var command : commandBatch) { - convertToWriteModels(commandType, command, writeModels); + WriteModelConverter.convertToWriteModels(commandType, command, writeModels); } var bulkWriteResult = collection.bulkWrite(clientSession, writeModels); return getUpdateCount(commandType, bulkWriteResult); @@ -265,7 +262,7 @@ void checkSupportedUpdateCommand(BsonDocument command) throws SQLException { checkSupportedUpdateCommand(getCommandType(command)); } - private void checkSupportedUpdateCommand(CommandType commandType) throws SQLException { + void checkSupportedUpdateCommand(CommandType commandType) throws SQLException { if (commandType != CommandType.INSERT && commandType != CommandType.UPDATE && commandType != CommandType.DELETE) { @@ -274,22 +271,6 @@ private void checkSupportedUpdateCommand(CommandType commandType) throws SQLExce } } - void checkSupportedBatchCommand(BsonDocument command) throws SQLException { - var commandType = getCommandType(command); - if (commandType == CommandType.AGGREGATE) { - // The method executeBatch throws a BatchUpdateException if any of the commands in the batch attempts to - // return a result set. - throw new BatchUpdateException( - format( - "Commands returning result set are not supported. Received command: %s", - commandType.getCommandName()), - null, - DEFAULT_ERROR_CODE, - null); - } - checkSupportedUpdateCommand(commandType); - } - static BsonDocument parse(String mql) throws SQLSyntaxErrorException { try { return BsonDocument.parse(mql); @@ -319,86 +300,7 @@ private MongoCollection getCollection(CommandType commandType, Bso return mongoDatabase.getCollection(collectionName, BsonDocument.class); } - private static void convertToWriteModels( - CommandType commandType, BsonDocument command, Collection> writeModels) - throws SQLFeatureNotSupportedException { - switch (commandType) { - case INSERT: - var documents = command.getArray("documents"); - for (var insertDocument : documents) { - writeModels.add(createInsertModel(insertDocument.asDocument())); - } - break; - case UPDATE: - var updates = command.getArray("updates").getValues(); - for (var updateDocument : updates) { - writeModels.add(createUpdateModel(updateDocument.asDocument())); - } - break; - case DELETE: - var deletes = command.getArray("deletes"); - for (var deleteDocument : deletes) { - writeModels.add(createDeleteModel(deleteDocument.asDocument())); - } - break; - default: - throw fail(); - } - } - - private static WriteModel createInsertModel(final BsonDocument insertDocument) { - return new InsertOneModel<>(insertDocument); - } - - private static WriteModel createDeleteModel(final BsonDocument deleteDocument) - throws SQLFeatureNotSupportedException { - checkDeleteElements(deleteDocument); - var isSingleDelete = deleteDocument.getNumber("limit").intValue() == 1; - var queryFilter = deleteDocument.getDocument("q"); - - if (isSingleDelete) { - return new DeleteOneModel<>(queryFilter); - } - return new DeleteManyModel<>(queryFilter); - } - - private static WriteModel createUpdateModel(final BsonDocument updateDocument) - throws SQLFeatureNotSupportedException { - checkUpdateElements(updateDocument); - var isMulti = updateDocument.getBoolean("multi").getValue(); - var queryFilter = updateDocument.getDocument("q"); - var updatePipeline = updateDocument.getDocument("u"); - - if (isMulti) { - return new UpdateManyModel<>(queryFilter, updatePipeline); - } - return new UpdateOneModel<>(queryFilter, updatePipeline); - } - - private static void checkDeleteElements(final BsonDocument deleteDocument) throws SQLFeatureNotSupportedException { - if (deleteDocument.size() > SUPPORTED_DELETE_COMMAND_ELEMENTS.size()) { - var unSupportedElements = getUnsupportedElements(deleteDocument, SUPPORTED_DELETE_COMMAND_ELEMENTS); - throw new SQLFeatureNotSupportedException( - format("Unsupported elements in delete command: %s", unSupportedElements)); - } - } - - private static void checkUpdateElements(final BsonDocument updateDocument) throws SQLFeatureNotSupportedException { - if (updateDocument.size() > SUPPORTED_UPDATE_COMMAND_ELEMENTS.size()) { - var unSupportedElements = getUnsupportedElements(updateDocument, SUPPORTED_UPDATE_COMMAND_ELEMENTS); - throw new SQLFeatureNotSupportedException( - format("Unsupported elements in update command: %s", unSupportedElements)); - } - } - - private static List getUnsupportedElements( - final BsonDocument deleteDocument, final List supportedElements) { - return deleteDocument.keySet().stream() - .filter((key) -> !supportedElements.contains(key)) - .toList(); - } - - static int getUpdateCount(CommandType commandType, BulkWriteResult bulkWriteResult) { + private static int getUpdateCount(CommandType commandType, BulkWriteResult bulkWriteResult) { return switch (commandType) { case INSERT -> bulkWriteResult.getInsertedCount(); case UPDATE -> bulkWriteResult.getModifiedCount(); @@ -408,40 +310,64 @@ static int getUpdateCount(CommandType commandType, BulkWriteResult bulkWriteResu } private static SQLException handleException( - RuntimeException exception, CommandType commandType, ExecutionType executionType) { - int errorCode = getErrorCode(exception); + RuntimeException exceptionToHandle, CommandType commandType, ExecutionType executionType) { + var errorCode = getErrorCode(exceptionToHandle); return switch (executionType) { - case BATCH -> handleBatchException(exception, commandType, errorCode); + case BATCH -> handleBatchException(exceptionToHandle, commandType, errorCode); case QUERY, UPDATE -> { - if (exception instanceof MongoException mongoException) { - Exception handledException = handleMongoException(mongoException, errorCode); + if (exceptionToHandle instanceof MongoException mongoException) { + var handledException = handleMongoException(mongoException, errorCode); yield toSqlException(errorCode, handledException); } - yield toSqlException(DEFAULT_ERROR_CODE, exception); + yield toSqlException(NO_ERROR_CODE, exceptionToHandle); } }; } private static SQLException handleBatchException( - RuntimeException exception, CommandType commandType, int errorCode) { - if (exception instanceof MongoException mongoException) { - Exception cause = handleMongoException(mongoException, errorCode); - if (exception instanceof MongoBulkWriteException bulkWriteException) { + RuntimeException exceptionToHandle, CommandType commandType, int errorCode) { + if (exceptionToHandle instanceof MongoException mongoException) { + var cause = handleMongoException(mongoException, errorCode); + if (exceptionToHandle instanceof MongoBulkWriteException bulkWriteException) { return createBatchUpdateException(cause, bulkWriteException.getWriteResult(), errorCode, commandType); } - return toBatchUpdateException(errorCode, cause); + return createBatchUpdateException(errorCode, cause); } - return toBatchUpdateException(DEFAULT_ERROR_CODE, exception); + return createBatchUpdateException(NO_ERROR_CODE, exceptionToHandle); + } + + private static Exception handleMongoException(final MongoException exceptionToHandle, final int errorCode) { + if (isTimeoutException(exceptionToHandle)) { + return new SQLTimeoutException(EXCEPTION_MESSAGE_TIMEOUT, null, errorCode, exceptionToHandle); + } + return handleByErrorCode(errorCode, exceptionToHandle); + } + + private static Exception handleByErrorCode(int errorCode, final MongoException exceptionToHandle) { + var errorCategory = ErrorCategory.fromErrorCode(errorCode); + return switch (errorCategory) { + case DUPLICATE_KEY -> + new SQLIntegrityConstraintViolationException( + EXCEPTION_MESSAGE_OPERATION_FAILED, null, errorCode, exceptionToHandle); + case EXECUTION_TIMEOUT -> + new SQLTimeoutException(EXCEPTION_MESSAGE_TIMEOUT, null, errorCode, exceptionToHandle); + case UNCATEGORIZED -> exceptionToHandle; + }; } private static int getErrorCode(final RuntimeException runtimeException) { if (runtimeException instanceof MongoBulkWriteException mongoBulkWriteException) { return getErrorCode(mongoBulkWriteException); + } else if (runtimeException instanceof MongoException mongoException) { + return max(NO_ERROR_CODE, mongoException.getCode()); } - if (runtimeException instanceof MongoException mongoException) { - return max(DEFAULT_ERROR_CODE, mongoException.getCode()); - } - return DEFAULT_ERROR_CODE; + return NO_ERROR_CODE; + } + + private static int getErrorCode(final MongoBulkWriteException mongoBulkWriteException) { + var writeErrors = mongoBulkWriteException.getWriteErrors(); + // Since we are executing an ordered bulk write, there will be at most one BulkWriteError. + return writeErrors.isEmpty() ? NO_ERROR_CODE : writeErrors.get(0).getCode(); } private static SQLException toSqlException(final int errorCode, final Exception exception) { @@ -451,17 +377,18 @@ private static SQLException toSqlException(final int errorCode, final Exception return new SQLException(EXCEPTION_MESSAGE_OPERATION_FAILED, null, errorCode, exception); } - private static Exception handleMongoException(final MongoException exceptionToHandle, final int errorCode) { - if (isTimeoutException(exceptionToHandle)) { - return new SQLTimeoutException(EXCEPTION_MESSAGE_TIMEOUT, null, errorCode, exceptionToHandle); - } - return handleByErrorCode(errorCode, exceptionToHandle); + private static SQLException createBatchUpdateException(final int errorCode, final Exception cause) { + return withCause( + new BatchUpdateException(EXCEPTION_MESSAGE_BATCH_FAILED, null, errorCode, EMPTY_BATCH_RESULT), cause); } - private static SQLException toBatchUpdateException(final int errorCode, final Exception exception) { + private static BatchUpdateException createBatchUpdateException( + Exception cause, BulkWriteResult bulkWriteResult, int errorCode, CommandType commandType) { + var updateCount = getUpdateCount(commandType, bulkWriteResult); + var updateCounts = new int[updateCount]; + Arrays.fill(updateCounts, Statement.SUCCESS_NO_INFO); return withCause( - new BatchUpdateException(EXCEPTION_MESSAGE_BATCH_FAILED, null, errorCode, EMPTY_BATCH_RESULT), - exception); + new BatchUpdateException(EXCEPTION_MESSAGE_BATCH_FAILED, null, errorCode, updateCounts), cause); } private static T withCause(T sqlException, final Exception cause) { @@ -472,17 +399,6 @@ private static T withCause(T sqlException, final Except return sqlException; } - private static Exception handleByErrorCode(int errorCode, final MongoException cause) { - ErrorCategory errorCategory = ErrorCategory.fromErrorCode(errorCode); - return switch (errorCategory) { - case DUPLICATE_KEY -> - new SQLIntegrityConstraintViolationException( - EXCEPTION_MESSAGE_OPERATION_FAILED, null, errorCode, cause); - case EXECUTION_TIMEOUT -> new SQLTimeoutException(EXCEPTION_MESSAGE_TIMEOUT, null, errorCode, cause); - case UNCATEGORIZED -> cause; - }; - } - private static boolean isTimeoutException(final MongoException exception) { return exception instanceof MongoSocketReadTimeoutException || exception instanceof MongoSocketWriteTimeoutException @@ -490,23 +406,6 @@ private static boolean isTimeoutException(final MongoException exception) { || exception instanceof MongoExecutionTimeoutException; } - private static BatchUpdateException createBatchUpdateException( - Exception cause, BulkWriteResult bulkWriteResult, int errorCode, CommandType commandType) { - var updateCount = getUpdateCount(commandType, bulkWriteResult); - var updateCounts = new int[updateCount]; - Arrays.fill(updateCounts, Statement.SUCCESS_NO_INFO); - return withCause( - new BatchUpdateException(EXCEPTION_MESSAGE_BATCH_FAILED, null, errorCode, updateCounts), cause); - } - - private static int getErrorCode(final MongoBulkWriteException mongoBulkWriteException) { - var writeErrors = mongoBulkWriteException.getWriteErrors(); - // Since we are executing an ordered bulk write, there will be at most one BulkWriteError. - return writeErrors.isEmpty() - ? DEFAULT_ERROR_CODE - : writeErrors.get(DEFAULT_ERROR_CODE).getCode(); - } - enum CommandType { INSERT("insert"), UPDATE("update"), @@ -539,4 +438,90 @@ enum ExecutionType { BATCH, QUERY } + + private static class WriteModelConverter { + private static final List SUPPORTED_UPDATE_COMMAND_ELEMENTS = List.of("q", "u", "multi"); + private static final List SUPPORTED_DELETE_COMMAND_ELEMENTS = List.of("q", "limit"); + + static void convertToWriteModels( + CommandType commandType, BsonDocument command, Collection> writeModels) + throws SQLFeatureNotSupportedException { + switch (commandType) { + case INSERT: + var documents = command.getArray("documents"); + for (var insertDocument : documents) { + writeModels.add(createInsertModel(insertDocument.asDocument())); + } + break; + case UPDATE: + var updates = command.getArray("updates").getValues(); + for (var updateDocument : updates) { + writeModels.add(createUpdateModel(updateDocument.asDocument())); + } + break; + case DELETE: + var deletes = command.getArray("deletes"); + for (var deleteDocument : deletes) { + writeModels.add(createDeleteModel(deleteDocument.asDocument())); + } + break; + default: + throw fail(); + } + } + + private static WriteModel createInsertModel(final BsonDocument insertDocument) { + return new InsertOneModel<>(insertDocument); + } + + private static WriteModel createDeleteModel(final BsonDocument deleteDocument) + throws SQLFeatureNotSupportedException { + checkDeleteElements(deleteDocument); + var isSingleDelete = deleteDocument.getNumber("limit").intValue() == 1; + var queryFilter = deleteDocument.getDocument("q"); + + if (isSingleDelete) { + return new DeleteOneModel<>(queryFilter); + } + return new DeleteManyModel<>(queryFilter); + } + + private static WriteModel createUpdateModel(final BsonDocument updateDocument) + throws SQLFeatureNotSupportedException { + checkUpdateElements(updateDocument); + var isMulti = updateDocument.getBoolean("multi").getValue(); + var queryFilter = updateDocument.getDocument("q"); + var updatePipeline = updateDocument.getDocument("u"); + + if (isMulti) { + return new UpdateManyModel<>(queryFilter, updatePipeline); + } + return new UpdateOneModel<>(queryFilter, updatePipeline); + } + + private static void checkDeleteElements(final BsonDocument deleteDocument) + throws SQLFeatureNotSupportedException { + if (deleteDocument.size() > SUPPORTED_DELETE_COMMAND_ELEMENTS.size()) { + var unSupportedElements = getUnsupportedElements(deleteDocument, SUPPORTED_DELETE_COMMAND_ELEMENTS); + throw new SQLFeatureNotSupportedException( + format("Unsupported elements in delete command: %s", unSupportedElements)); + } + } + + private static void checkUpdateElements(final BsonDocument updateDocument) + throws SQLFeatureNotSupportedException { + if (updateDocument.size() > SUPPORTED_UPDATE_COMMAND_ELEMENTS.size()) { + var unSupportedElements = getUnsupportedElements(updateDocument, SUPPORTED_UPDATE_COMMAND_ELEMENTS); + throw new SQLFeatureNotSupportedException( + format("Unsupported elements in update command: %s", unSupportedElements)); + } + } + + private static List getUnsupportedElements( + final BsonDocument deleteDocument, final List supportedElements) { + return deleteDocument.keySet().stream() + .filter((key) -> !supportedElements.contains(key)) + .toList(); + } + } } diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java index 4abcb996..2a376611 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java @@ -22,11 +22,12 @@ import static java.util.Collections.emptySet; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatObject; +import static org.assertj.core.api.InstanceOfAssertFactories.type; import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Named.named; @@ -264,38 +265,41 @@ private static Stream genericMongoExceptions() { new MongoException(-3, DUMMY_EXCEPTION_MESSAGE), new MongoException(5000, DUMMY_EXCEPTION_MESSAGE)); } - @ParameterizedTest(name = "test executeBatch MongoException. Parameters: Parameters: exception: {0}") + @ParameterizedTest(name = "test executeBatch MongoException. Parameters: Parameters: mongoException: {0}") @MethodSource("genericMongoExceptions") void testExecuteBatchMongoException(MongoException mongoException) throws SQLException { int expectedErrorCode = max(0, mongoException.getCode()); doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); assertExecuteBatchThrowsSqlException(batchUpdateException -> { - assertAll( - () -> assertEquals(expectedErrorCode, batchUpdateException.getErrorCode()), - () -> assertNull(batchUpdateException.getSQLState()), - () -> assertEquals(mongoException, batchUpdateException.getCause()), - () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0)); + assertThatObject(batchUpdateException) + .returns(expectedErrorCode, BatchUpdateException::getErrorCode) + .returns(null, BatchUpdateException::getSQLState) + .returns(mongoException, SQLException::getCause) + .satisfies(exception -> { + assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0); + }); }); } - @ParameterizedTest(name = "test executeUpdate MongoException. Parameters: Parameters: exception: {0}") + @ParameterizedTest(name = "test executeUpdate MongoException. Parameters: Parameters: mongoException: {0}") @MethodSource("genericMongoExceptions") void testExecuteUpdateMongoException(MongoException mongoException) throws SQLException { doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); assertExecuteUpdateThrowsSqlException( - sqlException -> assertGenericMongoException(mongoException, sqlException)); + sqlException -> assertGenericMongoException(sqlException, mongoException)); } - @ParameterizedTest(name = "test executeUQuery MongoException. Parameters: Parameters: exception: {0}") + @ParameterizedTest(name = "test executeUQuery MongoException. Parameters: Parameters: mongoException: {0}") @MethodSource("genericMongoExceptions") void testExecuteQueryMongoException(MongoException mongoException) throws SQLException { doThrow(mongoException).when(mongoCollection).aggregate(eq(clientSession), anyList()); assertExecuteQueryThrowsSqlException( - sqlException -> assertGenericMongoException(mongoException, sqlException)); + sqlException -> assertGenericMongoException(sqlException, mongoException)); } - @ParameterizedTest(name = "test executeUpdate timeout exception. Parameters: Parameters: exception: {0}") + @ParameterizedTest( + name = "test executeUpdate timeout exception. Parameters: Parameters: mongoTimeoutException: {0}") @MethodSource("timeoutExceptions") void testExecuteUpdateTimeoutException(MongoException mongoTimeoutException) throws SQLException { doThrow(mongoTimeoutException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); @@ -303,7 +307,7 @@ void testExecuteUpdateTimeoutException(MongoException mongoTimeoutException) thr sqlException -> assertTimeoutException(mongoTimeoutException, sqlException)); } - @ParameterizedTest(name = "test executeQuery timeout exception. Parameters: exception: {0}") + @ParameterizedTest(name = "test executeQuery timeout exception. Parameters: mongoTimeoutException: {0}") @MethodSource("timeoutExceptions") void testExecuteQueryTimeoutException(MongoException mongoTimeoutException) throws SQLException { doThrow(mongoTimeoutException).when(mongoCollection).aggregate(eq(clientSession), anyList()); @@ -311,7 +315,7 @@ void testExecuteQueryTimeoutException(MongoException mongoTimeoutException) thro sqlException -> assertTimeoutException(mongoTimeoutException, sqlException)); } - @ParameterizedTest(name = "test executeUpdate constraint violation. Parameters: exception: {0}") + @ParameterizedTest(name = "test executeUpdate constraint violation. Parameters: mongoException: {0}") @MethodSource("constraintViolationExceptions") void testExecuteUpdateConstraintViolationException(MongoException mongoException) throws SQLException { int expectedErrorCode = mongoException.getCode(); @@ -322,7 +326,7 @@ void testExecuteUpdateConstraintViolationException(MongoException mongoException }); } - @ParameterizedTest(name = "test executeQuery constraint violation. Parameters: exception: {0}") + @ParameterizedTest(name = "test executeQuery constraint violation. Parameters: mongoException: {0}") @MethodSource("constraintViolationExceptions") void testExecuteQueryConstraintViolationException(MongoException mongoException) throws SQLException { int expectedErrorCode = mongoException.getCode(); @@ -335,55 +339,49 @@ void testExecuteQueryConstraintViolationException(MongoException mongoException) private static void assertConstraintViolationException( final MongoException mongoException, final SQLException sqlException, final int expectedErrorCode) { - assertAll( - () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), - () -> { - SQLIntegrityConstraintViolationException sqlIntegrityConstraintViolationException = - assertInstanceOf(SQLIntegrityConstraintViolationException.class, sqlException); - assertEquals(expectedErrorCode, sqlIntegrityConstraintViolationException.getErrorCode()); - assertEquals(mongoException, sqlIntegrityConstraintViolationException.getCause()); - }, - () -> assertNull(sqlException.getSQLState())); - } - - @ParameterizedTest(name = "test executeBatch timeout exception. Parameters: exception: {0}") + assertThatObject(sqlException) + .asInstanceOf(type(SQLIntegrityConstraintViolationException.class)) + .returns(expectedErrorCode, SQLIntegrityConstraintViolationException::getErrorCode) + .returns(null, SQLIntegrityConstraintViolationException::getSQLState) + .returns(mongoException, SQLIntegrityConstraintViolationException::getCause); + } + + @ParameterizedTest(name = "test executeBatch timeout exception. Parameters: mongoTimeoutException: {0}") @MethodSource("timeoutExceptions") void testExecuteBatchTimeoutException(MongoException mongoTimeoutException) throws SQLException { doThrow(mongoTimeoutException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); assertExecuteBatchThrowsSqlException(batchUpdateException -> { int expectedErrorCode = max(0, mongoTimeoutException.getCode()); - assertAll( - () -> assertEquals(expectedErrorCode, batchUpdateException.getErrorCode()), - () -> { - SQLTimeoutException sqlTimeoutException = - assertInstanceOf(SQLTimeoutException.class, batchUpdateException.getCause()); - assertEquals(expectedErrorCode, sqlTimeoutException.getErrorCode()); - assertEquals(mongoTimeoutException, sqlTimeoutException.getCause()); - }, - () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0), - () -> assertNull(batchUpdateException.getSQLState())); + assertThatObject(batchUpdateException) + .returns(expectedErrorCode, BatchUpdateException::getErrorCode) + .returns(null, BatchUpdateException::getSQLState) + .satisfies(ex -> { + assertUpdateCounts(ex.getUpdateCounts(), 0); + }) + .extracting(SQLException::getCause) + .asInstanceOf(type(SQLTimeoutException.class)) + .returns(expectedErrorCode, SQLTimeoutException::getErrorCode) + .returns(mongoTimeoutException, SQLTimeoutException::getCause); }); } - @ParameterizedTest(name = "test executeBatch constraint violation. Parameters: exception: {0}") + @ParameterizedTest(name = "test executeBatch constraint violation. Parameters: mongoException: {0}") @MethodSource("constraintViolationExceptions") void testExecuteBatchConstraintViolationException(MongoException mongoException) throws SQLException { int expectedErrorCode = mongoException.getCode(); doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); assertExecuteBatchThrowsSqlException(batchUpdateException -> { - assertAll( - () -> assertEquals(expectedErrorCode, batchUpdateException.getErrorCode()), - () -> { - SQLIntegrityConstraintViolationException sqlIntegrityConstraintViolationException = - assertInstanceOf( - SQLIntegrityConstraintViolationException.class, - batchUpdateException.getCause()); - assertEquals(expectedErrorCode, sqlIntegrityConstraintViolationException.getErrorCode()); - assertEquals(mongoException, sqlIntegrityConstraintViolationException.getCause()); - }, - () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0), - () -> assertNull(batchUpdateException.getSQLState())); + assertThatObject(batchUpdateException) + .returns(expectedErrorCode, BatchUpdateException::getErrorCode) + .returns(null, BatchUpdateException::getSQLState) + .satisfies(ex -> { + assertUpdateCounts(ex.getUpdateCounts(), 0); + }) + .extracting(SQLException::getCause) + .asInstanceOf(type(SQLIntegrityConstraintViolationException.class)) + .returns(expectedErrorCode, SQLIntegrityConstraintViolationException::getErrorCode) + .returns(mongoException, SQLIntegrityConstraintViolationException::getCause); }); } @@ -392,11 +390,13 @@ void testExecuteBatchRuntimeExceptionCause() throws SQLException { RuntimeException runtimeException = new RuntimeException(); doThrow(runtimeException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); assertExecuteBatchThrowsSqlException(batchUpdateException -> { - assertAll( - () -> assertEquals(runtimeException, batchUpdateException.getCause()), - () -> assertEquals(0, batchUpdateException.getErrorCode()), - () -> assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0), - () -> assertNull(batchUpdateException.getSQLState())); + assertThatObject(batchUpdateException) + .returns(0, BatchUpdateException::getErrorCode) + .returns(null, BatchUpdateException::getSQLState) + .returns(runtimeException, BatchUpdateException::getCause) + .satisfies(ex -> { + assertUpdateCounts(ex.getUpdateCounts(), 0); + }); }); } @@ -438,9 +438,8 @@ void testExecuteUpdateMongoBulkWriteException(String mql, MongoBulkWriteExceptio assertThatExceptionOfType(SQLException.class) .isThrownBy(mongoPreparedStatement::executeUpdate) .withCause(mongoBulkWriteException) - .satisfies(sqlException -> assertAll( - () -> assertEquals(vendorCodeError, sqlException.getErrorCode()), - () -> assertNull(sqlException.getSQLState()))); + .returns(vendorCodeError, SQLException::getErrorCode) + .returns(null, SQLException::getSQLState); } } @@ -486,48 +485,39 @@ void testExecuteBatchMongoBulkWriteException( assertThatExceptionOfType(BatchUpdateException.class) .isThrownBy(mongoPreparedStatement::executeBatch) .withCause(mongoBulkWriteException) - .satisfies(batchUpdateException -> { - assertAll( - () -> assertEquals(vendorCodeError, batchUpdateException.getErrorCode()), - () -> assertNull(batchUpdateException.getSQLState()), - () -> assertUpdateCounts( - batchUpdateException.getUpdateCounts(), expectedUpdateCountLength), - () -> assertEquals(vendorCodeError, batchUpdateException.getErrorCode())); + .returns(vendorCodeError, BatchUpdateException::getErrorCode) + .returns(null, BatchUpdateException::getSQLState) + .satisfies(ex -> { + assertUpdateCounts(ex.getUpdateCounts(), expectedUpdateCountLength); }); } } private static void assertGenericException(final SQLException sqlException, RuntimeException cause) { - assertAll( - () -> assertThat((Throwable) sqlException).isExactlyInstanceOf(SQLException.class), - () -> assertEquals(cause, sqlException.getCause()), - () -> assertEquals(0, sqlException.getErrorCode()), - () -> assertNull(sqlException.getSQLState())); + assertThatObject(sqlException) + .isExactlyInstanceOf(SQLException.class) + .returns(0, SQLException::getErrorCode) + .returns(null, SQLException::getSQLState) + .returns(cause, SQLException::getCause); } - private static void assertGenericMongoException( - final MongoException mongoException, final SQLException sqlException) { - int expectedErrorCode = max(0, mongoException.getCode()); - assertAll( - () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), - () -> assertNull(sqlException.getSQLState()), - () -> assertEquals(mongoException, sqlException.getCause()), - () -> assertThat((Throwable) sqlException).isExactlyInstanceOf(SQLException.class)); + private static void assertGenericMongoException(final SQLException sqlException, final MongoException cause) { + int expectedErrorCode = max(0, cause.getCode()); + assertThatObject(sqlException) + .isExactlyInstanceOf(SQLException.class) + .returns(expectedErrorCode, SQLException::getErrorCode) + .returns(null, SQLException::getSQLState) + .returns(cause, SQLException::getCause); } private static void assertTimeoutException( final MongoException mongoTimeoutException, final SQLException sqlException) { int expectedErrorCode = max(0, mongoTimeoutException.getCode()); - assertAll( - () -> assertEquals(expectedErrorCode, sqlException.getErrorCode()), - () -> assertNull(sqlException.getSQLState()), - () -> { - SQLTimeoutException sqlTimeoutException = - assertInstanceOf(SQLTimeoutException.class, sqlException); - assertEquals(expectedErrorCode, sqlTimeoutException.getErrorCode()); - assertNull(sqlTimeoutException.getSQLState()); - assertEquals(mongoTimeoutException, sqlTimeoutException.getCause()); - }); + assertThatObject(sqlException) + .asInstanceOf(type(SQLTimeoutException.class)) + .returns(expectedErrorCode, SQLTimeoutException::getErrorCode) + .returns(null, SQLTimeoutException::getSQLState) + .returns(mongoTimeoutException, SQLTimeoutException::getCause); } private void assertExecuteBatchThrowsSqlException(ThrowingConsumer asserter) From 76acbc5c87b3234a6abcfe71f7c2f85505fae1d7 Mon Sep 17 00:00:00 2001 From: Viacheslav Babanin Date: Fri, 24 Oct 2025 21:42:51 -0700 Subject: [PATCH 13/37] Update src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java Co-authored-by: Valentin Kovalenko --- .../jdbc/MongoPreparedStatementIntegrationTests.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java index df4041d3..31454faf 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java @@ -673,9 +673,9 @@ void testNotSupportedCommands(String commandName) { %s: "books" }""", commandName))) { - SQLFeatureNotSupportedException exception = - assertThrows(SQLFeatureNotSupportedException.class, pstm::executeUpdate); - assertThat(exception.getMessage()).contains(commandName); + assertThatThrownBy(pstm::executeUpdate) + .isInstanceOf(SQLFeatureNotSupportedException.class) + .hasMessageContaining(commandName); } }); } From 39df511f8b55f4aaa69fadf25f0156c0bbe0728f Mon Sep 17 00:00:00 2001 From: Viacheslav Babanin Date: Fri, 24 Oct 2025 21:43:11 -0700 Subject: [PATCH 14/37] Update src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java Co-authored-by: Valentin Kovalenko --- .../jdbc/MongoPreparedStatementIntegrationTests.java | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java index 31454faf..e1b92f1e 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java @@ -723,10 +723,9 @@ void testNotSupportedDeleteElements(String unsupportedElement) { ] }""", unsupportedElement))) { - SQLFeatureNotSupportedException exception = - assertThrows(SQLFeatureNotSupportedException.class, pstm::executeUpdate); - assertThat(exception.getMessage()) - .isEqualTo(format("Unsupported elements in delete command: [%s]", unsupportedElement)); + assertThatThrownBy(pstm::executeUpdate) + .isInstanceOf(SQLFeatureNotSupportedException.class) + .hasMessage("Unsupported elements in delete command: [%s]".formatted(unsupportedElement)); } }); } From 11a2b07794aa294548b3e08a25ab8c95b3846208 Mon Sep 17 00:00:00 2001 From: Viacheslav Babanin Date: Fri, 24 Oct 2025 21:43:17 -0700 Subject: [PATCH 15/37] Update src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java Co-authored-by: Valentin Kovalenko --- .../jdbc/MongoPreparedStatementIntegrationTests.java | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java index e1b92f1e..5110984a 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java @@ -698,10 +698,9 @@ void testNotSupportedUpdateElements(String unsupportedElement) { ] }""", unsupportedElement))) { - SQLFeatureNotSupportedException exception = - assertThrows(SQLFeatureNotSupportedException.class, pstm::executeUpdate); - assertThat(exception.getMessage()) - .isEqualTo(format("Unsupported elements in update command: [%s]", unsupportedElement)); + assertThatThrownBy(pstm::executeUpdate) + .isInstanceOf(SQLFeatureNotSupportedException.class) + .hasMessage("Unsupported elements in update command: [%s]".formatted(unsupportedElement)); } }); } From 510bc2d21340e108773e284fc9ec10ed9efc508b Mon Sep 17 00:00:00 2001 From: Viacheslav Babanin Date: Mon, 27 Oct 2025 13:39:42 -0700 Subject: [PATCH 16/37] Update src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java Co-authored-by: Valentin Kovalenko --- .../hibernate/jdbc/MongoPreparedStatementIntegrationTests.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java index 5110984a..e0481fb7 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java @@ -283,8 +283,7 @@ void testQueriesReturningResult() { void testEmptyBatch() { doWorkAwareOfAutoCommit(connection -> { try (var pstmt = connection.prepareStatement(INSERT_MQL)) { - var updateCounts = pstmt.executeBatch(); - assertEquals(0, updateCounts.length); + assertExecuteBatch(pstmt, 0); } catch (SQLException e) { throw new RuntimeException(e); } From 6faa9cb9f7a1bd693e2b352b2589ef372f87f329 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Mon, 27 Oct 2025 17:58:25 -0700 Subject: [PATCH 17/37] Fix batch update count reporting. Remove redundant tests. --- ...ongoPreparedStatementIntegrationTests.java | 389 ++++++++++++-- .../jdbc/MongoPreparedStatement.java | 32 +- .../hibernate/jdbc/MongoStatement.java | 473 +++++++++++------- .../jdbc/MongoPreparedStatementTests.java | 279 ++++++----- 4 files changed, 803 insertions(+), 370 deletions(-) diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java index e0481fb7..3ccfd3cf 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java @@ -16,6 +16,7 @@ package com.mongodb.hibernate.jdbc; +import static com.mongodb.hibernate.internal.MongoConstants.EXTENDED_JSON_WRITER_SETTINGS; import static com.mongodb.hibernate.internal.MongoConstants.ID_FIELD_NAME; import static com.mongodb.hibernate.jdbc.MongoStatementIntegrationTests.doAndTerminateTransaction; import static com.mongodb.hibernate.jdbc.MongoStatementIntegrationTests.doWithSpecifiedAutoCommit; @@ -24,13 +25,14 @@ import static java.sql.Statement.SUCCESS_NO_INFO; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.params.provider.Arguments.of; import com.mongodb.client.MongoCollection; import com.mongodb.client.model.Sorts; @@ -43,10 +45,12 @@ import java.sql.PreparedStatement; import java.sql.SQLException; import java.sql.SQLFeatureNotSupportedException; +import java.sql.SQLSyntaxErrorException; import java.util.ArrayList; import java.util.List; import java.util.Random; import java.util.function.Function; +import java.util.stream.Stream; import org.bson.BsonDocument; import org.hibernate.Session; import org.hibernate.SessionFactory; @@ -60,6 +64,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; @ExtendWith(MongoExtension.class) @@ -231,6 +237,75 @@ private void assertRoundTrip(SqlConsumer executor) { }); } + @Test + void testNoCommandNameProvidedExecuteQuery() { + assertInvalidMql( + """ + {}""", + PreparedStatement::executeQuery, + "Invalid MQL. Command name is missing: [{}]"); + } + + @Test + void testNoCommandNameProvidedExecuteUpdate() { + assertInvalidMql( + """ + {}""", + PreparedStatement::executeUpdate, + "Invalid MQL. Command name is missing: [{}]"); + } + + @Test + void testNoCommandNameProvidedExecuteBatch() { + assertInvalidMql( + """ + {}""", + preparedStatement -> { + preparedStatement.addBatch(); + preparedStatement.executeBatch(); + }, + "Invalid MQL. Command name is missing: [{}]"); + } + + @Test + void testNoCollectionNameProvidedExecuteQuery() { + assertInvalidMql( + """ + { + insert: {} + }""", + PreparedStatement::executeQuery, + """ + Invalid MQL. Collection name is missing [{"insert": {}}]"""); + } + + @Test + void testNoCollectionNameProvidedExecuteUpdate() { + assertInvalidMql( + """ + { + insert: {} + }""", + PreparedStatement::executeUpdate, + """ + Invalid MQL. Collection name is missing [{"insert": {}}]"""); + } + + @Test + void testNoCollectionNameProvidedExecuteBatch() { + assertInvalidMql( + """ + { + insert: {} + }""", + preparedStatement -> { + preparedStatement.addBatch(); + preparedStatement.executeBatch(); + }, + """ + Invalid MQL. Collection name is missing [{"insert": {}}]"""); + } + @Nested class ExecuteBatchTests { private static final String INSERT_MQL = @@ -271,8 +346,6 @@ void testQueriesReturningResult() { .returns(null, BatchUpdateException::getUpdateCounts) .returns(null, BatchUpdateException::getSQLState) .returns(0, BatchUpdateException::getErrorCode); - } catch (SQLException e) { - throw new RuntimeException(e); } }); @@ -283,7 +356,8 @@ void testQueriesReturningResult() { void testEmptyBatch() { doWorkAwareOfAutoCommit(connection -> { try (var pstmt = connection.prepareStatement(INSERT_MQL)) { - assertExecuteBatch(pstmt, 0); + var updateCounts = pstmt.executeBatch(); + assertEquals(0, updateCounts.length); } catch (SQLException e) { throw new RuntimeException(e); } @@ -293,7 +367,7 @@ void testEmptyBatch() { } @Test - @DisplayName("Test statement’s batch queue is reset once executeBatch returns") + @DisplayName("Test statement’s batch of commands is reset once executeBatch returns") void testBatchQueueIsResetAfterExecute() { doWorkAwareOfAutoCommit(connection -> { try (var pstmt = connection.prepareStatement( @@ -313,8 +387,6 @@ void testBatchQueueIsResetAfterExecute() { pstmt.addBatch(); assertExecuteBatch(pstmt, 1); assertExecuteBatch(pstmt, 0); - } catch (SQLException e) { - throw new RuntimeException(e); } }); @@ -353,8 +425,6 @@ void testBatchParametersReuse() { // No need to set title again, it should be reused from the previous execution pstmt.addBatch(); assertExecuteBatch(pstmt, 1); - } catch (SQLException e) { - throw new RuntimeException(e); } }); @@ -376,7 +446,7 @@ void testBatchParametersReuse() { @Test void testBatchInsert() { - var batchCount = 3; + var batchSize = 3; doWorkAwareOfAutoCommit(connection -> { try (var pstmt = connection.prepareStatement( """ @@ -388,26 +458,24 @@ void testBatchInsert() { }] }""")) { - for (int i = 1; i <= batchCount; i++) { + for (int i = 1; i <= batchSize; i++) { pstmt.setInt(1, i); pstmt.setString(2, "Book " + i); pstmt.addBatch(); } - assertExecuteBatch(pstmt, batchCount); - } catch (SQLException e) { - throw new RuntimeException(e); + assertExecuteBatch(pstmt, batchSize); } }); var expectedDocs = new ArrayList(); - for (int i = 0; i < batchCount; i++) { + for (int i = 1; i <= batchSize; i++) { expectedDocs.add(BsonDocument.parse(format( """ { "_id": %d, "title": "Book %d" }""", - i + 1, i + 1))); + i, i))); } assertThat(mongoCollection.find()).containsExactlyElementsOf(expectedDocs); } @@ -415,7 +483,7 @@ void testBatchInsert() { @Test void testBatchUpdate() { insertTestData(session, INSERT_MQL); - var batchCount = 3; + var batchSize = 3; doWorkAwareOfAutoCommit(connection -> { try (var pstmt = connection.prepareStatement( """ @@ -427,26 +495,24 @@ void testBatchUpdate() { multi: true }] }""")) { - for (int i = 1; i <= batchCount; i++) { + for (int i = 1; i <= batchSize; i++) { pstmt.setInt(1, i); pstmt.setString(2, "Book " + i); pstmt.addBatch(); } - assertExecuteBatch(pstmt, batchCount); - } catch (SQLException e) { - throw new RuntimeException(e); + assertExecuteBatch(pstmt, batchSize); } }); var expectedDocs = new ArrayList(); - for (int i = 0; i < batchCount; i++) { + for (int i = 1; i <= batchSize; i++) { expectedDocs.add(BsonDocument.parse(format( """ { "_id": %d, "title": "Book %d" }""", - i + 1, i + 1))); + i, i))); } assertThat(mongoCollection.find()).containsExactlyElementsOf(expectedDocs); } @@ -454,7 +520,7 @@ void testBatchUpdate() { @Test void testBatchDelete() { insertTestData(session, INSERT_MQL); - var batchCount = 3; + var batchSize = 3; doWorkAwareOfAutoCommit(connection -> { try (var pstmt = connection.prepareStatement( """ @@ -465,13 +531,11 @@ void testBatchDelete() { limit: 0 }] }""")) { - for (int i = 1; i <= batchCount; i++) { + for (int i = 1; i <= batchSize; i++) { pstmt.setInt(1, i); pstmt.addBatch(); } - assertExecuteBatch(pstmt, batchCount); - } catch (SQLException e) { - throw new RuntimeException(e); + assertExecuteBatch(pstmt, batchSize); } }); @@ -666,12 +730,12 @@ void testDelete() { @ValueSource(strings = {"findAndModify", "aggregate", "bulkWrite"}) void testNotSupportedCommands(String commandName) { doWorkAwareOfAutoCommit(connection -> { - try (PreparedStatement pstm = connection.prepareStatement(format( + try (PreparedStatement pstm = connection.prepareStatement( """ { %s: "books" - }""", - commandName))) { + }""" + .formatted(commandName))) { assertThatThrownBy(pstm::executeUpdate) .isInstanceOf(SQLFeatureNotSupportedException.class) .hasMessageContaining(commandName); @@ -679,11 +743,168 @@ void testNotSupportedCommands(String commandName) { }); } - @ParameterizedTest(name = "test not supported update elements. Parameters: option={0}") - @ValueSource(strings = {"hint", "collation", "arrayFilters", "sort", "upsert", "c"}) - void testNotSupportedUpdateElements(String unsupportedElement) { + @ParameterizedTest(name = "test not supported update command field. Parameters: option={0}") + @ValueSource( + strings = { + "maxTimeMS: 1", + "writeConcern: {}", + "bypassDocumentValidation: true", + "comment: {}", + "ordered: true", + "let: {}" + }) + void testNotSupportedUpdateCommandField(String unsupportedField) { doWorkAwareOfAutoCommit(connection -> { - try (PreparedStatement pstm = connection.prepareStatement(format( + try (PreparedStatement pstm = connection.prepareStatement( + """ + { + update: "books", + updates: [ + { + q: { author: { $eq: "Leo Tolstoy" } }, + u: { $set: { outOfStock: true } }, + multi: true + } + ], + %s + }""" + .formatted(unsupportedField))) { + assertThatThrownBy(pstm::executeUpdate) + .isInstanceOf(SQLFeatureNotSupportedException.class) + .hasMessage("Unsupported field in update command: [%s]" + .formatted(getFieldName(unsupportedField))); + } + }); + } + + @Test + void testAbsentRequiredUpdateCommandField() { + doWorkAwareOfAutoCommit(connection -> { + String mql = + """ + { + update: "books" + }"""; + try (PreparedStatement pstm = connection.prepareStatement(mql)) { + assertThatThrownBy(pstm::executeUpdate) + .isInstanceOf(SQLSyntaxErrorException.class) + .hasMessage("Invalid MQL: [%s]".formatted(toExtendedJson(mql))) + .cause() + .hasMessageContaining("Document does not contain key updates"); + } + }); + } + + @ParameterizedTest(name = "test not supported delete command field. Parameters: option={0}") + @ValueSource(strings = {"maxTimeMS: 1", "writeConcern: {}", "comment: {}", "ordered: true", "let: {}"}) + void testNotSupportedDeleteCommandField(String unsupportedField) { + doWorkAwareOfAutoCommit(connection -> { + try (PreparedStatement pstm = connection.prepareStatement( + """ + { + delete: "books", + deletes: [ + { + q: { author: { $eq: "Leo Tolstoy" } }, + limit: 0 + } + ] + %s + }""" + .formatted(unsupportedField))) { + assertThatThrownBy(pstm::executeUpdate) + .isInstanceOf(SQLFeatureNotSupportedException.class) + .hasMessage("Unsupported field in delete command: [%s]" + .formatted(getFieldName(unsupportedField))); + } + }); + } + + @Test + void testAbsentRequiredDeleteCommandField() { + doWorkAwareOfAutoCommit(connection -> { + String mql = + """ + { + delete: "books" + }"""; + try (PreparedStatement pstm = connection.prepareStatement(mql)) { + assertThatThrownBy(pstm::executeUpdate) + .isInstanceOf(SQLSyntaxErrorException.class) + .hasMessage("Invalid MQL: [%s]".formatted(toExtendedJson(mql))) + .cause() + .hasMessageContaining("Document does not contain key deletes"); + } + }); + } + + @ParameterizedTest(name = "test not supported insert command field. Parameters: option={0}") + @ValueSource( + strings = { + "maxTimeMS: 1", + "writeConcern: {}", + "bypassDocumentValidation: true", + "comment: {}", + "ordered: true", + "let: {}" + }) + void testNotSupportedInsertCommandField(String unsupportedField) { + doWorkAwareOfAutoCommit(connection -> { + try (PreparedStatement pstm = connection.prepareStatement( + """ + { + insert: "books", + documents: [ + { + _id: 1 + } + ], + %s + }""" + .formatted(unsupportedField))) { + assertThatThrownBy(pstm::executeUpdate) + .isInstanceOf(SQLFeatureNotSupportedException.class) + .hasMessage("Unsupported field in insert command: [%s]" + .formatted(getFieldName(unsupportedField))); + } + }); + } + + @Test + void testAbsentRequiredInsertCommandField() { + doWorkAwareOfAutoCommit(connection -> { + String mql = + """ + { + insert: "books" + }"""; + try (PreparedStatement pstm = connection.prepareStatement(mql)) { + assertThatThrownBy(pstm::executeUpdate) + .isInstanceOf(SQLSyntaxErrorException.class) + .hasMessage("Invalid MQL: [%s]".formatted(toExtendedJson(mql))) + .cause() + .hasMessageContaining("Document does not contain key documents"); + } + }); + } + + private static Stream unsupportedUpdateStatementFields() { + return Stream.of( + of("hint: {}", "Unsupported field in update statement: [hint]"), + of("hint: \"a\"", "Unsupported field in update statement: [hint]"), + of("collation: {}", "Unsupported field in update statement: [collation]"), + of("arrayFilters: []", "Unsupported field in update statement: [arrayFilters]"), + of("sort: {}", "Unsupported field in update statement: [sort]"), + of("upsert: true", "Unsupported field in update statement: [upsert]"), + of("u: []", "Only document type is supported as value for field: [u]"), + of("c: {}", "Unsupported field in update statement: [c]")); + } + + @ParameterizedTest(name = "test not supported update statement field. Parameters: option={0}") + @MethodSource("unsupportedUpdateStatementFields") + void testNotSupportedUpdateStatemenField(String unsupportedField, String expectedMessage) { + doWorkAwareOfAutoCommit(connection -> { + try (PreparedStatement pstm = connection.prepareStatement( """ { update: "books", @@ -692,23 +913,50 @@ void testNotSupportedUpdateElements(String unsupportedElement) { q: { author: { $eq: "Leo Tolstoy" } }, u: { $set: { outOfStock: true } }, multi: true, - %s: { _id: 1 } + %s } ] - }""", - unsupportedElement))) { + }""" + .formatted(unsupportedField))) { assertThatThrownBy(pstm::executeUpdate) .isInstanceOf(SQLFeatureNotSupportedException.class) - .hasMessage("Unsupported elements in update command: [%s]".formatted(unsupportedElement)); + .hasMessage(expectedMessage); + } + }); + } + + @ParameterizedTest(name = "test absent required update statement field. Parameters: fieldToRemove={0}") + @ValueSource(strings = {"q: {}", "u: {}"}) + void testAbsentRequiredUpdateStatementField(String fieldToRemove) { + doWorkAwareOfAutoCommit(connection -> { + String mql = + """ + { + update: "books", + updates: [ + { + q: {}, + u: {}, + } + ] + }""" + .replace(fieldToRemove + ",", ""); + try (PreparedStatement pstm = connection.prepareStatement(mql)) { + assertThatThrownBy(pstm::executeUpdate) + .isInstanceOf(SQLSyntaxErrorException.class) + .hasMessage("Invalid MQL: [%s]".formatted(toExtendedJson(mql))) + .cause() + .hasMessageContaining( + "Document does not contain key %s".formatted(getFieldName(fieldToRemove))); } }); } - @ParameterizedTest(name = "test not supported delete elements. Parameters: option={0}") - @ValueSource(strings = {"hint", "collation"}) - void testNotSupportedDeleteElements(String unsupportedElement) { + @ParameterizedTest(name = "test not supported delete statement field. Parameters: option={0}") + @ValueSource(strings = {"hint: {}", "hint: \"a\"", "collation: {}"}) + void testNotSupportedDeleteStatementField(String unsupportedField) { doWorkAwareOfAutoCommit(connection -> { - try (PreparedStatement pstm = connection.prepareStatement(format( + try (PreparedStatement pstm = connection.prepareStatement( """ { delete: "books", @@ -716,14 +964,42 @@ void testNotSupportedDeleteElements(String unsupportedElement) { { q: { author: { $eq: "Leo Tolstoy" } }, limit: 0, - %s: { _id: 1 } + %s } ] - }""", - unsupportedElement))) { + }""" + .formatted(unsupportedField))) { assertThatThrownBy(pstm::executeUpdate) .isInstanceOf(SQLFeatureNotSupportedException.class) - .hasMessage("Unsupported elements in delete command: [%s]".formatted(unsupportedElement)); + .hasMessage("Unsupported field in delete statement: [%s]" + .formatted(getFieldName(unsupportedField))); + } + }); + } + + @ParameterizedTest(name = "test absent required update statement field. Parameters: fieldToRemove={0}") + @ValueSource(strings = {"q: {}", "limit: 0"}) + void testAbsentRequiredDeleteStatementField(String fieldToRemove) { + doWorkAwareOfAutoCommit(connection -> { + String mql = + """ + { + delete: "books", + deletes: [ + { + q: {}, + limit: 0, + } + ] + }""" + .replace(fieldToRemove + ",", ""); + try (PreparedStatement pstm = connection.prepareStatement(mql)) { + assertThatThrownBy(pstm::executeUpdate) + .isInstanceOf(SQLSyntaxErrorException.class) + .hasMessage("Invalid MQL: [%s]".formatted(toExtendedJson(mql))) + .cause() + .hasMessageContaining( + "Document does not contain key %s".formatted(getFieldName(fieldToRemove))); } }); } @@ -740,6 +1016,25 @@ private void assertExecuteUpdate( assertThat(mongoCollection.find().sort(Sorts.ascending(ID_FIELD_NAME))) .containsExactlyElementsOf(expectedDocuments); } + + private static String getFieldName(final String unsupportedField) { + return BsonDocument.parse("{" + unsupportedField + "}").getFirstKey(); + } + + private String toExtendedJson(final String mql) { + return BsonDocument.parse(mql).toJson(EXTENDED_JSON_WRITER_SETTINGS); + } + } + + private void assertInvalidMql( + final String mql, SqlConsumer executor, String expectedExceptionMessage) { + doWorkAwareOfAutoCommit(connection -> { + try (PreparedStatement pstm = connection.prepareStatement(mql)) { + assertThatThrownBy(() -> executor.accept(pstm)) + .isInstanceOf(SQLSyntaxErrorException.class) + .hasMessage(expectedExceptionMessage); + } + }); } private void doWorkAwareOfAutoCommit(Work work) { diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java index 1b569f6e..d6c99fd1 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java @@ -73,7 +73,7 @@ public ResultSet executeQuery() throws SQLException { checkClosed(); closeLastOpenResultSet(); checkAllParametersSet(); - return executeQueryCommand(command); + return executeQuery(command); } @Override @@ -82,7 +82,7 @@ public int executeUpdate() throws SQLException { closeLastOpenResultSet(); checkAllParametersSet(); checkSupportedUpdateCommand(command); - return executeUpdateCommand(command); + return executeUpdate(command); } private void checkAllParametersSet() throws SQLException { @@ -218,13 +218,13 @@ public void clearBatch() throws SQLException { @Override public int[] executeBatch() throws SQLException { checkClosed(); - closeLastOpenResultSet(); - if (commandBatch.isEmpty()) { - return EMPTY_BATCH_RESULT; - } - checkSupportedBatchCommand(commandBatch.get(0)); try { - executeBulkWrite(commandBatch, ExecutionType.BATCH); + closeLastOpenResultSet(); + if (commandBatch.isEmpty()) { + return EMPTY_BATCH_RESULT; + } + checkSupportedBatchCommand(commandBatch.get(0)); + executeBatch(commandBatch); var updateCounts = new int[commandBatch.size()]; // We cannot determine the actual number of rows affected for each command in the batch. Arrays.fill(updateCounts, Statement.SUCCESS_NO_INFO); @@ -234,20 +234,22 @@ public int[] executeBatch() throws SQLException { } } + /** @throws BatchUpdateException if any of the commands in the batch attempts to return a result set. */ private void checkSupportedBatchCommand(BsonDocument command) throws SQLException { - var commandType = getCommandType(command); - if (commandType == CommandType.AGGREGATE) { - // The method executeBatch throws a BatchUpdateException if any of the commands in the batch attempts to - // return a result set. + var commandDescription = getCommandDescription(command); + if (commandDescription.returnsResultSet()) { throw new BatchUpdateException( format( - "Commands returning result set are not supported. Received command: %s", - commandType.getCommandName()), + "Commands returning result set are not allowed. Received command: %s", + commandDescription.getCommandName()), null, NO_ERROR_CODE, null); } - checkSupportedUpdateCommand(commandType); + if (!commandDescription.isUpdate()) { + throw new SQLFeatureNotSupportedException( + format("Unsupported command for batch operation: %s", commandDescription.getCommandName())); + } } @Override diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index db4cfb02..699377ec 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -16,22 +16,24 @@ package com.mongodb.hibernate.jdbc; +import static com.mongodb.assertions.Assertions.assertFalse; +import static com.mongodb.assertions.Assertions.assertNotNull; +import static com.mongodb.assertions.Assertions.assertTrue; import static com.mongodb.assertions.Assertions.fail; -import static com.mongodb.hibernate.internal.MongoAssertions.assertNotNull; +import static com.mongodb.hibernate.internal.MongoConstants.EXTENDED_JSON_WRITER_SETTINGS; import static com.mongodb.hibernate.internal.MongoConstants.ID_FIELD_NAME; import static com.mongodb.hibernate.internal.VisibleForTesting.AccessModifier.PRIVATE; -import static java.lang.Math.max; import static java.lang.String.format; -import static java.util.Collections.singletonList; import static java.util.stream.Collectors.toCollection; +import static org.bson.BsonBoolean.FALSE; import com.mongodb.ErrorCategory; import com.mongodb.MongoBulkWriteException; import com.mongodb.MongoException; -import com.mongodb.MongoExecutionTimeoutException; import com.mongodb.MongoSocketReadTimeoutException; import com.mongodb.MongoSocketWriteTimeoutException; import com.mongodb.MongoTimeoutException; +import com.mongodb.bulk.BulkWriteError; import com.mongodb.bulk.BulkWriteResult; import com.mongodb.client.ClientSession; import com.mongodb.client.MongoCollection; @@ -51,24 +53,31 @@ import java.sql.SQLFeatureNotSupportedException; import java.sql.SQLIntegrityConstraintViolationException; import java.sql.SQLSyntaxErrorException; -import java.sql.SQLTimeoutException; import java.sql.SQLWarning; import java.sql.Statement; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; +import org.bson.BSONException; import org.bson.BsonDocument; +import org.bson.BsonInvalidOperationException; +import org.bson.BsonString; import org.bson.BsonValue; import org.jspecify.annotations.Nullable; class MongoStatement implements StatementAdapter { private static final String EXCEPTION_MESSAGE_OPERATION_FAILED = "Failed to execute operation"; - private static final String EXCEPTION_MESSAGE_BATCH_FAILED = "Batch execution failed"; private static final String EXCEPTION_MESSAGE_TIMEOUT = "Timeout while waiting for operation to complete"; static final int NO_ERROR_CODE = 0; static final int[] EMPTY_BATCH_RESULT = new int[0]; + + @Nullable public static final String NULL_SQL_STATE = null; + private final MongoDatabase mongoDatabase; private final MongoConnection mongoConnection; private final ClientSession clientSession; @@ -87,7 +96,8 @@ public ResultSet executeQuery(String mql) throws SQLException { checkClosed(); closeLastOpenResultSet(); var command = parse(mql); - return executeQueryCommand(command); + checkSupportedQueryCommand(command); + return executeQuery(command); } void closeLastOpenResultSet() throws SQLException { @@ -96,12 +106,11 @@ void closeLastOpenResultSet() throws SQLException { } } - ResultSet executeQueryCommand(BsonDocument command) throws SQLException { - var commandType = getCommandType(command); - checkSupportedQueryCommand(command); + ResultSet executeQuery(BsonDocument command) throws SQLException { + var commandDescription = getCommandDescription(command); try { startTransactionIfNeeded(); - var collection = getCollection(commandType, command); + var collection = getCollection(commandDescription, command); var pipeline = command.getArray("pipeline").stream() .map(BsonValue::asDocument) .toList(); @@ -111,7 +120,7 @@ ResultSet executeQueryCommand(BsonDocument command) throws SQLException { return resultSet = new MongoResultSet( collection.aggregate(clientSession, pipeline).cursor(), fieldNames); } catch (RuntimeException exception) { - throw handleException(exception, commandType, ExecutionType.QUERY); + throw handleQueryOrUpdateException(exception); } } @@ -151,27 +160,39 @@ public int executeUpdate(String mql) throws SQLException { closeLastOpenResultSet(); var command = parse(mql); checkSupportedUpdateCommand(command); - return executeUpdateCommand(command); + return executeUpdate(command); } - int executeUpdateCommand(BsonDocument command) throws SQLException { - return executeBulkWrite(singletonList(command), ExecutionType.UPDATE); - } - - int executeBulkWrite(List commandBatch, ExecutionType executionType) throws SQLException { - var firstDocumentInBatch = commandBatch.get(0); - var commandType = getCommandType(firstDocumentInBatch); - var collection = getCollection(commandType, firstDocumentInBatch); + void executeBatch(List commandBatch) throws SQLException { + var firstCommandInBatch = commandBatch.get(0); + var commandDescription = getCommandDescription(firstCommandInBatch); + var collection = getCollection(commandDescription, firstCommandInBatch); + WriteModelsToCommandMapper writeModelsToCommandMapper = null; try { startTransactionIfNeeded(); var writeModels = new ArrayList>(commandBatch.size()); - for (var command : commandBatch) { - WriteModelConverter.convertToWriteModels(commandType, command, writeModels); + writeModelsToCommandMapper = new WriteModelsToCommandMapper(commandBatch.size()); + for (BsonDocument command : commandBatch) { + WriteModelConverter.convertToWriteModels(commandDescription, command, writeModels); + writeModelsToCommandMapper.add(writeModels.size()); } + collection.bulkWrite(clientSession, writeModels); + } catch (RuntimeException exception) { + throw handleBatchException(exception, writeModelsToCommandMapper); + } + } + + int executeUpdate(BsonDocument command) throws SQLException { + var commandDescription = getCommandDescription(command); + var collection = getCollection(commandDescription, command); + try { + startTransactionIfNeeded(); + var writeModels = new ArrayList>(); + WriteModelConverter.convertToWriteModels(commandDescription, command, writeModels); var bulkWriteResult = collection.bulkWrite(clientSession, writeModels); - return getUpdateCount(commandType, bulkWriteResult); + return getUpdateCount(commandDescription, bulkWriteResult); } catch (RuntimeException exception) { - throw handleException(exception, commandType, executionType); + throw handleQueryOrUpdateException(exception); } } @@ -250,24 +271,19 @@ void checkClosed() throws SQLException { } } - private void checkSupportedQueryCommand(BsonDocument command) throws SQLFeatureNotSupportedException { - var commandType = getCommandType(command); - if (commandType != CommandType.AGGREGATE) { + private void checkSupportedQueryCommand(BsonDocument command) throws SQLException { + var commandDescription = getCommandDescription(command); + if (commandDescription.isUpdate()) { throw new SQLFeatureNotSupportedException( - format("Unsupported command for query operation: %s", commandType.getCommandName())); + format("Unsupported command for query operation: %s", commandDescription.getCommandName())); } } void checkSupportedUpdateCommand(BsonDocument command) throws SQLException { - checkSupportedUpdateCommand(getCommandType(command)); - } - - void checkSupportedUpdateCommand(CommandType commandType) throws SQLException { - if (commandType != CommandType.INSERT - && commandType != CommandType.UPDATE - && commandType != CommandType.DELETE) { + CommandDescription commandDescription = getCommandDescription(command); + if (!commandDescription.isUpdate()) { throw new SQLFeatureNotSupportedException( - format("Unsupported command for batch operation: %s", commandType.getCommandName())); + "Unsupported command for update operation: %s".formatted(commandDescription.getCommandName())); } } @@ -275,13 +291,13 @@ static BsonDocument parse(String mql) throws SQLSyntaxErrorException { try { return BsonDocument.parse(mql); } catch (RuntimeException e) { - throw new SQLSyntaxErrorException("Invalid MQL: " + mql, e); + throw new SQLSyntaxErrorException("Invalid MQL: [%s]".formatted(mql), e); } } /** - * Starts transaction for the first {@link java.sql.Statement} executing if - * {@linkplain MongoConnection#getAutoCommit() auto-commit} is disabled. + * Starts transaction for the first {@link Statement} executing if {@linkplain MongoConnection#getAutoCommit() + * auto-commit} is disabled. */ private void startTransactionIfNeeded() throws SQLException { if (!mongoConnection.getAutoCommit() && !clientSession.hasActiveTransaction()) { @@ -289,19 +305,37 @@ private void startTransactionIfNeeded() throws SQLException { } } - static CommandType getCommandType(BsonDocument command) throws SQLFeatureNotSupportedException { - // The first key is always the command name, e.g. "insert", "update", "delete". - return CommandType.fromString(assertNotNull(command.getFirstKey())); + /** The first key is always the command name, e.g. "insert", "update", "delete". */ + static CommandDescription getCommandDescription(BsonDocument command) throws SQLException { + String commandName; + try { + commandName = command.getFirstKey(); + } catch (NoSuchElementException exception) { + throw new SQLSyntaxErrorException( + "Invalid MQL. Command name is missing: [%s]" + .formatted(command.toJson(EXTENDED_JSON_WRITER_SETTINGS)), + exception); + } + return CommandDescription.fromString(commandName); } - private MongoCollection getCollection(CommandType commandType, BsonDocument command) { - var collectionName = - assertNotNull(command.getString(commandType.getCommandName()).getValue()); - return mongoDatabase.getCollection(collectionName, BsonDocument.class); + private MongoCollection getCollection(CommandDescription commandDescription, BsonDocument command) + throws SQLSyntaxErrorException { + var commandName = commandDescription.getCommandName(); + BsonString collectionName; + try { + collectionName = command.getString(commandName); + } catch (BsonInvalidOperationException exception) { + throw new SQLSyntaxErrorException( + "Invalid MQL. Collection name is missing [%s]" + .formatted(command.toJson(EXTENDED_JSON_WRITER_SETTINGS)), + exception); + } + return mongoDatabase.getCollection(collectionName.getValue(), BsonDocument.class); } - private static int getUpdateCount(CommandType commandType, BulkWriteResult bulkWriteResult) { - return switch (commandType) { + private static int getUpdateCount(CommandDescription commandDescription, BulkWriteResult bulkWriteResult) { + return switch (commandDescription) { case INSERT -> bulkWriteResult.getInsertedCount(); case UPDATE -> bulkWriteResult.getModifiedCount(); case DELETE -> bulkWriteResult.getDeletedCount(); @@ -309,86 +343,88 @@ private static int getUpdateCount(CommandType commandType, BulkWriteResult bulkW }; } - private static SQLException handleException( - RuntimeException exceptionToHandle, CommandType commandType, ExecutionType executionType) { + private static SQLException handleBatchException( + RuntimeException exceptionToHandle, @Nullable WriteModelsToCommandMapper writeModelsToCommandMapper) { var errorCode = getErrorCode(exceptionToHandle); - return switch (executionType) { - case BATCH -> handleBatchException(exceptionToHandle, commandType, errorCode); - case QUERY, UPDATE -> { - if (exceptionToHandle instanceof MongoException mongoException) { - var handledException = handleMongoException(mongoException, errorCode); - yield toSqlException(errorCode, handledException); - } - yield toSqlException(NO_ERROR_CODE, exceptionToHandle); + if (exceptionToHandle instanceof MongoException mongoException) { + var cause = handleMongoException(errorCode, mongoException); + if (mongoException instanceof MongoBulkWriteException bulkWriteException) { + return createBatchUpdateException( + errorCode, cause, bulkWriteException, assertNotNull(writeModelsToCommandMapper)); } - }; + // TODO-HIBERNATE-132 BatchUpdateException is thrown when one of the commands fails to execute properly. + // When + // exception is not of MongoBulkWriteException, we are not sure if any command was executed successfully or + // failed. + return cause; + } + return new SQLException(EXCEPTION_MESSAGE_OPERATION_FAILED, null, errorCode, exceptionToHandle); } - private static SQLException handleBatchException( - RuntimeException exceptionToHandle, CommandType commandType, int errorCode) { + private static SQLException handleQueryOrUpdateException(RuntimeException exceptionToHandle) { + var errorCode = getErrorCode(exceptionToHandle); if (exceptionToHandle instanceof MongoException mongoException) { - var cause = handleMongoException(mongoException, errorCode); - if (exceptionToHandle instanceof MongoBulkWriteException bulkWriteException) { - return createBatchUpdateException(cause, bulkWriteException.getWriteResult(), errorCode, commandType); - } - return createBatchUpdateException(errorCode, cause); + return handleMongoException(errorCode, mongoException); } - return createBatchUpdateException(NO_ERROR_CODE, exceptionToHandle); + return new SQLException(EXCEPTION_MESSAGE_OPERATION_FAILED, null, errorCode, exceptionToHandle); } - private static Exception handleMongoException(final MongoException exceptionToHandle, final int errorCode) { + private static SQLException handleMongoException(final int errorCode, final MongoException exceptionToHandle) { if (isTimeoutException(exceptionToHandle)) { - return new SQLTimeoutException(EXCEPTION_MESSAGE_TIMEOUT, null, errorCode, exceptionToHandle); + return new SQLException(EXCEPTION_MESSAGE_TIMEOUT, null, errorCode, exceptionToHandle); } - return handleByErrorCode(errorCode, exceptionToHandle); - } - - private static Exception handleByErrorCode(int errorCode, final MongoException exceptionToHandle) { var errorCategory = ErrorCategory.fromErrorCode(errorCode); return switch (errorCategory) { case DUPLICATE_KEY -> new SQLIntegrityConstraintViolationException( EXCEPTION_MESSAGE_OPERATION_FAILED, null, errorCode, exceptionToHandle); - case EXECUTION_TIMEOUT -> - new SQLTimeoutException(EXCEPTION_MESSAGE_TIMEOUT, null, errorCode, exceptionToHandle); - case UNCATEGORIZED -> exceptionToHandle; + // TODO-HIBERNATE-132 EXECUTION_TIMEOUT code is returned from the server. Do we know how many commands were + // executed + // successfully so we can return it as BatchUpdateException? + case EXECUTION_TIMEOUT -> new SQLException(EXCEPTION_MESSAGE_TIMEOUT, null, errorCode, exceptionToHandle); + case UNCATEGORIZED -> + new SQLException(EXCEPTION_MESSAGE_OPERATION_FAILED, null, errorCode, exceptionToHandle); }; } private static int getErrorCode(final RuntimeException runtimeException) { if (runtimeException instanceof MongoBulkWriteException mongoBulkWriteException) { - return getErrorCode(mongoBulkWriteException); + var writeErrors = mongoBulkWriteException.getWriteErrors(); + if (writeErrors.isEmpty()) { + return NO_ERROR_CODE; + } + // Since we are executing an ordered bulk write, there will be at most one BulkWriteError. + assertTrue(writeErrors.size() == 1); + var code = writeErrors.get(0).getCode(); + assertFalse(code == NO_ERROR_CODE); + return code; } else if (runtimeException instanceof MongoException mongoException) { - return max(NO_ERROR_CODE, mongoException.getCode()); + var code = mongoException.getCode(); + assertFalse(code == NO_ERROR_CODE); + return code; } return NO_ERROR_CODE; } - private static int getErrorCode(final MongoBulkWriteException mongoBulkWriteException) { - var writeErrors = mongoBulkWriteException.getWriteErrors(); - // Since we are executing an ordered bulk write, there will be at most one BulkWriteError. - return writeErrors.isEmpty() ? NO_ERROR_CODE : writeErrors.get(0).getCode(); - } - - private static SQLException toSqlException(final int errorCode, final Exception exception) { - if (exception instanceof SQLException sqlException) { - return sqlException; - } - return new SQLException(EXCEPTION_MESSAGE_OPERATION_FAILED, null, errorCode, exception); - } - - private static SQLException createBatchUpdateException(final int errorCode, final Exception cause) { - return withCause( - new BatchUpdateException(EXCEPTION_MESSAGE_BATCH_FAILED, null, errorCode, EMPTY_BATCH_RESULT), cause); - } - private static BatchUpdateException createBatchUpdateException( - Exception cause, BulkWriteResult bulkWriteResult, int errorCode, CommandType commandType) { - var updateCount = getUpdateCount(commandType, bulkWriteResult); + int errorCode, + Exception cause, + MongoBulkWriteException mongoBulkWriteException, + WriteModelsToCommandMapper writeModelsToCommandMapper) { + List writeErrors = mongoBulkWriteException.getWriteErrors(); + var updateCount = 0; + var writeConcernError = mongoBulkWriteException.getWriteConcernError(); + if (writeConcernError == null) { + if (!writeErrors.isEmpty()) { + var failedModelIndex = writeErrors.get(0).getIndex(); + updateCount = writeModelsToCommandMapper.findCommandIndex(failedModelIndex); + } + } var updateCounts = new int[updateCount]; Arrays.fill(updateCounts, Statement.SUCCESS_NO_INFO); return withCause( - new BatchUpdateException(EXCEPTION_MESSAGE_BATCH_FAILED, null, errorCode, updateCounts), cause); + new BatchUpdateException(EXCEPTION_MESSAGE_OPERATION_FAILED, NULL_SQL_STATE, errorCode, updateCounts), + cause); } private static T withCause(T sqlException, final Exception cause) { @@ -400,29 +436,52 @@ private static T withCause(T sqlException, final Except } private static boolean isTimeoutException(final MongoException exception) { + // We do not check for `MongoExecutionTimeoutException` and `MongoOperationTimeoutException` here, + // because it is handled via error codes. return exception instanceof MongoSocketReadTimeoutException || exception instanceof MongoSocketWriteTimeoutException - || exception instanceof MongoTimeoutException - || exception instanceof MongoExecutionTimeoutException; + || exception instanceof MongoTimeoutException; } - enum CommandType { - INSERT("insert"), - UPDATE("update"), - DELETE("delete"), - AGGREGATE("aggregate"); + enum CommandDescription { + INSERT("insert", false, true), + UPDATE("update", false, true), + DELETE("delete", false, true), + AGGREGATE("aggregate", true, false); private final String commandName; + private final boolean returnsResultSet; + private final boolean isUpdate; - CommandType(String commandName) { + CommandDescription(String commandName, boolean returnsResultSet, boolean isUpdate) { this.commandName = commandName; + this.returnsResultSet = returnsResultSet; + this.isUpdate = isUpdate; } String getCommandName() { return commandName; } - static CommandType fromString(String commandName) throws SQLFeatureNotSupportedException { + /** + * Indicates whether the command is used in {@code executeUpdate(...)} or {@code executeBatch()} methods. + * + * @return true if the command is used in update operations, false if it is used in query operations. + */ + boolean isUpdate() { + return isUpdate; + } + + /** + * Indicates whether the command returns a {@link ResultSet}. + * + * @see #executeQuery(String) + */ + boolean returnsResultSet() { + return returnsResultSet; + } + + static CommandDescription fromString(String commandName) throws SQLFeatureNotSupportedException { return switch (commandName) { case "insert" -> INSERT; case "update" -> UPDATE; @@ -433,95 +492,159 @@ static CommandType fromString(String commandName) throws SQLFeatureNotSupportedE } } - enum ExecutionType { - UPDATE, - BATCH, - QUERY - } - private static class WriteModelConverter { - private static final List SUPPORTED_UPDATE_COMMAND_ELEMENTS = List.of("q", "u", "multi"); - private static final List SUPPORTED_DELETE_COMMAND_ELEMENTS = List.of("q", "limit"); - - static void convertToWriteModels( - CommandType commandType, BsonDocument command, Collection> writeModels) - throws SQLFeatureNotSupportedException { - switch (commandType) { - case INSERT: - var documents = command.getArray("documents"); - for (var insertDocument : documents) { - writeModels.add(createInsertModel(insertDocument.asDocument())); - } - break; - case UPDATE: - var updates = command.getArray("updates").getValues(); - for (var updateDocument : updates) { - writeModels.add(createUpdateModel(updateDocument.asDocument())); - } - break; - case DELETE: - var deletes = command.getArray("deletes"); - for (var deleteDocument : deletes) { - writeModels.add(createDeleteModel(deleteDocument.asDocument())); - } - break; - default: - throw fail(); + private static final String UNSUPPORTED_MESSAGE_STATEMENT_FIELD = "Unsupported field in %s statement: [%s]"; + private static final String UNSUPPORTED_MESSAGE_COMMAND_FIELD = "Unsupported field in %s command: [%s]"; + + private static final Set SUPPORTED_INSERT_COMMAND_FIELDS = Set.of("documents"); + + private static final Set SUPPORTED_UPDATE_COMMAND_FIELDS = Set.of("updates"); + private static final Set SUPPORTED_UPDATE_STATEMENT_FIELDS = Set.of("q", "u", "multi"); + + private static final Set SUPPORTED_DELETE_COMMAND_FIELDS = Set.of("deletes"); + private static final Set SUPPORTED_DELETE_STATEMENT_FIELDS = Set.of("q", "limit"); + + private WriteModelConverter() {} + + private static void convertToWriteModels( + CommandDescription commandDescription, + BsonDocument command, + Collection> writeModels) + throws SQLFeatureNotSupportedException, SQLSyntaxErrorException { + try { + switch (commandDescription) { + case INSERT: + checkCommandFields(command, commandDescription, SUPPORTED_INSERT_COMMAND_FIELDS); + var documentsToInsert = command.getArray("documents"); + for (var documentToInsert : documentsToInsert) { + writeModels.add(createInsertModel(documentToInsert.asDocument())); + } + break; + case UPDATE: + checkCommandFields(command, commandDescription, SUPPORTED_UPDATE_COMMAND_FIELDS); + var updateStatements = command.getArray("updates"); + for (var updateStatement : updateStatements) { + writeModels.add(createUpdateModel(updateStatement.asDocument(), commandDescription)); + } + break; + case DELETE: + checkCommandFields(command, commandDescription, SUPPORTED_DELETE_COMMAND_FIELDS); + var deleteStatements = command.getArray("deletes"); + for (var deleteStatement : deleteStatements) { + writeModels.add(createDeleteModel(deleteStatement.asDocument(), commandDescription)); + } + break; + default: + throw fail(commandDescription.toString()); + } + } catch (BSONException bsonException) { + throw new SQLSyntaxErrorException( + "Invalid MQL: [%s]".formatted(command.toJson(EXTENDED_JSON_WRITER_SETTINGS)), + NULL_SQL_STATE, + bsonException); } } - private static WriteModel createInsertModel(final BsonDocument insertDocument) { + private static WriteModel createInsertModel(BsonDocument insertDocument) { return new InsertOneModel<>(insertDocument); } - private static WriteModel createDeleteModel(final BsonDocument deleteDocument) + private static WriteModel createUpdateModel( + BsonDocument updateStatement, CommandDescription commandDescription) throws SQLFeatureNotSupportedException { - checkDeleteElements(deleteDocument); - var isSingleDelete = deleteDocument.getNumber("limit").intValue() == 1; - var queryFilter = deleteDocument.getDocument("q"); - - if (isSingleDelete) { - return new DeleteOneModel<>(queryFilter); + checkStatementFields(updateStatement, commandDescription, SUPPORTED_UPDATE_STATEMENT_FIELDS); + var isMulti = updateStatement.getBoolean("multi", FALSE).getValue(); + var filter = updateStatement.getDocument("q"); + var updateModification = updateStatement.get("u"); + if (updateModification == null) { + // We force exception here because the field is mandatory. + updateStatement.getDocument("u"); + } + if (!(updateModification instanceof BsonDocument uDocument)) { + throw new SQLFeatureNotSupportedException("Only document type is supported as value for field: [u]"); + } + if (isMulti) { + return new UpdateManyModel<>(filter, uDocument); } - return new DeleteManyModel<>(queryFilter); + return new UpdateOneModel<>(filter, uDocument); } - private static WriteModel createUpdateModel(final BsonDocument updateDocument) + private static WriteModel createDeleteModel( + BsonDocument deleteStatement, CommandDescription commandDescription) throws SQLFeatureNotSupportedException { - checkUpdateElements(updateDocument); - var isMulti = updateDocument.getBoolean("multi").getValue(); - var queryFilter = updateDocument.getDocument("q"); - var updatePipeline = updateDocument.getDocument("u"); + checkStatementFields(deleteStatement, commandDescription, SUPPORTED_DELETE_STATEMENT_FIELDS); + var isSingleDelete = deleteStatement.getNumber("limit").intValue() == 1; + var filter = deleteStatement.getDocument("q"); - if (isMulti) { - return new UpdateManyModel<>(queryFilter, updatePipeline); + if (isSingleDelete) { + return new DeleteOneModel<>(filter); } - return new UpdateOneModel<>(queryFilter, updatePipeline); + return new DeleteManyModel<>(filter); } - private static void checkDeleteElements(final BsonDocument deleteDocument) + private static void checkStatementFields( + BsonDocument statement, CommandDescription commandDescription, Set supportedStatementFields) throws SQLFeatureNotSupportedException { - if (deleteDocument.size() > SUPPORTED_DELETE_COMMAND_ELEMENTS.size()) { - var unSupportedElements = getUnsupportedElements(deleteDocument, SUPPORTED_DELETE_COMMAND_ELEMENTS); - throw new SQLFeatureNotSupportedException( - format("Unsupported elements in delete command: %s", unSupportedElements)); - } + checkFields( + commandDescription, + UNSUPPORTED_MESSAGE_STATEMENT_FIELD, + supportedStatementFields, + statement.keySet().iterator()); } - private static void checkUpdateElements(final BsonDocument updateDocument) + private static void checkCommandFields( + BsonDocument command, CommandDescription commandDescription, Set supportedCommandFields) throws SQLFeatureNotSupportedException { - if (updateDocument.size() > SUPPORTED_UPDATE_COMMAND_ELEMENTS.size()) { - var unSupportedElements = getUnsupportedElements(updateDocument, SUPPORTED_UPDATE_COMMAND_ELEMENTS); - throw new SQLFeatureNotSupportedException( - format("Unsupported elements in update command: %s", unSupportedElements)); + var iterator = command.keySet().iterator(); + iterator.next(); // skip the command name + checkFields(commandDescription, UNSUPPORTED_MESSAGE_COMMAND_FIELD, supportedCommandFields, iterator); + } + + private static void checkFields( + CommandDescription commandDescription, + String exceptionMessage, + Set supportedCommandFields, + Iterator fieldIterator) + throws SQLFeatureNotSupportedException { + while (fieldIterator.hasNext()) { + var field = fieldIterator.next(); + if (!supportedCommandFields.contains(field)) { + throw new SQLFeatureNotSupportedException( + exceptionMessage.formatted(commandDescription.getCommandName(), field)); + } } } + } - private static List getUnsupportedElements( - final BsonDocument deleteDocument, final List supportedElements) { - return deleteDocument.keySet().stream() - .filter((key) -> !supportedElements.contains(key)) - .toList(); + /** Maps write model indices to their corresponding command indices in batch of commands. */ + private static class WriteModelsToCommandMapper { + /** The cumulative counts of write models for each command in the batch (prefix sum). */ + private final int[] cumulativeCounts; + + private int index; + + private WriteModelsToCommandMapper(int commandCount) { + this.cumulativeCounts = new int[commandCount]; + this.index = 0; + } + + private void add(int cumulativeWriteModelCount) { + assertFalse(index >= cumulativeCounts.length); + cumulativeCounts[index++] = cumulativeWriteModelCount; + } + + private int findCommandIndex(int writeModelIndex) { + assertTrue(index >= cumulativeCounts.length); + int lo = 0, hi = cumulativeCounts.length; + while (lo < hi) { + var mid = (lo + hi) >>> 1; + if (cumulativeCounts[mid] >= writeModelIndex + 1) { + hi = mid; + } else { + lo = mid + 1; + } + } + return lo; } } } diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java index 2a376611..325a3a3b 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java @@ -16,7 +16,6 @@ package com.mongodb.hibernate.jdbc; -import static java.lang.Math.max; import static java.sql.Statement.SUCCESS_NO_INFO; import static java.util.Collections.emptyList; import static java.util.Collections.emptySet; @@ -50,6 +49,7 @@ import com.mongodb.bulk.BulkWriteError; import com.mongodb.bulk.BulkWriteInsert; import com.mongodb.bulk.BulkWriteResult; +import com.mongodb.bulk.WriteConcernError; import com.mongodb.client.AggregateIterable; import com.mongodb.client.ClientSession; import com.mongodb.client.MongoCollection; @@ -66,7 +66,6 @@ import java.sql.SQLException; import java.sql.SQLIntegrityConstraintViolationException; import java.sql.SQLSyntaxErrorException; -import java.sql.SQLTimeoutException; import java.sql.Time; import java.sql.Timestamp; import java.sql.Types; @@ -84,6 +83,7 @@ import org.bson.types.ObjectId; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Named; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -183,19 +183,26 @@ void testSuccess() throws SQLException { } @Nested - class ExecuteMethodThrowsSqlExceptionTests { + class ExecuteThrowsSqlExceptionTests { private static final String DUMMY_EXCEPTION_MESSAGE = "Test message"; - private static final ServerAddress DUMMY_SERVER_ADDRESS = new ServerAddress("localhost"); - - private static final BulkWriteError BULK_WRITE_ERROR = - new BulkWriteError(10, DUMMY_EXCEPTION_MESSAGE, new BsonDocument(), 0); + private static final ServerAddress DUMMY_SERVER_ADDRESS = new ServerAddress(); + private static final BsonDocument DUMMY_ERROR_DETAILS = new BsonDocument(); private static final BulkWriteResult BULK_WRITE_RESULT = BulkWriteResult.acknowledged( 1, 0, 2, 3, emptyList(), List.of(new BulkWriteInsert(0, new BsonObjectId(new ObjectId(1, 2))))); - private static final MongoBulkWriteException MONGO_BULK_WRITE_EXCEPTION_WITH_ERRORS = + private static final MongoBulkWriteException MONGO_BULK_WRITE_EXCEPTION_WITH_WRITE_ERRORS = + new MongoBulkWriteException( + BULK_WRITE_RESULT, + List.of(new BulkWriteError(10, DUMMY_EXCEPTION_MESSAGE, DUMMY_ERROR_DETAILS, 0)), + null, + DUMMY_SERVER_ADDRESS, + emptySet()); + private static final MongoBulkWriteException MONGO_BULK_WRITE_EXCEPTION_WITH_WRITE_CONCERN_EXCEPTION = new MongoBulkWriteException( - BULK_WRITE_RESULT, List.of(BULK_WRITE_ERROR), null, DUMMY_SERVER_ADDRESS, emptySet()); - private static final MongoBulkWriteException MONGO_BULK_WRITE_EXCEPTION_NO_ERRORS = - new MongoBulkWriteException(BULK_WRITE_RESULT, emptyList(), null, DUMMY_SERVER_ADDRESS, emptySet()); + BULK_WRITE_RESULT, + emptyList(), + new WriteConcernError(10, "No code name", DUMMY_EXCEPTION_MESSAGE, DUMMY_ERROR_DETAILS), + DUMMY_SERVER_ADDRESS, + emptySet()); private static final String MQL_ITEMS_AGGREGATE = """ @@ -214,6 +221,9 @@ class ExecuteMethodThrowsSqlExceptionTests { insert: "items", documents: [ { _id: 1 } + { _id: 2 } + { _id: 3 } + { _id: 4 } ] } """; @@ -223,6 +233,9 @@ class ExecuteMethodThrowsSqlExceptionTests { update: "items", updates: [ { q: { _id: 1 }, u: { $set: { touched: true } }, multi: false } + { q: { _id: 1 }, u: { $set: { touched: true } }, multi: false } + { q: { _id: 1 }, u: { $set: { touched: true } }, multi: false } + { q: { _id: 1 }, u: { $set: { touched: true } }, multi: false } ] } """; @@ -232,6 +245,9 @@ class ExecuteMethodThrowsSqlExceptionTests { delete: "items", deletes: [ { q: { _id: 1 }, limit: 1 } + { q: { _id: 1 }, limit: 1 } + { q: { _id: 1 }, limit: 1 } + { q: { _id: 1 }, limit: 1 } ] } """; @@ -241,6 +257,13 @@ void beforeEach() { doReturn(mongoCollection).when(mongoDatabase).getCollection(anyString(), eq(BsonDocument.class)); } + private static Stream> mqlCommands() { + return Stream.of( + named("insert", MQL_ITEMS_INSERT), + named("update", MQL_ITEMS_UPDATE), + named("delete", MQL_ITEMS_DELETE)); + } + private static Stream timeoutExceptions() { var dummyCause = new RuntimeException(); return Stream.of( @@ -253,11 +276,13 @@ private static Stream timeoutExceptions() { ); } + private static Stream constraintViolationErrorCodes() { + return Stream.of(11000, 11001, 12582); + } + private static Stream constraintViolationExceptions() { - return Stream.of( - new MongoException(11000, DUMMY_EXCEPTION_MESSAGE), - new MongoException(11001, DUMMY_EXCEPTION_MESSAGE), - new MongoException(12582, DUMMY_EXCEPTION_MESSAGE)); + return constraintViolationErrorCodes() + .map(errorCode -> new MongoException(errorCode, DUMMY_EXCEPTION_MESSAGE)); } private static Stream genericMongoExceptions() { @@ -268,22 +293,18 @@ private static Stream genericMongoExceptions() { @ParameterizedTest(name = "test executeBatch MongoException. Parameters: Parameters: mongoException: {0}") @MethodSource("genericMongoExceptions") void testExecuteBatchMongoException(MongoException mongoException) throws SQLException { - int expectedErrorCode = max(0, mongoException.getCode()); doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); - assertExecuteBatchThrowsSqlException(batchUpdateException -> { - assertThatObject(batchUpdateException) - .returns(expectedErrorCode, BatchUpdateException::getErrorCode) - .returns(null, BatchUpdateException::getSQLState) - .returns(mongoException, SQLException::getCause) - .satisfies(exception -> { - assertUpdateCounts(batchUpdateException.getUpdateCounts(), 0); - }); + assertExecuteBatchThrowsSqlException(sqlException -> { + assertThatObject(sqlException) + .returns(mongoException.getCode(), SQLException::getErrorCode) + .returns(null, SQLException::getSQLState) + .returns(mongoException, SQLException::getCause); }); } @ParameterizedTest(name = "test executeUpdate MongoException. Parameters: Parameters: mongoException: {0}") - @MethodSource("genericMongoExceptions") + @MethodSource({"genericMongoExceptions", "timeoutExceptions"}) void testExecuteUpdateMongoException(MongoException mongoException) throws SQLException { doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); assertExecuteUpdateThrowsSqlException( @@ -291,30 +312,13 @@ void testExecuteUpdateMongoException(MongoException mongoException) throws SQLEx } @ParameterizedTest(name = "test executeUQuery MongoException. Parameters: Parameters: mongoException: {0}") - @MethodSource("genericMongoExceptions") + @MethodSource({"genericMongoExceptions", "timeoutExceptions"}) void testExecuteQueryMongoException(MongoException mongoException) throws SQLException { doThrow(mongoException).when(mongoCollection).aggregate(eq(clientSession), anyList()); assertExecuteQueryThrowsSqlException( sqlException -> assertGenericMongoException(sqlException, mongoException)); } - @ParameterizedTest( - name = "test executeUpdate timeout exception. Parameters: Parameters: mongoTimeoutException: {0}") - @MethodSource("timeoutExceptions") - void testExecuteUpdateTimeoutException(MongoException mongoTimeoutException) throws SQLException { - doThrow(mongoTimeoutException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); - assertExecuteUpdateThrowsSqlException( - sqlException -> assertTimeoutException(mongoTimeoutException, sqlException)); - } - - @ParameterizedTest(name = "test executeQuery timeout exception. Parameters: mongoTimeoutException: {0}") - @MethodSource("timeoutExceptions") - void testExecuteQueryTimeoutException(MongoException mongoTimeoutException) throws SQLException { - doThrow(mongoTimeoutException).when(mongoCollection).aggregate(eq(clientSession), anyList()); - assertExecuteQueryThrowsSqlException( - sqlException -> assertTimeoutException(mongoTimeoutException, sqlException)); - } - @ParameterizedTest(name = "test executeUpdate constraint violation. Parameters: mongoException: {0}") @MethodSource("constraintViolationExceptions") void testExecuteUpdateConstraintViolationException(MongoException mongoException) throws SQLException { @@ -337,51 +341,33 @@ void testExecuteQueryConstraintViolationException(MongoException mongoException) }); } - private static void assertConstraintViolationException( - final MongoException mongoException, final SQLException sqlException, final int expectedErrorCode) { - assertThatObject(sqlException) - .asInstanceOf(type(SQLIntegrityConstraintViolationException.class)) - .returns(expectedErrorCode, SQLIntegrityConstraintViolationException::getErrorCode) - .returns(null, SQLIntegrityConstraintViolationException::getSQLState) - .returns(mongoException, SQLIntegrityConstraintViolationException::getCause); - } - @ParameterizedTest(name = "test executeBatch timeout exception. Parameters: mongoTimeoutException: {0}") @MethodSource("timeoutExceptions") void testExecuteBatchTimeoutException(MongoException mongoTimeoutException) throws SQLException { doThrow(mongoTimeoutException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); assertExecuteBatchThrowsSqlException(batchUpdateException -> { - int expectedErrorCode = max(0, mongoTimeoutException.getCode()); - assertThatObject(batchUpdateException) - .returns(expectedErrorCode, BatchUpdateException::getErrorCode) - .returns(null, BatchUpdateException::getSQLState) - .satisfies(ex -> { - assertUpdateCounts(ex.getUpdateCounts(), 0); - }) - .extracting(SQLException::getCause) - .asInstanceOf(type(SQLTimeoutException.class)) - .returns(expectedErrorCode, SQLTimeoutException::getErrorCode) - .returns(mongoTimeoutException, SQLTimeoutException::getCause); + assertGenericMongoException(batchUpdateException, mongoTimeoutException); }); } @ParameterizedTest(name = "test executeBatch constraint violation. Parameters: mongoException: {0}") - @MethodSource("constraintViolationExceptions") - void testExecuteBatchConstraintViolationException(MongoException mongoException) throws SQLException { - int expectedErrorCode = mongoException.getCode(); - doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + @MethodSource("constraintViolationErrorCodes") + void testExecuteBatchConstraintViolationException(int errorCode) throws SQLException { + MongoBulkWriteException mongoBulkWriteException = createMongoBulkWriteException(errorCode, 0); - assertExecuteBatchThrowsSqlException(batchUpdateException -> { + doThrow(mongoBulkWriteException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); + + assertExecuteBatchThrowsBatchUpdateException(batchUpdateException -> { assertThatObject(batchUpdateException) - .returns(expectedErrorCode, BatchUpdateException::getErrorCode) + .returns(errorCode, BatchUpdateException::getErrorCode) .returns(null, BatchUpdateException::getSQLState) .satisfies(ex -> { assertUpdateCounts(ex.getUpdateCounts(), 0); }) .extracting(SQLException::getCause) .asInstanceOf(type(SQLIntegrityConstraintViolationException.class)) - .returns(expectedErrorCode, SQLIntegrityConstraintViolationException::getErrorCode) - .returns(mongoException, SQLIntegrityConstraintViolationException::getCause); + .returns(errorCode, SQLIntegrityConstraintViolationException::getErrorCode) + .returns(mongoBulkWriteException, SQLIntegrityConstraintViolationException::getCause); }); } @@ -389,14 +375,11 @@ void testExecuteBatchConstraintViolationException(MongoException mongoException) void testExecuteBatchRuntimeExceptionCause() throws SQLException { RuntimeException runtimeException = new RuntimeException(); doThrow(runtimeException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); - assertExecuteBatchThrowsSqlException(batchUpdateException -> { - assertThatObject(batchUpdateException) - .returns(0, BatchUpdateException::getErrorCode) - .returns(null, BatchUpdateException::getSQLState) - .returns(runtimeException, BatchUpdateException::getCause) - .satisfies(ex -> { - assertUpdateCounts(ex.getUpdateCounts(), 0); - }); + assertExecuteBatchThrowsSqlException(sqlException -> { + assertThatObject(sqlException) + .returns(0, SQLException::getErrorCode) + .returns(null, SQLException::getSQLState) + .returns(runtimeException, SQLException::getCause); }); } @@ -417,13 +400,10 @@ void testExecuteQueryRuntimeExceptionCause() throws SQLException { } private static Stream bulkWriteExceptionsForExecuteUpdate() { - return Stream.of( - Arguments.of(named("insert", MQL_ITEMS_INSERT), MONGO_BULK_WRITE_EXCEPTION_NO_ERRORS), - Arguments.of(named("update", MQL_ITEMS_UPDATE), MONGO_BULK_WRITE_EXCEPTION_NO_ERRORS), - Arguments.of(named("delete", MQL_ITEMS_DELETE), MONGO_BULK_WRITE_EXCEPTION_NO_ERRORS), - Arguments.of(named("insert", MQL_ITEMS_INSERT), MONGO_BULK_WRITE_EXCEPTION_WITH_ERRORS), - Arguments.of(named("update", MQL_ITEMS_UPDATE), MONGO_BULK_WRITE_EXCEPTION_WITH_ERRORS), - Arguments.of(named("delete", MQL_ITEMS_DELETE), MONGO_BULK_WRITE_EXCEPTION_WITH_ERRORS)); + return mqlCommands() + .flatMap(mqlCommand -> Stream.of( + Arguments.of(mqlCommand, MONGO_BULK_WRITE_EXCEPTION_WITH_WRITE_CONCERN_EXCEPTION), + Arguments.of(mqlCommand, MONGO_BULK_WRITE_EXCEPTION_WITH_WRITE_ERRORS))); } @ParameterizedTest( @@ -443,37 +423,35 @@ void testExecuteUpdateMongoBulkWriteException(String mql, MongoBulkWriteExceptio } } - private static Stream bulkWriteExceptionsForExecuteBatch() { - return Stream.of( - Arguments.of( - named("insert", MQL_ITEMS_INSERT), - MONGO_BULK_WRITE_EXCEPTION_NO_ERRORS, - BULK_WRITE_RESULT.getInsertedCount()), - Arguments.of( - named("update", MQL_ITEMS_UPDATE), - MONGO_BULK_WRITE_EXCEPTION_NO_ERRORS, - BULK_WRITE_RESULT.getModifiedCount()), - Arguments.of( - named("delete", MQL_ITEMS_DELETE), - MONGO_BULK_WRITE_EXCEPTION_NO_ERRORS, - BULK_WRITE_RESULT.getDeletedCount()), - Arguments.of( - named("insert", MQL_ITEMS_INSERT), - MONGO_BULK_WRITE_EXCEPTION_WITH_ERRORS, - BULK_WRITE_RESULT.getInsertedCount()), - Arguments.of( - named("update", MQL_ITEMS_UPDATE), - MONGO_BULK_WRITE_EXCEPTION_WITH_ERRORS, - BULK_WRITE_RESULT.getModifiedCount()), - Arguments.of( - named("delete", MQL_ITEMS_DELETE), - MONGO_BULK_WRITE_EXCEPTION_WITH_ERRORS, - BULK_WRITE_RESULT.getDeletedCount())); + private static Stream testExecuteBatchMongoBulkWriteException() { + return mqlCommands() + .flatMap(mqlCommand -> Stream.of( + // Error in command 1 + Arguments.of( + mqlCommand, // MQL command to execute + createMongoBulkWriteException(1), // failed model index + 0), // expected update count length + Arguments.of(mqlCommand, createMongoBulkWriteException(2), 0), + Arguments.of(mqlCommand, createMongoBulkWriteException(3), 0), + + // Error in command 2 + Arguments.of(mqlCommand, createMongoBulkWriteException(4), 1), + Arguments.of(mqlCommand, createMongoBulkWriteException(5), 1), + Arguments.of(mqlCommand, createMongoBulkWriteException(6), 1), + Arguments.of(mqlCommand, createMongoBulkWriteException(7), 1), + + // Error in command 3 + Arguments.of(mqlCommand, createMongoBulkWriteException(8), 2), + Arguments.of(mqlCommand, createMongoBulkWriteException(9), 2), + Arguments.of(mqlCommand, createMongoBulkWriteException(10), 2), + Arguments.of(mqlCommand, createMongoBulkWriteException(11), 2), + Arguments.of(mqlCommand, MONGO_BULK_WRITE_EXCEPTION_WITH_WRITE_CONCERN_EXCEPTION, 0))); } @ParameterizedTest( - name = "test executeBatch MongoBulkWriteException. Parameters: commandName={0}, exception={1}") - @MethodSource("bulkWriteExceptionsForExecuteBatch") + name = + "test executeBatch MongoBulkWriteException. Parameters: commandName={0}, exception={1}, expectedUpdateCountLength={2}") + @MethodSource("testExecuteBatchMongoBulkWriteException") void testExecuteBatchMongoBulkWriteException( String mql, MongoBulkWriteException mongoBulkWriteException, int expectedUpdateCountLength) throws SQLException { @@ -482,14 +460,20 @@ void testExecuteBatchMongoBulkWriteException( try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(mql)) { mongoPreparedStatement.addBatch(); + mongoPreparedStatement.addBatch(); + mongoPreparedStatement.addBatch(); + assertThatExceptionOfType(BatchUpdateException.class) .isThrownBy(mongoPreparedStatement::executeBatch) - .withCause(mongoBulkWriteException) .returns(vendorCodeError, BatchUpdateException::getErrorCode) .returns(null, BatchUpdateException::getSQLState) .satisfies(ex -> { assertUpdateCounts(ex.getUpdateCounts(), expectedUpdateCountLength); - }); + }) + .havingCause() + .isInstanceOf(SQLException.class) + .havingCause() + .isSameAs(mongoBulkWriteException); } } @@ -502,25 +486,23 @@ private static void assertGenericException(final SQLException sqlException, Runt } private static void assertGenericMongoException(final SQLException sqlException, final MongoException cause) { - int expectedErrorCode = max(0, cause.getCode()); assertThatObject(sqlException) .isExactlyInstanceOf(SQLException.class) - .returns(expectedErrorCode, SQLException::getErrorCode) + .returns(cause.getCode(), SQLException::getErrorCode) .returns(null, SQLException::getSQLState) .returns(cause, SQLException::getCause); } - private static void assertTimeoutException( - final MongoException mongoTimeoutException, final SQLException sqlException) { - int expectedErrorCode = max(0, mongoTimeoutException.getCode()); + private static void assertConstraintViolationException( + final MongoException mongoException, final SQLException sqlException, final int expectedErrorCode) { assertThatObject(sqlException) - .asInstanceOf(type(SQLTimeoutException.class)) - .returns(expectedErrorCode, SQLTimeoutException::getErrorCode) - .returns(null, SQLTimeoutException::getSQLState) - .returns(mongoTimeoutException, SQLTimeoutException::getCause); + .asInstanceOf(type(SQLIntegrityConstraintViolationException.class)) + .returns(expectedErrorCode, SQLIntegrityConstraintViolationException::getErrorCode) + .returns(null, SQLIntegrityConstraintViolationException::getSQLState) + .returns(mongoException, SQLIntegrityConstraintViolationException::getCause); } - private void assertExecuteBatchThrowsSqlException(ThrowingConsumer asserter) + private void assertExecuteBatchThrowsBatchUpdateException(ThrowingConsumer asserter) throws SQLException { try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(MQL_ITEMS_INSERT)) { mongoPreparedStatement.addBatch(); @@ -530,6 +512,16 @@ private void assertExecuteBatchThrowsSqlException(ThrowingConsumer asserter) throws SQLException { + try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(MQL_ITEMS_INSERT)) { + mongoPreparedStatement.addBatch(); + assertThatExceptionOfType(SQLException.class) + .isThrownBy(mongoPreparedStatement::executeBatch) + .isExactlyInstanceOf(SQLException.class) + .satisfies(asserter); + } + } + private void assertExecuteUpdateThrowsSqlException(ThrowingConsumer asserter) throws SQLException { try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(MQL_ITEMS_INSERT)) { @@ -547,19 +539,40 @@ private void assertExecuteQueryThrowsSqlException(ThrowingConsumer } } + private static void assertUpdateCounts(final int[] actualUpdateCounts, int expectedUpdateCountsLength) { + assertEquals(expectedUpdateCountsLength, actualUpdateCounts.length); + for (int count : actualUpdateCounts) { + assertEquals(SUCCESS_NO_INFO, count); + } + } + + private static MongoBulkWriteException createMongoBulkWriteException( + final int errorCode, final int failedModelIndex) { + return new MongoBulkWriteException( + BULK_WRITE_RESULT, + List.of(new BulkWriteError( + errorCode, DUMMY_EXCEPTION_MESSAGE, DUMMY_ERROR_DETAILS, failedModelIndex)), + null, + DUMMY_SERVER_ADDRESS, + emptySet()); + } + + private static MongoBulkWriteException createMongoBulkWriteException(final int failedModelIndex) { + return new MongoBulkWriteException( + BULK_WRITE_RESULT, + List.of(new BulkWriteError( + failedModelIndex, DUMMY_EXCEPTION_MESSAGE, DUMMY_ERROR_DETAILS, failedModelIndex)), + null, + DUMMY_SERVER_ADDRESS, + emptySet()); + } + private static Integer getVendorCodeError(final MongoBulkWriteException mongoBulkWriteException) { return mongoBulkWriteException.getWriteErrors().stream() .map(BulkWriteError::getCode) .findFirst() .orElse(0); } - - private static void assertUpdateCounts(final int[] updateCounts, int expectedUpdateCountsLength) { - assertEquals(expectedUpdateCountsLength, updateCounts.length); - for (int count : updateCounts) { - assertEquals(SUCCESS_NO_INFO, count); - } - } } @Test From 7c8a7026a03d573658870577aeb8d39c2cd74e84 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Mon, 27 Oct 2025 18:43:40 -0700 Subject: [PATCH 18/37] Make NULL_SQL_STATE private. --- .../java/com/mongodb/hibernate/jdbc/MongoStatement.java | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 699377ec..53cb5b9c 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -76,7 +76,7 @@ class MongoStatement implements StatementAdapter { static final int NO_ERROR_CODE = 0; static final int[] EMPTY_BATCH_RESULT = new int[0]; - @Nullable public static final String NULL_SQL_STATE = null; + @Nullable private static final String NULL_SQL_STATE = null; private final MongoDatabase mongoDatabase; private final MongoConnection mongoConnection; @@ -353,9 +353,8 @@ private static SQLException handleBatchException( errorCode, cause, bulkWriteException, assertNotNull(writeModelsToCommandMapper)); } // TODO-HIBERNATE-132 BatchUpdateException is thrown when one of the commands fails to execute properly. - // When - // exception is not of MongoBulkWriteException, we are not sure if any command was executed successfully or - // failed. + // When exception is not of MongoBulkWriteException, we are not sure if any command was executed + // successfully or failed. return cause; } return new SQLException(EXCEPTION_MESSAGE_OPERATION_FAILED, null, errorCode, exceptionToHandle); From 51103befcd735487eee966e9c158b1434759fc64 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Mon, 27 Oct 2025 18:48:26 -0700 Subject: [PATCH 19/37] Remove final in parameters. --- .../MongoPreparedStatementIntegrationTests.java | 6 +++--- .../mongodb/hibernate/jdbc/MongoStatement.java | 8 ++++---- .../jdbc/MongoPreparedStatementTests.java | 15 +++++++-------- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java index 3ccfd3cf..89ac357b 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java @@ -1017,17 +1017,17 @@ private void assertExecuteUpdate( .containsExactlyElementsOf(expectedDocuments); } - private static String getFieldName(final String unsupportedField) { + private static String getFieldName(String unsupportedField) { return BsonDocument.parse("{" + unsupportedField + "}").getFirstKey(); } - private String toExtendedJson(final String mql) { + private String toExtendedJson(String mql) { return BsonDocument.parse(mql).toJson(EXTENDED_JSON_WRITER_SETTINGS); } } private void assertInvalidMql( - final String mql, SqlConsumer executor, String expectedExceptionMessage) { + String mql, SqlConsumer executor, String expectedExceptionMessage) { doWorkAwareOfAutoCommit(connection -> { try (PreparedStatement pstm = connection.prepareStatement(mql)) { assertThatThrownBy(() -> executor.accept(pstm)) diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 53cb5b9c..cbb66771 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -368,7 +368,7 @@ private static SQLException handleQueryOrUpdateException(RuntimeException except return new SQLException(EXCEPTION_MESSAGE_OPERATION_FAILED, null, errorCode, exceptionToHandle); } - private static SQLException handleMongoException(final int errorCode, final MongoException exceptionToHandle) { + private static SQLException handleMongoException(int errorCode, MongoException exceptionToHandle) { if (isTimeoutException(exceptionToHandle)) { return new SQLException(EXCEPTION_MESSAGE_TIMEOUT, null, errorCode, exceptionToHandle); } @@ -386,7 +386,7 @@ private static SQLException handleMongoException(final int errorCode, final Mong }; } - private static int getErrorCode(final RuntimeException runtimeException) { + private static int getErrorCode(RuntimeException runtimeException) { if (runtimeException instanceof MongoBulkWriteException mongoBulkWriteException) { var writeErrors = mongoBulkWriteException.getWriteErrors(); if (writeErrors.isEmpty()) { @@ -426,7 +426,7 @@ private static BatchUpdateException createBatchUpdateException( cause); } - private static T withCause(T sqlException, final Exception cause) { + private static T withCause(T sqlException, Exception cause) { sqlException.initCause(cause); if (cause instanceof SQLException sqlExceptionCause) { sqlException.setNextException(sqlExceptionCause); @@ -434,7 +434,7 @@ private static T withCause(T sqlException, final Except return sqlException; } - private static boolean isTimeoutException(final MongoException exception) { + private static boolean isTimeoutException(MongoException exception) { // We do not check for `MongoExecutionTimeoutException` and `MongoOperationTimeoutException` here, // because it is handled via error codes. return exception instanceof MongoSocketReadTimeoutException diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java index 743bb57f..947374a2 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java @@ -474,7 +474,7 @@ void testExecuteBatchMongoBulkWriteException( } } - private static void assertGenericException(final SQLException sqlException, RuntimeException cause) { + private static void assertGenericException(SQLException sqlException, RuntimeException cause) { assertThatObject(sqlException) .isExactlyInstanceOf(SQLException.class) .returns(0, SQLException::getErrorCode) @@ -482,7 +482,7 @@ private static void assertGenericException(final SQLException sqlException, Runt .returns(cause, SQLException::getCause); } - private static void assertGenericMongoException(final SQLException sqlException, final MongoException cause) { + private static void assertGenericMongoException(SQLException sqlException, MongoException cause) { assertThatObject(sqlException) .isExactlyInstanceOf(SQLException.class) .returns(cause.getCode(), SQLException::getErrorCode) @@ -491,7 +491,7 @@ private static void assertGenericMongoException(final SQLException sqlException, } private static void assertConstraintViolationException( - final MongoException mongoException, final SQLException sqlException, final int expectedErrorCode) { + MongoException mongoException, SQLException sqlException, int expectedErrorCode) { assertThatObject(sqlException) .asInstanceOf(type(SQLIntegrityConstraintViolationException.class)) .returns(expectedErrorCode, SQLIntegrityConstraintViolationException::getErrorCode) @@ -536,15 +536,14 @@ private void assertExecuteQueryThrowsSqlException(ThrowingConsumer } } - private static void assertUpdateCounts(final int[] actualUpdateCounts, int expectedUpdateCountsLength) { + private static void assertUpdateCounts(int[] actualUpdateCounts, int expectedUpdateCountsLength) { assertEquals(expectedUpdateCountsLength, actualUpdateCounts.length); for (int count : actualUpdateCounts) { assertEquals(SUCCESS_NO_INFO, count); } } - private static MongoBulkWriteException createMongoBulkWriteException( - final int errorCode, final int failedModelIndex) { + private static MongoBulkWriteException createMongoBulkWriteException(int errorCode, int failedModelIndex) { return new MongoBulkWriteException( BULK_WRITE_RESULT, List.of(new BulkWriteError( @@ -554,7 +553,7 @@ private static MongoBulkWriteException createMongoBulkWriteException( emptySet()); } - private static MongoBulkWriteException createMongoBulkWriteException(final int failedModelIndex) { + private static MongoBulkWriteException createMongoBulkWriteException(int failedModelIndex) { return new MongoBulkWriteException( BULK_WRITE_RESULT, List.of(new BulkWriteError( @@ -564,7 +563,7 @@ private static MongoBulkWriteException createMongoBulkWriteException(final int f emptySet()); } - private static Integer getVendorCodeError(final MongoBulkWriteException mongoBulkWriteException) { + private static Integer getVendorCodeError(MongoBulkWriteException mongoBulkWriteException) { return mongoBulkWriteException.getWriteErrors().stream() .map(BulkWriteError::getCode) .findFirst() From 4317e2f1cb0e8b8a704541d2de1967b6fe346873 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Tue, 28 Oct 2025 10:21:13 -0700 Subject: [PATCH 20/37] Remove catch. --- .../hibernate/jdbc/MongoPreparedStatementIntegrationTests.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java index 89ac357b..9b57b3d4 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java @@ -358,8 +358,6 @@ void testEmptyBatch() { try (var pstmt = connection.prepareStatement(INSERT_MQL)) { var updateCounts = pstmt.executeBatch(); assertEquals(0, updateCounts.length); - } catch (SQLException e) { - throw new RuntimeException(e); } }); From 2b5be783d0dd107d4f7448481ac85afd234543e6 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Tue, 28 Oct 2025 11:33:42 -0700 Subject: [PATCH 21/37] Use constants. --- .../MongoPreparedStatementIntegrationTests.java | 12 +++++------- .../hibernate/jdbc/MongoPreparedStatement.java | 16 ++++++++-------- .../mongodb/hibernate/jdbc/MongoStatement.java | 16 ++++++++-------- 3 files changed, 21 insertions(+), 23 deletions(-) diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java index 9b57b3d4..20712a48 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java @@ -788,7 +788,7 @@ void testAbsentRequiredUpdateCommandField() { .isInstanceOf(SQLSyntaxErrorException.class) .hasMessage("Invalid MQL: [%s]".formatted(toExtendedJson(mql))) .cause() - .hasMessageContaining("Document does not contain key updates"); + .hasMessage("Document does not contain key updates"); } }); } @@ -831,7 +831,7 @@ void testAbsentRequiredDeleteCommandField() { .isInstanceOf(SQLSyntaxErrorException.class) .hasMessage("Invalid MQL: [%s]".formatted(toExtendedJson(mql))) .cause() - .hasMessageContaining("Document does not contain key deletes"); + .hasMessage("Document does not contain key deletes"); } }); } @@ -881,7 +881,7 @@ void testAbsentRequiredInsertCommandField() { .isInstanceOf(SQLSyntaxErrorException.class) .hasMessage("Invalid MQL: [%s]".formatted(toExtendedJson(mql))) .cause() - .hasMessageContaining("Document does not contain key documents"); + .hasMessage("Document does not contain key documents"); } }); } @@ -944,8 +944,7 @@ void testAbsentRequiredUpdateStatementField(String fieldToRemove) { .isInstanceOf(SQLSyntaxErrorException.class) .hasMessage("Invalid MQL: [%s]".formatted(toExtendedJson(mql))) .cause() - .hasMessageContaining( - "Document does not contain key %s".formatted(getFieldName(fieldToRemove))); + .hasMessage("Document does not contain key %s".formatted(getFieldName(fieldToRemove))); } }); } @@ -996,8 +995,7 @@ void testAbsentRequiredDeleteStatementField(String fieldToRemove) { .isInstanceOf(SQLSyntaxErrorException.class) .hasMessage("Invalid MQL: [%s]".formatted(toExtendedJson(mql))) .cause() - .hasMessageContaining( - "Document does not contain key %s".formatted(getFieldName(fieldToRemove))); + .hasMessage("Document does not contain key %s".formatted(getFieldName(fieldToRemove))); } }); } diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java index c95c1f6c..bf4e9a90 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java @@ -54,14 +54,15 @@ final class MongoPreparedStatement extends MongoStatement implements PreparedSta private final BsonDocument command; private final List commandBatch; private final List parameterValueSetters; + private static final int[] EMPTY_BATCH_RESULT = new int[0]; MongoPreparedStatement( MongoDatabase mongoDatabase, ClientSession clientSession, MongoConnection mongoConnection, String mql) throws SQLSyntaxErrorException { super(mongoDatabase, clientSession, mongoConnection); - command = MongoStatement.parse(mql); - commandBatch = new ArrayList<>(); - parameterValueSetters = new ArrayList<>(); + this.command = MongoStatement.parse(mql); + this.commandBatch = new ArrayList<>(); + this.parameterValueSetters = new ArrayList<>(); parseParameters(command, parameterValueSetters); } @@ -218,12 +219,11 @@ private void checkSupportedBatchCommand(BsonDocument command) throws SQLExceptio var commandDescription = getCommandDescription(command); if (commandDescription.returnsResultSet()) { throw new BatchUpdateException( - format( - "Commands returning result set are not allowed. Received command: %s", - commandDescription.getCommandName()), - null, + "Commands returning result set are not allowed. Received command: %s" + .formatted(commandDescription.getCommandName()), + NULL_SQL_STATE, NO_ERROR_CODE, - null); + EMPTY_BATCH_RESULT); } if (!commandDescription.isUpdate()) { throw new SQLFeatureNotSupportedException( diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index cbb66771..da62dab1 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -74,9 +74,8 @@ class MongoStatement implements StatementAdapter { private static final String EXCEPTION_MESSAGE_OPERATION_FAILED = "Failed to execute operation"; private static final String EXCEPTION_MESSAGE_TIMEOUT = "Timeout while waiting for operation to complete"; static final int NO_ERROR_CODE = 0; - static final int[] EMPTY_BATCH_RESULT = new int[0]; - @Nullable private static final String NULL_SQL_STATE = null; + @Nullable static final String NULL_SQL_STATE = null; private final MongoDatabase mongoDatabase; private final MongoConnection mongoConnection; @@ -357,7 +356,7 @@ private static SQLException handleBatchException( // successfully or failed. return cause; } - return new SQLException(EXCEPTION_MESSAGE_OPERATION_FAILED, null, errorCode, exceptionToHandle); + return new SQLException(EXCEPTION_MESSAGE_OPERATION_FAILED, NULL_SQL_STATE, errorCode, exceptionToHandle); } private static SQLException handleQueryOrUpdateException(RuntimeException exceptionToHandle) { @@ -365,24 +364,25 @@ private static SQLException handleQueryOrUpdateException(RuntimeException except if (exceptionToHandle instanceof MongoException mongoException) { return handleMongoException(errorCode, mongoException); } - return new SQLException(EXCEPTION_MESSAGE_OPERATION_FAILED, null, errorCode, exceptionToHandle); + return new SQLException(EXCEPTION_MESSAGE_OPERATION_FAILED, NULL_SQL_STATE, errorCode, exceptionToHandle); } private static SQLException handleMongoException(int errorCode, MongoException exceptionToHandle) { if (isTimeoutException(exceptionToHandle)) { - return new SQLException(EXCEPTION_MESSAGE_TIMEOUT, null, errorCode, exceptionToHandle); + return new SQLException(EXCEPTION_MESSAGE_TIMEOUT, NULL_SQL_STATE, errorCode, exceptionToHandle); } var errorCategory = ErrorCategory.fromErrorCode(errorCode); return switch (errorCategory) { case DUPLICATE_KEY -> new SQLIntegrityConstraintViolationException( - EXCEPTION_MESSAGE_OPERATION_FAILED, null, errorCode, exceptionToHandle); + EXCEPTION_MESSAGE_OPERATION_FAILED, NULL_SQL_STATE, errorCode, exceptionToHandle); // TODO-HIBERNATE-132 EXECUTION_TIMEOUT code is returned from the server. Do we know how many commands were // executed // successfully so we can return it as BatchUpdateException? - case EXECUTION_TIMEOUT -> new SQLException(EXCEPTION_MESSAGE_TIMEOUT, null, errorCode, exceptionToHandle); + case EXECUTION_TIMEOUT -> + new SQLException(EXCEPTION_MESSAGE_TIMEOUT, NULL_SQL_STATE, errorCode, exceptionToHandle); case UNCATEGORIZED -> - new SQLException(EXCEPTION_MESSAGE_OPERATION_FAILED, null, errorCode, exceptionToHandle); + new SQLException(EXCEPTION_MESSAGE_OPERATION_FAILED, NULL_SQL_STATE, errorCode, exceptionToHandle); }; } From 40c75cf480e5b69caf021ffdd50646c7d0aa1056 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Tue, 28 Oct 2025 11:35:59 -0700 Subject: [PATCH 22/37] Remove accidental use of NO_ERROR_CODE. --- src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index da62dab1..9eda433d 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -140,7 +140,7 @@ private static boolean isExcludeProjectSpecification(Map.Entry Date: Tue, 28 Oct 2025 11:39:27 -0700 Subject: [PATCH 23/37] Consolidate batch execution logic in one place. --- .../hibernate/jdbc/MongoPreparedStatement.java | 8 +------- .../com/mongodb/hibernate/jdbc/MongoStatement.java | 11 ++++++++--- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java index bf4e9a90..0dd1a387 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java @@ -35,11 +35,9 @@ import java.sql.SQLException; import java.sql.SQLFeatureNotSupportedException; import java.sql.SQLSyntaxErrorException; -import java.sql.Statement; import java.sql.Types; import java.time.Instant; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Set; import java.util.function.Consumer; @@ -204,11 +202,7 @@ public int[] executeBatch() throws SQLException { return EMPTY_BATCH_RESULT; } checkSupportedBatchCommand(commandBatch.get(0)); - executeBatch(commandBatch); - var updateCounts = new int[commandBatch.size()]; - // We cannot determine the actual number of rows affected for each command in the batch. - Arrays.fill(updateCounts, Statement.SUCCESS_NO_INFO); - return updateCounts; + return executeBatch(commandBatch); } finally { commandBatch.clear(); } diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 9eda433d..2dfe4345 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -162,20 +162,25 @@ public int executeUpdate(String mql) throws SQLException { return executeUpdate(command); } - void executeBatch(List commandBatch) throws SQLException { + int[] executeBatch(List commandBatch) throws SQLException { var firstCommandInBatch = commandBatch.get(0); + int commandBatchSize = commandBatch.size(); var commandDescription = getCommandDescription(firstCommandInBatch); var collection = getCollection(commandDescription, firstCommandInBatch); WriteModelsToCommandMapper writeModelsToCommandMapper = null; try { startTransactionIfNeeded(); - var writeModels = new ArrayList>(commandBatch.size()); - writeModelsToCommandMapper = new WriteModelsToCommandMapper(commandBatch.size()); + var writeModels = new ArrayList>(commandBatchSize); + writeModelsToCommandMapper = new WriteModelsToCommandMapper(commandBatchSize); for (BsonDocument command : commandBatch) { WriteModelConverter.convertToWriteModels(commandDescription, command, writeModels); writeModelsToCommandMapper.add(writeModels.size()); } collection.bulkWrite(clientSession, writeModels); + var updateCounts = new int[commandBatchSize]; + // We cannot determine the actual number of rows affected for each command in the batch. + Arrays.fill(updateCounts, Statement.SUCCESS_NO_INFO); + return updateCounts; } catch (RuntimeException exception) { throw handleBatchException(exception, writeModelsToCommandMapper); } From 9435785b635eab25462a67f41129f736838331f2 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Tue, 28 Oct 2025 13:01:40 -0700 Subject: [PATCH 24/37] Remove redundunt methods. --- .../jdbc/MongoPreparedStatement.java | 11 ++- .../hibernate/jdbc/MongoStatement.java | 83 ++++++++++--------- 2 files changed, 50 insertions(+), 44 deletions(-) diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java index 0dd1a387..a4379479 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java @@ -46,13 +46,15 @@ import org.bson.BsonType; import org.bson.BsonValue; import org.bson.types.ObjectId; +import org.jspecify.annotations.Nullable; final class MongoPreparedStatement extends MongoStatement implements PreparedStatementAdapter { private final BsonDocument command; private final List commandBatch; private final List parameterValueSetters; - private static final int[] EMPTY_BATCH_RESULT = new int[0]; + + private static final int @Nullable [] NULL_UPDATE_COUNTS = null; MongoPreparedStatement( MongoDatabase mongoDatabase, ClientSession clientSession, MongoConnection mongoConnection, String mql) @@ -209,7 +211,8 @@ public int[] executeBatch() throws SQLException { } /** @throws BatchUpdateException if any of the commands in the batch attempts to return a result set. */ - private void checkSupportedBatchCommand(BsonDocument command) throws SQLException { + private void checkSupportedBatchCommand(BsonDocument command) + throws SQLFeatureNotSupportedException, BatchUpdateException, SQLSyntaxErrorException { var commandDescription = getCommandDescription(command); if (commandDescription.returnsResultSet()) { throw new BatchUpdateException( @@ -217,11 +220,11 @@ private void checkSupportedBatchCommand(BsonDocument command) throws SQLExceptio .formatted(commandDescription.getCommandName()), NULL_SQL_STATE, NO_ERROR_CODE, - EMPTY_BATCH_RESULT); + NULL_UPDATE_COUNTS); } if (!commandDescription.isUpdate()) { throw new SQLFeatureNotSupportedException( - format("Unsupported command for batch operation: %s", commandDescription.getCommandName())); + "Unsupported command for batch operation: %s".formatted(commandDescription.getCommandName())); } } diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 2dfe4345..827ef742 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -33,7 +33,6 @@ import com.mongodb.MongoSocketReadTimeoutException; import com.mongodb.MongoSocketWriteTimeoutException; import com.mongodb.MongoTimeoutException; -import com.mongodb.bulk.BulkWriteError; import com.mongodb.bulk.BulkWriteResult; import com.mongodb.client.ClientSession; import com.mongodb.client.MongoCollection; @@ -74,6 +73,7 @@ class MongoStatement implements StatementAdapter { private static final String EXCEPTION_MESSAGE_OPERATION_FAILED = "Failed to execute operation"; private static final String EXCEPTION_MESSAGE_TIMEOUT = "Timeout while waiting for operation to complete"; static final int NO_ERROR_CODE = 0; + static final int[] EMPTY_BATCH_RESULT = new int[0]; @Nullable static final String NULL_SQL_STATE = null; @@ -177,15 +177,19 @@ int[] executeBatch(List commandBatch) throws SQLExceptio writeModelsToCommandMapper.add(writeModels.size()); } collection.bulkWrite(clientSession, writeModels); - var updateCounts = new int[commandBatchSize]; - // We cannot determine the actual number of rows affected for each command in the batch. - Arrays.fill(updateCounts, Statement.SUCCESS_NO_INFO); - return updateCounts; + return createUpdateCounts(commandBatchSize); } catch (RuntimeException exception) { throw handleBatchException(exception, writeModelsToCommandMapper); } } + private static int[] createUpdateCounts(int updateCountsSize) { + // We cannot determine the actual number of rows affected for each command in the batch. + var updateCounts = new int[updateCountsSize]; + Arrays.fill(updateCounts, Statement.SUCCESS_NO_INFO); + return updateCounts; + } + int executeUpdate(BsonDocument command) throws SQLException { var commandDescription = getCommandDescription(command); var collection = getCollection(commandDescription, command); @@ -275,15 +279,17 @@ void checkClosed() throws SQLException { } } - private void checkSupportedQueryCommand(BsonDocument command) throws SQLException { + private void checkSupportedQueryCommand(BsonDocument command) + throws SQLFeatureNotSupportedException, SQLSyntaxErrorException { var commandDescription = getCommandDescription(command); if (commandDescription.isUpdate()) { throw new SQLFeatureNotSupportedException( - format("Unsupported command for query operation: %s", commandDescription.getCommandName())); + "Unsupported command for query operation: %s".formatted(commandDescription.getCommandName())); } } - void checkSupportedUpdateCommand(BsonDocument command) throws SQLException { + void checkSupportedUpdateCommand(BsonDocument command) + throws SQLFeatureNotSupportedException, SQLSyntaxErrorException { CommandDescription commandDescription = getCommandDescription(command); if (!commandDescription.isUpdate()) { throw new SQLFeatureNotSupportedException( @@ -310,7 +316,8 @@ private void startTransactionIfNeeded() throws SQLException { } /** The first key is always the command name, e.g. "insert", "update", "delete". */ - static CommandDescription getCommandDescription(BsonDocument command) throws SQLException { + static CommandDescription getCommandDescription(BsonDocument command) + throws SQLFeatureNotSupportedException, SQLSyntaxErrorException { String commandName; try { commandName = command.getFirstKey(); @@ -382,8 +389,7 @@ private static SQLException handleMongoException(int errorCode, MongoException e new SQLIntegrityConstraintViolationException( EXCEPTION_MESSAGE_OPERATION_FAILED, NULL_SQL_STATE, errorCode, exceptionToHandle); // TODO-HIBERNATE-132 EXECUTION_TIMEOUT code is returned from the server. Do we know how many commands were - // executed - // successfully so we can return it as BatchUpdateException? + // executed successfully so we can return it as BatchUpdateException? case EXECUTION_TIMEOUT -> new SQLException(EXCEPTION_MESSAGE_TIMEOUT, NULL_SQL_STATE, errorCode, exceptionToHandle); case UNCATEGORIZED -> @@ -412,31 +418,28 @@ private static int getErrorCode(RuntimeException runtimeException) { private static BatchUpdateException createBatchUpdateException( int errorCode, - Exception cause, + SQLException sqlCause, MongoBulkWriteException mongoBulkWriteException, WriteModelsToCommandMapper writeModelsToCommandMapper) { - List writeErrors = mongoBulkWriteException.getWriteErrors(); - var updateCount = 0; + var updateCounts = calculateBatchUpdateCounts(mongoBulkWriteException, writeModelsToCommandMapper); + var batchUpdateException = new BatchUpdateException( + EXCEPTION_MESSAGE_OPERATION_FAILED, NULL_SQL_STATE, errorCode, updateCounts, sqlCause); + batchUpdateException.setNextException(sqlCause); + return batchUpdateException; + } + + private static int[] calculateBatchUpdateCounts( + MongoBulkWriteException mongoBulkWriteException, WriteModelsToCommandMapper writeModelsToCommandMapper) { + var writeErrors = mongoBulkWriteException.getWriteErrors(); var writeConcernError = mongoBulkWriteException.getWriteConcernError(); if (writeConcernError == null) { if (!writeErrors.isEmpty()) { var failedModelIndex = writeErrors.get(0).getIndex(); - updateCount = writeModelsToCommandMapper.findCommandIndex(failedModelIndex); + var commandIndex = writeModelsToCommandMapper.findCommandIndex(failedModelIndex); + return createUpdateCounts(commandIndex); } } - var updateCounts = new int[updateCount]; - Arrays.fill(updateCounts, Statement.SUCCESS_NO_INFO); - return withCause( - new BatchUpdateException(EXCEPTION_MESSAGE_OPERATION_FAILED, NULL_SQL_STATE, errorCode, updateCounts), - cause); - } - - private static T withCause(T sqlException, Exception cause) { - sqlException.initCause(cause); - if (cause instanceof SQLException sqlExceptionCause) { - sqlException.setNextException(sqlExceptionCause); - } - return sqlException; + return EMPTY_BATCH_RESULT; } private static boolean isTimeoutException(MongoException exception) { @@ -491,7 +494,7 @@ static CommandDescription fromString(String commandName) throws SQLFeatureNotSup case "update" -> UPDATE; case "delete" -> DELETE; case "aggregate" -> AGGREGATE; - default -> throw new SQLFeatureNotSupportedException(format("Unsupported command: %s", commandName)); + default -> throw new SQLFeatureNotSupportedException("Unsupported command: %s".formatted(commandName)); }; } } @@ -564,13 +567,13 @@ private static WriteModel createUpdateModel( // We force exception here because the field is mandatory. updateStatement.getDocument("u"); } - if (!(updateModification instanceof BsonDocument uDocument)) { + if (!(updateModification instanceof BsonDocument updateDocument)) { throw new SQLFeatureNotSupportedException("Only document type is supported as value for field: [u]"); } if (isMulti) { - return new UpdateManyModel<>(filter, uDocument); + return new UpdateManyModel<>(filter, updateDocument); } - return new UpdateOneModel<>(filter, uDocument); + return new UpdateOneModel<>(filter, updateDocument); } private static WriteModel createDeleteModel( @@ -608,10 +611,10 @@ private static void checkFields( CommandDescription commandDescription, String exceptionMessage, Set supportedCommandFields, - Iterator fieldIterator) + Iterator fieldNameIterator) throws SQLFeatureNotSupportedException { - while (fieldIterator.hasNext()) { - var field = fieldIterator.next(); + while (fieldNameIterator.hasNext()) { + var field = fieldNameIterator.next(); if (!supportedCommandFields.contains(field)) { throw new SQLFeatureNotSupportedException( exceptionMessage.formatted(commandDescription.getCommandName(), field)); @@ -625,20 +628,20 @@ private static class WriteModelsToCommandMapper { /** The cumulative counts of write models for each command in the batch (prefix sum). */ private final int[] cumulativeCounts; - private int index; + private int cumulativeCountIndex; private WriteModelsToCommandMapper(int commandCount) { this.cumulativeCounts = new int[commandCount]; - this.index = 0; + this.cumulativeCountIndex = 0; } private void add(int cumulativeWriteModelCount) { - assertFalse(index >= cumulativeCounts.length); - cumulativeCounts[index++] = cumulativeWriteModelCount; + assertFalse(cumulativeCountIndex >= cumulativeCounts.length); + cumulativeCounts[cumulativeCountIndex++] = cumulativeWriteModelCount; } private int findCommandIndex(int writeModelIndex) { - assertTrue(index >= cumulativeCounts.length); + assertTrue(cumulativeCountIndex == cumulativeCounts.length); int lo = 0, hi = cumulativeCounts.length; while (lo < hi) { var mid = (lo + hi) >>> 1; From 46d4682fbb35db0c0ad84bbf7c5944e80acc92ea Mon Sep 17 00:00:00 2001 From: Viacheslav Babanin Date: Tue, 28 Oct 2025 16:10:30 -0700 Subject: [PATCH 25/37] Update src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java Co-authored-by: Valentin Kovalenko --- src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 827ef742..9cb1d478 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -279,7 +279,7 @@ void checkClosed() throws SQLException { } } - private void checkSupportedQueryCommand(BsonDocument command) + private static void checkSupportedQueryCommand(BsonDocument command) throws SQLFeatureNotSupportedException, SQLSyntaxErrorException { var commandDescription = getCommandDescription(command); if (commandDescription.isUpdate()) { From afd885300e8aa78341b2e9a5c706108a90fd8b1c Mon Sep 17 00:00:00 2001 From: Viacheslav Babanin Date: Tue, 28 Oct 2025 16:10:41 -0700 Subject: [PATCH 26/37] Update src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java Co-authored-by: Valentin Kovalenko --- src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 9cb1d478..8a4f46fb 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -288,7 +288,7 @@ private static void checkSupportedQueryCommand(BsonDocument command) } } - void checkSupportedUpdateCommand(BsonDocument command) + static void checkSupportedUpdateCommand(BsonDocument command) throws SQLFeatureNotSupportedException, SQLSyntaxErrorException { CommandDescription commandDescription = getCommandDescription(command); if (!commandDescription.isUpdate()) { From 18a18f715f6f6d936a56bdf1f7ed605223b455ca Mon Sep 17 00:00:00 2001 From: Viacheslav Babanin Date: Wed, 29 Oct 2025 12:16:41 -0700 Subject: [PATCH 27/37] Apply suggestions from code review Co-authored-by: Valentin Kovalenko --- .../jdbc/MongoPreparedStatement.java | 3 +- .../hibernate/jdbc/MongoStatement.java | 30 +++++++++++-------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java index a4379479..38f66f01 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java @@ -71,6 +71,7 @@ public ResultSet executeQuery() throws SQLException { checkClosed(); closeLastOpenResultSet(); checkAllParametersSet(); + checkSupportedQueryCommand(command); return executeQuery(command); } @@ -211,7 +212,7 @@ public int[] executeBatch() throws SQLException { } /** @throws BatchUpdateException if any of the commands in the batch attempts to return a result set. */ - private void checkSupportedBatchCommand(BsonDocument command) + private static void checkSupportedBatchCommand(BsonDocument command) throws SQLFeatureNotSupportedException, BatchUpdateException, SQLSyntaxErrorException { var commandDescription = getCommandDescription(command); if (commandDescription.returnsResultSet()) { diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 8a4f46fb..b28168cb 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -16,10 +16,10 @@ package com.mongodb.hibernate.jdbc; -import static com.mongodb.assertions.Assertions.assertFalse; -import static com.mongodb.assertions.Assertions.assertNotNull; -import static com.mongodb.assertions.Assertions.assertTrue; -import static com.mongodb.assertions.Assertions.fail; +import static com.mongodb.hibernate.internal.MongoAssertions.assertFalse; +import static com.mongodb.hibernate.internal.MongoAssertions.assertNotNull; +import static com.mongodb.hibernate.internal.MongoAssertions.assertTrue; +import static com.mongodb.hibernate.internal.MongoAssertions.fail; import static com.mongodb.hibernate.internal.MongoConstants.EXTENDED_JSON_WRITER_SETTINGS; import static com.mongodb.hibernate.internal.MongoConstants.ID_FIELD_NAME; import static com.mongodb.hibernate.internal.VisibleForTesting.AccessModifier.PRIVATE; @@ -75,7 +75,7 @@ class MongoStatement implements StatementAdapter { static final int NO_ERROR_CODE = 0; static final int[] EMPTY_BATCH_RESULT = new int[0]; - @Nullable static final String NULL_SQL_STATE = null; + static final String @Nullable NULL_SQL_STATE = null; private final MongoDatabase mongoDatabase; private final MongoConnection mongoConnection; @@ -164,7 +164,7 @@ public int executeUpdate(String mql) throws SQLException { int[] executeBatch(List commandBatch) throws SQLException { var firstCommandInBatch = commandBatch.get(0); - int commandBatchSize = commandBatch.size(); + var commandBatchSize = commandBatch.size(); var commandDescription = getCommandDescription(firstCommandInBatch); var collection = getCollection(commandDescription, firstCommandInBatch); WriteModelsToCommandMapper writeModelsToCommandMapper = null; @@ -172,7 +172,7 @@ int[] executeBatch(List commandBatch) throws SQLExceptio startTransactionIfNeeded(); var writeModels = new ArrayList>(commandBatchSize); writeModelsToCommandMapper = new WriteModelsToCommandMapper(commandBatchSize); - for (BsonDocument command : commandBatch) { + for (var command : commandBatch) { WriteModelConverter.convertToWriteModels(commandDescription, command, writeModels); writeModelsToCommandMapper.add(writeModels.size()); } @@ -444,16 +444,20 @@ private static int[] calculateBatchUpdateCounts( private static boolean isTimeoutException(MongoException exception) { // We do not check for `MongoExecutionTimeoutException` and `MongoOperationTimeoutException` here, - // because it is handled via error codes. + // because they are handled via error codes. return exception instanceof MongoSocketReadTimeoutException || exception instanceof MongoSocketWriteTimeoutException || exception instanceof MongoTimeoutException; } enum CommandDescription { + /** See {@code insert}. */ INSERT("insert", false, true), + /** See {@code update}. */ UPDATE("update", false, true), + /** See {@code delete}. */ DELETE("delete", false, true), + /** See {@code aggregate}. */ AGGREGATE("aggregate", true, false); private final String commandName; @@ -471,9 +475,9 @@ String getCommandName() { } /** - * Indicates whether the command is used in {@code executeUpdate(...)} or {@code executeBatch()} methods. + * Indicates whether the command may be used in {@code executeUpdate(...)} or {@code executeBatch()} methods. * - * @return true if the command is used in update operations, false if it is used in query operations. + * @return true if the command may be used in update operations, false if it is used in query operations. */ boolean isUpdate() { return isUpdate; @@ -500,8 +504,10 @@ static CommandDescription fromString(String commandName) throws SQLFeatureNotSup } private static class WriteModelConverter { - private static final String UNSUPPORTED_MESSAGE_STATEMENT_FIELD = "Unsupported field in %s statement: [%s]"; - private static final String UNSUPPORTED_MESSAGE_COMMAND_FIELD = "Unsupported field in %s command: [%s]"; + private static final String UNSUPPORTED_MESSAGE_TEMPLATE_STATEMENT_FIELD = + "Unsupported field in [%s] statement: [%s]"; + private static final String UNSUPPORTED_MESSAGE_TEMPLATE_COMMAND_FIELD = + "Unsupported field in [%s] command: [%s]"; private static final Set SUPPORTED_INSERT_COMMAND_FIELDS = Set.of("documents"); From 234921a1427221033c0e7d853807638b0a836246 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Wed, 29 Oct 2025 02:13:34 -0700 Subject: [PATCH 28/37] Move inside `try`, do not start transactions too early --- .../hibernate/jdbc/MongoStatement.java | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 487dbbbc..9116bfa2 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -108,16 +108,15 @@ void closeLastOpenResultSet() throws SQLException { } ResultSet executeQuery(BsonDocument command) throws SQLException { - var commandDescription = getCommandDescription(command); try { - startTransactionIfNeeded(); + var commandDescription = getCommandDescription(command); var collection = getCollection(commandDescription, command); var pipeline = command.getArray("pipeline").stream() .map(BsonValue::asDocument) .toList(); var fieldNames = getFieldNamesFromProjectStage( pipeline.get(pipeline.size() - 1).getDocument("$project")); - + startTransactionIfNeeded(); return resultSet = new MongoResultSet( collection.aggregate(clientSession, pipeline).cursor(), fieldNames); } catch (RuntimeException exception) { @@ -166,19 +165,19 @@ public int executeUpdate(String mql) throws SQLException { } int[] executeBatch(List commandBatch) throws SQLException { - var firstCommandInBatch = commandBatch.get(0); - var commandBatchSize = commandBatch.size(); - var commandDescription = getCommandDescription(firstCommandInBatch); - var collection = getCollection(commandDescription, firstCommandInBatch); WriteModelsToCommandMapper writeModelsToCommandMapper = null; try { - startTransactionIfNeeded(); + var firstCommandInBatch = commandBatch.get(0); + var commandBatchSize = commandBatch.size(); + var commandDescription = getCommandDescription(firstCommandInBatch); + var collection = getCollection(commandDescription, firstCommandInBatch); var writeModels = new ArrayList>(commandBatchSize); writeModelsToCommandMapper = new WriteModelsToCommandMapper(commandBatchSize); for (var command : commandBatch) { WriteModelConverter.convertToWriteModels(commandDescription, command, writeModels); writeModelsToCommandMapper.add(writeModels.size()); } + startTransactionIfNeeded(); collection.bulkWrite(clientSession, writeModels); return createUpdateCounts(commandBatchSize); } catch (RuntimeException exception) { @@ -194,12 +193,12 @@ private static int[] createUpdateCounts(int updateCountsSize) { } int executeUpdate(BsonDocument command) throws SQLException { - var commandDescription = getCommandDescription(command); - var collection = getCollection(commandDescription, command); try { - startTransactionIfNeeded(); + var commandDescription = getCommandDescription(command); + var collection = getCollection(commandDescription, command); var writeModels = new ArrayList>(); WriteModelConverter.convertToWriteModels(commandDescription, command, writeModels); + startTransactionIfNeeded(); var bulkWriteResult = collection.bulkWrite(clientSession, writeModels); return getUpdateCount(commandDescription, bulkWriteResult); } catch (RuntimeException exception) { From 7a17cd5d1f7ff7563aeec78fa24f4cccf481a099 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Thu, 30 Oct 2025 12:34:39 -0700 Subject: [PATCH 29/37] Remove SQLConstraintViolationException. --- ...ongoPreparedStatementIntegrationTests.java | 81 ++++++--- .../jdbc/MongoPreparedStatement.java | 11 +- .../hibernate/jdbc/MongoStatement.java | 157 +++++++++--------- .../jdbc/MongoPreparedStatementTests.java | 79 +-------- 4 files changed, 146 insertions(+), 182 deletions(-) diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java index 20712a48..81b85d99 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java @@ -51,6 +51,7 @@ import java.util.Random; import java.util.function.Function; import java.util.stream.Stream; +import org.bson.BSONException; import org.bson.BsonDocument; import org.hibernate.Session; import org.hibernate.SessionFactory; @@ -272,11 +273,11 @@ void testNoCollectionNameProvidedExecuteQuery() { assertInvalidMql( """ { - insert: {} + aggregate: {} }""", PreparedStatement::executeQuery, """ - Invalid MQL. Collection name is missing [{"insert": {}}]"""); + Invalid MQL. Collection name is missing [{"aggregate": {}}]"""); } @Test @@ -306,6 +307,42 @@ void testNoCollectionNameProvidedExecuteBatch() { Invalid MQL. Collection name is missing [{"insert": {}}]"""); } + @Test + void testAbsentRequiredAggregateCommandField() { + doWorkAwareOfAutoCommit(connection -> { + String mql = + """ + { + aggregate: "books" + }"""; + try (PreparedStatement pstm = connection.prepareStatement(mql)) { + assertThatThrownBy(pstm::executeQuery) + .isInstanceOf(SQLSyntaxErrorException.class) + .hasMessage("Invalid MQL: [%s]".formatted(toExtendedJson(mql))) + .cause() + .isInstanceOf(BSONException.class) + .hasMessage("Document does not contain key pipeline"); + } + }); + } + + @Test + void testAbsentRequiredProjectAggregationPipelineStage() { + doWorkAwareOfAutoCommit(connection -> { + String mql = + """ + { + aggregate: "books", + "pipeline": [] + }"""; + try (PreparedStatement pstm = connection.prepareStatement(mql)) { + assertThatThrownBy(pstm::executeQuery) + .isInstanceOf(SQLSyntaxErrorException.class) + .hasMessage("Invalid MQL. $project stage is missing [%s]".formatted(toExtendedJson(mql))); + } + }); + } + @Nested class ExecuteBatchTests { private static final String INSERT_MQL = @@ -343,7 +380,7 @@ void testQueriesReturningResult() { pstm.addBatch(); assertThatExceptionOfType(BatchUpdateException.class) .isThrownBy(pstm::executeBatch) - .returns(null, BatchUpdateException::getUpdateCounts) + .returns(new int[0], BatchUpdateException::getUpdateCounts) .returns(null, BatchUpdateException::getSQLState) .returns(0, BatchUpdateException::getErrorCode); } @@ -769,7 +806,7 @@ void testNotSupportedUpdateCommandField(String unsupportedField) { .formatted(unsupportedField))) { assertThatThrownBy(pstm::executeUpdate) .isInstanceOf(SQLFeatureNotSupportedException.class) - .hasMessage("Unsupported field in update command: [%s]" + .hasMessage("Unsupported field in [update] command: [%s]" .formatted(getFieldName(unsupportedField))); } }); @@ -812,7 +849,7 @@ void testNotSupportedDeleteCommandField(String unsupportedField) { .formatted(unsupportedField))) { assertThatThrownBy(pstm::executeUpdate) .isInstanceOf(SQLFeatureNotSupportedException.class) - .hasMessage("Unsupported field in delete command: [%s]" + .hasMessage("Unsupported field in [delete] command: [%s]" .formatted(getFieldName(unsupportedField))); } }); @@ -862,7 +899,7 @@ void testNotSupportedInsertCommandField(String unsupportedField) { .formatted(unsupportedField))) { assertThatThrownBy(pstm::executeUpdate) .isInstanceOf(SQLFeatureNotSupportedException.class) - .hasMessage("Unsupported field in insert command: [%s]" + .hasMessage("Unsupported field in [insert] command: [%s]" .formatted(getFieldName(unsupportedField))); } }); @@ -888,14 +925,14 @@ void testAbsentRequiredInsertCommandField() { private static Stream unsupportedUpdateStatementFields() { return Stream.of( - of("hint: {}", "Unsupported field in update statement: [hint]"), - of("hint: \"a\"", "Unsupported field in update statement: [hint]"), - of("collation: {}", "Unsupported field in update statement: [collation]"), - of("arrayFilters: []", "Unsupported field in update statement: [arrayFilters]"), - of("sort: {}", "Unsupported field in update statement: [sort]"), - of("upsert: true", "Unsupported field in update statement: [upsert]"), + of("hint: {}", "Unsupported field in [update] statement: [hint]"), + of("hint: \"a\"", "Unsupported field in [update] statement: [hint]"), + of("collation: {}", "Unsupported field in [update] statement: [collation]"), + of("arrayFilters: []", "Unsupported field in [update] statement: [arrayFilters]"), + of("sort: {}", "Unsupported field in [update] statement: [sort]"), + of("upsert: true", "Unsupported field in [update] statement: [upsert]"), of("u: []", "Only document type is supported as value for field: [u]"), - of("c: {}", "Unsupported field in update statement: [c]")); + of("c: {}", "Unsupported field in [update] statement: [c]")); } @ParameterizedTest(name = "test not supported update statement field. Parameters: option={0}") @@ -968,7 +1005,7 @@ void testNotSupportedDeleteStatementField(String unsupportedField) { .formatted(unsupportedField))) { assertThatThrownBy(pstm::executeUpdate) .isInstanceOf(SQLFeatureNotSupportedException.class) - .hasMessage("Unsupported field in delete statement: [%s]" + .hasMessage("Unsupported field in [delete] statement: [%s]" .formatted(getFieldName(unsupportedField))); } }); @@ -1012,14 +1049,6 @@ private void assertExecuteUpdate( assertThat(mongoCollection.find().sort(Sorts.ascending(ID_FIELD_NAME))) .containsExactlyElementsOf(expectedDocuments); } - - private static String getFieldName(String unsupportedField) { - return BsonDocument.parse("{" + unsupportedField + "}").getFirstKey(); - } - - private String toExtendedJson(String mql) { - return BsonDocument.parse(mql).toJson(EXTENDED_JSON_WRITER_SETTINGS); - } } private void assertInvalidMql( @@ -1041,6 +1070,14 @@ void doAwareOfAutoCommit(Connection connection, SqlExecutable work) throws SQLEx doWithSpecifiedAutoCommit(false, connection, () -> doAndTerminateTransaction(connection, work)); } + private static String getFieldName(String unsupportedField) { + return BsonDocument.parse("{" + unsupportedField + "}").getFirstKey(); + } + + private String toExtendedJson(String mql) { + return BsonDocument.parse(mql).toJson(EXTENDED_JSON_WRITER_SETTINGS); + } + interface SqlConsumer { void accept(T t) throws SQLException; } diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java index a238c139..5bdb5794 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java @@ -47,7 +47,6 @@ import org.bson.BsonType; import org.bson.BsonValue; import org.bson.types.ObjectId; -import org.jspecify.annotations.Nullable; final class MongoPreparedStatement extends MongoStatement implements PreparedStatementAdapter { @@ -55,8 +54,6 @@ final class MongoPreparedStatement extends MongoStatement implements PreparedSta private final List commandBatch; private final List parameterValueSetters; - private static final int @Nullable [] NULL_UPDATE_COUNTS = null; - MongoPreparedStatement( MongoDatabase mongoDatabase, ClientSession clientSession, MongoConnection mongoConnection, String mql) throws SQLSyntaxErrorException { @@ -204,7 +201,7 @@ public int[] executeBatch() throws SQLException { try { closeLastOpenResultSet(); if (commandBatch.isEmpty()) { - return EMPTY_BATCH_RESULT; + return EMPTY_UPDATE_COUNTS; } checkSupportedBatchCommand(commandBatch.get(0)); return executeBatch(commandBatch); @@ -217,17 +214,17 @@ public int[] executeBatch() throws SQLException { private static void checkSupportedBatchCommand(BsonDocument command) throws SQLFeatureNotSupportedException, BatchUpdateException, SQLSyntaxErrorException { var commandDescription = getCommandDescription(command); - if (commandDescription.returnsResultSet()) { + if (commandDescription.isQuery()) { throw new BatchUpdateException( "Commands returning result set are not allowed. Received command: %s" .formatted(commandDescription.getCommandName()), NULL_SQL_STATE, NO_ERROR_CODE, - NULL_UPDATE_COUNTS); + EMPTY_UPDATE_COUNTS); } if (!commandDescription.isUpdate()) { throw new SQLFeatureNotSupportedException( - "Unsupported command for batch operation: %s".formatted(commandDescription.getCommandName())); + "Unsupported command for executeBatch: %s".formatted(commandDescription.getCommandName())); } } diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 9116bfa2..1054b63c 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -51,7 +51,6 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.sql.SQLFeatureNotSupportedException; -import java.sql.SQLIntegrityConstraintViolationException; import java.sql.SQLSyntaxErrorException; import java.sql.SQLWarning; import java.sql.Statement; @@ -72,11 +71,13 @@ class MongoStatement implements StatementAdapter { private static final String EXCEPTION_MESSAGE_OPERATION_FAILED = "Failed to execute operation"; - private static final String EXCEPTION_MESSAGE_TIMEOUT = "Timeout while waiting for operation to complete"; + private static final String EXCEPTION_MESSAGE_OPERATION_TIMED_OUT = + "Timeout while waiting for operation to complete"; static final int NO_ERROR_CODE = 0; - static final int[] EMPTY_BATCH_RESULT = new int[0]; + static final int[] EMPTY_UPDATE_COUNTS = new int[0]; - static final String @Nullable NULL_SQL_STATE = null; + static final @Nullable String NULL_SQL_STATE = null; + private static final String EXCEPTION_MESSAGE_PREFIX_INVALID_MQL = "Invalid MQL"; private final MongoDatabase mongoDatabase; private final MongoConnection mongoConnection; @@ -114,13 +115,19 @@ ResultSet executeQuery(BsonDocument command) throws SQLException { var pipeline = command.getArray("pipeline").stream() .map(BsonValue::asDocument) .toList(); + var projectStageIndex = pipeline.size() - 1; + if (pipeline.isEmpty()) { + throw createSyntaxErrorException("%s. $project stage is missing [%s]", command); + } var fieldNames = getFieldNamesFromProjectStage( - pipeline.get(pipeline.size() - 1).getDocument("$project")); + pipeline.get(projectStageIndex).getDocument("$project")); startTransactionIfNeeded(); return resultSet = new MongoResultSet( collection.aggregate(clientSession, pipeline).cursor(), fieldNames); + } catch (BSONException bsonException) { + throw createSyntaxErrorException("%s: [%s]", command, bsonException); } catch (RuntimeException exception) { - throw handleQueryOrUpdateException(exception); + throw handleExecuteQueryOrUpdateException(exception); } } @@ -164,7 +171,7 @@ public int executeUpdate(String mql) throws SQLException { return executeUpdate(command); } - int[] executeBatch(List commandBatch) throws SQLException { + int[] executeBatch(List commandBatch) throws SQLException { WriteModelsToCommandMapper writeModelsToCommandMapper = null; try { var firstCommandInBatch = commandBatch.get(0); @@ -181,7 +188,7 @@ int[] executeBatch(List commandBatch) throws SQLExceptio collection.bulkWrite(clientSession, writeModels); return createUpdateCounts(commandBatchSize); } catch (RuntimeException exception) { - throw handleBatchException(exception, writeModelsToCommandMapper); + throw handleExecuteBatchException(exception, writeModelsToCommandMapper); } } @@ -202,7 +209,7 @@ int executeUpdate(BsonDocument command) throws SQLException { var bulkWriteResult = collection.bulkWrite(clientSession, writeModels); return getUpdateCount(commandDescription, bulkWriteResult); } catch (RuntimeException exception) { - throw handleQueryOrUpdateException(exception); + throw handleExecuteQueryOrUpdateException(exception); } } @@ -289,21 +296,21 @@ void checkClosed() throws SQLException { } } - private static void checkSupportedQueryCommand(BsonDocument command) + static void checkSupportedQueryCommand(BsonDocument command) throws SQLFeatureNotSupportedException, SQLSyntaxErrorException { var commandDescription = getCommandDescription(command); if (commandDescription.isUpdate()) { throw new SQLFeatureNotSupportedException( - "Unsupported command for query operation: %s".formatted(commandDescription.getCommandName())); + "Unsupported command for executeQuery: %s".formatted(commandDescription.getCommandName())); } } static void checkSupportedUpdateCommand(BsonDocument command) throws SQLFeatureNotSupportedException, SQLSyntaxErrorException { CommandDescription commandDescription = getCommandDescription(command); - if (!commandDescription.isUpdate()) { + if (commandDescription.isQuery()) { throw new SQLFeatureNotSupportedException( - "Unsupported command for update operation: %s".formatted(commandDescription.getCommandName())); + "Unsupported command for executeUpdate: %s".formatted(commandDescription.getCommandName())); } } @@ -311,7 +318,7 @@ static BsonDocument parse(String mql) throws SQLSyntaxErrorException { try { return BsonDocument.parse(mql); } catch (RuntimeException e) { - throw new SQLSyntaxErrorException("Invalid MQL: [%s]".formatted(mql), e); + throw new SQLSyntaxErrorException("%s: [%s]".formatted(EXCEPTION_MESSAGE_PREFIX_INVALID_MQL, mql), e); } } @@ -332,12 +339,9 @@ static CommandDescription getCommandDescription(BsonDocument command) try { commandName = command.getFirstKey(); } catch (NoSuchElementException exception) { - throw new SQLSyntaxErrorException( - "Invalid MQL. Command name is missing: [%s]" - .formatted(command.toJson(EXTENDED_JSON_WRITER_SETTINGS)), - exception); + throw createSyntaxErrorException("%s. Command name is missing: [%s]", command, exception); } - return CommandDescription.fromString(commandName); + return CommandDescription.of(commandName); } private MongoCollection getCollection(CommandDescription commandDescription, BsonDocument command) @@ -347,10 +351,7 @@ private MongoCollection getCollection(CommandDescription commandDe try { collectionName = command.getString(commandName); } catch (BsonInvalidOperationException exception) { - throw new SQLSyntaxErrorException( - "Invalid MQL. Collection name is missing [%s]" - .formatted(command.toJson(EXTENDED_JSON_WRITER_SETTINGS)), - exception); + throw createSyntaxErrorException("%s. Collection name is missing [%s]", command, exception); } return mongoDatabase.getCollection(collectionName.getValue(), BsonDocument.class); } @@ -364,46 +365,37 @@ private static int getUpdateCount(CommandDescription commandDescription, BulkWri }; } - private static SQLException handleBatchException( + private static SQLException handleExecuteBatchException( RuntimeException exceptionToHandle, @Nullable WriteModelsToCommandMapper writeModelsToCommandMapper) { var errorCode = getErrorCode(exceptionToHandle); - if (exceptionToHandle instanceof MongoException mongoException) { - var cause = handleMongoException(errorCode, mongoException); - if (mongoException instanceof MongoBulkWriteException bulkWriteException) { - return createBatchUpdateException( - errorCode, cause, bulkWriteException, assertNotNull(writeModelsToCommandMapper)); - } - // TODO-HIBERNATE-132 BatchUpdateException is thrown when one of the commands fails to execute properly. - // When exception is not of MongoBulkWriteException, we are not sure if any command was executed - // successfully or failed. - return cause; + String exceptionMessage = getExceptionMessage(errorCode, exceptionToHandle); + if (exceptionToHandle instanceof MongoBulkWriteException bulkWriteException) { + return createBatchUpdateException( + exceptionMessage, errorCode, bulkWriteException, assertNotNull(writeModelsToCommandMapper)); } - return new SQLException(EXCEPTION_MESSAGE_OPERATION_FAILED, NULL_SQL_STATE, errorCode, exceptionToHandle); + + // TODO-HIBERNATE-132 BatchUpdateException is thrown when one of the commands fails to execute properly. + // When exception is not of MongoBulkWriteException, we are not sure if any command was executed + // successfully or failed. + return new SQLException(exceptionMessage, NULL_SQL_STATE, errorCode, exceptionToHandle); } - private static SQLException handleQueryOrUpdateException(RuntimeException exceptionToHandle) { + private static SQLException handleExecuteQueryOrUpdateException(RuntimeException exceptionToHandle) { var errorCode = getErrorCode(exceptionToHandle); - if (exceptionToHandle instanceof MongoException mongoException) { - return handleMongoException(errorCode, mongoException); - } - return new SQLException(EXCEPTION_MESSAGE_OPERATION_FAILED, NULL_SQL_STATE, errorCode, exceptionToHandle); + String exceptionMessage = getExceptionMessage(errorCode, exceptionToHandle); + return new SQLException(exceptionMessage, NULL_SQL_STATE, errorCode, exceptionToHandle); } - private static SQLException handleMongoException(int errorCode, MongoException exceptionToHandle) { - if (isTimeoutException(exceptionToHandle)) { - return new SQLException(EXCEPTION_MESSAGE_TIMEOUT, NULL_SQL_STATE, errorCode, exceptionToHandle); + private static String getExceptionMessage(int errorCode, RuntimeException exceptionToHandle) { + if (exceptionToHandle instanceof MongoException mongoException && isTimeoutException(mongoException)) { + return EXCEPTION_MESSAGE_OPERATION_TIMED_OUT; } var errorCategory = ErrorCategory.fromErrorCode(errorCode); return switch (errorCategory) { - case DUPLICATE_KEY -> - new SQLIntegrityConstraintViolationException( - EXCEPTION_MESSAGE_OPERATION_FAILED, NULL_SQL_STATE, errorCode, exceptionToHandle); + case DUPLICATE_KEY, UNCATEGORIZED -> EXCEPTION_MESSAGE_OPERATION_FAILED; // TODO-HIBERNATE-132 EXECUTION_TIMEOUT code is returned from the server. Do we know how many commands were // executed successfully so we can return it as BatchUpdateException? - case EXECUTION_TIMEOUT -> - new SQLException(EXCEPTION_MESSAGE_TIMEOUT, NULL_SQL_STATE, errorCode, exceptionToHandle); - case UNCATEGORIZED -> - new SQLException(EXCEPTION_MESSAGE_OPERATION_FAILED, NULL_SQL_STATE, errorCode, exceptionToHandle); + case EXECUTION_TIMEOUT -> EXCEPTION_MESSAGE_OPERATION_TIMED_OUT; }; } @@ -426,15 +418,28 @@ private static int getErrorCode(RuntimeException runtimeException) { return NO_ERROR_CODE; } + private static SQLSyntaxErrorException createSyntaxErrorException( + String exceptionMessageTemplate, BsonDocument command, @Nullable Exception cause) { + return new SQLSyntaxErrorException( + exceptionMessageTemplate.formatted( + EXCEPTION_MESSAGE_PREFIX_INVALID_MQL, command.toJson(EXTENDED_JSON_WRITER_SETTINGS)), + command.toJson(), + cause); + } + + private static SQLSyntaxErrorException createSyntaxErrorException( + String exceptionMessageTemplate, BsonDocument command) { + return createSyntaxErrorException(exceptionMessageTemplate, command, null); + } + private static BatchUpdateException createBatchUpdateException( + String exceptionMessage, int errorCode, - SQLException sqlCause, MongoBulkWriteException mongoBulkWriteException, WriteModelsToCommandMapper writeModelsToCommandMapper) { var updateCounts = calculateBatchUpdateCounts(mongoBulkWriteException, writeModelsToCommandMapper); var batchUpdateException = new BatchUpdateException( - EXCEPTION_MESSAGE_OPERATION_FAILED, NULL_SQL_STATE, errorCode, updateCounts, sqlCause); - batchUpdateException.setNextException(sqlCause); + exceptionMessage, NULL_SQL_STATE, errorCode, updateCounts, mongoBulkWriteException); return batchUpdateException; } @@ -443,13 +448,12 @@ private static int[] calculateBatchUpdateCounts( var writeErrors = mongoBulkWriteException.getWriteErrors(); var writeConcernError = mongoBulkWriteException.getWriteConcernError(); if (writeConcernError == null) { - if (!writeErrors.isEmpty()) { - var failedModelIndex = writeErrors.get(0).getIndex(); - var commandIndex = writeModelsToCommandMapper.findCommandIndex(failedModelIndex); - return createUpdateCounts(commandIndex); - } + assertTrue(writeErrors.size() == 1); + var failedModelIndex = writeErrors.get(0).getIndex(); + var failedCommandIndexInBatch = writeModelsToCommandMapper.findCommandIndex(failedModelIndex); + return createUpdateCounts(failedCommandIndexInBatch); } - return EMPTY_BATCH_RESULT; + return EMPTY_UPDATE_COUNTS; } private static boolean isTimeoutException(MongoException exception) { @@ -471,12 +475,12 @@ enum CommandDescription { AGGREGATE("aggregate", true, false); private final String commandName; - private final boolean returnsResultSet; + private final boolean isQuery; private final boolean isUpdate; - CommandDescription(String commandName, boolean returnsResultSet, boolean isUpdate) { + CommandDescription(String commandName, boolean isQuery, boolean isUpdate) { this.commandName = commandName; - this.returnsResultSet = returnsResultSet; + this.isQuery = isQuery; this.isUpdate = isUpdate; } @@ -487,22 +491,22 @@ String getCommandName() { /** * Indicates whether the command may be used in {@code executeUpdate(...)} or {@code executeBatch()} methods. * - * @return true if the command may be used in update operations, false if it is used in query operations. + * @return true if the command may be used in update operations. */ boolean isUpdate() { return isUpdate; } /** - * Indicates whether the command returns a {@link ResultSet}. + * Indicates whether the command may be used in {@code executeQuery(...)} methods. * - * @see #executeQuery(String) + * @return true if the command may be used in query operations. */ - boolean returnsResultSet() { - return returnsResultSet; + boolean isQuery() { + return isQuery; } - static CommandDescription fromString(String commandName) throws SQLFeatureNotSupportedException { + static CommandDescription of(String commandName) throws SQLFeatureNotSupportedException { return switch (commandName) { case "insert" -> INSERT; case "update" -> UPDATE; @@ -561,10 +565,7 @@ private static void convertToWriteModels( throw fail(commandDescription.toString()); } } catch (BSONException bsonException) { - throw new SQLSyntaxErrorException( - "Invalid MQL: [%s]".formatted(command.toJson(EXTENDED_JSON_WRITER_SETTINGS)), - NULL_SQL_STATE, - bsonException); + throw createSyntaxErrorException("%s: [%s]", command, bsonException); } } @@ -610,7 +611,7 @@ private static void checkStatementFields( throws SQLFeatureNotSupportedException { checkFields( commandDescription, - UNSUPPORTED_MESSAGE_STATEMENT_FIELD, + UNSUPPORTED_MESSAGE_TEMPLATE_STATEMENT_FIELD, supportedStatementFields, statement.keySet().iterator()); } @@ -620,12 +621,13 @@ private static void checkCommandFields( throws SQLFeatureNotSupportedException { var iterator = command.keySet().iterator(); iterator.next(); // skip the command name - checkFields(commandDescription, UNSUPPORTED_MESSAGE_COMMAND_FIELD, supportedCommandFields, iterator); + checkFields( + commandDescription, UNSUPPORTED_MESSAGE_TEMPLATE_COMMAND_FIELD, supportedCommandFields, iterator); } private static void checkFields( CommandDescription commandDescription, - String exceptionMessage, + String exceptionMessageTemplate, Set supportedCommandFields, Iterator fieldNameIterator) throws SQLFeatureNotSupportedException { @@ -633,7 +635,7 @@ private static void checkFields( var field = fieldNameIterator.next(); if (!supportedCommandFields.contains(field)) { throw new SQLFeatureNotSupportedException( - exceptionMessage.formatted(commandDescription.getCommandName(), field)); + exceptionMessageTemplate.formatted(commandDescription.getCommandName(), field)); } } } @@ -658,7 +660,8 @@ private void add(int cumulativeWriteModelCount) { private int findCommandIndex(int writeModelIndex) { assertTrue(cumulativeCountIndex == cumulativeCounts.length); - int lo = 0, hi = cumulativeCounts.length; + var lo = 0; + var hi = cumulativeCounts.length; while (lo < hi) { var mid = (lo + hi) >>> 1; if (cumulativeCounts[mid] >= writeModelIndex + 1) { diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java index e8004661..9ca43125 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java @@ -22,7 +22,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatObject; -import static org.assertj.core.api.InstanceOfAssertFactories.type; import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -63,7 +62,6 @@ import java.sql.BatchUpdateException; import java.sql.ResultSet; import java.sql.SQLException; -import java.sql.SQLIntegrityConstraintViolationException; import java.sql.SQLSyntaxErrorException; import java.sql.Types; import java.util.Calendar; @@ -273,18 +271,11 @@ private static Stream timeoutExceptions() { ); } - private static Stream constraintViolationErrorCodes() { - return Stream.of(11000, 11001, 12582); - } - - private static Stream constraintViolationExceptions() { - return constraintViolationErrorCodes() - .map(errorCode -> new MongoException(errorCode, DUMMY_EXCEPTION_MESSAGE)); - } - private static Stream genericMongoExceptions() { return Stream.of( - new MongoException(-3, DUMMY_EXCEPTION_MESSAGE), new MongoException(5000, DUMMY_EXCEPTION_MESSAGE)); + new MongoException(-3, DUMMY_EXCEPTION_MESSAGE), + new MongoException(11000, DUMMY_EXCEPTION_MESSAGE), + new MongoException(5000, DUMMY_EXCEPTION_MESSAGE)); } @ParameterizedTest(name = "test executeBatch MongoException. Parameters: Parameters: mongoException: {0}") @@ -316,28 +307,6 @@ void testExecuteQueryMongoException(MongoException mongoException) throws SQLExc sqlException -> assertGenericMongoException(sqlException, mongoException)); } - @ParameterizedTest(name = "test executeUpdate constraint violation. Parameters: mongoException: {0}") - @MethodSource("constraintViolationExceptions") - void testExecuteUpdateConstraintViolationException(MongoException mongoException) throws SQLException { - int expectedErrorCode = mongoException.getCode(); - doThrow(mongoException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); - - assertExecuteUpdateThrowsSqlException(sqlException -> { - assertConstraintViolationException(mongoException, sqlException, expectedErrorCode); - }); - } - - @ParameterizedTest(name = "test executeQuery constraint violation. Parameters: mongoException: {0}") - @MethodSource("constraintViolationExceptions") - void testExecuteQueryConstraintViolationException(MongoException mongoException) throws SQLException { - int expectedErrorCode = mongoException.getCode(); - doThrow(mongoException).when(mongoCollection).aggregate(eq(clientSession), anyList()); - - assertExecuteQueryThrowsSqlException(sqlException -> { - assertConstraintViolationException(mongoException, sqlException, expectedErrorCode); - }); - } - @ParameterizedTest(name = "test executeBatch timeout exception. Parameters: mongoTimeoutException: {0}") @MethodSource("timeoutExceptions") void testExecuteBatchTimeoutException(MongoException mongoTimeoutException) throws SQLException { @@ -347,27 +316,6 @@ void testExecuteBatchTimeoutException(MongoException mongoTimeoutException) thro }); } - @ParameterizedTest(name = "test executeBatch constraint violation. Parameters: mongoException: {0}") - @MethodSource("constraintViolationErrorCodes") - void testExecuteBatchConstraintViolationException(int errorCode) throws SQLException { - MongoBulkWriteException mongoBulkWriteException = createMongoBulkWriteException(errorCode, 0); - - doThrow(mongoBulkWriteException).when(mongoCollection).bulkWrite(eq(clientSession), anyList()); - - assertExecuteBatchThrowsBatchUpdateException(batchUpdateException -> { - assertThatObject(batchUpdateException) - .returns(errorCode, BatchUpdateException::getErrorCode) - .returns(null, BatchUpdateException::getSQLState) - .satisfies(ex -> { - assertUpdateCounts(ex.getUpdateCounts(), 0); - }) - .extracting(SQLException::getCause) - .asInstanceOf(type(SQLIntegrityConstraintViolationException.class)) - .returns(errorCode, SQLIntegrityConstraintViolationException::getErrorCode) - .returns(mongoBulkWriteException, SQLIntegrityConstraintViolationException::getCause); - }); - } - @Test void testExecuteBatchRuntimeExceptionCause() throws SQLException { RuntimeException runtimeException = new RuntimeException(); @@ -468,8 +416,6 @@ void testExecuteBatchMongoBulkWriteException( assertUpdateCounts(ex.getUpdateCounts(), expectedUpdateCountLength); }) .havingCause() - .isInstanceOf(SQLException.class) - .havingCause() .isSameAs(mongoBulkWriteException); } } @@ -490,25 +436,6 @@ private static void assertGenericMongoException(SQLException sqlException, Mongo .returns(cause, SQLException::getCause); } - private static void assertConstraintViolationException( - MongoException mongoException, SQLException sqlException, int expectedErrorCode) { - assertThatObject(sqlException) - .asInstanceOf(type(SQLIntegrityConstraintViolationException.class)) - .returns(expectedErrorCode, SQLIntegrityConstraintViolationException::getErrorCode) - .returns(null, SQLIntegrityConstraintViolationException::getSQLState) - .returns(mongoException, SQLIntegrityConstraintViolationException::getCause); - } - - private void assertExecuteBatchThrowsBatchUpdateException(ThrowingConsumer asserter) - throws SQLException { - try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(MQL_ITEMS_INSERT)) { - mongoPreparedStatement.addBatch(); - assertThatExceptionOfType(BatchUpdateException.class) - .isThrownBy(mongoPreparedStatement::executeBatch) - .satisfies(asserter); - } - } - private void assertExecuteBatchThrowsSqlException(ThrowingConsumer asserter) throws SQLException { try (MongoPreparedStatement mongoPreparedStatement = createMongoPreparedStatement(MQL_ITEMS_INSERT)) { mongoPreparedStatement.addBatch(); From 5cfa2ff1977e9c772aebb42704d76eb90225d2f0 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Thu, 30 Oct 2025 12:38:27 -0700 Subject: [PATCH 30/37] Make visibility of WriteModelConverter and WriteModelsToCommandMapper methods package-private. --- .../java/com/mongodb/hibernate/jdbc/MongoStatement.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 1054b63c..82c4776f 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -533,7 +533,7 @@ private static class WriteModelConverter { private WriteModelConverter() {} - private static void convertToWriteModels( + static void convertToWriteModels( CommandDescription commandDescription, BsonDocument command, Collection> writeModels) @@ -648,17 +648,17 @@ private static class WriteModelsToCommandMapper { private int cumulativeCountIndex; - private WriteModelsToCommandMapper(int commandCount) { + WriteModelsToCommandMapper(int commandCount) { this.cumulativeCounts = new int[commandCount]; this.cumulativeCountIndex = 0; } - private void add(int cumulativeWriteModelCount) { + void add(int cumulativeWriteModelCount) { assertFalse(cumulativeCountIndex >= cumulativeCounts.length); cumulativeCounts[cumulativeCountIndex++] = cumulativeWriteModelCount; } - private int findCommandIndex(int writeModelIndex) { + int findCommandIndex(int writeModelIndex) { assertTrue(cumulativeCountIndex == cumulativeCounts.length); var lo = 0; var hi = cumulativeCounts.length; From c7f884faff1b8701b796b5e4a827712eb684f891 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Thu, 30 Oct 2025 12:39:35 -0700 Subject: [PATCH 31/37] use fully qualified class name in TODO for clarity. --- src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 82c4776f..e924d9f5 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -374,7 +374,7 @@ private static SQLException handleExecuteBatchException( exceptionMessage, errorCode, bulkWriteException, assertNotNull(writeModelsToCommandMapper)); } - // TODO-HIBERNATE-132 BatchUpdateException is thrown when one of the commands fails to execute properly. + // TODO-HIBERNATE-132 java.sql.BatchUpdateException is thrown when one of the commands fails to execute properly. // When exception is not of MongoBulkWriteException, we are not sure if any command was executed // successfully or failed. return new SQLException(exceptionMessage, NULL_SQL_STATE, errorCode, exceptionToHandle); From 9286eb7bde341d0480bb32360e45808cdcfa879b Mon Sep 17 00:00:00 2001 From: Viacheslav Babanin Date: Thu, 30 Oct 2025 12:45:28 -0700 Subject: [PATCH 32/37] Update src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java Co-authored-by: Valentin Kovalenko --- .../java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java index 5bdb5794..58776028 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java @@ -186,7 +186,7 @@ public void setObject(int parameterIndex, Object x, int targetSqlType) throws SQ public void addBatch() throws SQLException { checkClosed(); checkAllParametersSet(); - commandBatch.add(command.clone()); + commandBatch.add(parameterValueSetters.isEmpty() ? command : command.clone()); } @Override From 515a8ac94b15981a5bf3bcb13bb5124351010691 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Thu, 30 Oct 2025 13:46:59 -0700 Subject: [PATCH 33/37] Craete helper method for syntax error exceptions. --- .../hibernate/jdbc/MongoStatement.java | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index e924d9f5..9e3ffff2 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -70,6 +70,7 @@ import org.jspecify.annotations.Nullable; class MongoStatement implements StatementAdapter { + private static final String EXCEPTION_MESSAGE_PREFIX_INVALID_MQL = "Invalid MQL"; private static final String EXCEPTION_MESSAGE_OPERATION_FAILED = "Failed to execute operation"; private static final String EXCEPTION_MESSAGE_OPERATION_TIMED_OUT = "Timeout while waiting for operation to complete"; @@ -77,7 +78,6 @@ class MongoStatement implements StatementAdapter { static final int[] EMPTY_UPDATE_COUNTS = new int[0]; static final @Nullable String NULL_SQL_STATE = null; - private static final String EXCEPTION_MESSAGE_PREFIX_INVALID_MQL = "Invalid MQL"; private final MongoDatabase mongoDatabase; private final MongoConnection mongoConnection; @@ -117,7 +117,7 @@ ResultSet executeQuery(BsonDocument command) throws SQLException { .toList(); var projectStageIndex = pipeline.size() - 1; if (pipeline.isEmpty()) { - throw createSyntaxErrorException("%s. $project stage is missing [%s]", command); + throw createSyntaxErrorException("%s. $project stage is missing [%s]", command, null); } var fieldNames = getFieldNamesFromProjectStage( pipeline.get(projectStageIndex).getDocument("$project")); @@ -317,8 +317,8 @@ static void checkSupportedUpdateCommand(BsonDocument command) static BsonDocument parse(String mql) throws SQLSyntaxErrorException { try { return BsonDocument.parse(mql); - } catch (RuntimeException e) { - throw new SQLSyntaxErrorException("%s: [%s]".formatted(EXCEPTION_MESSAGE_PREFIX_INVALID_MQL, mql), e); + } catch (RuntimeException exception) { + throw createSyntaxErrorException("%s: [%s]", mql, exception); } } @@ -374,9 +374,9 @@ private static SQLException handleExecuteBatchException( exceptionMessage, errorCode, bulkWriteException, assertNotNull(writeModelsToCommandMapper)); } - // TODO-HIBERNATE-132 java.sql.BatchUpdateException is thrown when one of the commands fails to execute properly. - // When exception is not of MongoBulkWriteException, we are not sure if any command was executed - // successfully or failed. + // TODO-HIBERNATE-132 java.sql.BatchUpdateException is thrown when one of the + // commands fails to execute properly. When exception is not of MongoBulkWriteException, + // we are not sure if any command was executed successfully or failed. return new SQLException(exceptionMessage, NULL_SQL_STATE, errorCode, exceptionToHandle); } @@ -420,16 +420,14 @@ private static int getErrorCode(RuntimeException runtimeException) { private static SQLSyntaxErrorException createSyntaxErrorException( String exceptionMessageTemplate, BsonDocument command, @Nullable Exception cause) { - return new SQLSyntaxErrorException( - exceptionMessageTemplate.formatted( - EXCEPTION_MESSAGE_PREFIX_INVALID_MQL, command.toJson(EXTENDED_JSON_WRITER_SETTINGS)), - command.toJson(), - cause); + var mql = command.toJson(EXTENDED_JSON_WRITER_SETTINGS); + return createSyntaxErrorException(exceptionMessageTemplate, mql, cause); } private static SQLSyntaxErrorException createSyntaxErrorException( - String exceptionMessageTemplate, BsonDocument command) { - return createSyntaxErrorException(exceptionMessageTemplate, command, null); + String exceptionMessageTemplate, String mql, @Nullable Exception cause) { + return new SQLSyntaxErrorException( + exceptionMessageTemplate.formatted(EXCEPTION_MESSAGE_PREFIX_INVALID_MQL, mql), NULL_SQL_STATE, cause); } private static BatchUpdateException createBatchUpdateException( From 355edfeb90fad7cf70b702862020b6877d42e4ac Mon Sep 17 00:00:00 2001 From: Viacheslav Babanin Date: Thu, 30 Oct 2025 16:04:01 -0700 Subject: [PATCH 34/37] Update src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java Co-authored-by: Valentin Kovalenko --- src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 9e3ffff2..dbfc790e 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -382,7 +382,7 @@ private static SQLException handleExecuteBatchException( private static SQLException handleExecuteQueryOrUpdateException(RuntimeException exceptionToHandle) { var errorCode = getErrorCode(exceptionToHandle); - String exceptionMessage = getExceptionMessage(errorCode, exceptionToHandle); + var exceptionMessage = getExceptionMessage(errorCode, exceptionToHandle); return new SQLException(exceptionMessage, NULL_SQL_STATE, errorCode, exceptionToHandle); } From 14098834faa17083f776059eca0996aadaeb22d9 Mon Sep 17 00:00:00 2001 From: Viacheslav Babanin Date: Thu, 30 Oct 2025 16:04:11 -0700 Subject: [PATCH 35/37] Update src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java Co-authored-by: Valentin Kovalenko --- src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index dbfc790e..b6e48815 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -368,7 +368,7 @@ private static int getUpdateCount(CommandDescription commandDescription, BulkWri private static SQLException handleExecuteBatchException( RuntimeException exceptionToHandle, @Nullable WriteModelsToCommandMapper writeModelsToCommandMapper) { var errorCode = getErrorCode(exceptionToHandle); - String exceptionMessage = getExceptionMessage(errorCode, exceptionToHandle); + var exceptionMessage = getExceptionMessage(errorCode, exceptionToHandle); if (exceptionToHandle instanceof MongoBulkWriteException bulkWriteException) { return createBatchUpdateException( exceptionMessage, errorCode, bulkWriteException, assertNotNull(writeModelsToCommandMapper)); From 5e4503b3aea2740688151e8fd03147a7b1a50c8a Mon Sep 17 00:00:00 2001 From: Viacheslav Babanin Date: Thu, 30 Oct 2025 16:04:31 -0700 Subject: [PATCH 36/37] Update src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java Co-authored-by: Valentin Kovalenko --- src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index b6e48815..c657d6b8 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -436,9 +436,8 @@ private static BatchUpdateException createBatchUpdateException( MongoBulkWriteException mongoBulkWriteException, WriteModelsToCommandMapper writeModelsToCommandMapper) { var updateCounts = calculateBatchUpdateCounts(mongoBulkWriteException, writeModelsToCommandMapper); - var batchUpdateException = new BatchUpdateException( + return new BatchUpdateException( exceptionMessage, NULL_SQL_STATE, errorCode, updateCounts, mongoBulkWriteException); - return batchUpdateException; } private static int[] calculateBatchUpdateCounts( From 0a3ec1af3f17e0beaadafaf47277325577320f41 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Thu, 30 Oct 2025 16:23:19 -0700 Subject: [PATCH 37/37] Change to SQLFeatureNotSupportedException. --- .../query/AbstractQueryIntegrationTests.java | 9 +++++---- .../select/SimpleSelectQueryIntegrationTests.java | 15 ++++++++++----- .../hibernate/jdbc/MongoPreparedStatement.java | 6 +++--- .../mongodb/hibernate/jdbc/MongoStatement.java | 13 ------------- .../hibernate/jdbc/MongoStatementTests.java | 1 - 5 files changed, 18 insertions(+), 26 deletions(-) diff --git a/src/integrationTest/java/com/mongodb/hibernate/query/AbstractQueryIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/query/AbstractQueryIntegrationTests.java index 6f01e881..b8aa009e 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/query/AbstractQueryIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/query/AbstractQueryIntegrationTests.java @@ -30,6 +30,7 @@ import com.mongodb.hibernate.junit.MongoExtension; import java.util.Set; import java.util.function.Consumer; +import org.assertj.core.api.AbstractThrowableAssert; import org.assertj.core.api.InstanceOfAssertFactories; import org.bson.BsonDocument; import org.hibernate.engine.jdbc.dialect.spi.DialectResolutionInfo; @@ -144,14 +145,14 @@ protected void assertSelectionQuery( assertSelectionQuery(hql, resultType, null, expectedMql, resultListVerifier, expectedAffectedCollections); } - protected void assertSelectQueryFailure( + protected AbstractThrowableAssert assertSelectQueryFailure( String hql, Class resultType, Consumer> queryPostProcessor, Class expectedExceptionType, - String expectedExceptionMessage, + String expectedExceptionMessageSubstring, Object... expectedExceptionMessageParameters) { - sessionFactoryScope.inTransaction(session -> assertThatThrownBy(() -> { + return sessionFactoryScope.fromTransaction(session -> assertThatThrownBy(() -> { var selectionQuery = session.createSelectionQuery(hql, resultType); if (queryPostProcessor != null) { queryPostProcessor.accept(selectionQuery); @@ -159,7 +160,7 @@ protected void assertSelectQueryFailure( selectionQuery.getResultList(); }) .isInstanceOf(expectedExceptionType) - .hasMessage(expectedExceptionMessage, expectedExceptionMessageParameters)); + .hasMessageContaining(expectedExceptionMessageSubstring, expectedExceptionMessageParameters)); } protected void assertSelectQueryFailure( diff --git a/src/integrationTest/java/com/mongodb/hibernate/query/select/SimpleSelectQueryIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/query/select/SimpleSelectQueryIntegrationTests.java index 52c91fcb..78c78813 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/query/select/SimpleSelectQueryIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/query/select/SimpleSelectQueryIntegrationTests.java @@ -29,9 +29,11 @@ import jakarta.persistence.Id; import jakarta.persistence.Table; import java.math.BigDecimal; +import java.sql.SQLFeatureNotSupportedException; import java.util.Arrays; import java.util.List; import java.util.Set; +import org.hibernate.JDBCException; import org.hibernate.query.SemanticException; import org.hibernate.testing.orm.junit.DomainModel; import org.junit.jupiter.api.BeforeEach; @@ -710,11 +712,14 @@ void testComparisonBetweenParametersNotSupported() { @Test void testNullParameterNotSupported() { assertSelectQueryFailure( - "from Contact where country != :country", - Contact.class, - q -> q.setParameter("country", null), - FeatureNotSupportedException.class, - "TODO-HIBERNATE-74 https://jira.mongodb.org/browse/HIBERNATE-74"); + "from Contact where country != :country", + Contact.class, + q -> q.setParameter("country", null), + JDBCException.class, + "JDBC exception executing SQL") + .cause() + .isInstanceOf(SQLFeatureNotSupportedException.class) + .hasMessage("TODO-HIBERNATE-74 https://jira.mongodb.org/browse/HIBERNATE-74"); } @Test diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java index 58776028..b7e49ce9 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java @@ -23,7 +23,6 @@ import com.mongodb.client.ClientSession; import com.mongodb.client.MongoDatabase; -import com.mongodb.hibernate.internal.FeatureNotSupportedException; import com.mongodb.hibernate.internal.dialect.MongoAggregateSupport; import com.mongodb.hibernate.internal.type.MongoStructJdbcType; import com.mongodb.hibernate.internal.type.ObjectIdJdbcType; @@ -345,12 +344,13 @@ boolean isUsed() { * *

Note that only find expression is involved before HIBERNATE-74. TODO-HIBERNATE-74 delete this temporary method */ - private static void checkComparatorNotComparingWithNullValues(BsonDocument document) { + private static void checkComparatorNotComparingWithNullValues(BsonDocument document) + throws SQLFeatureNotSupportedException { var comparisonOperators = Set.of("$ne", "$gt", "$gte", "$lt", "$lte", "$in", "$nin"); for (var entry : document.entrySet()) { var value = entry.getValue(); if (value.isNull() && comparisonOperators.contains(entry.getKey())) { - throw new FeatureNotSupportedException( + throw new SQLFeatureNotSupportedException( "TODO-HIBERNATE-74 https://jira.mongodb.org/browse/HIBERNATE-74"); } if (value instanceof BsonDocument documentValue) { diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index c657d6b8..c6aa169b 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -223,12 +223,6 @@ public void close() throws SQLException { } } - @Override - public void cancel() throws SQLException { - checkClosed(); - throw new SQLFeatureNotSupportedException(); - } - @Override public @Nullable SQLWarning getWarnings() throws SQLException { checkClosed(); @@ -266,13 +260,6 @@ public int getUpdateCount() throws SQLException { throw new SQLFeatureNotSupportedException("TODO-HIBERNATE-66 https://jira.mongodb.org/browse/HIBERNATE-66"); } - @Override - public void addBatch(String mql) throws SQLException { - checkClosed(); - MongoAggregateSupport.checkSupported(mql); - throw new SQLFeatureNotSupportedException(); - } - @Override public Connection getConnection() throws SQLException { checkClosed(); diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java index 388412ba..4b57d7b2 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java @@ -274,7 +274,6 @@ private void checkMethodsWithOpenPrecondition() { assertAll( () -> assertThrowsClosedException(() -> mongoStatement.executeQuery(exampleQueryMql)), () -> assertThrowsClosedException(() -> mongoStatement.executeUpdate(exampleUpdateMql)), - () -> assertThrowsClosedException(mongoStatement::cancel), () -> assertThrowsClosedException(mongoStatement::getWarnings), () -> assertThrowsClosedException(mongoStatement::clearWarnings), () -> assertThrowsClosedException(() -> mongoStatement.execute(exampleUpdateMql)),