5252import io .cdap .plugin .gcp .spanner .SpannerArrayConstants ;
5353import io .cdap .plugin .gcp .spanner .SpannerConstants ;
5454import io .cdap .plugin .gcp .spanner .common .SpannerUtil ;
55+ import org .apache .commons .lang3 .StringUtils ;
5556import org .apache .hadoop .conf .Configuration ;
5657import org .apache .hadoop .io .NullWritable ;
5758import org .slf4j .Logger ;
6263import java .io .ObjectOutputStream ;
6364import java .util .ArrayList ;
6465import java .util .Base64 ;
66+ import java .util .HashMap ;
6567import java .util .List ;
68+ import java .util .Map ;
6669import java .util .concurrent .TimeUnit ;
70+ import java .util .regex .Pattern ;
6771import java .util .stream .Collectors ;
6872import javax .annotation .Nullable ;
6973
@@ -82,6 +86,8 @@ public class SpannerSource extends BatchSource<NullWritable, ResultSet, Structur
8286 private static final Statement .Builder SCHEMA_STATEMENT_BUILDER = Statement .newBuilder (
8387 String .format ("SELECT t.column_name,t.spanner_type, t.is_nullable FROM information_schema.columns AS t WHERE " +
8488 " t.table_catalog = '' AND t.table_schema = '' AND t.table_name = @%s" , TABLE_NAME ));
89+ private static final String LIMIT = "limit" ;
90+
8591 public static final String NAME = "Spanner" ;
8692 private final SpannerSourceConfig config ;
8793 private Schema schema ;
@@ -243,26 +249,55 @@ private Schema getSchema(FailureCollector collector) {
243249 projectId )) {
244250 DatabaseClient databaseClient =
245251 spanner .getDatabaseClient (DatabaseId .of (projectId , config .instance , config .database ));
246- Statement getTableSchemaStatement = SCHEMA_STATEMENT_BUILDER .bind (TABLE_NAME ).to (config .table ).build ();
247- try (ResultSet resultSet = databaseClient .singleUse ().executeQuery (getTableSchemaStatement )) {
248- List <Schema .Field > schemaFields = new ArrayList <>();
249- while (resultSet .next ()) {
250- String columnName = resultSet .getString ("column_name" );
251- String spannerType = resultSet .getString ("spanner_type" );
252- String nullable = resultSet .getString ("is_nullable" );
253- boolean isNullable = "YES" .equals (nullable );
254- Schema typeSchema = parseSchemaFromSpannerTypeString (columnName , spannerType , collector );
255- if (typeSchema == null ) {
256- // this means there were failures added to failure collector. Continue to collect more failures
257- continue ;
252+ if (Strings .isNullOrEmpty (config .importQuery )) {
253+ Statement getTableSchemaStatement = SCHEMA_STATEMENT_BUILDER .bind (TABLE_NAME ).to (config .table ).build ();
254+ try (ResultSet resultSet = databaseClient .singleUse ().executeQuery (getTableSchemaStatement )) {
255+ List <Schema .Field > schemaFields = new ArrayList <>();
256+ while (resultSet .next ()) {
257+ String columnName = resultSet .getString ("column_name" );
258+ String spannerType = resultSet .getString ("spanner_type" );
259+ String nullable = resultSet .getString ("is_nullable" );
260+ boolean isNullable = "YES" .equals (nullable );
261+ Schema typeSchema = parseSchemaFromSpannerTypeString (columnName , spannerType , collector );
262+ if (typeSchema == null ) {
263+ // this means there were failures added to failure collector. Continue to collect more failures
264+ continue ;
265+ }
266+ Schema fieldSchema = isNullable ? Schema .nullableOf (typeSchema ) : typeSchema ;
267+ schemaFields .add (Schema .Field .of (columnName , fieldSchema ));
268+ }
269+ if (schemaFields .isEmpty () && !collector .getValidationFailures ().isEmpty ()) {
270+ collector .getOrThrowException ();
258271 }
259- Schema fieldSchema = isNullable ? Schema .nullableOf (typeSchema ) : typeSchema ;
260- schemaFields .add (Schema .Field .of (columnName , fieldSchema ));
272+ return Schema .recordOf ("outputSchema" , schemaFields );
261273 }
262- if (schemaFields .isEmpty () && !collector .getValidationFailures ().isEmpty ()) {
263- collector .getOrThrowException ();
274+ } else {
275+ final Map <String , Boolean > nullableFields = getFieldsNullability (databaseClient );
276+ Statement importQueryStatement = getStatementForOneRow (config .importQuery );
277+ List <Schema .Field > schemaFields = new ArrayList <>();
278+ try (ResultSet resultSet = databaseClient .singleUse ().executeQuery (importQueryStatement )) {
279+ while (resultSet .next ()) {
280+ final List <Type .StructField > structFields = resultSet .getCurrentRowAsStruct ().getType ().getStructFields ();
281+ for (Type .StructField structField : structFields ) {
282+ final Type fieldSpannerType = structField .getType ();
283+ final String columnName = structField .getName ();
284+ // there are cases when column name is not in metadata table such as "Select FirstName as name",
285+ // so fallback is nullable
286+ final boolean isNullable = nullableFields .getOrDefault (columnName , true );
287+ final Schema typeSchema = parseSchemaFromSpannerType (fieldSpannerType , columnName , collector );
288+ if (typeSchema == null ) {
289+ // this means there were failures added to failure collector. Continue to collect more failures
290+ continue ;
291+ }
292+ Schema fieldSchema = isNullable ? Schema .nullableOf (typeSchema ) : typeSchema ;
293+ schemaFields .add (Schema .Field .of (columnName , fieldSchema ));
294+ }
295+ }
296+ if (schemaFields .isEmpty () && !collector .getValidationFailures ().isEmpty ()) {
297+ collector .getOrThrowException ();
298+ }
299+ return Schema .recordOf ("outputSchema" , schemaFields );
264300 }
265- return Schema .recordOf ("outputSchema" , schemaFields );
266301 }
267302 } catch (IOException e ) {
268303 collector .addFailure ("Unable to get Spanner Client: " + e .getMessage (), null )
@@ -274,6 +309,22 @@ private Schema getSchema(FailureCollector collector) {
274309
275310 }
276311
312+ private Statement getStatementForOneRow (String importQuery ) {
313+ String query ;
314+ // Matches any String containing the word 'limit' followed by a number
315+ // ex: SELECT NAME FROM TABLE LIMIT 15
316+ String regex = "^(?:[^;']|(?:'[^']+'))+ LIMIT +\\ d+(.*)" ;
317+ Pattern pattern = Pattern .compile (regex , Pattern .MULTILINE | Pattern .CASE_INSENSITIVE );
318+ if (pattern .matcher (importQuery ).matches ()) {
319+ int index = StringUtils .lastIndexOf (importQuery , LIMIT );
320+ String substringToReplace = importQuery .substring (index );
321+ query = importQuery .replace (substringToReplace , "limit 1" );
322+ } else {
323+ query = String .format ("%s limit 1" , importQuery );
324+ }
325+ return Statement .newBuilder (query ).build ();
326+ }
327+
277328 @ Nullable
278329 private Schema parseSchemaFromSpannerTypeString (String columnName ,
279330 String spannerType , FailureCollector collector ) {
@@ -323,4 +374,69 @@ private Schema parseSchemaFromSpannerTypeString(String columnName,
323374 }
324375 return null ;
325376 }
377+
378+ @ Nullable
379+ Schema parseSchemaFromSpannerType (Type spannerType , String columnName , FailureCollector collector ) {
380+ final Type .Code code = spannerType .getCode ();
381+
382+ if (code == Type .Code .ARRAY ) {
383+ final Type arrayElementType = spannerType .getArrayElementType ();
384+ final Type .Code arrayElementTypeCode = arrayElementType .getCode ();
385+ switch (arrayElementTypeCode ) {
386+ case BOOL :
387+ return Schema .arrayOf (Schema .of (Schema .Type .BOOLEAN ));
388+ case INT64 :
389+ return Schema .arrayOf (Schema .of (Schema .Type .LONG ));
390+ case FLOAT64 :
391+ return Schema .arrayOf (Schema .of (Schema .Type .DOUBLE ));
392+ case STRING :
393+ return Schema .arrayOf (Schema .of (Schema .Type .STRING ));
394+ case BYTES :
395+ return Schema .arrayOf (Schema .of (Schema .Type .BYTES ));
396+ case TIMESTAMP :
397+ return Schema .arrayOf (Schema .of (Schema .LogicalType .TIMESTAMP_MICROS ));
398+ case DATE :
399+ return Schema .arrayOf (Schema .of (Schema .LogicalType .DATE ));
400+ default :
401+ collector .addFailure (String .format ("Column '%s' has unsupported type '%s'." , columnName , spannerType ), null );
402+ return null ;
403+ }
404+ } else {
405+ switch (code ) {
406+ case BOOL :
407+ return Schema .of (Schema .Type .BOOLEAN );
408+ case INT64 :
409+ return Schema .of (Schema .Type .LONG );
410+ case FLOAT64 :
411+ return Schema .of (Schema .Type .DOUBLE );
412+ case STRING :
413+ return Schema .of (Schema .Type .STRING );
414+ case BYTES :
415+ return Schema .of (Schema .Type .BYTES );
416+ case TIMESTAMP :
417+ return Schema .of (Schema .LogicalType .TIMESTAMP_MICROS );
418+ case DATE :
419+ return Schema .of (Schema .LogicalType .DATE );
420+ default :
421+ collector .addFailure (String .format ("Column '%s' has unsupported type '%s'." , columnName , spannerType ), null );
422+ return null ;
423+ }
424+ }
425+ }
426+
427+ /** Get from table metadata nullability for each field
428+ * @param databaseClient Database Client
429+ * @return Map where key is field name and value is nullability true or false
430+ */
431+ private Map <String , Boolean > getFieldsNullability (DatabaseClient databaseClient ) {
432+ Statement tableMetadataStatement = SCHEMA_STATEMENT_BUILDER .bind (TABLE_NAME ).to (config .table ).build ();
433+ Map <String , Boolean > nullableState = new HashMap <>();
434+ ResultSet resultSet = databaseClient .singleUse ().executeQuery (tableMetadataStatement );
435+ while (resultSet .next ()) {
436+ String columnName = resultSet .getString ("column_name" );
437+ String nullable = resultSet .getString ("is_nullable" );
438+ nullableState .put (columnName , "YES" .equals (nullable ));
439+ }
440+ return nullableState ;
441+ }
326442}
0 commit comments