Skip to content

Commit b2e8f25

Browse files
authored
GH-47711: [C++][FlightRPC] Enable ODBC query execution (#48032)
### Rationale for this change Enable query execution in ODBC. ### What changes are included in this PR? - Extract SQLExecDirect, SQLExecute, SQLPrepare implementation & tests ### Are these changes tested? - Tested on local MSVC ### Are there any user-facing changes? N/A * GitHub Issue: #47711 Authored-by: Alina (Xi) Li <[email protected]> Signed-off-by: David Li <[email protected]>
1 parent e90bacd commit b2e8f25

File tree

4 files changed

+119
-11
lines changed

4 files changed

+119
-11
lines changed

cpp/src/arrow/flight/sql/odbc/odbc_api.cc

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,22 +1005,49 @@ SQLRETURN SQLExecDirect(SQLHSTMT stmt, SQLWCHAR* query_text, SQLINTEGER text_len
10051005
ARROW_LOG(DEBUG) << "SQLExecDirectW called with stmt: " << stmt
10061006
<< ", query_text: " << static_cast<const void*>(query_text)
10071007
<< ", text_length: " << text_length;
1008-
// GH-47711 TODO: Implement SQLExecDirect
1009-
return SQL_INVALID_HANDLE;
1008+
1009+
using ODBC::ODBCStatement;
1010+
// The driver is built to handle SELECT statements only.
1011+
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
1012+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
1013+
std::string query = ODBC::SqlWcharToString(query_text, text_length);
1014+
1015+
statement->Prepare(query);
1016+
statement->ExecutePrepared();
1017+
1018+
return SQL_SUCCESS;
1019+
});
10101020
}
10111021

10121022
SQLRETURN SQLPrepare(SQLHSTMT stmt, SQLWCHAR* query_text, SQLINTEGER text_length) {
10131023
ARROW_LOG(DEBUG) << "SQLPrepareW called with stmt: " << stmt
10141024
<< ", query_text: " << static_cast<const void*>(query_text)
10151025
<< ", text_length: " << text_length;
1016-
// GH-47712 TODO: Implement SQLPrepare
1017-
return SQL_INVALID_HANDLE;
1026+
1027+
using ODBC::ODBCStatement;
1028+
// The driver is built to handle SELECT statements only.
1029+
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
1030+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
1031+
std::string query = ODBC::SqlWcharToString(query_text, text_length);
1032+
1033+
statement->Prepare(query);
1034+
1035+
return SQL_SUCCESS;
1036+
});
10181037
}
10191038

10201039
SQLRETURN SQLExecute(SQLHSTMT stmt) {
10211040
ARROW_LOG(DEBUG) << "SQLExecute called with stmt: " << stmt;
1022-
// GH-47712 TODO: Implement SQLExecute
1023-
return SQL_INVALID_HANDLE;
1041+
1042+
using ODBC::ODBCStatement;
1043+
// The driver is built to handle SELECT statements only.
1044+
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
1045+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
1046+
1047+
statement->ExecutePrepared();
1048+
1049+
return SQL_SUCCESS;
1050+
});
10241051
}
10251052

10261053
SQLRETURN SQLFetch(SQLHSTMT stmt) {

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ FlightSqlStatement::FlightSqlStatement(const Diagnostics& diagnostics,
6969
call_options_.timeout = TimeoutDuration{-1};
7070
}
7171

72+
FlightSqlStatement::~FlightSqlStatement() {
73+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
74+
}
75+
7276
bool FlightSqlStatement::SetAttribute(StatementAttributeId attribute,
7377
const Attribute& value) {
7478
switch (attribute) {
@@ -119,7 +123,6 @@ bool FlightSqlStatement::ExecutePrepared() {
119123

120124
Result<std::shared_ptr<FlightInfo>> result =
121125
prepared_statement_->Execute(call_options_);
122-
123126
ThrowIfNotOK(result.status());
124127

125128
current_result_set_ = std::make_shared<FlightSqlResultSet>(

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class FlightSqlStatement : public Statement {
4949
FlightSqlStatement(const Diagnostics& diagnostics, FlightSqlClient& sql_client,
5050
FlightClientOptions client_options, FlightCallOptions call_options,
5151
const MetadataSettings& metadata_settings);
52+
~FlightSqlStatement();
5253

5354
bool SetAttribute(StatementAttributeId attribute, const Attribute& value) override;
5455

cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,86 @@ class StatementRemoteTest : public FlightSQLODBCRemoteTestBase {};
3737
using TestTypes = ::testing::Types<StatementMockTest, StatementRemoteTest>;
3838
TYPED_TEST_SUITE(StatementTest, TestTypes);
3939

40+
TYPED_TEST(StatementTest, TestSQLExecDirectSimpleQuery) {
41+
std::wstring wsql = L"SELECT 1;";
42+
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
43+
44+
ASSERT_EQ(SQL_SUCCESS,
45+
SQLExecDirect(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));
46+
47+
// GH-47713 TODO: Uncomment call to SQLFetch SQLGetData after implementation
48+
/*
49+
ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt));
50+
51+
SQLINTEGER val;
52+
53+
ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
54+
// Verify 1 is returned
55+
EXPECT_EQ(1, val);
56+
57+
ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt));
58+
59+
ASSERT_EQ(SQL_ERROR, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
60+
// Invalid cursor state
61+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState24000);
62+
*/
63+
}
64+
65+
TYPED_TEST(StatementTest, TestSQLExecDirectInvalidQuery) {
66+
std::wstring wsql = L"SELECT;";
67+
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
68+
69+
ASSERT_EQ(SQL_ERROR,
70+
SQLExecDirect(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));
71+
// ODBC provides generic error code HY000 to all statement errors
72+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY000);
73+
}
74+
75+
TYPED_TEST(StatementTest, TestSQLExecuteSimpleQuery) {
76+
std::wstring wsql = L"SELECT 1;";
77+
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
78+
79+
ASSERT_EQ(SQL_SUCCESS,
80+
SQLPrepare(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));
81+
82+
ASSERT_EQ(SQL_SUCCESS, SQLExecute(this->stmt));
83+
84+
// GH-47713 TODO: Uncomment call to SQLFetch SQLGetData after implementation
85+
/*
86+
// Fetch data
87+
ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt));
88+
89+
SQLINTEGER val;
90+
ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
91+
92+
// Verify 1 is returned
93+
EXPECT_EQ(1, val);
94+
95+
ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt));
96+
97+
ASSERT_EQ(SQL_ERROR, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
98+
// Invalid cursor state
99+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState24000);
100+
*/
101+
}
102+
103+
TYPED_TEST(StatementTest, TestSQLPrepareInvalidQuery) {
104+
std::wstring wsql = L"SELECT;";
105+
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
106+
107+
ASSERT_EQ(SQL_ERROR,
108+
SQLPrepare(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));
109+
// ODBC provides generic error code HY000 to all statement errors
110+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY000);
111+
112+
ASSERT_EQ(SQL_ERROR, SQLExecute(this->stmt));
113+
// Verify function sequence error state is returned
114+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY010);
115+
}
116+
40117
TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsInputString) {
41118
SQLWCHAR buf[1024];
42-
SQLINTEGER buf_char_len = sizeof(buf) / ODBC::GetSqlWCharSize();
119+
SQLINTEGER buf_char_len = sizeof(buf) / GetSqlWCharSize();
43120
SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1";
44121
SQLINTEGER input_char_len = static_cast<SQLINTEGER>(wcslen(input_str));
45122
SQLINTEGER output_char_len = 0;
@@ -58,7 +135,7 @@ TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsInputString) {
58135

59136
TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsNTSInputString) {
60137
SQLWCHAR buf[1024];
61-
SQLINTEGER buf_char_len = sizeof(buf) / ODBC::GetSqlWCharSize();
138+
SQLINTEGER buf_char_len = sizeof(buf) / GetSqlWCharSize();
62139
SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1";
63140
SQLINTEGER input_char_len = static_cast<SQLINTEGER>(wcslen(input_str));
64141
SQLINTEGER output_char_len = 0;
@@ -95,7 +172,7 @@ TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsInputStringLength) {
95172
TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsTruncatedString) {
96173
const SQLINTEGER small_buf_size_in_char = 11;
97174
SQLWCHAR small_buf[small_buf_size_in_char];
98-
SQLINTEGER small_buf_char_len = sizeof(small_buf) / ODBC::GetSqlWCharSize();
175+
SQLINTEGER small_buf_char_len = sizeof(small_buf) / GetSqlWCharSize();
99176
SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1";
100177
SQLINTEGER input_char_len = static_cast<SQLINTEGER>(wcslen(input_str));
101178
SQLINTEGER output_char_len = 0;
@@ -122,7 +199,7 @@ TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsTruncatedString) {
122199

123200
TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsErrorOnBadInputs) {
124201
SQLWCHAR buf[1024];
125-
SQLINTEGER buf_char_len = sizeof(buf) / ODBC::GetSqlWCharSize();
202+
SQLINTEGER buf_char_len = sizeof(buf) / GetSqlWCharSize();
126203
SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1";
127204
SQLINTEGER input_char_len = static_cast<SQLINTEGER>(wcslen(input_str));
128205
SQLINTEGER output_char_len = 0;

0 commit comments

Comments
 (0)