@@ -40,6 +40,13 @@ struct HttpServerState {
4040
4141static HttpServerState global_state;
4242
43+ struct HttpServerException : public std ::exception {
44+ int status;
45+ std::string message;
46+
47+ HttpServerException (int status, const std::string& message) : message(message), status(status) {}
48+ };
49+
4350std::string GetColumnType (MaterializedQueryResult &result, idx_t column) {
4451 if (result.RowCount () == 0 ) {
4552 return " String" ;
@@ -152,28 +159,28 @@ std::string base64_decode(const std::string &in) {
152159 return out;
153160}
154161
155- // Auth Check
156- bool IsAuthenticated (const duckdb_httplib_openssl::Request& req) {
162+ // Check authentication
163+ void CheckAuthentication (const duckdb_httplib_openssl::Request& req) {
157164 if (global_state.auth_token .empty ()) {
158- return true ; // No authentication required if no token is set
165+ return ; // No authentication required if no token is set
159166 }
160167
161168 // Check for X-API-Key header
162169 auto api_key = req.get_header_value (" X-API-Key" );
163170 if (!api_key.empty () && api_key == global_state.auth_token ) {
164- return true ;
171+ return ;
165172 }
166173
167174 // Check for Basic Auth
168175 auto auth = req.get_header_value (" Authorization" );
169176 if (!auth.empty () && auth.compare (0 , 6 , " Basic " ) == 0 ) {
170177 std::string decoded_auth = base64_decode (auth.substr (6 ));
171178 if (decoded_auth == global_state.auth_token ) {
172- return true ;
179+ return ;
173180 }
174181 }
175182
176- return false ;
183+ throw HttpServerException ( 401 , " Unauthorized " ) ;
177184}
178185
179186// Convert the query result to NDJSON (JSONEachRow) format
@@ -217,49 +224,130 @@ static std::string ConvertResultToNDJSON(MaterializedQueryResult &result) {
217224 return ndjson_output;
218225}
219226
220- // Handle both GET and POST requests
221- void HandleHttpRequest (const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) {
222- std::string query;
227+ BoundParameterData ExtractQueryParameter (const std::string& key, yyjson_val* parameterVal) {
228+ if (!yyjson_is_obj (parameterVal)) {
229+ throw HttpServerException (400 , " The parameter `" + key + " ` parameter must be an object" );
230+ }
223231
224- // Check authentication
225- if (!IsAuthenticated (req)) {
226- res.status = 401 ;
227- res.set_content (" Unauthorized" , " text/plain" );
228- return ;
232+ auto typeVal = yyjson_obj_get (parameterVal, " type" );
233+ if (!typeVal) {
234+ throw HttpServerException (400 , " The parameter `" + key + " ` does not have a `type` field" );
229235 }
236+ if (!yyjson_is_str (typeVal)) {
237+ throw HttpServerException (400 , " The field `type` for the parameter `" + key + " ` must be a string" );
238+ }
239+ auto type = std::string (yyjson_get_str (typeVal));
230240
231- // CORS allow
232- res.set_header (" Access-Control-Allow-Origin" , " *" );
233- res.set_header (" Access-Control-Allow-Methods" , " GET, POST, OPTIONS, PUT" );
234- res.set_header (" Access-Control-Allow-Headers" , " *" );
235- res.set_header (" Access-Control-Allow-Credentials" , " true" );
236- res.set_header (" Access-Control-Max-Age" , " 86400" );
241+ auto valueVal = yyjson_obj_get (parameterVal, " value" );
242+ if (!valueVal) {
243+ throw HttpServerException (400 , " The parameter `" + key + " ` does not have a `value` field" );
244+ }
237245
238- // Handle preflight OPTIONS request
239- if (req.method == " OPTIONS" ) {
240- res.status = 204 ; // No content
241- return ;
246+ if (type == " TEXT" ) {
247+ if (!yyjson_is_str (valueVal)) {
248+ throw HttpServerException (400 , " The field `value` for the parameter `" + key + " ` must be a string" );
249+ }
250+
251+ return BoundParameterData (Value (yyjson_get_str (valueVal)));
252+ }
253+ else if (type == " BOOLEAN" ) {
254+ if (!yyjson_is_bool (valueVal)) {
255+ throw HttpServerException (400 , " The field `value` for the parameter `" + key + " ` must be a boolean" );
256+ }
257+
258+ return BoundParameterData (Value (bool (yyjson_get_bool (valueVal))));
259+ }
260+
261+ throw HttpServerException (400 , " Unsupported type " + type + " the parameter `" + key + " `" );
262+ }
263+
264+ case_insensitive_map_t <BoundParameterData> ExtractQueryParameters (yyjson_doc* parametersDoc) {
265+ if (!parametersDoc) {
266+ throw HttpServerException (400 , " Unable to parse the `parameters` parameter" );
267+ }
268+
269+ auto parametersRoot = yyjson_doc_get_root (parametersDoc);
270+ if (!yyjson_is_obj (parametersRoot)) {
271+ throw HttpServerException (400 , " The `parameters` parameter must be an object" );
242272 }
243273
274+ case_insensitive_map_t <BoundParameterData> named_values;
275+
276+ size_t idx, max;
277+ yyjson_val *parameterKeyVal, *parameterVal;
278+ yyjson_obj_foreach (parametersRoot, idx, max, parameterKeyVal, parameterVal) {
279+ auto parameterKeyString = std::string (yyjson_get_str (parameterKeyVal));
280+
281+ named_values[parameterKeyString] = ExtractQueryParameter (parameterKeyString, parameterVal);
282+ }
283+
284+ return named_values;
285+ }
286+
287+ case_insensitive_map_t <BoundParameterData> ExtractQueryParametersWrapper (const duckdb_httplib_openssl::Request& req) {
288+ yyjson_doc *parametersDoc = nullptr ;
289+
290+ try {
291+ auto parametersJson = req.get_param_value (" parameters" );
292+ auto parametersJsonCStr = parametersJson.c_str ();
293+ parametersDoc = yyjson_read (parametersJsonCStr, strlen (parametersJsonCStr), 0 );
294+ return ExtractQueryParameters (parametersDoc);
295+ }
296+ catch (const Exception& exception) {
297+ yyjson_doc_free (parametersDoc);
298+
299+ throw exception;
300+ }
301+ }
302+
303+ // Execute query (optionally using a prepared statement)
304+ std::unique_ptr<MaterializedQueryResult> ExecuteQuery (
305+ const duckdb_httplib_openssl::Request& req,
306+ const std::string& query
307+ ) {
308+ Connection con (*global_state.db_instance );
309+ std::unique_ptr<MaterializedQueryResult> result;
310+
311+ if (req.has_param (" parameters" )) {
312+ auto prepared_stmt = con.Prepare (query);
313+ if (prepared_stmt->HasError ()) {
314+ throw HttpServerException (500 , prepared_stmt->GetError ());
315+ }
316+
317+ auto named_values = ExtractQueryParametersWrapper (req);
318+
319+ auto prepared_stmt_result = prepared_stmt->Execute (named_values);
320+ D_ASSERT (prepared_stmt_result->type == QueryResultType::STREAM_RESULT);
321+ result = unique_ptr_cast<QueryResult, StreamQueryResult>(std::move (prepared_stmt_result))->Materialize ();
322+ } else {
323+ result = con.Query (query);
324+ }
325+
326+ if (result->HasError ()) {
327+ throw HttpServerException (500 , result->GetError ());
328+ }
329+
330+ return result;
331+ }
332+
333+ std::string ExtractQuery (const duckdb_httplib_openssl::Request& req) {
244334 // Check if the query is in the URL parameters
245335 if (req.has_param (" query" )) {
246- query = req.get_param_value (" query" );
336+ return req.get_param_value (" query" );
247337 }
248338 else if (req.has_param (" q" )) {
249- query = req.get_param_value (" q" );
339+ return req.get_param_value (" q" );
250340 }
341+
251342 // If not in URL, and it's a POST request, check the body
252343 else if (req.method == " POST" && !req.body .empty ()) {
253- query = req.body ;
254- }
255- // If no query found, return an error
256- else {
257- res.status = 200 ;
258- res.set_content (playContent, " text/html" );
259- return ;
344+ return req.body ;
260345 }
261346
262- // Set default format to JSONCompact
347+ throw HttpServerException (200 , playContent);
348+ }
349+
350+ std::string ExtractFormat (const duckdb_httplib_openssl::Request& req) {
263351 std::string format = " JSONEachRow" ;
264352
265353 // Check for format in URL parameter or header
@@ -271,24 +359,39 @@ void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httpli
271359 format = req.get_header_value (" format" );
272360 }
273361
362+ return format;
363+ }
364+
365+ // Handle both GET and POST requests
366+ void HandleHttpRequest (const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) {
367+ CheckAuthentication (req);
368+
369+ // CORS allow
370+ res.set_header (" Access-Control-Allow-Origin" , " *" );
371+ res.set_header (" Access-Control-Allow-Methods" , " GET, POST, OPTIONS, PUT" );
372+ res.set_header (" Access-Control-Allow-Headers" , " *" );
373+ res.set_header (" Access-Control-Allow-Credentials" , " true" );
374+ res.set_header (" Access-Control-Max-Age" , " 86400" );
375+
376+ // Handle preflight OPTIONS request
377+ if (req.method == " OPTIONS" ) {
378+ res.status = 204 ; // No content
379+ return ;
380+ }
381+
382+ auto query = ExtractQuery (req);
383+ auto format = ExtractFormat (req);
384+
274385 try {
275386 if (!global_state.db_instance ) {
276387 throw IOException (" Database instance not initialized" );
277388 }
278389
279- Connection con (*global_state.db_instance );
280390 auto start = std::chrono::system_clock::now ();
281- auto result = con. Query ( query);
391+ auto result = ExecuteQuery (req, query);
282392 auto end = std::chrono::system_clock::now ();
283393 auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
284394
285- if (result->HasError ()) {
286- res.status = 500 ;
287- res.set_content (result->GetError (), " text/plain" );
288- return ;
289- }
290-
291-
292395 ReqStats stats{
293396 static_cast <float >(elapsed.count ()) / 1000 ,
294397 0 ,
@@ -308,7 +411,12 @@ void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httpli
308411 res.set_content (json_output, " application/x-ndjson" );
309412 }
310413
311- } catch (const Exception& ex) {
414+ }
415+ catch (const HttpServerException& ex) {
416+ res.status = ex.status ;
417+ res.set_content (ex.message , " text/plain" );
418+ }
419+ catch (const Exception& ex) {
312420 res.status = 500 ;
313421 std::string error_message = " Code: 59, e.displayText() = DB::Exception: " + std::string (ex.what ());
314422 res.set_content (error_message, " text/plain" );
0 commit comments