diff --git a/decode.go b/decode.go index 4126022..5105645 100644 --- a/decode.go +++ b/decode.go @@ -433,6 +433,10 @@ func (d *decoder) prepare(n *Node, out reflect.Value) (newout reflect.Value, unm out = out.Elem() again = true } + if out.Kind() == reflect.Map && out.IsNil() { + out.Set(reflect.MakeMap(out.Type())) + again = true + } if out.CanAddr() { outi := out.Addr().Interface() if u, ok := outi.(Unmarshaler); ok { diff --git a/decode_test.go b/decode_test.go index 2df6c7a..413ed12 100644 --- a/decode_test.go +++ b/decode_test.go @@ -1716,6 +1716,81 @@ func TestObsoleteUnmarshalerRetry(t *testing.T) { a.Equal(obsoleteSliceUnmarshaler([]int{1}), su) } +type inlineMapUnmarshaler map[string]string + +func (m inlineMapUnmarshaler) UnmarshalYAML(node *yaml.Node) error { + if node.Kind != yaml.MappingNode || len(node.Content)%2 != 0 { + return errors.New("not a map") + } + for i := 0; i < len(node.Content); i += 2 { + key := node.Content[i].Value + value := node.Content[i+1].Value + m[key] = value + } + return nil +} + +type ptrInlineMapUnmarshaler map[string]string + +func (m *ptrInlineMapUnmarshaler) UnmarshalYAML(node *yaml.Node) error { + if node.Kind != yaml.MappingNode || len(node.Content)%2 != 0 { + return errors.New("not a map") + } + *m = make(map[string]string) + for i := 0; i < len(node.Content); i += 2 { + key := node.Content[i].Value + value := node.Content[i+1].Value + (*m)[key] = value + } + return nil +} + +// Check UnmarshalYAML is called on inline maps. +// +// See https://github.com/go-yaml/yaml/issues/742. +func TestUnmarshalerInlineMap(t *testing.T) { + a := require.New(t) + + var ( + direct struct { + A int `yaml:"a"` + InlineMap inlineMapUnmarshaler `yaml:",inline"` + } + directByPtr struct { + A int `yaml:"a"` + InlineMap *inlineMapUnmarshaler `yaml:",inline"` + } + ptr struct { + A int `yaml:"a"` + InlineMap ptrInlineMapUnmarshaler `yaml:",inline"` + } + ptrByPtr struct { + A int `yaml:"a"` + InlineMap *ptrInlineMapUnmarshaler `yaml:",inline"` + } + ) + input := "a: 1\nb: 2\nc: 3" + + a.NoError(yaml.Unmarshal([]byte(input), &direct)) + a.NoError(yaml.Unmarshal([]byte(input), &directByPtr)) + a.NoError(yaml.Unmarshal([]byte(input), &ptr)) + a.NoError(yaml.Unmarshal([]byte(input), &ptrByPtr)) + + a.Equal(direct.A, 1) + a.Equal(directByPtr.A, 1) + a.Equal(ptr.A, 1) + a.Equal(ptrByPtr.A, 1) + + // Inline map semantic does not apply here, since the field implements Unmarshaler. + // + // So, the decoder does not exclude the "a" field from the node. + expect := map[string]string{"a": "1", "b": "2", "c": "3"} + a.Equal(inlineMapUnmarshaler(expect), direct.InlineMap) + a.Equal(inlineMapUnmarshaler(expect), *directByPtr.InlineMap) + a.Equal(ptrInlineMapUnmarshaler(expect), ptr.InlineMap) + a.Equal(ptrInlineMapUnmarshaler(expect), *ptrByPtr.InlineMap) +} + // From http://yaml.org/type/merge.html var mergeTests = ` anchors: diff --git a/yaml.go b/yaml.go index c2f4ada..0cbf65c 100644 --- a/yaml.go +++ b/yaml.go @@ -120,23 +120,24 @@ func getStructInfo(st reflect.Type) (*structInfo, error) { } if inline { - switch field.Type.Kind() { + ftype := field.Type + for ftype.Kind() == reflect.Ptr { + ftype = ftype.Elem() + } + switch ftype.Kind() { case reflect.Map: + if reflect.PtrTo(ftype).Implements(unmarshalerType) { + inlineUnmarshalers = append(inlineUnmarshalers, []int{i}) + break + } if inlineMap >= 0 { return nil, errors.New("multiple ,inline maps in struct " + st.String()) } - if field.Type.Key() != reflect.TypeOf("") { + if ftype.Key() != reflect.TypeOf("") { return nil, errors.New("option ,inline needs a map with string keys in struct " + st.String()) } inlineMap = info.Num - case reflect.Struct, reflect.Ptr: - ftype := field.Type - for ftype.Kind() == reflect.Ptr { - ftype = ftype.Elem() - } - if ftype.Kind() != reflect.Struct { - return nil, errors.New("option ,inline may only be used on a struct or map field") - } + case reflect.Struct: if reflect.PtrTo(ftype).Implements(unmarshalerType) { inlineUnmarshalers = append(inlineUnmarshalers, []int{i}) } else {