From b0e8c201c781a2ff27702d53885093aefb15b623 Mon Sep 17 00:00:00 2001 From: Emil Davtyan Date: Wed, 7 Mar 2018 16:51:38 +0100 Subject: [PATCH] Fix slice of structs TextUnmarshaler. (#103) Fix handling of situation where a slice of structs implments the encoding.TextUnmarshaler interface, previously it would return "invalid path" error. Includes, some minor refactoring and documentation to clarify `isUnmarshaler` output. --- cache.go | 36 +++++++++++++--------- decoder.go | 79 +++++++++++++++++++++++++++++++++---------------- decoder_test.go | 48 +++++++++++++++++++++++++++++- encoder_test.go | 1 - 4 files changed, 122 insertions(+), 42 deletions(-) diff --git a/cache.go b/cache.go index afa20a3..73b75f4 100644 --- a/cache.go +++ b/cache.go @@ -63,7 +63,7 @@ func (c *cache) parsePath(p string, t reflect.Type) ([]pathPart, error) { } // Valid field. Append index. path = append(path, field.name) - if field.ss { + if field.isSliceOfStructs && (!field.unmarshalerInfo.IsValid || (field.unmarshalerInfo.IsValid && field.unmarshalerInfo.IsSliceElement)) { // Parse a special case: slices of structs. // i+1 must be the slice index. // @@ -142,7 +142,7 @@ func (c *cache) create(t reflect.Type, info *structInfo) *structInfo { c.create(ft, info) for _, fi := range info.fields[bef:len(info.fields)] { // exclude required check because duplicated to embedded field - fi.required = false + fi.isRequired = false } } } @@ -162,6 +162,7 @@ func (c *cache) createField(field reflect.StructField, info *structInfo) { // First let's get the basic type. isSlice, isStruct := false, false ft := field.Type + m := isTextUnmarshaler(reflect.Zero(ft)) if ft.Kind() == reflect.Ptr { ft = ft.Elem() } @@ -185,12 +186,13 @@ func (c *cache) createField(field reflect.StructField, info *structInfo) { } info.fields = append(info.fields, &fieldInfo{ - typ: field.Type, - name: field.Name, - ss: isSlice && isStruct, - alias: alias, - anon: field.Anonymous, - required: options.Contains("required"), + typ: field.Type, + name: field.Name, + alias: alias, + unmarshalerInfo: m, + isSliceOfStructs: isSlice && isStruct, + isAnonymous: field.Anonymous, + isRequired: options.Contains("required"), }) } @@ -215,12 +217,18 @@ func (i *structInfo) get(alias string) *fieldInfo { } type fieldInfo struct { - typ reflect.Type - name string // field name in the struct. - ss bool // true if this is a slice of structs. - alias string - anon bool // is an embedded field - required bool // tag option + typ reflect.Type + // name is the field name in the struct. + name string + alias string + // unmarshalerInfo contains information regarding the + // encoding.TextUnmarshaler implementation of the field type. + unmarshalerInfo unmarshaler + // isSliceOfStructs indicates if the field type is a slice of structs. + isSliceOfStructs bool + // isAnonymous indicates whether the field is embedded in the struct. + isAnonymous bool + isRequired bool } type pathPart struct { diff --git a/decoder.go b/decoder.go index e49b53c..16ece12 100644 --- a/decoder.go +++ b/decoder.go @@ -106,7 +106,7 @@ func (d *Decoder) checkRequired(t reflect.Type, src map[string][]string, prefix if f.typ.Kind() == reflect.Struct { err := d.checkRequired(f.typ, src, prefix+f.alias+".") if err != nil { - if !f.anon { + if !f.isAnonymous { return err } // check embedded parent field. @@ -116,7 +116,7 @@ func (d *Decoder) checkRequired(t reflect.Type, src map[string][]string, prefix } } } - if f.required { + if f.isRequired { key := f.alias if prefix != "" { key = prefix + key @@ -185,7 +185,7 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values // Get the converter early in case there is one for a slice type. conv := d.cache.converter(t) m := isTextUnmarshaler(v) - if conv == nil && t.Kind() == reflect.Slice && m.IsSlice { + if conv == nil && t.Kind() == reflect.Slice && m.IsSliceElement { var items []reflect.Value elemT := t.Elem() isPtrElem := elemT.Kind() == reflect.Ptr @@ -211,7 +211,7 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values } } else if m.IsValid { u := reflect.New(elemT) - if m.IsPtr { + if m.IsSliceElementPtr { u = reflect.New(reflect.PtrTo(elemT).Elem()) } if err := u.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(value)); err != nil { @@ -222,7 +222,7 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values Err: err, } } - if m.IsPtr { + if m.IsSliceElementPtr { items = append(items, u.Elem().Addr()) } else if u.Kind() == reflect.Ptr { items = append(items, u.Elem()) @@ -298,14 +298,27 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values } } } else if m.IsValid { - // If the value implements the encoding.TextUnmarshaler interface - // apply UnmarshalText as the converter - if err := m.Unmarshaler.UnmarshalText([]byte(val)); err != nil { - return ConversionError{ - Key: path, - Type: t, - Index: -1, - Err: err, + if m.IsPtr { + u := reflect.New(v.Type()) + if err := u.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(val)); err != nil { + return ConversionError{ + Key: path, + Type: t, + Index: -1, + Err: err, + } + } + v.Set(reflect.Indirect(u)) + } else { + // If the value implements the encoding.TextUnmarshaler interface + // apply UnmarshalText as the converter + if err := m.Unmarshaler.UnmarshalText([]byte(val)); err != nil { + return ConversionError{ + Key: path, + Type: t, + Index: -1, + Err: err, + } } } } else if conv := builtinConverters[t.Kind()]; conv != nil { @@ -326,18 +339,18 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values } func isTextUnmarshaler(v reflect.Value) unmarshaler { - // Create a new unmarshaller instance m := unmarshaler{} - - // As the UnmarshalText function should be applied - // to the pointer of the type, we convert the value to pointer. - if v.CanAddr() { - v = v.Addr() - } if m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler); m.IsValid { return m } + // As the UnmarshalText function should be applied to the pointer of the + // type, we check that type to see if it implements the necessary + // method. + if m.Unmarshaler, m.IsValid = reflect.New(v.Type()).Interface().(encoding.TextUnmarshaler); m.IsValid { + m.IsPtr = true + return m + } // if v is []T or *[]T create new T t := v.Type() @@ -345,12 +358,17 @@ func isTextUnmarshaler(v reflect.Value) unmarshaler { t = t.Elem() } if t.Kind() == reflect.Slice { - // if t is a pointer slice, check if it implements encoding.TextUnmarshaler - m.IsSlice = true + // Check if the slice implements encoding.TextUnmarshaller + if m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler); m.IsValid { + return m + } + // If t is a pointer slice, check if its elements implement + // encoding.TextUnmarshaler + m.IsSliceElement = true if t = t.Elem(); t.Kind() == reflect.Ptr { t = reflect.PtrTo(t.Elem()) v = reflect.Zero(t) - m.IsPtr = true + m.IsSliceElementPtr = true m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler) return m } @@ -365,9 +383,18 @@ func isTextUnmarshaler(v reflect.Value) unmarshaler { // unmarshaller contains information about a TextUnmarshaler type type unmarshaler struct { Unmarshaler encoding.TextUnmarshaler - IsSlice bool - IsPtr bool - IsValid bool + // IsValid indicates whether the resolved type indicated by the other + // flags implements the encoding.TextUnmarshaler interface. + IsValid bool + // IsPtr indicates that the resolved type is the pointer of the original + // type. + IsPtr bool + // IsSliceElement indicates that the resolved type is a slice element of + // the original type. + IsSliceElement bool + // IsSliceElementPtr indicates that the resolved type is a pointer to a + // slice element of the original type. + IsSliceElementPtr bool } // Errors --------------------------------------------------------------------- diff --git a/decoder_test.go b/decoder_test.go index 18b3b32..ac4898a 100644 --- a/decoder_test.go +++ b/decoder_test.go @@ -1684,10 +1684,56 @@ func TestTextUnmarshalerTypeSlice(t *testing.T) { }{} decoder := NewDecoder() if err := decoder.Decode(&s, data); err != nil { - t.Error("Error while decoding:", err) + t.Fatal("Error while decoding:", err) } expected := S20{"a", "b", "c"} if !reflect.DeepEqual(expected, s.Value) { t.Errorf("Expected %v errors, got %v", expected, s.Value) } } + +type S21E struct{ ElementValue string } + +func (e *S21E) UnmarshalText(text []byte) error { + *e = S21E{"x"} + return nil +} + +type S21 []S21E + +func (s *S21) UnmarshalText(text []byte) error { + *s = S21{{"a"}} + return nil +} + +type S21B []S21E + +// Test to ensure that if custom type base on a slice of structs implements an +// encoding.TextUnmarshaler interface it is unaffected by the special path +// requirements imposed on a slice of structs. +func TestTextUnmarshalerTypeSliceOfStructs(t *testing.T) { + data := map[string][]string{ + "Value": []string{"raw a"}, + } + // Implements encoding.TextUnmarshaler, should not throw invalid path + // error. + s := struct { + Value S21 + }{} + decoder := NewDecoder() + if err := decoder.Decode(&s, data); err != nil { + t.Fatal("Error while decoding:", err) + } + expected := S21{{"a"}} + if !reflect.DeepEqual(expected, s.Value) { + t.Errorf("Expected %v errors, got %v", expected, s.Value) + } + // Does not implement encoding.TextUnmarshaler, should throw invalid + // path error. + sb := struct { + Value S21B + }{} + if err := decoder.Decode(&sb, data); err == invalidPath { + t.Fatal("Expecting invalid path error", err) + } +} diff --git a/encoder_test.go b/encoder_test.go index ac4cd61..e03b2d0 100644 --- a/encoder_test.go +++ b/encoder_test.go @@ -415,6 +415,5 @@ func TestRegisterEncoderCustomArrayType(t *testing.T) { }) encoder.Encode(s, vals) - t.Log(vals) } }