Skip to content

Commit

Permalink
test: Creating v1 yaml tests from existing v0 tests (open-policy-agen…
Browse files Browse the repository at this point in the history
…t#6924)

Fixes: open-policy-agent#6864

Signed-off-by: Johan Fylling <[email protected]>
  • Loading branch information
johanfylling authored Aug 20, 2024
1 parent 25f4cb6 commit 3e7e6a0
Show file tree
Hide file tree
Showing 2,255 changed files with 40,019 additions and 78 deletions.
13 changes: 13 additions & 0 deletions ast/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,19 @@ func (v RegoVersion) Int() int {
return 0
}

func (v RegoVersion) String() string {
switch v {
case RegoV0:
return "v0"
case RegoV1:
return "v1"
case RegoV0CompatV1:
return "v0v1"
default:
return "unknown"
}
}

func RegoVersionFromInt(i int) RegoVersion {
if i == 1 {
return RegoV1
Expand Down
17 changes: 12 additions & 5 deletions format/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ type Opts struct {

// RegoVersion is the version of Rego to format code for.
RegoVersion ast.RegoVersion

// ParserOptions is the parser options used when parsing the module to be formatted.
ParserOptions *ast.ParserOptions
}

// defaultLocationFile is the file name used in `Ast()` for terms
Expand All @@ -43,11 +46,15 @@ func Source(filename string, src []byte) ([]byte, error) {
}

func SourceWithOpts(filename string, src []byte, opts Opts) ([]byte, error) {
parserOpts := ast.ParserOptions{}
if opts.RegoVersion == ast.RegoV1 {
// If the rego version is V1, wee need to parse it as such, to allow for future keywords not being imported.
// Otherwise, we'll default to RegoV0
parserOpts.RegoVersion = ast.RegoV1
var parserOpts ast.ParserOptions
if opts.ParserOptions != nil {
parserOpts = *opts.ParserOptions
} else {
if opts.RegoVersion == ast.RegoV1 {
// If the rego version is V1, we need to parse it as such, to allow for future keywords not being imported.
// Otherwise, we'll default to RegoV0
parserOpts.RegoVersion = ast.RegoV1
}
}

module, err := ast.ParseModuleWithOpts(filename, string(src), parserOpts)
Expand Down
101 changes: 54 additions & 47 deletions internal/wasm/sdk/test/e2e/external_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (

const opaRootDir = "../../../../../"

var caseDir = flag.String("case-dir", filepath.Join(opaRootDir, "test/cases/testdata"), "set directory to load test cases from")
var caseDir = flag.String("case-dir", filepath.Join(opaRootDir, "test/cases/testdata/"), "set directory to load test cases from")
var exceptionsFile = flag.String("exceptions", "./exceptions.yaml", "set file to load a list of test names to exclude")

var exceptions map[string]string
Expand Down Expand Up @@ -57,52 +57,59 @@ func TestWasmE2E(t *testing.T) {

ctx := context.Background()

for _, tc := range cases.MustLoad(*caseDir).Sorted().Cases {
name := fmt.Sprintf("%s/%s", strings.TrimPrefix(tc.Filename, opaRootDir), tc.Note)
t.Run(name, func(t *testing.T) {

if shouldSkip(t, tc) {
t.SkipNow()
}

for k, v := range tc.Env {
t.Setenv(k, v)
}

opts := []func(*rego.Rego){
rego.Query(tc.Query),
}
for i := range tc.Modules {
opts = append(opts, rego.Module(fmt.Sprintf("module-%d.rego", i), tc.Modules[i]))
}
if testing.Verbose() {
opts = append(opts, rego.Dump(os.Stderr))
}
cr, err := rego.New(opts...).Compile(ctx)
if err != nil {
t.Fatal(err)
}
o := opa.New().WithPolicyBytes(cr.Bytes)
if tc.Data != nil {
o = o.WithDataJSON(tc.Data)
}
o, err = o.Init()
if err != nil {
t.Fatal(err)
}

var input *interface{}

if tc.InputTerm != nil {
var x interface{} = ast.MustParseTerm(*tc.InputTerm)
input = &x
} else if tc.Input != nil {
input = tc.Input
}

result, err := o.Eval(ctx, opa.EvalOpts{Input: input})
assert(t, tc, result, err)
})
regoVersions := map[string]ast.RegoVersion{
"v0": ast.RegoV0,
"v1": ast.RegoV1,
}
for versionName, regoVersion := range regoVersions {
for _, tc := range cases.MustLoad(filepath.Join(*caseDir, versionName)).Sorted().Cases {
name := fmt.Sprintf("%s/%s", strings.TrimPrefix(tc.Filename, opaRootDir), tc.Note)
t.Run(name, func(t *testing.T) {

if shouldSkip(t, tc) {
t.SkipNow()
}

for k, v := range tc.Env {
t.Setenv(k, v)
}

opts := []func(*rego.Rego){
rego.Query(tc.Query),
rego.SetRegoVersion(regoVersion),
}
for i := range tc.Modules {
opts = append(opts, rego.Module(fmt.Sprintf("module-%d.rego", i), tc.Modules[i]))
}
if testing.Verbose() {
opts = append(opts, rego.Dump(os.Stderr))
}
cr, err := rego.New(opts...).Compile(ctx)
if err != nil {
t.Fatal(err)
}
o := opa.New().WithPolicyBytes(cr.Bytes)
if tc.Data != nil {
o = o.WithDataJSON(tc.Data)
}
o, err = o.Init()
if err != nil {
t.Fatal(err)
}

var input *interface{}

if tc.InputTerm != nil {
var x interface{} = ast.MustParseTerm(*tc.InputTerm)
input = &x
} else if tc.Input != nil {
input = tc.Input
}

result, err := o.Eval(ctx, opa.EvalOpts{Input: input})
assert(t, tc, result, err)
})
}
}
}

Expand Down
32 changes: 18 additions & 14 deletions test/cases/cases.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ import (
"github.com/open-policy-agent/opa/util"
)

// Create v1 test cases from v0 test cases.
// //go:generate ../../build/gen-run-go.sh internal/fmtcases/main.go testdata/v0 0 testdata/v1 1
//go:generate ../../build/gen-run-go.sh internal/fmtcases/main.go testdata/v1 0 testdata/v1_2 1

// Set represents a collection of test cases.
type Set struct {
Cases []TestCase `json:"cases"`
Expand All @@ -31,20 +35,20 @@ func (s Set) Sorted() Set {

// TestCase represents a single test case.
type TestCase struct {
Filename string `json:"-"` // name of file that case was loaded from
Note string `json:"note"` // globally unique identifier for this test case
Query string `json:"query"` // policy query to execute
Modules []string `json:"modules,omitempty"` // policies to test against
Data *map[string]interface{} `json:"data,omitempty"` // data to test against
Input *interface{} `json:"input,omitempty"` // parsed input data to use
InputTerm *string `json:"input_term,omitempty"` // raw input data (serialized as a string, overrides input)
WantDefined *bool `json:"want_defined,omitempty"` // expect query result to be defined (or not)
WantResult *[]map[string]interface{} `json:"want_result,omitempty"` // expect query result (overrides defined)
WantErrorCode *string `json:"want_error_code,omitempty"` // expect query error code (overrides result)
WantError *string `json:"want_error,omitempty"` // expect query error message (overrides error code)
SortBindings bool `json:"sort_bindings,omitempty"` // indicates that binding values should be treated as sets
StrictError bool `json:"strict_error,omitempty"` // indicates that the error depends on strict builtin error mode
Env map[string]string `json:"env,omitempty"` // environment variables to be set during the test
Filename string `json:"-" yaml:"-"` // name of file that case was loaded from
Note string `json:"note" yaml:"note"` // globally unique identifier for this test case
Query string `json:"query" yaml:"query"` // policy query to execute
Modules []string `json:"modules,omitempty" yaml:"modules,omitempty"` // policies to test against
Data *map[string]interface{} `json:"data,omitempty" yaml:"data,omitempty"` // data to test against
Input *interface{} `json:"input,omitempty" yaml:"input,omitempty"` // parsed input data to use
InputTerm *string `json:"input_term,omitempty" yaml:"input_term,omitempty"` // raw input data (serialized as a string, overrides input)
WantDefined *bool `json:"want_defined,omitempty" yaml:"want_defined,omitempty"` // expect query result to be defined (or not)
WantResult *[]map[string]interface{} `json:"want_result,omitempty" yaml:"want_result,omitempty"` // expect query result (overrides defined)
WantErrorCode *string `json:"want_error_code,omitempty" yaml:"want_error_code,omitempty"` // expect query error code (overrides result)
WantError *string `json:"want_error,omitempty" yaml:"want_error,omitempty"` // expect query error message (overrides error code)
SortBindings bool `json:"sort_bindings,omitempty" yaml:"sort_bindings,omitempty"` // indicates that binding values should be treated as sets
StrictError bool `json:"strict_error,omitempty" yaml:"strict_error,omitempty"` // indicates that the error depends on strict builtin error mode
Env map[string]string `json:"env,omitempty" yaml:"env,omitempty"` // environment variables to be set during the test
}

// Load returns a set of built-in test cases.
Expand Down
124 changes: 124 additions & 0 deletions test/cases/internal/fmtcases/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright 2024 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.

package main

import (
"bytes"
"fmt"
"os"
"path/filepath"
"strconv"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/format"
"github.com/open-policy-agent/opa/test/cases"
"github.com/open-policy-agent/opa/util"
"gopkg.in/yaml.v3"
)

func main() {
if len(os.Args) < 5 {
fmt.Println("Usage: main <source-dir> <source-version> <target-dir> <target-version>")
os.Exit(1)
}

s := os.Args[1]
sv, err := strconv.Atoi(os.Args[2])
if err != nil {
fmt.Println("Version must be an integer")
os.Exit(1)
}
t := os.Args[3]
tv, err := strconv.Atoi(os.Args[4])
if err != nil {
fmt.Println("Version must be an integer")
os.Exit(1)
}
sourceRegoVersion := ast.RegoVersionFromInt(sv)
targetRegoVersion := ast.RegoVersionFromInt(tv)

fmt.Printf("Formatting test cases '%s'->'%s' to rego-version %s\n", s, t, targetRegoVersion)

es, err := os.ReadDir(s)
if err != nil {
fmt.Println("Error reading source directory:", err)
os.Exit(1)
}
for _, e := range es {
if err := copyEntry(s, sourceRegoVersion, e, t, targetRegoVersion); err != nil {
fmt.Println("Error handling source entry:", err)
os.Exit(1)
}
}
}

func copyEntry(sourceRoot string, sourceRegoVersion ast.RegoVersion, e os.DirEntry, targetRoot string, targetRegoVersion ast.RegoVersion) error {
i, err := e.Info()
if err != nil {
return err
}

if i.IsDir() {
err = os.MkdirAll(filepath.Join(targetRoot, e.Name()), i.Mode())
if err != nil {
return err
}
childSourceRoot := filepath.Join(sourceRoot, e.Name())
childTargetRoot := filepath.Join(targetRoot, e.Name())
es, err := os.ReadDir(childSourceRoot)
if err != nil {
return err
}
for _, c := range es {
if err := copyEntry(childSourceRoot, sourceRegoVersion, c, childTargetRoot, targetRegoVersion); err != nil {
return err
}
}
} else {
path := filepath.Join(sourceRoot, i.Name())
bs, err := os.ReadFile(path)
if err != nil {
return err
}

var testCases cases.Set
if err := util.Unmarshal(bs, &testCases); err != nil {
return err
}

// Format test modules
for _, testCase := range testCases.Cases {
for i, module := range testCase.Modules {
bs, err := format.SourceWithOpts(fmt.Sprintf("mod%d.rego", i), []byte(module),
format.Opts{
ParserOptions: &ast.ParserOptions{
RegoVersion: sourceRegoVersion,
},
RegoVersion: targetRegoVersion,
})
if err != nil {
fmt.Printf("Error formatting module %s %s:%d: %v\n", path, testCase.Note, i, err)
} else {
testCase.Modules[i] = string(bs)
}
}
}

// Write formatted test cases to target directory
targetPath := filepath.Join(targetRoot, i.Name())
var buf bytes.Buffer
enc := yaml.NewEncoder(&buf)
enc.SetIndent(2)
if err := enc.Encode(testCases); err != nil {
return err
}

text := fmt.Sprintf("---\n%s", buf.String())
if err := os.WriteFile(targetPath, []byte(text), i.Mode()); err != nil {
return err
}
}
return nil
}
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ cases:
- |
package generated
p = 3
p = 3.0
note: "completedoc/number: 3.0"
query: data.generated.p = x
want_result:
- x: 3
- x: 3.0
Loading

0 comments on commit 3e7e6a0

Please sign in to comment.