Skip to content

Commit 88d25a3

Browse files
committed
Support prepared statements and parameters
1 parent bd381c1 commit 88d25a3

18 files changed

+895
-326
lines changed

CMakeLists.txt

+7-1
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@ include_directories(
1616
# Embed ./src/assets/index.html as a C++ header
1717
add_custom_command(
1818
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/playground.hpp
19-
COMMAND ${CMAKE_COMMAND} -P ${PROJECT_SOURCE_DIR}/embed.cmake ${PROJECT_SOURCE_DIR}/src/assets/index.html ${CMAKE_CURRENT_BINARY_DIR}/playground.hpp playgroundContent
19+
COMMAND ${CMAKE_COMMAND} -P ${PROJECT_SOURCE_DIR}/embed.cmake ${PROJECT_SOURCE_DIR}/src/assets/index.html ${CMAKE_CURRENT_BINARY_DIR}/httpserver_extension/http_handler/playground.hpp playgroundContent
2020
DEPENDS ${PROJECT_SOURCE_DIR}/src/assets/index.html
2121
)
2222

2323
set(EXTENSION_SOURCES
2424
src/httpserver_extension.cpp
25+
src/http_handler/authentication.cpp
26+
src/http_handler/bindings.cpp
27+
src/http_handler/handler.cpp
28+
src/http_handler/response_serializer.cpp
2529
${CMAKE_CURRENT_BINARY_DIR}/playground.hpp
2630
)
2731

@@ -37,7 +41,9 @@ build_loadable_extension(${TARGET_NAME} " " ${EXTENSION_SOURCES})
3741

3842
include_directories(${OPENSSL_INCLUDE_DIR})
3943
target_link_libraries(${LOADABLE_EXTENSION_NAME} duckdb_mbedtls ${OPENSSL_LIBRARIES})
44+
set_property(TARGET ${LOADABLE_EXTENSION_NAME} PROPERTY CXX_STANDARD 17)
4045
target_link_libraries(${EXTENSION_NAME} duckdb_mbedtls ${OPENSSL_LIBRARIES})
46+
set_property(TARGET ${EXTENSION_NAME} PROPERTY CXX_STANDARD 17)
4147

4248
if(MINGW)
4349
set(WIN_LIBS crypt32 ws2_32 wsock32)

src/http_handler/authentication.cpp

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#include "httpserver_extension/http_handler/common.hpp"
2+
#include "httpserver_extension/state.hpp"
3+
#include <string>
4+
#include <vector>
5+
6+
#define CPPHTTPLIB_OPENSSL_SUPPORT
7+
#include "httplib.hpp"
8+
9+
namespace duckdb_httpserver {
10+
11+
// Base64 decoding function
12+
static std::string base64_decode(const std::string &in) {
13+
std::string out;
14+
std::vector<int> T(256, -1);
15+
for (int i = 0; i < 64; i++)
16+
T["ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"[i]] = i;
17+
18+
int val = 0, valb = -8;
19+
for (unsigned char c : in) {
20+
if (T[c] == -1) break;
21+
val = (val << 6) + T[c];
22+
valb += 6;
23+
if (valb >= 0) {
24+
out.push_back(char((val >> valb) & 0xFF));
25+
valb -= 8;
26+
}
27+
}
28+
return out;
29+
}
30+
31+
// Check authentication
32+
void CheckAuthentication(const duckdb_httplib_openssl::Request& req) {
33+
if (global_state.auth_token.empty()) {
34+
return; // No authentication required if no token is set
35+
}
36+
37+
// Check for X-API-Key header
38+
auto api_key = req.get_header_value("X-API-Key");
39+
if (!api_key.empty() && api_key == global_state.auth_token) {
40+
return;
41+
}
42+
43+
// Check for Basic Auth
44+
auto auth = req.get_header_value("Authorization");
45+
if (!auth.empty() && auth.compare(0, 6, "Basic ") == 0) {
46+
std::string decoded_auth = base64_decode(auth.substr(6));
47+
if (decoded_auth == global_state.auth_token) {
48+
return;
49+
}
50+
}
51+
52+
throw HttpHandlerException(401, "Unauthorized");
53+
}
54+
55+
} // namespace duckdb_httpserver

src/http_handler/bindings.cpp

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#include "httpserver_extension/http_handler/common.hpp"
2+
#include "duckdb.hpp"
3+
#include "yyjson.hpp"
4+
#include <string>
5+
6+
#define CPPHTTPLIB_OPENSSL_SUPPORT
7+
#include "httplib.hpp"
8+
9+
using namespace duckdb;
10+
using namespace duckdb_yyjson;
11+
12+
namespace duckdb_httpserver {
13+
14+
static BoundParameterData ExtractQueryParameter(const std::string& key, yyjson_val* parameterVal) {
15+
if (!yyjson_is_obj(parameterVal)) {
16+
throw HttpHandlerException(400, "The parameter `" + key + "` must be an object");
17+
}
18+
19+
auto typeVal = yyjson_obj_get(parameterVal, "type");
20+
if (!typeVal) {
21+
throw HttpHandlerException(400, "The parameter `" + key + "` does not have a `type` field");
22+
}
23+
if (!yyjson_is_str(typeVal)) {
24+
throw HttpHandlerException(400, "The field `type` for the parameter `" + key + "` must be a string");
25+
}
26+
auto type = std::string(yyjson_get_str(typeVal));
27+
28+
auto valueVal = yyjson_obj_get(parameterVal, "value");
29+
if (!valueVal) {
30+
throw HttpHandlerException(400, "The parameter `" + key + "` does not have a `value` field");
31+
}
32+
33+
if (type == "TEXT") {
34+
if (!yyjson_is_str(valueVal)) {
35+
throw HttpHandlerException(400, "The field `value` for the parameter `" + key + "` must be a string");
36+
}
37+
38+
return BoundParameterData(Value(yyjson_get_str(valueVal)));
39+
}
40+
else if (type == "BOOLEAN") {
41+
if (!yyjson_is_bool(valueVal)) {
42+
throw HttpHandlerException(400, "The field `value` for the parameter `" + key + "` must be a boolean");
43+
}
44+
45+
return BoundParameterData(Value(bool(yyjson_get_bool(valueVal))));
46+
}
47+
48+
throw HttpHandlerException(400, "Unsupported type " + type + " the parameter `" + key + "`");
49+
}
50+
51+
case_insensitive_map_t<BoundParameterData> ExtractQueryParameters(yyjson_val* parametersVal) {
52+
if (!parametersVal || !yyjson_is_obj(parametersVal)) {
53+
throw HttpHandlerException(400, "The `parameters` field must be an object");
54+
}
55+
56+
case_insensitive_map_t<BoundParameterData> named_values;
57+
58+
size_t idx, max;
59+
yyjson_val *parameterKeyVal, *parameterVal;
60+
yyjson_obj_foreach(parametersVal, idx, max, parameterKeyVal, parameterVal) {
61+
auto parameterKeyString = std::string(yyjson_get_str(parameterKeyVal));
62+
63+
named_values[parameterKeyString] = ExtractQueryParameter(parameterKeyString, parameterVal);
64+
}
65+
66+
return named_values;
67+
}
68+
69+
} // namespace duckdb_httpserver

src/http_handler/handler.cpp

+220
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
#include "httpserver_extension/http_handler/authentication.hpp"
2+
#include "httpserver_extension/http_handler/bindings.hpp"
3+
#include "httpserver_extension/http_handler/common.hpp"
4+
#include "httpserver_extension/http_handler/handler.hpp"
5+
#include "httpserver_extension/http_handler/playground.hpp"
6+
#include "httpserver_extension/http_handler/response_serializer.hpp"
7+
#include "httpserver_extension/state.hpp"
8+
#include "duckdb.hpp"
9+
#include "yyjson.hpp"
10+
11+
#include <string>
12+
#include <vector>
13+
14+
#define CPPHTTPLIB_OPENSSL_SUPPORT
15+
#include "httplib.hpp"
16+
17+
using namespace duckdb;
18+
using namespace duckdb_yyjson;
19+
20+
namespace duckdb_httpserver {
21+
22+
// Handle both GET and POST requests
23+
void HttpHandler(const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) {
24+
try {
25+
// CORS allow
26+
res.set_header("Access-Control-Allow-Origin", "*");
27+
res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT");
28+
res.set_header("Access-Control-Allow-Headers", "*");
29+
res.set_header("Access-Control-Allow-Credentials", "true");
30+
res.set_header("Access-Control-Max-Age", "86400");
31+
32+
// Handle preflight OPTIONS request
33+
if (req.method == "OPTIONS") {
34+
res.status = 204; // No content
35+
return;
36+
}
37+
38+
CheckAuthentication(req);
39+
40+
auto queryApiParameters = ExtractQueryApiParameters(req);
41+
42+
if (!queryApiParameters.sqlQueryOpt.has_value()) {
43+
res.status = 200;
44+
res.set_content(reinterpret_cast<char const*>(playgroundContent), sizeof(playgroundContent), "text/html");
45+
return;
46+
}
47+
48+
if (!global_state.db_instance) {
49+
throw IOException("Database instance not initialized");
50+
}
51+
52+
auto start = std::chrono::system_clock::now();
53+
auto result = ExecuteQuery(req, queryApiParameters);
54+
auto end = std::chrono::system_clock::now();
55+
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
56+
57+
QueryExecStats stats {
58+
static_cast<float>(elapsed.count()) / 1000,
59+
0,
60+
0
61+
};
62+
63+
// Format output
64+
if (queryApiParameters.outputFormat == OutputFormat::Ndjson) {
65+
std::string json_output = ConvertResultToNDJSON(*result);
66+
res.set_content(json_output, "application/x-ndjson");
67+
}
68+
else {
69+
auto json_output = ConvertResultToJSON(*result, stats);
70+
res.set_content(json_output, "application/json");
71+
}
72+
}
73+
catch (const HttpHandlerException& ex) {
74+
res.status = ex.status;
75+
res.set_content(ex.message, "text/plain");
76+
}
77+
catch (const std::exception& ex) {
78+
res.status = 500;
79+
std::string error_message = "Code: 59, e.displayText() = DB::Exception: " + std::string(ex.what());
80+
res.set_content(error_message, "text/plain");
81+
}
82+
}
83+
84+
// Execute query (optionally using a prepared statement)
85+
std::unique_ptr<MaterializedQueryResult> ExecuteQuery(
86+
const duckdb_httplib_openssl::Request& req,
87+
const QueryApiParameters& queryApiParameters
88+
) {
89+
Connection con(*global_state.db_instance);
90+
std::unique_ptr<MaterializedQueryResult> result;
91+
auto query = queryApiParameters.sqlQueryOpt.value();
92+
93+
auto use_prepared_stmt =
94+
queryApiParameters.sqlParametersOpt.has_value() &&
95+
queryApiParameters.sqlParametersOpt.value().empty() == false;
96+
97+
if (use_prepared_stmt) {
98+
auto prepared_stmt = con.Prepare(query);
99+
if (prepared_stmt->HasError()) {
100+
throw HttpHandlerException(500, prepared_stmt->GetError());
101+
}
102+
103+
auto named_values = queryApiParameters.sqlParametersOpt.value();
104+
105+
auto prepared_stmt_result = prepared_stmt->Execute(named_values);
106+
D_ASSERT(prepared_stmt_result->type == QueryResultType::STREAM_RESULT);
107+
result = unique_ptr_cast<QueryResult, StreamQueryResult>(std::move(prepared_stmt_result))->Materialize();
108+
} else {
109+
result = con.Query(query);
110+
}
111+
112+
if (result->HasError()) {
113+
throw HttpHandlerException(500, result->GetError());
114+
}
115+
116+
return result;
117+
}
118+
119+
QueryApiParameters ExtractQueryApiParameters(const duckdb_httplib_openssl::Request& req) {
120+
if (req.method == "POST" && req.has_header("Content-Type") && req.get_header_value("Content-Type") == "application/json") {
121+
return ExtractQueryApiParametersComplex(req);
122+
}
123+
else {
124+
return QueryApiParameters {
125+
ExtractSqlQuerySimple(req),
126+
std::nullopt,
127+
ExtractOutputFormatSimple(req),
128+
};
129+
}
130+
}
131+
132+
std::optional<std::string> ExtractSqlQuerySimple(const duckdb_httplib_openssl::Request& req) {
133+
// Check if the query is in the URL parameters
134+
if (req.has_param("query")) {
135+
return req.get_param_value("query");
136+
}
137+
else if (req.has_param("q")) {
138+
return req.get_param_value("q");
139+
}
140+
141+
// If not in URL, and it's a POST request, check the body
142+
else if (req.method == "POST" && !req.body.empty()) {
143+
return req.body;
144+
}
145+
146+
return std::nullopt;
147+
}
148+
149+
OutputFormat ExtractOutputFormatSimple(const duckdb_httplib_openssl::Request& req) {
150+
// Check for format in URL parameter or header
151+
if (req.has_param("default_format")) {
152+
return ParseOutputFormat(req.get_param_value("default_format"));
153+
}
154+
else if (req.has_header("X-ClickHouse-Format")) {
155+
return ParseOutputFormat(req.get_header_value("X-ClickHouse-Format"));
156+
}
157+
else if (req.has_header("format")) {
158+
return ParseOutputFormat(req.get_header_value("format"));
159+
}
160+
else {
161+
return OutputFormat::Ndjson;
162+
}
163+
}
164+
165+
OutputFormat ParseOutputFormat(const std::string& formatStr) {
166+
if (formatStr == "JSONEachRow" || formatStr == "ndjson" || formatStr == "jsonl") {
167+
return OutputFormat::Ndjson;
168+
}
169+
else if (formatStr == "JSONCompact") {
170+
return OutputFormat::Json;
171+
}
172+
else {
173+
throw HttpHandlerException(400, "Unknown format");
174+
}
175+
}
176+
177+
QueryApiParameters ExtractQueryApiParametersComplex(const duckdb_httplib_openssl::Request& req) {
178+
yyjson_doc *bodyDoc = nullptr;
179+
180+
try {
181+
auto bodyJson = req.body;
182+
auto bodyJsonCStr = bodyJson.c_str();
183+
bodyDoc = yyjson_read(bodyJsonCStr, strlen(bodyJsonCStr), 0);
184+
185+
return ExtractQueryApiParametersComplexImpl(bodyDoc);
186+
}
187+
catch (const std::exception& exception) {
188+
yyjson_doc_free(bodyDoc);
189+
throw;
190+
}
191+
}
192+
193+
QueryApiParameters ExtractQueryApiParametersComplexImpl(yyjson_doc* bodyDoc) {
194+
if (!bodyDoc) {
195+
throw HttpHandlerException(400, "Unable to parse the request body");
196+
}
197+
198+
auto bodyRoot = yyjson_doc_get_root(bodyDoc);
199+
if (!yyjson_is_obj(bodyRoot)) {
200+
throw HttpHandlerException(400, "The request body must be an object");
201+
}
202+
203+
auto queryVal = yyjson_obj_get(bodyRoot, "query");
204+
if (!queryVal || !yyjson_is_str(queryVal)) {
205+
throw HttpHandlerException(400, "The `query` field must be a string");
206+
}
207+
208+
auto formatVal = yyjson_obj_get(bodyRoot, "format");
209+
if (!formatVal || !yyjson_is_str(formatVal)) {
210+
throw HttpHandlerException(400, "The `format` field must be a string");
211+
}
212+
213+
return QueryApiParameters {
214+
std::string(yyjson_get_str(queryVal)),
215+
ExtractQueryParameters(yyjson_obj_get(bodyRoot, "parameters")),
216+
ParseOutputFormat(std::string(yyjson_get_str(formatVal))),
217+
};
218+
}
219+
220+
} // namespace duckdb_httpserver

0 commit comments

Comments
 (0)