Skip to content

Commit

Permalink
feat(go/adbc/driver/snowflake): improve GetObjects performance and se…
Browse files Browse the repository at this point in the history
…mantics (#2254)

Fixes #2171 

Improves the channel handling and query building for metadata conversion
to Arrow for better performance.

For all cases except when retrieving Column metadata we'll now utilize
`SHOW` queries and build the patterns into those queries. This allows
those `GetObjects` calls with appropriate depths to be called without
having to specify a current database or schema.
  • Loading branch information
zeroshade authored Oct 17, 2024
1 parent 0366632 commit 5471d95
Show file tree
Hide file tree
Showing 13 changed files with 365 additions and 280 deletions.
1 change: 1 addition & 0 deletions c/driver/flightsql/sqlite_flightsql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class SqliteFlightSqlQuirks : public adbc_validation::DriverQuirks {
bool supports_get_objects() const override { return true; }
bool supports_partitioned_data() const override { return true; }
bool supports_dynamic_parameter_binding() const override { return true; }
std::string catalog() const { return "main"; }
};

class SqliteFlightSqlTest : public ::testing::Test, public adbc_validation::DatabaseTest {
Expand Down
12 changes: 11 additions & 1 deletion c/driver/snowflake/snowflake_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks {
adbc_validation::Handle<struct AdbcStatement> statement;
CHECK_OK(AdbcStatementNew(connection, &statement.value, error));

std::string create = "CREATE TABLE \"";
std::string create = "CREATE OR REPLACE TABLE \"";
create += name;
create += "\" (int64s INT, strings TEXT)";
CHECK_OK(AdbcStatementSetSqlQuery(&statement.value, create.c_str(), error));
Expand Down Expand Up @@ -131,7 +131,13 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks {
return NANOARROW_TYPE_DOUBLE;
case NANOARROW_TYPE_STRING:
case NANOARROW_TYPE_LARGE_STRING:
case NANOARROW_TYPE_LIST:
case NANOARROW_TYPE_LARGE_LIST:
return NANOARROW_TYPE_STRING;
case NANOARROW_TYPE_BINARY:
case NANOARROW_TYPE_LARGE_BINARY:
case NANOARROW_TYPE_FIXED_SIZE_BINARY:
return NANOARROW_TYPE_BINARY;
default:
return ingest_type;
}
Expand All @@ -149,7 +155,11 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks {
bool supports_dynamic_parameter_binding() const override { return true; }
bool supports_error_on_incompatible_schema() const override { return false; }
bool ddl_implicit_commit_txn() const override { return true; }
bool supports_ingest_view_types() const override { return false; }
bool supports_ingest_float16() const override { return false; }

std::string db_schema() const override { return schema_; }
std::string catalog() const override { return "ADBC_TESTING"; }

const char* uri_;
bool skip_{false};
Expand Down
6 changes: 6 additions & 0 deletions c/validation/adbc_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ class DriverQuirks {
/// column matching.
virtual bool supports_error_on_incompatible_schema() const { return true; }

/// \brief Whether ingestion supports StringView/BinaryView types
virtual bool supports_ingest_view_types() const { return true; }

/// \brief Whether ingestion supports Float16
virtual bool supports_ingest_float16() const { return true; }

/// \brief Default catalog to use for tests
virtual std::string catalog() const { return ""; }

Expand Down
27 changes: 15 additions & 12 deletions c/validation/adbc_validation_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -744,27 +744,30 @@ void ConnectionTest::TestMetadataGetObjectsColumns() {

struct TestCase {
std::optional<std::string> filter;
std::vector<std::string> column_names;
std::vector<int32_t> ordinal_positions;
// the pair is column name & ordinal position of the column
std::vector<std::pair<std::string, int32_t>> columns;
};

std::vector<TestCase> test_cases;
test_cases.push_back({std::nullopt, {"int64s", "strings"}, {1, 2}});
test_cases.push_back({"in%", {"int64s"}, {1}});
test_cases.push_back({std::nullopt, {{"int64s", 1}, {"strings", 2}}});
test_cases.push_back({"in%", {{"int64s", 1}}});

const std::string catalog = quirks()->catalog();

for (const auto& test_case : test_cases) {
std::string scope = "Filter: ";
scope += test_case.filter ? *test_case.filter : "(no filter)";
SCOPED_TRACE(scope);

StreamReader reader;
std::vector<std::pair<std::string, int32_t>> columns;
std::vector<std::string> column_names;
std::vector<int32_t> ordinal_positions;

ASSERT_THAT(
AdbcConnectionGetObjects(
&connection, ADBC_OBJECT_DEPTH_COLUMNS, nullptr, nullptr, nullptr, nullptr,
test_case.filter.has_value() ? test_case.filter->c_str() : nullptr,
&connection, ADBC_OBJECT_DEPTH_COLUMNS, catalog.c_str(), nullptr, nullptr,
nullptr, test_case.filter.has_value() ? test_case.filter->c_str() : nullptr,
&reader.stream.value, &error),
IsOkStatus(&error));
ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
Expand Down Expand Up @@ -834,10 +837,9 @@ void ConnectionTest::TestMetadataGetObjectsColumns() {
std::string temp(name.data, name.size_bytes);
std::transform(temp.begin(), temp.end(), temp.begin(),
[](unsigned char c) { return std::tolower(c); });
column_names.push_back(std::move(temp));
ordinal_positions.push_back(
static_cast<int32_t>(ArrowArrayViewGetIntUnsafe(
table_columns->children[1], columns_index)));
columns.emplace_back(std::move(temp),
static_cast<int32_t>(ArrowArrayViewGetIntUnsafe(
table_columns->children[1], columns_index)));
}
}
}
Expand All @@ -847,8 +849,9 @@ void ConnectionTest::TestMetadataGetObjectsColumns() {
} while (reader.array->release);

ASSERT_TRUE(found_expected_table) << "Did (not) find table in metadata";
ASSERT_EQ(test_case.column_names, column_names);
ASSERT_EQ(test_case.ordinal_positions, ordinal_positions);
// metadata columns do not guarantee the order they are returned in, just
// validate all the elements are there.
ASSERT_THAT(columns, testing::UnorderedElementsAreArray(test_case.columns));
}
}

Expand Down
16 changes: 14 additions & 2 deletions c/validation/adbc_validation_statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ void StatementTest::TestSqlIngestInt64() {
}

void StatementTest::TestSqlIngestFloat16() {
if (!quirks()->supports_ingest_float16()) {
GTEST_SKIP();
}

ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType<float>(NANOARROW_TYPE_HALF_FLOAT));
}

Expand All @@ -268,6 +272,10 @@ void StatementTest::TestSqlIngestLargeString() {
}

void StatementTest::TestSqlIngestStringView() {
if (!quirks()->supports_ingest_view_types()) {
GTEST_SKIP();
}

ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::string>(
NANOARROW_TYPE_STRING_VIEW, {std::nullopt, "", "", "longer than 12 bytes", ""},
false));
Expand Down Expand Up @@ -302,6 +310,10 @@ void StatementTest::TestSqlIngestFixedSizeBinary() {
}

void StatementTest::TestSqlIngestBinaryView() {
if (!quirks()->supports_ingest_view_types()) {
GTEST_SKIP();
}

ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::vector<std::byte>>(
NANOARROW_TYPE_LARGE_BINARY,
{std::nullopt, std::vector<std::byte>{},
Expand Down Expand Up @@ -2218,15 +2230,15 @@ void StatementTest::TestSqlBind() {

ASSERT_THAT(
AdbcStatementSetSqlQuery(
&statement, "SELECT * FROM bindtest ORDER BY \"col1\" ASC NULLS FIRST", &error),
&statement, "SELECT * FROM bindtest ORDER BY col1 ASC NULLS FIRST", &error),
IsOkStatus(&error));
{
StreamReader reader;
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
&reader.rows_affected, &error),
IsOkStatus(&error));
ASSERT_THAT(reader.rows_affected,
::testing::AnyOf(::testing::Eq(0), ::testing::Eq(-1)));
::testing::AnyOf(::testing::Eq(3), ::testing::Eq(-1)));

ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ASSERT_NO_FATAL_FAILURE(reader.Next());
Expand Down
51 changes: 32 additions & 19 deletions go/adbc/driver/internal/driverbase/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,17 @@ func (cnxn *connection) GetObjects(ctx context.Context, depth adbc.ObjectDepth,

bufferSize := len(catalogs)
addCatalogCh := make(chan GetObjectsInfo, bufferSize)
for _, cat := range catalogs {
addCatalogCh <- GetObjectsInfo{CatalogName: Nullable(cat)}
}

close(addCatalogCh)
errCh := make(chan error, 1)
go func() {
defer close(addCatalogCh)
for _, cat := range catalogs {
addCatalogCh <- GetObjectsInfo{CatalogName: Nullable(cat)}
}
}()

if depth == adbc.ObjectDepthCatalogs {
return BuildGetObjectsRecordReader(cnxn.Base().Alloc, addCatalogCh)
close(errCh)
return BuildGetObjectsRecordReader(cnxn.Base().Alloc, addCatalogCh, errCh)
}

g, ctxG := errgroup.WithContext(ctx)
Expand Down Expand Up @@ -386,7 +389,7 @@ func (cnxn *connection) GetObjects(ctx context.Context, depth adbc.ObjectDepth,
g.Go(func() error { defer close(addDbSchemasCh); return gSchemas.Wait() })

if depth == adbc.ObjectDepthDBSchemas {
rdr, err := BuildGetObjectsRecordReader(cnxn.Base().Alloc, addDbSchemasCh)
rdr, err := BuildGetObjectsRecordReader(cnxn.Base().Alloc, addDbSchemasCh, errCh)
return rdr, errors.Join(err, g.Wait())
}

Expand Down Expand Up @@ -432,7 +435,7 @@ func (cnxn *connection) GetObjects(ctx context.Context, depth adbc.ObjectDepth,

g.Go(func() error { defer close(addTablesCh); return gTables.Wait() })

rdr, err := BuildGetObjectsRecordReader(cnxn.Base().Alloc, addTablesCh)
rdr, err := BuildGetObjectsRecordReader(cnxn.Base().Alloc, addTablesCh, errCh)
return rdr, errors.Join(err, g.Wait())
}

Expand Down Expand Up @@ -621,20 +624,20 @@ type ColumnInfo struct {
type TableInfo struct {
TableName string `json:"table_name"`
TableType string `json:"table_type"`
TableColumns []ColumnInfo `json:"table_columns,omitempty"`
TableConstraints []ConstraintInfo `json:"table_constraints,omitempty"`
TableColumns []ColumnInfo `json:"table_columns"`
TableConstraints []ConstraintInfo `json:"table_constraints"`
}

// DBSchemaInfo is a structured representation of adbc.DBSchemaSchema
type DBSchemaInfo struct {
DbSchemaName *string `json:"db_schema_name,omitempty"`
DbSchemaTables []TableInfo `json:"db_schema_tables,omitempty"`
DbSchemaTables []TableInfo `json:"db_schema_tables"`
}

// GetObjectsInfo is a structured representation of adbc.GetObjectsSchema
type GetObjectsInfo struct {
CatalogName *string `json:"catalog_name,omitempty"`
CatalogDbSchemas []DBSchemaInfo `json:"catalog_db_schemas,omitempty"`
CatalogDbSchemas []DBSchemaInfo `json:"catalog_db_schemas"`
}

// Scan implements sql.Scanner.
Expand All @@ -659,23 +662,33 @@ func (g *GetObjectsInfo) Scan(src any) error {
// BuildGetObjectsRecordReader constructs a RecordReader for the GetObjects ADBC method.
// It accepts a channel of GetObjectsInfo to allow concurrent retrieval of metadata and
// serialization to Arrow record.
func BuildGetObjectsRecordReader(mem memory.Allocator, in chan GetObjectsInfo) (array.RecordReader, error) {
func BuildGetObjectsRecordReader(mem memory.Allocator, in <-chan GetObjectsInfo, errCh <-chan error) (array.RecordReader, error) {
bldr := array.NewRecordBuilder(mem, adbc.GetObjectsSchema)
defer bldr.Release()

for catalog := range in {
b, err := json.Marshal(catalog)
if err != nil {
return nil, err
}
CATALOGLOOP:
for {
select {
case catalog, ok := <-in:
if !ok {
break CATALOGLOOP
}
b, err := json.Marshal(catalog)
if err != nil {
return nil, err
}

if err := json.Unmarshal(b, bldr); err != nil {
if err := json.Unmarshal(b, bldr); err != nil {
return nil, err
}
case err := <-errCh:
return nil, err
}
}

rec := bldr.NewRecord()
defer rec.Release()

return array.NewRecordReader(adbc.GetObjectsSchema, []arrow.Record{rec})
}

Expand Down
Loading

0 comments on commit 5471d95

Please sign in to comment.