From 8be4415d0f39cd794de227f9299a11e3cbd46dc3 Mon Sep 17 00:00:00 2001 From: utahta Date: Thu, 12 Sep 2024 10:42:15 +0900 Subject: [PATCH] fix --- grpc/federation/cel.go | 88 ++++++++++++++++++++++-------------------- 1 file changed, 46 insertions(+), 42 deletions(-) diff --git a/grpc/federation/cel.go b/grpc/federation/cel.go index 2fca43d4..34c463ec 100644 --- a/grpc/federation/cel.go +++ b/grpc/federation/cel.go @@ -744,7 +744,7 @@ func createCELAst(req *EvalCELRequest, env *cel.Env) (*cel.Ast, error) { if err != nil { return nil, err } - if newNullValueFuncReplacer().Replace(checkedExpr) { + if newComparingNullResolver().Resolve(checkedExpr) { ca, err := celast.ToAST(checkedExpr) if err != nil { return nil, err @@ -860,20 +860,19 @@ func SetCELValue[T any](ctx context.Context, param *SetCELValueParam[T]) error { return nil } -// nullValueFuncReplacer is a feature that allows to compare typed null and null value correctly. +// comparingNullResolver is a feature that allows to compare typed null and null value correctly. // It parses the expression and wraps the message with grpc.federation.cast.null_value if the message is compared to null. -type nullValueFuncReplacer struct { +type comparingNullResolver struct { checkedExpr *exprpb.CheckedExpr lastID int64 - replaced bool - unsupported bool + resolved bool } -func newNullValueFuncReplacer() *nullValueFuncReplacer { - return &nullValueFuncReplacer{} +func newComparingNullResolver() *comparingNullResolver { + return &comparingNullResolver{} } -func (r *nullValueFuncReplacer) init(checkedExpr *exprpb.CheckedExpr) { +func (r *comparingNullResolver) init(checkedExpr *exprpb.CheckedExpr) { var lastID int64 for k := range checkedExpr.GetReferenceMap() { if lastID < k { @@ -887,30 +886,54 @@ func (r *nullValueFuncReplacer) init(checkedExpr *exprpb.CheckedExpr) { } r.checkedExpr = checkedExpr r.lastID = lastID - r.replaced = false - r.unsupported = false + r.resolved = false } -func (r *nullValueFuncReplacer) nextID() int64 { +func (r *comparingNullResolver) nextID() int64 { r.lastID++ return r.lastID } -func (r *nullValueFuncReplacer) Replace(checkedExpr *exprpb.CheckedExpr) bool { +func (r *comparingNullResolver) Resolve(checkedExpr *exprpb.CheckedExpr) bool { r.init(checkedExpr) newExprVisitor().Visit(checkedExpr.GetExpr(), func(e *exprpb.Expr) { switch e.GetExprKind().(type) { case *exprpb.Expr_CallExpr: - r.replaceCall(e) - case *exprpb.Expr_ComprehensionExpr: - // Comprehension is not supported by parser.Unparse. - r.unsupported = true + r.resolveCallExpr(e) } }) - return r.replaced && !r.unsupported + return r.resolved } -func (r *nullValueFuncReplacer) replaceCall(e *exprpb.Expr) { +func (r *comparingNullResolver) resolveCallExpr(e *exprpb.Expr) { + target := r.lookupComparingNullCallExpr(e) + if target == nil { + return + } + newID := r.nextID() + newExprKind := &exprpb.Expr_CallExpr{ + CallExpr: &exprpb.Expr_Call{ + Function: grpcfedcel.CastNullValueFunc, + Args: []*exprpb.Expr{ + { + Id: target.GetId(), + ExprKind: target.GetExprKind(), + }, + }, + }, + } + target.Id = newID + target.ExprKind = newExprKind + r.checkedExpr.GetReferenceMap()[newID] = &exprpb.Reference{ + OverloadId: []string{grpcfedcel.CastNullValueFunc}, + } + r.checkedExpr.GetTypeMap()[newID] = &exprpb.Type{ + TypeKind: &exprpb.Type_Dyn{}, + } + r.resolved = true +} + +func (r *comparingNullResolver) lookupComparingNullCallExpr(e *exprpb.Expr) *exprpb.Expr { call := e.GetCallExpr() fnName := call.GetFunction() if fnName == operators.Equals || fnName == operators.NotEquals { @@ -923,38 +946,19 @@ func (r *nullValueFuncReplacer) replaceCall(e *exprpb.Expr) { if _, ok := rhs.GetConstExpr().GetConstantKind().(*exprpb.Constant_NullValue); ok { if target != nil { // maybe null == null - return + return nil } target = lhs } if target == nil { - return + return nil } if target.GetCallExpr() != nil && target.GetCallExpr().GetFunction() == grpcfedcel.CastNullValueFunc { - return - } - newID := r.nextID() - newExprKind := &exprpb.Expr_CallExpr{ - CallExpr: &exprpb.Expr_Call{ - Function: grpcfedcel.CastNullValueFunc, - Args: []*exprpb.Expr{ - { - Id: target.GetId(), - ExprKind: target.GetExprKind(), - }, - }, - }, - } - target.Id = newID - target.ExprKind = newExprKind - r.checkedExpr.GetReferenceMap()[newID] = &exprpb.Reference{ - OverloadId: []string{grpcfedcel.CastNullValueFunc}, + return nil } - r.checkedExpr.GetTypeMap()[newID] = &exprpb.Type{ - TypeKind: &exprpb.Type_Dyn{}, - } - r.replaced = true + return target } + return nil } type exprVisitor struct {