@@ -40,6 +40,13 @@ struct HttpServerState {
40
40
41
41
static HttpServerState global_state;
42
42
43
+ struct HttpServerException : public std ::exception {
44
+ int status;
45
+ std::string message;
46
+
47
+ HttpServerException (int status, const std::string& message) : message(message), status(status) {}
48
+ };
49
+
43
50
std::string GetColumnType (MaterializedQueryResult &result, idx_t column) {
44
51
if (result.RowCount () == 0 ) {
45
52
return " String" ;
@@ -152,28 +159,28 @@ std::string base64_decode(const std::string &in) {
152
159
return out;
153
160
}
154
161
155
- // Auth Check
156
- bool IsAuthenticated (const duckdb_httplib_openssl::Request& req) {
162
+ // Check authentication
163
+ void CheckAuthentication (const duckdb_httplib_openssl::Request& req) {
157
164
if (global_state.auth_token .empty ()) {
158
- return true ; // No authentication required if no token is set
165
+ return ; // No authentication required if no token is set
159
166
}
160
167
161
168
// Check for X-API-Key header
162
169
auto api_key = req.get_header_value (" X-API-Key" );
163
170
if (!api_key.empty () && api_key == global_state.auth_token ) {
164
- return true ;
171
+ return ;
165
172
}
166
173
167
174
// Check for Basic Auth
168
175
auto auth = req.get_header_value (" Authorization" );
169
176
if (!auth.empty () && auth.compare (0 , 6 , " Basic " ) == 0 ) {
170
177
std::string decoded_auth = base64_decode (auth.substr (6 ));
171
178
if (decoded_auth == global_state.auth_token ) {
172
- return true ;
179
+ return ;
173
180
}
174
181
}
175
182
176
- return false ;
183
+ throw HttpServerException ( 401 , " Unauthorized " ) ;
177
184
}
178
185
179
186
// Convert the query result to NDJSON (JSONEachRow) format
@@ -217,49 +224,130 @@ static std::string ConvertResultToNDJSON(MaterializedQueryResult &result) {
217
224
return ndjson_output;
218
225
}
219
226
220
- // Handle both GET and POST requests
221
- void HandleHttpRequest (const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) {
222
- std::string query;
227
+ BoundParameterData ExtractQueryParameter (const std::string& key, yyjson_val* parameterVal) {
228
+ if (!yyjson_is_obj (parameterVal)) {
229
+ throw HttpServerException (400 , " The parameter `" + key + " ` parameter must be an object" );
230
+ }
223
231
224
- // Check authentication
225
- if (!IsAuthenticated (req)) {
226
- res.status = 401 ;
227
- res.set_content (" Unauthorized" , " text/plain" );
228
- return ;
232
+ auto typeVal = yyjson_obj_get (parameterVal, " type" );
233
+ if (!typeVal) {
234
+ throw HttpServerException (400 , " The parameter `" + key + " ` does not have a `type` field" );
229
235
}
236
+ if (!yyjson_is_str (typeVal)) {
237
+ throw HttpServerException (400 , " The field `type` for the parameter `" + key + " ` must be a string" );
238
+ }
239
+ auto type = std::string (yyjson_get_str (typeVal));
230
240
231
- // CORS allow
232
- res.set_header (" Access-Control-Allow-Origin" , " *" );
233
- res.set_header (" Access-Control-Allow-Methods" , " GET, POST, OPTIONS, PUT" );
234
- res.set_header (" Access-Control-Allow-Headers" , " *" );
235
- res.set_header (" Access-Control-Allow-Credentials" , " true" );
236
- res.set_header (" Access-Control-Max-Age" , " 86400" );
241
+ auto valueVal = yyjson_obj_get (parameterVal, " value" );
242
+ if (!valueVal) {
243
+ throw HttpServerException (400 , " The parameter `" + key + " ` does not have a `value` field" );
244
+ }
237
245
238
- // Handle preflight OPTIONS request
239
- if (req.method == " OPTIONS" ) {
240
- res.status = 204 ; // No content
241
- return ;
246
+ if (type == " TEXT" ) {
247
+ if (!yyjson_is_str (valueVal)) {
248
+ throw HttpServerException (400 , " The field `value` for the parameter `" + key + " ` must be a string" );
249
+ }
250
+
251
+ return BoundParameterData (Value (yyjson_get_str (valueVal)));
252
+ }
253
+ else if (type == " BOOLEAN" ) {
254
+ if (!yyjson_is_bool (valueVal)) {
255
+ throw HttpServerException (400 , " The field `value` for the parameter `" + key + " ` must be a boolean" );
256
+ }
257
+
258
+ return BoundParameterData (Value (bool (yyjson_get_bool (valueVal))));
259
+ }
260
+
261
+ throw HttpServerException (400 , " Unsupported type " + type + " the parameter `" + key + " `" );
262
+ }
263
+
264
+ case_insensitive_map_t <BoundParameterData> ExtractQueryParameters (yyjson_doc* parametersDoc) {
265
+ if (!parametersDoc) {
266
+ throw HttpServerException (400 , " Unable to parse the `parameters` parameter" );
267
+ }
268
+
269
+ auto parametersRoot = yyjson_doc_get_root (parametersDoc);
270
+ if (!yyjson_is_obj (parametersRoot)) {
271
+ throw HttpServerException (400 , " The `parameters` parameter must be an object" );
242
272
}
243
273
274
+ case_insensitive_map_t <BoundParameterData> named_values;
275
+
276
+ size_t idx, max;
277
+ yyjson_val *parameterKeyVal, *parameterVal;
278
+ yyjson_obj_foreach (parametersRoot, idx, max, parameterKeyVal, parameterVal) {
279
+ auto parameterKeyString = std::string (yyjson_get_str (parameterKeyVal));
280
+
281
+ named_values[parameterKeyString] = ExtractQueryParameter (parameterKeyString, parameterVal);
282
+ }
283
+
284
+ return named_values;
285
+ }
286
+
287
+ case_insensitive_map_t <BoundParameterData> ExtractQueryParametersWrapper (const duckdb_httplib_openssl::Request& req) {
288
+ yyjson_doc *parametersDoc = nullptr ;
289
+
290
+ try {
291
+ auto parametersJson = req.get_param_value (" parameters" );
292
+ auto parametersJsonCStr = parametersJson.c_str ();
293
+ parametersDoc = yyjson_read (parametersJsonCStr, strlen (parametersJsonCStr), 0 );
294
+ return ExtractQueryParameters (parametersDoc);
295
+ }
296
+ catch (const Exception& exception ) {
297
+ yyjson_doc_free (parametersDoc);
298
+
299
+ throw exception ;
300
+ }
301
+ }
302
+
303
+ // Execute query (optionally using a prepared statement)
304
+ std::unique_ptr<MaterializedQueryResult> ExecuteQuery (
305
+ const duckdb_httplib_openssl::Request& req,
306
+ const std::string& query
307
+ ) {
308
+ Connection con (*global_state.db_instance );
309
+ std::unique_ptr<MaterializedQueryResult> result;
310
+
311
+ if (req.has_param (" parameters" )) {
312
+ auto prepared_stmt = con.Prepare (query);
313
+ if (prepared_stmt->HasError ()) {
314
+ throw HttpServerException (500 , prepared_stmt->GetError ());
315
+ }
316
+
317
+ auto named_values = ExtractQueryParametersWrapper (req);
318
+
319
+ auto prepared_stmt_result = prepared_stmt->Execute (named_values);
320
+ D_ASSERT (prepared_stmt_result->type == QueryResultType::STREAM_RESULT);
321
+ result = unique_ptr_cast<QueryResult, StreamQueryResult>(std::move (prepared_stmt_result))->Materialize ();
322
+ } else {
323
+ result = con.Query (query);
324
+ }
325
+
326
+ if (result->HasError ()) {
327
+ throw HttpServerException (500 , result->GetError ());
328
+ }
329
+
330
+ return result;
331
+ }
332
+
333
+ std::string ExtractQuery (const duckdb_httplib_openssl::Request& req) {
244
334
// Check if the query is in the URL parameters
245
335
if (req.has_param (" query" )) {
246
- query = req.get_param_value (" query" );
336
+ return req.get_param_value (" query" );
247
337
}
248
338
else if (req.has_param (" q" )) {
249
- query = req.get_param_value (" q" );
339
+ return req.get_param_value (" q" );
250
340
}
341
+
251
342
// If not in URL, and it's a POST request, check the body
252
343
else if (req.method == " POST" && !req.body .empty ()) {
253
- query = req.body ;
254
- }
255
- // If no query found, return an error
256
- else {
257
- res.status = 200 ;
258
- res.set_content (playContent, " text/html" );
259
- return ;
344
+ return req.body ;
260
345
}
261
346
262
- // Set default format to JSONCompact
347
+ throw HttpServerException (200 , playContent);
348
+ }
349
+
350
+ std::string ExtractFormat (const duckdb_httplib_openssl::Request& req) {
263
351
std::string format = " JSONEachRow" ;
264
352
265
353
// Check for format in URL parameter or header
@@ -271,24 +359,39 @@ void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httpli
271
359
format = req.get_header_value (" format" );
272
360
}
273
361
362
+ return format;
363
+ }
364
+
365
+ // Handle both GET and POST requests
366
+ void HandleHttpRequest (const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) {
367
+ CheckAuthentication (req);
368
+
369
+ // CORS allow
370
+ res.set_header (" Access-Control-Allow-Origin" , " *" );
371
+ res.set_header (" Access-Control-Allow-Methods" , " GET, POST, OPTIONS, PUT" );
372
+ res.set_header (" Access-Control-Allow-Headers" , " *" );
373
+ res.set_header (" Access-Control-Allow-Credentials" , " true" );
374
+ res.set_header (" Access-Control-Max-Age" , " 86400" );
375
+
376
+ // Handle preflight OPTIONS request
377
+ if (req.method == " OPTIONS" ) {
378
+ res.status = 204 ; // No content
379
+ return ;
380
+ }
381
+
382
+ auto query = ExtractQuery (req);
383
+ auto format = ExtractFormat (req);
384
+
274
385
try {
275
386
if (!global_state.db_instance ) {
276
387
throw IOException (" Database instance not initialized" );
277
388
}
278
389
279
- Connection con (*global_state.db_instance );
280
390
auto start = std::chrono::system_clock::now ();
281
- auto result = con. Query ( query);
391
+ auto result = ExecuteQuery (req, query);
282
392
auto end = std::chrono::system_clock::now ();
283
393
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
284
394
285
- if (result->HasError ()) {
286
- res.status = 500 ;
287
- res.set_content (result->GetError (), " text/plain" );
288
- return ;
289
- }
290
-
291
-
292
395
ReqStats stats{
293
396
static_cast <float >(elapsed.count ()) / 1000 ,
294
397
0 ,
@@ -308,7 +411,12 @@ void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httpli
308
411
res.set_content (json_output, " application/x-ndjson" );
309
412
}
310
413
311
- } catch (const Exception& ex) {
414
+ }
415
+ catch (const HttpServerException& ex) {
416
+ res.status = ex.status ;
417
+ res.set_content (ex.message , " text/plain" );
418
+ }
419
+ catch (const Exception& ex) {
312
420
res.status = 500 ;
313
421
std::string error_message = " Code: 59, e.displayText() = DB::Exception: " + std::string (ex.what ());
314
422
res.set_content (error_message, " text/plain" );
0 commit comments