Skip to content

Commit f803e12

Browse files
committed
Support prepared statements and parameters
1 parent bd381c1 commit f803e12

18 files changed

+823
-326
lines changed

CMakeLists.txt

+5-1
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@ include_directories(
1616
# Embed ./src/assets/index.html as a C++ header
1717
add_custom_command(
1818
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/playground.hpp
19-
COMMAND ${CMAKE_COMMAND} -P ${PROJECT_SOURCE_DIR}/embed.cmake ${PROJECT_SOURCE_DIR}/src/assets/index.html ${CMAKE_CURRENT_BINARY_DIR}/playground.hpp playgroundContent
19+
COMMAND ${CMAKE_COMMAND} -P ${PROJECT_SOURCE_DIR}/embed.cmake ${PROJECT_SOURCE_DIR}/src/assets/index.html ${CMAKE_CURRENT_BINARY_DIR}/httpserver_extension/http_handler/playground.hpp playgroundContent
2020
DEPENDS ${PROJECT_SOURCE_DIR}/src/assets/index.html
2121
)
2222

2323
set(EXTENSION_SOURCES
2424
src/httpserver_extension.cpp
25+
src/http_handler/authentication.cpp
26+
src/http_handler/bindings.cpp
27+
src/http_handler/handler.cpp
28+
src/http_handler/response_serializer.cpp
2529
${CMAKE_CURRENT_BINARY_DIR}/playground.hpp
2630
)
2731

src/http_handler/authentication.cpp

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#include "httpserver_extension/http_handler/common.hpp"
2+
#include "httpserver_extension/state.hpp"
3+
#include <string>
4+
#include <vector>
5+
6+
#define CPPHTTPLIB_OPENSSL_SUPPORT
7+
#include "httplib.hpp"
8+
9+
namespace duckdb_httpserver {
10+
11+
// Base64 decoding function
12+
static std::string base64_decode(const std::string &in) {
13+
std::string out;
14+
std::vector<int> T(256, -1);
15+
for (int i = 0; i < 64; i++)
16+
T["ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"[i]] = i;
17+
18+
int val = 0, valb = -8;
19+
for (unsigned char c : in) {
20+
if (T[c] == -1) break;
21+
val = (val << 6) + T[c];
22+
valb += 6;
23+
if (valb >= 0) {
24+
out.push_back(char((val >> valb) & 0xFF));
25+
valb -= 8;
26+
}
27+
}
28+
return out;
29+
}
30+
31+
// Check authentication
32+
void CheckAuthentication(const duckdb_httplib_openssl::Request& req) {
33+
if (global_state.auth_token.empty()) {
34+
return; // No authentication required if no token is set
35+
}
36+
37+
// Check for X-API-Key header
38+
auto api_key = req.get_header_value("X-API-Key");
39+
if (!api_key.empty() && api_key == global_state.auth_token) {
40+
return;
41+
}
42+
43+
// Check for Basic Auth
44+
auto auth = req.get_header_value("Authorization");
45+
if (!auth.empty() && auth.compare(0, 6, "Basic ") == 0) {
46+
std::string decoded_auth = base64_decode(auth.substr(6));
47+
if (decoded_auth == global_state.auth_token) {
48+
return;
49+
}
50+
}
51+
52+
throw HttpHandlerException(401, "Unauthorized");
53+
}
54+
55+
} // namespace duckdb_httpserver

src/http_handler/bindings.cpp

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#include "httpserver_extension/http_handler/common.hpp"
2+
#include "duckdb.hpp"
3+
4+
#include "yyjson.hpp"
5+
#include <string>
6+
7+
#define CPPHTTPLIB_OPENSSL_SUPPORT
8+
#include "httplib.hpp"
9+
10+
using namespace duckdb;
11+
using namespace duckdb_yyjson;
12+
13+
namespace duckdb_httpserver {
14+
15+
static BoundParameterData ExtractQueryParameter(const std::string& key, yyjson_val* parameterVal) {
16+
if (!yyjson_is_obj(parameterVal)) {
17+
throw HttpHandlerException(400, "The parameter `" + key + "` parameter must be an object");
18+
}
19+
20+
auto typeVal = yyjson_obj_get(parameterVal, "type");
21+
if (!typeVal) {
22+
throw HttpHandlerException(400, "The parameter `" + key + "` does not have a `type` field");
23+
}
24+
if (!yyjson_is_str(typeVal)) {
25+
throw HttpHandlerException(400, "The field `type` for the parameter `" + key + "` must be a string");
26+
}
27+
auto type = std::string(yyjson_get_str(typeVal));
28+
29+
auto valueVal = yyjson_obj_get(parameterVal, "value");
30+
if (!valueVal) {
31+
throw HttpHandlerException(400, "The parameter `" + key + "` does not have a `value` field");
32+
}
33+
34+
if (type == "TEXT") {
35+
if (!yyjson_is_str(valueVal)) {
36+
throw HttpHandlerException(400, "The field `value` for the parameter `" + key + "` must be a string");
37+
}
38+
39+
return BoundParameterData(Value(yyjson_get_str(valueVal)));
40+
}
41+
else if (type == "BOOLEAN") {
42+
if (!yyjson_is_bool(valueVal)) {
43+
throw HttpHandlerException(400, "The field `value` for the parameter `" + key + "` must be a boolean");
44+
}
45+
46+
return BoundParameterData(Value(bool(yyjson_get_bool(valueVal))));
47+
}
48+
49+
throw HttpHandlerException(400, "Unsupported type " + type + " the parameter `" + key + "`");
50+
}
51+
52+
static case_insensitive_map_t<BoundParameterData> ExtractQueryParametersImpl(yyjson_doc* parametersDoc) {
53+
if (!parametersDoc) {
54+
throw HttpHandlerException(400, "Unable to parse the `parameters` parameter");
55+
}
56+
57+
auto parametersRoot = yyjson_doc_get_root(parametersDoc);
58+
if (!yyjson_is_obj(parametersRoot)) {
59+
throw HttpHandlerException(400, "The `parameters` parameter must be an object");
60+
}
61+
62+
case_insensitive_map_t<BoundParameterData> named_values;
63+
64+
size_t idx, max;
65+
yyjson_val *parameterKeyVal, *parameterVal;
66+
yyjson_obj_foreach(parametersRoot, idx, max, parameterKeyVal, parameterVal) {
67+
auto parameterKeyString = std::string(yyjson_get_str(parameterKeyVal));
68+
69+
named_values[parameterKeyString] = ExtractQueryParameter(parameterKeyString, parameterVal);
70+
}
71+
72+
return named_values;
73+
}
74+
75+
case_insensitive_map_t<BoundParameterData> ExtractQueryParameters(const duckdb_httplib_openssl::Request& req) {
76+
yyjson_doc *parametersDoc = nullptr;
77+
78+
try {
79+
auto parametersJson = req.get_param_value("parameters");
80+
auto parametersJsonCStr = parametersJson.c_str();
81+
parametersDoc = yyjson_read(parametersJsonCStr, strlen(parametersJsonCStr), 0);
82+
return ExtractQueryParametersImpl(parametersDoc);
83+
}
84+
catch (const Exception& exception) {
85+
yyjson_doc_free(parametersDoc);
86+
87+
throw exception;
88+
}
89+
}
90+
91+
} // namespace duckdb_httpserver

src/http_handler/handler.cpp

+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
#include "httpserver_extension/http_handler/authentication.hpp"
2+
#include "httpserver_extension/http_handler/bindings.hpp"
3+
#include "httpserver_extension/http_handler/common.hpp"
4+
#include "httpserver_extension/http_handler/handler.hpp"
5+
#include "httpserver_extension/http_handler/playground.hpp"
6+
#include "httpserver_extension/http_handler/response_serializer.hpp"
7+
#include "httpserver_extension/state.hpp"
8+
#include "duckdb.hpp"
9+
10+
#include <string>
11+
#include <vector>
12+
13+
#define CPPHTTPLIB_OPENSSL_SUPPORT
14+
#include "httplib.hpp"
15+
16+
using namespace duckdb;
17+
18+
namespace duckdb_httpserver {
19+
20+
// Handle both GET and POST requests
21+
void HttpHandler(const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) {
22+
try {
23+
CheckAuthentication(req);
24+
25+
// CORS allow
26+
res.set_header("Access-Control-Allow-Origin", "*");
27+
res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT");
28+
res.set_header("Access-Control-Allow-Headers", "*");
29+
res.set_header("Access-Control-Allow-Credentials", "true");
30+
res.set_header("Access-Control-Max-Age", "86400");
31+
32+
// Handle preflight OPTIONS request
33+
if (req.method == "OPTIONS") {
34+
res.status = 204; // No content
35+
return;
36+
}
37+
38+
auto query = ExtractQuery(req);
39+
auto format = ExtractFormat(req);
40+
41+
if (query == "") {
42+
res.status = 200;
43+
res.set_content(reinterpret_cast<char const*>(playgroundContent), sizeof(playgroundContent), "text/html");
44+
return;
45+
}
46+
47+
if (!global_state.db_instance) {
48+
throw IOException("Database instance not initialized");
49+
}
50+
51+
auto start = std::chrono::system_clock::now();
52+
auto result = ExecuteQuery(req, query);
53+
auto end = std::chrono::system_clock::now();
54+
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
55+
56+
QueryExecStats stats{
57+
static_cast<float>(elapsed.count()) / 1000,
58+
0,
59+
0
60+
};
61+
62+
// Format Options
63+
if (format == "JSONEachRow") {
64+
std::string json_output = ConvertResultToNDJSON(*result);
65+
res.set_content(json_output, "application/x-ndjson");
66+
} else if (format == "JSONCompact") {
67+
std::string json_output = ConvertResultToJSON(*result, stats);
68+
res.set_content(json_output, "application/json");
69+
} else {
70+
// Default to NDJSON for DuckDB's own queries
71+
std::string json_output = ConvertResultToNDJSON(*result);
72+
res.set_content(json_output, "application/x-ndjson");
73+
}
74+
75+
}
76+
catch (const HttpHandlerException& ex) {
77+
res.status = ex.status;
78+
res.set_content(ex.message, "text/plain");
79+
}
80+
catch (const Exception& ex) {
81+
res.status = 500;
82+
std::string error_message = "Code: 59, e.displayText() = DB::Exception: " + std::string(ex.what());
83+
res.set_content(error_message, "text/plain");
84+
}
85+
}
86+
87+
// Execute query (optionally using a prepared statement)
88+
std::unique_ptr<MaterializedQueryResult> ExecuteQuery(
89+
const duckdb_httplib_openssl::Request& req,
90+
const std::string& query
91+
) {
92+
Connection con(*global_state.db_instance);
93+
std::unique_ptr<MaterializedQueryResult> result;
94+
95+
if (req.has_param("parameters")) {
96+
auto prepared_stmt = con.Prepare(query);
97+
if (prepared_stmt->HasError()) {
98+
throw HttpHandlerException(500, prepared_stmt->GetError());
99+
}
100+
101+
auto named_values = ExtractQueryParameters(req);
102+
103+
auto prepared_stmt_result = prepared_stmt->Execute(named_values);
104+
D_ASSERT(prepared_stmt_result->type == QueryResultType::STREAM_RESULT);
105+
result = unique_ptr_cast<QueryResult, StreamQueryResult>(std::move(prepared_stmt_result))->Materialize();
106+
} else {
107+
result = con.Query(query);
108+
}
109+
110+
if (result->HasError()) {
111+
throw HttpHandlerException(500, result->GetError());
112+
}
113+
114+
return result;
115+
}
116+
117+
std::string ExtractQuery(const duckdb_httplib_openssl::Request& req) {
118+
// Check if the query is in the URL parameters
119+
if (req.has_param("query")) {
120+
return req.get_param_value("query");
121+
}
122+
else if (req.has_param("q")) {
123+
return req.get_param_value("q");
124+
}
125+
126+
// If not in URL, and it's a POST request, check the body
127+
else if (req.method == "POST" && !req.body.empty()) {
128+
return req.body;
129+
}
130+
131+
// std::optional is not available for this project
132+
return "";
133+
}
134+
135+
std::string ExtractFormat(const duckdb_httplib_openssl::Request& req) {
136+
std::string format = "JSONEachRow";
137+
138+
// Check for format in URL parameter or header
139+
if (req.has_param("default_format")) {
140+
format = req.get_param_value("default_format");
141+
} else if (req.has_header("X-ClickHouse-Format")) {
142+
format = req.get_header_value("X-ClickHouse-Format");
143+
} else if (req.has_header("format")) {
144+
format = req.get_header_value("format");
145+
}
146+
147+
return format;
148+
}
149+
150+
} // namespace duckdb_httpserver

0 commit comments

Comments
 (0)