Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(go/adbc/driver/snowflake): fix setting database and schema context after initial connection #2169

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csharp/src/Apache.Arrow.Adbc/AdbcConnection11.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public abstract class AdbcConnection11 : IDisposable
, IAsyncDisposable
#endif
{

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider removing added whitespace

~AdbcConnection11() => Dispose(false);

/// <summary>
Expand Down
36 changes: 36 additions & 0 deletions csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,42 @@ private bool CurrentRoleIsExpectedRole(AdbcConnection cn, string expectedRole)
return expectedRole == stringArray.GetString(0);
}

[SkippableFact, Order(1)]
public void CanSetDatabase()
{
Skip.If(string.IsNullOrEmpty(_testConfiguration.Metadata.Catalog));

// connect without the parameter and ensure we get the DATABASE successfully
Dictionary<string, string> parameters = new Dictionary<string, string>();
Dictionary<string, string> options = new Dictionary<string, string>();

using AdbcDriver localSnowflakeDriver = SnowflakeTestingUtils.GetSnowflakeAdbcDriver(_testConfiguration, out parameters);
parameters.Remove(SnowflakeParameters.DATABASE);
using AdbcDatabase localDatabase = localSnowflakeDriver.Open(parameters);
using AdbcConnection localConnection = localDatabase.Connect(options);

localConnection.SetOption(AdbcOptions.Connection.CurrentCatalog, _testConfiguration.Metadata.Catalog);

Assert.True(CurrentDatabaseIsExpectedCatalog(localConnection, _testConfiguration.Metadata.Catalog));

localConnection.GetObjects(AdbcConnection.GetObjectsDepth.All, _testConfiguration.Metadata.Catalog, _testConfiguration.Metadata.Schema, _testConfiguration.Metadata.Table, _tableTypes, null);
}

private bool CurrentDatabaseIsExpectedCatalog(AdbcConnection cn, string expectedCatalog)
{
using AdbcStatement statement = cn.CreateStatement();
statement.SqlQuery = "SELECT CURRENT_DATABASE() as CURRENT_DATABASE;"; // GetOption doesn't exist in 1.0, only 1.1

QueryResult queryResult = statement.ExecuteQuery();
using RecordBatch? recordBatch = queryResult.Stream?.ReadNextRecordBatchAsync().Result;
Assert.True(recordBatch != null);

StringArray stringArray = (StringArray)recordBatch.Column("CURRENT_DATABASE");
Assert.True(stringArray.Length > 0);

return expectedCatalog == stringArray.GetString(0);
}

/// <summary>
/// Validates if the driver can connect to a live server and
/// parse the results.
Expand Down
26 changes: 24 additions & 2 deletions go/adbc/driver/snowflake/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,24 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth,
sql.Named("UNIQUE_QUERY_ID", uniqueQueryID),
}

// the connection that is used is not the same connection context where the database may have been set
// if the caller called SetCurrentCatalog() so need to ensure the database context is appropriate
if !isNilOrEmpty(catalog) {
_, e := conn.ExecContext(context.Background(), fmt.Sprintf("USE DATABASE %s;", quoteTblName(*catalog)), nil)
if e != nil {
return nil, errToAdbcErr(adbc.StatusIO, e)
}
}

// the connection that is used is not the same connection context where the schema may have been set
// if the caller called SetCurrentDbSchema() so need to ensure the schema context is appropriate
if !isNilOrEmpty(dbSchema) {
_, e2 := conn.ExecContext(context.Background(), fmt.Sprintf("USE SCHEMA %s;", quoteTblName(*dbSchema)), nil)
if e2 != nil {
return nil, errToAdbcErr(adbc.StatusIO, e2)
}
}

query := bldr.String()
rows, err := conn.QueryContext(ctx, query, args...)
if err != nil {
Expand Down Expand Up @@ -214,6 +232,10 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth,
}
}

func isNilOrEmpty(str *string) bool {
return str == nil || *str == ""
}

// PrepareDriverInfo implements driverbase.DriverInfoPreparer.
func (c *connectionImpl) PrepareDriverInfo(ctx context.Context, infoCodes []adbc.InfoCode) error {
if err := c.ConnectionImplBase.DriverInfo.RegisterInfoCode(adbc.InfoVendorSql, true); err != nil {
Expand All @@ -239,13 +261,13 @@ func (c *connectionImpl) GetCurrentDbSchema() (string, error) {

// SetCurrentCatalog implements driverbase.CurrentNamespacer.
func (c *connectionImpl) SetCurrentCatalog(value string) error {
_, err := c.cn.ExecContext(context.Background(), "USE DATABASE ?", []driver.NamedValue{{Value: value}})
_, err := c.cn.ExecContext(context.Background(), fmt.Sprintf("USE DATABASE %s;", quoteTblName(value)), nil)
return err
}

// SetCurrentDbSchema implements driverbase.CurrentNamespacer.
func (c *connectionImpl) SetCurrentDbSchema(value string) error {
_, err := c.cn.ExecContext(context.Background(), "USE SCHEMA ?", []driver.NamedValue{{Value: value}})
_, err := c.cn.ExecContext(context.Background(), fmt.Sprintf("USE SCHEMA %s;", quoteTblName(value)), nil)
return err
}

Expand Down
33 changes: 33 additions & 0 deletions go/adbc/driver/snowflake/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2111,3 +2111,36 @@ func TestIngestCancelContext(t *testing.T) {
require.Equal(t, "", buf.String())
})
}

func (suite *SnowflakeTests) TestChangeDatabaseAndGetObjects() {
// this test demonstrates:
// 1. changing the database context
// 2. being able to call GetObjects after changing the database context
// (this uses a different connection context but still executes successfully)

uri, ok := os.LookupEnv("SNOWFLAKE_URI")
if !ok {
suite.T().Skip("Cannot find the `SNOWFLAKE_URI` value")
}

newCatalog, ok := os.LookupEnv("SNOWFLAKE_NEW_CATALOG")
if !ok {
suite.T().Skip("Cannot find the `SNOWFLAKE_NEW_CATALOG` value")
}

getObjectsTable, ok := os.LookupEnv("SNOWFLAKE_TABLE_GETOBJECTS")
if !ok {
suite.T().Skip("Cannot find the `SNOWFLAKE_TABLE_GETOBJECTS` value")
}

cfg, err := gosnowflake.ParseDSN(uri)
suite.NoError(err)

cnxnopt, ok := suite.cnxn.(adbc.PostInitOptions)
suite.True(ok)
err = cnxnopt.SetOption(adbc.OptionKeyCurrentCatalog, newCatalog)
suite.NoError(err)

_, err2 := suite.cnxn.GetObjects(suite.ctx, adbc.ObjectDepthAll, &newCatalog, &cfg.Schema, &getObjectsTable, nil, nil)
suite.NoError(err2)
}
Loading