Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ If you're looking to integrate Cedar into a production system, please be sure th

## Change log

### 1.4.0
#### New Features
- Add an `UnmarshalCedar` method on `types.EntityUID`s
- Implement the `encoding.BinaryMarshaler` and `encoding.BinaryUnmarshaler` interfaces for `types.EntityUID`s

### 1.2.5
#### New Features
- Adds experimental support for parsing Cedar schema in both Cedar and JSON formats
Expand Down
42 changes: 42 additions & 0 deletions types/entity_uid.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package types

import (
"encoding/json"
"errors"
"hash/fnv"
"strconv"
"strings"

"github.com/cedar-policy/cedar-go/internal/mapset"
"github.com/cedar-policy/cedar-go/internal/rust"
)

// Path is a series of idents separated by ::
Expand Down Expand Up @@ -46,6 +49,37 @@ func (e EntityUID) MarshalCedar() []byte {
return []byte(e.String())
}

var errInvalidUID = errors.New("invalid EntityUID")

// UnmarshalCedar parses a Cedar language representation of an EntityUID.
func (e *EntityUID) UnmarshalCedar(data []byte) error {
// NB: In a perfect world we'd use the full parsing from internal/parser, but
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for writing this up. It's good to document our warts.

// today that imports cedar-go/types (this pkg) which means we'd need to carve
// it out to reuse it. Given that NewEntityUID(.,.) does zero validation
// itself, the juice is not worth the squeeze today.
s := string(data)
idx := strings.Index(s, "::\"")
if idx <= 0 {
// If idx == 0, the entity has no type, which is invalid.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny nit, == should be <=

Alternatively, just remove the comment :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I actually was only referring to the == 0 case for this comment. If we do find the ::" string but it's the first thing then there is no entity type which is also invalid. (There a test that exercises this scenario)

return errInvalidUID
}

typ := EntityType(s[:idx])
quoted := s[idx+2:] // include the leading `"`

if len(quoted) < 2 || quoted[0] != '"' || quoted[len(quoted)-1] != '"' {
return errInvalidUID
}

id, _, err := rust.Unquote([]byte(quoted[1:len(quoted)-1]), false)
if err != nil {
return errInvalidUID
}

*e = NewEntityUID(typ, String(id))
return nil
}

func (e *EntityUID) UnmarshalJSON(b []byte) error {
// TODO: review after adding support for schemas
var res entityValueJSON
Expand Down Expand Up @@ -74,6 +108,14 @@ func (e EntityUID) MarshalJSON() ([]byte, error) {
})
}

func (e EntityUID) MarshalBinary() ([]byte, error) {
return e.MarshalCedar(), nil
}

func (e *EntityUID) UnmarshalBinary(data []byte) error {
return e.UnmarshalCedar(data)
}

func (e EntityUID) hash() uint64 {
h := fnv.New64()
_, _ = h.Write([]byte(e.Type))
Expand Down
85 changes: 83 additions & 2 deletions types/entity_uid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package types_test
import (
"testing"

"github.com/cedar-policy/cedar-go"
"github.com/cedar-policy/cedar-go/internal/testutil"
"github.com/cedar-policy/cedar-go/types"
)
Expand All @@ -26,9 +27,89 @@ func TestEntity(t *testing.T) {
testutil.Equals(t, types.EntityUID{Type: "namespace::type", ID: "id"}.String(), `namespace::type::"id"`)
})

t.Run("MarshalCedar", func(t *testing.T) {
t.Run("Marshal EntityUID round trip", func(t *testing.T) {
t.Parallel()
testutil.Equals(t, string(types.EntityUID{"type", "id"}.MarshalCedar()), `type::"id"`)
testCases := []struct {
typ, id, bin string
}{
{"namespace::type", "id", `namespace::type::"id"`},
{"namespace::type", "", `namespace::type::""`},
{"X::Y", "abc::", `X::Y::"abc::"`},
{"Search::Algorithm", "A*", `Search::Algorithm::"A*"`},
{"Super", "*", `Super::"*"`},
}
marshalFuncs := []struct {
name string
marshal func(types.EntityUID) ([]byte, error)
unmarshal func([]byte, *types.EntityUID) error
}{
{
name: "MarshalBinary",
marshal: func(uid types.EntityUID) ([]byte, error) { return uid.MarshalBinary() },
unmarshal: func(bin []byte, uid *types.EntityUID) error {
return uid.UnmarshalBinary(bin)
},
},
{
name: "MarshalCedar",
marshal: func(uid types.EntityUID) ([]byte, error) { return uid.MarshalCedar(), nil },
unmarshal: func(bin []byte, uid *types.EntityUID) error {
return (uid).UnmarshalCedar(bin)
},
},
}

for _, marshalFunc := range marshalFuncs {
t.Run(marshalFunc.name, func(t *testing.T) {
t.Parallel()
for _, testCase := range testCases {
t.Run(testCase.bin, func(t *testing.T) {
t.Parallel()
uid := types.NewEntityUID(cedar.EntityType(testCase.typ), cedar.String(testCase.id))
gotBin, err := marshalFunc.marshal(uid)
testutil.OK(t, err)

wantBin := []byte(testCase.bin)
testutil.Equals(t, gotBin, wantBin)

want := types.NewEntityUID(cedar.EntityType(testCase.typ), cedar.String(testCase.id))
got := types.EntityUID{}
err = marshalFunc.unmarshal(gotBin, &got)
testutil.OK(t, err)
testutil.Equals(t, got.String(), want.String())
testutil.FatalIf(t, !uid.Equal(got), "expected %v to not equal %v", got, want)
})
}
})
}
})

t.Run("UnmarshalCedar invalid", func(t *testing.T) {
t.Parallel()

tests := []struct {
name string
input string
}{
{"unquoted string", `Type::id`},
{"missing double colon", `Type"id"`},
{"missing a quote at beginning", `Type::id"`},
{"missing a quote at end", `Type::"id`},
{"empty input", ``},
{"just quoted string", `"id"`},
{"no type", `::"id"`},
{"partial", `Type::"`},
{"unescaped unicode", `Type::"\u00ab"`},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var got types.EntityUID
err := got.UnmarshalCedar([]byte(tt.input))
testutil.FatalIf(t, err == nil, "expected error for input %q, got nil", tt.input)
})
}
})
}

Expand Down