Skip to content

Commit 622ac37

Browse files
committed
-add unit test
1 parent 2e6edfd commit 622ac37

File tree

4 files changed

+107
-29
lines changed

4 files changed

+107
-29
lines changed

athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandler.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -459,14 +459,13 @@ private GetSplitsResponse handleS3ExportSplits(GetSplitsRequest request)
459459

460460
// Get the SQL statement which was created in getPartitions
461461
FieldReader fieldReaderQid = request.getPartitions().getFieldReader(S3_PATH_PREFIX);
462-
String s3Path_prefix = fieldReaderQid.readText().toString();
462+
String s3PathPrefix = fieldReaderQid.readText().toString();
463463

464464
FieldReader fieldReaderPreparedStmt = request.getPartitions().getFieldReader(PREPARED_STMT);
465465
String preparedStmt = fieldReaderPreparedStmt.readText().toString();
466466
LOGGER.debug("doGetSplits: Catalog {}, table {}, s3ExportBucketPath:{}, preparedStmt:{}", request.getTableName().getSchemaName(),
467467
request.getTableName().getTableName(), s3ExportBucketPath, preparedStmt);
468468

469-
470469
try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider());
471470
PreparedStatement preparedStatement = new PreparedStatementBuilder()
472471
.withConnection(connection)
@@ -475,7 +474,7 @@ private GetSplitsResponse handleS3ExportSplits(GetSplitsRequest request)
475474
request.getTableName().getTableName()))
476475
.build()) {
477476
// get S3 URI by combining s3ExportPath with AthenaQueryId.
478-
URI uri = URI.create(s3Path_prefix);
477+
URI uri = URI.create(s3PathPrefix);
479478
S3Uri s3Uri = amazonS3.utilities().parseUri(uri);
480479

481480
// List S3 Object summary first, in case same table has been reference multiple time
@@ -487,7 +486,8 @@ private GetSplitsResponse handleS3ExportSplits(GetSplitsRequest request)
487486
if (s3ObjectSummaries.isEmpty()) {
488487
try {
489488
preparedStatement.execute();
490-
} catch (SnowflakeSQLException snowflakeSQLException) {
489+
}
490+
catch (SnowflakeSQLException snowflakeSQLException) {
491491
// handle race condition on another splits already start the copy into statement
492492
if (!snowflakeSQLException.getMessage().contains("Files already existing")) {
493493
throw new RuntimeException("Exception in execution export statement " + snowflakeSQLException.getMessage(), snowflakeSQLException);
@@ -564,7 +564,8 @@ protected List<TableName> getPaginatedTables(Connection connection, String datab
564564
/*
565565
* Get the list of all the exported S3 objects
566566
*/
567-
private List<S3Object> getlistExportedObjects(String s3ExportBucketName, String prefix)
567+
@VisibleForTesting
568+
List<S3Object> getlistExportedObjects(String s3ExportBucketName, String prefix)
568569
{
569570
ListObjectsResponse listObjectsResponse;
570571
try {

athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeQueryStringBuilder.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import com.amazonaws.athena.connectors.jdbc.manager.FederationExpressionParser;
2525
import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder;
2626
import com.amazonaws.athena.connectors.jdbc.manager.TypeAndValue;
27+
import com.google.common.annotations.VisibleForTesting;
2728
import com.google.common.base.Strings;
2829
import org.apache.arrow.vector.types.Types;
2930
import org.apache.arrow.vector.types.pojo.Field;
@@ -97,7 +98,8 @@ public String getBaseExportSQLString(
9798
return sqlBaseString;
9899
}
99100

100-
private String expandSql(String sql, List<TypeAndValue> accumulator)
101+
@VisibleForTesting
102+
String expandSql(String sql, List<TypeAndValue> accumulator)
101103
{
102104
if (sql == null) {
103105
return null;

athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandlerTest.java

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,10 @@ public void getPartitions() throws Exception {
297297
GetTableLayoutResponse res = snowflakeMetadataHandlerMocked.doGetTableLayout(allocator, req);
298298
Block partitions = res.getPartitions();
299299

300+
assertNotNull(partitions);
300301
assertTrue(partitions.getRowCount() > 0);
301302
assertNotNull(partitions.getFieldVector("preparedStmt"));
302-
assertNotNull(partitions.getFieldVector("queryId"));
303+
assertNotNull(partitions.getFieldVector("s3_path"));
303304
}
304305

305306
@Test
@@ -561,4 +562,36 @@ public void testGetPartitionsForView() throws Exception {
561562
assertEquals(1, partitions.getRowCount());
562563
assertEquals("*", partitions.getFieldVector(BLOCK_PARTITION_COLUMN_NAME).getObject(0).toString());
563564
}
565+
566+
@Test
567+
public void testGetlistExportedObjects_S3Path() {
568+
System.setProperty("aws_region", "us-east-1");
569+
List<S3Object> objectList = new ArrayList<>();
570+
S3Object obj1 = S3Object.builder().key("queryId123/file1.parquet").build();
571+
S3Object obj2 = S3Object.builder().key("queryId123/file2.parquet").build();
572+
objectList.add(obj1);
573+
objectList.add(obj2);
574+
575+
ListObjectsResponse response = ListObjectsResponse.builder()
576+
.contents(objectList)
577+
.build();
578+
579+
when(mockS3.listObjects(any(ListObjectsRequest.class))).thenReturn(response);
580+
581+
List<S3Object> result = snowflakeMetadataHandler.getlistExportedObjects("test-bucket", "queryId123");
582+
assertEquals(2, result.size());
583+
assertEquals("queryId123/file1.parquet", result.get(0).key());
584+
assertEquals("queryId123/file2.parquet", result.get(1).key());
585+
}
586+
587+
@Test(expected = RuntimeException.class)
588+
public void testGetlistExportedObjects_S3Exception() {
589+
System.setProperty("aws_region", "us-east-1");
590+
when(mockS3.listObjects(any(ListObjectsRequest.class)))
591+
.thenThrow(software.amazon.awssdk.services.s3.model.S3Exception.builder()
592+
.message("Access denied")
593+
.build());
594+
595+
snowflakeMetadataHandler.getlistExportedObjects("test-bucket", "queryId123");
596+
}
564597
}

athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeQueryStringBuilderTest.java

Lines changed: 64 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints;
2323
import com.amazonaws.athena.connectors.jdbc.manager.TypeAndValue;
2424
import org.apache.arrow.vector.types.DateUnit;
25+
import org.apache.arrow.vector.types.FloatingPointPrecision;
26+
import org.apache.arrow.vector.types.TimeUnit;
2527
import org.apache.arrow.vector.types.pojo.ArrowType;
2628
import org.apache.arrow.vector.types.pojo.Field;
2729
import org.apache.arrow.vector.types.pojo.FieldType;
@@ -68,28 +70,6 @@ void testGetPartitionWhereClauses() {
6870
assertTrue(result.isEmpty());
6971
}
7072

71-
// @Test
72-
// void testBuildSqlString_NoConstraints() throws SQLException {
73-
// Schema tableSchema = new Schema(List.of(new Field("id", new FieldType(true, new ArrowType.Int(32, true), null), null)));
74-
//
75-
// Constraints constraints = mock(Constraints.class);
76-
// when(constraints.getLimit()).thenReturn(0L);
77-
//
78-
// String sql = queryBuilder.buildSqlString(mockConnection, null, "public", "users", tableSchema, constraints, null);
79-
// assertTrue(sql.contains("SELECT \"id\" FROM \"public\".\"users\" "));
80-
// }
81-
82-
// @Test
83-
// void testBuildSqlString_WithConstraints() throws SQLException {
84-
// Schema tableSchema = new Schema(List.of(new Field("id", new FieldType(true, new ArrowType.Int(32, true), null), null)));
85-
//
86-
// Constraints constraints = mock(Constraints.class);
87-
// when(constraints.getLimit()).thenReturn(10L);
88-
//
89-
// String sql = queryBuilder.buildSqlString(mockConnection, null, "public", "users", tableSchema, constraints, null);
90-
// assertTrue(sql.contains("LIMIT 10"));
91-
// }
92-
9373
@Test
9474
void testQuote() {
9575
String result = queryBuilder.quote("users");
@@ -101,4 +81,66 @@ void testSingleQuote() {
10181
String result = queryBuilder.singleQuote("O'Reilly");
10282
assertEquals("'O''Reilly'", result);
10383
}
84+
85+
@Test
86+
void testExpandSql_NullInput() {
87+
String result = queryBuilder.expandSql(null, new ArrayList<>());
88+
assertEquals(null, result);
89+
}
90+
91+
@Test
92+
void testExpandSql_NullValue() {
93+
List<TypeAndValue> accumulator = new ArrayList<>();
94+
// Skip null value test as TypeAndValue constructor doesn't allow null values
95+
String result = queryBuilder.expandSql("SELECT * FROM table WHERE col = ?", accumulator);
96+
assertEquals("SELECT * FROM table WHERE col = ?", result);
97+
}
98+
99+
@Test
100+
void testExpandSql_NumberTypes() {
101+
List<TypeAndValue> accumulator = new ArrayList<>();
102+
accumulator.add(new TypeAndValue(new ArrowType.Int(32, true), 123));
103+
accumulator.add(new TypeAndValue(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE), 45.67));
104+
105+
String result = queryBuilder.expandSql("SELECT * FROM table WHERE col1 = ? AND col2 = ?", accumulator);
106+
assertEquals("SELECT * FROM table WHERE col1 = 123 AND col2 = 45.67", result);
107+
}
108+
109+
@Test
110+
void testExpandSql_DateDay() {
111+
List<TypeAndValue> accumulator = new ArrayList<>();
112+
accumulator.add(new TypeAndValue(new ArrowType.Date(DateUnit.DAY), 19000L)); // Days since epoch
113+
114+
String result = queryBuilder.expandSql("SELECT * FROM table WHERE date_col = ?", accumulator);
115+
assertTrue(result.contains("DATE '"));
116+
}
117+
118+
@Test
119+
void testExpandSql_DateMilli() {
120+
List<TypeAndValue> accumulator = new ArrayList<>();
121+
accumulator.add(new TypeAndValue(new ArrowType.Timestamp(TimeUnit.MILLISECOND, null), 1642464000000L)); // Millis since epoch
122+
123+
String result = queryBuilder.expandSql("SELECT * FROM table WHERE timestamp_col = ?", accumulator);
124+
// The timestamp gets converted to string format, not TIMESTAMP literal
125+
assertTrue(result.contains("'1642464000000'"));
126+
}
127+
128+
@Test
129+
void testExpandSql_TimestampTypes() {
130+
List<TypeAndValue> accumulator = new ArrayList<>();
131+
accumulator.add(new TypeAndValue(new ArrowType.Timestamp(TimeUnit.MILLISECOND, "UTC"), 1642464000000L));
132+
accumulator.add(new TypeAndValue(new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC"), new java.sql.Timestamp(1642464000000L)));
133+
134+
String result = queryBuilder.expandSql("SELECT * FROM table WHERE ts1 = ? AND ts2 = ?", accumulator);
135+
assertTrue(result.contains("TIMESTAMP '"));
136+
}
137+
138+
@Test
139+
void testExpandSql_StringValue() {
140+
List<TypeAndValue> accumulator = new ArrayList<>();
141+
accumulator.add(new TypeAndValue(new ArrowType.Utf8(), "test'value"));
142+
143+
String result = queryBuilder.expandSql("SELECT * FROM table WHERE str_col = ?", accumulator);
144+
assertEquals("SELECT * FROM table WHERE str_col = 'test''value'", result);
145+
}
104146
}

0 commit comments

Comments
 (0)