From 668c17e7eaa1e7abf62d6b303fdfa33ac6c20b49 Mon Sep 17 00:00:00 2001 From: Tristan Cartledge Date: Mon, 4 Nov 2024 14:24:53 +1000 Subject: [PATCH] fix: fixed issue with core models not being updated on syncing changes from newly created models --- arazzo/arazzo.go | 10 +- arazzo/arazzo_test.go | 72 +++++++++- arazzo/core/arazzo.go | 5 + arazzo/core/criterion.go | 4 +- arazzo/core/reusable.go | 17 +-- arazzo/criterion/criterion.go | 4 +- arazzo/parameter.go | 2 +- arazzo/testdata/speakeasybar.arazzo.yaml | 10 +- arazzo/testdata/test.arazzo.yaml | 3 +- jsonschema/oas31/core/value.go | 4 +- jsonschema/oas31/core/value_test.go | 2 +- marshaller/node.go | 2 +- marshaller/syncer.go | 172 ++++++++++++++--------- marshaller/syncer_test.go | 126 ++++++++++++++--- marshaller/unmarshaller.go | 6 +- sequencedmap/map.go | 4 +- 16 files changed, 316 insertions(+), 127 deletions(-) diff --git a/arazzo/arazzo.go b/arazzo/arazzo.go index 0672b69..b42d53d 100644 --- a/arazzo/arazzo.go +++ b/arazzo/arazzo.go @@ -123,11 +123,19 @@ func (a *Arazzo) GetCore() *core.Arazzo { return &a.core } +// Sync will sync any changes made to the Arazzo document models back to the core models. +func (a *Arazzo) Sync(ctx context.Context) error { + if _, err := marshaller.SyncValue(ctx, a, &a.core, nil, false); err != nil { + return err + } + return nil +} + // Marshal will marshal the Arazzo document to the provided io.Writer. func (a *Arazzo) Marshal(ctx context.Context, w io.Writer) error { ctx = yml.ContextWithConfig(ctx, a.core.Config) - if _, err := marshaller.SyncValue(ctx, a, &a.core, nil); err != nil { + if _, err := marshaller.SyncValue(ctx, a, &a.core, nil, false); err != nil { return err } diff --git a/arazzo/arazzo_test.go b/arazzo/arazzo_test.go index 3e3c134..a0b1f2a 100644 --- a/arazzo/arazzo_test.go +++ b/arazzo/arazzo_test.go @@ -179,7 +179,7 @@ var testArazzoInstance = &arazzo.Arazzo{ Type: arazzo.FailureActionTypeRetry, RetryAfter: pointer.From(10.0), RetryLimit: pointer.From(3), - Criteria: []criterion.Criterion{{Condition: "$statusCode == 500", Type: criterion.CriterionTypeUnion{ + Criteria: []criterion.Criterion{{Context: pointer.From(expression.Expression("$statusCode")), Condition: "$statusCode == 500", Type: criterion.CriterionTypeUnion{ Type: pointer.From(criterion.CriterionTypeSimple), }}}, Valid: true, @@ -320,6 +320,9 @@ func TestArazzo_Mutate_Success(t *testing.T) { err = arazzo.Marshal(ctx, a, outBuf) require.NoError(t, err) + errs := a.Validate(ctx) + require.Empty(t, errs) + assert.Equal(t, `arazzo: 1.0.0 info: title: My updated workflow title @@ -383,7 +386,8 @@ components: retryAfter: 10 retryLimit: 3 criteria: - - condition: $statusCode == 500 + - context: $statusCode + condition: $statusCode == 500 x-test: some-value `, outBuf.String()) } @@ -391,9 +395,14 @@ x-test: some-value func TestArazzo_Create_Success(t *testing.T) { outBuf := bytes.NewBuffer([]byte{}) - err := arazzo.Marshal(context.Background(), testArazzoInstance, outBuf) + ctx := context.Background() + + err := arazzo.Marshal(ctx, testArazzoInstance, outBuf) require.NoError(t, err) + errs := testArazzoInstance.Validate(ctx) + require.Empty(t, errs) + data, err := os.ReadFile("testdata/test.arazzo.yaml") require.NoError(t, err) @@ -561,15 +570,66 @@ var stressTests = []struct { wantTitle: "", }, { - name: "DevAttila87 Example", + name: "Itarazzo Library Example", args: args{ - location: "https://raw.githubusercontent.com/devAttila87/arazzo/24dd4c896f98b942e61831f3529fe538089baedf/application-integration-test/src/test/resources/arazzo.yaml", + location: "https://raw.githubusercontent.com/leidenheit/itarazzo-library/3b335e1c4293444add52b5f2476420e2d871b1a5/src/test/resources/test.arazzo.yaml", validationIgnores: []string{ - "only one of operationId, operationPath or workflowId can be set", // legit issue + "expression is not valid, must begin with $: 4711Chocolate", // legit issue }, }, wantTitle: "A cookie eating workflow", }, + { + name: "Itarazzo Client Pet Store Example", + args: args{ + location: "https://raw.githubusercontent.com/leidenheit/itarazzo-client/b744ca1ca3a036964ae30be601f10a25b14dc52d/src/test/resources/pet-store.arazzo.yaml", + validationIgnores: []string{ + "jsonpointer must start with /: $.status", // legit issues TODO: improve the error returned as it is wrong + "jsonpointer must start with /: $.id", // legit issues TODO: improve the error returned as it is wrong + }, + }, + wantTitle: "PetStore - Example of Workflows", + }, + { + name: "Ritza build-a-bot workflow", + args: args{ + location: "https://raw.githubusercontent.com/ritza-co/e2e-testing-arazzo/c0615c3708a1e4c0fcaeb79edae78ddc4eb5ba82/arazzo.yaml", + validationIgnores: []string{}, + }, + wantTitle: "Build-a-Bot Workflow", + }, + { + name: " API-Flows adyen-giving workflow", + args: args{ + location: "https://raw.githubusercontent.com/API-Flows/openapi-workflow-registry/3d85d79232fa8f42993b2f5bd47e273b9369dc2d/root/adyen/adyen-giving.yaml", + validationIgnores: []string{ + "in must be one of [path, query, header, cookie] but was body", + }, + }, + wantTitle: "Adyen Giving", + }, + { + name: "API-Flows simple workflow", + args: args{ + location: "https://raw.githubusercontent.com/API-Flows/openapi-workflow-parser/6b28ba4def262969c5a96bc54d08433e6c336643/src/test/resources/1.0.0/simple.workflow.yaml", + validationIgnores: []string{}, + }, + wantTitle: "simple", + }, + // Disabled for now as it is currently failing round tripping due to missing conditions + // { + // name: "Kartikhub swap tokens workflow", + // args: args{ + // location: "https://raw.githubusercontent.com/Kartikhub/web3-basics/d95bc51bb935ef07d627e52c6fdfe18aaea69e18/swap-react/docs/swap-transaction-arazzo.yaml", + // validationIgnores: []string{ // All valid issues + // "field condition is missing", + // "condition is required", + // "field value is missing", + // "expression is not valid, must begin with $", + // }, + // }, + // wantTitle: "Swap Tokens", + // }, } func TestArazzo_StressTests_Validate(t *testing.T) { diff --git a/arazzo/core/arazzo.go b/arazzo/core/arazzo.go index 8c97668..01d6137 100644 --- a/arazzo/core/arazzo.go +++ b/arazzo/core/arazzo.go @@ -2,6 +2,7 @@ package core import ( "context" + "errors" "fmt" "io" @@ -32,6 +33,10 @@ func Unmarshal(ctx context.Context, doc io.Reader) (*Arazzo, error) { return nil, fmt.Errorf("failed to read Arazzo document: %w", err) } + if len(data) == 0 { + return nil, errors.New("empty document") + } + var root yaml.Node if err := yaml.Unmarshal(data, &root); err != nil { return nil, fmt.Errorf("failed to unmarshal Arazzo document: %w", err) diff --git a/arazzo/core/criterion.go b/arazzo/core/criterion.go index 25edb65..e28b3b4 100644 --- a/arazzo/core/criterion.go +++ b/arazzo/core/criterion.go @@ -64,12 +64,12 @@ func (c *CriterionTypeUnion) SyncChanges(ctx context.Context, model any, valueNo tf := mv.FieldByName("Type") ef := mv.FieldByName("ExpressionType") - tv, err := marshaller.SyncValue(ctx, tf.Interface(), &c.Type, valueNode) + tv, err := marshaller.SyncValue(ctx, tf.Interface(), &c.Type, valueNode, false) if err != nil { return nil, err } - ev, err := marshaller.SyncValue(ctx, ef.Interface(), &c.ExpressionType, valueNode) + ev, err := marshaller.SyncValue(ctx, ef.Interface(), &c.ExpressionType, valueNode, false) if err != nil { return nil, err } diff --git a/arazzo/core/reusable.go b/arazzo/core/reusable.go index f7fc311..c828c1d 100644 --- a/arazzo/core/reusable.go +++ b/arazzo/core/reusable.go @@ -58,27 +58,14 @@ func (r *Reusable[T]) SyncChanges(ctx context.Context, model any, valueNode *yam of := mv.FieldByName("Object") if of.IsZero() { - type reusable[T any] struct { - Reference marshaller.Node[*Expression] `key:"reference"` - Value marshaller.Node[Value] `key:"value"` - - RootNode *yaml.Node - } - - rl := reusable[T]{ - Reference: r.Reference, - Value: r.Value, - RootNode: r.RootNode, - } - var err error - valueNode, err = marshaller.SyncValue(ctx, model, &rl, valueNode) + valueNode, err = marshaller.SyncValue(ctx, model, r, valueNode, true) if err != nil { return nil, err } } else { var err error - valueNode, err = marshaller.SyncValue(ctx, of.Interface(), &r.Object, valueNode) + valueNode, err = marshaller.SyncValue(ctx, of.Interface(), &r.Object, valueNode, false) if err != nil { return nil, err } diff --git a/arazzo/criterion/criterion.go b/arazzo/criterion/criterion.go index 7d88668..301e932 100644 --- a/arazzo/criterion/criterion.go +++ b/arazzo/criterion/criterion.go @@ -263,8 +263,8 @@ func (c *Criterion) Validate(opts ...validation.Option) []error { func (c *Criterion) validateCondition(opts ...validation.Option) []error { errs := []error{} - conditionLine := c.core.Condition.GetValueNodeOrRoot(c.core.RootNode).Line - conditionColumn := c.core.Condition.GetValueNodeOrRoot(c.core.RootNode).Column + conditionLine := c.GetCore().Condition.GetValueNodeOrRoot(c.GetCore().RootNode).Line + conditionColumn := c.GetCore().Condition.GetValueNodeOrRoot(c.GetCore().RootNode).Column switch c.Type.GetType() { case CriterionTypeSimple: diff --git a/arazzo/parameter.go b/arazzo/parameter.go index 0eea6fb..228f663 100644 --- a/arazzo/parameter.go +++ b/arazzo/parameter.go @@ -91,7 +91,7 @@ func (p *Parameter) Validate(ctx context.Context, opts ...validation.Option) []e if in != "" { errs = append(errs, &validation.Error{ - Message: fmt.Sprintf("in must be one of [%s]", strings.Join([]string{string(InPath), string(InQuery), string(InHeader), string(InCookie)}, ", ")), + Message: fmt.Sprintf("in must be one of [%s] but was %s", strings.Join([]string{string(InPath), string(InQuery), string(InHeader), string(InCookie)}, ", "), in), Line: p.core.In.GetValueNodeOrRoot(p.core.RootNode).Line, Column: p.core.In.GetValueNodeOrRoot(p.core.RootNode).Column, }) diff --git a/arazzo/testdata/speakeasybar.arazzo.yaml b/arazzo/testdata/speakeasybar.arazzo.yaml index f815a8e..324f7e5 100644 --- a/arazzo/testdata/speakeasybar.arazzo.yaml +++ b/arazzo/testdata/speakeasybar.arazzo.yaml @@ -76,13 +76,13 @@ workflows: operationId: createOrder parameters: - name: orderType - in: body + in: query value: $inputs.orderType - name: productCode - in: body + in: query value: $inputs.productCode - name: quantity - in: body + in: query value: $inputs.quantity outputs: orderNumber: $response.body#/orderNumber @@ -176,9 +176,9 @@ components: value: $steps.authenticate.outputs.token username: name: username - in: body + in: query value: $inputs.username password: name: password - in: body + in: query value: $inputs.password diff --git a/arazzo/testdata/test.arazzo.yaml b/arazzo/testdata/test.arazzo.yaml index caf6c90..9c207df 100755 --- a/arazzo/testdata/test.arazzo.yaml +++ b/arazzo/testdata/test.arazzo.yaml @@ -68,5 +68,6 @@ components: retryAfter: 10 retryLimit: 3 criteria: - - condition: $statusCode == 500%s + - context: $statusCode + condition: $statusCode == 500%s x-test: some-value diff --git a/jsonschema/oas31/core/value.go b/jsonschema/oas31/core/value.go index 36e118c..840e828 100644 --- a/jsonschema/oas31/core/value.go +++ b/jsonschema/oas31/core/value.go @@ -53,12 +53,12 @@ func (v *EitherValue[L, R]) SyncChanges(ctx context.Context, model any, valueNod lf := mv.FieldByName("Left") rf := mv.FieldByName("Right") - lv, err := marshaller.SyncValue(ctx, lf.Interface(), &v.Left, valueNode) + lv, err := marshaller.SyncValue(ctx, lf.Interface(), &v.Left, valueNode, false) if err != nil { return nil, err } - rv, err := marshaller.SyncValue(ctx, rf.Interface(), &v.Right, valueNode) + rv, err := marshaller.SyncValue(ctx, rf.Interface(), &v.Right, valueNode, false) if err != nil { return nil, err } diff --git a/jsonschema/oas31/core/value_test.go b/jsonschema/oas31/core/value_test.go index 411c12a..6a7a5c7 100644 --- a/jsonschema/oas31/core/value_test.go +++ b/jsonschema/oas31/core/value_test.go @@ -23,7 +23,7 @@ func TestEitherValue_SyncChanges_Success(t *testing.T) { Left: pointer.From("some-value"), } var target EitherValue[string, string] - outNode, err := marshaller.SyncValue(ctx, source, &target, nil) + outNode, err := marshaller.SyncValue(ctx, source, &target, nil, false) require.NoError(t, err) assert.Equal(t, testutils.CreateStringYamlNode("some-value", 0, 0), outNode) assert.Equal(t, "some-value", *target.Left) diff --git a/marshaller/node.go b/marshaller/node.go index 4870eb7..a52483a 100644 --- a/marshaller/node.go +++ b/marshaller/node.go @@ -51,7 +51,7 @@ func (n Node[V]) GetValueType() reflect.Type { func (n *Node[V]) SyncValue(ctx context.Context, key string, value any) (*yaml.Node, *yaml.Node, error) { n.Key = key n.KeyNode = yml.CreateOrUpdateKeyNode(ctx, key, n.KeyNode) - valueNode, err := SyncValue(ctx, value, &n.Value, n.ValueNode) + valueNode, err := SyncValue(ctx, value, &n.Value, n.ValueNode, false) if err != nil { return nil, nil, err } diff --git a/marshaller/syncer.go b/marshaller/syncer.go index ddd97a9..5469dc6 100644 --- a/marshaller/syncer.go +++ b/marshaller/syncer.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "reflect" + "unsafe" "github.com/speakeasy-api/openapi/yml" "gopkg.in/yaml.v3" @@ -14,75 +15,71 @@ type Syncer interface { } type SyncerWithSyncFunc interface { - SyncChangesWithSyncFunc(ctx context.Context, model any, valueNode *yaml.Node, syncFunc func(context.Context, any, any, *yaml.Node) (*yaml.Node, error)) (*yaml.Node, error) + SyncChangesWithSyncFunc(ctx context.Context, model any, valueNode *yaml.Node, syncFunc func(context.Context, any, any, *yaml.Node, bool) (*yaml.Node, error)) (*yaml.Node, error) } -func SyncValue(ctx context.Context, source any, target any, valueNode *yaml.Node) (node *yaml.Node, err error) { +func SyncValue(ctx context.Context, source any, target any, valueNode *yaml.Node, skipCustomSyncer bool) (node *yaml.Node, err error) { s := reflect.ValueOf(source) - st := reflect.TypeOf(source) t := reflect.ValueOf(target) - tt := reflect.TypeOf(target) - if s.Kind() == reflect.Ptr { - if s.IsNil() { - t.Elem().Set(reflect.Zero(t.Type().Elem())) - return nil, nil - } - - s, st = fullyDereference(s, st) + if t.Kind() != reflect.Ptr { + return nil, fmt.Errorf("SyncValue expected target to be a pointer, got %s", t.Kind()) } - t, tt = dereferenceToLastPtr(t, tt) + s = dereferenceToLastPtr(s) + t = dereferenceAndInitializeIfNeededToLastPtr(t, reflect.ValueOf(source)) - if tt.Kind() != reflect.Ptr { - return nil, fmt.Errorf("SyncValue expected pointer, got %s", tt) + if s.Kind() == reflect.Ptr && s.IsNil() { + if !t.IsZero() { + t.Elem().Set(reflect.Zero(t.Type().Elem())) + } + return nil, nil } - if tt.Elem().Kind() != st.Kind() { - return nil, fmt.Errorf("SyncValue expected target to be %s, got %s", st.Kind(), tt.Elem().Kind()) + sUnderlying := getUnderlyingValue(s) + tUnderlyingType := dereferenceType(t.Type()) + + if sUnderlying.Kind() != tUnderlyingType.Kind() { + return nil, fmt.Errorf("SyncValue expected target to be %s, got %s", sUnderlying.Kind(), tUnderlyingType.Kind()) } switch { - case s.Kind() == reflect.Struct && t.Type() == reflect.TypeOf((*yaml.Node)(nil)): - t.Elem().Set(s) + case sUnderlying.Kind() == reflect.Struct && t.Type() == reflect.TypeOf((*yaml.Node)(nil)): + t.Set(s) return t.Interface().(*yaml.Node), nil - case s.Kind() == reflect.Struct: - syncer, ok := t.Interface().(Syncer) - if ok { - sv := s.Interface() - if s.CanAddr() { - sv = s.Addr().Interface() + case sUnderlying.Kind() == reflect.Struct: + if !skipCustomSyncer { + syncer, ok := t.Interface().(Syncer) + if ok { + sv := s.Interface() + + return syncer.SyncChanges(ctx, sv, valueNode) } - return syncer.SyncChanges(ctx, sv, valueNode) - } + syncerWithSyncFunc, ok := t.Interface().(SyncerWithSyncFunc) + if ok { + sv := s.Interface() - syncerWithSyncFunc, ok := t.Interface().(SyncerWithSyncFunc) - if ok { - sv := s.Interface() - if s.CanAddr() { - sv = s.Addr().Interface() + return syncerWithSyncFunc.SyncChangesWithSyncFunc(ctx, sv, valueNode, SyncValue) } - - return syncerWithSyncFunc.SyncChangesWithSyncFunc(ctx, sv, valueNode, SyncValue) } return syncChanges(ctx, s.Interface(), t.Interface(), valueNode) - case s.Kind() == reflect.Map: + case sUnderlying.Kind() == reflect.Map: // TODO call sync changes on each value panic("not implemented") - case s.Kind() == reflect.Slice, s.Kind() == reflect.Array: - return syncArraySlice(ctx, s.Interface(), t.Interface(), valueNode) + case sUnderlying.Kind() == reflect.Slice, sUnderlying.Kind() == reflect.Array: + return syncArraySlice(ctx, sUnderlying.Interface(), t.Interface(), valueNode) default: - if st != tt.Elem() { + if sUnderlying.Type() != tUnderlyingType { // Cast the value to the target type - s = s.Convert(tt.Elem()) + sUnderlying = sUnderlying.Convert(tUnderlyingType) } if !t.Elem().IsValid() { - t.Set(reflect.New(tt.Elem())) + t.Set(reflect.New(tUnderlyingType)) } - t.Elem().Set(s) - out := yml.CreateOrUpdateScalarNode(ctx, s.Interface(), valueNode) + t.Elem().Set(sUnderlying) + out := yml.CreateOrUpdateScalarNode(ctx, sUnderlying.Interface(), valueNode) return out, nil } } @@ -91,20 +88,26 @@ func syncChanges(ctx context.Context, source any, target any, valueNode *yaml.No s := reflect.ValueOf(source) t := reflect.ValueOf(target) - if s.Kind() == reflect.Ptr { - if s.IsNil() { - panic("not implemented") - } - s = s.Elem() + if s.Kind() != reflect.Ptr { + return nil, fmt.Errorf("syncChanges expected source to be a pointer, got %s", s.Kind()) } - if t.Kind() == reflect.Ptr && t.IsNil() { - t.Set(reflect.New(t.Type().Elem())) + + if t.Kind() != reflect.Ptr { + return nil, fmt.Errorf("syncChanges expected target to be a pointer, got %s", t.Kind()) } - if t.Kind() == reflect.Ptr { - t = t.Elem() + + if s.IsNil() { + panic("not implemented") } - if s.Kind() != reflect.Struct { + if t.IsNil() { + t.Set(reflect.New(t.Elem().Type())) + } + + sUnderlying := getUnderlyingValue(s) + t = getUnderlyingValue(t) + + if sUnderlying.Kind() != reflect.Struct { return nil, fmt.Errorf("syncChanges expected struct, got %s", s.Type()) } @@ -113,7 +116,7 @@ func syncChanges(ctx context.Context, source any, target any, valueNode *yaml.No if !field.IsExported() { continue } - sourceVal := s.FieldByName(field.Name) + sourceVal := sUnderlying.FieldByName(field.Name) key := field.Tag.Get("key") if key == "" { @@ -144,7 +147,7 @@ func syncChanges(ctx context.Context, source any, target any, valueNode *yaml.No } targetInt := target.Interface() - sourceInt := sourceVal.Interface() + sourceInt := sourceVal.Addr().Interface() nodeMutator, ok := targetInt.(NodeMutator) if !ok { @@ -163,6 +166,7 @@ func syncChanges(ctx context.Context, source any, target any, valueNode *yaml.No } } + // Populate the RootNode of the target with the result rn, ok := t.Type().FieldByName("RootNode") if !ok { return nil, fmt.Errorf("SyncChanges expected a RootNode field on the target %s", t.Type()) @@ -170,6 +174,13 @@ func syncChanges(ctx context.Context, source any, target any, valueNode *yaml.No t.FieldByIndex(rn.Index).Set(reflect.ValueOf(valueNode)) + // Update the core of the source with the updated value + cf, ok := sUnderlying.Type().FieldByName("core") + if ok { + sf := sUnderlying.FieldByIndex(cf.Index) + reflect.NewAt(sf.Type(), unsafe.Pointer(sf.UnsafeAddr())).Elem().Set(t) + } + return valueNode, nil } @@ -233,7 +244,15 @@ func syncArraySlice(ctx context.Context, source any, target any, valueNode *yaml } var err error - currentElementNode, err = SyncValue(ctx, sourceVal.Index(i).Interface(), targetVal.Index(i).Addr().Interface(), currentElementNode) + + var sourceValAtIdx any + if sourceVal.Index(i).CanAddr() { + sourceValAtIdx = sourceVal.Index(i).Addr().Interface() + } else { + sourceValAtIdx = sourceVal.Index(i).Interface() + } + + currentElementNode, err = SyncValue(ctx, sourceValAtIdx, targetVal.Index(i).Addr().Interface(), currentElementNode, false) if err != nil { return nil, err } @@ -248,22 +267,45 @@ func syncArraySlice(ctx context.Context, source any, target any, valueNode *yaml return yml.CreateOrUpdateSliceNode(ctx, elements, valueNode), nil } -func fullyDereference(val reflect.Value, typ reflect.Type) (reflect.Value, reflect.Type) { - if typ.Kind() == reflect.Ptr { - return fullyDereference(val.Elem(), typ.Elem()) +// will dereference the last ptr in the type while initializing any higher level pointers +func dereferenceAndInitializeIfNeededToLastPtr(val reflect.Value, source reflect.Value) reflect.Value { + if val.Kind() == reflect.Ptr && val.IsNil() { + if (source.Kind() == reflect.Ptr && !source.IsNil()) || (source.Kind() != reflect.Ptr && source.IsValid()) { + val.Set(reflect.New(val.Type().Elem())) + } } + if val.Kind() == reflect.Ptr && val.Elem().Kind() == reflect.Ptr { + sourceVal := source + if sourceVal.Kind() == reflect.Ptr { + sourceVal = sourceVal.Elem() + } - return val, typ + return dereferenceAndInitializeIfNeededToLastPtr(val.Elem(), sourceVal) + } + + return val } -// will dereference the last ptr in the type while initializing any higher level pointers -func dereferenceToLastPtr(val reflect.Value, typ reflect.Type) (reflect.Value, reflect.Type) { - if typ.Kind() == reflect.Ptr && val.IsNil() { - val.Set(reflect.New(typ.Elem())) +// will dereference the last ptr in the type +func dereferenceToLastPtr(val reflect.Value) reflect.Value { + if val.Kind() == reflect.Ptr && val.Elem().Kind() == reflect.Ptr { + return dereferenceToLastPtr(val.Elem()) + } + + return val +} + +func getUnderlyingValue(v reflect.Value) reflect.Value { + for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface { + v = v.Elem() } - if typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Ptr { - return dereferenceToLastPtr(val.Elem(), typ.Elem()) + return v +} + +func dereferenceType(typ reflect.Type) reflect.Type { + for typ.Kind() == reflect.Ptr { + return dereferenceType(typ.Elem()) } - return val, typ + return typ } diff --git a/marshaller/syncer_test.go b/marshaller/syncer_test.go index 30e4c05..5387afa 100644 --- a/marshaller/syncer_test.go +++ b/marshaller/syncer_test.go @@ -15,7 +15,7 @@ import ( func TestSyncValue_String(t *testing.T) { target := "" - outNode, err := SyncValue(context.Background(), "some-value", &target, nil) + outNode, err := SyncValue(context.Background(), "some-value", &target, nil, false) require.NoError(t, err) assert.Equal(t, testutils.CreateStringYamlNode("some-value", 0, 0), outNode) assert.Equal(t, "some-value", target) @@ -23,7 +23,7 @@ func TestSyncValue_String(t *testing.T) { func TestSyncValue_StringPtrSet(t *testing.T) { target := pointer.From("") - outNode, err := SyncValue(context.Background(), pointer.From("some-value"), &target, nil) + outNode, err := SyncValue(context.Background(), pointer.From("some-value"), &target, nil, false) require.NoError(t, err) assert.Equal(t, testutils.CreateStringYamlNode("some-value", 0, 0), outNode) assert.Equal(t, "some-value", *target) @@ -31,7 +31,7 @@ func TestSyncValue_StringPtrSet(t *testing.T) { func TestSyncValue_StringPtrNil(t *testing.T) { var target *string - outNode, err := SyncValue(context.Background(), pointer.From("some-value"), &target, nil) + outNode, err := SyncValue(context.Background(), pointer.From("some-value"), &target, nil, false) require.NoError(t, err) assert.Equal(t, testutils.CreateStringYamlNode("some-value", 0, 0), outNode) assert.Equal(t, "some-value", *target) @@ -57,7 +57,7 @@ func (t *TestStructSyncerCore[T]) SyncChanges(ctx context.Context, model any, va } var err error - t.RootNode, err = SyncValue(ctx, mv.FieldByName("Val").Interface(), &t.Val, valueNode) + t.RootNode, err = SyncValue(ctx, mv.FieldByName("Val").Interface(), &t.Val, valueNode, false) return t.RootNode, err } @@ -66,7 +66,7 @@ func TestSyncValue_StructPtr_CustomSyncer(t *testing.T) { source := &TestStructSyncer[int]{Val: pointer.From(1)} - outNode, err := SyncValue(context.Background(), source, &target, nil) + outNode, err := SyncValue(context.Background(), source, &target, nil, false) require.NoError(t, err) node := testutils.CreateIntYamlNode(1, 0, 0) assert.Equal(t, node, outNode) @@ -79,7 +79,7 @@ func TestSyncValue_Struct_CustomSyncer(t *testing.T) { source := TestStructSyncer[int]{Val: pointer.From(1)} - outNode, err := SyncValue(context.Background(), source, &target, nil) + outNode, err := SyncValue(context.Background(), source, &target, nil, false) require.NoError(t, err) node := testutils.CreateIntYamlNode(1, 0, 0) assert.Equal(t, node, outNode) @@ -91,6 +91,8 @@ type TestStruct struct { Str string StrPtr *string BoolPtr *bool + + core TestStructCore } type TestStructCore struct { @@ -103,8 +105,6 @@ type TestStructCore struct { } func TestSyncChanges_Struct(t *testing.T) { - var target TestStructCore - source := TestStruct{ Int: 1, Str: "some-string", @@ -112,7 +112,7 @@ func TestSyncChanges_Struct(t *testing.T) { BoolPtr: pointer.From(true), } - outNode, err := SyncValue(context.Background(), source, &target, nil) + outNode, err := SyncValue(context.Background(), &source, &source.core, nil, false) require.NoError(t, err) node := testutils.CreateMapYamlNode([]*yaml.Node{ @@ -127,16 +127,38 @@ func TestSyncChanges_Struct(t *testing.T) { }, 0, 0) assert.Equal(t, node, outNode) - assert.Equal(t, node, target.RootNode) - assert.Equal(t, 1, target.Int.Value) - assert.Equal(t, "some-string", target.Str.Value) - assert.Equal(t, "some-string-ptr", *target.StrPtr.Value) - assert.Equal(t, true, *target.BoolPtr.Value) + assert.Equal(t, node, source.core.RootNode) + assert.Equal(t, 1, source.core.Int.Value) + assert.Equal(t, "some-string", source.core.Str.Value) + assert.Equal(t, "some-string-ptr", *source.core.StrPtr.Value) + assert.Equal(t, true, *source.core.BoolPtr.Value) } -func TestSyncChanges_StructPtr(t *testing.T) { - var target *TestStructCore +func TestSyncChanges_StructWithOptionalsUnset(t *testing.T) { + source := TestStruct{ + Int: 1, + Str: "some-string", + } + + outNode, err := SyncValue(context.Background(), &source, &source.core, nil, false) + require.NoError(t, err) + + node := testutils.CreateMapYamlNode([]*yaml.Node{ + testutils.CreateStringYamlNode("int", 0, 0), + testutils.CreateIntYamlNode(1, 0, 0), + testutils.CreateStringYamlNode("str", 0, 0), + testutils.CreateStringYamlNode("some-string", 0, 0), + }, 0, 0) + assert.Equal(t, node, outNode) + assert.Equal(t, node, source.core.RootNode) + assert.Equal(t, 1, source.core.Int.Value) + assert.Equal(t, "some-string", source.core.Str.Value) + assert.Nil(t, source.core.StrPtr.Value) + assert.Nil(t, source.core.BoolPtr.Value) +} + +func TestSyncChanges_StructPtr(t *testing.T) { source := &TestStruct{ Int: 1, Str: "some-string", @@ -144,7 +166,7 @@ func TestSyncChanges_StructPtr(t *testing.T) { BoolPtr: pointer.From(true), } - outNode, err := SyncValue(context.Background(), source, &target, nil) + outNode, err := SyncValue(context.Background(), &source, &source.core, nil, false) require.NoError(t, err) node := testutils.CreateMapYamlNode([]*yaml.Node{ @@ -159,9 +181,69 @@ func TestSyncChanges_StructPtr(t *testing.T) { }, 0, 0) assert.Equal(t, node, outNode) - assert.Equal(t, node, target.RootNode) - assert.Equal(t, 1, target.Int.Value) - assert.Equal(t, "some-string", target.Str.Value) - assert.Equal(t, "some-string-ptr", *target.StrPtr.Value) - assert.Equal(t, true, *target.BoolPtr.Value) + assert.Equal(t, node, source.core.RootNode) + assert.Equal(t, 1, source.core.Int.Value) + assert.Equal(t, "some-string", source.core.Str.Value) + assert.Equal(t, "some-string-ptr", *source.core.StrPtr.Value) + assert.Equal(t, true, *source.core.BoolPtr.Value) +} + +type TestStructNested struct { + TestStruct TestStruct + + core TestStructNestedCore +} + +type TestStructNestedCore struct { + TestStruct Node[TestStructCore] `key:"testStruct"` + + RootNode *yaml.Node +} + +func TestSyncChanges_NestedStruct(t *testing.T) { + source := TestStructNested{ + TestStruct: TestStruct{ + Int: 1, + Str: "some-string", + StrPtr: pointer.From("some-string-ptr"), + BoolPtr: pointer.From(true), + }, + } + + outNode, err := SyncValue(context.Background(), &source, &source.core, nil, false) + require.NoError(t, err) + + nestedNode := testutils.CreateMapYamlNode([]*yaml.Node{ + testutils.CreateStringYamlNode("int", 0, 0), + testutils.CreateIntYamlNode(1, 0, 0), + testutils.CreateStringYamlNode("str", 0, 0), + testutils.CreateStringYamlNode("some-string", 0, 0), + testutils.CreateStringYamlNode("strPtr", 0, 0), + testutils.CreateStringYamlNode("some-string-ptr", 0, 0), + testutils.CreateStringYamlNode("boolPtr", 0, 0), + testutils.CreateBoolYamlNode(true, 0, 0), + }, 0, 0) + + node := testutils.CreateMapYamlNode([]*yaml.Node{ + testutils.CreateStringYamlNode("testStruct", 0, 0), + nestedNode, + }, 0, 0) + + assert.Equal(t, node, outNode) + assert.Equal(t, node, source.core.RootNode) + assert.Equal(t, nestedNode, source.TestStruct.core.RootNode) + assert.Equal(t, 1, source.core.TestStruct.Value.Int.Value) + assert.Equal(t, "some-string", source.core.TestStruct.Value.Str.Value) + assert.Equal(t, "some-string-ptr", *source.core.TestStruct.Value.StrPtr.Value) + assert.Equal(t, true, *source.core.TestStruct.Value.BoolPtr.Value) +} + +type TestInt int + +func TestSyncValue_TypeDefinition(t *testing.T) { + var target TestInt + outNode, err := SyncValue(context.Background(), 1, &target, nil, false) + require.NoError(t, err) + assert.Equal(t, testutils.CreateIntYamlNode(1, 0, 0), outNode) + assert.Equal(t, TestInt(1), target) } diff --git a/marshaller/unmarshaller.go b/marshaller/unmarshaller.go index d6ffbfa..8dddc1f 100644 --- a/marshaller/unmarshaller.go +++ b/marshaller/unmarshaller.go @@ -149,9 +149,13 @@ func UnmarshalStruct(ctx context.Context, node *yaml.Node, structPtr any) error } func unmarshal(ctx context.Context, node *yaml.Node, out reflect.Value) error { - if out.Type() == reflect.TypeOf((*yaml.Node)(nil)) { + switch { + case out.Type() == reflect.TypeOf((*yaml.Node)(nil)): out.Set(reflect.ValueOf(node)) return nil + case out.Type() == reflect.TypeOf(yaml.Node{}): + out.Set(reflect.ValueOf(*node)) + return nil } if isUnmarshallable(out) { diff --git a/sequencedmap/map.go b/sequencedmap/map.go index 65f4fcc..44c5a85 100644 --- a/sequencedmap/map.go +++ b/sequencedmap/map.go @@ -338,7 +338,7 @@ type mapGetter interface { AllUntyped() iter.Seq2[any, any] } -func (m *Map[K, V]) SyncChangesWithSyncFunc(ctx context.Context, model any, valueNode *yaml.Node, syncFunc func(context.Context, any, any, *yaml.Node) (*yaml.Node, error)) (*yaml.Node, error) { +func (m *Map[K, V]) SyncChangesWithSyncFunc(ctx context.Context, model any, valueNode *yaml.Node, syncFunc func(context.Context, any, any, *yaml.Node, bool) (*yaml.Node, error)) (*yaml.Node, error) { m.Init() mg, ok := (model).(mapGetter) @@ -357,7 +357,7 @@ func (m *Map[K, V]) SyncChangesWithSyncFunc(ctx context.Context, model any, valu kn, vn, _ := yml.GetMapElementNodes(ctx, valueNode, keyStr) - vn, err := syncFunc(ctx, v, &lv, vn) + vn, err := syncFunc(ctx, v, &lv, vn, false) if err != nil { return nil, err }