Skip to content

Commit

Permalink
arff serialiser
Browse files Browse the repository at this point in the history
  • Loading branch information
gf712 committed May 22, 2019
1 parent 2ab615d commit 977b9f8
Show file tree
Hide file tree
Showing 2 changed files with 267 additions and 28 deletions.
210 changes: 207 additions & 3 deletions src/shogun/io/ARFFFile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <shogun/features/DenseFeatures.h>
#include <shogun/io/ARFFFile.h>
#include <shogun/lib/type_case.h>

#include <date/date.h>

Expand Down Expand Up @@ -132,8 +133,9 @@ void ARFFDeserializer::read_helper()
std::string name, type;
auto inner_string =
m_current_line.substr(strlen(m_attribute_string));
left_trim(
inner_string, [](const auto& val) { return !std::isspace(val); });
left_trim(inner_string, [](const auto& val) {
return !std::isspace(val);
});
auto it = inner_string.begin();
if (is_part_of(*it, "\"\'"))
{
Expand Down Expand Up @@ -540,4 +542,206 @@ void ARFFDeserializer::reserve_vector_memory(
VectorResizeVisitor visitor{line_count};
for (auto& vec : v)
shogun::visit(visitor, vec);
}
}

/**
* Very type unsafe, but no UB!
* @param obj
* @return
*/
std::vector<std::string> features_to_string(CSGObject* obj, Attribute att)
{
std::vector<std::string> result_string;
switch (att)
{
case Attribute::NUMERIC:
case Attribute::REAL:
{
auto mat_to_string = [&result_string](const auto& mat) {
result_string.reserve(mat.size());
for (int i = 0; i < mat.size(); ++i)
{
result_string.push_back(std::to_string(mat[i]));
}
};

for (const auto& param : obj->get_params())
if (param.first == "feature_matrix")
{
sg_any_dispatch(
param.second->get_value(), sg_matrix_typemap,
shogun::None{}, shogun::None{}, mat_to_string);
return result_string;
}
}
break;
case Attribute::INTEGER:
{
auto mat_to_string = [&result_string](const auto& mat) {
result_string.reserve(mat.size());
for (int i = 0; i < mat.size(); ++i)
{
result_string.push_back(
std::to_string(static_cast<int64_t>(mat[i])));
}
};

for (const auto& param : obj->get_params())
if (param.first == "feature_matrix")
{
sg_any_dispatch(
param.second->get_value(), sg_matrix_typemap,
shogun::None{}, shogun::None{}, mat_to_string);
return result_string;
}
}
break;
default:
SG_SERROR("Unsupported type: %d\n", static_cast<int>(att))
}
SG_SERROR("The provided feature object does not have a feature matrix!\n")
return std::vector<std::string>{};
}

std::vector<std::string> features_to_string(
CSGObject* obj, const std::vector<std::string>& nominal_values)
{
std::vector<std::string> result_string;
auto mat_to_string = [&result_string, &nominal_values](const auto& mat) {
result_string.reserve(mat.size());
for (int i = 0; i < mat.size(); ++i)
{
result_string.push_back(
"\"" + nominal_values[static_cast<size_t>(mat[i])] + "\"");
}
};

for (const auto& param : obj->get_params())
if (param.first == "feature_matrix")
{
sg_any_dispatch(
param.second->get_value(), sg_matrix_typemap, shogun::None{},
shogun::None{}, mat_to_string);
return result_string;
}
SG_SERROR("The provided feature object does not have a feature matrix!\n")
return std::vector<std::string>{};
}

std::unique_ptr<std::ostringstream> ARFFSerializer::write()
{
auto ss = std::make_unique<std::ostringstream>();

// @relation
*ss << ARFFDeserializer::m_relation_string << " " << m_name << "\n\n";

// @attribute
for (const auto& att : m_attributes)
{
switch (att.second)
{
case Attribute::NUMERIC:
*ss << ARFFDeserializer::m_attribute_string << " " << att.first
<< " numeric\n";
break;
case Attribute::INTEGER:
*ss << ARFFDeserializer::m_attribute_string << " " << att.first
<< " integer\n";
break;
case Attribute::REAL:
*ss << ARFFDeserializer::m_attribute_string << " " << att.first
<< " real\n";
break;
case Attribute::STRING:
*ss << ARFFDeserializer::m_attribute_string << " " << att.first
<< " string\n";
break;
case Attribute::DATE:
SG_SNOTIMPLEMENTED
break;
case Attribute::NOMINAL:
{
*ss << ARFFDeserializer::m_attribute_string << " " << att.first
<< " ";
auto nominal_values_vector = m_nominal_mapping.at(att.first);
std::string nominal_values_string = std::accumulate(
nominal_values_vector.begin(), nominal_values_vector.end(),
"{\"" + nominal_values_vector[0] + "\"",
[](std::string& lhs, const std::string& rhs) {
return lhs += ",\"" + rhs + "\"";
});
nominal_values_string.append("}\n");
*ss << nominal_values_string;
}
}
}

// @data
*ss << "\n" << ARFFDeserializer::m_data_string << "\n\n";

auto* obj = m_feature_list->get_first_element();
auto num_vectors = obj->as<CFeatures>()->get_num_vectors();
for (int i = 0; i < m_feature_list->get_num_elements(); ++i)
{
auto n_i = obj->as<CFeatures>()->get_num_vectors();
SG_UNREF(obj)
REQUIRE(
n_i == num_vectors,
"Expected all features to have the same number of examples!\n")
// in the last iteration this will be nullptr so don't need to deref
obj = m_feature_list->get_next_element();
}

std::vector<std::vector<std::string>> result;
auto att_iter = m_attributes.begin();

obj = m_feature_list->get_first_element();

for (int i = 0; i < m_feature_list->get_num_elements(); ++i)
{
switch (att_iter->second)
{
case Attribute::NUMERIC:
case Attribute::REAL:
case Attribute::INTEGER:
result.push_back(features_to_string(obj, att_iter->second));
break;
case Attribute::NOMINAL:
result.push_back(
features_to_string(obj, m_nominal_mapping.at(att_iter->first)));
break;
case Attribute::DATE:
case Attribute::STRING:
SG_SNOTIMPLEMENTED
}
SG_UNREF(obj)
obj = m_feature_list->get_next_element();
++att_iter;
}

std::vector<std::string> result_rows(num_vectors);

for (auto col = 0; col != result.size(); ++col)
{
if (col != result.size() - 1)
for (auto row = 0; row != num_vectors; ++row)
result_rows[row].append(result[col][row] + ",");
else
for (auto row = 0; row != num_vectors; ++row)
result_rows[row].append(result[col][row] + "\n");
}

for (const auto& row : result_rows)
*ss << row;

return ss;
}

void ARFFSerializer::write(const std::string& filename)
{
auto result = write();
std::ofstream myfile;
myfile.open(filename);
myfile << result->str();
myfile.close();
}
85 changes: 60 additions & 25 deletions src/shogun/io/ARFFFile.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,22 @@ namespace shogun
}
} // namespace arff_detail
#endif // SWIG

/**
* The attributes supported in the ARFF format
*/
enum class Attribute
{
NUMERIC = 0,
INTEGER = 1,
REAL = 2,
STRING = 3,
DATE = 4,
NOMINAL = 5
};

class ARFFSerializer;

/**
* ARFFDeserializer parses files in the ARFF format.
* For information about this format see
Expand All @@ -348,19 +364,7 @@ namespace shogun
class ARFFDeserializer
{
public:
/**
* The attributes supported in the ARFF format
*/
enum class Attribute
{
NUMERIC = 0,
INTEGER = 1,
REAL = 2,
STRING = 3,
DATE = 4,
NOMINAL = 5
};

friend class ARFFSerializer;
/**
* ARFFDeserializer constructor with a filename.
* Performs a check to see if a file can be streamed.
Expand Down Expand Up @@ -530,6 +534,20 @@ namespace shogun
return std::vector<std::string>{};
}

protected:
/** character used in file to comment out a line */
static const char* m_comment_string;
/** characters to declare relations, i.e. @relation */
static const char* m_relation_string;
/** characters to declare attributes, i.e. @attribute */
static const char* m_attribute_string;
/** characters to declare data fields, i.e. @data */
static const char* m_data_string;
/** the default C++ date format specified by the ARFF standard */
static const char* m_default_date_format;
/** missing data */
static const char* m_missing_value_string;

private:
/**
* Templated parser helper for string container primitive type.
Expand Down Expand Up @@ -621,18 +639,6 @@ namespace shogun
std::vector<ScalarType>,
std::vector<std::basic_string<CharType>>>>& v);

/** character used in file to comment out a line */
static const char* m_comment_string;
/** characters to declare relations, i.e. @relation */
static const char* m_relation_string;
/** characters to declare attributes, i.e. @attribute */
static const char* m_attribute_string;
/** characters to declare data fields, i.e. @data */
static const char* m_data_string;
/** the default C++ date format specified by the ARFF standard */
static const char* m_default_date_format;
/** missing data */
static const char* m_missing_value_string;
/** the name of the attributes */
std::vector<std::string> m_attribute_names;

Expand Down Expand Up @@ -668,6 +674,35 @@ namespace shogun
/** the parsed features */
std::vector<std::shared_ptr<CFeatures>> m_features;
};

class ARFFSerializer
{
public:
ARFFSerializer(
const std::string& name, CList* feature_list,
const std::vector<std::pair<std::string, Attribute>>& attributes,
const std::unordered_map<std::string, std::vector<std::string>>&
nominal_mapping)
: m_name(name), m_attributes(attributes), m_nominal_mapping(nominal_mapping)
{
SG_REF(feature_list)
m_feature_list = feature_list;
}

#ifndef SWIG
std::unique_ptr<std::ostringstream> write();
#endif // SWIG

void write(const std::string& filename);

private:
std::string m_name;
CList* m_feature_list;
std::vector<std::pair<std::string, Attribute>> m_attributes;
std::unordered_map<std::string, std::vector<std::string>>
m_nominal_mapping;
};

} // namespace shogun

#endif // SHOGUN_ARFFFILE_H

0 comments on commit 977b9f8

Please sign in to comment.