diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index c293e0f0d6..c8cf5cdd45 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -137,7 +137,7 @@ jobs: - name: Start SQLite server and Dremio shell: bash -l {0} run: | - docker-compose up -d golang-sqlite-flightsql dremio dremio-init + docker-compose up -d flightsql-test flightsql-sqlite-test dremio dremio-init - name: Build FlightSQL Driver shell: bash -l {0} @@ -155,6 +155,7 @@ jobs: ADBC_DREMIO_FLIGHTSQL_USER: "dremio" ADBC_DREMIO_FLIGHTSQL_PASS: "dremio123" ADBC_SQLITE_FLIGHTSQL_URI: "grpc+tcp://localhost:8080" + ADBC_TEST_FLIGHTSQL_URI: "grpc+tcp://localhost:41414" run: | ./ci/scripts/cpp_build.sh "$(pwd)" "$(pwd)/build" ./ci/scripts/cpp_test.sh "$(pwd)/build" @@ -174,6 +175,7 @@ jobs: ADBC_DREMIO_FLIGHTSQL_URI: "grpc+tcp://localhost:32010" ADBC_DREMIO_FLIGHTSQL_USER: "dremio" ADBC_DREMIO_FLIGHTSQL_PASS: "dremio123" + ADBC_TEST_FLIGHTSQL_URI: "grpc+tcp://localhost:41414" run: | ./ci/scripts/python_test.sh "$(pwd)" "$(pwd)/build" - name: Stop SQLite server and Dremio diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index 10b98e1c71..d955ab2d88 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -69,7 +69,7 @@ jobs: - name: Start SQLite server shell: bash -l {0} run: | - docker-compose up -d golang-sqlite-flightsql + docker-compose up -d flightsql-sqlite-test - name: Build/Test env: ADBC_SQLITE_FLIGHTSQL_URI: "grpc+tcp://localhost:8080" diff --git a/.github/workflows/native-unix.yml b/.github/workflows/native-unix.yml index b12aa46d51..e3a1bac3d1 100644 --- a/.github/workflows/native-unix.yml +++ b/.github/workflows/native-unix.yml @@ -658,7 +658,7 @@ jobs: if: matrix.config.pkg == 'adbcpostgresql' && runner.os == 'Linux' run: | cd r/adbcpostgresql - docker compose up --detach postgres_test + docker compose up --detach postgres-test ADBC_POSTGRESQL_TEST_URI="postgresql://localhost:5432/postgres?user=postgres&password=password" echo "ADBC_POSTGRESQL_TEST_URI=${ADBC_POSTGRESQL_TEST_URI}" >> $GITHUB_ENV @@ -666,7 +666,7 @@ jobs: if: matrix.config.pkg == 'adbcflightsql' && runner.os == 'Linux' run: | cd r/adbcpostgresql - docker compose up --detach golang-sqlite-flightsql + docker compose up --detach flightsql-sqlite-test ADBC_FLIGHTSQL_TEST_URI="grpc://localhost:8080" echo "ADBC_FLIGHTSQL_TEST_URI=${ADBC_FLIGHTSQL_TEST_URI}" >> $GITHUB_ENV diff --git a/c/driver/postgresql/README.md b/c/driver/postgresql/README.md index cc5a3dfe03..8ccffb6845 100644 --- a/c/driver/postgresql/README.md +++ b/c/driver/postgresql/README.md @@ -54,9 +54,9 @@ Alternatively use the `docker compose` provided by ADBC to manage the test database container. ```shell -$ docker compose up postgres_test +$ docker compose up postgres-test # When finished: -# docker compose down postgres_test +# docker compose down postgres-test ``` Then, to run the tests, set the environment variable specifying the diff --git a/c/driver_manager/adbc_driver_manager.cc b/c/driver_manager/adbc_driver_manager.cc index 516bf9bbf7..e4287534df 100644 --- a/c/driver_manager/adbc_driver_manager.cc +++ b/c/driver_manager/adbc_driver_manager.cc @@ -642,8 +642,12 @@ const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream return nullptr; } auto* private_data = reinterpret_cast(stream->private_data); - return private_data->private_driver->ErrorFromArrayStream(&private_data->stream, - status); + auto* error = + private_data->private_driver->ErrorFromArrayStream(&private_data->stream, status); + if (error) { + const_cast(error)->private_driver = private_data->private_driver; + } + return error; } #define INIT_ERROR(ERROR, SOURCE) \ diff --git a/ci/docker/flightsql-test.dockerfile b/ci/docker/flightsql-test.dockerfile new file mode 100644 index 0000000000..7c67b06533 --- /dev/null +++ b/ci/docker/flightsql-test.dockerfile @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +ARG GO +FROM golang:${GO} +EXPOSE 41414 diff --git a/docker-compose.yml b/docker-compose.yml index dd0ef2f53f..2c77d72198 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -107,15 +107,6 @@ services: ###################### Test database environments ############################ - postgres_test: - container_name: adbc_postgres_test - image: postgres:latest - environment: - POSTGRES_USER: postgres - POSTGRES_PASSWORD: password - ports: - - "5432:5432" - dremio: container_name: adbc-dremio image: dremio/dremio-oss:latest @@ -150,7 +141,23 @@ services: volumes: - "./ci/scripts/integration/dremio:/init" - golang-sqlite-flightsql: + flightsql-test: + image: ${REPO}:adbc-flightsql-test + build: + context: . + cache_from: + - ${REPO}:adbc-flightsql-test + dockerfile: ci/docker/flightsql-test.dockerfile + args: + GO: ${GO} + ports: + - "41414:41414" + volumes: + - .:/adbc:delegated + command: >- + /bin/bash -c "cd /adbc/go/adbc && go run ./driver/flightsql/cmd/testserver -host 0.0.0.0 -port 41414" + + flightsql-sqlite-test: image: ${REPO}:golang-${GO}-sqlite-flightsql build: context: . @@ -162,3 +169,12 @@ services: ARROW_MAJOR_VERSION: ${ARROW_MAJOR_VERSION} ports: - 8080:8080 + + postgres-test: + container_name: adbc_postgres_test + image: postgres:latest + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: password + ports: + - "5432:5432" diff --git a/go/adbc/adbc.go b/go/adbc/adbc.go index ad6194f240..b0737fe02a 100644 --- a/go/adbc/adbc.go +++ b/go/adbc/adbc.go @@ -131,7 +131,12 @@ type Error struct { } func (e Error) Error() string { - return fmt.Sprintf("%s: SqlState: %s, msg: %s", e.Code, string(e.SqlState[:]), e.Msg) + // Don't include a NUL in the string since C Data Interface uses char* (and + // don't include the extra cruft if not needed in the first place) + if e.SqlState[0] != 0 { + return fmt.Sprintf("%s: %s (%s)", e.Code, e.Msg, string(e.SqlState[:])) + } + return fmt.Sprintf("%s: %s", e.Code, e.Msg) } // Status represents an error code for operations that may fail diff --git a/go/adbc/driver/flightsql/cmd/testserver/main.go b/go/adbc/driver/flightsql/cmd/testserver/main.go new file mode 100644 index 0000000000..6e0ca4ffa8 --- /dev/null +++ b/go/adbc/driver/flightsql/cmd/testserver/main.go @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// A server intended specifically for testing the Flight SQL driver. Unlike +// the upstream SQLite example, which tries to be functional, this server +// tries to be useful. + +package main + +import ( + "bytes" + "context" + "flag" + "fmt" + "log" + "net" + "os" + "strconv" + "strings" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/apache/arrow/go/v13/arrow/flight" + "github.com/apache/arrow/go/v13/arrow/flight/flightsql" + "github.com/apache/arrow/go/v13/arrow/memory" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type ExampleServer struct { + flightsql.BaseServer +} + +func (srv *ExampleServer) ClosePreparedStatement(ctx context.Context, request flightsql.ActionClosePreparedStatementRequest) error { + return nil +} + +func (srv *ExampleServer) CreatePreparedStatement(ctx context.Context, req flightsql.ActionCreatePreparedStatementRequest) (result flightsql.ActionCreatePreparedStatementResult, err error) { + result.Handle = []byte(req.GetQuery()) + return +} + +func (srv *ExampleServer) GetFlightInfoPreparedStatement(_ context.Context, cmd flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + if bytes.Equal(cmd.GetPreparedStatementHandle(), []byte("error_do_get")) || bytes.Equal(cmd.GetPreparedStatementHandle(), []byte("error_do_get_stream")) { + schema := arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) + return &flight.FlightInfo{ + Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: desc.Cmd}}}, + FlightDescriptor: desc, + TotalRecords: -1, + TotalBytes: -1, + Schema: flight.SerializeSchema(schema, srv.Alloc), + }, nil + } + + return &flight.FlightInfo{ + Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: desc.Cmd}}}, + FlightDescriptor: desc, + TotalRecords: -1, + TotalBytes: -1, + }, nil +} + +func (srv *ExampleServer) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + ticket, err := flightsql.CreateStatementQueryTicket(desc.Cmd) + if err != nil { + return nil, err + } + + return &flight.FlightInfo{ + Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: ticket}}}, + FlightDescriptor: desc, + TotalRecords: -1, + TotalBytes: -1, + }, nil +} + +func (srv *ExampleServer) DoGetPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery) (schema *arrow.Schema, out <-chan flight.StreamChunk, err error) { + log.Printf("DoGetPreparedStatement: %v", cmd.GetPreparedStatementHandle()) + if bytes.Equal(cmd.GetPreparedStatementHandle(), []byte("error_do_get")) { + err = status.Error(codes.InvalidArgument, "expected error") + return + } + + schema = arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) + rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, schema, strings.NewReader(`[{"a": 5}]`)) + + ch := make(chan flight.StreamChunk) + go func() { + defer close(ch) + ch <- flight.StreamChunk{ + Data: rec, + Desc: nil, + Err: nil, + } + if bytes.Equal(cmd.GetPreparedStatementHandle(), []byte("error_do_get_stream")) { + ch <- flight.StreamChunk{ + Data: nil, + Desc: nil, + Err: status.Error(codes.InvalidArgument, "expected error"), + } + } + }() + out = ch + return +} + +func (srv *ExampleServer) DoGetStatement(ctx context.Context, cmd flightsql.StatementQueryTicket) (schema *arrow.Schema, out <-chan flight.StreamChunk, err error) { + schema = arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) + rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, schema, strings.NewReader(`[{"ints": 5}]`)) + + ch := make(chan flight.StreamChunk) + go func() { + defer close(ch) + ch <- flight.StreamChunk{ + Data: rec, + Desc: nil, + Err: nil, + } + }() + out = ch + return +} + +func main() { + var ( + host = flag.String("host", "localhost", "hostname to bind to") + port = flag.Int("port", 0, "port to bind to") + ) + + flag.Parse() + + srv := &ExampleServer{} + srv.Alloc = memory.DefaultAllocator + + server := flight.NewServerWithMiddleware(nil) + server.RegisterFlightService(flightsql.NewFlightServer(srv)) + if err := server.Init(net.JoinHostPort(*host, strconv.Itoa(*port))); err != nil { + log.Fatal(err) + } + server.SetShutdownOnSignals(os.Interrupt, os.Kill) + + fmt.Println("Starting testing Flight SQL Server on", server.Addr(), "...") + + if err := server.Serve(); err != nil { + log.Fatal(err) + } +} diff --git a/go/adbc/drivermgr/adbc_driver_manager.cc b/go/adbc/drivermgr/adbc_driver_manager.cc index 516bf9bbf7..e4287534df 100644 --- a/go/adbc/drivermgr/adbc_driver_manager.cc +++ b/go/adbc/drivermgr/adbc_driver_manager.cc @@ -642,8 +642,12 @@ const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream return nullptr; } auto* private_data = reinterpret_cast(stream->private_data); - return private_data->private_driver->ErrorFromArrayStream(&private_data->stream, - status); + auto* error = + private_data->private_driver->ErrorFromArrayStream(&private_data->stream, status); + if (error) { + const_cast(error)->private_driver = private_data->private_driver; + } + return error; } #define INIT_ERROR(ERROR, SOURCE) \ diff --git a/go/adbc/pkg/_tmpl/driver.go.tmpl b/go/adbc/pkg/_tmpl/driver.go.tmpl index fc489a4016..24c15f3960 100644 --- a/go/adbc/pkg/_tmpl/driver.go.tmpl +++ b/go/adbc/pkg/_tmpl/driver.go.tmpl @@ -263,7 +263,7 @@ func (cStream *cArrayStream) maybeError() C.int { if cStream.adbcErr != nil { C.{{.Prefix}}errRelease(cStream.adbcErr) } else { - cStream.adbcErr = (*C.struct_AdbcError)(C.malloc(C.ADBC_ERROR_1_1_0_SIZE)) + cStream.adbcErr = (*C.struct_AdbcError)(C.calloc(1, C.ADBC_ERROR_1_1_0_SIZE)) } cStream.adbcErr.vendor_code = C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA cStream.status = C.AdbcStatusCode(errToAdbcErr(cStream.adbcErr, err)) diff --git a/go/adbc/pkg/_tmpl/utils.h.tmpl b/go/adbc/pkg/_tmpl/utils.h.tmpl index ce3dba9dbc..d73f4bad71 100644 --- a/go/adbc/pkg/_tmpl/utils.h.tmpl +++ b/go/adbc/pkg/_tmpl/utils.h.tmpl @@ -83,7 +83,10 @@ AdbcStatusCode {{.Prefix}}StatementSetSubstraitPlan(struct AdbcStatement* stmt, AdbcStatusCode {{.Prefix}}DriverInit(int version, void* rawDriver, struct AdbcError* err); static inline void {{.Prefix}}errRelease(struct AdbcError* error) { - error->release(error); + if (error->release) { + error->release(error); + error->release = NULL; + } } void {{.Prefix}}_release_error(struct AdbcError* error); diff --git a/go/adbc/pkg/flightsql/driver.go b/go/adbc/pkg/flightsql/driver.go index 925fd8658e..46e096952c 100644 --- a/go/adbc/pkg/flightsql/driver.go +++ b/go/adbc/pkg/flightsql/driver.go @@ -267,7 +267,7 @@ func (cStream *cArrayStream) maybeError() C.int { if cStream.adbcErr != nil { C.FlightSQLerrRelease(cStream.adbcErr) } else { - cStream.adbcErr = (*C.struct_AdbcError)(C.malloc(C.ADBC_ERROR_1_1_0_SIZE)) + cStream.adbcErr = (*C.struct_AdbcError)(C.calloc(1, C.ADBC_ERROR_1_1_0_SIZE)) } cStream.adbcErr.vendor_code = C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA cStream.status = C.AdbcStatusCode(errToAdbcErr(cStream.adbcErr, err)) diff --git a/go/adbc/pkg/flightsql/utils.h b/go/adbc/pkg/flightsql/utils.h index e3b22fb737..fbdbe89a8a 100644 --- a/go/adbc/pkg/flightsql/utils.h +++ b/go/adbc/pkg/flightsql/utils.h @@ -153,7 +153,12 @@ AdbcStatusCode FlightSQLStatementSetSubstraitPlan(struct AdbcStatement* stmt, AdbcStatusCode FlightSQLDriverInit(int version, void* rawDriver, struct AdbcError* err); -static inline void FlightSQLerrRelease(struct AdbcError* error) { error->release(error); } +static inline void FlightSQLerrRelease(struct AdbcError* error) { + if (error->release) { + error->release(error); + error->release = NULL; + } +} void FlightSQL_release_error(struct AdbcError* error); diff --git a/go/adbc/pkg/panicdummy/driver.go b/go/adbc/pkg/panicdummy/driver.go index d1c143a762..c99153ccb5 100644 --- a/go/adbc/pkg/panicdummy/driver.go +++ b/go/adbc/pkg/panicdummy/driver.go @@ -267,7 +267,7 @@ func (cStream *cArrayStream) maybeError() C.int { if cStream.adbcErr != nil { C.PanicDummyerrRelease(cStream.adbcErr) } else { - cStream.adbcErr = (*C.struct_AdbcError)(C.malloc(C.ADBC_ERROR_1_1_0_SIZE)) + cStream.adbcErr = (*C.struct_AdbcError)(C.calloc(1, C.ADBC_ERROR_1_1_0_SIZE)) } cStream.adbcErr.vendor_code = C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA cStream.status = C.AdbcStatusCode(errToAdbcErr(cStream.adbcErr, err)) diff --git a/go/adbc/pkg/panicdummy/utils.h b/go/adbc/pkg/panicdummy/utils.h index 91d8294c4e..b8db59c227 100644 --- a/go/adbc/pkg/panicdummy/utils.h +++ b/go/adbc/pkg/panicdummy/utils.h @@ -156,7 +156,10 @@ AdbcStatusCode PanicDummyStatementSetSubstraitPlan(struct AdbcStatement* stmt, AdbcStatusCode PanicDummyDriverInit(int version, void* rawDriver, struct AdbcError* err); static inline void PanicDummyerrRelease(struct AdbcError* error) { - error->release(error); + if (error->release) { + error->release(error); + error->release = NULL; + } } void PanicDummy_release_error(struct AdbcError* error); diff --git a/go/adbc/pkg/snowflake/driver.go b/go/adbc/pkg/snowflake/driver.go index 6ca09646d4..4804e32e38 100644 --- a/go/adbc/pkg/snowflake/driver.go +++ b/go/adbc/pkg/snowflake/driver.go @@ -267,7 +267,7 @@ func (cStream *cArrayStream) maybeError() C.int { if cStream.adbcErr != nil { C.SnowflakeerrRelease(cStream.adbcErr) } else { - cStream.adbcErr = (*C.struct_AdbcError)(C.malloc(C.ADBC_ERROR_1_1_0_SIZE)) + cStream.adbcErr = (*C.struct_AdbcError)(C.calloc(1, C.ADBC_ERROR_1_1_0_SIZE)) } cStream.adbcErr.vendor_code = C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA cStream.status = C.AdbcStatusCode(errToAdbcErr(cStream.adbcErr, err)) diff --git a/go/adbc/pkg/snowflake/utils.h b/go/adbc/pkg/snowflake/utils.h index 23391dfd70..c679316232 100644 --- a/go/adbc/pkg/snowflake/utils.h +++ b/go/adbc/pkg/snowflake/utils.h @@ -153,7 +153,12 @@ AdbcStatusCode SnowflakeStatementSetSubstraitPlan(struct AdbcStatement* stmt, AdbcStatusCode SnowflakeDriverInit(int version, void* rawDriver, struct AdbcError* err); -static inline void SnowflakeerrRelease(struct AdbcError* error) { error->release(error); } +static inline void SnowflakeerrRelease(struct AdbcError* error) { + if (error->release) { + error->release(error); + error->release = NULL; + } +} void Snowflake_release_error(struct AdbcError* error); diff --git a/python/adbc_driver_flightsql/tests/conftest.py b/python/adbc_driver_flightsql/tests/conftest.py index 4ca9508d07..b4eb181105 100644 --- a/python/adbc_driver_flightsql/tests/conftest.py +++ b/python/adbc_driver_flightsql/tests/conftest.py @@ -71,3 +71,13 @@ def dremio_dbapi(dremio_uri, dremio_user, dremio_pass): }, ) as conn: yield conn + + +@pytest.fixture +def test_dbapi(): + uri = os.environ.get("ADBC_TEST_FLIGHTSQL_URI") + if not uri: + pytest.skip("Set ADBC_TEST_FLIGHTSQL_URI to run tests") + + with adbc_driver_flightsql.dbapi.connect(uri) as conn: + yield conn diff --git a/python/adbc_driver_flightsql/tests/test_dbapi.py b/python/adbc_driver_flightsql/tests/test_dbapi.py index 0918fc7a93..e199035473 100644 --- a/python/adbc_driver_flightsql/tests/test_dbapi.py +++ b/python/adbc_driver_flightsql/tests/test_dbapi.py @@ -33,6 +33,21 @@ def test_query_error(dremio_dbapi): assert exc.args[0].startswith("INVALID_ARGUMENT: [FlightSQL] ") +def test_query_error_fetch(test_dbapi): + with test_dbapi.cursor() as cur: + cur.execute("error_do_get") + with pytest.raises(Exception, match="expected error"): + cur.fetch_arrow_table() + + +def test_query_error_stream(test_dbapi): + with test_dbapi.cursor() as cur: + cur.execute("error_do_get_stream") + with pytest.raises(Exception, match="expected error"): + cur.fetchone() + cur.fetchone() + + def test_query_trivial(dremio_dbapi): with dremio_dbapi.cursor() as cur: cur.execute("SELECT 1")