Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 187 additions & 20 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type Decoder struct {
parsedFile *ast.File
streamIndex int
decodeDepth int
initer Initer
}

// NewDecoder returns a new decoder that reads from r.
Expand All @@ -64,6 +65,7 @@ func NewDecoder(r io.Reader, opts ...DecodeOption) *Decoder {
disallowUnknownField: false,
allowDuplicateMapKey: false,
useOrderedMap: false,
initer: DefaultInitier,
}
}

Expand Down Expand Up @@ -476,21 +478,44 @@ func (d *Decoder) nodeToValue(ctx context.Context, node ast.Node) (any, error) {
}
iter := value.MapRange()
if d.useOrderedMap {
m := MapSlice{}
t := reflect.TypeOf(MapSlice{})
m, err := d.initer(ctx, d, n, t, reflect.Zero(t))
if err != nil {
return nil, err
}

mc, ok := m.(MapSlice)
if !ok {
return nil, errors.ErrTypeMismatch(t, reflect.TypeOf(m), n.GetToken())
}

for iter.Next() {
if err := d.setToOrderedMapValue(ctx, iter.KeyValue(), &m); err != nil {
if err := d.setToOrderedMapValue(ctx, iter.KeyValue(), &mc); err != nil {
return nil, err
}
}
return m, nil

return mc, nil
}

t := reflect.TypeOf(map[string]interface{}{})
m, err := d.initer(ctx, d, n, t, reflect.Zero(t))
if err != nil {
return nil, err
}
m := make(map[string]any)

mc, ok := m.(map[string]interface{})
if !ok {
return nil, errors.ErrTypeMismatch(t, reflect.TypeOf(m), n.GetToken())
}

for iter.Next() {
if err := d.setToMapValue(ctx, iter.KeyValue(), m); err != nil {
if err := d.setToMapValue(ctx, iter.KeyValue(), mc); err != nil {
return nil, err
}
}
return m, nil

return mc, nil
}
key, err := d.mapKeyNodeToString(ctx, n.Key)
if err != nil {
Expand All @@ -501,40 +526,101 @@ func (d *Decoder) nodeToValue(ctx context.Context, node ast.Node) (any, error) {
if err != nil {
return nil, err
}
return MapSlice{{Key: key, Value: v}}, nil

t := reflect.TypeOf(MapSlice{})
m, err := d.initer(ctx, d, n, t, reflect.Zero(t))
if err != nil {
return nil, err
}

mc, ok := m.(MapSlice)
if !ok {
return nil, errors.ErrTypeMismatch(t, reflect.TypeOf(m), n.GetToken())
}

return append(mc, MapItem{Key: key, Value: v}), nil
}

v, err := d.nodeToValue(ctx, n.Value)
if err != nil {
return nil, err
}
return map[string]interface{}{key: v}, nil

t := reflect.TypeOf(map[string]interface{}{})
m, err := d.initer(ctx, d, n, t, reflect.Zero(t))
if err != nil {
return nil, err
}

mc, ok := m.(map[string]interface{})
if !ok {
return nil, errors.ErrTypeMismatch(t, reflect.TypeOf(m), n.GetToken())
}

mc[key] = v

return mc, nil
case *ast.MappingNode:
if d.useOrderedMap {
m := make(MapSlice, 0, len(n.Values))
t := reflect.TypeOf(MapSlice{})
m, err := d.initer(ctx, d, n, t, reflect.Zero(t))
if err != nil {
return nil, err
}

mc, ok := m.(MapSlice)
if !ok {
return nil, errors.ErrTypeMismatch(t, reflect.TypeOf(m), n.GetToken())
}

for _, value := range n.Values {
if err := d.setToOrderedMapValue(ctx, value, &m); err != nil {
if err := d.setToOrderedMapValue(ctx, value, &mc); err != nil {
return nil, err
}
}
return m, nil

return mc, nil
}

t := reflect.TypeOf(map[string]interface{}{})
m, err := d.initer(ctx, d, n, t, reflect.Zero(t))
if err != nil {
return nil, err
}

mc, ok := m.(map[string]interface{})
if !ok {
return nil, errors.ErrTypeMismatch(t, reflect.TypeOf(m), n.GetToken())
}
m := make(map[string]interface{}, len(n.Values))

for _, value := range n.Values {
if err := d.setToMapValue(ctx, value, m); err != nil {
if err := d.setToMapValue(ctx, value, mc); err != nil {
return nil, err
}
}
return m, nil

return mc, nil
case *ast.SequenceNode:
v := make([]interface{}, 0, len(n.Values))
t := reflect.TypeOf([]interface{}{})
v, err := d.initer(ctx, d, n, t, reflect.Zero(t))
if err != nil {
return nil, err
}

vc, ok := v.([]interface{})
if !ok {
return nil, errors.ErrTypeMismatch(t, reflect.TypeOf(v), n.GetToken())
}

for _, value := range n.Values {
vv, err := d.nodeToValue(ctx, value)
if err != nil {
return nil, err
}
v = append(v, vv)
vc = append(vc, vv)
}
return v, nil

return vc, nil
}
return nil, nil
}
Expand Down Expand Up @@ -895,6 +981,22 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No
return nil
}

if !dst.IsZero() {
val, err := d.initer(ctx, d, src, dst.Type(), dst)
if err != nil {
return err
}

vOf := reflect.ValueOf(val)
if val != nil {
if vOf.Type() != dst.Type() {
return errors.ErrTypeMismatch(dst.Type(), vOf.Type(), src.GetToken())
}

dst.Set(vOf)
}
}

if src.Type() == ast.AnchorType {
anchor, _ := src.(*ast.AnchorNode)
anchorName := anchor.Name.GetToken().Value
Expand Down Expand Up @@ -1325,6 +1427,11 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N
dst.Set(srcValue)
return nil
}
c, err := d.initer(ctx, d, src, structType, dst)
if err != nil {
return err
}
dst.Set(reflect.ValueOf(c))
structFieldMap, err := structFieldMap(structType)
if err != nil {
return err
Expand Down Expand Up @@ -1574,9 +1681,19 @@ func (d *Decoder) decodeSlice(ctx context.Context, dst reflect.Value, src ast.No
if arrayNode == nil {
return nil
}
iter := arrayNode.ArrayRange()

sliceType := dst.Type()
sliceValue := reflect.MakeSlice(sliceType, 0, iter.Len())
v, err := d.initer(ctx, d, arrayNode, sliceType, dst)
if err != nil {
return err
}

sliceValue := reflect.ValueOf(v)
if sliceValue.Type() != sliceType {
return errors.ErrTypeMismatch(sliceType, sliceValue.Type(), src.GetToken())
}

iter := arrayNode.ArrayRange()
elemType := sliceType.Elem()

var foundErr error
Expand Down Expand Up @@ -1711,7 +1828,17 @@ func (d *Decoder) decodeMap(ctx context.Context, dst reflect.Value, src ast.Node
return err
}
mapType := dst.Type()
mapValue := reflect.MakeMap(mapType)

v, err := d.initer(ctx, d, mapNode, mapType, dst)
if err != nil {
return err
}

mapValue := reflect.ValueOf(v)
if mapValue.Type() != mapType {
return errors.ErrTypeMismatch(mapType, mapValue.Type(), src.GetToken())
}

keyType := mapValue.Type().Key()
valueType := mapValue.Type().Elem()
mapIter := mapNode.MapRange()
Expand Down Expand Up @@ -2024,3 +2151,43 @@ func (d *Decoder) DecodeFromNodeContext(ctx context.Context, node ast.Node, v in
}
return nil
}

func (d *Decoder) IsRecursiveDir() bool {
return d.isRecursiveDir
}

func (d *Decoder) IsResolvedReference() bool {
return d.isResolvedReference
}

func (d *Decoder) Validator() StructValidator {
return d.validator
}

func (d *Decoder) DisallowUnknownField() bool {
return d.disallowUnknownField
}

func (d *Decoder) AllowDuplicateMapKey() bool {
return d.allowDuplicateMapKey
}

func (d *Decoder) UseOrderedMap() bool {
return d.useOrderedMap
}

func (d *Decoder) UseJSONUnmarshaler() bool {
return d.useJSONUnmarshaler
}

func (d *Decoder) ParsedFileName() string {
return d.parsedFile.Name
}

func (d *Decoder) StreamIndex() int {
return d.streamIndex
}

func (d *Decoder) DecodeDepth() int {
return d.decodeDepth
}
71 changes: 71 additions & 0 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"reflect"

"github.com/goccy/go-yaml/ast"
"github.com/goccy/go-yaml/internal/errors"
"github.com/goccy/go-yaml/token"
)

// DecodeOption functional option type for Decoder
Expand Down Expand Up @@ -341,3 +343,72 @@ func CommentToMap(cm CommentMap) DecodeOption {
return nil
}
}

// Initializes a value before parsing. This is especially useful when using default values ​​for structures, initializing slices, or the like.
type Initer = func(ctx context.Context, d *Decoder, node interface{}, t reflect.Type, dst reflect.Value) (interface{}, error)

func DefaultInitier(ctx context.Context, d *Decoder, node interface{}, t reflect.Type, dst reflect.Value) (interface{}, error) {
switch t.Kind() {
case reflect.Slice:
var l int
switch n := node.(type) {
case *ast.SequenceNode:
l = len(n.Values)
case ast.ArrayNode:
iter := n.ArrayRange()
if iter != nil {
l = iter.Len()
}

}

return reflect.MakeSlice(t, 0, l).Interface(), nil
case reflect.Map:
if d.UseOrderedMap() {
var l int
if n, ok := node.(*ast.MappingNode); ok {
l = len(n.Values)
}

return make(MapSlice, 0, l), nil
}

return reflect.MakeMap(t).Interface(), nil
case reflect.Struct:
return dst.Interface(), nil
case reflect.Ptr:
el := dst.Elem()
res, err := d.initer(ctx, d, node, el.Type(), el)
if err != nil {
return nil, err
}

vOf := reflect.ValueOf(res)
if vOf.Type() != el.Type() {
var token *token.Token
if n, ok := node.(ast.Node); ok {
token = n.GetToken()
}

return nil, errors.ErrTypeMismatch(el.Type(), vOf.Type(), token)
}

el.Set(vOf)

return dst.Interface(), nil
}

return reflect.New(t).Elem().Interface(), nil
}

func WithIniter(i Initer) DecodeOption {
return func(d *Decoder) error {
if i == nil {
i = DefaultInitier
}

d.initer = i

return nil
}
}