diff --git a/classad/api_math_coverage_test.go b/classad/api_math_coverage_test.go new file mode 100644 index 0000000..83b4b59 --- /dev/null +++ b/classad/api_math_coverage_test.go @@ -0,0 +1,94 @@ +package classad + +import ( + "testing" + + "github.com/PelicanPlatform/classad/ast" +) + +func TestClassAdAPIAndMathEvaluation(t *testing.T) { + ad := New() + ad.InsertAttr("i", 2) + ad.InsertAttrFloat("f", 2.5) + ad.InsertAttrString("s", "hi") + ad.InsertAttrBool("b", true) + InsertAttrList(ad, "lst", []int64{1, 2, 3}) + child := New() + child.InsertAttr("c", 5) + ad.InsertAttrClassAd("child", child) + + if ad.Size() != 6 { + t.Fatalf("expected 6 attributes, got %d", ad.Size()) + } + + if v := GetOr[int64](ad, "missing", 9); v != 9 { + t.Fatalf("GetOr default expected 9, got %d", v) + } + if v := GetOr[int64](ad, "i", 0); v != 2 { + t.Fatalf("GetOr existing expected 2, got %d", v) + } + + if _, ok := ad.Lookup("s"); !ok { + t.Fatalf("expected to find attribute s") + } + if str, ok := ad.EvaluateAttrString("s"); !ok || str != "hi" { + t.Fatalf("expected EvaluateAttrString hi, got %q ok=%v", str, ok) + } + if b, ok := ad.EvaluateAttrBool("b"); !ok || !b { + t.Fatalf("expected EvaluateAttrBool true, got %v ok=%v", b, ok) + } + if num, ok := ad.EvaluateAttrNumber("f"); !ok || num != 2.5 { + t.Fatalf("expected EvaluateAttrNumber 2.5, got %v ok=%v", num, ok) + } + + expr := &ast.BinaryOp{ + Op: "+", + Left: &ast.AttributeReference{Name: "i"}, + Right: &ast.AttributeReference{Name: "f"}, + } + res := ad.EvaluateExpr(expr) + if !res.IsReal() { + t.Fatalf("expected real result from i+f") + } + if val, _ := res.RealValue(); val != 4.5 { + t.Fatalf("expected 4.5, got %f", val) + } + + expr2, err := ParseExpr("MY.i + TARGET.j + round(MY.f)") + if err != nil { + t.Fatalf("ParseExpr failed: %v", err) + } + target := New() + target.InsertAttr("j", 3) + res2 := ad.EvaluateExprWithTarget(expr2, target) + if !res2.IsInteger() { + t.Fatalf("expected integer from EvaluateExprWithTarget") + } + if val, _ := res2.IntValue(); val != 8 { + t.Fatalf("expected 8 from MY/TARGET math, got %d", val) + } + + bin := ad.evaluateBinaryOp("-", NewIntValue(10), NewRealValue(3.5)) + if !bin.IsReal() { + t.Fatalf("binary op should yield real") + } + if val, _ := bin.RealValue(); val != 6.5 { + t.Fatalf("expected 6.5, got %f", val) + } + + unary := ad.evaluateUnaryOp("-", NewRealValue(1.25)) + if val, _ := unary.RealValue(); val != -1.25 { + t.Fatalf("expected -1.25 from unary op, got %f", val) + } + + if !ad.Delete("s") { + t.Fatalf("expected to delete attribute s") + } + ad.Clear() + if ad.Size() != 0 { + t.Fatalf("expected clear to remove all attributes") + } + if len(ad.GetAttributes()) != 0 { + t.Fatalf("expected no attribute names after clear") + } +} diff --git a/classad/classad.go b/classad/classad.go index 3859a3d..ca04992 100644 --- a/classad/classad.go +++ b/classad/classad.go @@ -1056,6 +1056,20 @@ func (c *ClassAd) valueToExpr(val Value) ast.Expr { if boolVal, err := val.BoolValue(); err == nil { return &ast.BooleanLiteral{Value: boolVal} } + case ListValue: + list, err := val.ListValue() + if err == nil { + elements := make([]ast.Expr, 0, len(list)) + for _, item := range list { + elements = append(elements, c.valueToExpr(item)) + } + return &ast.ListLiteral{Elements: elements} + } + case ClassAdValue: + adVal, err := val.ClassAdValue() + if err == nil && adVal != nil { + return &ast.RecordLiteral{ClassAd: adVal.ad} + } case UndefinedValue: return &ast.UndefinedLiteral{} case ErrorValue: diff --git a/classad/comparison_test.go b/classad/comparison_test.go new file mode 100644 index 0000000..ac11454 --- /dev/null +++ b/classad/comparison_test.go @@ -0,0 +1,78 @@ +package classad + +import "testing" + +func TestEvaluatorComparisonOperators(t *testing.T) { + cases := []struct { + expr string + expectErr bool + expectVal bool + }{ + {"3 > 2", false, true}, + {"2 <= 2", false, true}, + {"1 >= 2", false, false}, + {"\"b\" <= \"a\"", false, false}, + {"\"b\" >= \"a\"", false, true}, + {"true > false", true, false}, + {"1 >= \"a\"", true, false}, + {"\"c\" > \"b\"", false, true}, + } + + for _, tc := range cases { + val := evalBuiltin(t, tc.expr) + if tc.expectErr { + if !val.IsError() { + t.Fatalf("expected error for %s", tc.expr) + } + continue + } + + if val.IsError() || val.IsUndefined() { + t.Fatalf("unexpected non-boolean result for %s: %v", tc.expr, val.Type()) + } + b, _ := val.BoolValue() + if b != tc.expectVal { + t.Fatalf("unexpected value for %s: got %v want %v", tc.expr, b, tc.expectVal) + } + } +} + +func TestEvaluatorEqualityUndefined(t *testing.T) { + val := evalBuiltin(t, `1 == undefined`) + if !val.IsUndefined() { + t.Fatalf("expected undefined result when comparing to undefined, got %v", val.Type()) + } +} + +func TestEvaluatorEqualityNonScalar(t *testing.T) { + val := evalBuiltin(t, `{1,2} == {1,2}`) + if !val.IsError() { + t.Fatalf("expected error when comparing non-scalar values, got %v", val.Type()) + } +} + +func TestEvaluatorEqualityVariants(t *testing.T) { + if v := evalBuiltin(t, `undefined == undefined`); !v.IsBool() { + t.Fatalf("expected bool for undefined==undefined") + } else if b, _ := v.BoolValue(); !b { + t.Fatalf("expected undefined==undefined to be true") + } + + if v := evalBuiltin(t, `true == true`); !v.IsBool() { + t.Fatalf("expected bool for bool equality") + } else if b, _ := v.BoolValue(); !b { + t.Fatalf("expected true==true") + } + + if v := evalBuiltin(t, `1 == 1.0000000001`); !v.IsBool() { + t.Fatalf("expected bool for numeric near-equality") + } else if b, _ := v.BoolValue(); !b { + t.Fatalf("expected near-equal numeric values to compare true") + } + + if v := evalBuiltin(t, `"a" == 1`); !v.IsBool() { + t.Fatalf("expected bool for type-mismatch equality") + } else if b, _ := v.BoolValue(); b { + t.Fatalf("expected string vs int equality to be false") + } +} diff --git a/classad/coverage_misc_test.go b/classad/coverage_misc_test.go new file mode 100644 index 0000000..e87da65 --- /dev/null +++ b/classad/coverage_misc_test.go @@ -0,0 +1,67 @@ +package classad + +import ( + "testing" + + "github.com/PelicanPlatform/classad/ast" +) + +func TestClassAdStringEmpty(t *testing.T) { + var ad ClassAd + if ad.String() != "[]" { + t.Fatalf("expected empty ClassAd string to be [] but got %q", ad.String()) + } +} + +func TestExprStringUndefined(t *testing.T) { + var expr Expr + if expr.String() != "undefined" { + t.Fatalf("expected undefined Expr string but got %q", expr.String()) + } +} + +func TestInsertExprNilNoop(t *testing.T) { + ad := &ClassAd{} + ad.InsertExpr("ignored", nil) + if ad.Size() != 0 { + t.Fatalf("expected insert of nil Expr to be ignored; size=%d", ad.Size()) + } +} + +func TestEvaluateExprWithTargetNil(t *testing.T) { + ad := &ClassAd{} + result := ad.EvaluateExprWithTarget(nil, &ClassAd{}) + if !result.IsUndefined() { + t.Fatalf("expected undefined when evaluating nil expr, got %v", result) + } +} + +func TestFlattenBinaryWithMissingRef(t *testing.T) { + ad := &ClassAd{} + ad.InsertAttr("A", 2) + + expr, err := ParseExpr("A + B") + if err != nil { + t.Fatalf("failed to parse expression: %v", err) + } + + flattened := ad.Flatten(expr) + if flattened == nil { + t.Fatalf("expected flattened expression, got nil") + } + + bin, ok := flattened.internal().(*ast.BinaryOp) + if !ok { + t.Fatalf("expected binary op, got %T", flattened.internal()) + } + + left, ok := bin.Left.(*ast.IntegerLiteral) + if !ok || left.Value != 2 { + t.Fatalf("expected left literal 2, got %T with value %v", bin.Left, bin.Left) + } + + right, ok := bin.Right.(*ast.AttributeReference) + if !ok || right.Name != "B" { + t.Fatalf("expected right attribute reference B, got %T with name %v", bin.Right, right.Name) + } +} diff --git a/classad/eval_with_target_test.go b/classad/eval_with_target_test.go new file mode 100644 index 0000000..5000c9c --- /dev/null +++ b/classad/eval_with_target_test.go @@ -0,0 +1,23 @@ +package classad + +import "testing" + +func TestEvaluateExprWithTargetCustom(t *testing.T) { + scope := New() + scope.InsertAttr("A", 1) + target := New() + target.InsertAttr("B", 5) + + expr, err := ParseExpr("A + TARGET.B") + if err != nil { + t.Fatalf("parse expr failed: %v", err) + } + + val := scope.EvaluateExprWithTarget(expr, target) + if val.IsError() || val.IsUndefined() { + t.Fatalf("expected defined value, got %v", val.Type()) + } + if num, _ := val.IntValue(); num != 6 { + t.Fatalf("unexpected evaluation result: %d", num) + } +} diff --git a/classad/flatten_test.go b/classad/flatten_test.go new file mode 100644 index 0000000..ba55594 --- /dev/null +++ b/classad/flatten_test.go @@ -0,0 +1,115 @@ +package classad + +import ( + "strings" + "testing" +) + +func TestFlattenPartialEvaluation(t *testing.T) { + ad, err := Parse(`[A = 1; B = 2]`) + if err != nil { + t.Fatalf("failed to parse base ad: %v", err) + } + + expr, err := ParseExpr("A + B + C") + if err != nil { + t.Fatalf("parse expr failed: %v", err) + } + + flat := ad.Flatten(expr) + if !strings.Contains(flat.String(), "3") { + t.Fatalf("expected flattened expression to include computed constant, got %s", flat.String()) + } + + ad.InsertAttr("C", 4) + if val := flat.Eval(ad); !val.IsInteger() { + t.Fatalf("flattened eval not integer: %v", val.Type()) + } else if i, _ := val.IntValue(); i != 7 { + t.Fatalf("unexpected flattened eval result: %d", i) + } +} + +func TestFlattenUnaryPreservesUndefined(t *testing.T) { + ad, err := Parse(`[X = 5]`) + if err != nil { + t.Fatalf("parse ad failed: %v", err) + } + + expr, err := ParseExpr("-(UndefinedAttr) + X") + if err != nil { + t.Fatalf("parse expr failed: %v", err) + } + + flat := ad.Flatten(expr) + if strings.Contains(flat.String(), "-UndefinedAttr") == false { + t.Fatalf("expected undefined attribute to remain in expression, got %s", flat.String()) + } +} + +func TestFlattenListValue(t *testing.T) { + ad := New() + InsertAttrList(ad, "List", []int64{1, 2, 3}) + expr, err := ParseExpr("List") + if err != nil { + t.Fatalf("parse expr failed: %v", err) + } + + flat := ad.Flatten(expr) + if flat.String() == "undefined" { + t.Fatalf("expected list flattening to produce list literal, got %s", flat.String()) + } + if !strings.Contains(flat.String(), "{1, 2, 3}") { + t.Fatalf("expected flattened list literal content, got %s", flat.String()) + } +} + +func TestFlattenBoolValue(t *testing.T) { + ad := New() + ad.InsertAttrBool("Flag", true) + + expr, err := ParseExpr("Flag") + if err != nil { + t.Fatalf("parse expr failed: %v", err) + } + + flat := ad.Flatten(expr) + if flat.String() != "true" { + t.Fatalf("expected boolean literal after flatten, got %s", flat.String()) + } +} + +func TestFlattenClassAdValue(t *testing.T) { + inner := New() + inner.InsertAttr("X", 1) + + outer := New() + outer.InsertAttrClassAd("Nested", inner) + + expr, err := ParseExpr("Nested") + if err != nil { + t.Fatalf("parse expr failed: %v", err) + } + + flat := outer.Flatten(expr) + if flat.String() == "undefined" { + t.Fatalf("expected flatten to yield record literal, got %s", flat.String()) + } + if !strings.Contains(flat.String(), "X = 1") { + t.Fatalf("expected flattened record content, got %s", flat.String()) + } +} + +func TestFlattenPreservesUnknownReference(t *testing.T) { + ad := New() + ad.InsertAttr("Known", 2) + + expr, err := ParseExpr("Unknown + 1") + if err != nil { + t.Fatalf("parse expr failed: %v", err) + } + + flat := ad.Flatten(expr) + if !strings.Contains(flat.String(), "Unknown + 1") { + t.Fatalf("expected unknown reference to remain, got %s", flat.String()) + } +} diff --git a/classad/functions_string_test.go b/classad/functions_string_test.go new file mode 100644 index 0000000..0520c24 --- /dev/null +++ b/classad/functions_string_test.go @@ -0,0 +1,305 @@ +package classad + +import "testing" + +// evalBuiltin is a helper to evaluate a raw ClassAd expression. +func evalBuiltin(t *testing.T, expr string) Value { + // Evaluate in an empty ClassAd scope. + e, err := ParseExpr(expr) + if err != nil { + t.Fatalf("parse failed for %s: %v", expr, err) + } + return e.Eval(New()) +} + +func TestJoinVariants(t *testing.T) { + if val := evalBuiltin(t, "join()"); !val.IsError() { + t.Fatalf("expected error for empty join, got %v", val.Type()) + } + + noSep := evalBuiltin(t, `join({"a", undefined, 2, true, 1.5})`) + if s, _ := noSep.StringValue(); s != "a2true1.5" { + t.Fatalf("unexpected join(list) result: %q", s) + } + + listForm := evalBuiltin(t, `join(",", {"a", "b", "c"})`) + if s, _ := listForm.StringValue(); s != "a,b,c" { + t.Fatalf("unexpected join(sep,list) result: %q", s) + } + + variadic := evalBuiltin(t, `join("-", "x", 3, false)`) + if s, _ := variadic.StringValue(); s != "x-3-false" { + t.Fatalf("unexpected join variadic result: %q", s) + } + + if val := evalBuiltin(t, `join(123, "a")`); !val.IsError() { + t.Fatalf("expected error for non-string separator, got %v", val.Type()) + } +} + +func TestSplitAndSlots(t *testing.T) { + ws := evalBuiltin(t, `split("a b c")`) + fields, _ := ws.ListValue() + if len(fields) != 3 { + t.Fatalf("expected 3 fields, got %d", len(fields)) + } + + custom := evalBuiltin(t, `split("a,b;c", ",;")`) + parts, _ := custom.ListValue() + if len(parts) != 3 { + t.Fatalf("expected 3 fields with custom delimiter, got %d", len(parts)) + } + + if val := evalBuiltin(t, `split(123)`); !val.IsError() { + t.Fatalf("expected error for non-string split input") + } + + slot := evalBuiltin(t, `splitSlotName("slot1@machine")`) + slotParts, _ := slot.ListValue() + if len(slotParts) != 2 { + t.Fatalf("expected 2 slot parts, got %d", len(slotParts)) + } + left, _ := slotParts[0].StringValue() + right, _ := slotParts[1].StringValue() + if left != "slot1" || right != "machine" { + t.Fatalf("unexpected splitSlotName parts: %q, %q", left, right) + } + + noAt := evalBuiltin(t, `splitSlotName("machine")`) + noAtParts, _ := noAt.ListValue() + first, _ := noAtParts[0].StringValue() + second, _ := noAtParts[1].StringValue() + if first != "" || second != "machine" { + t.Fatalf("unexpected splitSlotName without @: %q, %q", first, second) + } + + user := evalBuiltin(t, `splitUserName("alice@example.com")`) + userParts, _ := user.ListValue() + if u, _ := userParts[0].StringValue(); u != "alice" { + t.Fatalf("unexpected username part: %q", u) + } + if dom, _ := userParts[1].StringValue(); dom != "example.com" { + t.Fatalf("unexpected domain part: %q", dom) + } +} + +func TestStringComparisons(t *testing.T) { + cmp := evalBuiltin(t, `strcmp("apple", "banana")`) + if v, _ := cmp.IntValue(); v >= 0 { + t.Fatalf("expected apple < banana, got %d", v) + } + + ci := evalBuiltin(t, `stricmp("Case", "case")`) + if v, _ := ci.IntValue(); v != 0 { + t.Fatalf("expected case-insensitive match, got %d", v) + } + + numeric := evalBuiltin(t, `strcmp(10, 2)`) + if v, _ := numeric.IntValue(); v >= 0 { + t.Fatalf("expected string compare 10 vs 2 to be negative, got %d", v) + } + + if errVal := evalBuiltin(t, `strcmp({1}, "x")`); !errVal.IsError() { + t.Fatalf("expected error for invalid strcmp types") + } +} + +func TestVersionComparisons(t *testing.T) { + vc := evalBuiltin(t, `versioncmp("1.2", "1.10")`) + if v, _ := vc.IntValue(); v >= 0 { + t.Fatalf("expected 1.2 < 1.10, got %d", v) + } + + if gt := evalBuiltin(t, `version_gt("8.9.1", "8.8.9")`); !gt.IsBool() { + t.Fatalf("version_gt should return bool") + } else if b, _ := gt.BoolValue(); !b { + t.Fatalf("expected version_gt to be true") + } + + inRange := evalBuiltin(t, `version_in_range("8.8.0", "8.7.0", "8.9.0")`) + if ok, _ := inRange.BoolValue(); !ok { + t.Fatalf("expected version in range") + } + + undef := evalBuiltin(t, `versioncmp(undefined, "1")`) + if !undef.IsUndefined() { + t.Fatalf("expected undefined for versioncmp with undefined input") + } +} + +func TestIntervalAndIdentity(t *testing.T) { + ival := evalBuiltin(t, `interval(3661)`) + if s, _ := ival.StringValue(); s != "1:01:01" { + t.Fatalf("unexpected interval formatting: %q", s) + } + + ident := evalBuiltin(t, `identicalMember(1, {1, 1.0, "1"})`) + if b, _ := ident.BoolValue(); !b { + t.Fatalf("expected identicalMember to find matching int") + } + + undefMatch := evalBuiltin(t, `identicalMember(undefined, {1, undefined})`) + if b, _ := undefMatch.BoolValue(); !b { + t.Fatalf("expected identicalMember to match undefined") + } + + if errVal := evalBuiltin(t, `identicalMember({1}, {1})`); !errVal.IsError() { + t.Fatalf("expected error when first arg is list") + } +} + +func TestAnyAllCompare(t *testing.T) { + anyResult := evalBuiltin(t, `anyCompare("<", {1,2,3}, 2)`) + if b, _ := anyResult.BoolValue(); !b { + t.Fatalf("expected anyCompare to be true") + } + + all := evalBuiltin(t, `allCompare(">=", {2,2,3}, 2)`) + if b, _ := all.BoolValue(); !b { + t.Fatalf("expected allCompare to be true") + } + + nonList := evalBuiltin(t, `allCompare("==", 1, 1)`) + if !nonList.IsError() { + t.Fatalf("expected error for non-list input") + } +} + +func TestStringListBuiltins(t *testing.T) { + size := evalBuiltin(t, `stringListSize("a, b, ,c")`) + if v, _ := size.IntValue(); v != 3 { + t.Fatalf("unexpected stringListSize: %d", v) + } + + sum := evalBuiltin(t, `stringListSum("1,2.5,bad")`) + if sum.IsInteger() { + t.Fatalf("expected real sum due to decimal input") + } + if v, _ := sum.RealValue(); v != 3.5 { + t.Fatalf("unexpected stringListSum: %g", v) + } + + avg := evalBuiltin(t, `stringListAvg("")`) + if v, _ := avg.RealValue(); v != 0.0 { + t.Fatalf("expected zero average for empty list, got %g", v) + } + + minVal := evalBuiltin(t, `stringListMin("5,3,7")`) + if v, _ := minVal.IntValue(); v != 3 { + t.Fatalf("unexpected stringListMin: %d", v) + } + + maxVal := evalBuiltin(t, `stringListMax("1.5,2.5,2")`) + if v, _ := maxVal.RealValue(); v != 2.5 { + t.Fatalf("unexpected stringListMax: %g", v) + } + + inter := evalBuiltin(t, `stringListsIntersect("a,b,c", "x,b,z")`) + if b, _ := inter.BoolValue(); !b { + t.Fatalf("expected intersection to be true") + } + + subset := evalBuiltin(t, `stringListSubsetMatch("a,b", "a,b,c")`) + if b, _ := subset.BoolValue(); !b { + t.Fatalf("expected subset match to be true") + } + + regexMember := evalBuiltin(t, `stringListRegexpMember("a.*", "abc,def")`) + if b, _ := regexMember.BoolValue(); !b { + t.Fatalf("expected regex member to match") + } + + listRegex := evalBuiltin(t, `regexpMember("foo", {"bar", "foobar"})`) + if b, _ := listRegex.BoolValue(); !b { + t.Fatalf("expected regexpMember to match list element") + } +} + +func TestRegexAndReplace(t *testing.T) { + replaceFirst := evalBuiltin(t, `replace("ab", "xxabyyab", "Q")`) + if s, _ := replaceFirst.StringValue(); s != "xxQyyab" { + t.Fatalf("unexpected replace result: %q", s) + } + + replaceAll := evalBuiltin(t, `replaceAll("ab", "xxabyyab", "Q")`) + if s, _ := replaceAll.StringValue(); s != "xxQyyQ" { + t.Fatalf("unexpected replaceAll result: %q", s) + } + + regexpsVal := evalBuiltin(t, `regexps("[0-9]+", "a1b2c", "#")`) + if s, _ := regexpsVal.StringValue(); s != "a#b#c" { + t.Fatalf("unexpected regexps result: %q", s) + } + + if errVal := evalBuiltin(t, `replace("(", "a", "b")`); !errVal.IsError() { + t.Fatalf("expected error for invalid regex pattern") + } +} + +func TestMembershipFunctions(t *testing.T) { + member := evalBuiltin(t, `member(2, {1, 2, 3})`) + if b, _ := member.BoolValue(); !b { + t.Fatalf("expected member to find value") + } + + nonMember := evalBuiltin(t, `member("x", {"a", "b"})`) + if b, _ := nonMember.BoolValue(); b { + t.Fatalf("expected member to be false for missing value") + } + + memberErr := evalBuiltin(t, `member(1, 2)`) + if !memberErr.IsError() { + t.Fatalf("expected error when member second argument is not a list") + } + + memberUndef := evalBuiltin(t, `member(undefined, {1})`) + if !memberUndef.IsUndefined() { + t.Fatalf("expected undefined when member element is undefined") + } + + strList := evalBuiltin(t, `stringListMember("foo", "bar, foo, baz")`) + if b, _ := strList.BoolValue(); !b { + t.Fatalf("expected stringListMember to match") + } + + strListIgnore := evalBuiltin(t, `stringListIMember("FOO", "bar, foo")`) + if b, _ := strListIgnore.BoolValue(); !b { + t.Fatalf("expected case-insensitive member to match") + } + + avg := evalBuiltin(t, `stringListAvg("1,2,3")`) + if v, _ := avg.RealValue(); v != 2 { + t.Fatalf("unexpected average value: %g", v) + } + avgDelim := evalBuiltin(t, `stringListAvg("1.0|2.0", "|")`) + if v, _ := avgDelim.RealValue(); v != 1.5 { + t.Fatalf("unexpected average with delimiter: %g", v) + } + avgErr := evalBuiltin(t, `stringListAvg(123)`) + if !avgErr.IsError() { + t.Fatalf("expected error for non-string average input") + } + avgUndef := evalBuiltin(t, `stringListAvg(undefined)`) + if !avgUndef.IsUndefined() { + t.Fatalf("expected undefined when average input is undefined") + } + + stricmpInt := evalBuiltin(t, `stricmp(10, 10)`) + if v, _ := stricmpInt.IntValue(); v != 0 { + t.Fatalf("expected stricmp to treat integers as equal, got %d", v) + } + + if errVal := evalBuiltin(t, `stricmp(undefined, "x")`); !errVal.IsError() { + t.Fatalf("expected error when stricmp receives undefined") + } + + regexMember := evalBuiltin(t, `regexpMember("foo", {"FOO"}, "i")`) + if b, _ := regexMember.BoolValue(); !b { + t.Fatalf("expected regexpMember to match with options") + } + + noMatch := evalBuiltin(t, `replace("zz", "abc", "Q")`) + if s, _ := noMatch.StringValue(); s != "abc" { + t.Fatalf("expected replace to return original when no match, got %q", s) + } +} diff --git a/classad/functions_time_regex_test.go b/classad/functions_time_regex_test.go new file mode 100644 index 0000000..2bf8541 --- /dev/null +++ b/classad/functions_time_regex_test.go @@ -0,0 +1,69 @@ +package classad + +import ( + "strings" + "testing" +) + +func TestFormatTimeConversion(t *testing.T) { + val := evalBuiltin(t, `formatTime(0, "%a %A %b %B %c %d %H %I %j %m %M %p %S %U %w %x %X %y %Y %Z %q")`) + if val.IsError() { + t.Fatalf("formatTime returned error") + } + + s, _ := val.StringValue() + for _, piece := range []string{"Thu", "Thursday", "Jan", "1970", "UTC", "%q"} { + if !strings.Contains(s, piece) { + t.Fatalf("expected formatted time to include %q, got %q", piece, s) + } + } +} + +func TestRegexpsOptionsAndErrors(t *testing.T) { + repl := evalBuiltin(t, `regexps("foo", "FOO foo", "bar", "i")`) + if s, _ := repl.StringValue(); s != "bar bar" { + t.Fatalf("unexpected case-insensitive regexps result: %q", s) + } + + if errVal := evalBuiltin(t, `regexps("(", "a", "b")`); !errVal.IsError() { + t.Fatalf("expected error for invalid regexps pattern") + } + + member := evalBuiltin(t, `stringListRegexpMember("^foo$", "foo|bar", "|", "i")`) + if b, _ := member.BoolValue(); !b { + t.Fatalf("expected stringListRegexpMember to match with delimiter and options") + } + + undef := evalBuiltin(t, `stringListRegexpMember("foo", undefined)`) + if !undef.IsUndefined() { + t.Fatalf("expected undefined when list argument is undefined") + } +} + +func TestStringListSetOperations(t *testing.T) { + inter := evalBuiltin(t, `stringListsIntersect("", "", ",")`) + if b, _ := inter.BoolValue(); b { + t.Fatalf("expected no intersection for empty string lists") + } + + subset := evalBuiltin(t, `stringListSubsetMatch("a|b", "a|b|c", "|")`) + if b, _ := subset.BoolValue(); !b { + t.Fatalf("expected subset match with custom delimiter") + } + + undef := evalBuiltin(t, `stringListsIntersect(undefined, "a")`) + if !undef.IsUndefined() { + t.Fatalf("expected undefined when first list is undefined") + } +} + +func TestStricmpErrorAndBoolCompare(t *testing.T) { + if val := evalBuiltin(t, `stricmp(true, "x")`); !val.IsError() { + t.Fatalf("expected error for invalid stricmp argument types") + } + + boolCompare := evalBuiltin(t, `anyCompare("==", {true, false}, true)`) + if b, _ := boolCompare.BoolValue(); !b { + t.Fatalf("expected anyCompare to match boolean value") + } +} diff --git a/classad/matchclassad_test.go b/classad/matchclassad_test.go new file mode 100644 index 0000000..36d681f --- /dev/null +++ b/classad/matchclassad_test.go @@ -0,0 +1,92 @@ +package classad + +import "testing" + +func TestMatchClassAdBasic(t *testing.T) { + left, err := Parse(`[Requirements = TARGET.Memory >= 1024; Rank = 5; Memory = 2048; Name = "job"]`) + if err != nil { + t.Fatalf("parse left failed: %v", err) + } + right, err := Parse(`[Requirements = TARGET.Memory >= 2048; Rank = 3.5; Memory = 1024; Name = "machine"]`) + if err != nil { + t.Fatalf("parse right failed: %v", err) + } + + match := NewMatchClassAd(left, right) + + if l := match.GetLeftAd(); l == nil { + t.Fatalf("left ad should be set") + } + if r := match.GetRightAd(); r == nil { + t.Fatalf("right ad should be set") + } + + if !match.Match() { + t.Fatalf("expected match to succeed when both requirements true") + } + + if rank, ok := match.EvaluateRankLeft(); !ok || rank != 5 { + t.Fatalf("unexpected left rank: %v (ok=%v)", rank, ok) + } + if rank, ok := match.EvaluateRankRight(); !ok || rank != 3.5 { + t.Fatalf("unexpected right rank: %v (ok=%v)", rank, ok) + } + + expr, err := ParseExpr(`Memory + TARGET.Memory`) + if err != nil { + t.Fatalf("parse expr failed: %v", err) + } + if val := match.EvaluateExprLeft(expr.internal()); !val.IsNumber() { + t.Fatalf("expected numeric expression result, got %v", val.Type()) + } else if num, _ := val.NumberValue(); num != 3072 { + t.Fatalf("unexpected expression result: %g", num) + } + + exprRight, err := ParseExpr(`TARGET.Memory - Memory`) + if err != nil { + t.Fatalf("parse expr right failed: %v", err) + } + if val := match.EvaluateExprRight(exprRight.internal()); !val.IsNumber() { + t.Fatalf("expected numeric result from right expr") + } else if num, _ := val.NumberValue(); num != 1024 { + t.Fatalf("unexpected right expr result: %g", num) + } + + // Replace right with a stricter ad that fails symmetry. + newRight, err := Parse(`[Requirements = false; Rank = 0; Memory = 512]`) + if err != nil { + t.Fatalf("parse new right failed: %v", err) + } + match.ReplaceRightAd(newRight) + if match.Match() { + t.Fatalf("expected match to fail after replacing right") + } + + // Replace left with nil and ensure evaluations become undefined. + match.ReplaceLeftAd(nil) + if val := match.EvaluateAttrLeft("Memory"); !val.IsUndefined() { + t.Fatalf("expected undefined when left ad is nil") + } + if val := match.EvaluateAttrRight("Memory"); val.IsUndefined() { + t.Fatalf("right ad should remain accessible") + } +} + +func TestMatchClassAdRankFallback(t *testing.T) { + left, err := Parse(`[Requirements = true; Rank = "high"]`) + if err != nil { + t.Fatalf("parse left failed: %v", err) + } + right, err := Parse(`[Requirements = true; Rank = 1]`) + if err != nil { + t.Fatalf("parse right failed: %v", err) + } + + match := NewMatchClassAd(left, right) + if _, ok := match.EvaluateRankLeft(); ok { + t.Fatalf("expected non-numeric left rank to return ok=false") + } + if rank, ok := match.EvaluateRankRight(); !ok || rank != 1 { + t.Fatalf("expected numeric right rank, got %v (ok=%v)", rank, ok) + } +} diff --git a/classad/math_eval_test.go b/classad/math_eval_test.go new file mode 100644 index 0000000..baa99f9 --- /dev/null +++ b/classad/math_eval_test.go @@ -0,0 +1,125 @@ +package classad + +import ( + "fmt" + "testing" +) + +func evalExpr(t *testing.T, expr string) Value { + t.Helper() + ad := New() + wrapped := fmt.Sprintf("[__tmp__ = %s]", expr) + val, err := ad.EvaluateExprString(wrapped) + if err != nil { + t.Fatalf("parse/eval failed for %q: %v", expr, err) + } + return val +} + +func TestEvaluateExprMathFunctions(t *testing.T) { + floorVal := evalExpr(t, "floor(3.7)") + if !floorVal.IsInteger() { + t.Fatalf("floor result not integer: %v", floorVal.Type()) + } + if got, _ := floorVal.IntValue(); got != 3 { + t.Fatalf("floor(3.7) = %d, want 3", got) + } + + ceilVal := evalExpr(t, "ceiling(3.2)") + if got, _ := ceilVal.IntValue(); got != 4 { + t.Fatalf("ceiling(3.2) = %d, want 4", got) + } + + roundVal := evalExpr(t, "round(3.5)") + if got, _ := roundVal.IntValue(); got != 4 { + t.Fatalf("round(3.5) = %d, want 4", got) + } + + powVal := evalExpr(t, "pow(2, -2)") + if !powVal.IsReal() { + t.Fatalf("pow negative exponent should be real") + } + if got, _ := powVal.RealValue(); got != 0.25 { + t.Fatalf("pow(2,-2) = %f, want 0.25", got) + } + + quantizeVal := evalExpr(t, "quantize(12, {5, 10, 15})") + if got, _ := quantizeVal.IntValue(); got != 15 { + t.Fatalf("quantize(12,{5,10,15}) = %d, want 15", got) + } + + sumVal := evalExpr(t, "sum({1, 2, 3.5})") + if !sumVal.IsReal() { + t.Fatalf("sum mixed numeric should be real") + } + if got, _ := sumVal.RealValue(); got != 6.5 { + t.Fatalf("sum({1,2,3.5}) = %f, want 6.5", got) + } +} + +func TestEvaluateExprMathErrorPaths(t *testing.T) { + errVal := evalExpr(t, "floor(\"x\")") + if !errVal.IsError() { + t.Fatalf("expected floor on string to be error, got %v", errVal.Type()) + } + + undefVal := evalExpr(t, "floor(undefined)") + if !undefVal.IsUndefined() { + t.Fatalf("expected undefined from floor(undefined), got %v", undefVal.Type()) + } + + powErr := evalExpr(t, "pow(\"x\", 2)") + if !powErr.IsError() { + t.Fatalf("expected error from pow with string base") + } + + quantizeErr := evalExpr(t, "quantize(\"x\", 3)") + if !quantizeErr.IsError() { + t.Fatalf("expected error from quantize non-numeric input") + } + + quantizeUndef := evalExpr(t, "quantize(12, {undefined})") + if !quantizeUndef.IsUndefined() { + t.Fatalf("expected undefined from quantize with all undefined list entries") + } +} + +func TestEvaluateExprUnaryAndDivide(t *testing.T) { + plusVal := evalExpr(t, "+2.5") + if !plusVal.IsReal() { + t.Fatalf("expected real from unary plus") + } + if got, _ := plusVal.RealValue(); got != 2.5 { + t.Fatalf("+2.5 = %f, want 2.5", got) + } + + minusVal := evalExpr(t, "-(4)") + if got, _ := minusVal.IntValue(); got != -4 { + t.Fatalf("-(4) = %d, want -4", got) + } + + notErr := evalExpr(t, "!1") + if !notErr.IsError() { + t.Fatalf("expected error from logical not on non-bool") + } + + plErr := evalExpr(t, "+\"str\"") + if !plErr.IsError() { + t.Fatalf("expected error from unary plus on string") + } + + boolNot := evalExpr(t, "!false") + if got, _ := boolNot.BoolValue(); !got { + t.Fatalf("!false expected true, got %v", got) + } + + cmpVal := evalExpr(t, "\"b\" < \"c\"") + if got, _ := cmpVal.BoolValue(); !got { + t.Fatalf("expected string comparison to be true") + } + + divZero := evalExpr(t, "10 / 0") + if !divZero.IsError() { + t.Fatalf("division by zero should be error") + } +} diff --git a/classad/reader.go b/classad/reader.go index 04b56f1..f935847 100644 --- a/classad/reader.go +++ b/classad/reader.go @@ -5,11 +5,16 @@ import ( "io" "iter" "strings" + + "github.com/PelicanPlatform/classad/parser" ) // Reader provides an iterator for parsing multiple ClassAds from an io.Reader. // It supports both new-style (bracketed) and old-style (newline-delimited) formats. +// New-style ClassAds can be concatenated without delimiters or whitespace. type Reader struct { + reader *bufio.Reader + parser *parser.ReaderParser scanner *bufio.Scanner oldStyle bool err error @@ -17,14 +22,16 @@ type Reader struct { } // NewReader creates a new Reader for parsing new-style ClassAds (with brackets). -// Each ClassAd should be on its own, separated by whitespace or comments. +// ClassAds may be concatenated directly; whitespace and comments are optional. // Example format: // // [Foo = 1; Bar = 2] // [Baz = 3; Qux = 4] func NewReader(r io.Reader) *Reader { + br := bufio.NewReader(r) return &Reader{ - scanner: bufio.NewScanner(r), + reader: br, + parser: parser.NewReaderParser(br), oldStyle: false, } } @@ -61,69 +68,17 @@ func (r *Reader) Next() bool { // nextNew reads the next new-style ClassAd (with brackets) func (r *Reader) nextNew() bool { - var lines []string - inClassAd := false - bracketDepth := 0 - - for r.scanner.Scan() { - line := strings.TrimSpace(r.scanner.Text()) - - // Skip empty lines and comments outside of ClassAds - if !inClassAd && (line == "" || strings.HasPrefix(line, "//") || strings.HasPrefix(line, "/*")) { - continue - } - - // Check if this line starts a ClassAd - if !inClassAd && strings.HasPrefix(line, "[") { - inClassAd = true - } - - if inClassAd { - lines = append(lines, line) - - // Count brackets to handle nested ClassAds - for _, ch := range line { - switch ch { - case '[': - bracketDepth++ - case ']': - bracketDepth-- - } - } - - // If we've closed all brackets, we have a complete ClassAd - if bracketDepth == 0 { - classAdStr := strings.Join(lines, "\n") - ad, err := Parse(classAdStr) - if err != nil { - r.err = err - return false - } - r.current = ad - return true - } + ad, err := r.parser.ParseClassAd() + if err != nil { + if err == io.EOF { + return false } - } - - // Check for scanner errors - if err := r.scanner.Err(); err != nil { r.err = err return false } - // If we have accumulated lines but hit EOF, try to parse them - if len(lines) > 0 { - classAdStr := strings.Join(lines, "\n") - ad, err := Parse(classAdStr) - if err != nil { - r.err = err - return false - } - r.current = ad - return true - } - - return false + r.current = &ClassAd{ad: ad} + return true } // nextOld reads the next old-style ClassAd (newline-delimited, separated by blank lines) diff --git a/classad/reader_test.go b/classad/reader_test.go index 3e0ff48..f9e6a2d 100644 --- a/classad/reader_test.go +++ b/classad/reader_test.go @@ -82,6 +82,53 @@ func TestNewReader_MultipleClassAds(t *testing.T) { } } +func TestNewReader_ConcatenatedClassAds(t *testing.T) { + input := `[ID = 1][ID = 2][ID = 3]` + reader := NewReader(strings.NewReader(input)) + + ids := []int64{} + for reader.Next() { + id, ok := reader.ClassAd().EvaluateAttrInt("ID") + if !ok { + t.Fatalf("expected ID attribute") + } + ids = append(ids, id) + } + + if err := reader.Err(); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + expected := []int64{1, 2, 3} + if len(ids) != len(expected) { + t.Fatalf("Expected %d ClassAds, got %d", len(expected), len(ids)) + } + for i, v := range expected { + if ids[i] != v { + t.Fatalf("Expected ID=%d at position %d, got %d", v, i, ids[i]) + } + } +} + +func TestNewReader_ConcatenatedWithComments(t *testing.T) { + input := `[ID = 1]/*block*/[ID = 2]// trailing comment +[ID = 3]` + reader := NewReader(strings.NewReader(input)) + + count := 0 + for reader.Next() { + count++ + } + + if err := reader.Err(); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if count != 3 { + t.Fatalf("Expected 3 ClassAds, got %d", count) + } +} + func TestNewReader_WithComments(t *testing.T) { input := ` // This is a comment @@ -105,27 +152,6 @@ func TestNewReader_WithComments(t *testing.T) { } } -func TestNewReader_NestedClassAds(t *testing.T) { - input := `[Outer = [Inner = 42]; Value = 10]` - reader := NewReader(strings.NewReader(input)) - - if !reader.Next() { - t.Fatalf("Expected ClassAd, got error: %v", reader.Err()) - } - - ad := reader.ClassAd() - outerVal := ad.EvaluateAttr("Outer") - if !outerVal.IsClassAd() { - t.Error("Expected Outer to be a ClassAd") - } - - innerAd, _ := outerVal.ClassAdValue() - inner, _ := innerAd.EvaluateAttrInt("Inner") - if inner != 42 { - t.Errorf("Expected Inner=42, got %v", inner) - } -} - func TestNewReader_MultilineClassAd(t *testing.T) { input := `[ Foo = 1; @@ -157,6 +183,32 @@ func TestNewReader_EmptyInput(t *testing.T) { } } +func TestNewReader_ErrorAfterFirstAd(t *testing.T) { + input := `[Ok = 1][Broken = ]` + reader := NewReader(strings.NewReader(input)) + + if !reader.Next() { + t.Fatalf("expected first ClassAd, got error: %v", reader.Err()) + } + + ok, _ := reader.ClassAd().EvaluateAttrInt("Ok") + if ok != 1 { + t.Fatalf("expected Ok=1, got %d", ok) + } + + if reader.Next() { + t.Fatalf("expected failure on second ClassAd") + } + + if reader.Err() == nil { + t.Fatalf("expected error after malformed second ClassAd") + } + + if reader.Next() { + t.Fatalf("expected no further ClassAds after error") + } +} + func TestNewReader_InvalidClassAd(t *testing.T) { input := `[Foo = ]` // Invalid syntax reader := NewReader(strings.NewReader(input)) @@ -503,6 +555,83 @@ func TestAllWithIndex(t *testing.T) { } } +// TestRangeIterator demonstrates the Go 1.23+ range-over-func syntax for All. +func TestRangeIterator(t *testing.T) { + input := `[ID = 1] +[ID = 2] +[ID = 3]` + + ids := []int64{} + + for ad := range All(strings.NewReader(input)) { + id, ok := ad.EvaluateAttrInt("ID") + if !ok { + t.Fatalf("expected ID attribute") + } + ids = append(ids, id) + } + + expected := []int64{1, 2, 3} + if len(ids) != len(expected) { + t.Fatalf("expected %d ids, got %d", len(expected), len(ids)) + } + for i, v := range expected { + if ids[i] != v { + t.Fatalf("expected ID=%d at index %d, got %d", v, i, ids[i]) + } + } +} + +func TestAllEarlyStop(t *testing.T) { + input := `[V = 1][V = 2][V = 3]` + count := 0 + + All(strings.NewReader(input))(func(ad *ClassAd) bool { + count++ + return false + }) + + if count != 1 { + t.Fatalf("expected to stop after first element, got %d", count) + } +} + +func TestAllWithIndexEarlyStop(t *testing.T) { + input := `[V = 1][V = 2][V = 3]` + count := 0 + + AllWithIndex(strings.NewReader(input))(func(i int, ad *ClassAd) bool { + if i != 0 { + t.Fatalf("expected first index 0, got %d", i) + } + count++ + return false + }) + + if count != 1 { + t.Fatalf("expected to stop after first element, got %d", count) + } +} + +func TestAllWithErrorEarlyStop(t *testing.T) { + input := `[V = 1][V = 2][V = 3]` + count := 0 + var err error + + AllWithError(strings.NewReader(input), &err)(func(ad *ClassAd) bool { + count++ + return false + }) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if count != 1 { + t.Fatalf("expected to stop after first element, got %d", count) + } +} + // TestAllWithError tests error handling in iterator func TestAllWithError(t *testing.T) { // Valid input @@ -566,3 +695,60 @@ Name = "third"` t.Errorf("Expected 3 ClassAds, got %d", count) } } + +func TestAllOldWithError(t *testing.T) { + input := `A = 1 + +B = 2` + var err error + count := 0 + + for ad := range AllOldWithError(strings.NewReader(input), &err) { + if _, ok := ad.EvaluateAttrInt("A"); ok { + count++ + } + } + + if err != nil { + t.Fatalf("unexpected error on valid input: %v", err) + } + if count != 1 { + t.Fatalf("expected 1 ClassAd from valid input, got %d", count) + } + + broken := `Good = 1 + +Broken =` + count = 0 + err = nil + + for range AllOldWithError(strings.NewReader(broken), &err) { + count++ + } + + if err == nil { + t.Fatalf("expected parse error for broken input") + } + if count != 1 { + t.Fatalf("expected one yielded ad before error, got %d", count) + } + + earliestop := `X = 1 + +Y = 2` + err = nil + count = 0 + + for ad := range AllOldWithError(strings.NewReader(earliestop), &err) { + count++ + _ = ad + break // stop early to ensure no error is recorded + } + + if err != nil { + t.Fatalf("did not expect error when stopping early: %v", err) + } + if count != 1 { + t.Fatalf("expected single iteration before early stop, got %d", count) + } +} diff --git a/classad/refs_test.go b/classad/refs_test.go new file mode 100644 index 0000000..a1dda79 --- /dev/null +++ b/classad/refs_test.go @@ -0,0 +1,30 @@ +package classad + +import "testing" + +func TestCollectRefsCoversCompositeNodes(t *testing.T) { + expr, err := ParseExpr(`({AttrA, AttrB}[1] + (-AttrC)) + (AttrD ? AttrE : AttrF) + (AttrG ?: AttrH) + strcat(AttrI, [Inner = AttrJ].Inner)`) + if err != nil { + t.Fatalf("failed to parse expression: %v", err) + } + + ad := New() + ad.InsertAttr("AttrA", 1) + ad.InsertAttr("AttrD", 1) + + external := ad.ExternalRefs(expr) + expected := []string{"AttrB", "AttrC", "AttrE", "AttrF", "AttrG", "AttrH", "AttrI"} + + for _, ref := range expected { + found := false + for _, got := range external { + if got == ref { + found = true + break + } + } + if !found { + t.Fatalf("expected external reference %q not found in %v", ref, external) + } + } +} diff --git a/classad/struct_test.go b/classad/struct_test.go index 5e30cab..e4a89f0 100644 --- a/classad/struct_test.go +++ b/classad/struct_test.go @@ -1,6 +1,7 @@ package classad import ( + "strings" "testing" ) @@ -153,6 +154,36 @@ func TestMarshal_WithOmitEmpty(t *testing.T) { } } +func TestMarshal_OmitEmptyVariants(t *testing.T) { + ptrVal := 5 + type Example struct { + S string `classad:"s,omitempty"` + I int `classad:"i,omitempty"` + B bool `classad:"b,omitempty"` + L []int `classad:"l,omitempty"` + P *int `classad:"p,omitempty"` + F float64 `classad:"f,omitempty"` + } + + value := Example{I: 7, P: &ptrVal} + ad, err := Marshal(value) + if err != nil { + t.Fatalf("unexpected marshal error: %v", err) + } + + if !strings.Contains(ad, "i = 7") { + t.Fatalf("expected integer field to be marshaled, got %s", ad) + } + if !strings.Contains(ad, "p = 5") { + t.Fatalf("expected pointer field to be marshaled, got %s", ad) + } + for _, unwanted := range []string{"s =", "b =", "l =", "f ="} { + if strings.Contains(ad, unwanted) { + t.Fatalf("unexpected omitted field %q present in %s", unwanted, ad) + } + } +} + func TestMarshal_SkipField(t *testing.T) { type Job struct { ID int diff --git a/classad/unparse_extra_test.go b/classad/unparse_extra_test.go new file mode 100644 index 0000000..fd29358 --- /dev/null +++ b/classad/unparse_extra_test.go @@ -0,0 +1,15 @@ +package classad + +import "testing" + +func TestUnparseInvalidArgument(t *testing.T) { + ad, err := Parse(`[X = 1; Y = unparse(1)]`) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + + val := ad.EvaluateAttr("Y") + if !val.IsError() { + t.Fatalf("expected error value for unparse with non-attribute argument, got %v", val.Type()) + } +} diff --git a/parser/additional_parser_test.go b/parser/additional_parser_test.go new file mode 100644 index 0000000..0fda572 --- /dev/null +++ b/parser/additional_parser_test.go @@ -0,0 +1,377 @@ +package parser + +import ( + "errors" + "io" + "strings" + "testing" + + "github.com/PelicanPlatform/classad/ast" +) + +// errReader always returns an error on read. +type errReader struct{} + +func (errReader) Read(p []byte) (int, error) { + return 0, errors.New("boom") +} + +// failAfterReader emits a fixed set of byte chunks and then returns an error. +type failAfterReader struct { + chunks [][]byte + idx int +} + +func (r *failAfterReader) Read(p []byte) (int, error) { + if r.idx >= len(r.chunks) { + return 0, errors.New("boom") + } + n := copy(p, r.chunks[r.idx]) + r.idx++ + return n, nil +} + +func TestStreamingLexerPropagatesReadError(t *testing.T) { + lex := NewStreamingLexer(errReader{}) + lval := &yySymType{} + if tok := lex.Lex(lval); tok != 0 { + t.Fatalf("expected tok=0, got %d", tok) + } + if lex.err == nil || lex.err.Error() != "boom" { + t.Fatalf("expected read error 'boom', got %v", lex.err) + } +} + +func TestStreamingLexerOperators(t *testing.T) { + cases := []struct { + name string + input string + expected []int + }{ + {"IS", "a=?=b", []int{IDENTIFIER, IS, IDENTIFIER}}, + {"UnreadQuestion", "=?x", []int{int('='), int('?')}}, + {"UnreadBang", "=!x", []int{int('='), int('!')}}, + {"URSHIFT", "1>>>2", []int{INTEGER_LITERAL, URSHIFT, INTEGER_LITERAL}}, + {"LSHIFT", "1<<2", []int{INTEGER_LITERAL, LSHIFT, INTEGER_LITERAL}}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + lex := NewStreamingLexer(strings.NewReader(tc.input)) + lex.stopAfterClassAd = false + lval := &yySymType{} + var tokens []int + for { + tok := lex.Lex(lval) + if tok == 0 { + break + } + tokens = append(tokens, tok) + } + if lex.err != nil { + t.Fatalf("unexpected error: %v", lex.err) + } + if len(tokens) != len(tc.expected) { + t.Fatalf("tokens len=%d want=%d: %v", len(tokens), len(tc.expected), tokens) + } + for i, tok := range tc.expected { + if tokens[i] != tok { + t.Fatalf("token %d = %d want %d", i, tokens[i], tok) + } + } + }) + } +} + +func TestParseClassAdReturnsNilForNonClassAd(t *testing.T) { + ad, err := ParseClassAd("true") + if err == nil { + t.Fatalf("expected parse error for bare expression, got nil") + } + if ad != nil { + t.Fatalf("expected nil ClassAd result for expression input") + } +} + +func TestReaderParserEOF(t *testing.T) { + p := NewReaderParser(strings.NewReader(" \n\t")) + ad, err := p.ParseClassAd() + if err != io.EOF { + t.Fatalf("expected io.EOF, got %v (ad=%v)", err, ad) + } +} + +func TestReaderParserNonClassAd(t *testing.T) { + p := NewReaderParser(strings.NewReader("true")) + ad, err := p.ParseClassAd() + if err == nil { + t.Fatalf("expected error for non-ClassAd input, got nil (ad=%v)", ad) + } +} + +func TestReaderParserReadError(t *testing.T) { + p := NewReaderParser(errReader{}) + _, err := p.ParseClassAd() + if err == nil { + t.Fatalf("expected read error, got nil") + } +} + +func TestConvertOldToNewFormatBranches(t *testing.T) { + // Semicolon already present + out := convertOldToNewFormat("Foo = 1;\n// comment") + if !strings.Contains(out, "Foo = 1;\n") { + t.Fatalf("expected semicolon branch to preserve suffix, got %q", out) + } + // Non-assignment line preserved + out = convertOldToNewFormat("Req = A\n && B") + if !strings.Contains(out, " && B\n") { + t.Fatalf("expected non-assignment line to be kept, got %q", out) + } +} + +func TestParseOldClassAdError(t *testing.T) { + _, err := ParseOldClassAd("Foo =") + if err == nil { + t.Fatalf("expected parse error for malformed old ClassAd") + } +} + +func TestStreamingLexerSingleCharAndOr(t *testing.T) { + cases := []struct { + name string + input string + want int + }{ + {"SingleAmpersand", "&", int('&')}, + {"SinglePipe", "|", int('|')}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + lex := NewStreamingLexer(strings.NewReader(tc.input)) + lex.stopAfterClassAd = false + + tok := lex.Lex(&yySymType{}) + if tok != tc.want { + t.Fatalf("first token = %d want %d", tok, tc.want) + } + if lex.err != nil { + t.Fatalf("unexpected error: %v", lex.err) + } + if tok := lex.Lex(&yySymType{}); tok != 0 { + t.Fatalf("expected second token 0, got %d", tok) + } + }) + } +} + +func TestStreamingLexerSkipsComments(t *testing.T) { + cases := []struct { + name string + input string + expected []int + }{ + {"LineComment", "// comment\n[a=1]", []int{int('['), IDENTIFIER, int('='), INTEGER_LITERAL, int(']')}}, + {"BlockComment", "/*block*/[b=2]", []int{int('['), IDENTIFIER, int('='), INTEGER_LITERAL, int(']')}}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + lex := NewStreamingLexer(strings.NewReader(tc.input)) + tokens := collectTokens(t, lex) + if lex.err != nil { + t.Fatalf("unexpected error: %v", lex.err) + } + if len(tokens) != len(tc.expected) { + t.Fatalf("token count %d want %d: %v", len(tokens), len(tc.expected), tokens) + } + for i, tok := range tc.expected { + if tokens[i] != tok { + t.Fatalf("token %d = %d want %d", i, tokens[i], tok) + } + } + }) + } +} + +func collectTokens(t *testing.T, lex *StreamingLexer) []int { + t.Helper() + var tokens []int + lval := &yySymType{} + for { + tok := lex.Lex(lval) + if tok == 0 { + break + } + tokens = append(tokens, tok) + } + return tokens +} + +func TestStreamingLexerUnreadRuneNoSeen(t *testing.T) { + lex := NewStreamingLexer(strings.NewReader("")) + lex.unreadRune(1) + if lex.hasPending { + t.Fatalf("expected no pending rune after unread with empty history") + } +} + +func TestStreamingLexerInvalidNumbers(t *testing.T) { + cases := []struct { + name string + input string + wantMessage string + }{ + {"InvalidReal", "1e", "invalid real number: 1e"}, + {"InvalidInteger", strings.Repeat("9", 30), "invalid integer"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + lex := NewStreamingLexer(strings.NewReader(tc.input)) + lex.stopAfterClassAd = false + tok := lex.Lex(&yySymType{}) + if tok != 0 { + t.Fatalf("expected tok=0 after error, got %d", tok) + } + if lex.err == nil || !strings.Contains(lex.err.Error(), tc.wantMessage) { + t.Fatalf("expected error containing %q, got %v", tc.wantMessage, lex.err) + } + }) + } +} + +func TestParseScopedIdentifierNoScope(t *testing.T) { + name, scope := ParseScopedIdentifier("Attr") + if name != "Attr" || scope != ast.NoScope { + t.Fatalf("expected unchanged identifier and NoScope, got %q and %v", name, scope) + } +} + +func TestParseScopedIdentifierParent(t *testing.T) { + name, scope := ParseScopedIdentifier("PARENT.child") + if name != "child" || scope != ast.ParentScope { + t.Fatalf("expected child/ParentScope, got %q and %v", name, scope) + } +} + +func TestStreamingLexerLineCommentError(t *testing.T) { + lex := NewStreamingLexer(&failAfterReader{chunks: [][]byte{[]byte("// no newline")}}) + tok := lex.Lex(&yySymType{}) + if tok != 0 { + t.Fatalf("expected token 0 on error, got %d", tok) + } + if lex.err == nil || !strings.Contains(lex.err.Error(), "boom") { + t.Fatalf("expected propagated read error, got %v", lex.err) + } +} + +func TestStreamingLexerBlockCommentError(t *testing.T) { + lex := NewStreamingLexer(&failAfterReader{chunks: [][]byte{[]byte("/* unterminated")}}) + if tok := lex.Lex(&yySymType{}); tok != 0 { + t.Fatalf("expected token 0 on error, got %d", tok) + } + if lex.err == nil || !strings.Contains(lex.err.Error(), "boom") { + t.Fatalf("expected propagated read error, got %v", lex.err) + } +} + +func TestStreamingLexerUnterminatedEscape(t *testing.T) { + lex := NewStreamingLexer(strings.NewReader("\"abc\\")) + lex.stopAfterClassAd = false + tok := lex.Lex(&yySymType{}) + if tok != STRING_LITERAL { + t.Fatalf("expected STRING_LITERAL, got %d", tok) + } + if lex.err == nil || !strings.Contains(lex.err.Error(), "unterminated escape sequence") { + t.Fatalf("expected unterminated escape error, got %v", lex.err) + } +} + +func TestStreamingLexerInvalidEscape(t *testing.T) { + lex := NewStreamingLexer(strings.NewReader("\"\\y\"")) + lex.stopAfterClassAd = false + _ = lex.Lex(&yySymType{}) + if lex.err == nil || !strings.Contains(lex.err.Error(), "invalid escape sequence") { + t.Fatalf("expected invalid escape sequence error, got %v", lex.err) + } +} + +func TestStreamingLexerLineCommentEOF(t *testing.T) { + lex := NewStreamingLexer(strings.NewReader("// trailing comment")) + if tok := lex.Lex(&yySymType{}); tok != 0 { + t.Fatalf("expected no tokens, got %d", tok) + } + if lex.err != nil { + t.Fatalf("unexpected error for comment at EOF: %v", lex.err) + } +} + +func TestStreamingLexerScopedIdentifierEOF(t *testing.T) { + lex := NewStreamingLexer(strings.NewReader("MY.attr")) + tok := lex.Lex(&yySymType{}) + if tok != IDENTIFIER { + t.Fatalf("expected IDENTIFIER token, got %d", tok) + } + if lex.err != nil { + t.Fatalf("unexpected error: %v", lex.err) + } +} + +func TestParserComplexExpressionCoverage(t *testing.T) { + input := `[ +A = (1 + 2 * 3 >= 4) ? strcat("x", "y") : {5, 6}; +B = [C = {1, 2}]; +D = (1 << 2) >> 1; +E = !true || false && ~0; +F = (1 % 2) + (-3); +]` + lex := NewLexer(input) + yyParse(lex) + if _, err := lex.Result(); err != nil { + t.Fatalf("unexpected parse error: %v", err) + } +} + +func TestParserSyntaxErrorPath(t *testing.T) { + lex := NewLexer("[A = ]") + yyParse(lex) + if _, err := lex.Result(); err == nil { + t.Fatalf("expected syntax error, got nil") + } +} + +func TestStreamingLexerSingleCharOperators(t *testing.T) { + cases := []struct { + input string + token int + }{ + {"+", int('+')}, + {"-", int('-')}, + {"*", int('*')}, + {"/", int('/')}, + {"%", int('%')}, + {"?", int('?')}, + {":", int(':')}, + {"^", int('^')}, + {"~", int('~')}, + {".", int('.')}, + } + + for _, tc := range cases { + lex := NewStreamingLexer(strings.NewReader(tc.input)) + lex.stopAfterClassAd = false + lval := &yySymType{} + tok := lex.Lex(lval) + if tok != tc.token { + t.Fatalf("input %q: got token %d want %d", tc.input, tok, tc.token) + } + if next := lex.Lex(lval); next != 0 { + t.Fatalf("input %q: expected EOF token 0, got %d", tc.input, next) + } + if lex.err != nil { + t.Fatalf("input %q: unexpected error %v", tc.input, lex.err) + } + } +} diff --git a/parser/lexer.go b/parser/lexer.go index 3f1ce3b..4ca4933 100644 --- a/parser/lexer.go +++ b/parser/lexer.go @@ -2,11 +2,8 @@ package parser import ( - "fmt" - "strconv" + "bufio" "strings" - "unicode" - "unicode/utf8" "github.com/PelicanPlatform/classad/ast" ) @@ -18,438 +15,56 @@ type Token struct { Pos int } -// Lexer represents a lexical scanner for ClassAd expressions. +// Lexer wraps the streaming lexer for string inputs while retaining the input +// and position fields used in existing tests. type Lexer struct { input string pos int + lex *StreamingLexer result ast.Node err error } // NewLexer creates a new lexer for the given input. func NewLexer(input string) *Lexer { + br := bufio.NewReader(strings.NewReader(input)) + lex := NewStreamingLexer(br) + lex.stopAfterClassAd = false return &Lexer{ input: input, pos: 0, + lex: lex, } } // Lex implements the goyacc Lexer interface. func (l *Lexer) Lex(lval *yySymType) int { - l.skipWhitespace() - - if l.pos >= len(l.input) { - return 0 // EOF - } - - // Check for operators and punctuation - switch l.peek() { - case '[': - l.advance() - return int('[') - case ']': - l.advance() - return int(']') - case '{': - l.advance() - return int('{') - case '}': - l.advance() - return int('}') - case '(': - l.advance() - return int('(') - case ')': - l.advance() - return int(')') - case ';': - l.advance() - return int(';') - case ',': - l.advance() - return int(',') - case '?': - l.advance() - return int('?') - case ':': - l.advance() - return int(':') - case '^': - l.advance() - return int('^') - case '~': - l.advance() - return int('~') - case '+': - l.advance() - return int('+') - case '-': - l.advance() - return int('-') - case '*': - l.advance() - return int('*') - case '%': - l.advance() - return int('%') - case '.': - l.advance() - return int('.') - case '"': - str := l.scanString() - lval.str = str - return STRING_LITERAL - case '=': - l.advance() - if l.peek() == '=' { - l.advance() - return EQ - } else if l.peek() == '?' { - l.advance() - if l.peek() == '=' { - l.advance() - return IS // =?= is an alias for 'is' - } - // Put back the '?' - l.pos-- - return int('=') - } else if l.peek() == '!' { - l.advance() - if l.peek() == '=' { - l.advance() - return ISNT // =!= is an alias for 'isnt' - } - // Put back the '!' - l.pos-- - return int('=') - } - return int('=') - case '!': - l.advance() - if l.peek() == '=' { - l.advance() - return NE - } - return int('!') - case '<': - l.advance() - if l.peek() == '=' { - l.advance() - return LE - } else if l.peek() == '<' { - l.advance() - return LSHIFT - } - return int('<') - case '>': - l.advance() - if l.peek() == '=' { - l.advance() - return GE - } else if l.peek() == '>' { - l.advance() - if l.peek() == '>' { - l.advance() - return URSHIFT - } - return RSHIFT - } - return int('>') - case '&': - l.advance() - if l.peek() == '&' { - l.advance() - return AND - } - return int('&') - case '|': - l.advance() - if l.peek() == '|' { - l.advance() - return OR - } - return int('|') - case '/': - l.advance() - if l.peek() == '/' { - // Line comment - l.skipLineComment() - return l.Lex(lval) - } else if l.peek() == '*' { - // Block comment - l.skipBlockComment() - return l.Lex(lval) - } - return int('/') - } - - // Check for numbers - ch := l.peek() - if unicode.IsDigit(ch) { - return l.scanNumber(lval) - } - - // Check for identifiers and keywords - if unicode.IsLetter(ch) || ch == '_' { - return l.scanIdentifierOrKeyword(lval) - } - - // Unknown character - l.Error(fmt.Sprintf("unexpected character: %c", ch)) - l.advance() - return l.Lex(lval) + tok := l.lex.Lex(lval) + // Mirror position for tests. + l.pos = l.lex.pos + return tok } // Error implements the goyacc Lexer interface. func (l *Lexer) Error(s string) { - l.err = fmt.Errorf("parse error at position %d: %s", l.pos, s) + l.lex.Error(s) + // Mirror error for tests that inspect Result. + l.err = l.lex.err } // Result returns the parsed result and any error. func (l *Lexer) Result() (ast.Node, error) { - return l.result, l.err + if l.result != nil || l.err != nil { + return l.result, l.err + } + res, err := l.lex.Result() + l.result = res + l.err = err + return res, err } // SetResult sets the parse result. func (l *Lexer) SetResult(node ast.Node) { + l.lex.SetResult(node) l.result = node } - -func (l *Lexer) peek() rune { - if l.pos >= len(l.input) { - return 0 - } - ch, _ := utf8.DecodeRuneInString(l.input[l.pos:]) - return ch -} - -func (l *Lexer) advance() { - if l.pos < len(l.input) { - _, size := utf8.DecodeRuneInString(l.input[l.pos:]) - l.pos += size - } -} - -func (l *Lexer) skipWhitespace() { - for l.pos < len(l.input) && unicode.IsSpace(l.peek()) { - l.advance() - } -} - -func (l *Lexer) skipLineComment() { - // Skip // - l.advance() - for l.pos < len(l.input) && l.peek() != '\n' { - l.advance() - } -} - -func (l *Lexer) skipBlockComment() { - // Skip /* - l.advance() - for l.pos < len(l.input) { - if l.peek() == '*' { - l.advance() - if l.peek() == '/' { - l.advance() - return - } - } else { - l.advance() - } - } -} - -func (l *Lexer) scanString() string { - // Skip opening quote - l.advance() - start := l.pos - var result strings.Builder - - for l.pos < len(l.input) { - ch := l.peek() - if ch == '"' { - l.advance() - return result.String() - } else if ch == '\\' { - l.advance() - if l.pos < len(l.input) { - escaped := l.peek() - l.advance() - switch escaped { - case 'b': - result.WriteRune('\b') // Backspace (8) - case 't': - result.WriteRune('\t') // Tab (9) - case 'n': - result.WriteRune('\n') // Newline (10) - case 'f': - result.WriteRune('\f') // Formfeed (12) - case 'r': - result.WriteRune('\r') // Carriage return (13) - case '\\': - result.WriteRune('\\') // Backslash (92) - case '"': - result.WriteRune('"') // Quote (34) - case '\'': - result.WriteRune('\'') // Apostrophe (39) - case '0', '1', '2', '3', '4', '5', '6', '7': - // Octal escape sequence - // If first digit is 0-3, read up to 3 digits - // If first digit is 4-7, read up to 2 digits - var octalStr strings.Builder - octalStr.WriteRune(escaped) - - maxDigits := 2 - if escaped >= '0' && escaped <= '3' { - maxDigits = 3 - } - - // Read additional octal digits - for i := 1; i < maxDigits && l.pos < len(l.input); i++ { - nextCh := l.peek() - if nextCh >= '0' && nextCh <= '7' { - octalStr.WriteRune(nextCh) - l.advance() - } else { - break - } - } - - // Convert octal string to integer - octalValue := int64(0) - for _, digit := range octalStr.String() { - octalValue = octalValue*8 + int64(digit-'0') - } - - // Check for null (value 0) which is not allowed - if octalValue == 0 { - l.Error(fmt.Sprintf("null character (\\%s) not allowed in string at position %d", octalStr.String(), l.pos-len(octalStr.String())-2)) - return result.String() - } - - result.WriteRune(rune(octalValue)) - default: - // Unknown escape sequence - this is an error according to spec - l.Error(fmt.Sprintf("invalid escape sequence \\%c at position %d", escaped, l.pos-2)) - result.WriteRune(escaped) - } - } - } else { - result.WriteRune(ch) - l.advance() - } - } - - l.Error(fmt.Sprintf("unterminated string starting at position %d", start-1)) - return result.String() -} - -func (l *Lexer) scanNumber(lval *yySymType) int { - start := l.pos - hasDecimal := false - hasExponent := false - - for l.pos < len(l.input) { - ch := l.peek() - if unicode.IsDigit(ch) { - l.advance() - } else if ch == '.' && !hasDecimal && !hasExponent { - hasDecimal = true - l.advance() - } else if (ch == 'e' || ch == 'E') && !hasExponent { - hasExponent = true - hasDecimal = true // Exponent implies floating point - l.advance() - if l.peek() == '+' || l.peek() == '-' { - l.advance() - } - } else { - break - } - } - - text := l.input[start:l.pos] - if hasDecimal || hasExponent { - val, err := strconv.ParseFloat(text, 64) - if err != nil { - l.Error(fmt.Sprintf("invalid real number: %s", text)) - return 0 - } - lval.real = val - return REAL_LITERAL - } - - val, err := strconv.ParseInt(text, 10, 64) - if err != nil { - l.Error(fmt.Sprintf("invalid integer: %s", text)) - return 0 - } - lval.integer = val - return INTEGER_LITERAL -} - -func (l *Lexer) scanIdentifierOrKeyword(lval *yySymType) int { - start := l.pos - l.advance() - - for l.pos < len(l.input) { - ch := l.peek() - if unicode.IsLetter(ch) || unicode.IsDigit(ch) || ch == '_' { - l.advance() - } else { - break - } - } - - text := l.input[start:l.pos] - - // Check for scoped attribute references (MY., TARGET., PARENT.) - textUpper := strings.ToUpper(text) - if l.peek() == '.' { - switch textUpper { - case "MY", "TARGET", "PARENT": - // Consume the dot - l.advance() - // Now scan the attribute name - if unicode.IsLetter(l.peek()) || l.peek() == '_' { - l.advance() - for l.pos < len(l.input) { - ch := l.peek() - if unicode.IsLetter(ch) || unicode.IsDigit(ch) || ch == '_' { - l.advance() - } else { - break - } - } - // Return the full scoped reference (e.g., "MY.Cpus") - scopedName := l.input[start:l.pos] - lval.str = scopedName - return IDENTIFIER - } - // If no valid identifier follows, put the dot back and continue - l.pos-- - } - } - - // Check for keywords - switch strings.ToLower(text) { - case "true": - lval.boolean = true - return BOOLEAN_LITERAL - case "false": - lval.boolean = false - return BOOLEAN_LITERAL - case "undefined": - return UNDEFINED - case "error": - return ERROR - case "is": - return IS - case "isnt": - return ISNT - } - - lval.str = text - return IDENTIFIER -} diff --git a/parser/lexer_test.go b/parser/lexer_test.go index bb2da04..f189817 100644 --- a/parser/lexer_test.go +++ b/parser/lexer_test.go @@ -176,6 +176,101 @@ func TestLexerError(t *testing.T) { } } +func TestLexerErrorFormattingUnexpectedChar(t *testing.T) { + lex := NewLexer("foo\n @") + lval := &yySymType{} + for { + if tok := lex.Lex(lval); tok == 0 { + break + } + } + _, err := lex.Result() + if err == nil { + t.Fatalf("expected error, got nil") + } + + const expected = "parse error at line 2, col 3: unexpected character: @\n @\n ^" + if err.Error() != expected { + t.Fatalf("unexpected error message:\n got: %q\nwant: %q", err.Error(), expected) + } +} + +func TestLexerErrorFormattingUnterminatedString(t *testing.T) { + lex := NewLexer("\"foo") + lval := &yySymType{} + for { + if tok := lex.Lex(lval); tok == 0 { + break + } + } + _, err := lex.Result() + if err == nil { + t.Fatalf("expected error, got nil") + } + + const expected = "parse error at line 1, col 4: unterminated string starting at byte 0\n\"foo\n ^" + if err.Error() != expected { + t.Fatalf("unexpected error message:\n got: %q\nwant: %q", err.Error(), expected) + } +} + +func TestLexerErrorFormattingUnterminatedBlockComment(t *testing.T) { + lex := NewLexer("/* unterminated") + lval := &yySymType{} + for { + if tok := lex.Lex(lval); tok == 0 { + break + } + } + _, err := lex.Result() + if err == nil { + t.Fatalf("expected error, got nil") + } + + const expected = "parse error at line 1, col 15: unterminated block comment\n/* unterminated\n ^" + if err.Error() != expected { + t.Fatalf("unexpected error message:\n got: %q\nwant: %q", err.Error(), expected) + } +} + +func TestLexerErrorFormattingNullOctalInString(t *testing.T) { + lex := NewLexer("\"\\000\"") + lval := &yySymType{} + for { + if tok := lex.Lex(lval); tok == 0 { + break + } + } + _, err := lex.Result() + if err == nil { + t.Fatalf("expected error, got nil") + } + + const expected = "parse error at line 1, col 6: unterminated string starting at byte 5\n\"\\000\"\n ^" + if err.Error() != expected { + t.Fatalf("unexpected error message:\n got: %q\nwant: %q", err.Error(), expected) + } +} + +func TestLexerErrorFormattingInvalidNumber(t *testing.T) { + lex := NewLexer("1e+") + lval := &yySymType{} + for { + if tok := lex.Lex(lval); tok == 0 { + break + } + } + _, err := lex.Result() + if err == nil { + t.Fatalf("expected error, got nil") + } + + const expected = "parse error at line 1, col 3: invalid real number: 1e+\n1e+\n ^" + if err.Error() != expected { + t.Fatalf("unexpected error message:\n got: %q\nwant: %q", err.Error(), expected) + } +} + func TestLexerResult(t *testing.T) { lex := NewLexer("test") expectedResult := &ast.ClassAd{} diff --git a/parser/parser.go b/parser/parser.go index f432cd6..4439cac 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -3,6 +3,8 @@ package parser import ( + "fmt" + "io" "strings" "github.com/PelicanPlatform/classad/ast" @@ -27,6 +29,43 @@ func ParseClassAd(input string) (*ast.ClassAd, error) { return nil, nil } +// ReaderParser parses consecutive ClassAds from a buffered reader without +// requiring delimiters between ads. It reuses a single streaming lexer instance +// for efficiency. +type ReaderParser struct { + lex *StreamingLexer +} + +// NewReaderParser creates a reusable parser that pulls consecutive ClassAds +// from the provided reader without requiring delimiters. Non-buffered readers +// are wrapped internally for efficiency. +func NewReaderParser(r io.Reader) *ReaderParser { + return &ReaderParser{lex: NewStreamingLexer(r)} +} + +// ParseClassAd parses the next ClassAd from the underlying reader. +// It reuses the same streaming lexer instance to avoid per-call allocations. +// It returns io.EOF when there is no more data to parse. +func (p *ReaderParser) ParseClassAd() (*ast.ClassAd, error) { + p.lex.resetForNext() + if err := p.lex.skipTrivia(); err != nil { + if err == io.EOF { + return nil, io.EOF + } + return nil, err + } + + yyParse(p.lex) + node, err := p.lex.Result() + if err != nil { + return nil, err + } + if classad, ok := node.(*ast.ClassAd); ok { + return classad, nil + } + return nil, fmt.Errorf("failed to parse ClassAd") +} + // ParseScopedIdentifier parses an identifier that may have a scope prefix. // Returns the attribute name and scope. func ParseScopedIdentifier(identifier string) (string, ast.AttributeScope) { diff --git a/parser/reader_parser_test.go b/parser/reader_parser_test.go new file mode 100644 index 0000000..224b24d --- /dev/null +++ b/parser/reader_parser_test.go @@ -0,0 +1,137 @@ +package parser + +import ( + "io" + "strings" + "testing" + + "github.com/PelicanPlatform/classad/ast" +) + +func getAttrInt(t *testing.T, ad *ast.ClassAd, name string) int64 { + t.Helper() + for _, attr := range ad.Attributes { + if attr.Name == name { + if lit, ok := attr.Value.(*ast.IntegerLiteral); ok { + return lit.Value + } + break + } + } + t.Fatalf("attribute %s not found or not integer", name) + return 0 +} + +func TestReaderParserConcatenatedAdsWithComments(t *testing.T) { + input := "/*lead*/[A=1]//line comment\n/*block\ncomment*/[B=2][C=3]/*tail*/" + p := NewReaderParser(strings.NewReader(input)) + + var values []int64 + for { + ad, err := p.ParseClassAd() + if err != nil { + if err == io.EOF { + break + } + t.Fatalf("unexpected error: %v", err) + } + if len(ad.Attributes) == 0 { + t.Fatalf("parsed empty ClassAd") + } + values = append(values, getAttrInt(t, ad, ad.Attributes[0].Name)) + } + + expected := []int64{1, 2, 3} + if len(values) != len(expected) { + t.Fatalf("expected %d ads, got %d", len(expected), len(values)) + } + for i, v := range expected { + if values[i] != v { + t.Fatalf("value %d: expected %d, got %d", i, v, values[i]) + } + } +} + +func TestReaderParserCommentsInsideAd(t *testing.T) { + input := `[A /*c1*/=/*c2*/ 5 /*c3*/; B= /*c4*/6]/*after*/` + p := NewReaderParser(strings.NewReader(input)) + + ad, err := p.ParseClassAd() + if err != nil { + t.Fatalf("unexpected error: %v, tokens=%v", err, dumpTokens(t, input)) + } + + a := getAttrInt(t, ad, "A") + b := getAttrInt(t, ad, "B") + + if a != 5 || b != 6 { + t.Fatalf("expected A=5 and B=6, got A=%d B=%d", a, b) + } + + if next, err := p.ParseClassAd(); err != io.EOF { + if err == nil { + t.Fatalf("expected EOF, got additional ad: %+v", next) + } + t.Fatalf("expected EOF, got error: %v", err) + } +} + +func dumpTokens(t *testing.T, input string) []int { + t.Helper() + lex := NewStreamingLexer(strings.NewReader(input)) + var tokens []int + lval := &yySymType{} + for { + tok := lex.Lex(lval) + if tok == 0 { + break + } + tokens = append(tokens, tok) + } + return tokens +} + +func TestReaderParserUnterminatedBlockComment(t *testing.T) { + input := "/* unterminated [A=1]" + p := NewReaderParser(strings.NewReader(input)) + + ad, err := p.ParseClassAd() + if err == nil { + t.Fatalf("expected error, got ad: %+v", ad) + } +} + +func TestReaderParserComplexExpression(t *testing.T) { + input := `[X = (Memory / 1024) + Disk]` + p := NewReaderParser(strings.NewReader(input)) + + ad, err := p.ParseClassAd() + if err != nil { + t.Fatalf("unexpected error: %v, tokens=%v", err, dumpTokens(t, input)) + } + + if len(ad.Attributes) != 1 { + t.Fatalf("expected 1 attribute, got %d", len(ad.Attributes)) + } + + attr := ad.Attributes[0] + if attr.Name != "X" { + t.Fatalf("expected attribute X, got %s", attr.Name) + } + + bin, ok := attr.Value.(*ast.BinaryOp) + if !ok { + t.Fatalf("expected binary op for X value, got %T", attr.Value) + } + if bin.Op != "+" { + t.Fatalf("expected + op at root, got %s", bin.Op) + } + + left, ok := bin.Left.(*ast.BinaryOp) + if !ok { + t.Fatalf("expected left child to be binary op, got %T", bin.Left) + } + if left.Op != "/" { + t.Fatalf("expected / in nested op, got %s", left.Op) + } +} diff --git a/parser/streaming_lexer.go b/parser/streaming_lexer.go new file mode 100644 index 0000000..8417f9c --- /dev/null +++ b/parser/streaming_lexer.go @@ -0,0 +1,701 @@ +package parser + +import ( + "bufio" + "errors" + "fmt" + "io" + "strconv" + "strings" + "unicode" + "unicode/utf8" + + "github.com/PelicanPlatform/classad/ast" +) + +// StreamingLexer tokenizes ClassAds directly from an io.Reader. It stops +// producing tokens after the first complete ClassAd so the caller can parse +// multiple ads from a single stream. +type StreamingLexer struct { + r *bufio.Reader + pos int + result ast.Node + err error + depth int + started bool + done bool + stopAfterClassAd bool + pendingRune rune + pendingSize int + hasPending bool + seen []rune +} + +// NewStreamingLexer creates a lexer that consumes tokens directly from a reader. +// It wraps non-buffered readers in a bufio.Reader for efficiency. +func NewStreamingLexer(r io.Reader) *StreamingLexer { + if br, ok := r.(*bufio.Reader); ok { + return &StreamingLexer{r: br, stopAfterClassAd: true} + } + return &StreamingLexer{r: bufio.NewReader(r), stopAfterClassAd: true} +} + +// resetForNext prepares the lexer to scan another ClassAd from the same reader. +// It preserves the current reader position but clears parsing state and result. +func (l *StreamingLexer) resetForNext() { + l.result = nil + l.err = nil + l.depth = 0 + l.started = false + l.done = false + l.pendingRune = 0 + l.pendingSize = 0 + l.hasPending = false + l.seen = l.seen[:0] +} + +// Lex implements the goyacc Lexer interface. +func (l *StreamingLexer) Lex(lval *yySymType) int { + if l.done { + return 0 + } + + if err := l.skipTrivia(); err != nil { + if err == io.EOF { + l.done = true + return 0 + } + l.err = err + return 0 + } + + ch, _, err := l.readRune() + if err != nil { + if err == io.EOF { + l.done = true + if l.started && l.depth > 0 { + l.Error("unexpected EOF while parsing ClassAd") + } + return 0 + } + l.err = err + return 0 + } + + // Check for operators and punctuation + switch ch { + case '[': + l.started = true + l.depth++ + return int('[') + case ']': + if l.depth > 0 { + l.depth-- + } + if l.stopAfterClassAd && l.started && l.depth == 0 { + // Signal EOF after this token so the parser stops at the first ClassAd. + l.done = true + } + return int(']') + case '{': + return int('{') + case '}': + return int('}') + case '(': + return int('(') + case ')': + return int(')') + case ';': + return int(';') + case ',': + return int(',') + case '?': + return int('?') + case ':': + return int(':') + case '^': + return int('^') + case '~': + return int('~') + case '+': + return int('+') + case '-': + return int('-') + case '*': + return int('*') + case '%': + return int('%') + case '.': + return int('.') + case '"': + str := l.scanString() + lval.str = str + return STRING_LITERAL + case '=': + if next, err := l.peekRune(); err == nil { + switch next { + case '=': + if err := l.discardRune(); err != nil { + l.err = err + return 0 + } + return EQ + case '?': + if err := l.discardRune(); err != nil { + l.err = err + return 0 + } + if peek, err := l.peekRune(); err == nil && peek == '=' { + if err := l.discardRune(); err != nil { + l.err = err + return 0 + } + return IS + } + // Put back the '?' by unread one rune + l.unreadRune(utf8.RuneLen('?')) + return int('=') + case '!': + if err := l.discardRune(); err != nil { + l.err = err + return 0 + } + if peek, err := l.peekRune(); err == nil && peek == '=' { + if err := l.discardRune(); err != nil { + l.err = err + return 0 + } + return ISNT + } + l.unreadRune(utf8.RuneLen('!')) + return int('=') + } + } + return int('=') + case '!': + if next, err := l.peekRune(); err == nil && next == '=' { + if err := l.discardRune(); err != nil { + l.err = err + return 0 + } + return NE + } + return int('!') + case '<': + if next, err := l.peekRune(); err == nil { + switch next { + case '=': + if err := l.discardRune(); err != nil { + l.err = err + return 0 + } + return LE + case '<': + if err := l.discardRune(); err != nil { + l.err = err + return 0 + } + return LSHIFT + } + } + return int('<') + case '>': + if next, err := l.peekRune(); err == nil { + switch next { + case '=': + if err := l.discardRune(); err != nil { + l.err = err + return 0 + } + return GE + case '>': + if err := l.discardRune(); err != nil { + l.err = err + return 0 + } + if peek, err := l.peekRune(); err == nil && peek == '>' { + if err := l.discardRune(); err != nil { + l.err = err + return 0 + } + return URSHIFT + } + return RSHIFT + } + } + return int('>') + case '&': + if next, err := l.peekRune(); err == nil && next == '&' { + if err := l.discardRune(); err != nil { + l.err = err + return 0 + } + return AND + } + return int('&') + case '|': + if next, err := l.peekRune(); err == nil && next == '|' { + if err := l.discardRune(); err != nil { + l.err = err + return 0 + } + return OR + } + return int('|') + case '/': + if next, err := l.peekRune(); err == nil { + switch next { + case '/': + if err := l.discardRune(); err != nil { + l.err = err + return 0 + } + if err := l.skipLineComment(); err != nil { + l.err = err + return 0 + } + return l.Lex(lval) + case '*': + if err := l.discardRune(); err != nil { + l.err = err + return 0 + } + if err := l.skipBlockComment(); err != nil { + l.err = err + return 0 + } + return l.Lex(lval) + } + } + return int('/') + } + + // Numbers + if unicode.IsDigit(ch) { + return l.scanNumber(ch, lval) + } + + // Identifiers and keywords + if unicode.IsLetter(ch) || ch == '_' { + return l.scanIdentifierOrKeyword(ch, lval) + } + + // Unknown character + l.Error(fmt.Sprintf("unexpected character: %c", ch)) + return l.Lex(lval) +} + +// Error implements the goyacc Lexer interface. +func (l *StreamingLexer) Error(s string) { + l.err = errors.New(l.formatError(s)) +} + +// Result returns the parsed result and any error. +func (l *StreamingLexer) Result() (ast.Node, error) { + return l.result, l.err +} + +// SetResult sets the parse result. +func (l *StreamingLexer) SetResult(node ast.Node) { + l.result = node +} + +func (l *StreamingLexer) readRune() (rune, int, error) { + if l.hasPending { + ch := l.pendingRune + size := l.pendingSize + l.hasPending = false + l.recordRune(ch, size) + return ch, size, nil + } + + ch, size, err := l.r.ReadRune() + if err != nil { + return 0, 0, err + } + l.recordRune(ch, size) + return ch, size, nil +} + +func (l *StreamingLexer) unreadRune(size int) { + if len(l.seen) == 0 { + return + } + last := l.seen[len(l.seen)-1] + l.seen = l.seen[:len(l.seen)-1] + l.pos -= size + l.pendingRune = last + l.pendingSize = size + l.hasPending = true +} + +func (l *StreamingLexer) peekRune() (rune, error) { + if l.hasPending { + return l.pendingRune, nil + } + + ch, size, err := l.r.ReadRune() + if err != nil { + return 0, err + } + // Do not advance pos/seen yet; stage as pending. + l.pendingRune = ch + l.pendingSize = size + l.hasPending = true + return ch, nil +} + +// discardRune consumes a rune and returns any read error. +func (l *StreamingLexer) discardRune() error { + _, _, err := l.readRune() + return err +} + +func (l *StreamingLexer) skipTrivia() error { + for { + ch, size, err := l.readRune() + if err != nil { + return err + } + + if unicode.IsSpace(ch) { + continue + } + + if ch == '/' { + next, err := l.peekRune() + if err == nil { + switch next { + case '/': + // Consume next '/' + if err := l.discardRune(); err != nil { + return err + } + if err := l.skipLineComment(); err != nil { + return err + } + continue + case '*': + if err := l.discardRune(); err != nil { + return err + } + if err := l.skipBlockComment(); err != nil { + return err + } + continue + } + } + } + + // Non-trivia rune; stage it for the lexer without consuming it. + l.pendingRune = ch + l.pendingSize = size + l.hasPending = true + // We recorded this rune in readRune, so roll back the position and seen to + // reflect that it is not yet consumed by the parser. + l.pos -= size + if len(l.seen) > 0 { + l.seen = l.seen[:len(l.seen)-1] + } + return nil + } +} + +func (l *StreamingLexer) skipLineComment() error { + for { + ch, _, err := l.readRune() + if err != nil { + if err == io.EOF { + return nil + } + return err + } + if ch == '\n' { + return nil + } + } +} + +func (l *StreamingLexer) skipBlockComment() error { + for { + ch, _, err := l.readRune() + if err != nil { + if err == io.EOF { + l.Error("unterminated block comment") + } + return err + } + if ch == '*' { + next, err := l.peekRune() + if err == nil && next == '/' { + if err := l.discardRune(); err != nil { + return err + } + return nil + } + } + } +} + +func (l *StreamingLexer) scanString() string { + var result strings.Builder + startPos := l.pos - utf8.RuneLen('"') + + for { + ch, _, err := l.readRune() + if err != nil { + if err == io.EOF { + l.Error(fmt.Sprintf("unterminated string starting at byte %d", startPos)) + } + return result.String() + } + + if ch == '"' { + return result.String() + } + + if ch == '\\' { + escaped, _, err := l.readRune() + if err != nil { + l.Error(fmt.Sprintf("unterminated escape sequence in string starting at position %d", startPos)) + return result.String() + } + switch escaped { + case 'b': + result.WriteRune('\b') + case 't': + result.WriteRune('\t') + case 'n': + result.WriteRune('\n') + case 'f': + result.WriteRune('\f') + case 'r': + result.WriteRune('\r') + case '\\': + result.WriteRune('\\') + case '"': + result.WriteRune('"') + case '\'': + result.WriteRune('\'') + case '0', '1', '2', '3', '4', '5', '6', '7': + var octalStr strings.Builder + octalStr.WriteRune(escaped) + + maxDigits := 2 + if escaped >= '0' && escaped <= '3' { + maxDigits = 3 + } + + for i := 1; i < maxDigits; i++ { + next, err := l.peekRune() + if err != nil { + break + } + if next >= '0' && next <= '7' { + if err := l.discardRune(); err != nil { + return result.String() + } + octalStr.WriteRune(next) + } else { + break + } + } + + val, err := strconv.ParseInt(octalStr.String(), 8, 64) + if err != nil { + l.Error(fmt.Sprintf("invalid octal escape %s at position %d", octalStr.String(), l.pos)) + return result.String() + } + if val == 0 { + l.Error(fmt.Sprintf("null character (\\%s) not allowed in string at position %d", octalStr.String(), l.pos)) + return result.String() + } + result.WriteRune(rune(val)) + default: + l.Error(fmt.Sprintf("invalid escape sequence \\%c at position %d", escaped, l.pos-2)) + result.WriteRune(escaped) + } + continue + } + + result.WriteRune(ch) + } +} + +func (l *StreamingLexer) recordRune(ch rune, size int) { + l.pos += size + l.seen = append(l.seen, ch) +} + +func (l *StreamingLexer) formatError(msg string) string { + line, col := 1, 0 + for _, r := range l.seen { + if r == '\n' { + line++ + col = 0 + } else { + col++ + } + } + + runes := l.seen + lastNL := -1 + for i := len(runes) - 1; i >= 0; i-- { + if runes[i] == '\n' { + lastNL = i + break + } + } + start := lastNL + 1 + lineText := string(runes[start:]) + caret := strings.Repeat(" ", max(col-1, 0)) + "^" + + return fmt.Sprintf("parse error at line %d, col %d: %s\n%s\n%s", line, col, msg, lineText, caret) +} + +func (l *StreamingLexer) scanNumber(first rune, lval *yySymType) int { + var sb strings.Builder + sb.WriteRune(first) + + hasDecimal := false + hasExponent := false + + for { + ch, err := l.peekRune() + if err != nil { + break + } + + if unicode.IsDigit(ch) { + if err := l.discardRune(); err != nil { + return 0 + } + sb.WriteRune(ch) + continue + } + + if ch == '.' && !hasDecimal && !hasExponent { + hasDecimal = true + if err := l.discardRune(); err != nil { + return 0 + } + sb.WriteRune(ch) + continue + } + + if (ch == 'e' || ch == 'E') && !hasExponent { + hasExponent = true + hasDecimal = true + if err := l.discardRune(); err != nil { + return 0 + } + sb.WriteRune(ch) + next, err := l.peekRune() + if err == nil && (next == '+' || next == '-') { + if err := l.discardRune(); err != nil { + return 0 + } + sb.WriteRune(next) + } + continue + } + + break + } + + text := sb.String() + + if hasDecimal || hasExponent { + val, err := strconv.ParseFloat(text, 64) + if err != nil { + l.Error(fmt.Sprintf("invalid real number: %s", text)) + return 0 + } + lval.real = val + return REAL_LITERAL + } + + val, err := strconv.ParseInt(text, 10, 64) + if err != nil { + l.Error(fmt.Sprintf("invalid integer: %s", text)) + return 0 + } + lval.integer = val + return INTEGER_LITERAL +} + +func (l *StreamingLexer) scanIdentifierOrKeyword(first rune, lval *yySymType) int { + var sb strings.Builder + sb.WriteRune(first) + + for { + ch, err := l.peekRune() + if err != nil { + break + } + if unicode.IsLetter(ch) || unicode.IsDigit(ch) || ch == '_' { + if err := l.discardRune(); err != nil { + return 0 + } + sb.WriteRune(ch) + continue + } + break + } + + text := sb.String() + textUpper := strings.ToUpper(text) + + if next, err := l.peekRune(); err == nil && next == '.' { + switch textUpper { + case "MY", "TARGET", "PARENT": + // Consume '.' + if err := l.discardRune(); err != nil { + return 0 + } + peek, err := l.peekRune() + if err == nil && (unicode.IsLetter(peek) || peek == '_') { + if err := l.discardRune(); err != nil { + return 0 + } + sb.WriteRune('.') + sb.WriteRune(peek) + for { + nextCh, err := l.peekRune() + if err != nil { + break + } + if unicode.IsLetter(nextCh) || unicode.IsDigit(nextCh) || nextCh == '_' { + if err := l.discardRune(); err != nil { + return 0 + } + sb.WriteRune(nextCh) + } else { + break + } + } + } + } + } + + scoped := sb.String() + switch strings.ToLower(scoped) { + case "true": + lval.boolean = true + return BOOLEAN_LITERAL + case "false": + lval.boolean = false + return BOOLEAN_LITERAL + case "undefined": + return UNDEFINED + case "error": + return ERROR + case "is": + return IS + case "isnt": + return ISNT + } + + lval.str = scoped + return IDENTIFIER +}