Skip to content

Commit

Permalink
Support transforming maps in Data Contract rules (confluentinc#1324)
Browse files Browse the repository at this point in the history
* Support transforming maps in Data Contract rules

* Enhance log msg
  • Loading branch information
rayokota authored Oct 29, 2024
1 parent 77ae7fa commit 8d502ae
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 24 deletions.
75 changes: 75 additions & 0 deletions schemaregistry/serde/avrov2/avro_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,81 @@ func TestAvroSerdeEncryption(t *testing.T) {
serde.MaybeFail("deserialization", err, serde.Expect(newobj, &obj))
}

func TestAvroSerdeEncryptionWithSimpleMap(t *testing.T) {
serde.MaybeFail = serde.InitFailFunc(t)
var err error

conf := schemaregistry.NewConfig("mock://")

client, err := schemaregistry.NewClient(conf)
serde.MaybeFail("Schema Registry configuration", err)

serConfig := NewSerializerConfig()
serConfig.AutoRegisterSchemas = false
serConfig.UseLatestVersion = true
serConfig.RuleConfig = map[string]string{
"secret": "mysecret",
}
ser, err := NewSerializer(client, serde.ValueSerde, serConfig)
serde.MaybeFail("Serializer configuration", err)

encRule := schemaregistry.Rule{
Name: "test-encrypt",
Kind: "TRANSFORM",
Mode: "WRITEREAD",
Type: "ENCRYPT",
Tags: []string{"PII"},
Params: map[string]string{
"encrypt.kek.name": "kek1",
"encrypt.kms.type": "local-kms",
"encrypt.kms.key.id": "mykey",
},
OnFailure: "ERROR,NONE",
}
ruleSet := schemaregistry.RuleSet{
DomainRules: []schemaregistry.Rule{encRule},
}

info := schemaregistry.SchemaInfo{
Schema: demoSchema,
SchemaType: "AVRO",
RuleSet: &ruleSet,
}

id, err := client.Register("topic1-value", info, false)
serde.MaybeFail("Schema registration", err)
if id <= 0 {
t.Errorf("Expected valid schema id, found %d", id)
}

obj := make(map[string]interface{})
obj["IntField"] = 123
obj["DoubleField"] = 45.67
obj["StringField"] = "hi"
obj["BoolField"] = true
obj["BytesField"] = []byte{1, 2}

bytes, err := ser.Serialize("topic1", &obj)
serde.MaybeFail("serialization", err)

// Reset encrypted field
obj["StringField"] = "hi"
obj["BytesField"] = []byte{1, 2}

deserConfig := NewDeserializerConfig()
deserConfig.RuleConfig = map[string]string{
"secret": "mysecret",
}
deser, err := NewDeserializer(client, serde.ValueSerde, deserConfig)
serde.MaybeFail("Deserializer configuration", err)
deser.Client = ser.Client
deser.MessageFactory = testMessageFactory

var newobj map[string]interface{}
err = deser.DeserializeInto("topic1", bytes, &newobj)
serde.MaybeFail("deserialization into", err, serde.Expect(newobj, obj))
}

func TestAvroSerdeEncryptionDekRotation(t *testing.T) {
f := fakeClock{now: time.Now().UnixMilli()}
executor := encryption.RegisterWithClock(&f)
Expand Down
44 changes: 32 additions & 12 deletions schemaregistry/serde/avrov2/avro_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,32 @@ func transform(ctx serde.RuleContext, resolver *avro.TypeResolver, schema avro.S
return msg, nil
case *avro.RecordSchema:
val := deref(msg)
fieldByNames := fieldByNames(val)
recordSchema := schema.(*avro.RecordSchema)
for _, avroField := range recordSchema.Fields() {
structField, ok := fieldByNames[avroField.Name()]
if !ok {
return nil, fmt.Errorf("avro: missing field %s", avroField.Name())
if val.Kind() == reflect.Struct {
fieldByNames := fieldByNames(val)
for _, avroField := range recordSchema.Fields() {
structField, ok := fieldByNames[avroField.Name()]
if !ok {
return nil, fmt.Errorf("avro: missing field %s", avroField.Name())
}
err := transformField(ctx, resolver, recordSchema, avroField, structField, val, fieldTransform)
if err != nil {
return nil, err
}
}
err := transformField(ctx, resolver, recordSchema, avroField, structField, val, fieldTransform)
if err != nil {
return nil, err
return msg, nil
} else if val.Kind() == reflect.Map {
for _, avroField := range recordSchema.Fields() {
mapField := val.MapIndex(reflect.ValueOf(avroField.Name()))
err := transformField(ctx, resolver, recordSchema, avroField, &mapField, val, fieldTransform)
if err != nil {
return nil, err
}
}
return msg, nil
} else {
return nil, fmt.Errorf("message of kind %s is not a struct or map", val.Kind())
}
return msg, nil
default:
if fieldCtx != nil {
ruleTags := ctx.Rule.Tags
Expand Down Expand Up @@ -137,9 +150,13 @@ func transformField(ctx serde.RuleContext, resolver *avro.TypeResolver, recordSc
}
}
} else {
err = setField(structField, newVal)
if err != nil {
return err
if val.Kind() == reflect.Struct {
err = setField(structField, newVal)
if err != nil {
return err
}
} else {
val.SetMapIndex(reflect.ValueOf(avroField.Name()), *newVal)
}
}
return nil
Expand Down Expand Up @@ -203,6 +220,9 @@ func disjoint(slice1 []string, map1 map[string]bool) bool {
}

func getField(msg *reflect.Value, name string) (*reflect.Value, error) {
if msg.Kind() != reflect.Struct {
return nil, fmt.Errorf("message is not a struct")
}
fieldVal := msg.FieldByName(name)
return &fieldVal, nil
}
Expand Down
77 changes: 77 additions & 0 deletions schemaregistry/serde/jsonschema/json_schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,83 @@ func TestJSONSchemaSerdeWithCELFieldConditionFail(t *testing.T) {
serde.MaybeFail("serialization", nil, serde.Expect(ruleErr, serde.RuleConditionErr{Rule: &encRule}))
}

func TestJSONSchemaSerdeEncryptionWithSimpleMap(t *testing.T) {
serde.MaybeFail = serde.InitFailFunc(t)
var err error

conf := schemaregistry.NewConfig("mock://")

client, err := schemaregistry.NewClient(conf)
serde.MaybeFail("Schema Registry configuration", err)

serConfig := NewSerializerConfig()
serConfig.AutoRegisterSchemas = false
serConfig.UseLatestVersion = true
serConfig.RuleConfig = map[string]string{
"secret": "mysecret",
}
ser, err := NewSerializer(client, serde.ValueSerde, serConfig)
serde.MaybeFail("Serializer configuration", err)

encRule := schemaregistry.Rule{
Name: "test-encrypt",
Kind: "TRANSFORM",
Mode: "WRITEREAD",
Type: "ENCRYPT",
Tags: []string{"PII"},
Params: map[string]string{
"encrypt.kek.name": "kek1",
"encrypt.kms.type": "local-kms",
"encrypt.kms.key.id": "mykey",
},
OnFailure: "ERROR,NONE",
}
ruleSet := schemaregistry.RuleSet{
DomainRules: []schemaregistry.Rule{encRule},
}

info := schemaregistry.SchemaInfo{
Schema: demoSchema,
SchemaType: "JSON",
RuleSet: &ruleSet,
}

id, err := client.Register("topic1-value", info, false)
serde.MaybeFail("Schema registration", err)
if id <= 0 {
t.Errorf("Expected valid schema id, found %d", id)
}

obj := make(map[string]interface{})
obj["IntField"] = 123
obj["DoubleField"] = 45.67
obj["StringField"] = "hi"
obj["BoolField"] = true
obj["BytesField"] = base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 1})

bytes, err := ser.Serialize("topic1", &obj)
serde.MaybeFail("serialization", err)

// Reset encrypted field
obj["StringField"] = "hi"
obj["BytesField"] = base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 1})

// JSON decoding produces floats
obj["IntField"] = 123.0

deserConfig := NewDeserializerConfig()
deserConfig.RuleConfig = map[string]string{
"secret": "mysecret",
}
deser, err := NewDeserializer(client, serde.ValueSerde, deserConfig)
serde.MaybeFail("Deserializer configuration", err)
deser.Client = ser.Client

var newobj map[string]interface{}
err = deser.DeserializeInto("topic1", bytes, &newobj)
serde.MaybeFail("deserialization", err, serde.Expect(newobj, obj))
}

func TestJSONSchemaSerdeEncryption(t *testing.T) {
serde.MaybeFail = serde.InitFailFunc(t)
var err error
Expand Down
44 changes: 32 additions & 12 deletions schemaregistry/serde/jsonschema/json_schema_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,31 @@ func transform(ctx serde.RuleContext, schema *jsonschema2.Schema, path string, m
switch typ {
case serde.TypeRecord:
val := deref(msg)
fieldByNames := fieldByNames(val)
for propName, propSchema := range schema.Properties {
structField, ok := fieldByNames[propName]
if !ok {
return nil, fmt.Errorf("json: missing field %s", propName)
if val.Kind() == reflect.Struct {
fieldByNames := fieldByNames(val)
for propName, propSchema := range schema.Properties {
structField, ok := fieldByNames[propName]
if !ok {
return nil, fmt.Errorf("json: missing field %s", propName)
}
err := transformField(ctx, path, propName, structField, val, propSchema, fieldTransform)
if err != nil {
return nil, err
}
}
err := transformField(ctx, path, propName, structField, val, propSchema, fieldTransform)
if err != nil {
return nil, err
return msg, nil
} else if val.Kind() == reflect.Map {
for propName, propSchema := range schema.Properties {
mapField := val.MapIndex(reflect.ValueOf(propName))
err := transformField(ctx, path, propName, &mapField, val, propSchema, fieldTransform)
if err != nil {
return nil, err
}
}
return msg, nil
} else {
return nil, fmt.Errorf("message of kind %s is not a struct or map", val.Kind())
}
return msg, nil
case serde.TypeEnum, serde.TypeString, serde.TypeInt, serde.TypeDouble, serde.TypeBoolean:
if fieldCtx != nil {
ruleTags := ctx.Rule.Tags
Expand Down Expand Up @@ -140,9 +153,13 @@ func transformField(ctx serde.RuleContext, path string, propName string, structF
}
}
} else {
err = setField(structField, newVal)
if err != nil {
return err
if val.Kind() == reflect.Struct {
err = setField(structField, newVal)
if err != nil {
return err
}
} else if val.Kind() == reflect.Map {
val.SetMapIndex(reflect.ValueOf(propName), *newVal)
}
}
return nil
Expand Down Expand Up @@ -238,6 +255,9 @@ func disjoint(slice1 []string, map1 map[string]bool) bool {
}

func getField(msg *reflect.Value, name string) (*reflect.Value, error) {
if msg.Kind() != reflect.Struct {
return nil, fmt.Errorf("message is not a struct")
}
fieldVal := msg.FieldByName(name)
return &fieldVal, nil
}
Expand Down

0 comments on commit 8d502ae

Please sign in to comment.