Skip to content

Commit

Permalink
Fix slice of structs TextUnmarshaler. (#103)
Browse files Browse the repository at this point in the history
Fix handling of situation where a slice of structs implments the
encoding.TextUnmarshaler interface, previously it would return "invalid
path" error.

Includes, some minor refactoring and documentation to clarify
`isUnmarshaler` output.
  • Loading branch information
emil2k authored and kisielk committed Mar 7, 2018
1 parent afe7739 commit b0e8c20
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 42 deletions.
36 changes: 22 additions & 14 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (c *cache) parsePath(p string, t reflect.Type) ([]pathPart, error) {
}
// Valid field. Append index.
path = append(path, field.name)
if field.ss {
if field.isSliceOfStructs && (!field.unmarshalerInfo.IsValid || (field.unmarshalerInfo.IsValid && field.unmarshalerInfo.IsSliceElement)) {
// Parse a special case: slices of structs.
// i+1 must be the slice index.
//
Expand Down Expand Up @@ -142,7 +142,7 @@ func (c *cache) create(t reflect.Type, info *structInfo) *structInfo {
c.create(ft, info)
for _, fi := range info.fields[bef:len(info.fields)] {
// exclude required check because duplicated to embedded field
fi.required = false
fi.isRequired = false
}
}
}
Expand All @@ -162,6 +162,7 @@ func (c *cache) createField(field reflect.StructField, info *structInfo) {
// First let's get the basic type.
isSlice, isStruct := false, false
ft := field.Type
m := isTextUnmarshaler(reflect.Zero(ft))
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
Expand All @@ -185,12 +186,13 @@ func (c *cache) createField(field reflect.StructField, info *structInfo) {
}

info.fields = append(info.fields, &fieldInfo{
typ: field.Type,
name: field.Name,
ss: isSlice && isStruct,
alias: alias,
anon: field.Anonymous,
required: options.Contains("required"),
typ: field.Type,
name: field.Name,
alias: alias,
unmarshalerInfo: m,
isSliceOfStructs: isSlice && isStruct,
isAnonymous: field.Anonymous,
isRequired: options.Contains("required"),
})
}

Expand All @@ -215,12 +217,18 @@ func (i *structInfo) get(alias string) *fieldInfo {
}

type fieldInfo struct {
typ reflect.Type
name string // field name in the struct.
ss bool // true if this is a slice of structs.
alias string
anon bool // is an embedded field
required bool // tag option
typ reflect.Type
// name is the field name in the struct.
name string
alias string
// unmarshalerInfo contains information regarding the
// encoding.TextUnmarshaler implementation of the field type.
unmarshalerInfo unmarshaler
// isSliceOfStructs indicates if the field type is a slice of structs.
isSliceOfStructs bool
// isAnonymous indicates whether the field is embedded in the struct.
isAnonymous bool
isRequired bool
}

type pathPart struct {
Expand Down
79 changes: 53 additions & 26 deletions decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func (d *Decoder) checkRequired(t reflect.Type, src map[string][]string, prefix
if f.typ.Kind() == reflect.Struct {
err := d.checkRequired(f.typ, src, prefix+f.alias+".")
if err != nil {
if !f.anon {
if !f.isAnonymous {
return err
}
// check embedded parent field.
Expand All @@ -116,7 +116,7 @@ func (d *Decoder) checkRequired(t reflect.Type, src map[string][]string, prefix
}
}
}
if f.required {
if f.isRequired {
key := f.alias
if prefix != "" {
key = prefix + key
Expand Down Expand Up @@ -185,7 +185,7 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values
// Get the converter early in case there is one for a slice type.
conv := d.cache.converter(t)
m := isTextUnmarshaler(v)
if conv == nil && t.Kind() == reflect.Slice && m.IsSlice {
if conv == nil && t.Kind() == reflect.Slice && m.IsSliceElement {
var items []reflect.Value
elemT := t.Elem()
isPtrElem := elemT.Kind() == reflect.Ptr
Expand All @@ -211,7 +211,7 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values
}
} else if m.IsValid {
u := reflect.New(elemT)
if m.IsPtr {
if m.IsSliceElementPtr {
u = reflect.New(reflect.PtrTo(elemT).Elem())
}
if err := u.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(value)); err != nil {
Expand All @@ -222,7 +222,7 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values
Err: err,
}
}
if m.IsPtr {
if m.IsSliceElementPtr {
items = append(items, u.Elem().Addr())
} else if u.Kind() == reflect.Ptr {
items = append(items, u.Elem())
Expand Down Expand Up @@ -298,14 +298,27 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values
}
}
} else if m.IsValid {
// If the value implements the encoding.TextUnmarshaler interface
// apply UnmarshalText as the converter
if err := m.Unmarshaler.UnmarshalText([]byte(val)); err != nil {
return ConversionError{
Key: path,
Type: t,
Index: -1,
Err: err,
if m.IsPtr {
u := reflect.New(v.Type())
if err := u.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(val)); err != nil {
return ConversionError{
Key: path,
Type: t,
Index: -1,
Err: err,
}
}
v.Set(reflect.Indirect(u))
} else {
// If the value implements the encoding.TextUnmarshaler interface
// apply UnmarshalText as the converter
if err := m.Unmarshaler.UnmarshalText([]byte(val)); err != nil {
return ConversionError{
Key: path,
Type: t,
Index: -1,
Err: err,
}
}
}
} else if conv := builtinConverters[t.Kind()]; conv != nil {
Expand All @@ -326,31 +339,36 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values
}

func isTextUnmarshaler(v reflect.Value) unmarshaler {

// Create a new unmarshaller instance
m := unmarshaler{}

// As the UnmarshalText function should be applied
// to the pointer of the type, we convert the value to pointer.
if v.CanAddr() {
v = v.Addr()
}
if m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler); m.IsValid {
return m
}
// As the UnmarshalText function should be applied to the pointer of the
// type, we check that type to see if it implements the necessary
// method.
if m.Unmarshaler, m.IsValid = reflect.New(v.Type()).Interface().(encoding.TextUnmarshaler); m.IsValid {
m.IsPtr = true
return m
}

// if v is []T or *[]T create new T
t := v.Type()
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() == reflect.Slice {
// if t is a pointer slice, check if it implements encoding.TextUnmarshaler
m.IsSlice = true
// Check if the slice implements encoding.TextUnmarshaller
if m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler); m.IsValid {
return m
}
// If t is a pointer slice, check if its elements implement
// encoding.TextUnmarshaler
m.IsSliceElement = true
if t = t.Elem(); t.Kind() == reflect.Ptr {
t = reflect.PtrTo(t.Elem())
v = reflect.Zero(t)
m.IsPtr = true
m.IsSliceElementPtr = true
m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler)
return m
}
Expand All @@ -365,9 +383,18 @@ func isTextUnmarshaler(v reflect.Value) unmarshaler {
// unmarshaller contains information about a TextUnmarshaler type
type unmarshaler struct {
Unmarshaler encoding.TextUnmarshaler
IsSlice bool
IsPtr bool
IsValid bool
// IsValid indicates whether the resolved type indicated by the other
// flags implements the encoding.TextUnmarshaler interface.
IsValid bool
// IsPtr indicates that the resolved type is the pointer of the original
// type.
IsPtr bool
// IsSliceElement indicates that the resolved type is a slice element of
// the original type.
IsSliceElement bool
// IsSliceElementPtr indicates that the resolved type is a pointer to a
// slice element of the original type.
IsSliceElementPtr bool
}

// Errors ---------------------------------------------------------------------
Expand Down
48 changes: 47 additions & 1 deletion decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1684,10 +1684,56 @@ func TestTextUnmarshalerTypeSlice(t *testing.T) {
}{}
decoder := NewDecoder()
if err := decoder.Decode(&s, data); err != nil {
t.Error("Error while decoding:", err)
t.Fatal("Error while decoding:", err)
}
expected := S20{"a", "b", "c"}
if !reflect.DeepEqual(expected, s.Value) {
t.Errorf("Expected %v errors, got %v", expected, s.Value)
}
}

type S21E struct{ ElementValue string }

func (e *S21E) UnmarshalText(text []byte) error {
*e = S21E{"x"}
return nil
}

type S21 []S21E

func (s *S21) UnmarshalText(text []byte) error {
*s = S21{{"a"}}
return nil
}

type S21B []S21E

// Test to ensure that if custom type base on a slice of structs implements an
// encoding.TextUnmarshaler interface it is unaffected by the special path
// requirements imposed on a slice of structs.
func TestTextUnmarshalerTypeSliceOfStructs(t *testing.T) {
data := map[string][]string{
"Value": []string{"raw a"},
}
// Implements encoding.TextUnmarshaler, should not throw invalid path
// error.
s := struct {
Value S21
}{}
decoder := NewDecoder()
if err := decoder.Decode(&s, data); err != nil {
t.Fatal("Error while decoding:", err)
}
expected := S21{{"a"}}
if !reflect.DeepEqual(expected, s.Value) {
t.Errorf("Expected %v errors, got %v", expected, s.Value)
}
// Does not implement encoding.TextUnmarshaler, should throw invalid
// path error.
sb := struct {
Value S21B
}{}
if err := decoder.Decode(&sb, data); err == invalidPath {
t.Fatal("Expecting invalid path error", err)
}
}
1 change: 0 additions & 1 deletion encoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,5 @@ func TestRegisterEncoderCustomArrayType(t *testing.T) {
})

encoder.Encode(s, vals)
t.Log(vals)
}
}

0 comments on commit b0e8c20

Please sign in to comment.