-
-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
openml integration in shogun #4628
base: develop
Are you sure you want to change the base?
Changes from 1 commit
e6d5a20
7cf1d10
b70398d
0500504
bbb493a
4c988bb
859786b
beb0acf
86c06b8
d40785a
7b4a365
40fda20
677c1e7
d06fe2f
b045a8a
753f8ba
f190940
9b331a6
21d2a00
134cf2e
07b07d0
3cd470b
f1546a6
d5ac051
df1a8d6
c615896
c5c9d93
f9f9c79
37e8c63
f17bbe1
c2cf37b
6616ecf
6d6c254
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
/* | ||
* This software is distributed under BSD 3-clause license (see LICENSE file). | ||
* | ||
* Authors: Gil Hoben | ||
*/ | ||
|
||
#ifdef HAVE_CURL | ||
|
||
#include <shogun/io/OpenmlFlow.h> | ||
#include "OpenmlFlow.h" | ||
|
||
|
||
using namespace shogun; | ||
|
||
size_t writer(char *data, size_t size, size_t nmemb, std::string* buffer_in) | ||
{ | ||
// adapted from https://stackoverflow.com/a/5780603 | ||
// Is there anything in the buffer? | ||
if (buffer_in->empty()) | ||
{ | ||
// Append the data to the buffer | ||
buffer_in->append(data, size * nmemb); | ||
|
||
return size * nmemb; | ||
} | ||
|
||
return 0; | ||
} | ||
|
||
const char* OpenMLReader::xml_server = "https://www.openml.org/api/v1/xml"; | ||
const char* OpenMLReader::json_server = "https://www.openml.org/api/v1/json"; | ||
const char* OpenMLReader::dataset_description = "/data/{}"; | ||
const char* OpenMLReader::list_data_qualities = "/data/qualities/list"; | ||
const char* OpenMLReader::data_features = "/data/features/{}"; | ||
const char* OpenMLReader::list_dataset_qualities = "/data/qualities/{}"; | ||
const char* OpenMLReader::list_dataset_filter = "/data/list/{}"; | ||
const char* OpenMLReader::flow_file = "/flow/{}"; | ||
|
||
const std::unordered_map<std::string, std::string> | ||
OpenMLReader::m_format_options = {{"xml", xml_server}, | ||
{"json", json_server}}; | ||
const std::unordered_map<std::string, std::string> | ||
OpenMLReader::m_request_options = { | ||
|
||
{"dataset_description", dataset_description}, | ||
{"list_data_qualities", list_data_qualities}, | ||
{"data_features", data_features}, | ||
{"list_dataset_qualities", list_dataset_qualities}, | ||
{"list_dataset_filter", list_dataset_filter}, | ||
{"flow_file", flow_file}}; | ||
|
||
OpenMLReader::OpenMLReader(const std::string& api_key) : m_api_key(api_key) | ||
{ | ||
} | ||
|
||
void OpenMLReader::post(const std::string& request, const std::string& data) | ||
{ | ||
} | ||
|
||
void OpenMLReader::openml_curl_request_helper(const std::string& url) | ||
{ | ||
CURL* curl_handle = nullptr; | ||
|
||
curl_handle = curl_easy_init(); | ||
|
||
if (!curl_handle) | ||
{ | ||
SG_SERROR("Failed to initialise curl handle.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe some infos on what happened? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure what the exact error would be... From the docs: "If this function [curl_easy_init] returns NULL, something went wrong and you cannot use the other curl functions." There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any value in using this other global init instead? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah thanks, yes, that seems helpful! |
||
return; | ||
} | ||
|
||
curl_easy_setopt(curl_handle, CURLOPT_URL, url.c_str()); | ||
curl_easy_setopt(curl_handle, CURLOPT_HTTPGET,1); | ||
curl_easy_setopt(curl_handle, CURLOPT_WRITEFUNCTION, writer); | ||
curl_easy_setopt(curl_handle, CURLOPT_WRITEDATA, &m_curl_response_buffer); | ||
|
||
CURLcode res = curl_easy_perform(curl_handle); | ||
|
||
openml_curl_error_helper(res); | ||
|
||
curl_easy_cleanup(curl_handle); | ||
} | ||
|
||
void OpenMLReader::openml_curl_error_helper(CURLcode code) { | ||
|
||
} | ||
|
||
|
||
void OpenMLFlow::download_flow() | ||
{ | ||
|
||
auto reader = OpenMLReader(m_api_key); | ||
auto return_string = reader.get("flow_file", "json", m_flow_id); | ||
} | ||
|
||
void OpenMLFlow::upload_flow(const OpenMLFlow& flow) | ||
{ | ||
} | ||
|
||
#endif // HAVE_CURL |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
/* | ||
* This software is distributed under BSD 3-clause license (see LICENSE file). | ||
* | ||
* Authors: Gil Hoben | ||
*/ | ||
|
||
#ifndef SHOGUN_OPENMLFLOW_H | ||
#define SHOGUN_OPENMLFLOW_H | ||
|
||
#ifdef HAVE_CURL | ||
|
||
#include <shogun/io/SGIO.h> | ||
|
||
#include <curl/curl.h> | ||
#include <numeric> | ||
#include <string> | ||
#include <unordered_map> | ||
#include <vector> | ||
#include <iostream> | ||
|
||
namespace shogun | ||
{ | ||
class OpenMLReader | ||
{ | ||
|
||
public: | ||
explicit OpenMLReader(const std::string& api_key); | ||
|
||
/** | ||
* Returns a string returned by the server given a request. | ||
* Raises an error if the returned code is not 200. | ||
* Additional arguments can be passed to the request, | ||
* which are then concatenated with a "/" character. | ||
* | ||
* @tparam Args argument type pack, should all be std::string | ||
* @param request the request name, see m_request_options | ||
* @param format the format to return the data in, see m_format_options | ||
* @param args the additional arguments to be passed to request | ||
* @return the returned stream from the server if the return code is 200 | ||
*/ | ||
template <typename... Args> | ||
std::string | ||
get(const std::string& request, const std::string& format, Args... args) | ||
{ | ||
auto find_format = m_format_options.find(format); | ||
if (find_format == m_format_options.end()) | ||
{ | ||
SG_SERROR( | ||
"The provided format \"%s\" is not available\n", | ||
format.c_str()) | ||
} | ||
auto find_request = m_request_options.find(request); | ||
if (find_request == m_request_options.end()) | ||
{ | ||
SG_SERROR( | ||
"Could not find a way to solve the request \"%s\"\n", | ||
request.c_str()) | ||
} | ||
std::string request_format = find_format->second; | ||
std::string request_path = find_request->second; | ||
|
||
// get additional args and concatenate them with "/" | ||
if (sizeof...(Args) > 0) | ||
{ | ||
if (request_path.substr(request_path.size() - 2) == "{}") | ||
{ | ||
request_path = | ||
request_path.substr(0, request_path.size() - 2); | ||
} | ||
else | ||
{ | ||
SG_SERROR( | ||
"The provided request \"%s\" cannot handle additional " | ||
"args.\n", | ||
request.c_str()) | ||
} | ||
std::vector<std::string> args_vec = {args...}; | ||
std::string args_string = std::accumulate( | ||
args_vec.begin() + 1, args_vec.end(), args_vec.front(), | ||
[](std::string s0, std::string& s1) { | ||
return s0 += "/" + s1; | ||
}); | ||
request_path += args_string; | ||
} | ||
|
||
std::string url = request_format + request_path + "?" + m_api_key; | ||
|
||
openml_curl_request_helper(url); | ||
|
||
return m_curl_response_buffer; | ||
} | ||
|
||
void post(const std::string& request, const std::string& data); | ||
|
||
private: | ||
|
||
std::string m_curl_response_buffer; | ||
|
||
/** | ||
* Initialises CURL session and gets the data. | ||
* This function also handles the response code from the server. | ||
* | ||
* @param url the url to query | ||
*/ | ||
void openml_curl_request_helper(const std::string& url); | ||
|
||
/** | ||
* Handles all possible codes | ||
* | ||
* @param code the code returned by the query | ||
*/ | ||
void openml_curl_error_helper(CURLcode code); | ||
|
||
std::string m_api_key; | ||
|
||
static const char* xml_server; | ||
static const char* json_server; | ||
|
||
static const std::unordered_map<std::string, std::string> | ||
m_format_options; | ||
static const std::unordered_map<std::string, std::string> | ||
m_request_options; | ||
|
||
/* DATA API */ | ||
static const char* dataset_description; | ||
static const char* list_data_qualities; | ||
static const char* data_features; | ||
static const char* list_dataset_qualities; | ||
static const char* list_dataset_filter; | ||
|
||
/* FLOW API */ | ||
static const char* flow_file; | ||
}; | ||
|
||
class OpenMLFlow | ||
{ | ||
|
||
public: | ||
explicit OpenMLFlow( | ||
const std::string& api_key, const std::string& flow_id) | ||
: m_api_key(api_key), m_flow_id(flow_id){}; | ||
|
||
void download_flow(); | ||
|
||
static void upload_flow(const OpenMLFlow& flow); | ||
|
||
private: | ||
std::string m_api_key; | ||
std::string m_flow_id; | ||
}; | ||
} // namespace shogun | ||
#endif // HAVE_CURL | ||
|
||
#endif // SHOGUN_OPENMLFLOW_H |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should these things be hard-coded here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes