From 5471d955c3e94f68aced979ddf8ab6d98fd9a098 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Thu, 17 Oct 2024 12:44:26 -0600 Subject: [PATCH] feat(go/adbc/driver/snowflake): improve GetObjects performance and semantics (#2254) Fixes #2171 Improves the channel handling and query building for metadata conversion to Arrow for better performance. For all cases except when retrieving Column metadata we'll now utilize `SHOW` queries and build the patterns into those queries. This allows those `GetObjects` calls with appropriate depths to be called without having to specify a current database or schema. --- c/driver/flightsql/sqlite_flightsql_test.cc | 1 + c/driver/snowflake/snowflake_test.cc | 12 +- c/validation/adbc_validation.h | 6 + c/validation/adbc_validation_connection.cc | 27 +- c/validation/adbc_validation_statement.cc | 16 +- .../driver/internal/driverbase/connection.go | 51 ++- go/adbc/driver/snowflake/connection.go | 334 ++++++++++++------ go/adbc/driver/snowflake/driver_test.go | 11 +- .../snowflake/queries/get_objects_all.sql | 18 +- .../queries/get_objects_catalogs.sql | 25 -- .../queries/get_objects_dbschemas.sql | 28 +- .../snowflake/queries/get_objects_tables.sql | 112 ++---- go/adbc/driver/snowflake/statement.go | 4 +- 13 files changed, 365 insertions(+), 280 deletions(-) delete mode 100644 go/adbc/driver/snowflake/queries/get_objects_catalogs.sql diff --git a/c/driver/flightsql/sqlite_flightsql_test.cc b/c/driver/flightsql/sqlite_flightsql_test.cc index 454ea02977..40601e2803 100644 --- a/c/driver/flightsql/sqlite_flightsql_test.cc +++ b/c/driver/flightsql/sqlite_flightsql_test.cc @@ -121,6 +121,7 @@ class SqliteFlightSqlQuirks : public adbc_validation::DriverQuirks { bool supports_get_objects() const override { return true; } bool supports_partitioned_data() const override { return true; } bool supports_dynamic_parameter_binding() const override { return true; } + std::string catalog() const { return "main"; } }; class SqliteFlightSqlTest : public ::testing::Test, public adbc_validation::DatabaseTest { diff --git a/c/driver/snowflake/snowflake_test.cc b/c/driver/snowflake/snowflake_test.cc index 60003353da..262286192a 100644 --- a/c/driver/snowflake/snowflake_test.cc +++ b/c/driver/snowflake/snowflake_test.cc @@ -99,7 +99,7 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { adbc_validation::Handle statement; CHECK_OK(AdbcStatementNew(connection, &statement.value, error)); - std::string create = "CREATE TABLE \""; + std::string create = "CREATE OR REPLACE TABLE \""; create += name; create += "\" (int64s INT, strings TEXT)"; CHECK_OK(AdbcStatementSetSqlQuery(&statement.value, create.c_str(), error)); @@ -131,7 +131,13 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { return NANOARROW_TYPE_DOUBLE; case NANOARROW_TYPE_STRING: case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_LIST: + case NANOARROW_TYPE_LARGE_LIST: return NANOARROW_TYPE_STRING; + case NANOARROW_TYPE_BINARY: + case NANOARROW_TYPE_LARGE_BINARY: + case NANOARROW_TYPE_FIXED_SIZE_BINARY: + return NANOARROW_TYPE_BINARY; default: return ingest_type; } @@ -149,7 +155,11 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { bool supports_dynamic_parameter_binding() const override { return true; } bool supports_error_on_incompatible_schema() const override { return false; } bool ddl_implicit_commit_txn() const override { return true; } + bool supports_ingest_view_types() const override { return false; } + bool supports_ingest_float16() const override { return false; } + std::string db_schema() const override { return schema_; } + std::string catalog() const override { return "ADBC_TESTING"; } const char* uri_; bool skip_{false}; diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h index fa3c1cdccb..f8ef350cc2 100644 --- a/c/validation/adbc_validation.h +++ b/c/validation/adbc_validation.h @@ -238,6 +238,12 @@ class DriverQuirks { /// column matching. virtual bool supports_error_on_incompatible_schema() const { return true; } + /// \brief Whether ingestion supports StringView/BinaryView types + virtual bool supports_ingest_view_types() const { return true; } + + /// \brief Whether ingestion supports Float16 + virtual bool supports_ingest_float16() const { return true; } + /// \brief Default catalog to use for tests virtual std::string catalog() const { return ""; } diff --git a/c/validation/adbc_validation_connection.cc b/c/validation/adbc_validation_connection.cc index a885fa2c86..032f1d328f 100644 --- a/c/validation/adbc_validation_connection.cc +++ b/c/validation/adbc_validation_connection.cc @@ -744,13 +744,15 @@ void ConnectionTest::TestMetadataGetObjectsColumns() { struct TestCase { std::optional filter; - std::vector column_names; - std::vector ordinal_positions; + // the pair is column name & ordinal position of the column + std::vector> columns; }; std::vector test_cases; - test_cases.push_back({std::nullopt, {"int64s", "strings"}, {1, 2}}); - test_cases.push_back({"in%", {"int64s"}, {1}}); + test_cases.push_back({std::nullopt, {{"int64s", 1}, {"strings", 2}}}); + test_cases.push_back({"in%", {{"int64s", 1}}}); + + const std::string catalog = quirks()->catalog(); for (const auto& test_case : test_cases) { std::string scope = "Filter: "; @@ -758,13 +760,14 @@ void ConnectionTest::TestMetadataGetObjectsColumns() { SCOPED_TRACE(scope); StreamReader reader; + std::vector> columns; std::vector column_names; std::vector ordinal_positions; ASSERT_THAT( AdbcConnectionGetObjects( - &connection, ADBC_OBJECT_DEPTH_COLUMNS, nullptr, nullptr, nullptr, nullptr, - test_case.filter.has_value() ? test_case.filter->c_str() : nullptr, + &connection, ADBC_OBJECT_DEPTH_COLUMNS, catalog.c_str(), nullptr, nullptr, + nullptr, test_case.filter.has_value() ? test_case.filter->c_str() : nullptr, &reader.stream.value, &error), IsOkStatus(&error)); ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); @@ -834,10 +837,9 @@ void ConnectionTest::TestMetadataGetObjectsColumns() { std::string temp(name.data, name.size_bytes); std::transform(temp.begin(), temp.end(), temp.begin(), [](unsigned char c) { return std::tolower(c); }); - column_names.push_back(std::move(temp)); - ordinal_positions.push_back( - static_cast(ArrowArrayViewGetIntUnsafe( - table_columns->children[1], columns_index))); + columns.emplace_back(std::move(temp), + static_cast(ArrowArrayViewGetIntUnsafe( + table_columns->children[1], columns_index))); } } } @@ -847,8 +849,9 @@ void ConnectionTest::TestMetadataGetObjectsColumns() { } while (reader.array->release); ASSERT_TRUE(found_expected_table) << "Did (not) find table in metadata"; - ASSERT_EQ(test_case.column_names, column_names); - ASSERT_EQ(test_case.ordinal_positions, ordinal_positions); + // metadata columns do not guarantee the order they are returned in, just + // validate all the elements are there. + ASSERT_THAT(columns, testing::UnorderedElementsAreArray(test_case.columns)); } } diff --git a/c/validation/adbc_validation_statement.cc b/c/validation/adbc_validation_statement.cc index 07ab0b22af..150aabf327 100644 --- a/c/validation/adbc_validation_statement.cc +++ b/c/validation/adbc_validation_statement.cc @@ -246,6 +246,10 @@ void StatementTest::TestSqlIngestInt64() { } void StatementTest::TestSqlIngestFloat16() { + if (!quirks()->supports_ingest_float16()) { + GTEST_SKIP(); + } + ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType(NANOARROW_TYPE_HALF_FLOAT)); } @@ -268,6 +272,10 @@ void StatementTest::TestSqlIngestLargeString() { } void StatementTest::TestSqlIngestStringView() { + if (!quirks()->supports_ingest_view_types()) { + GTEST_SKIP(); + } + ASSERT_NO_FATAL_FAILURE(TestSqlIngestType( NANOARROW_TYPE_STRING_VIEW, {std::nullopt, "", "", "longer than 12 bytes", "δΎ‹"}, false)); @@ -302,6 +310,10 @@ void StatementTest::TestSqlIngestFixedSizeBinary() { } void StatementTest::TestSqlIngestBinaryView() { + if (!quirks()->supports_ingest_view_types()) { + GTEST_SKIP(); + } + ASSERT_NO_FATAL_FAILURE(TestSqlIngestType>( NANOARROW_TYPE_LARGE_BINARY, {std::nullopt, std::vector{}, @@ -2218,7 +2230,7 @@ void StatementTest::TestSqlBind() { ASSERT_THAT( AdbcStatementSetSqlQuery( - &statement, "SELECT * FROM bindtest ORDER BY \"col1\" ASC NULLS FIRST", &error), + &statement, "SELECT * FROM bindtest ORDER BY col1 ASC NULLS FIRST", &error), IsOkStatus(&error)); { StreamReader reader; @@ -2226,7 +2238,7 @@ void StatementTest::TestSqlBind() { &reader.rows_affected, &error), IsOkStatus(&error)); ASSERT_THAT(reader.rows_affected, - ::testing::AnyOf(::testing::Eq(0), ::testing::Eq(-1))); + ::testing::AnyOf(::testing::Eq(3), ::testing::Eq(-1))); ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); ASSERT_NO_FATAL_FAILURE(reader.Next()); diff --git a/go/adbc/driver/internal/driverbase/connection.go b/go/adbc/driver/internal/driverbase/connection.go index 6e78816351..b09f74e301 100644 --- a/go/adbc/driver/internal/driverbase/connection.go +++ b/go/adbc/driver/internal/driverbase/connection.go @@ -349,14 +349,17 @@ func (cnxn *connection) GetObjects(ctx context.Context, depth adbc.ObjectDepth, bufferSize := len(catalogs) addCatalogCh := make(chan GetObjectsInfo, bufferSize) - for _, cat := range catalogs { - addCatalogCh <- GetObjectsInfo{CatalogName: Nullable(cat)} - } - - close(addCatalogCh) + errCh := make(chan error, 1) + go func() { + defer close(addCatalogCh) + for _, cat := range catalogs { + addCatalogCh <- GetObjectsInfo{CatalogName: Nullable(cat)} + } + }() if depth == adbc.ObjectDepthCatalogs { - return BuildGetObjectsRecordReader(cnxn.Base().Alloc, addCatalogCh) + close(errCh) + return BuildGetObjectsRecordReader(cnxn.Base().Alloc, addCatalogCh, errCh) } g, ctxG := errgroup.WithContext(ctx) @@ -386,7 +389,7 @@ func (cnxn *connection) GetObjects(ctx context.Context, depth adbc.ObjectDepth, g.Go(func() error { defer close(addDbSchemasCh); return gSchemas.Wait() }) if depth == adbc.ObjectDepthDBSchemas { - rdr, err := BuildGetObjectsRecordReader(cnxn.Base().Alloc, addDbSchemasCh) + rdr, err := BuildGetObjectsRecordReader(cnxn.Base().Alloc, addDbSchemasCh, errCh) return rdr, errors.Join(err, g.Wait()) } @@ -432,7 +435,7 @@ func (cnxn *connection) GetObjects(ctx context.Context, depth adbc.ObjectDepth, g.Go(func() error { defer close(addTablesCh); return gTables.Wait() }) - rdr, err := BuildGetObjectsRecordReader(cnxn.Base().Alloc, addTablesCh) + rdr, err := BuildGetObjectsRecordReader(cnxn.Base().Alloc, addTablesCh, errCh) return rdr, errors.Join(err, g.Wait()) } @@ -621,20 +624,20 @@ type ColumnInfo struct { type TableInfo struct { TableName string `json:"table_name"` TableType string `json:"table_type"` - TableColumns []ColumnInfo `json:"table_columns,omitempty"` - TableConstraints []ConstraintInfo `json:"table_constraints,omitempty"` + TableColumns []ColumnInfo `json:"table_columns"` + TableConstraints []ConstraintInfo `json:"table_constraints"` } // DBSchemaInfo is a structured representation of adbc.DBSchemaSchema type DBSchemaInfo struct { DbSchemaName *string `json:"db_schema_name,omitempty"` - DbSchemaTables []TableInfo `json:"db_schema_tables,omitempty"` + DbSchemaTables []TableInfo `json:"db_schema_tables"` } // GetObjectsInfo is a structured representation of adbc.GetObjectsSchema type GetObjectsInfo struct { CatalogName *string `json:"catalog_name,omitempty"` - CatalogDbSchemas []DBSchemaInfo `json:"catalog_db_schemas,omitempty"` + CatalogDbSchemas []DBSchemaInfo `json:"catalog_db_schemas"` } // Scan implements sql.Scanner. @@ -659,23 +662,33 @@ func (g *GetObjectsInfo) Scan(src any) error { // BuildGetObjectsRecordReader constructs a RecordReader for the GetObjects ADBC method. // It accepts a channel of GetObjectsInfo to allow concurrent retrieval of metadata and // serialization to Arrow record. -func BuildGetObjectsRecordReader(mem memory.Allocator, in chan GetObjectsInfo) (array.RecordReader, error) { +func BuildGetObjectsRecordReader(mem memory.Allocator, in <-chan GetObjectsInfo, errCh <-chan error) (array.RecordReader, error) { bldr := array.NewRecordBuilder(mem, adbc.GetObjectsSchema) defer bldr.Release() - for catalog := range in { - b, err := json.Marshal(catalog) - if err != nil { - return nil, err - } +CATALOGLOOP: + for { + select { + case catalog, ok := <-in: + if !ok { + break CATALOGLOOP + } + b, err := json.Marshal(catalog) + if err != nil { + return nil, err + } - if err := json.Unmarshal(b, bldr); err != nil { + if err := json.Unmarshal(b, bldr); err != nil { + return nil, err + } + case err := <-errCh: return nil, err } } rec := bldr.NewRecord() defer rec.Release() + return array.NewRecordReader(adbc.GetObjectsSchema, []arrow.Record{rec}) } diff --git a/go/adbc/driver/snowflake/connection.go b/go/adbc/driver/snowflake/connection.go index a8361a3653..190426c7f9 100644 --- a/go/adbc/driver/snowflake/connection.go +++ b/go/adbc/driver/snowflake/connection.go @@ -24,7 +24,9 @@ import ( "embed" "fmt" "io" + "io/fs" "path" + "runtime" "strconv" "strings" "time" @@ -42,7 +44,6 @@ const ( defaultPrefetchConcurrency = 10 queryTemplateGetObjectsAll = "get_objects_all.sql" - queryTemplateGetObjectsCatalogs = "get_objects_catalogs.sql" queryTemplateGetObjectsDbSchemas = "get_objects_dbschemas.sql" queryTemplateGetObjectsTables = "get_objects_tables.sql" queryTemplateGetObjectsTerseCatalogs = "get_objects_terse_catalogs.sql" @@ -73,9 +74,105 @@ type connectionImpl struct { useHighPrecision bool } -func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (array.RecordReader, error) { +func escapeSingleQuoteForLike(arg string) string { + if len(arg) == 0 { + return arg + } + + idx := strings.IndexByte(arg, '\'') + if idx == -1 { + return arg + } + + var b strings.Builder + b.Grow(len(arg)) + + for { + before, after, found := strings.Cut(arg, `'`) + b.WriteString(before) + if !found { + return b.String() + } + + if before[len(before)-1] != '\\' { + b.WriteByte('\\') + } + b.WriteByte('\'') + arg = after + } +} + +func getQueryID(ctx context.Context, query string, driverConn any) (string, error) { + rows, err := driverConn.(driver.QueryerContext).QueryContext(ctx, query, nil) + if err != nil { + return "", err + } + + return rows.(gosnowflake.SnowflakeRows).GetQueryID(), rows.Close() +} + +const ( + objSchemas = "SCHEMAS" + objDatabases = "DATABASES" + objViews = "VIEWS" + objTables = "TABLES" + objObjects = "OBJECTS" +) + +func addLike(query string, pattern *string) string { + if pattern != nil && len(*pattern) > 0 && *pattern != "%" && *pattern != ".*" { + query += " LIKE '" + escapeSingleQuoteForLike(*pattern) + "'" + } + return query +} + +func goGetQueryID(ctx context.Context, conn *sql.Conn, grp *errgroup.Group, objType string, catalog, dbSchema, tableName *string, outQueryID *string) { + grp.Go(func() error { + return conn.Raw(func(driverConn any) (err error) { + query := "SHOW TERSE /* ADBC:getObjects */ " + objType + switch objType { + case objDatabases: + query = addLike(query, catalog) + query += " IN ACCOUNT" + case objSchemas: + query = addLike(query, dbSchema) + + if catalog == nil || isWildcardStr(*catalog) { + query += " IN ACCOUNT" + } else { + query += " IN DATABASE " + quoteTblName(*catalog) + } + case objViews, objTables, objObjects: + query = addLike(query, tableName) + + if catalog == nil || isWildcardStr(*catalog) { + query += " IN ACCOUNT" + } else { + escapedCatalog := quoteTblName(*catalog) + if dbSchema == nil || isWildcardStr(*dbSchema) { + query += " IN DATABASE " + escapedCatalog + } else { + query += " IN SCHEMA " + escapedCatalog + "." + quoteTblName(*dbSchema) + } + } + default: + return fmt.Errorf("unimplemented object type") + } + + *outQueryID, err = getQueryID(ctx, query, driverConn) + return + }) + }) +} + +func isWildcardStr(ident string) bool { + return strings.ContainsAny(ident, "_%") +} + +func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog, dbSchema, tableName, columnName *string, tableType []string) (array.RecordReader, error) { var ( pkQueryID, fkQueryID, uniqueQueryID, terseDbQueryID string + showSchemaQueryID, tableQueryID string ) conn, err := c.sqldb.Conn(ctx) @@ -84,81 +181,117 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, } defer conn.Close() + var hasViews, hasTables bool + for _, t := range tableType { + if strings.EqualFold("VIEW", t) { + hasViews = true + } else if strings.EqualFold("TABLE", t) { + hasTables = true + } + } + + // force empty result from SHOW TABLES if tableType list is not empty + // and does not contain TABLE or VIEW in the list. + // we need this because we should have non-null db_schema_tables when + // depth is Tables, Columns or All. + var badTableType = "tabletypedoesnotexist" + if len(tableType) > 0 && depth >= adbc.ObjectDepthTables && !hasViews && !hasTables { + tableName = &badTableType + tableType = []string{"TABLE"} + } + gQueryIDs, gQueryIDsCtx := errgroup.WithContext(ctx) queryFile := queryTemplateGetObjectsAll switch depth { case adbc.ObjectDepthCatalogs: - if catalog == nil { - queryFile = queryTemplateGetObjectsTerseCatalogs - // if the catalog is null, show the terse databases - // which doesn't require a database context - gQueryIDs.Go(func() error { - return conn.Raw(func(driverConn any) error { - rows, err := driverConn.(driver.QueryerContext).QueryContext(gQueryIDsCtx, "SHOW TERSE DATABASES", nil) - if err != nil { - return err - } - - terseDbQueryID = rows.(gosnowflake.SnowflakeRows).GetQueryID() - return rows.Close() - }) - }) - } else { - queryFile = queryTemplateGetObjectsCatalogs - } + queryFile = queryTemplateGetObjectsTerseCatalogs + goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases, + catalog, dbSchema, tableName, &terseDbQueryID) case adbc.ObjectDepthDBSchemas: queryFile = queryTemplateGetObjectsDbSchemas + goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objSchemas, + catalog, dbSchema, tableName, &showSchemaQueryID) + goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases, + catalog, dbSchema, tableName, &terseDbQueryID) case adbc.ObjectDepthTables: queryFile = queryTemplateGetObjectsTables - fallthrough + goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objSchemas, + catalog, dbSchema, tableName, &showSchemaQueryID) + goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases, + catalog, dbSchema, tableName, &terseDbQueryID) + + objType := objObjects + if len(tableType) == 1 { + if strings.EqualFold("VIEW", tableType[0]) { + objType = objViews + } else if strings.EqualFold("TABLE", tableType[0]) { + objType = objTables + } + } + + goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objType, + catalog, dbSchema, tableName, &tableQueryID) default: + var suffix string + if catalog == nil || isWildcardStr(*catalog) { + suffix = " IN ACCOUNT" + } else { + escapedCatalog := quoteTblName(*catalog) + if dbSchema == nil || isWildcardStr(*dbSchema) { + suffix = " IN DATABASE " + escapedCatalog + } else { + escapedSchema := quoteTblName(*dbSchema) + if tableName == nil || isWildcardStr(*tableName) { + suffix = " IN SCHEMA " + escapedCatalog + "." + escapedSchema + } else { + escapedTable := quoteTblName(*tableName) + suffix = " IN TABLE " + escapedCatalog + "." + escapedSchema + "." + escapedTable + } + } + } + // Detailed constraint info not available in information_schema // Need to dispatch SHOW queries and use conn.Raw to extract the queryID for reuse in GetObjects query gQueryIDs.Go(func() error { - return conn.Raw(func(driverConn any) error { - rows, err := driverConn.(driver.QueryerContext).QueryContext(gQueryIDsCtx, "SHOW PRIMARY KEYS", nil) - if err != nil { - return err - } - - pkQueryID = rows.(gosnowflake.SnowflakeRows).GetQueryID() - return rows.Close() + return conn.Raw(func(driverConn any) (err error) { + pkQueryID, err = getQueryID(gQueryIDsCtx, "SHOW PRIMARY KEYS /* ADBC:getObjectsTables */"+suffix, driverConn) + return err }) }) gQueryIDs.Go(func() error { - return conn.Raw(func(driverConn any) error { - rows, err := driverConn.(driver.QueryerContext).QueryContext(gQueryIDsCtx, "SHOW IMPORTED KEYS", nil) - if err != nil { - return err - } - - fkQueryID = rows.(gosnowflake.SnowflakeRows).GetQueryID() - return rows.Close() + return conn.Raw(func(driverConn any) (err error) { + fkQueryID, err = getQueryID(gQueryIDsCtx, "SHOW IMPORTED KEYS /* ADBC:getObjectsTables */"+suffix, driverConn) + return err }) }) gQueryIDs.Go(func() error { - return conn.Raw(func(driverConn any) error { - rows, err := driverConn.(driver.QueryerContext).QueryContext(gQueryIDsCtx, "SHOW UNIQUE KEYS", nil) - if err != nil { - return err - } - - uniqueQueryID = rows.(gosnowflake.SnowflakeRows).GetQueryID() - return rows.Close() + return conn.Raw(func(driverConn any) (err error) { + uniqueQueryID, err = getQueryID(gQueryIDsCtx, "SHOW UNIQUE KEYS /* ADBC:getObjectsTables */"+suffix, driverConn) + return err }) }) - } - f, err := queryTemplates.Open(path.Join("queries", queryFile)) - if err != nil { - return nil, err + goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases, + catalog, dbSchema, tableName, &terseDbQueryID) + goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objSchemas, + catalog, dbSchema, tableName, &showSchemaQueryID) + + objType := objObjects + if len(tableType) == 1 { + if strings.EqualFold("VIEW", tableType[0]) { + objType = objViews + } else if strings.EqualFold("TABLE", tableType[0]) { + objType = objTables + } + } + goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objType, + catalog, dbSchema, tableName, &tableQueryID) } - defer f.Close() - var bldr strings.Builder - if _, err := io.Copy(&bldr, f); err != nil { + queryBytes, err := fs.ReadFile(queryTemplates, path.Join("queries", queryFile)) + if err != nil { return nil, err } @@ -180,80 +313,71 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, sql.Named("FK_QUERY_ID", fkQueryID), sql.Named("UNIQUE_QUERY_ID", uniqueQueryID), sql.Named("SHOW_DB_QUERY_ID", terseDbQueryID), + sql.Named("SHOW_SCHEMA_QUERY_ID", showSchemaQueryID), + sql.Named("SHOW_TABLE_QUERY_ID", tableQueryID), } - // the connection that is used is not the same connection context where the database may have been set - // if the caller called SetCurrentCatalog() so need to ensure the database context is appropriate - if !isNilOrEmpty(catalog) { - _, e := conn.ExecContext(context.Background(), fmt.Sprintf("USE DATABASE %s;", quoteTblName(*catalog)), nil) - if e != nil { - return nil, errToAdbcErr(adbc.StatusIO, e) + // currently only the Columns / all case still requires a current database/schema + // to be propagated. The rest of the cases all solely use SHOW queries for the metadata + // just as done by the snowflake JDBC driver. In those cases we don't need to propagate + // the current session database/schema. + if depth == adbc.ObjectDepthColumns || depth == adbc.ObjectDepthAll { + dbname, err := c.GetCurrentCatalog() + if err != nil { + return nil, errToAdbcErr(adbc.StatusIO, err) } - } - // the connection that is used is not the same connection context where the schema may have been set - // if the caller called SetCurrentDbSchema() so need to ensure the schema context is appropriate - if !isNilOrEmpty(dbSchema) { - _, e2 := conn.ExecContext(context.Background(), fmt.Sprintf("USE SCHEMA %s;", quoteTblName(*dbSchema)), nil) - if e2 != nil { - return nil, errToAdbcErr(adbc.StatusIO, e2) + schemaname, err := c.GetCurrentDbSchema() + if err != nil { + return nil, errToAdbcErr(adbc.StatusIO, err) + } + + // the connection that is used is not the same connection context where the database may have been set + // if the caller called SetCurrentCatalog() so need to ensure the database context is appropriate + multiCtx, _ := gosnowflake.WithMultiStatement(ctx, 2) + _, err = conn.ExecContext(multiCtx, fmt.Sprintf("USE DATABASE %s; USE SCHEMA %s;", quoteTblName(dbname), quoteTblName(schemaname))) + if err != nil { + return nil, errToAdbcErr(adbc.StatusIO, err) } } - query := bldr.String() + query := string(queryBytes) rows, err := conn.QueryContext(ctx, query, args...) if err != nil { return nil, errToAdbcErr(adbc.StatusIO, err) } defer rows.Close() - catalogCh := make(chan driverbase.GetObjectsInfo, 1) - readerCh := make(chan array.RecordReader) + catalogCh := make(chan driverbase.GetObjectsInfo, runtime.NumCPU()) errCh := make(chan error) go func() { - rdr, err := driverbase.BuildGetObjectsRecordReader(c.Alloc, catalogCh) - if err != nil { - errCh <- err - } - - readerCh <- rdr - close(readerCh) - }() - - for rows.Next() { - var getObjectsCatalog driverbase.GetObjectsInfo - if err := rows.Scan(&getObjectsCatalog); err != nil { - return nil, errToAdbcErr(adbc.StatusInvalidData, err) - } + defer close(catalogCh) + for rows.Next() { + var getObjectsCatalog driverbase.GetObjectsInfo + if err := rows.Scan(&getObjectsCatalog); err != nil { + errCh <- errToAdbcErr(adbc.StatusInvalidData, err) + return + } - // A few columns need additional processing outside of Snowflake - for i, sch := range getObjectsCatalog.CatalogDbSchemas { - for j, tab := range sch.DbSchemaTables { - for k, col := range tab.TableColumns { - field := c.toArrowField(col) - xdbcDataType := driverbase.ToXdbcDataType(field.Type) + // A few columns need additional processing outside of Snowflake + for i, sch := range getObjectsCatalog.CatalogDbSchemas { + for j, tab := range sch.DbSchemaTables { + for k, col := range tab.TableColumns { + field := c.toArrowField(col) + xdbcDataType := driverbase.ToXdbcDataType(field.Type) - getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcDataType = driverbase.Nullable(int16(field.Type.ID())) - getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcSqlDataType = driverbase.Nullable(int16(xdbcDataType)) + getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcDataType = driverbase.Nullable(int16(field.Type.ID())) + getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcSqlDataType = driverbase.Nullable(int16(xdbcDataType)) + } } } - } - - catalogCh <- getObjectsCatalog - } - close(catalogCh) - select { - case rdr := <-readerCh: - return rdr, nil - case err := <-errCh: - return nil, err - } -} + catalogCh <- getObjectsCatalog + } + }() -func isNilOrEmpty(str *string) bool { - return str == nil || *str == "" + return driverbase.BuildGetObjectsRecordReader(c.Alloc, catalogCh, errCh) } // PrepareDriverInfo implements driverbase.DriverInfoPreparer. @@ -266,7 +390,7 @@ func (c *connectionImpl) PrepareDriverInfo(ctx context.Context, infoCodes []adbc // ListTableTypes implements driverbase.TableTypeLister. func (*connectionImpl) ListTableTypes(ctx context.Context) ([]string, error) { - return []string{"BASE TABLE", "TEMPORARY TABLE", "VIEW"}, nil + return []string{"TABLE", "VIEW"}, nil } // GetCurrentCatalog implements driverbase.CurrentNamespacer. diff --git a/go/adbc/driver/snowflake/driver_test.go b/go/adbc/driver/snowflake/driver_test.go index 895015ffd7..c67389ca14 100644 --- a/go/adbc/driver/snowflake/driver_test.go +++ b/go/adbc/driver/snowflake/driver_test.go @@ -1215,15 +1215,15 @@ func (suite *SnowflakeTests) TestSqlIngestMapType() { [ { "col_int64": 1, - "col_map": "{\n \"key_value\": [\n {\n \"key\": \"key1\",\n \"value\": 1\n }\n ]\n}" + "col_map": "{\n \"key1\": 1\n}" }, { "col_int64": 2, - "col_map": "{\n \"key_value\": [\n {\n \"key\": \"key2\",\n \"value\": 2\n }\n ]\n}" + "col_map": "{\n \"key2\": 2\n}" }, { "col_int64": 3, - "col_map": "{\n \"key_value\": [\n {\n \"key\": \"key3\",\n \"value\": 3\n }\n ]\n}" + "col_map": "{\n \"key3\": 3\n}" } ] `))) @@ -2161,6 +2161,9 @@ func (suite *SnowflakeTests) TestGetSetClientConfigFile() { func (suite *SnowflakeTests) TestGetObjectsWithNilCatalog() { // this test demonstrates calling GetObjects with the catalog depth and a nil catalog - _, err := suite.cnxn.GetObjects(suite.ctx, adbc.ObjectDepthCatalogs, nil, nil, nil, nil, nil) + rdr, err := suite.cnxn.GetObjects(suite.ctx, adbc.ObjectDepthCatalogs, nil, nil, nil, nil, nil) suite.NoError(err) + // test suite validates memory allocator so we need to make sure we call + // release on the result reader + rdr.Release() } diff --git a/go/adbc/driver/snowflake/queries/get_objects_all.sql b/go/adbc/driver/snowflake/queries/get_objects_all.sql index 45b807f15e..7fc10f2e24 100644 --- a/go/adbc/driver/snowflake/queries/get_objects_all.sql +++ b/go/adbc/driver/snowflake/queries/get_objects_all.sql @@ -86,12 +86,12 @@ constraints AS ( table_catalog, table_schema, table_name, - ARRAY_AGG({ + ARRAY_AGG(NULLIF({ 'constraint_name': constraint_name, 'constraint_type': constraint_type, 'constraint_column_names': constraint_column_names, 'constraint_column_usage': constraint_column_usage - }) table_constraints, + }, {})) table_constraints, FROM ( SELECT * FROM pk_constraints UNION ALL @@ -105,12 +105,12 @@ tables AS ( SELECT table_catalog catalog_name, table_schema schema_name, - ARRAY_AGG({ + ARRAY_AGG(NULLIF({ 'table_name': table_name, 'table_type': table_type, - 'table_columns': table_columns, - 'table_constraints': table_constraints - }) db_schema_tables + 'table_columns': COALESCE(table_columns, []), + 'table_constraints': COALESCE(table_constraints, []) + }, {})) db_schema_tables FROM information_schema.tables LEFT JOIN columns USING (table_catalog, table_schema, table_name) @@ -123,7 +123,7 @@ db_schemas AS ( SELECT catalog_name, schema_name, - db_schema_tables, + COALESCE(db_schema_tables, []) db_schema_tables, FROM information_schema.schemata LEFT JOIN tables USING (catalog_name, schema_name) @@ -132,10 +132,10 @@ db_schemas AS ( SELECT { 'catalog_name': database_name, - 'catalog_db_schemas': ARRAY_AGG({ + 'catalog_db_schemas': ARRAY_AGG(NULLIF({ 'db_schema_name': schema_name, 'db_schema_tables': db_schema_tables - }) + }, {})) } get_objects FROM information_schema.databases diff --git a/go/adbc/driver/snowflake/queries/get_objects_catalogs.sql b/go/adbc/driver/snowflake/queries/get_objects_catalogs.sql deleted file mode 100644 index ec2cef5157..0000000000 --- a/go/adbc/driver/snowflake/queries/get_objects_catalogs.sql +++ /dev/null @@ -1,25 +0,0 @@ --- Licensed to the Apache Software Foundation (ASF) under one --- or more contributor license agreements. See the NOTICE file --- distributed with this work for additional information --- regarding copyright ownership. The ASF licenses this file --- to you under the Apache License, Version 2.0 (the --- "License"); you may not use this file except in compliance --- with the License. You may obtain a copy of the License at --- --- http://www.apache.org/licenses/LICENSE-2.0 --- --- Unless required by applicable law or agreed to in writing, --- software distributed under the License is distributed on an --- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --- KIND, either express or implied. See the License for the --- specific language governing permissions and limitations --- under the License. - -SELECT - { - 'catalog_name': database_name, - 'catalog_db_schemas': null - } get_objects -FROM - information_schema.databases -WHERE database_name ILIKE :CATALOG; diff --git a/go/adbc/driver/snowflake/queries/get_objects_dbschemas.sql b/go/adbc/driver/snowflake/queries/get_objects_dbschemas.sql index 360a6d0837..872118f7c7 100644 --- a/go/adbc/driver/snowflake/queries/get_objects_dbschemas.sql +++ b/go/adbc/driver/snowflake/queries/get_objects_dbschemas.sql @@ -17,22 +17,26 @@ WITH db_schemas AS ( SELECT - catalog_name, - schema_name, - FROM information_schema.schemata - WHERE catalog_name ILIKE :CATALOG AND schema_name ILIKE :DB_SCHEMA + "database_name" as "catalog_name", + "name" as "schema_name" + FROM table(RESULT_SCAN(:SHOW_SCHEMA_QUERY_ID)) + WHERE "database_name" ILIKE :CATALOG +), db_info AS ( + SELECT "name" AS "database_name" + FROM table(RESULT_SCAN(:SHOW_DB_QUERY_ID)) + WHERE "name" ILIKE :CATALOG ) SELECT { - 'catalog_name': database_name, - 'catalog_db_schemas': ARRAY_AGG({ - 'db_schema_name': schema_name, + 'catalog_name': "database_name", + 'catalog_db_schemas': ARRAY_AGG(NULLIF({ + 'db_schema_name': "schema_name", 'db_schema_tables': null - }) + }, {})) } get_objects FROM - information_schema.databases + db_info LEFT JOIN db_schemas -ON database_name = catalog_name -WHERE database_name ILIKE :CATALOG -GROUP BY database_name; +ON "database_name" = "catalog_name" +WHERE "database_name" ILIKE :CATALOG +GROUP BY "database_name"; diff --git a/go/adbc/driver/snowflake/queries/get_objects_tables.sql b/go/adbc/driver/snowflake/queries/get_objects_tables.sql index b3b16ff515..9d6ce36ed8 100644 --- a/go/adbc/driver/snowflake/queries/get_objects_tables.sql +++ b/go/adbc/driver/snowflake/queries/get_objects_tables.sql @@ -15,107 +15,41 @@ -- specific language governing permissions and limitations -- under the License. -WITH pk_constraints AS ( - SELECT - "database_name" table_catalog, - "schema_name" table_schema, - "table_name" table_name, - "constraint_name" constraint_name, - 'PRIMARY KEY' constraint_type, - ARRAY_AGG("column_name") WITHIN GROUP (ORDER BY "key_sequence") constraint_column_names, - [] constraint_column_usage, - FROM TABLE(RESULT_SCAN(:PK_QUERY_ID)) - WHERE table_catalog ILIKE :CATALOG AND table_schema ILIKE :DB_SCHEMA AND table_name ILIKE :TABLE - GROUP BY table_catalog, table_schema, table_name, "constraint_name" -), -unique_constraints AS ( - SELECT - "database_name" table_catalog, - "schema_name" table_schema, - "table_name" table_name, - "constraint_name" constraint_name, - 'UNIQUE' constraint_type, - ARRAY_AGG("column_name") WITHIN GROUP (ORDER BY "key_sequence") constraint_column_names, - [] constraint_column_usage, - FROM TABLE(RESULT_SCAN(:UNIQUE_QUERY_ID)) - WHERE table_catalog ILIKE :CATALOG AND table_schema ILIKE :DB_SCHEMA AND table_name ILIKE :TABLE - GROUP BY table_catalog, table_schema, table_name, "constraint_name" -), -fk_constraints AS ( - SELECT - "fk_database_name" table_catalog, - "fk_schema_name" table_schema, - "fk_table_name" table_name, - "fk_name" constraint_name, - 'FOREIGN KEY' constraint_type, - ARRAY_AGG("fk_column_name") WITHIN GROUP (ORDER BY "key_sequence") constraint_column_names, - ARRAY_AGG({ - 'fk_catalog': "pk_database_name", - 'fk_db_schema': "pk_schema_name", - 'fk_table': "pk_table_name", - 'fk_column_name': "pk_column_name" - }) WITHIN GROUP (ORDER BY "key_sequence") constraint_column_usage, - FROM TABLE(RESULT_SCAN(:FK_QUERY_ID)) - WHERE table_catalog ILIKE :CATALOG AND table_schema ILIKE :DB_SCHEMA AND table_name ILIKE :TABLE - GROUP BY table_catalog, table_schema, table_name, constraint_name -), -constraints AS ( - SELECT - table_catalog, - table_schema, - table_name, - ARRAY_AGG({ - 'constraint_name': constraint_name, - 'constraint_type': constraint_type, - 'constraint_column_names': constraint_column_names, - 'constraint_column_usage': constraint_column_usage - }) table_constraints, - FROM ( - SELECT * FROM pk_constraints - UNION ALL - SELECT * FROM unique_constraints - UNION ALL - SELECT * FROM fk_constraints - ) - GROUP BY table_catalog, table_schema, table_name -), -tables AS ( +WITH tables AS ( SELECT - table_catalog catalog_name, - table_schema schema_name, + "database_name" "catalog_name", + "schema_name" "schema_name", ARRAY_AGG({ - 'table_name': table_name, - 'table_type': table_type, - 'table_constraints': table_constraints, + 'table_name': "name", + 'table_type': "kind", + 'table_constraints': null, 'table_columns': null }) db_schema_tables -FROM information_schema.tables -LEFT JOIN constraints -USING (table_catalog, table_schema, table_name) -WHERE table_catalog ILIKE :CATALOG AND table_schema ILIKE :DB_SCHEMA AND table_name ILIKE :TABLE -GROUP BY table_catalog, table_schema +FROM TABLE(RESULT_SCAN(:SHOW_TABLE_QUERY_ID)) +WHERE "database_name" ILIKE :CATALOG AND "schema_name" ILIKE :DB_SCHEMA AND "name" ILIKE :TABLE +GROUP BY "database_name", "schema_name" ), db_schemas AS ( SELECT - catalog_name, - schema_name, - db_schema_tables, - FROM information_schema.schemata + "database_name" "catalog_name", + "name" "schema_name", + COALESCE(db_schema_tables, []) db_schema_tables, + FROM TABLE(RESULT_SCAN(:SHOW_SCHEMA_QUERY_ID)) LEFT JOIN tables - USING (catalog_name, schema_name) - WHERE catalog_name ILIKE :CATALOG AND schema_name ILIKE :DB_SCHEMA + ON "database_name" = "catalog_name" AND "name" = tables."schema_name" + WHERE "database_name" ILIKE :CATALOG AND "name" ILIKE :DB_SCHEMA ) SELECT { - 'catalog_name': database_name, - 'catalog_db_schemas': ARRAY_AGG({ - 'db_schema_name': schema_name, + 'catalog_name': "name", + 'catalog_db_schemas': ARRAY_AGG(NULLIF({ + 'db_schema_name': db_schemas."schema_name", 'db_schema_tables': db_schema_tables - }) + }, {})) } get_objects FROM - information_schema.databases + TABLE(RESULT_SCAN(:SHOW_DB_QUERY_ID)) LEFT JOIN db_schemas -ON database_name = catalog_name -WHERE database_name ILIKE :CATALOG -GROUP BY database_name; +ON "name" = "catalog_name" +WHERE "name" ILIKE :CATALOG +GROUP BY "name"; diff --git a/go/adbc/driver/snowflake/statement.go b/go/adbc/driver/snowflake/statement.go index 1fd1f658fe..574e390453 100644 --- a/go/adbc/driver/snowflake/statement.go +++ b/go/adbc/driver/snowflake/statement.go @@ -321,9 +321,9 @@ func toSnowflakeType(dt arrow.DataType) string { case arrow.DECIMAL, arrow.DECIMAL256: dec := dt.(arrow.DecimalType) return fmt.Sprintf("NUMERIC(%d,%d)", dec.GetPrecision(), dec.GetScale()) - case arrow.STRING, arrow.LARGE_STRING: + case arrow.STRING, arrow.LARGE_STRING, arrow.STRING_VIEW: return "text" - case arrow.BINARY, arrow.LARGE_BINARY: + case arrow.BINARY, arrow.LARGE_BINARY, arrow.BINARY_VIEW: return "binary" case arrow.FIXED_SIZE_BINARY: fsb := dt.(*arrow.FixedSizeBinaryType)