Skip to content

Commit ac93207

Browse files
committed
Support prepared statements and parameters
1 parent a95d01a commit ac93207

File tree

1 file changed

+152
-44
lines changed

1 file changed

+152
-44
lines changed

src/httpserver_extension.cpp

+152-44
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ struct HttpServerState {
4040

4141
static HttpServerState global_state;
4242

43+
struct HttpServerException: public std::exception {
44+
int status;
45+
std::string message;
46+
47+
HttpServerException(int status, const std::string& message) : message(message), status(status) {}
48+
};
49+
4350
std::string GetColumnType(MaterializedQueryResult &result, idx_t column) {
4451
if (result.RowCount() == 0) {
4552
return "String";
@@ -152,28 +159,28 @@ std::string base64_decode(const std::string &in) {
152159
return out;
153160
}
154161

155-
// Auth Check
156-
bool IsAuthenticated(const duckdb_httplib_openssl::Request& req) {
162+
// Check authentication
163+
void CheckAuthentication(const duckdb_httplib_openssl::Request& req) {
157164
if (global_state.auth_token.empty()) {
158-
return true; // No authentication required if no token is set
165+
return; // No authentication required if no token is set
159166
}
160167

161168
// Check for X-API-Key header
162169
auto api_key = req.get_header_value("X-API-Key");
163170
if (!api_key.empty() && api_key == global_state.auth_token) {
164-
return true;
171+
return;
165172
}
166173

167174
// Check for Basic Auth
168175
auto auth = req.get_header_value("Authorization");
169176
if (!auth.empty() && auth.compare(0, 6, "Basic ") == 0) {
170177
std::string decoded_auth = base64_decode(auth.substr(6));
171178
if (decoded_auth == global_state.auth_token) {
172-
return true;
179+
return;
173180
}
174181
}
175182

176-
return false;
183+
throw HttpServerException(401, "Unauthorized");
177184
}
178185

179186
// Convert the query result to NDJSON (JSONEachRow) format
@@ -217,49 +224,130 @@ static std::string ConvertResultToNDJSON(MaterializedQueryResult &result) {
217224
return ndjson_output;
218225
}
219226

220-
// Handle both GET and POST requests
221-
void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) {
222-
std::string query;
227+
BoundParameterData ExtractQueryParameter(const std::string& key, yyjson_val* parameterVal) {
228+
if (!yyjson_is_obj(parameterVal)) {
229+
throw HttpServerException(400, "The parameter `" + key + "` parameter must be an object");
230+
}
223231

224-
// Check authentication
225-
if (!IsAuthenticated(req)) {
226-
res.status = 401;
227-
res.set_content("Unauthorized", "text/plain");
228-
return;
232+
auto typeVal = yyjson_obj_get(parameterVal, "type");
233+
if (!typeVal) {
234+
throw HttpServerException(400, "The parameter `" + key + "` does not have a `type` field");
229235
}
236+
if (!yyjson_is_str(typeVal)) {
237+
throw HttpServerException(400, "The field `type` for the parameter `" + key + "` must be a string");
238+
}
239+
auto type = std::string(yyjson_get_str(typeVal));
230240

231-
// CORS allow
232-
res.set_header("Access-Control-Allow-Origin", "*");
233-
res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT");
234-
res.set_header("Access-Control-Allow-Headers", "*");
235-
res.set_header("Access-Control-Allow-Credentials", "true");
236-
res.set_header("Access-Control-Max-Age", "86400");
241+
auto valueVal = yyjson_obj_get(parameterVal, "value");
242+
if (!valueVal) {
243+
throw HttpServerException(400, "The parameter `" + key + "` does not have a `value` field");
244+
}
237245

238-
// Handle preflight OPTIONS request
239-
if (req.method == "OPTIONS") {
240-
res.status = 204; // No content
241-
return;
246+
if (type == "TEXT") {
247+
if (!yyjson_is_str(valueVal)) {
248+
throw HttpServerException(400, "The field `value` for the parameter `" + key + "` must be a string");
249+
}
250+
251+
return BoundParameterData(Value(yyjson_get_str(valueVal)));
252+
}
253+
else if (type == "BOOLEAN") {
254+
if (!yyjson_is_bool(valueVal)) {
255+
throw HttpServerException(400, "The field `value` for the parameter `" + key + "` must be a boolean");
256+
}
257+
258+
return BoundParameterData(Value(bool(yyjson_get_bool(valueVal))));
259+
}
260+
261+
throw HttpServerException(400, "Unsupported type " + type + " the parameter `" + key + "`");
262+
}
263+
264+
case_insensitive_map_t<BoundParameterData> ExtractQueryParameters(yyjson_doc* parametersDoc) {
265+
if (!parametersDoc) {
266+
throw HttpServerException(400, "Unable to parse the `parameters` parameter");
267+
}
268+
269+
auto parametersRoot = yyjson_doc_get_root(parametersDoc);
270+
if (!yyjson_is_obj(parametersRoot)) {
271+
throw HttpServerException(400, "The `parameters` parameter must be an object");
242272
}
243273

274+
case_insensitive_map_t<BoundParameterData> named_values;
275+
276+
size_t idx, max;
277+
yyjson_val *parameterKeyVal, *parameterVal;
278+
yyjson_obj_foreach(parametersRoot, idx, max, parameterKeyVal, parameterVal) {
279+
auto parameterKeyString = std::string(yyjson_get_str(parameterKeyVal));
280+
281+
named_values[parameterKeyString] = ExtractQueryParameter(parameterKeyString, parameterVal);
282+
}
283+
284+
return named_values;
285+
}
286+
287+
case_insensitive_map_t<BoundParameterData> ExtractQueryParametersWrapper(const duckdb_httplib_openssl::Request& req) {
288+
yyjson_doc *parametersDoc = nullptr;
289+
290+
try {
291+
auto parametersJson = req.get_param_value("parameters");
292+
auto parametersJsonCStr = parametersJson.c_str();
293+
parametersDoc = yyjson_read(parametersJsonCStr, strlen(parametersJsonCStr), 0);
294+
return ExtractQueryParameters(parametersDoc);
295+
}
296+
catch (const Exception& exception) {
297+
yyjson_doc_free(parametersDoc);
298+
299+
throw exception;
300+
}
301+
}
302+
303+
// Execute query (optionally using a prepared statement)
304+
std::unique_ptr<MaterializedQueryResult> ExecuteQuery(
305+
const duckdb_httplib_openssl::Request& req,
306+
const std::string& query
307+
) {
308+
Connection con(*global_state.db_instance);
309+
std::unique_ptr<MaterializedQueryResult> result;
310+
311+
if (req.has_param("parameters")) {
312+
auto prepared_stmt = con.Prepare(query);
313+
if (prepared_stmt->HasError()) {
314+
throw HttpServerException(500, prepared_stmt->GetError());
315+
}
316+
317+
auto named_values = ExtractQueryParametersWrapper(req);
318+
319+
auto prepared_stmt_result = prepared_stmt->Execute(named_values);
320+
D_ASSERT(prepared_stmt_result->type == QueryResultType::STREAM_RESULT);
321+
result = unique_ptr_cast<QueryResult, StreamQueryResult>(std::move(prepared_stmt_result))->Materialize();
322+
} else {
323+
result = con.Query(query);
324+
}
325+
326+
if (result->HasError()) {
327+
throw HttpServerException(500, result->GetError());
328+
}
329+
330+
return result;
331+
}
332+
333+
std::string ExtractQuery(const duckdb_httplib_openssl::Request& req) {
244334
// Check if the query is in the URL parameters
245335
if (req.has_param("query")) {
246-
query = req.get_param_value("query");
336+
return req.get_param_value("query");
247337
}
248338
else if (req.has_param("q")) {
249-
query = req.get_param_value("q");
339+
return req.get_param_value("q");
250340
}
341+
251342
// If not in URL, and it's a POST request, check the body
252343
else if (req.method == "POST" && !req.body.empty()) {
253-
query = req.body;
254-
}
255-
// If no query found, return an error
256-
else {
257-
res.status = 200;
258-
res.set_content(playContent, "text/html");
259-
return;
344+
return req.body;
260345
}
261346

262-
// Set default format to JSONCompact
347+
throw HttpServerException(200, playContent);
348+
}
349+
350+
std::string ExtractFormat(const duckdb_httplib_openssl::Request& req) {
263351
std::string format = "JSONEachRow";
264352

265353
// Check for format in URL parameter or header
@@ -271,24 +359,39 @@ void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httpli
271359
format = req.get_header_value("format");
272360
}
273361

362+
return format;
363+
}
364+
365+
// Handle both GET and POST requests
366+
void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) {
367+
CheckAuthentication(req);
368+
369+
// CORS allow
370+
res.set_header("Access-Control-Allow-Origin", "*");
371+
res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT");
372+
res.set_header("Access-Control-Allow-Headers", "*");
373+
res.set_header("Access-Control-Allow-Credentials", "true");
374+
res.set_header("Access-Control-Max-Age", "86400");
375+
376+
// Handle preflight OPTIONS request
377+
if (req.method == "OPTIONS") {
378+
res.status = 204; // No content
379+
return;
380+
}
381+
382+
auto query = ExtractQuery(req);
383+
auto format = ExtractFormat(req);
384+
274385
try {
275386
if (!global_state.db_instance) {
276387
throw IOException("Database instance not initialized");
277388
}
278389

279-
Connection con(*global_state.db_instance);
280390
auto start = std::chrono::system_clock::now();
281-
auto result = con.Query(query);
391+
auto result = ExecuteQuery(req, query);
282392
auto end = std::chrono::system_clock::now();
283393
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
284394

285-
if (result->HasError()) {
286-
res.status = 500;
287-
res.set_content(result->GetError(), "text/plain");
288-
return;
289-
}
290-
291-
292395
ReqStats stats{
293396
static_cast<float>(elapsed.count()) / 1000,
294397
0,
@@ -308,7 +411,12 @@ void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httpli
308411
res.set_content(json_output, "application/x-ndjson");
309412
}
310413

311-
} catch (const Exception& ex) {
414+
}
415+
catch (const HttpServerException& ex) {
416+
res.status = ex.status;
417+
res.set_content(ex.message, "text/plain");
418+
}
419+
catch (const Exception& ex) {
312420
res.status = 500;
313421
std::string error_message = "Code: 59, e.displayText() = DB::Exception: " + std::string(ex.what());
314422
res.set_content(error_message, "text/plain");

0 commit comments

Comments
 (0)