From 8d502ae40624a59fd2e4af18e9a36ff9458621df Mon Sep 17 00:00:00 2001 From: Robert Yokota Date: Mon, 28 Oct 2024 19:01:19 -0700 Subject: [PATCH] Support transforming maps in Data Contract rules (#1324) * Support transforming maps in Data Contract rules * Enhance log msg --- schemaregistry/serde/avrov2/avro_test.go | 75 ++++++++++++++++++ schemaregistry/serde/avrov2/avro_util.go | 44 ++++++++--- .../serde/jsonschema/json_schema_test.go | 77 +++++++++++++++++++ .../serde/jsonschema/json_schema_util.go | 44 ++++++++--- 4 files changed, 216 insertions(+), 24 deletions(-) diff --git a/schemaregistry/serde/avrov2/avro_test.go b/schemaregistry/serde/avrov2/avro_test.go index 664f91247..8107fc40d 100644 --- a/schemaregistry/serde/avrov2/avro_test.go +++ b/schemaregistry/serde/avrov2/avro_test.go @@ -1181,6 +1181,81 @@ func TestAvroSerdeEncryption(t *testing.T) { serde.MaybeFail("deserialization", err, serde.Expect(newobj, &obj)) } +func TestAvroSerdeEncryptionWithSimpleMap(t *testing.T) { + serde.MaybeFail = serde.InitFailFunc(t) + var err error + + conf := schemaregistry.NewConfig("mock://") + + client, err := schemaregistry.NewClient(conf) + serde.MaybeFail("Schema Registry configuration", err) + + serConfig := NewSerializerConfig() + serConfig.AutoRegisterSchemas = false + serConfig.UseLatestVersion = true + serConfig.RuleConfig = map[string]string{ + "secret": "mysecret", + } + ser, err := NewSerializer(client, serde.ValueSerde, serConfig) + serde.MaybeFail("Serializer configuration", err) + + encRule := schemaregistry.Rule{ + Name: "test-encrypt", + Kind: "TRANSFORM", + Mode: "WRITEREAD", + Type: "ENCRYPT", + Tags: []string{"PII"}, + Params: map[string]string{ + "encrypt.kek.name": "kek1", + "encrypt.kms.type": "local-kms", + "encrypt.kms.key.id": "mykey", + }, + OnFailure: "ERROR,NONE", + } + ruleSet := schemaregistry.RuleSet{ + DomainRules: []schemaregistry.Rule{encRule}, + } + + info := schemaregistry.SchemaInfo{ + Schema: demoSchema, + SchemaType: "AVRO", + RuleSet: &ruleSet, + } + + id, err := client.Register("topic1-value", info, false) + serde.MaybeFail("Schema registration", err) + if id <= 0 { + t.Errorf("Expected valid schema id, found %d", id) + } + + obj := make(map[string]interface{}) + obj["IntField"] = 123 + obj["DoubleField"] = 45.67 + obj["StringField"] = "hi" + obj["BoolField"] = true + obj["BytesField"] = []byte{1, 2} + + bytes, err := ser.Serialize("topic1", &obj) + serde.MaybeFail("serialization", err) + + // Reset encrypted field + obj["StringField"] = "hi" + obj["BytesField"] = []byte{1, 2} + + deserConfig := NewDeserializerConfig() + deserConfig.RuleConfig = map[string]string{ + "secret": "mysecret", + } + deser, err := NewDeserializer(client, serde.ValueSerde, deserConfig) + serde.MaybeFail("Deserializer configuration", err) + deser.Client = ser.Client + deser.MessageFactory = testMessageFactory + + var newobj map[string]interface{} + err = deser.DeserializeInto("topic1", bytes, &newobj) + serde.MaybeFail("deserialization into", err, serde.Expect(newobj, obj)) +} + func TestAvroSerdeEncryptionDekRotation(t *testing.T) { f := fakeClock{now: time.Now().UnixMilli()} executor := encryption.RegisterWithClock(&f) diff --git a/schemaregistry/serde/avrov2/avro_util.go b/schemaregistry/serde/avrov2/avro_util.go index 97471e4d0..5075b83a7 100644 --- a/schemaregistry/serde/avrov2/avro_util.go +++ b/schemaregistry/serde/avrov2/avro_util.go @@ -76,19 +76,32 @@ func transform(ctx serde.RuleContext, resolver *avro.TypeResolver, schema avro.S return msg, nil case *avro.RecordSchema: val := deref(msg) - fieldByNames := fieldByNames(val) recordSchema := schema.(*avro.RecordSchema) - for _, avroField := range recordSchema.Fields() { - structField, ok := fieldByNames[avroField.Name()] - if !ok { - return nil, fmt.Errorf("avro: missing field %s", avroField.Name()) + if val.Kind() == reflect.Struct { + fieldByNames := fieldByNames(val) + for _, avroField := range recordSchema.Fields() { + structField, ok := fieldByNames[avroField.Name()] + if !ok { + return nil, fmt.Errorf("avro: missing field %s", avroField.Name()) + } + err := transformField(ctx, resolver, recordSchema, avroField, structField, val, fieldTransform) + if err != nil { + return nil, err + } } - err := transformField(ctx, resolver, recordSchema, avroField, structField, val, fieldTransform) - if err != nil { - return nil, err + return msg, nil + } else if val.Kind() == reflect.Map { + for _, avroField := range recordSchema.Fields() { + mapField := val.MapIndex(reflect.ValueOf(avroField.Name())) + err := transformField(ctx, resolver, recordSchema, avroField, &mapField, val, fieldTransform) + if err != nil { + return nil, err + } } + return msg, nil + } else { + return nil, fmt.Errorf("message of kind %s is not a struct or map", val.Kind()) } - return msg, nil default: if fieldCtx != nil { ruleTags := ctx.Rule.Tags @@ -137,9 +150,13 @@ func transformField(ctx serde.RuleContext, resolver *avro.TypeResolver, recordSc } } } else { - err = setField(structField, newVal) - if err != nil { - return err + if val.Kind() == reflect.Struct { + err = setField(structField, newVal) + if err != nil { + return err + } + } else { + val.SetMapIndex(reflect.ValueOf(avroField.Name()), *newVal) } } return nil @@ -203,6 +220,9 @@ func disjoint(slice1 []string, map1 map[string]bool) bool { } func getField(msg *reflect.Value, name string) (*reflect.Value, error) { + if msg.Kind() != reflect.Struct { + return nil, fmt.Errorf("message is not a struct") + } fieldVal := msg.FieldByName(name) return &fieldVal, nil } diff --git a/schemaregistry/serde/jsonschema/json_schema_test.go b/schemaregistry/serde/jsonschema/json_schema_test.go index 0187f43aa..35ccc2c51 100644 --- a/schemaregistry/serde/jsonschema/json_schema_test.go +++ b/schemaregistry/serde/jsonschema/json_schema_test.go @@ -798,6 +798,83 @@ func TestJSONSchemaSerdeWithCELFieldConditionFail(t *testing.T) { serde.MaybeFail("serialization", nil, serde.Expect(ruleErr, serde.RuleConditionErr{Rule: &encRule})) } +func TestJSONSchemaSerdeEncryptionWithSimpleMap(t *testing.T) { + serde.MaybeFail = serde.InitFailFunc(t) + var err error + + conf := schemaregistry.NewConfig("mock://") + + client, err := schemaregistry.NewClient(conf) + serde.MaybeFail("Schema Registry configuration", err) + + serConfig := NewSerializerConfig() + serConfig.AutoRegisterSchemas = false + serConfig.UseLatestVersion = true + serConfig.RuleConfig = map[string]string{ + "secret": "mysecret", + } + ser, err := NewSerializer(client, serde.ValueSerde, serConfig) + serde.MaybeFail("Serializer configuration", err) + + encRule := schemaregistry.Rule{ + Name: "test-encrypt", + Kind: "TRANSFORM", + Mode: "WRITEREAD", + Type: "ENCRYPT", + Tags: []string{"PII"}, + Params: map[string]string{ + "encrypt.kek.name": "kek1", + "encrypt.kms.type": "local-kms", + "encrypt.kms.key.id": "mykey", + }, + OnFailure: "ERROR,NONE", + } + ruleSet := schemaregistry.RuleSet{ + DomainRules: []schemaregistry.Rule{encRule}, + } + + info := schemaregistry.SchemaInfo{ + Schema: demoSchema, + SchemaType: "JSON", + RuleSet: &ruleSet, + } + + id, err := client.Register("topic1-value", info, false) + serde.MaybeFail("Schema registration", err) + if id <= 0 { + t.Errorf("Expected valid schema id, found %d", id) + } + + obj := make(map[string]interface{}) + obj["IntField"] = 123 + obj["DoubleField"] = 45.67 + obj["StringField"] = "hi" + obj["BoolField"] = true + obj["BytesField"] = base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 1}) + + bytes, err := ser.Serialize("topic1", &obj) + serde.MaybeFail("serialization", err) + + // Reset encrypted field + obj["StringField"] = "hi" + obj["BytesField"] = base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 1}) + + // JSON decoding produces floats + obj["IntField"] = 123.0 + + deserConfig := NewDeserializerConfig() + deserConfig.RuleConfig = map[string]string{ + "secret": "mysecret", + } + deser, err := NewDeserializer(client, serde.ValueSerde, deserConfig) + serde.MaybeFail("Deserializer configuration", err) + deser.Client = ser.Client + + var newobj map[string]interface{} + err = deser.DeserializeInto("topic1", bytes, &newobj) + serde.MaybeFail("deserialization", err, serde.Expect(newobj, obj)) +} + func TestJSONSchemaSerdeEncryption(t *testing.T) { serde.MaybeFail = serde.InitFailFunc(t) var err error diff --git a/schemaregistry/serde/jsonschema/json_schema_util.go b/schemaregistry/serde/jsonschema/json_schema_util.go index 8cb9e4894..2316f8eca 100644 --- a/schemaregistry/serde/jsonschema/json_schema_util.go +++ b/schemaregistry/serde/jsonschema/json_schema_util.go @@ -78,18 +78,31 @@ func transform(ctx serde.RuleContext, schema *jsonschema2.Schema, path string, m switch typ { case serde.TypeRecord: val := deref(msg) - fieldByNames := fieldByNames(val) - for propName, propSchema := range schema.Properties { - structField, ok := fieldByNames[propName] - if !ok { - return nil, fmt.Errorf("json: missing field %s", propName) + if val.Kind() == reflect.Struct { + fieldByNames := fieldByNames(val) + for propName, propSchema := range schema.Properties { + structField, ok := fieldByNames[propName] + if !ok { + return nil, fmt.Errorf("json: missing field %s", propName) + } + err := transformField(ctx, path, propName, structField, val, propSchema, fieldTransform) + if err != nil { + return nil, err + } } - err := transformField(ctx, path, propName, structField, val, propSchema, fieldTransform) - if err != nil { - return nil, err + return msg, nil + } else if val.Kind() == reflect.Map { + for propName, propSchema := range schema.Properties { + mapField := val.MapIndex(reflect.ValueOf(propName)) + err := transformField(ctx, path, propName, &mapField, val, propSchema, fieldTransform) + if err != nil { + return nil, err + } } + return msg, nil + } else { + return nil, fmt.Errorf("message of kind %s is not a struct or map", val.Kind()) } - return msg, nil case serde.TypeEnum, serde.TypeString, serde.TypeInt, serde.TypeDouble, serde.TypeBoolean: if fieldCtx != nil { ruleTags := ctx.Rule.Tags @@ -140,9 +153,13 @@ func transformField(ctx serde.RuleContext, path string, propName string, structF } } } else { - err = setField(structField, newVal) - if err != nil { - return err + if val.Kind() == reflect.Struct { + err = setField(structField, newVal) + if err != nil { + return err + } + } else if val.Kind() == reflect.Map { + val.SetMapIndex(reflect.ValueOf(propName), *newVal) } } return nil @@ -238,6 +255,9 @@ func disjoint(slice1 []string, map1 map[string]bool) bool { } func getField(msg *reflect.Value, name string) (*reflect.Value, error) { + if msg.Kind() != reflect.Struct { + return nil, fmt.Errorf("message is not a struct") + } fieldVal := msg.FieldByName(name) return &fieldVal, nil }