Skip to content

Commit

Permalink
AVRO-4033: [C++] Filter out redundant union classes generated by avro…
Browse files Browse the repository at this point in the history
…gencpp (#3088)

* AVRO-4033: [C++] Filter out redundant union classes generated by avrogencpp. For a unique list of union branches only one class will be generated. This can reduce the header size in schemas with many unions.

* AVRO-4033: [C++] Align parameter names for UnionCodeTracker::setTraitsGenerated to be more consistent (#3088)

---------

Co-authored-by: hwse <[email protected]>
  • Loading branch information
hwse and hwse authored Aug 22, 2024
1 parent 885c62d commit ea2c54b
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 21 deletions.
3 changes: 2 additions & 1 deletion lang/c++/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ gen (primitivetypes pt)
gen (cpp_reserved_words cppres)
gen (cpp_reserved_words_union_typedef cppres_union)
gen (big_union big_union)
gen (union_redundant_types redundant_types)

add_executable (avrogencpp impl/avrogencpp.cc)
target_link_libraries (avrogencpp avrocpp_s)
Expand Down Expand Up @@ -227,7 +228,7 @@ add_dependencies (AvrogencppTests bigrecord_hh bigrecord_r_hh bigrecord2_hh
union_array_union_hh union_map_union_hh union_conflict_hh
recursive_hh reuse_hh circulardep_hh tree1_hh tree2_hh crossref_hh
primitivetypes_hh empty_record_hh cpp_reserved_words_union_typedef_hh
union_empty_record_hh big_union_hh)
union_empty_record_hh big_union_hh union_redundant_types_hh)

include (InstallRequiredSystemLibraries)

Expand Down
80 changes: 60 additions & 20 deletions lang/c++/impl/avrogencpp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <fstream>
#include <iostream>
#include <map>
#include <optional>
#include <set>

#include <boost/algorithm/string.hpp>
Expand Down Expand Up @@ -67,8 +68,22 @@ struct PendingConstructor {
PendingConstructor(string sn, string n, bool im) : structName(std::move(sn)), memberName(std::move(n)), initMember(im) {}
};

class UnionCodeTracker {
std::string schemaFile_;
size_t unionNumber_ = 0;
std::map<std::vector<std::string>, std::string> unionBranchNameMapping_;
std::set<std::string> generatedUnionTraits_;

public:
explicit UnionCodeTracker(const std::string &schemaFile);
std::optional<std::string> getExistingUnionName(const std::vector<std::string> &unionBranches) const;
std::string generateNewUnionName(const std::vector<std::string> &unionBranches);
bool unionTraitsAlreadyGenerated(const std::string &unionClassName) const;
void setTraitsGenerated(const std::string &unionClassName);
};

class CodeGen {
size_t unionNumber_;
UnionCodeTracker unionTracker_;
std::ostream &os_;
bool inNamespace_;
const std::string ns_;
Expand All @@ -90,7 +105,6 @@ class CodeGen {
std::string generateEnumType(const NodePtr &n);
std::string cppTypeOf(const NodePtr &n);
std::string generateRecordType(const NodePtr &n);
std::string unionName();
std::string generateUnionType(const NodePtr &n);
std::string generateType(const NodePtr &n);
std::string generateDeclaration(const NodePtr &n);
Expand All @@ -106,7 +120,7 @@ class CodeGen {
CodeGen(std::ostream &os, std::string ns,
std::string schemaFile, std::string headerFile,
std::string guardString,
std::string includePrefix, bool noUnion) : unionNumber_(0), os_(os), inNamespace_(false), ns_(std::move(ns)),
std::string includePrefix, bool noUnion) : unionTracker_(schemaFile), os_(os), inNamespace_(false), ns_(std::move(ns)),
schemaFile_(std::move(schemaFile)), headerFile_(std::move(headerFile)),
includePrefix_(std::move(includePrefix)), noUnion_(noUnion),
guardString_(std::move(guardString)),
Expand Down Expand Up @@ -295,17 +309,6 @@ void makeCanonical(string &s, bool foldCase) {
}
}

string CodeGen::unionName() {
string s = schemaFile_;
string::size_type n = s.find_last_of("/\\");
if (n != string::npos) {
s = s.substr(n);
}
makeCanonical(s, false);

return s + "_Union__" + boost::lexical_cast<string>(unionNumber_++) + "__";
}

static void generateGetterAndSetter(ostream &os,
const string &structName, const string &type, const string &name,
size_t idx) {
Expand Down Expand Up @@ -386,7 +389,11 @@ string CodeGen::generateUnionType(const NodePtr &n) {
return done[n];
}

auto result = unionName();
// re-use existing union types that have the exact same branches
if (const auto existingName = unionTracker_.getExistingUnionName(types); existingName.has_value()) {
return existingName.value();
}
const std::string result = unionTracker_.generateNewUnionName(types);

os_ << "struct " << result << " {\n"
<< "private:\n"
Expand Down Expand Up @@ -643,16 +650,18 @@ void CodeGen::generateRecordTraits(const NodePtr &n) {
}

void CodeGen::generateUnionTraits(const NodePtr &n) {
const string name = done[n];
const string fn = fullname(name);
if (unionTracker_.unionTraitsAlreadyGenerated(fn)) {
return;
}
size_t c = n->leaves();

for (size_t i = 0; i < c; ++i) {
const NodePtr &nn = n->leafAt(i);
generateTraits(nn);
}

string name = done[n];
string fn = fullname(name);

os_ << "template<> struct codec_traits<" << fn << "> {\n"
<< " static void encode(Encoder& e, " << fn << " v) {\n"
<< " e.encodeUnionIndex(v.idx());\n"
Expand Down Expand Up @@ -696,6 +705,8 @@ void CodeGen::generateUnionTraits(const NodePtr &n) {
os_ << " }\n"
<< " }\n"
<< "};\n\n";

unionTracker_.setTraitsGenerated(fn);
}

void CodeGen::generateTraits(const NodePtr &n) {
Expand Down Expand Up @@ -808,8 +819,6 @@ void CodeGen::generate(const ValidSchema &schema) {

os_ << "namespace avro {\n";

unionNumber_ = 0;

generateTraits(root);

os_ << "}\n";
Expand Down Expand Up @@ -915,3 +924,34 @@ int main(int argc, char **argv) {
return 1;
}
}

UnionCodeTracker::UnionCodeTracker(const std::string &schemaFile) : schemaFile_(schemaFile) {
}

std::optional<std::string> UnionCodeTracker::getExistingUnionName(const std::vector<std::string> &unionBranches) const {
if (const auto it = unionBranchNameMapping_.find(unionBranches); it != unionBranchNameMapping_.end()) {
return it->second;
}
return std::nullopt;
}

std::string UnionCodeTracker::generateNewUnionName(const std::vector<std::string> &unionBranches) {
string s = schemaFile_;
string::size_type n = s.find_last_of("/\\");
if (n != string::npos) {
s = s.substr(n);
}
makeCanonical(s, false);

std::string result = s + "_Union__" + boost::lexical_cast<string>(unionNumber_++) + "__";
unionBranchNameMapping_.emplace(unionBranches, result);
return result;
}

bool UnionCodeTracker::unionTraitsAlreadyGenerated(const std::string &unionClassName) const {
return generatedUnionTraits_.find(unionClassName) != generatedUnionTraits_.end();
}

void UnionCodeTracker::setTraitsGenerated(const std::string &unionClassName) {
generatedUnionTraits_.insert(unionClassName);
}
22 changes: 22 additions & 0 deletions lang/c++/jsonschemas/union_redundant_types
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"type": "record",
"name": "RedundantUnionSchema",
"doc": "Schema to test the generation of redundant union types in avrogencpp",
"fields" : [
{"name": "null_string_1", "type": ["null", "string"]},
{"name": "null_string_2", "type": ["null", "string"]},
{"name": "string_null_1", "type": ["string", "null"]},
{"name": "string_null_2", "type": ["string", "null"]},
{"name": "null_string_int", "type": ["string", "null", "int"]},
{"name": "null_Empty_1", "type": ["null", {"type": "record", "name": "Empty", "fields": []}]},
{"name": "null_Empty_2", "type": ["null", "Empty"]},
{"name": "null_namespace_record_1", "type": ["null", {"type": "record", "namespace": "example_namespace", "name": "Record", "fields": []}]},
{"name": "null_namespace_record_2", "type": ["null", "example_namespace.Record"]},
{"name": "null_fixed_8", "type": ["null", {"type": "fixed", "size": 8, "name": "fixed_8"}]},
{"name": "null_fixed_16", "type": ["null", {"type": "fixed", "size": 16, "name": "fixed_16"}]},
{"name": "fixed_8_fixed_16", "type": ["fixed_8", "fixed_16"]},
{"name": "null_int_map_1", "type": ["null", {"type": "map", "values": "int"}]},
{"name": "null_int_map_2", "type": ["null", {"type": "map", "values": "int"}]},
{"name": "null_long_map", "type": ["null", {"type": "map", "values": "long"}]}
]
}
103 changes: 103 additions & 0 deletions lang/c++/test/AvrogencppTests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
#include "union_array_union.hh"
#include "union_empty_record.hh"
#include "union_map_union.hh"
#include "union_redundant_types.hh"

#include <array>
#include <boost/test/included/unit_test.hpp>

#ifdef min
Expand Down Expand Up @@ -408,6 +410,105 @@ void testUnionBranchEnum() {
BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::_Int);
}

// enable use of BOOST_CHECK_EQUAL
template<>
struct boost::test_tools::tt_detail::print_log_value<std::type_info> {
void operator()(std::ostream &stream, const std::type_info &type_info) const {
stream << "std::type_info{.name=" << type_info.name() << "}";
}
};

void testNoRedundantUnionTypes() {
redundant_types::RedundantUnionSchema record;
// ensure only one class is generated for same union
BOOST_CHECK_EQUAL(typeid(record.null_string_1), typeid(record.null_string_2));
BOOST_CHECK_EQUAL(typeid(record.string_null_1), typeid(record.string_null_2));
BOOST_CHECK_EQUAL(typeid(record.null_Empty_1), typeid(record.null_Empty_2));
BOOST_CHECK_EQUAL(typeid(record.null_namespace_record_1), typeid(record.null_namespace_record_2));
BOOST_CHECK_EQUAL(typeid(record.null_int_map_1), typeid(record.null_int_map_2));

// different union types should have different class
BOOST_CHECK_NE(typeid(record.null_string_1), typeid(record.string_null_1));
BOOST_CHECK_NE(typeid(record.null_string_1), typeid(record.null_string_int));
BOOST_CHECK_NE(typeid(record.null_fixed_8), typeid(record.null_fixed_16));
BOOST_CHECK_NE(typeid(record.null_int_map_1), typeid(record.null_long_map));
}

void testNoRedundantUnionTypesEncodeDecode() {
redundant_types::RedundantUnionSchema input_record;
input_record.null_string_1.set_string("null_string_1");
input_record.null_string_2.set_string("null_string_2");
input_record.string_null_1.set_string("string_null_1");
input_record.string_null_2.set_string("string_null_2");
input_record.null_string_int.set_string("null_string_int");
input_record.null_Empty_1.set_Empty({});
input_record.null_Empty_2.set_Empty({});
input_record.null_namespace_record_1.set_Record({});
input_record.null_namespace_record_2.set_Record({});
input_record.null_fixed_8.set_fixed_8({8});
input_record.null_fixed_16.set_fixed_16({16});
input_record.fixed_8_fixed_16.set_fixed_16({16});
input_record.null_int_map_1.set_map({{"null_int_map_1", 1}});
input_record.null_int_map_2.set_map({{"null_int_map_2", 1}});
input_record.null_long_map.set_map({{"null_long_map", 1}});

ValidSchema s;
ifstream ifs("jsonschemas/union_redundant_types");
compileJsonSchema(ifs, s);

unique_ptr<OutputStream> os = memoryOutputStream();
EncoderPtr e = validatingEncoder(s, binaryEncoder());
e->init(*os);
avro::encode(*e, input_record);
e->flush();

DecoderPtr d = validatingDecoder(s, binaryDecoder());
unique_ptr<InputStream> is = memoryInputStream(*os);
d->init(*is);
redundant_types::RedundantUnionSchema result_record;
avro::decode(*d, result_record);

BOOST_CHECK_EQUAL(result_record.null_string_1.get_string(), "null_string_1");
BOOST_CHECK_EQUAL(result_record.null_string_2.get_string(), "null_string_2");
BOOST_CHECK_EQUAL(result_record.string_null_1.get_string(), "string_null_1");
BOOST_CHECK_EQUAL(result_record.string_null_2.get_string(), "string_null_2");
BOOST_CHECK_EQUAL(result_record.null_string_int.get_string(), "null_string_int");
BOOST_CHECK(!result_record.null_Empty_1.is_null());
BOOST_CHECK(!result_record.null_Empty_2.is_null());
BOOST_CHECK(!result_record.null_namespace_record_1.is_null());
BOOST_CHECK(!result_record.null_namespace_record_2.is_null());
{
const auto actual = result_record.null_fixed_8.get_fixed_8();
const std::array<uint8_t, 8> expected{8};
BOOST_CHECK_EQUAL_COLLECTIONS(actual.begin(), actual.end(), expected.begin(), expected.end());
}
{
const auto actual = result_record.null_fixed_16.get_fixed_16();
const std::array<uint8_t, 16> expected{16};
BOOST_CHECK_EQUAL_COLLECTIONS(actual.begin(), actual.end(), expected.begin(), expected.end());
}
{
const auto actual = result_record.fixed_8_fixed_16.get_fixed_16();
const std::array<uint8_t, 16> expected{16};
BOOST_CHECK_EQUAL_COLLECTIONS(actual.begin(), actual.end(), expected.begin(), expected.end());
}
{
const auto actual = result_record.null_int_map_1.get_map();
BOOST_CHECK_EQUAL(actual.size(), 1);
BOOST_CHECK_EQUAL(actual.at("null_int_map_1"), 1);
}
{
const auto actual = result_record.null_int_map_2.get_map();
BOOST_CHECK_EQUAL(actual.size(), 1);
BOOST_CHECK_EQUAL(actual.at("null_int_map_2"), 1);
}
{
const auto actual = result_record.null_long_map.get_map();
BOOST_CHECK_EQUAL(actual.size(), 1);
BOOST_CHECK_EQUAL(actual.at("null_long_map"), 1);
}
}

boost::unit_test::test_suite *init_unit_test_suite(int /*argc*/, char * /*argv*/[]) {
auto *ts = BOOST_TEST_SUITE("Code generator tests");
ts->add(BOOST_TEST_CASE(testEncoding));
Expand All @@ -418,5 +519,7 @@ boost::unit_test::test_suite *init_unit_test_suite(int /*argc*/, char * /*argv*/
ts->add(BOOST_TEST_CASE(testEmptyRecord));
ts->add(BOOST_TEST_CASE(testUnionMethods));
ts->add(BOOST_TEST_CASE(testUnionBranchEnum));
ts->add(BOOST_TEST_CASE(testNoRedundantUnionTypes));
ts->add(BOOST_TEST_CASE(testNoRedundantUnionTypesEncodeDecode));
return ts;
}

0 comments on commit ea2c54b

Please sign in to comment.