Skip to content

Commit

Permalink
[CWS] statically define what getter to be generated (#33657)
Browse files Browse the repository at this point in the history
  • Loading branch information
safchain authored Feb 3, 2025
1 parent a66e254 commit 9b81852
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 2,303 deletions.
37 changes: 24 additions & 13 deletions pkg/security/generators/accessors/accessors.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"os/exec"
"path"
"reflect"
"slices"
"strconv"
"strings"
"text/template"
Expand Down Expand Up @@ -65,9 +66,12 @@ func (af *AstFiles) LookupSymbol(symbol string) *ast.Object { //nolint:staticche
return nil
}

// GetSpecs gets specs
func (af *AstFiles) GetSpecs() []ast.Spec {
var specs []ast.Spec
// Parse extract data
func (af *AstFiles) Parse() ([]ast.Spec, []string) {
var (
specs []ast.Spec
getters []string
)

for _, file := range af.files {
for _, decl := range file.Decls {
Expand All @@ -80,7 +84,13 @@ func (af *AstFiles) GetSpecs() []ast.Spec {
for _, document := range decl.Doc.List {
if strings.Contains(document.Text, "genaccessors") {
genaccessors = true
break
}

if strings.Contains(document.Text, "gengetter") {
els := strings.Split(document.Text, ":")
if len(els) > 1 {
getters = append(getters, strings.TrimSpace(els[1]))
}
}
}

Expand All @@ -92,7 +102,7 @@ func (af *AstFiles) GetSpecs() []ast.Spec {
}
}

return specs
return specs, getters
}

func origTypeToBasicType(kind string) string {
Expand Down Expand Up @@ -163,7 +173,6 @@ func handleBasic(module *common.Module, field seclField, name, alias, aliasPrefi
Alias: alias,
AliasPrefix: aliasPrefix,
GettersOnly: field.gettersOnly,
GenGetters: field.genGetters,
Ref: field.ref,
RestrictedTo: restrictedTo,
}
Expand Down Expand Up @@ -194,7 +203,6 @@ func handleBasic(module *common.Module, field seclField, name, alias, aliasPrefi
Alias: alias,
AliasPrefix: aliasPrefix,
GettersOnly: field.gettersOnly,
GenGetters: field.genGetters,
Ref: field.ref,
RestrictedTo: restrictedTo,
}
Expand Down Expand Up @@ -254,7 +262,6 @@ func handleNonEmbedded(module *common.Module, field seclField, aliasPrefix, alia

func addLengthOpField(module *common.Module, alias string, field *common.StructField) *common.StructField {
lengthField := *field
lengthField.GenGetters = false
lengthField.IsLength = true
lengthField.Name += ".length"
lengthField.OrigType = "int"
Expand Down Expand Up @@ -337,7 +344,6 @@ func handleFieldWithHandler(module *common.Module, field seclField, aliasPrefix,
Alias: alias,
AliasPrefix: aliasPrefix,
GettersOnly: field.gettersOnly,
GenGetters: field.genGetters,
Ref: field.ref,
RestrictedTo: restrictedTo,
ReadOnly: field.readOnly,
Expand Down Expand Up @@ -394,7 +400,6 @@ type seclField struct {
exposedAtEventRootOnly bool // fields that should only be exposed at the root of an event, i.e. `parent` should not be exposed for an `ancestor` of a process
containerStructName string
gettersOnly bool // a field that is not exposed via SECL, but still has an accessor generated
genGetters bool
ref string
readOnly bool
}
Expand Down Expand Up @@ -447,8 +452,6 @@ func parseFieldDef(def string) (seclField, error) {
case "getters_only":
field.gettersOnly = true
field.exposedAtEventRootOnly = true
case "gen_getters":
field.genGetters = true
case "readonly":
field.readOnly = true
}
Expand Down Expand Up @@ -754,7 +757,10 @@ func parseFile(modelFile string, typesFile string, pkgName string) (*common.Modu
module.TargetPkg = path.Clean(path.Join(pkgName, path.Dir(output)))
}

for _, spec := range astFiles.GetSpecs() {
specs, getters := astFiles.Parse()
module.Getters = getters

for _, spec := range specs {
handleSpecRecursive(module, astFiles, spec, "", "", "", nil, nil, make(map[string]bool))
}

Expand Down Expand Up @@ -1075,6 +1081,10 @@ func isReadOnly(field *common.StructField) bool {
return field.IsLength || field.Helper || field.ReadOnly
}

func genGetter(getters []string, getter string) bool {
return slices.Contains(getters, "*") || slices.Contains(getters, getter)
}

var funcMap = map[string]interface{}{
"TrimPrefix": strings.TrimPrefix,
"TrimSuffix": strings.TrimSuffix,
Expand All @@ -1094,6 +1104,7 @@ var funcMap = map[string]interface{}{
"GetFieldReflectType": getFieldReflectType,
"GetSetHandler": getSetHandler,
"IsReadOnly": isReadOnly,
"GenGetter": genGetter,
}

//go:embed accessors.tmpl
Expand Down
1 change: 1 addition & 0 deletions pkg/security/generators/accessors/common/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type Module struct {
Iterators map[string]*StructField
EventTypes map[string]*EventTypeMetadata
Mock bool
Getters []string
}

// StructField represents a structure field for which an accessor will be generated
Expand Down
10 changes: 5 additions & 5 deletions pkg/security/generators/accessors/field_accessors.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ var _ = eval.NewContext

{{range $Name, $Field := .Fields}}

{{if not $Field.GenGetters }}
{{ $getter := (PascalCaseFieldName $Name) | print "Get" }}

{{if not ($getter | GenGetter $.Getters) }}
{{continue}}
{{end}}

Expand All @@ -32,8 +34,6 @@ var _ = eval.NewContext
{{end}}
{{end}}

{{ $pascalCaseName := PascalCaseFieldName $Name }}

{{$accessorReturnType := $Field.OrigType}}
{{ if $Field.Handler}}
{{$accessorReturnType = $Field.ReturnType}}
Expand All @@ -43,8 +43,8 @@ var _ = eval.NewContext
{{$accessorReturnType = $accessorReturnType | printf "[]%s" }}
{{ end }}

// Get{{$pascalCaseName}} returns the value of the field, resolving if necessary
func (ev *Event) Get{{$pascalCaseName}}() {{ $accessorReturnType }} {
// {{$getter}} returns the value of the field, resolving if necessary
func (ev *Event) {{$getter}}() {{ $accessorReturnType }} {
{{if ne $Field.Event ""}}
if ev.GetEventType().String() != "{{$Field.Event}}" {
return {{ GetDefaultValueOfType $accessorReturnType}}
Expand Down
Loading

0 comments on commit 9b81852

Please sign in to comment.