diff --git a/classad/classad.go b/classad/classad.go index ca04992..84a2f2f 100644 --- a/classad/classad.go +++ b/classad/classad.go @@ -3,9 +3,12 @@ package classad import ( + "bytes" "encoding/json" "fmt" + "math" "reflect" + "sort" "strings" "github.com/PelicanPlatform/classad/ast" @@ -19,6 +22,17 @@ type Expr struct { expr ast.Expr } +// Equal reports structural equality between two expressions. +func (e *Expr) Equal(other *Expr) bool { + if e == nil && other == nil { + return true + } + if e == nil || other == nil { + return false + } + return exprEqual(e.internal(), other.internal()) +} + // ParseExpr parses a ClassAd expression string and returns an Expr object. // This allows you to work with expressions without evaluating them immediately. // @@ -133,9 +147,101 @@ func (e *Expr) EvalWithContext(scope, target *ClassAd) Value { // ClassAd represents a ClassAd with attributes that can be evaluated. // This is the main type for working with ClassAds. type ClassAd struct { - ad *ast.ClassAd - parent *ClassAd - target *ClassAd + ad *ast.ClassAd + parent *ClassAd + target *ClassAd + index map[string]*ast.Expr + attrsDirty bool // true when attributes changed since last sort +} + +// Equal reports whether two ClassAds have the same attributes and values, ignoring +// attribute order and casing of attribute names. +func (c *ClassAd) Equal(other *ClassAd) bool { + if c == nil && other == nil { + return true + } + if c == nil || other == nil { + return false + } + + c.ensureSorted() + other.ensureSorted() + + if len(c.ad.Attributes) != len(other.ad.Attributes) { + return false + } + + for i := range c.ad.Attributes { + left := c.ad.Attributes[i] + right := other.ad.Attributes[i] + if normalizeName(left.Name) != normalizeName(right.Name) { + return false + } + if !exprEqual(left.Value, right.Value) { + return false + } + } + + return true +} + +// normalizeName returns a case-insensitive key for attribute lookups. +func normalizeName(name string) string { + return strings.ToLower(name) +} + +// rebuildIndex recreates the fast lookup map from the underlying attributes. +func (c *ClassAd) rebuildIndex() { + if c.ad == nil { + c.index = nil + return + } + if c.index == nil { + c.index = make(map[string]*ast.Expr, len(c.ad.Attributes)) + } else { + for k := range c.index { + delete(c.index, k) + } + } + for i := range c.ad.Attributes { + attr := c.ad.Attributes[i] + c.index[normalizeName(attr.Name)] = &attr.Value + } +} + +// ensureSorted sorts attributes in-place by normalized name when dirty, +// then rebuilds the index so pointers remain valid. +func (c *ClassAd) ensureSorted() { + if c.ad == nil { + c.attrsDirty = false + return + } + if !c.attrsDirty { + c.ensureIndex() + return + } + + sort.SliceStable(c.ad.Attributes, func(i, j int) bool { + iName := normalizeName(c.ad.Attributes[i].Name) + jName := normalizeName(c.ad.Attributes[j].Name) + if iName == jName { + return c.ad.Attributes[i].Name < c.ad.Attributes[j].Name + } + return iName < jName + }) + c.attrsDirty = false + c.rebuildIndex() +} + +func (c *ClassAd) markDirty() { + c.attrsDirty = true +} + +// ensureIndex lazily initializes the lookup map if needed. +func (c *ClassAd) ensureIndex() { + if c.index == nil { + c.rebuildIndex() + } } // New creates a new empty ClassAd. @@ -144,6 +250,8 @@ func New() *ClassAd { ad: &ast.ClassAd{ Attributes: []*ast.AttributeAssignment{}, }, + index: map[string]*ast.Expr{}, + attrsDirty: false, } } @@ -156,7 +264,9 @@ func Parse(input string) (*ClassAd, error) { if ad == nil { return nil, fmt.Errorf("failed to parse ClassAd") } - return &ClassAd{ad: ad}, nil + obj := &ClassAd{ad: ad, attrsDirty: true} + obj.rebuildIndex() + return obj, nil } // ParseOld parses a ClassAd in the "old" HTCondor format and returns a ClassAd object. @@ -174,7 +284,9 @@ func ParseOld(input string) (*ClassAd, error) { if ad == nil { return nil, fmt.Errorf("failed to parse old ClassAd") } - return &ClassAd{ad: ad}, nil + obj := &ClassAd{ad: ad, attrsDirty: true} + obj.rebuildIndex() + return obj, nil } // String returns the string representation of the ClassAd. @@ -182,7 +294,20 @@ func (c *ClassAd) String() string { if c.ad == nil { return "[]" } - return c.ad.String() + + c.ensureSorted() + var b strings.Builder + b.WriteByte('[') + for i, attr := range c.ad.Attributes { + if i > 0 { + b.WriteString("; ") + } + b.WriteString(attr.Name) + b.WriteString(" = ") + b.WriteString(attr.Value.String()) + } + b.WriteByte(']') + return b.String() } // ToOldFormat serializes the ClassAd to old HTCondor format (newline-delimited). @@ -198,6 +323,7 @@ func (c *ClassAd) MarshalOld() string { return "" } + c.ensureSorted() result := "" for i, attr := range c.ad.Attributes { if i > 0 { @@ -213,13 +339,13 @@ func (c *ClassAd) Insert(name string, expr ast.Expr) { if c.ad == nil { c.ad = &ast.ClassAd{Attributes: []*ast.AttributeAssignment{}} } + c.ensureIndex() + c.markDirty() - // Check if attribute already exists and update it - for i, attr := range c.ad.Attributes { - if attr.Name == name { - c.ad.Attributes[i].Value = expr - return - } + normalized := normalizeName(name) + if ptr, ok := c.index[normalized]; ok { + *ptr = expr + return } // Add new attribute @@ -227,6 +353,7 @@ func (c *ClassAd) Insert(name string, expr ast.Expr) { Name: name, Value: expr, }) + c.index[normalized] = &c.ad.Attributes[len(c.ad.Attributes)-1].Value } // InsertExpr inserts an attribute with an Expr value into the ClassAd. @@ -352,6 +479,8 @@ func (c *ClassAd) InsertListElement(name string, element *Expr) { if c.ad == nil { c.ad = &ast.ClassAd{Attributes: []*ast.AttributeAssignment{}} } + c.ensureIndex() + c.markDirty() var astExpr ast.Expr if element == nil { @@ -360,18 +489,13 @@ func (c *ClassAd) InsertListElement(name string, element *Expr) { astExpr = element.internal() } - // Check if attribute already exists - for i, attr := range c.ad.Attributes { - if attr.Name == name { - // If it's a list, append to it - if list, ok := attr.Value.(*ast.ListLiteral); ok { - list.Elements = append(list.Elements, astExpr) - return - } - // Otherwise, replace with a new list containing the element - c.ad.Attributes[i].Value = &ast.ListLiteral{Elements: []ast.Expr{astExpr}} + if ptr, ok := c.index[normalizeName(name)]; ok { + if list, ok := (*ptr).(*ast.ListLiteral); ok { + list.Elements = append(list.Elements, astExpr) return } + *ptr = &ast.ListLiteral{Elements: []ast.Expr{astExpr}} + return } // Add new list attribute @@ -379,6 +503,7 @@ func (c *ClassAd) InsertListElement(name string, element *Expr) { Name: name, Value: &ast.ListLiteral{Elements: []ast.Expr{astExpr}}, }) + c.index[normalizeName(name)] = &c.ad.Attributes[len(c.ad.Attributes)-1].Value } // Lookup returns the unevaluated expression for an attribute. @@ -396,11 +521,10 @@ func (c *ClassAd) Lookup(name string) (*Expr, bool) { if c.ad == nil { return nil, false } + c.ensureIndex() - for _, attr := range c.ad.Attributes { - if attr.Name == name { - return &Expr{expr: attr.Value}, true - } + if ptr, ok := c.index[normalizeName(name)]; ok { + return &Expr{expr: *ptr}, true } return nil, false } @@ -412,11 +536,10 @@ func (c *ClassAd) lookupInternal(name string) ast.Expr { if c.ad == nil { return nil } + c.ensureIndex() - for _, attr := range c.ad.Attributes { - if attr.Name == name { - return attr.Value - } + if ptr, ok := c.index[normalizeName(name)]; ok { + return *ptr } return nil } @@ -556,10 +679,20 @@ func (c *ClassAd) Delete(name string) bool { if c.ad == nil { return false } + c.ensureIndex() + c.markDirty() - for i, attr := range c.ad.Attributes { - if attr.Name == name { + normalized := normalizeName(name) + ptr, ok := c.index[normalized] + if !ok { + return false + } + + // Find the matching attribute by pointer equality on Value. + for i := range c.ad.Attributes { + if &c.ad.Attributes[i].Value == ptr { c.ad.Attributes = append(c.ad.Attributes[:i], c.ad.Attributes[i+1:]...) + delete(c.index, normalized) return true } } @@ -579,6 +712,8 @@ func (c *ClassAd) Clear() { if c.ad != nil { c.ad.Attributes = []*ast.AttributeAssignment{} } + c.index = map[string]*ast.Expr{} + c.attrsDirty = false } // GetAttributes returns a list of all attribute names. @@ -910,6 +1045,44 @@ func (c *ClassAd) flattenExpr(expr ast.Expr) ast.Expr { leftVal := c.exprToValue(left) rightVal := c.exprToValue(right) + // Apply boolean short-circuiting when either side is a literal bool. + if v.Op == "&&" || v.Op == "||" { + if leftVal.IsBool() { + boolVal, err := leftVal.BoolValue() + if err != nil { + return &ast.ErrorLiteral{} + } + if v.Op == "&&" { + if !boolVal { + return &ast.BooleanLiteral{Value: false} + } + return right + } + // v.Op == "||" + if boolVal { + return &ast.BooleanLiteral{Value: true} + } + return right + } + if rightVal.IsBool() { + boolVal, err := rightVal.BoolValue() + if err != nil { + return &ast.ErrorLiteral{} + } + if v.Op == "&&" { + if !boolVal { + return &ast.BooleanLiteral{Value: false} + } + return left + } + // v.Op == "||" + if boolVal { + return &ast.BooleanLiteral{Value: true} + } + return left + } + } + if !leftVal.IsUndefined() && !rightVal.IsUndefined() { // Try to compute the operation result := c.evaluateBinaryOp(v.Op, leftVal, rightVal) @@ -982,11 +1155,24 @@ func (c *ClassAd) flattenExpr(expr ast.Expr) ast.Expr { for i, arg := range v.Args { args[i] = c.flattenExpr(arg) } - return &ast.FunctionCall{ - Name: v.Name, - Args: args, + + // Fold ifThenElse when the condition is a literal boolean after flattening. + if strings.EqualFold(v.Name, "ifThenElse") && len(args) == 3 { + condVal := c.exprToValue(args[0]) + if condVal.IsBool() { + boolVal, err := condVal.BoolValue() + if err != nil { + return &ast.ErrorLiteral{} + } + if boolVal { + return args[1] + } + return args[2] + } } + return &ast.FunctionCall{Name: v.Name, Args: args} + case *ast.ListLiteral: elements := make([]ast.Expr, len(v.Elements)) for i, elem := range v.Elements { @@ -1078,6 +1264,111 @@ func (c *ClassAd) valueToExpr(val Value) ast.Expr { return &ast.UndefinedLiteral{} } +// exprEqual compares two ast expressions for structural equality. +func exprEqual(a, b ast.Expr) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + + switch av := a.(type) { + case *ast.IntegerLiteral: + bv, ok := b.(*ast.IntegerLiteral) + return ok && av.Value == bv.Value + case *ast.RealLiteral: + bv, ok := b.(*ast.RealLiteral) + return ok && floatEqual(av.Value, bv.Value) + case *ast.StringLiteral: + bv, ok := b.(*ast.StringLiteral) + return ok && av.Value == bv.Value + case *ast.BooleanLiteral: + bv, ok := b.(*ast.BooleanLiteral) + return ok && av.Value == bv.Value + case *ast.UndefinedLiteral: + _, ok := b.(*ast.UndefinedLiteral) + return ok + case *ast.ErrorLiteral: + _, ok := b.(*ast.ErrorLiteral) + return ok + case *ast.AttributeReference: + bv, ok := b.(*ast.AttributeReference) + return ok && av.Scope == bv.Scope && strings.EqualFold(av.Name, bv.Name) + case *ast.BinaryOp: + bv, ok := b.(*ast.BinaryOp) + return ok && av.Op == bv.Op && exprEqual(av.Left, bv.Left) && exprEqual(av.Right, bv.Right) + case *ast.UnaryOp: + bv, ok := b.(*ast.UnaryOp) + return ok && av.Op == bv.Op && exprEqual(av.Expr, bv.Expr) + case *ast.ConditionalExpr: + bv, ok := b.(*ast.ConditionalExpr) + return ok && exprEqual(av.Condition, bv.Condition) && exprEqual(av.TrueExpr, bv.TrueExpr) && exprEqual(av.FalseExpr, bv.FalseExpr) + case *ast.ElvisExpr: + bv, ok := b.(*ast.ElvisExpr) + return ok && exprEqual(av.Left, bv.Left) && exprEqual(av.Right, bv.Right) + case *ast.FunctionCall: + bv, ok := b.(*ast.FunctionCall) + if !ok || !strings.EqualFold(av.Name, bv.Name) || len(av.Args) != len(bv.Args) { + return false + } + for i := range av.Args { + if !exprEqual(av.Args[i], bv.Args[i]) { + return false + } + } + return true + case *ast.ListLiteral: + bv, ok := b.(*ast.ListLiteral) + if !ok || len(av.Elements) != len(bv.Elements) { + return false + } + for i := range av.Elements { + if !exprEqual(av.Elements[i], bv.Elements[i]) { + return false + } + } + return true + case *ast.RecordLiteral: + bv, ok := b.(*ast.RecordLiteral) + if !ok { + return false + } + left := &ClassAd{ad: av.ClassAd, attrsDirty: true} + left.rebuildIndex() + right := &ClassAd{ad: bv.ClassAd, attrsDirty: true} + right.rebuildIndex() + return left.Equal(right) + case *ast.SelectExpr: + bv, ok := b.(*ast.SelectExpr) + return ok && strings.EqualFold(av.Attr, bv.Attr) && exprEqual(av.Record, bv.Record) + case *ast.SubscriptExpr: + bv, ok := b.(*ast.SubscriptExpr) + return ok && exprEqual(av.Container, bv.Container) && exprEqual(av.Index, bv.Index) + default: + return false + } +} + +// floatEqual compares two float64 values with a relative tolerance to account for +// floating point rounding. NaN only equals NaN; +Inf/-Inf must match exactly. +func floatEqual(a, b float64) bool { + if math.IsNaN(a) || math.IsNaN(b) { + return math.IsNaN(a) && math.IsNaN(b) + } + if math.IsInf(a, 0) || math.IsInf(b, 0) { + return math.IsInf(a, 1) == math.IsInf(b, 1) && math.IsInf(a, -1) == math.IsInf(b, -1) + } + + const relTol = 1e-9 + diff := math.Abs(a - b) + if diff == 0 { + return true + } + mag := math.Max(math.Abs(a), math.Abs(b)) + return diff <= relTol*mag +} + // Helper functions for evaluating operations during flattening func (c *ClassAd) evaluateBinaryOp(op string, left, right Value) Value { // Create a temporary evaluator to use its operator logic @@ -1118,22 +1409,32 @@ func (c *ClassAd) MarshalJSON() ([]byte, error) { return []byte("{}"), nil } - result := make(map[string]interface{}) - for _, attr := range c.ad.Attributes { + c.ensureSorted() + var buf bytes.Buffer + buf.WriteByte('{') + for i, attr := range c.ad.Attributes { + if i > 0 { + buf.WriteByte(',') + } + keyBytes, err := json.Marshal(attr.Name) + if err != nil { + return nil, err + } value, err := c.marshalValue(attr.Value) if err != nil { return nil, fmt.Errorf("failed to marshal attribute %s: %w", attr.Name, err) } - result[attr.Name] = value - } - - jsonBytes, err := json.Marshal(result) - if err != nil { - return nil, err + valBytes, err := json.Marshal(value) + if err != nil { + return nil, err + } + buf.Write(keyBytes) + buf.WriteByte(':') + buf.Write(valBytes) } + buf.WriteByte('}') - // Post-process to escape forward slashes in /Expr(...)/ patterns - // Go's json.Marshal doesn't escape / by default, but we prefer \/ for expressions + jsonBytes := buf.Bytes() jsonBytes = []byte(strings.ReplaceAll(string(jsonBytes), "\"/Expr(", "\"\\/Expr(")) jsonBytes = []byte(strings.ReplaceAll(string(jsonBytes), ")/\"", ")\\/\"")) @@ -1165,17 +1466,14 @@ func (c *ClassAd) marshalValue(expr ast.Expr) (interface{}, error) { } return list, nil case *ast.RecordLiteral: - // Nested ClassAd - nested := &ClassAd{ad: v.ClassAd} - nestedMap := make(map[string]interface{}) - for _, attr := range v.ClassAd.Attributes { - val, err := nested.marshalValue(attr.Value) - if err != nil { - return nil, err - } - nestedMap[attr.Name] = val + // Nested ClassAd: serialize deterministically and embed as raw JSON. + nested := &ClassAd{ad: v.ClassAd, attrsDirty: true} + nested.rebuildIndex() + nestedBytes, err := nested.MarshalJSON() + if err != nil { + return nil, err } - return nestedMap, nil + return json.RawMessage(nestedBytes), nil default: // Complex expression - serialize as string with special markers // Format: /Expr()/ @@ -1212,11 +1510,27 @@ func (c *ClassAd) UnmarshalJSON(data []byte) error { Value: expr, }) } + sortAttributeAssignments(attributes) c.ad = &ast.ClassAd{Attributes: attributes} + c.attrsDirty = true + c.rebuildIndex() return nil } +// sortAttributeAssignments provides deterministic ordering by case-insensitive name with +// a secondary case-sensitive tie-breaker to preserve stable behavior. +func sortAttributeAssignments(attrs []*ast.AttributeAssignment) { + sort.SliceStable(attrs, func(i, j int) bool { + iName := normalizeName(attrs[i].Name) + jName := normalizeName(attrs[j].Name) + if iName == jName { + return attrs[i].Name < attrs[j].Name + } + return iName < jName + }) +} + // unmarshalValue converts a JSON value back into an AST expression. func (c *ClassAd) unmarshalValue(value interface{}) (ast.Expr, error) { switch v := value.(type) { @@ -1232,8 +1546,7 @@ func (c *ClassAd) unmarshalValue(value interface{}) (ast.Expr, error) { } return &ast.RealLiteral{Value: v}, nil case string: - // Check if it's an expression string - // Only accept the format /Expr(...)/ + // Check if it's an expression string. if strings.HasPrefix(v, "/Expr(") && strings.HasSuffix(v, ")/") { exprStr := v[6 : len(v)-2] // Remove "/Expr(" and ")/" return c.parseExpression(exprStr) @@ -1264,6 +1577,7 @@ func (c *ClassAd) unmarshalValue(value interface{}) (ast.Expr, error) { Value: expr, }) } + sortAttributeAssignments(attributes) return &ast.RecordLiteral{ ClassAd: &ast.ClassAd{Attributes: attributes}, }, nil diff --git a/classad/equal_test.go b/classad/equal_test.go new file mode 100644 index 0000000..cb8bb09 --- /dev/null +++ b/classad/equal_test.go @@ -0,0 +1,240 @@ +package classad + +import ( + "encoding/json" + "math" + "testing" + + "github.com/PelicanPlatform/classad/ast" +) + +func TestExprEqual(t *testing.T) { + a, err := ParseExpr("Foo + bar") + if err != nil { + t.Fatalf("parse expr a: %v", err) + } + b, err := ParseExpr("foo + BAR") + if err != nil { + t.Fatalf("parse expr b: %v", err) + } + c, err := ParseExpr("foo + baz") + if err != nil { + t.Fatalf("parse expr c: %v", err) + } + + if !a.Equal(b) { + t.Fatalf("expected expressions to be equal") + } + if a.Equal(c) { + t.Fatalf("expected expressions to differ") + } +} + +func TestClassAdEqual(t *testing.T) { + ad1, err := Parse(`[ + A = 1; + list = {1, 2}; + nested = [b = 2; a = 1]; + ]`) + if err != nil { + t.Fatalf("parse ad1: %v", err) + } + + ad2, err := Parse(`[ + nested = [a = 1; b = 2]; + list = {1, 2}; + a = 1; + ]`) + if err != nil { + t.Fatalf("parse ad2: %v", err) + } + + ad3, err := Parse(`[ + nested = [a = 1; b = 3]; + list = {1, 2}; + a = 1; + ]`) + if err != nil { + t.Fatalf("parse ad3: %v", err) + } + + if !ad1.Equal(ad2) { + t.Fatalf("expected ad1 and ad2 to be equal") + } + if ad1.Equal(ad3) { + t.Fatalf("expected ad1 and ad3 to differ") + } +} + +func TestClassAdEqualAfterJSONRoundTrip(t *testing.T) { + original, err := Parse(`[ + Z = 1; + nested = [b = 2; A = 1; c = 3]; + alpha = [y = 20; x = 10]; + outer = [ + inner = [b = 5; a = 4]; + num = 7 + ]; + list = { [z = 9; y = 8; x = 7], 5, 4 } + ]`) + if err != nil { + t.Fatalf("parse original: %v", err) + } + + bytes1, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal original: %v", err) + } + + var roundTripped ClassAd + if err := json.Unmarshal(bytes1, &roundTripped); err != nil { + t.Fatalf("unmarshal copy: %v", err) + } + + if !original.Equal(&roundTripped) { + t.Fatalf("expected original and copy to be equal after round-trip") + } +} + +func TestCaseInsensitiveAttributeReferences(t *testing.T) { + ad, err := Parse(`[Foo = bar; BAR = 3]`) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + + val := ad.EvaluateAttr("Foo") + if !val.IsInteger() { + t.Fatalf("expected Foo to evaluate to integer, got %v", val.Type()) + } + intVal, _ := val.IntValue() + if intVal != 3 { + t.Fatalf("expected Foo to evaluate to 3, got %d", intVal) + } +} + +func exprFromAST(e ast.Expr) *Expr { + return &Expr{expr: e} +} + +func TestExprEqualAllCases(t *testing.T) { + int1 := &ast.IntegerLiteral{Value: 1} + int2 := &ast.IntegerLiteral{Value: 2} + realNearA := &ast.RealLiteral{Value: 1.0} + realNearB := &ast.RealLiteral{Value: 1.0 + 1e-10} + realFar := &ast.RealLiteral{Value: 1.0 + 1e-6} + realNaN := &ast.RealLiteral{Value: math.NaN()} + realPosInf := &ast.RealLiteral{Value: math.Inf(1)} + realNegInf := &ast.RealLiteral{Value: math.Inf(-1)} + strFoo := &ast.StringLiteral{Value: "foo"} + strBar := &ast.StringLiteral{Value: "bar"} + boolTrue := &ast.BooleanLiteral{Value: true} + boolFalse := &ast.BooleanLiteral{Value: false} + undef := &ast.UndefinedLiteral{} + errLit := &ast.ErrorLiteral{} + attrFoo := &ast.AttributeReference{Name: "Foo"} + attrfoo := &ast.AttributeReference{Name: "foo"} + binAdd1 := &ast.BinaryOp{Op: "+", Left: int1, Right: int2} + binAdd2 := &ast.BinaryOp{Op: "+", Left: int1, Right: int2} + binSub := &ast.BinaryOp{Op: "-", Left: int1, Right: int2} + unaryNeg1 := &ast.UnaryOp{Op: "-", Expr: int1} + unaryNeg2 := &ast.UnaryOp{Op: "-", Expr: int1} + cond1 := &ast.ConditionalExpr{Condition: boolTrue, TrueExpr: int1, FalseExpr: int2} + cond2 := &ast.ConditionalExpr{Condition: boolTrue, TrueExpr: int1, FalseExpr: int2} + elvis1 := &ast.ElvisExpr{Left: undef, Right: int1} + elvis2 := &ast.ElvisExpr{Left: undef, Right: int1} + fn1 := &ast.FunctionCall{Name: "toUpper", Args: []ast.Expr{strFoo}} + fn2 := &ast.FunctionCall{Name: "TOUPPER", Args: []ast.Expr{strFoo}} + list12 := &ast.ListLiteral{Elements: []ast.Expr{int1, int2}} + list21 := &ast.ListLiteral{Elements: []ast.Expr{int2, int1}} + record1 := &ast.RecordLiteral{ClassAd: &ast.ClassAd{Attributes: []*ast.AttributeAssignment{ + {Name: "x", Value: int1}, + {Name: "y", Value: int2}, + }}} + record2 := &ast.RecordLiteral{ClassAd: &ast.ClassAd{Attributes: []*ast.AttributeAssignment{ + {Name: "y", Value: int2}, + {Name: "x", Value: int1}, + }}} + select1 := &ast.SelectExpr{Record: attrFoo, Attr: "Bar"} + select2 := &ast.SelectExpr{Record: attrfoo, Attr: "bar"} + subscript1 := &ast.SubscriptExpr{Container: list12, Index: int1} + subscript2 := &ast.SubscriptExpr{Container: list12, Index: int1} + + cases := []struct { + name string + left ast.Expr + right ast.Expr + equal bool + }{ + {name: "int equal", left: int1, right: &ast.IntegerLiteral{Value: 1}, equal: true}, + {name: "int not equal", left: int1, right: int2, equal: false}, + {name: "real within tol", left: realNearA, right: realNearB, equal: true}, + {name: "real outside tol", left: realNearA, right: realFar, equal: false}, + {name: "real nan vs nan", left: realNaN, right: &ast.RealLiteral{Value: math.NaN()}, equal: true}, + {name: "real nan vs num", left: realNaN, right: realNearA, equal: false}, + {name: "real inf match", left: realPosInf, right: &ast.RealLiteral{Value: math.Inf(1)}, equal: true}, + {name: "real inf mismatch", left: realPosInf, right: realNegInf, equal: false}, + {name: "string equal", left: strFoo, right: &ast.StringLiteral{Value: "foo"}, equal: true}, + {name: "string diff", left: strFoo, right: strBar, equal: false}, + {name: "bool equal", left: boolTrue, right: &ast.BooleanLiteral{Value: true}, equal: true}, + {name: "bool diff", left: boolTrue, right: boolFalse, equal: false}, + {name: "undefined equal", left: undef, right: &ast.UndefinedLiteral{}, equal: true}, + {name: "undefined vs int", left: undef, right: int1, equal: false}, + {name: "error equal", left: errLit, right: &ast.ErrorLiteral{}, equal: true}, + {name: "attr case-insensitive", left: attrFoo, right: attrfoo, equal: true}, + {name: "binary equal", left: binAdd1, right: binAdd2, equal: true}, + {name: "binary diff op", left: binAdd1, right: binSub, equal: false}, + {name: "unary equal", left: unaryNeg1, right: unaryNeg2, equal: true}, + {name: "unary diff op", left: unaryNeg1, right: &ast.UnaryOp{Op: "!", Expr: int1}, equal: false}, + {name: "conditional equal", left: cond1, right: cond2, equal: true}, + {name: "conditional diff", left: cond1, right: &ast.ConditionalExpr{Condition: boolFalse, TrueExpr: int1, FalseExpr: int2}, equal: false}, + {name: "elvis equal", left: elvis1, right: elvis2, equal: true}, + {name: "elvis diff", left: elvis1, right: &ast.ElvisExpr{Left: int1, Right: int2}, equal: false}, + {name: "func name case-insensitive", left: fn1, right: fn2, equal: true}, + {name: "func args diff", left: fn1, right: &ast.FunctionCall{Name: "toUpper", Args: []ast.Expr{strBar}}, equal: false}, + {name: "list equal", left: list12, right: &ast.ListLiteral{Elements: []ast.Expr{int1, int2}}, equal: true}, + {name: "list order diff", left: list12, right: list21, equal: false}, + {name: "record equal order-insensitive", left: record1, right: record2, equal: true}, + {name: "record diff value", left: record1, right: &ast.RecordLiteral{ClassAd: &ast.ClassAd{Attributes: []*ast.AttributeAssignment{{Name: "x", Value: int2}}}}, equal: false}, + {name: "select attr case-insensitive", left: select1, right: select2, equal: true}, + {name: "subscript equal", left: subscript1, right: subscript2, equal: true}, + {name: "subscript diff", left: subscript1, right: &ast.SubscriptExpr{Container: list12, Index: int2}, equal: false}, + } + + for _, tc := range cases { + if got := exprFromAST(tc.left).Equal(exprFromAST(tc.right)); got != tc.equal { + t.Fatalf("%s: expected %v, got %v", tc.name, tc.equal, got) + } + } +} + +func TestExprEqualNilHandling(t *testing.T) { + var left *Expr + var right *Expr + if !left.Equal(right) { + t.Fatalf("nil vs nil should be equal") + } + + nonNil := exprFromAST(&ast.IntegerLiteral{Value: 1}) + if left.Equal(nonNil) { + t.Fatalf("nil vs non-nil should not be equal") + } + if nonNil.Equal(left) { + t.Fatalf("non-nil vs nil should not be equal") + } +} + +func TestClassAdEqualNilHandling(t *testing.T) { + var left *ClassAd + var right *ClassAd + if !left.Equal(right) { + t.Fatalf("nil vs nil ClassAd should be equal") + } + + nonNil := New() + if left.Equal(nonNil) { + t.Fatalf("nil vs non-nil ClassAd should not be equal") + } + if nonNil.Equal(left) { + t.Fatalf("non-nil vs nil ClassAd should not be equal") + } +} diff --git a/classad/flatten_test.go b/classad/flatten_test.go index ba55594..e9c5af6 100644 --- a/classad/flatten_test.go +++ b/classad/flatten_test.go @@ -113,3 +113,141 @@ func TestFlattenPreservesUnknownReference(t *testing.T) { t.Fatalf("expected unknown reference to remain, got %s", flat.String()) } } + +func TestFlattenBooleanAndShortCircuitFalse(t *testing.T) { + ad := New() + ad.InsertAttrString("User", "brian") + + expr, err := ParseExpr(`User == "bbockelm" && (Unknown > 0) && true`) + if err != nil { + t.Fatalf("parse expr failed: %v", err) + } + + flat := ad.Flatten(expr) + if flat.String() != "false" { + t.Fatalf("expected short-circuited false, got %s", flat.String()) + } +} + +func TestFlattenBooleanAndPropagatesRightWhenTrue(t *testing.T) { + ad := New() + ad.InsertAttrBool("Flag", true) + + expr, err := ParseExpr(`Flag && (Unknown > 0)`) + if err != nil { + t.Fatalf("parse expr failed: %v", err) + } + + flat := ad.Flatten(expr) + if !strings.Contains(flat.String(), "Unknown > 0") { + t.Fatalf("expected right-hand expression to remain, got %s", flat.String()) + } +} + +func TestFlattenBooleanOrShortCircuitTrue(t *testing.T) { + ad := New() + ad.InsertAttrBool("Flag", true) + + expr, err := ParseExpr(`Flag || (Unknown > 0)`) + if err != nil { + t.Fatalf("parse expr failed: %v", err) + } + + flat := ad.Flatten(expr) + if flat.String() != "true" { + t.Fatalf("expected short-circuited true, got %s", flat.String()) + } +} + +func TestFlattenBooleanOrPropagatesRightWhenFalse(t *testing.T) { + ad := New() + ad.InsertAttrBool("Flag", false) + + expr, err := ParseExpr(`Flag || (Unknown > 0)`) + if err != nil { + t.Fatalf("parse expr failed: %v", err) + } + + flat := ad.Flatten(expr) + if !strings.Contains(flat.String(), "Unknown > 0") { + t.Fatalf("expected right-hand expression to remain, got %s", flat.String()) + } +} + +func TestFlattenBooleanAndRightLiteralFalse(t *testing.T) { + ad := New() + ad.InsertAttrBool("Flag", true) + + expr, err := ParseExpr(`Unknown > 0 && false`) + if err != nil { + t.Fatalf("parse expr failed: %v", err) + } + + flat := ad.Flatten(expr) + if flat.String() != "false" { + t.Fatalf("expected false due to right-hand literal, got %s", flat.String()) + } +} + +func TestFlattenBooleanOrRightLiteralTrue(t *testing.T) { + ad := New() + ad.InsertAttrBool("Flag", false) + + expr, err := ParseExpr(`Unknown > 0 || true`) + if err != nil { + t.Fatalf("parse expr failed: %v", err) + } + + flat := ad.Flatten(expr) + if flat.String() != "true" { + t.Fatalf("expected true due to right-hand literal, got %s", flat.String()) + } +} + +func TestFlattenConditionalLiteralCondition(t *testing.T) { + ad := New() + ad.InsertAttr("X", 1) + + expr, err := ParseExpr(`true ? X : Unknown`) + if err != nil { + t.Fatalf("parse expr failed: %v", err) + } + + flat := ad.Flatten(expr) + if flat.String() != "1" { + t.Fatalf("expected true branch to evaluate to literal, got %s", flat.String()) + } + + expr2, err := ParseExpr(`false ? X : Unknown`) + if err != nil { + t.Fatalf("parse expr failed: %v", err) + } + flat2 := ad.Flatten(expr2) + if flat2.String() != "Unknown" { + t.Fatalf("expected false branch to remain, got %s", flat2.String()) + } +} + +func TestFlattenIfThenElseLiteralCondition(t *testing.T) { + ad := New() + ad.InsertAttr("X", 1) + + expr, err := ParseExpr(`ifThenElse(true, X, Unknown)`) + if err != nil { + t.Fatalf("parse expr failed: %v", err) + } + + flat := ad.Flatten(expr) + if flat.String() != "1" { + t.Fatalf("expected true branch to evaluate to literal, got %s", flat.String()) + } + + expr2, err := ParseExpr(`ifThenElse(false, X, Unknown)`) + if err != nil { + t.Fatalf("parse expr failed: %v", err) + } + flat2 := ad.Flatten(expr2) + if flat2.String() != "Unknown" { + t.Fatalf("expected false branch to remain, got %s", flat2.String()) + } +} diff --git a/classad/generic_api_test.go b/classad/generic_api_test.go index 0f35037..ad09589 100644 --- a/classad/generic_api_test.go +++ b/classad/generic_api_test.go @@ -142,6 +142,40 @@ func TestSet_Struct(t *testing.T) { } } +func TestCaseInsensitiveLookupAndEvaluate(t *testing.T) { + ad := New() + ad.InsertAttr("Foo", 1) + + if _, ok := ad.Lookup("foo"); !ok { + t.Fatalf("Lookup should be case-insensitive") + } + + if got, ok := ad.EvaluateAttrInt("FOO"); !ok || got != 1 { + t.Fatalf("EvaluateAttrInt should be case-insensitive, got %d", got) + } +} + +func TestSetCaseInsensitivePreserveCase(t *testing.T) { + ad := New() + + if err := ad.Set("BaR", int64(3)); err != nil { + t.Fatalf("initial Set failed: %v", err) + } + + if err := ad.Set("bar", int64(5)); err != nil { + t.Fatalf("second Set failed: %v", err) + } + + attrs := ad.GetAttributes() + if len(attrs) != 1 || attrs[0] != "BaR" { + t.Fatalf("expected case to be preserved, got %+v", attrs) + } + + if got, ok := ad.EvaluateAttrInt("BAR"); !ok || got != 5 { + t.Fatalf("expected updated value with case-insensitive lookup, got %d", got) + } +} + func TestGetAs_BasicTypes(t *testing.T) { ad := New() ad.InsertAttr("int", 42) diff --git a/classad/json_test.go b/classad/json_test.go index 6de0da9..87f7120 100644 --- a/classad/json_test.go +++ b/classad/json_test.go @@ -1,6 +1,7 @@ package classad import ( + "bytes" "encoding/json" "testing" ) @@ -155,6 +156,22 @@ func TestMarshalJSON_NestedClassAd(t *testing.T) { } } +func TestUnmarshalJSON_SortsKeys(t *testing.T) { + jsonStr := `{"b": 1, "a": 2, "c": 3}` + var ad ClassAd + if err := json.Unmarshal([]byte(jsonStr), &ad); err != nil { + t.Fatalf("UnmarshalJSON failed: %v", err) + } + + if ad.String() != "[a = 2; b = 1; c = 3]" { + t.Fatalf("expected sorted keys, got %s", ad.String()) + } + + if ad.MarshalOld() != "a = 2\nb = 1\nc = 3" { + t.Fatalf("expected sorted keys in old format, got %s", ad.MarshalOld()) + } +} + func TestUnmarshalJSON_SimpleValues(t *testing.T) { jsonStr := `{ "x": 5, @@ -333,6 +350,53 @@ func TestRoundTrip_ComplexClassAd(t *testing.T) { } } +func TestRoundTrip_JSONIdentityWithNested(t *testing.T) { + // Intentionally unsorted keys to verify deterministic ordering survives round-trip. + source := `[ + Z = 1; + nested = [b = 2; A = 1; c = 3]; + alpha = [y = 20; x = 10]; + outer = [ + inner = [b = 5; a = 4]; + num = 7 + ]; + list = { [z = 9; y = 8; x = 7], 5, 4 } + ]` + + ad1, err := Parse(source) + if err != nil { + t.Fatalf("failed to parse source ClassAd: %v", err) + } + + jsonBytes1, err := json.Marshal(ad1) + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + + var ad2 ClassAd + err = json.Unmarshal(jsonBytes1, &ad2) + if err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + + // Structural identity: string and old-format renderings must match. + if ad1.String() != ad2.String() { + t.Fatalf("String mismatch after round-trip:\norig: %s\nnew: %s", ad1.String(), ad2.String()) + } + if ad1.MarshalOld() != ad2.MarshalOld() { + t.Fatalf("MarshalOld mismatch after round-trip:\norig:\n%s\nnew:\n%s", ad1.MarshalOld(), ad2.MarshalOld()) + } + + // Deterministic JSON: re-marshal should produce identical bytes. + jsonBytes2, marshalErr := json.Marshal(&ad2) + if marshalErr != nil { + t.Fatalf("second marshal failed: %v", marshalErr) + } + if !bytes.Equal(jsonBytes1, jsonBytes2) { + t.Fatalf("JSON not deterministic after round-trip:\nfirst: %s\nsecond: %s", string(jsonBytes1), string(jsonBytes2)) + } +} + func TestMarshalJSON_UndefinedValue(t *testing.T) { ad, err := Parse(`[x = undefined]`) if err != nil {