Skip to content

Commit 8310c77

Browse files
committed
Support prepared statements and parameters
1 parent bd381c1 commit 8310c77

File tree

5 files changed

+411
-69
lines changed

5 files changed

+411
-69
lines changed

src/httpserver_extension.cpp

+161-46
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ struct HttpServerState {
4040

4141
static HttpServerState global_state;
4242

43+
struct HttpServerException: public std::exception {
44+
int status;
45+
std::string message;
46+
47+
HttpServerException(int status, const std::string& message) : message(message), status(status) {}
48+
};
49+
4350
std::string GetColumnType(MaterializedQueryResult &result, idx_t column) {
4451
if (result.RowCount() == 0) {
4552
return "String";
@@ -152,28 +159,28 @@ std::string base64_decode(const std::string &in) {
152159
return out;
153160
}
154161

155-
// Auth Check
156-
bool IsAuthenticated(const duckdb_httplib_openssl::Request& req) {
162+
// Check authentication
163+
void CheckAuthentication(const duckdb_httplib_openssl::Request& req) {
157164
if (global_state.auth_token.empty()) {
158-
return true; // No authentication required if no token is set
165+
return; // No authentication required if no token is set
159166
}
160167

161168
// Check for X-API-Key header
162169
auto api_key = req.get_header_value("X-API-Key");
163170
if (!api_key.empty() && api_key == global_state.auth_token) {
164-
return true;
171+
return;
165172
}
166173

167174
// Check for Basic Auth
168175
auto auth = req.get_header_value("Authorization");
169176
if (!auth.empty() && auth.compare(0, 6, "Basic ") == 0) {
170177
std::string decoded_auth = base64_decode(auth.substr(6));
171178
if (decoded_auth == global_state.auth_token) {
172-
return true;
179+
return;
173180
}
174181
}
175182

176-
return false;
183+
throw HttpServerException(401, "Unauthorized");
177184
}
178185

179186
// Convert the query result to NDJSON (JSONEachRow) format
@@ -217,49 +224,131 @@ static std::string ConvertResultToNDJSON(MaterializedQueryResult &result) {
217224
return ndjson_output;
218225
}
219226

220-
// Handle both GET and POST requests
221-
void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) {
222-
std::string query;
227+
BoundParameterData ExtractQueryParameter(const std::string& key, yyjson_val* parameterVal) {
228+
if (!yyjson_is_obj(parameterVal)) {
229+
throw HttpServerException(400, "The parameter `" + key + "` parameter must be an object");
230+
}
223231

224-
// Check authentication
225-
if (!IsAuthenticated(req)) {
226-
res.status = 401;
227-
res.set_content("Unauthorized", "text/plain");
228-
return;
232+
auto typeVal = yyjson_obj_get(parameterVal, "type");
233+
if (!typeVal) {
234+
throw HttpServerException(400, "The parameter `" + key + "` does not have a `type` field");
235+
}
236+
if (!yyjson_is_str(typeVal)) {
237+
throw HttpServerException(400, "The field `type` for the parameter `" + key + "` must be a string");
238+
}
239+
auto type = std::string(yyjson_get_str(typeVal));
240+
241+
auto valueVal = yyjson_obj_get(parameterVal, "value");
242+
if (!valueVal) {
243+
throw HttpServerException(400, "The parameter `" + key + "` does not have a `value` field");
229244
}
230245

231-
// CORS allow
232-
res.set_header("Access-Control-Allow-Origin", "*");
233-
res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT");
234-
res.set_header("Access-Control-Allow-Headers", "*");
235-
res.set_header("Access-Control-Allow-Credentials", "true");
236-
res.set_header("Access-Control-Max-Age", "86400");
246+
if (type == "TEXT") {
247+
if (!yyjson_is_str(valueVal)) {
248+
throw HttpServerException(400, "The field `value` for the parameter `" + key + "` must be a string");
249+
}
237250

238-
// Handle preflight OPTIONS request
239-
if (req.method == "OPTIONS") {
240-
res.status = 204; // No content
241-
return;
251+
return BoundParameterData(Value(yyjson_get_str(valueVal)));
252+
}
253+
else if (type == "BOOLEAN") {
254+
if (!yyjson_is_bool(valueVal)) {
255+
throw HttpServerException(400, "The field `value` for the parameter `" + key + "` must be a boolean");
256+
}
257+
258+
return BoundParameterData(Value(bool(yyjson_get_bool(valueVal))));
259+
}
260+
261+
throw HttpServerException(400, "Unsupported type " + type + " the parameter `" + key + "`");
262+
}
263+
264+
case_insensitive_map_t<BoundParameterData> ExtractQueryParameters(yyjson_doc* parametersDoc) {
265+
if (!parametersDoc) {
266+
throw HttpServerException(400, "Unable to parse the `parameters` parameter");
267+
}
268+
269+
auto parametersRoot = yyjson_doc_get_root(parametersDoc);
270+
if (!yyjson_is_obj(parametersRoot)) {
271+
throw HttpServerException(400, "The `parameters` parameter must be an object");
272+
}
273+
274+
case_insensitive_map_t<BoundParameterData> named_values;
275+
276+
size_t idx, max;
277+
yyjson_val *parameterKeyVal, *parameterVal;
278+
yyjson_obj_foreach(parametersRoot, idx, max, parameterKeyVal, parameterVal) {
279+
auto parameterKeyString = std::string(yyjson_get_str(parameterKeyVal));
280+
281+
named_values[parameterKeyString] = ExtractQueryParameter(parameterKeyString, parameterVal);
282+
}
283+
284+
return named_values;
285+
}
286+
287+
case_insensitive_map_t<BoundParameterData> ExtractQueryParametersWrapper(const duckdb_httplib_openssl::Request& req) {
288+
yyjson_doc *parametersDoc = nullptr;
289+
290+
try {
291+
auto parametersJson = req.get_param_value("parameters");
292+
auto parametersJsonCStr = parametersJson.c_str();
293+
parametersDoc = yyjson_read(parametersJsonCStr, strlen(parametersJsonCStr), 0);
294+
return ExtractQueryParameters(parametersDoc);
295+
}
296+
catch (const Exception& exception) {
297+
yyjson_doc_free(parametersDoc);
298+
299+
throw exception;
300+
}
301+
}
302+
303+
// Execute query (optionally using a prepared statement)
304+
std::unique_ptr<MaterializedQueryResult> ExecuteQuery(
305+
const duckdb_httplib_openssl::Request& req,
306+
const std::string& query
307+
) {
308+
Connection con(*global_state.db_instance);
309+
std::unique_ptr<MaterializedQueryResult> result;
310+
311+
if (req.has_param("parameters")) {
312+
auto prepared_stmt = con.Prepare(query);
313+
if (prepared_stmt->HasError()) {
314+
throw HttpServerException(500, prepared_stmt->GetError());
315+
}
316+
317+
auto named_values = ExtractQueryParametersWrapper(req);
318+
319+
auto prepared_stmt_result = prepared_stmt->Execute(named_values);
320+
D_ASSERT(prepared_stmt_result->type == QueryResultType::STREAM_RESULT);
321+
result = unique_ptr_cast<QueryResult, StreamQueryResult>(std::move(prepared_stmt_result))->Materialize();
322+
} else {
323+
result = con.Query(query);
242324
}
243325

326+
if (result->HasError()) {
327+
throw HttpServerException(500, result->GetError());
328+
}
329+
330+
return result;
331+
}
332+
333+
std::string ExtractQuery(const duckdb_httplib_openssl::Request& req) {
244334
// Check if the query is in the URL parameters
245335
if (req.has_param("query")) {
246-
query = req.get_param_value("query");
336+
return req.get_param_value("query");
247337
}
248338
else if (req.has_param("q")) {
249-
query = req.get_param_value("q");
339+
return req.get_param_value("q");
250340
}
341+
251342
// If not in URL, and it's a POST request, check the body
252343
else if (req.method == "POST" && !req.body.empty()) {
253-
query = req.body;
254-
}
255-
// If no query found, return an error
256-
else {
257-
res.status = 200;
258-
res.set_content(reinterpret_cast<char const*>(playgroundContent), "text/html");
259-
return;
344+
return req.body;
260345
}
261346

262-
// Set default format to JSONCompact
347+
// std::optional is not available for this project
348+
return "";
349+
}
350+
351+
std::string ExtractFormat(const duckdb_httplib_openssl::Request& req) {
263352
std::string format = "JSONEachRow";
264353

265354
// Check for format in URL parameter or header
@@ -271,24 +360,45 @@ void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httpli
271360
format = req.get_header_value("format");
272361
}
273362

363+
return format;
364+
}
365+
366+
// Handle both GET and POST requests
367+
void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) {
274368
try {
369+
CheckAuthentication(req);
370+
371+
// CORS allow
372+
res.set_header("Access-Control-Allow-Origin", "*");
373+
res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT");
374+
res.set_header("Access-Control-Allow-Headers", "*");
375+
res.set_header("Access-Control-Allow-Credentials", "true");
376+
res.set_header("Access-Control-Max-Age", "86400");
377+
378+
// Handle preflight OPTIONS request
379+
if (req.method == "OPTIONS") {
380+
res.status = 204; // No content
381+
return;
382+
}
383+
384+
auto query = ExtractQuery(req);
385+
auto format = ExtractFormat(req);
386+
387+
if (query == "") {
388+
res.status = 200;
389+
res.set_content(reinterpret_cast<char const*>(playgroundContent), sizeof(playgroundContent), "text/html");
390+
return;
391+
}
392+
275393
if (!global_state.db_instance) {
276394
throw IOException("Database instance not initialized");
277395
}
278396

279-
Connection con(*global_state.db_instance);
280397
auto start = std::chrono::system_clock::now();
281-
auto result = con.Query(query);
398+
auto result = ExecuteQuery(req, query);
282399
auto end = std::chrono::system_clock::now();
283400
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
284401

285-
if (result->HasError()) {
286-
res.status = 500;
287-
res.set_content(result->GetError(), "text/plain");
288-
return;
289-
}
290-
291-
292402
ReqStats stats{
293403
static_cast<float>(elapsed.count()) / 1000,
294404
0,
@@ -308,7 +418,12 @@ void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httpli
308418
res.set_content(json_output, "application/x-ndjson");
309419
}
310420

311-
} catch (const Exception& ex) {
421+
}
422+
catch (const HttpServerException& ex) {
423+
res.status = ex.status;
424+
res.set_content(ex.message, "text/plain");
425+
}
426+
catch (const Exception& ex) {
312427
res.status = 500;
313428
std::string error_message = "Code: 59, e.displayText() = DB::Exception: " + std::string(ex.what());
314429
res.set_content(error_message, "text/plain");
@@ -325,9 +440,9 @@ void HttpServerStart(DatabaseInstance& db, string_t host, int32_t port, string_t
325440
global_state.is_running = true;
326441
global_state.auth_token = auth.GetString();
327442

328-
// Custom basepath, defaults to root /
443+
// Custom basepath, defaults to root /
329444
const char* base_path_env = std::getenv("DUCKDB_HTTPSERVER_BASEPATH");
330-
std::string base_path = "/";
445+
std::string base_path = "/";
331446

332447
if (base_path_env && base_path_env[0] == '/' && strlen(base_path_env) > 1) {
333448
base_path = std::string(base_path_env);

0 commit comments

Comments
 (0)