Skip to content

Commit c3d2472

Browse files
committed
Support prepared statements and parameters
1 parent bd381c1 commit c3d2472

18 files changed

+894
-326
lines changed

CMakeLists.txt

+7-1
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@ include_directories(
1616
# Embed ./src/assets/index.html as a C++ header
1717
add_custom_command(
1818
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/playground.hpp
19-
COMMAND ${CMAKE_COMMAND} -P ${PROJECT_SOURCE_DIR}/embed.cmake ${PROJECT_SOURCE_DIR}/src/assets/index.html ${CMAKE_CURRENT_BINARY_DIR}/playground.hpp playgroundContent
19+
COMMAND ${CMAKE_COMMAND} -P ${PROJECT_SOURCE_DIR}/embed.cmake ${PROJECT_SOURCE_DIR}/src/assets/index.html ${CMAKE_CURRENT_BINARY_DIR}/httpserver_extension/http_handler/playground.hpp playgroundContent
2020
DEPENDS ${PROJECT_SOURCE_DIR}/src/assets/index.html
2121
)
2222

2323
set(EXTENSION_SOURCES
2424
src/httpserver_extension.cpp
25+
src/http_handler/authentication.cpp
26+
src/http_handler/bindings.cpp
27+
src/http_handler/handler.cpp
28+
src/http_handler/response_serializer.cpp
2529
${CMAKE_CURRENT_BINARY_DIR}/playground.hpp
2630
)
2731

@@ -37,7 +41,9 @@ build_loadable_extension(${TARGET_NAME} " " ${EXTENSION_SOURCES})
3741

3842
include_directories(${OPENSSL_INCLUDE_DIR})
3943
target_link_libraries(${LOADABLE_EXTENSION_NAME} duckdb_mbedtls ${OPENSSL_LIBRARIES})
44+
set_property(TARGET ${LOADABLE_EXTENSION_NAME} PROPERTY CXX_STANDARD 17)
4045
target_link_libraries(${EXTENSION_NAME} duckdb_mbedtls ${OPENSSL_LIBRARIES})
46+
set_property(TARGET ${EXTENSION_NAME} PROPERTY CXX_STANDARD 17)
4147

4248
if(MINGW)
4349
set(WIN_LIBS crypt32 ws2_32 wsock32)

src/http_handler/authentication.cpp

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#include "httpserver_extension/http_handler/common.hpp"
2+
#include "httpserver_extension/state.hpp"
3+
#include <string>
4+
#include <vector>
5+
6+
#define CPPHTTPLIB_OPENSSL_SUPPORT
7+
#include "httplib.hpp"
8+
9+
namespace duckdb_httpserver {
10+
11+
// Base64 decoding function
12+
static std::string base64_decode(const std::string &in) {
13+
std::string out;
14+
std::vector<int> T(256, -1);
15+
for (int i = 0; i < 64; i++)
16+
T["ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"[i]] = i;
17+
18+
int val = 0, valb = -8;
19+
for (unsigned char c : in) {
20+
if (T[c] == -1) break;
21+
val = (val << 6) + T[c];
22+
valb += 6;
23+
if (valb >= 0) {
24+
out.push_back(char((val >> valb) & 0xFF));
25+
valb -= 8;
26+
}
27+
}
28+
return out;
29+
}
30+
31+
// Check authentication
32+
void CheckAuthentication(const duckdb_httplib_openssl::Request& req) {
33+
if (global_state.auth_token.empty()) {
34+
return; // No authentication required if no token is set
35+
}
36+
37+
// Check for X-API-Key header
38+
auto api_key = req.get_header_value("X-API-Key");
39+
if (!api_key.empty() && api_key == global_state.auth_token) {
40+
return;
41+
}
42+
43+
// Check for Basic Auth
44+
auto auth = req.get_header_value("Authorization");
45+
if (!auth.empty() && auth.compare(0, 6, "Basic ") == 0) {
46+
std::string decoded_auth = base64_decode(auth.substr(6));
47+
if (decoded_auth == global_state.auth_token) {
48+
return;
49+
}
50+
}
51+
52+
throw HttpHandlerException(401, "Unauthorized");
53+
}
54+
55+
} // namespace duckdb_httpserver

src/http_handler/bindings.cpp

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#include "httpserver_extension/http_handler/common.hpp"
2+
#include "duckdb.hpp"
3+
#include "yyjson.hpp"
4+
#include <string>
5+
6+
#define CPPHTTPLIB_OPENSSL_SUPPORT
7+
#include "httplib.hpp"
8+
9+
using namespace duckdb;
10+
using namespace duckdb_yyjson;
11+
12+
namespace duckdb_httpserver {
13+
14+
static BoundParameterData ExtractQueryParameter(const std::string& key, yyjson_val* parameterVal) {
15+
if (!yyjson_is_obj(parameterVal)) {
16+
throw HttpHandlerException(400, "The parameter `" + key + "` must be an object");
17+
}
18+
19+
auto typeVal = yyjson_obj_get(parameterVal, "type");
20+
if (!typeVal) {
21+
throw HttpHandlerException(400, "The parameter `" + key + "` does not have a `type` field");
22+
}
23+
if (!yyjson_is_str(typeVal)) {
24+
throw HttpHandlerException(400, "The field `type` for the parameter `" + key + "` must be a string");
25+
}
26+
auto type = std::string(yyjson_get_str(typeVal));
27+
28+
auto valueVal = yyjson_obj_get(parameterVal, "value");
29+
if (!valueVal) {
30+
throw HttpHandlerException(400, "The parameter `" + key + "` does not have a `value` field");
31+
}
32+
33+
if (type == "TEXT") {
34+
if (!yyjson_is_str(valueVal)) {
35+
throw HttpHandlerException(400, "The field `value` for the parameter `" + key + "` must be a string");
36+
}
37+
38+
return BoundParameterData(Value(yyjson_get_str(valueVal)));
39+
}
40+
else if (type == "BOOLEAN") {
41+
if (!yyjson_is_bool(valueVal)) {
42+
throw HttpHandlerException(400, "The field `value` for the parameter `" + key + "` must be a boolean");
43+
}
44+
45+
return BoundParameterData(Value(bool(yyjson_get_bool(valueVal))));
46+
}
47+
48+
throw HttpHandlerException(400, "Unsupported type " + type + " the parameter `" + key + "`");
49+
}
50+
51+
case_insensitive_map_t<BoundParameterData> ExtractQueryParameters(yyjson_val* parametersVal) {
52+
if (!parametersVal || !yyjson_is_obj(parametersVal)) {
53+
throw HttpHandlerException(400, "The `parameters` field must be an object");
54+
}
55+
56+
case_insensitive_map_t<BoundParameterData> named_values;
57+
58+
size_t idx, max;
59+
yyjson_val *parameterKeyVal, *parameterVal;
60+
yyjson_obj_foreach(parametersVal, idx, max, parameterKeyVal, parameterVal) {
61+
auto parameterKeyString = std::string(yyjson_get_str(parameterKeyVal));
62+
63+
named_values[parameterKeyString] = ExtractQueryParameter(parameterKeyString, parameterVal);
64+
}
65+
66+
return named_values;
67+
}
68+
69+
} // namespace duckdb_httpserver

src/http_handler/handler.cpp

+219
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
#include "httpserver_extension/http_handler/authentication.hpp"
2+
#include "httpserver_extension/http_handler/bindings.hpp"
3+
#include "httpserver_extension/http_handler/common.hpp"
4+
#include "httpserver_extension/http_handler/handler.hpp"
5+
#include "httpserver_extension/http_handler/playground.hpp"
6+
#include "httpserver_extension/http_handler/response_serializer.hpp"
7+
#include "httpserver_extension/state.hpp"
8+
#include "duckdb.hpp"
9+
#include "yyjson.hpp"
10+
11+
#include <string>
12+
#include <vector>
13+
14+
#define CPPHTTPLIB_OPENSSL_SUPPORT
15+
#include "httplib.hpp"
16+
17+
using namespace duckdb;
18+
using namespace duckdb_yyjson;
19+
20+
namespace duckdb_httpserver {
21+
22+
// Handle both GET and POST requests
23+
void HttpHandler(const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) {
24+
try {
25+
// CORS allow
26+
res.set_header("Access-Control-Allow-Origin", "*");
27+
res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT");
28+
res.set_header("Access-Control-Allow-Headers", "*");
29+
res.set_header("Access-Control-Allow-Credentials", "true");
30+
res.set_header("Access-Control-Max-Age", "86400");
31+
32+
// Handle preflight OPTIONS request
33+
if (req.method == "OPTIONS") {
34+
res.status = 204; // No content
35+
return;
36+
}
37+
38+
CheckAuthentication(req);
39+
40+
auto queryApiParameters = ExtractQueryApiParameters(req);
41+
42+
if (!queryApiParameters.sqlQueryOpt.has_value()) {
43+
res.status = 200;
44+
res.set_content(reinterpret_cast<char const*>(playgroundContent), sizeof(playgroundContent), "text/html");
45+
return;
46+
}
47+
48+
if (!global_state.db_instance) {
49+
throw IOException("Database instance not initialized");
50+
}
51+
52+
auto start = std::chrono::system_clock::now();
53+
auto result = ExecuteQuery(req, queryApiParameters);
54+
auto end = std::chrono::system_clock::now();
55+
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
56+
57+
QueryExecStats stats {
58+
.elapsed_sec = static_cast<float>(elapsed.count()) / 1000,
59+
.read_bytes = 0,
60+
.read_rows = 0
61+
};
62+
63+
// Format output
64+
if (queryApiParameters.outputFormat == OutputFormat::Ndjson) {
65+
std::string json_output = ConvertResultToNDJSON(*result);
66+
res.set_content(json_output, "application/x-ndjson");
67+
}
68+
else {
69+
auto json_output = ConvertResultToJSON(*result, stats);
70+
res.set_content(json_output, "application/json");
71+
}
72+
}
73+
catch (const HttpHandlerException& ex) {
74+
res.status = ex.status;
75+
res.set_content(ex.message, "text/plain");
76+
}
77+
catch (const std::exception& ex) {
78+
res.status = 500;
79+
std::string error_message = "Code: 59, e.displayText() = DB::Exception: " + std::string(ex.what());
80+
res.set_content(error_message, "text/plain");
81+
}
82+
}
83+
84+
// Execute query (optionally using a prepared statement)
85+
std::unique_ptr<MaterializedQueryResult> ExecuteQuery(
86+
const duckdb_httplib_openssl::Request& req,
87+
const QueryApiParameters& queryApiParameters
88+
) {
89+
Connection con(*global_state.db_instance);
90+
std::unique_ptr<MaterializedQueryResult> result;
91+
auto query = queryApiParameters.sqlQueryOpt.value();
92+
93+
auto use_prepared_stmt =
94+
queryApiParameters.sqlParametersOpt.has_value() &&
95+
queryApiParameters.sqlParametersOpt.value().empty() == false;
96+
97+
if (use_prepared_stmt) {
98+
auto prepared_stmt = con.Prepare(query);
99+
if (prepared_stmt->HasError()) {
100+
throw HttpHandlerException(500, prepared_stmt->GetError());
101+
}
102+
103+
auto named_values = queryApiParameters.sqlParametersOpt.value();
104+
105+
auto prepared_stmt_result = prepared_stmt->Execute(named_values);
106+
D_ASSERT(prepared_stmt_result->type == QueryResultType::STREAM_RESULT);
107+
result = unique_ptr_cast<QueryResult, StreamQueryResult>(std::move(prepared_stmt_result))->Materialize();
108+
} else {
109+
result = con.Query(query);
110+
}
111+
112+
if (result->HasError()) {
113+
throw HttpHandlerException(500, result->GetError());
114+
}
115+
116+
return result;
117+
}
118+
119+
QueryApiParameters ExtractQueryApiParameters(const duckdb_httplib_openssl::Request& req) {
120+
if (req.method == "POST" && req.has_header("Content-Type") && req.get_header_value("Content-Type") == "application/json") {
121+
return ExtractQueryApiParametersComplex(req);
122+
}
123+
else {
124+
return QueryApiParameters {
125+
.sqlQueryOpt = ExtractSqlQuerySimple(req),
126+
.outputFormat = ExtractOutputFormatSimple(req),
127+
};
128+
}
129+
}
130+
131+
std::optional<std::string> ExtractSqlQuerySimple(const duckdb_httplib_openssl::Request& req) {
132+
// Check if the query is in the URL parameters
133+
if (req.has_param("query")) {
134+
return req.get_param_value("query");
135+
}
136+
else if (req.has_param("q")) {
137+
return req.get_param_value("q");
138+
}
139+
140+
// If not in URL, and it's a POST request, check the body
141+
else if (req.method == "POST" && !req.body.empty()) {
142+
return req.body;
143+
}
144+
145+
return std::nullopt;
146+
}
147+
148+
OutputFormat ExtractOutputFormatSimple(const duckdb_httplib_openssl::Request& req) {
149+
// Check for format in URL parameter or header
150+
if (req.has_param("default_format")) {
151+
return ParseOutputFormat(req.get_param_value("default_format"));
152+
}
153+
else if (req.has_header("X-ClickHouse-Format")) {
154+
return ParseOutputFormat(req.get_header_value("X-ClickHouse-Format"));
155+
}
156+
else if (req.has_header("format")) {
157+
return ParseOutputFormat(req.get_header_value("format"));
158+
}
159+
else {
160+
return OutputFormat::Ndjson;
161+
}
162+
}
163+
164+
OutputFormat ParseOutputFormat(const std::string& formatStr) {
165+
if (formatStr == "JSONEachRow" || formatStr == "ndjson" || formatStr == "jsonl") {
166+
return OutputFormat::Ndjson;
167+
}
168+
else if (formatStr == "JSONCompact") {
169+
return OutputFormat::Json;
170+
}
171+
else {
172+
throw HttpHandlerException(400, "Unknown format");
173+
}
174+
}
175+
176+
QueryApiParameters ExtractQueryApiParametersComplex(const duckdb_httplib_openssl::Request& req) {
177+
yyjson_doc *bodyDoc = nullptr;
178+
179+
try {
180+
auto bodyJson = req.body;
181+
auto bodyJsonCStr = bodyJson.c_str();
182+
bodyDoc = yyjson_read(bodyJsonCStr, strlen(bodyJsonCStr), 0);
183+
184+
return ExtractQueryApiParametersComplexImpl(bodyDoc);
185+
}
186+
catch (const std::exception& exception) {
187+
yyjson_doc_free(bodyDoc);
188+
throw;
189+
}
190+
}
191+
192+
QueryApiParameters ExtractQueryApiParametersComplexImpl(yyjson_doc* bodyDoc) {
193+
if (!bodyDoc) {
194+
throw HttpHandlerException(400, "Unable to parse the request body");
195+
}
196+
197+
auto bodyRoot = yyjson_doc_get_root(bodyDoc);
198+
if (!yyjson_is_obj(bodyRoot)) {
199+
throw HttpHandlerException(400, "The request body must be an object");
200+
}
201+
202+
auto queryVal = yyjson_obj_get(bodyRoot, "query");
203+
if (!queryVal || !yyjson_is_str(queryVal)) {
204+
throw HttpHandlerException(400, "The `query` field must be a string");
205+
}
206+
207+
auto formatVal = yyjson_obj_get(bodyRoot, "format");
208+
if (!formatVal || !yyjson_is_str(formatVal)) {
209+
throw HttpHandlerException(400, "The `format` field must be a string");
210+
}
211+
212+
return QueryApiParameters {
213+
.sqlQueryOpt = std::string(yyjson_get_str(queryVal)),
214+
.sqlParametersOpt = ExtractQueryParameters(yyjson_obj_get(bodyRoot, "parameters")),
215+
.outputFormat = ParseOutputFormat(std::string(yyjson_get_str(formatVal))),
216+
};
217+
}
218+
219+
} // namespace duckdb_httpserver

0 commit comments

Comments
 (0)