@@ -40,6 +40,13 @@ struct HttpServerState {
40
40
41
41
static HttpServerState global_state;
42
42
43
+ struct HttpServerException : public std ::exception {
44
+ int status;
45
+ std::string message;
46
+
47
+ HttpServerException (int status, const std::string& message) : message(message), status(status) {}
48
+ };
49
+
43
50
std::string GetColumnType (MaterializedQueryResult &result, idx_t column) {
44
51
if (result.RowCount () == 0 ) {
45
52
return " String" ;
@@ -152,28 +159,28 @@ std::string base64_decode(const std::string &in) {
152
159
return out;
153
160
}
154
161
155
- // Auth Check
156
- bool IsAuthenticated (const duckdb_httplib_openssl::Request& req) {
162
+ // Check authentication
163
+ void CheckAuthentication (const duckdb_httplib_openssl::Request& req) {
157
164
if (global_state.auth_token .empty ()) {
158
- return true ; // No authentication required if no token is set
165
+ return ; // No authentication required if no token is set
159
166
}
160
167
161
168
// Check for X-API-Key header
162
169
auto api_key = req.get_header_value (" X-API-Key" );
163
170
if (!api_key.empty () && api_key == global_state.auth_token ) {
164
- return true ;
171
+ return ;
165
172
}
166
173
167
174
// Check for Basic Auth
168
175
auto auth = req.get_header_value (" Authorization" );
169
176
if (!auth.empty () && auth.compare (0 , 6 , " Basic " ) == 0 ) {
170
177
std::string decoded_auth = base64_decode (auth.substr (6 ));
171
178
if (decoded_auth == global_state.auth_token ) {
172
- return true ;
179
+ return ;
173
180
}
174
181
}
175
182
176
- return false ;
183
+ throw HttpServerException ( 401 , " Unauthorized " ) ;
177
184
}
178
185
179
186
// Convert the query result to NDJSON (JSONEachRow) format
@@ -217,49 +224,131 @@ static std::string ConvertResultToNDJSON(MaterializedQueryResult &result) {
217
224
return ndjson_output;
218
225
}
219
226
220
- // Handle both GET and POST requests
221
- void HandleHttpRequest (const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) {
222
- std::string query;
227
+ BoundParameterData ExtractQueryParameter (const std::string& key, yyjson_val* parameterVal) {
228
+ if (!yyjson_is_obj (parameterVal)) {
229
+ throw HttpServerException (400 , " The parameter `" + key + " ` parameter must be an object" );
230
+ }
223
231
224
- // Check authentication
225
- if (!IsAuthenticated (req)) {
226
- res.status = 401 ;
227
- res.set_content (" Unauthorized" , " text/plain" );
228
- return ;
232
+ auto typeVal = yyjson_obj_get (parameterVal, " type" );
233
+ if (!typeVal) {
234
+ throw HttpServerException (400 , " The parameter `" + key + " ` does not have a `type` field" );
235
+ }
236
+ if (!yyjson_is_str (typeVal)) {
237
+ throw HttpServerException (400 , " The field `type` for the parameter `" + key + " ` must be a string" );
238
+ }
239
+ auto type = std::string (yyjson_get_str (typeVal));
240
+
241
+ auto valueVal = yyjson_obj_get (parameterVal, " value" );
242
+ if (!valueVal) {
243
+ throw HttpServerException (400 , " The parameter `" + key + " ` does not have a `value` field" );
229
244
}
230
245
231
- // CORS allow
232
- res.set_header (" Access-Control-Allow-Origin" , " *" );
233
- res.set_header (" Access-Control-Allow-Methods" , " GET, POST, OPTIONS, PUT" );
234
- res.set_header (" Access-Control-Allow-Headers" , " *" );
235
- res.set_header (" Access-Control-Allow-Credentials" , " true" );
236
- res.set_header (" Access-Control-Max-Age" , " 86400" );
246
+ if (type == " TEXT" ) {
247
+ if (!yyjson_is_str (valueVal)) {
248
+ throw HttpServerException (400 , " The field `value` for the parameter `" + key + " ` must be a string" );
249
+ }
237
250
238
- // Handle preflight OPTIONS request
239
- if (req.method == " OPTIONS" ) {
240
- res.status = 204 ; // No content
241
- return ;
251
+ return BoundParameterData (Value (yyjson_get_str (valueVal)));
252
+ }
253
+ else if (type == " BOOLEAN" ) {
254
+ if (!yyjson_is_bool (valueVal)) {
255
+ throw HttpServerException (400 , " The field `value` for the parameter `" + key + " ` must be a boolean" );
256
+ }
257
+
258
+ return BoundParameterData (Value (bool (yyjson_get_bool (valueVal))));
259
+ }
260
+
261
+ throw HttpServerException (400 , " Unsupported type " + type + " the parameter `" + key + " `" );
262
+ }
263
+
264
+ case_insensitive_map_t <BoundParameterData> ExtractQueryParameters (yyjson_doc* parametersDoc) {
265
+ if (!parametersDoc) {
266
+ throw HttpServerException (400 , " Unable to parse the `parameters` parameter" );
267
+ }
268
+
269
+ auto parametersRoot = yyjson_doc_get_root (parametersDoc);
270
+ if (!yyjson_is_obj (parametersRoot)) {
271
+ throw HttpServerException (400 , " The `parameters` parameter must be an object" );
272
+ }
273
+
274
+ case_insensitive_map_t <BoundParameterData> named_values;
275
+
276
+ size_t idx, max;
277
+ yyjson_val *parameterKeyVal, *parameterVal;
278
+ yyjson_obj_foreach (parametersRoot, idx, max, parameterKeyVal, parameterVal) {
279
+ auto parameterKeyString = std::string (yyjson_get_str (parameterKeyVal));
280
+
281
+ named_values[parameterKeyString] = ExtractQueryParameter (parameterKeyString, parameterVal);
282
+ }
283
+
284
+ return named_values;
285
+ }
286
+
287
+ case_insensitive_map_t <BoundParameterData> ExtractQueryParametersWrapper (const duckdb_httplib_openssl::Request& req) {
288
+ yyjson_doc *parametersDoc = nullptr ;
289
+
290
+ try {
291
+ auto parametersJson = req.get_param_value (" parameters" );
292
+ auto parametersJsonCStr = parametersJson.c_str ();
293
+ parametersDoc = yyjson_read (parametersJsonCStr, strlen (parametersJsonCStr), 0 );
294
+ return ExtractQueryParameters (parametersDoc);
295
+ }
296
+ catch (const Exception& exception ) {
297
+ yyjson_doc_free (parametersDoc);
298
+
299
+ throw exception ;
300
+ }
301
+ }
302
+
303
+ // Execute query (optionally using a prepared statement)
304
+ std::unique_ptr<MaterializedQueryResult> ExecuteQuery (
305
+ const duckdb_httplib_openssl::Request& req,
306
+ const std::string& query
307
+ ) {
308
+ Connection con (*global_state.db_instance );
309
+ std::unique_ptr<MaterializedQueryResult> result;
310
+
311
+ if (req.has_param (" parameters" )) {
312
+ auto prepared_stmt = con.Prepare (query);
313
+ if (prepared_stmt->HasError ()) {
314
+ throw HttpServerException (500 , prepared_stmt->GetError ());
315
+ }
316
+
317
+ auto named_values = ExtractQueryParametersWrapper (req);
318
+
319
+ auto prepared_stmt_result = prepared_stmt->Execute (named_values);
320
+ D_ASSERT (prepared_stmt_result->type == QueryResultType::STREAM_RESULT);
321
+ result = unique_ptr_cast<QueryResult, StreamQueryResult>(std::move (prepared_stmt_result))->Materialize ();
322
+ } else {
323
+ result = con.Query (query);
242
324
}
243
325
326
+ if (result->HasError ()) {
327
+ throw HttpServerException (500 , result->GetError ());
328
+ }
329
+
330
+ return result;
331
+ }
332
+
333
+ std::string ExtractQuery (const duckdb_httplib_openssl::Request& req) {
244
334
// Check if the query is in the URL parameters
245
335
if (req.has_param (" query" )) {
246
- query = req.get_param_value (" query" );
336
+ return req.get_param_value (" query" );
247
337
}
248
338
else if (req.has_param (" q" )) {
249
- query = req.get_param_value (" q" );
339
+ return req.get_param_value (" q" );
250
340
}
341
+
251
342
// If not in URL, and it's a POST request, check the body
252
343
else if (req.method == " POST" && !req.body .empty ()) {
253
- query = req.body ;
254
- }
255
- // If no query found, return an error
256
- else {
257
- res.status = 200 ;
258
- res.set_content (reinterpret_cast <char const *>(playgroundContent), " text/html" );
259
- return ;
344
+ return req.body ;
260
345
}
261
346
262
- // Set default format to JSONCompact
347
+ // std::optional is not available for this project
348
+ return " " ;
349
+ }
350
+
351
+ std::string ExtractFormat (const duckdb_httplib_openssl::Request& req) {
263
352
std::string format = " JSONEachRow" ;
264
353
265
354
// Check for format in URL parameter or header
@@ -271,24 +360,45 @@ void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httpli
271
360
format = req.get_header_value (" format" );
272
361
}
273
362
363
+ return format;
364
+ }
365
+
366
+ // Handle both GET and POST requests
367
+ void HandleHttpRequest (const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) {
274
368
try {
369
+ CheckAuthentication (req);
370
+
371
+ // CORS allow
372
+ res.set_header (" Access-Control-Allow-Origin" , " *" );
373
+ res.set_header (" Access-Control-Allow-Methods" , " GET, POST, OPTIONS, PUT" );
374
+ res.set_header (" Access-Control-Allow-Headers" , " *" );
375
+ res.set_header (" Access-Control-Allow-Credentials" , " true" );
376
+ res.set_header (" Access-Control-Max-Age" , " 86400" );
377
+
378
+ // Handle preflight OPTIONS request
379
+ if (req.method == " OPTIONS" ) {
380
+ res.status = 204 ; // No content
381
+ return ;
382
+ }
383
+
384
+ auto query = ExtractQuery (req);
385
+ auto format = ExtractFormat (req);
386
+
387
+ if (query == " " ) {
388
+ res.status = 200 ;
389
+ res.set_content (reinterpret_cast <char const *>(playgroundContent), sizeof (playgroundContent), " text/html" );
390
+ return ;
391
+ }
392
+
275
393
if (!global_state.db_instance ) {
276
394
throw IOException (" Database instance not initialized" );
277
395
}
278
396
279
- Connection con (*global_state.db_instance );
280
397
auto start = std::chrono::system_clock::now ();
281
- auto result = con. Query ( query);
398
+ auto result = ExecuteQuery (req, query);
282
399
auto end = std::chrono::system_clock::now ();
283
400
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
284
401
285
- if (result->HasError ()) {
286
- res.status = 500 ;
287
- res.set_content (result->GetError (), " text/plain" );
288
- return ;
289
- }
290
-
291
-
292
402
ReqStats stats{
293
403
static_cast <float >(elapsed.count ()) / 1000 ,
294
404
0 ,
@@ -308,7 +418,12 @@ void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httpli
308
418
res.set_content (json_output, " application/x-ndjson" );
309
419
}
310
420
311
- } catch (const Exception& ex) {
421
+ }
422
+ catch (const HttpServerException& ex) {
423
+ res.status = ex.status ;
424
+ res.set_content (ex.message , " text/plain" );
425
+ }
426
+ catch (const Exception& ex) {
312
427
res.status = 500 ;
313
428
std::string error_message = " Code: 59, e.displayText() = DB::Exception: " + std::string (ex.what ());
314
429
res.set_content (error_message, " text/plain" );
@@ -325,9 +440,9 @@ void HttpServerStart(DatabaseInstance& db, string_t host, int32_t port, string_t
325
440
global_state.is_running = true ;
326
441
global_state.auth_token = auth.GetString ();
327
442
328
- // Custom basepath, defaults to root /
443
+ // Custom basepath, defaults to root /
329
444
const char * base_path_env = std::getenv (" DUCKDB_HTTPSERVER_BASEPATH" );
330
- std::string base_path = " /" ;
445
+ std::string base_path = " /" ;
331
446
332
447
if (base_path_env && base_path_env[0 ] == ' /' && strlen (base_path_env) > 1 ) {
333
448
base_path = std::string (base_path_env);
0 commit comments