diff --git a/src/shogun/io/ARFFFile.cpp b/src/shogun/io/ARFFFile.cpp index b8c8f28245a..f2606dc5732 100644 --- a/src/shogun/io/ARFFFile.cpp +++ b/src/shogun/io/ARFFFile.cpp @@ -6,6 +6,7 @@ #include #include +#include #include @@ -132,8 +133,9 @@ void ARFFDeserializer::read_helper() std::string name, type; auto inner_string = m_current_line.substr(strlen(m_attribute_string)); - left_trim( - inner_string, [](const auto& val) { return !std::isspace(val); }); + left_trim(inner_string, [](const auto& val) { + return !std::isspace(val); + }); auto it = inner_string.begin(); if (is_part_of(*it, "\"\'")) { @@ -540,4 +542,206 @@ void ARFFDeserializer::reserve_vector_memory( VectorResizeVisitor visitor{line_count}; for (auto& vec : v) shogun::visit(visitor, vec); -} \ No newline at end of file +} + +/** + * Very type unsafe, but no UB! + * @param obj + * @return + */ +std::vector features_to_string(CSGObject* obj, Attribute att) +{ + std::vector result_string; + switch (att) + { + case Attribute::NUMERIC: + case Attribute::REAL: + { + auto mat_to_string = [&result_string](const auto& mat) { + result_string.reserve(mat.size()); + for (int i = 0; i < mat.size(); ++i) + { + result_string.push_back(std::to_string(mat[i])); + } + }; + + for (const auto& param : obj->get_params()) + if (param.first == "feature_matrix") + { + sg_any_dispatch( + param.second->get_value(), sg_matrix_typemap, + shogun::None{}, shogun::None{}, mat_to_string); + return result_string; + } + } + break; + case Attribute::INTEGER: + { + auto mat_to_string = [&result_string](const auto& mat) { + result_string.reserve(mat.size()); + for (int i = 0; i < mat.size(); ++i) + { + result_string.push_back( + std::to_string(static_cast(mat[i]))); + } + }; + + for (const auto& param : obj->get_params()) + if (param.first == "feature_matrix") + { + sg_any_dispatch( + param.second->get_value(), sg_matrix_typemap, + shogun::None{}, shogun::None{}, mat_to_string); + return result_string; + } + } + break; + default: + SG_SERROR("Unsupported type: %d\n", static_cast(att)) + } + SG_SERROR("The provided feature object does not have a feature matrix!\n") + return std::vector{}; +} + +std::vector features_to_string( + CSGObject* obj, const std::vector& nominal_values) +{ + std::vector result_string; + auto mat_to_string = [&result_string, &nominal_values](const auto& mat) { + result_string.reserve(mat.size()); + for (int i = 0; i < mat.size(); ++i) + { + result_string.push_back( + "\"" + nominal_values[static_cast(mat[i])] + "\""); + } + }; + + for (const auto& param : obj->get_params()) + if (param.first == "feature_matrix") + { + sg_any_dispatch( + param.second->get_value(), sg_matrix_typemap, shogun::None{}, + shogun::None{}, mat_to_string); + return result_string; + } + SG_SERROR("The provided feature object does not have a feature matrix!\n") + return std::vector{}; +} + +std::unique_ptr ARFFSerializer::write() +{ + auto ss = std::make_unique(); + + // @relation + *ss << ARFFDeserializer::m_relation_string << " " << m_name << "\n\n"; + + // @attribute + for (const auto& att : m_attributes) + { + switch (att.second) + { + case Attribute::NUMERIC: + *ss << ARFFDeserializer::m_attribute_string << " " << att.first + << " numeric\n"; + break; + case Attribute::INTEGER: + *ss << ARFFDeserializer::m_attribute_string << " " << att.first + << " integer\n"; + break; + case Attribute::REAL: + *ss << ARFFDeserializer::m_attribute_string << " " << att.first + << " real\n"; + break; + case Attribute::STRING: + *ss << ARFFDeserializer::m_attribute_string << " " << att.first + << " string\n"; + break; + case Attribute::DATE: + SG_SNOTIMPLEMENTED + break; + case Attribute::NOMINAL: + { + *ss << ARFFDeserializer::m_attribute_string << " " << att.first + << " "; + auto nominal_values_vector = m_nominal_mapping.at(att.first); + std::string nominal_values_string = std::accumulate( + nominal_values_vector.begin(), nominal_values_vector.end(), + "{\"" + nominal_values_vector[0] + "\"", + [](std::string& lhs, const std::string& rhs) { + return lhs += ",\"" + rhs + "\""; + }); + nominal_values_string.append("}\n"); + *ss << nominal_values_string; + } + } + } + + // @data + *ss << "\n" << ARFFDeserializer::m_data_string << "\n\n"; + + auto* obj = m_feature_list->get_first_element(); + auto num_vectors = obj->as()->get_num_vectors(); + for (int i = 0; i < m_feature_list->get_num_elements(); ++i) + { + auto n_i = obj->as()->get_num_vectors(); + SG_UNREF(obj) + REQUIRE( + n_i == num_vectors, + "Expected all features to have the same number of examples!\n") + // in the last iteration this will be nullptr so don't need to deref + obj = m_feature_list->get_next_element(); + } + + std::vector> result; + auto att_iter = m_attributes.begin(); + + obj = m_feature_list->get_first_element(); + + for (int i = 0; i < m_feature_list->get_num_elements(); ++i) + { + switch (att_iter->second) + { + case Attribute::NUMERIC: + case Attribute::REAL: + case Attribute::INTEGER: + result.push_back(features_to_string(obj, att_iter->second)); + break; + case Attribute::NOMINAL: + result.push_back( + features_to_string(obj, m_nominal_mapping.at(att_iter->first))); + break; + case Attribute::DATE: + case Attribute::STRING: + SG_SNOTIMPLEMENTED + } + SG_UNREF(obj) + obj = m_feature_list->get_next_element(); + ++att_iter; + } + + std::vector result_rows(num_vectors); + + for (auto col = 0; col != result.size(); ++col) + { + if (col != result.size() - 1) + for (auto row = 0; row != num_vectors; ++row) + result_rows[row].append(result[col][row] + ","); + else + for (auto row = 0; row != num_vectors; ++row) + result_rows[row].append(result[col][row] + "\n"); + } + + for (const auto& row : result_rows) + *ss << row; + + return ss; +} + +void ARFFSerializer::write(const std::string& filename) +{ + auto result = write(); + std::ofstream myfile; + myfile.open(filename); + myfile << result->str(); + myfile.close(); +} diff --git a/src/shogun/io/ARFFFile.h b/src/shogun/io/ARFFFile.h index 520f186e645..839f436847a 100644 --- a/src/shogun/io/ARFFFile.h +++ b/src/shogun/io/ARFFFile.h @@ -340,6 +340,22 @@ namespace shogun } } // namespace arff_detail #endif // SWIG + + /** + * The attributes supported in the ARFF format + */ + enum class Attribute + { + NUMERIC = 0, + INTEGER = 1, + REAL = 2, + STRING = 3, + DATE = 4, + NOMINAL = 5 + }; + + class ARFFSerializer; + /** * ARFFDeserializer parses files in the ARFF format. * For information about this format see @@ -348,19 +364,7 @@ namespace shogun class ARFFDeserializer { public: - /** - * The attributes supported in the ARFF format - */ - enum class Attribute - { - NUMERIC = 0, - INTEGER = 1, - REAL = 2, - STRING = 3, - DATE = 4, - NOMINAL = 5 - }; - + friend class ARFFSerializer; /** * ARFFDeserializer constructor with a filename. * Performs a check to see if a file can be streamed. @@ -530,6 +534,20 @@ namespace shogun return std::vector{}; } + protected: + /** character used in file to comment out a line */ + static const char* m_comment_string; + /** characters to declare relations, i.e. @relation */ + static const char* m_relation_string; + /** characters to declare attributes, i.e. @attribute */ + static const char* m_attribute_string; + /** characters to declare data fields, i.e. @data */ + static const char* m_data_string; + /** the default C++ date format specified by the ARFF standard */ + static const char* m_default_date_format; + /** missing data */ + static const char* m_missing_value_string; + private: /** * Templated parser helper for string container primitive type. @@ -621,18 +639,6 @@ namespace shogun std::vector, std::vector>>>& v); - /** character used in file to comment out a line */ - static const char* m_comment_string; - /** characters to declare relations, i.e. @relation */ - static const char* m_relation_string; - /** characters to declare attributes, i.e. @attribute */ - static const char* m_attribute_string; - /** characters to declare data fields, i.e. @data */ - static const char* m_data_string; - /** the default C++ date format specified by the ARFF standard */ - static const char* m_default_date_format; - /** missing data */ - static const char* m_missing_value_string; /** the name of the attributes */ std::vector m_attribute_names; @@ -668,6 +674,35 @@ namespace shogun /** the parsed features */ std::vector> m_features; }; + + class ARFFSerializer + { + public: + ARFFSerializer( + const std::string& name, CList* feature_list, + const std::vector>& attributes, + const std::unordered_map>& + nominal_mapping) + : m_name(name), m_attributes(attributes), m_nominal_mapping(nominal_mapping) + { + SG_REF(feature_list) + m_feature_list = feature_list; + } + +#ifndef SWIG + std::unique_ptr write(); +#endif // SWIG + + void write(const std::string& filename); + + private: + std::string m_name; + CList* m_feature_list; + std::vector> m_attributes; + std::unordered_map> + m_nominal_mapping; + }; + } // namespace shogun #endif // SHOGUN_ARFFFILE_H