|
6 | 6 | #include "duckdb/common/string_util.hpp"
|
7 | 7 | #include "duckdb/function/scalar_function.hpp"
|
8 | 8 | #include "duckdb/main/extension_util.hpp"
|
| 9 | +#include "duckdb/parser/parsed_data/create_macro_info.hpp" |
9 | 10 | #include <duckdb/parser/parsed_data/create_scalar_function_info.hpp>
|
| 11 | +#include "duckdb/common/exception/http_exception.hpp" |
10 | 12 |
|
11 |
| -// OpenSSL linked through vcpkg |
12 |
| -#include <openssl/opensslv.h> |
| 13 | +#define CPPHTTPLIB_OPENSSL_SUPPORT |
| 14 | +#include "httplib.hpp" |
| 15 | +#include "yyjson.hpp" |
13 | 16 |
|
14 | 17 | namespace duckdb {
|
15 | 18 |
|
16 |
| -inline void WebxtensionScalarFun(DataChunk &args, ExpressionState &state, Vector &result) { |
17 |
| - auto &name_vector = args.data[0]; |
18 |
| - UnaryExecutor::Execute<string_t, string_t>( |
19 |
| - name_vector, result, args.size(), |
20 |
| - [&](string_t name) { |
21 |
| - return StringVector::AddString(result, "Webxtension "+name.GetString()+" 🐥");; |
22 |
| - }); |
| 19 | +// Helper function to setup HTTP client |
| 20 | +static std::pair<duckdb_httplib_openssl::Client, std::string> SetupHttpClient(const std::string &url) { |
| 21 | + std::string scheme, domain, path; |
| 22 | + size_t pos = url.find("://"); |
| 23 | + std::string mod_url = url; |
| 24 | + if (pos != std::string::npos) { |
| 25 | + scheme = mod_url.substr(0, pos); |
| 26 | + mod_url.erase(0, pos + 3); |
| 27 | + } |
| 28 | + |
| 29 | + pos = mod_url.find("/"); |
| 30 | + if (pos != std::string::npos) { |
| 31 | + domain = mod_url.substr(0, pos); |
| 32 | + path = mod_url.substr(pos); |
| 33 | + } else { |
| 34 | + domain = mod_url; |
| 35 | + path = "/"; |
| 36 | + } |
| 37 | + |
| 38 | + duckdb_httplib_openssl::Client client(domain.c_str()); |
| 39 | + client.set_read_timeout(10, 0); |
| 40 | + client.set_follow_location(true); |
| 41 | + |
| 42 | + return std::make_pair(std::move(client), path); |
23 | 43 | }
|
24 | 44 |
|
25 |
| -inline void WebxtensionOpenSSLVersionScalarFun(DataChunk &args, ExpressionState &state, Vector &result) { |
26 |
| - auto &name_vector = args.data[0]; |
| 45 | +// Helper function to handle HTTP errors |
| 46 | +static void HandleHttpError(const duckdb_httplib_openssl::Result &res, const std::string &request_type) { |
| 47 | + std::string err_message = "HTTP " + request_type + " request failed. "; |
| 48 | + |
| 49 | + switch (res.error()) { |
| 50 | + case duckdb_httplib_openssl::Error::Connection: |
| 51 | + err_message += "Connection error."; |
| 52 | + break; |
| 53 | + case duckdb_httplib_openssl::Error::Read: |
| 54 | + err_message += "Error reading response."; |
| 55 | + break; |
| 56 | + default: |
| 57 | + err_message += "Unknown error."; |
| 58 | + break; |
| 59 | + } |
| 60 | + throw std::runtime_error(err_message); |
| 61 | +} |
| 62 | + |
| 63 | + |
| 64 | +static bool ContainsMacroDefinition(const std::string &content) { |
| 65 | + std::string upper_content = StringUtil::Upper(content); |
| 66 | + const char* patterns[] = { |
| 67 | + "CREATE MACRO", |
| 68 | + "CREATE OR REPLACE MACRO", |
| 69 | + "CREATE TEMP MACRO", |
| 70 | + "CREATE TEMPORARY MACRO", |
| 71 | + "CREATE OR REPLACE TEMP MACRO", |
| 72 | + "CREATE OR REPLACE TEMPORARY MACRO" |
| 73 | + }; |
| 74 | + |
| 75 | + for (const auto& pattern : patterns) { |
| 76 | + if (upper_content.find(pattern) != std::string::npos) { |
| 77 | + return true; |
| 78 | + } |
| 79 | + } |
| 80 | + return false; |
| 81 | +} |
| 82 | + |
| 83 | +// Function to fetch and create macro from URL |
| 84 | +static void LoadMacroFromUrlFunction(DataChunk &args, ExpressionState &state, Vector &result, DatabaseInstance *db_instance) { |
| 85 | + auto &context = state.GetContext(); |
| 86 | + |
27 | 87 | UnaryExecutor::Execute<string_t, string_t>(
|
28 |
| - name_vector, result, args.size(), |
29 |
| - [&](string_t name) { |
30 |
| - return StringVector::AddString(result, "Webxtension " + name.GetString() + |
31 |
| - ", my linked OpenSSL version is " + |
32 |
| - OPENSSL_VERSION_TEXT );; |
| 88 | + args.data[0], result, args.size(), |
| 89 | + [&](string_t url) { |
| 90 | + try { |
| 91 | + // Setup HTTP client |
| 92 | + auto client_and_path = SetupHttpClient(url.GetString()); |
| 93 | + auto &client = client_and_path.first; |
| 94 | + auto &path = client_and_path.second; |
| 95 | + |
| 96 | + // Make GET request |
| 97 | + auto res = client.Get(path.c_str()); |
| 98 | + if (!res) { |
| 99 | + HandleHttpError(res, "GET"); |
| 100 | + } |
| 101 | + |
| 102 | + if (res->status != 200) { |
| 103 | + throw std::runtime_error("HTTP error " + std::to_string(res->status) + ": " + res->reason); |
| 104 | + } |
| 105 | + |
| 106 | + // Get the SQL content |
| 107 | + std::string macro_sql = res->body; |
| 108 | + |
| 109 | + // Replace all \r\n with \n |
| 110 | + macro_sql = StringUtil::Replace(macro_sql, "\r\n", "\n"); |
| 111 | + // Replace any remaining \r with \n |
| 112 | + macro_sql = StringUtil::Replace(macro_sql, "\r", "\n"); |
| 113 | + // Normalize multiple newlines to single newlines |
| 114 | + macro_sql = StringUtil::Replace(macro_sql, "\n\n", "\n"); |
| 115 | + // Trim in place |
| 116 | + StringUtil::Trim(macro_sql); |
| 117 | + |
| 118 | + if (!ContainsMacroDefinition(macro_sql)) { |
| 119 | + throw std::runtime_error("URL content does not contain a valid macro definition"); |
| 120 | + } |
| 121 | + |
| 122 | + //std::cout << macro_sql << "\n"; |
| 123 | + Connection conn(*db_instance); |
| 124 | + |
| 125 | + // Execute the macro directly in the current context |
| 126 | + auto query_result = conn.Query(macro_sql); |
| 127 | + |
| 128 | + if (query_result->HasError()) { |
| 129 | + throw std::runtime_error("Failed loading Macro: " + query_result->GetError()); |
| 130 | + } |
| 131 | + |
| 132 | + return StringVector::AddString(result, "Successfully loaded macro"); |
| 133 | + |
| 134 | + } catch (std::exception &e) { |
| 135 | + std::string error_msg = "Error: " + std::string(e.what()); |
| 136 | + throw std::runtime_error(error_msg); |
| 137 | + } |
33 | 138 | });
|
34 | 139 | }
|
35 | 140 |
|
36 | 141 | static void LoadInternal(DatabaseInstance &instance) {
|
37 |
| - // Register a scalar function |
38 |
| - auto webxtension_scalar_function = ScalarFunction("webxtension", {LogicalType::VARCHAR}, LogicalType::VARCHAR, WebxtensionScalarFun); |
39 |
| - ExtensionUtil::RegisterFunction(instance, webxtension_scalar_function); |
40 |
| - |
41 |
| - // Register another scalar function |
42 |
| - auto webxtension_openssl_version_scalar_function = ScalarFunction("webxtension_openssl_version", {LogicalType::VARCHAR}, |
43 |
| - LogicalType::VARCHAR, WebxtensionOpenSSLVersionScalarFun); |
44 |
| - ExtensionUtil::RegisterFunction(instance, webxtension_openssl_version_scalar_function); |
| 142 | + // Create lambda to capture database instance |
| 143 | + auto load_macro_func = [&instance](DataChunk &args, ExpressionState &state, Vector &result) { |
| 144 | + LoadMacroFromUrlFunction(args, state, result, &instance); |
| 145 | + }; |
| 146 | + |
| 147 | + // Register function with captured database instance |
| 148 | + ExtensionUtil::RegisterFunction( |
| 149 | + instance, |
| 150 | + ScalarFunction("load_macro_from_url", {LogicalType::VARCHAR}, |
| 151 | + LogicalType::VARCHAR, load_macro_func) |
| 152 | + ); |
45 | 153 | }
|
46 | 154 |
|
47 | 155 | void WebxtensionExtension::Load(DuckDB &db) {
|
48 |
| - LoadInternal(*db.instance); |
| 156 | + LoadInternal(*db.instance); |
49 | 157 | }
|
| 158 | + |
50 | 159 | std::string WebxtensionExtension::Name() {
|
51 |
| - return "webxtension"; |
| 160 | + return "webxtension"; |
52 | 161 | }
|
53 | 162 |
|
54 | 163 | std::string WebxtensionExtension::Version() const {
|
55 | 164 | #ifdef EXT_VERSION_WEBXTENSION
|
56 |
| - return EXT_VERSION_WEBXTENSION; |
| 165 | + return EXT_VERSION_WEBXTENSION; |
57 | 166 | #else
|
58 |
| - return ""; |
| 167 | + return ""; |
59 | 168 | #endif
|
60 | 169 | }
|
61 | 170 |
|
62 | 171 | } // namespace duckdb
|
63 | 172 |
|
64 | 173 | extern "C" {
|
65 |
| - |
66 | 174 | DUCKDB_EXTENSION_API void webxtension_init(duckdb::DatabaseInstance &db) {
|
67 | 175 | duckdb::DuckDB db_wrapper(db);
|
68 | 176 | db_wrapper.LoadExtension<duckdb::WebxtensionExtension>();
|
69 | 177 | }
|
70 | 178 |
|
71 | 179 | DUCKDB_EXTENSION_API const char *webxtension_version() {
|
72 |
| - return duckdb::DuckDB::LibraryVersion(); |
| 180 | + return duckdb::DuckDB::LibraryVersion(); |
73 | 181 | }
|
74 | 182 | }
|
75 | 183 |
|
|
0 commit comments