Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
utahta committed Sep 12, 2024
1 parent 350073c commit 8be4415
Showing 1 changed file with 46 additions and 42 deletions.
88 changes: 46 additions & 42 deletions grpc/federation/cel.go
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ func createCELAst(req *EvalCELRequest, env *cel.Env) (*cel.Ast, error) {
if err != nil {
return nil, err
}
if newNullValueFuncReplacer().Replace(checkedExpr) {
if newComparingNullResolver().Resolve(checkedExpr) {
ca, err := celast.ToAST(checkedExpr)
if err != nil {
return nil, err
Expand Down Expand Up @@ -860,20 +860,19 @@ func SetCELValue[T any](ctx context.Context, param *SetCELValueParam[T]) error {
return nil
}

// nullValueFuncReplacer is a feature that allows to compare typed null and null value correctly.
// comparingNullResolver is a feature that allows to compare typed null and null value correctly.
// It parses the expression and wraps the message with grpc.federation.cast.null_value if the message is compared to null.
type nullValueFuncReplacer struct {
type comparingNullResolver struct {
checkedExpr *exprpb.CheckedExpr
lastID int64
replaced bool
unsupported bool
resolved bool
}

func newNullValueFuncReplacer() *nullValueFuncReplacer {
return &nullValueFuncReplacer{}
func newComparingNullResolver() *comparingNullResolver {
return &comparingNullResolver{}
}

func (r *nullValueFuncReplacer) init(checkedExpr *exprpb.CheckedExpr) {
func (r *comparingNullResolver) init(checkedExpr *exprpb.CheckedExpr) {
var lastID int64
for k := range checkedExpr.GetReferenceMap() {
if lastID < k {
Expand All @@ -887,30 +886,54 @@ func (r *nullValueFuncReplacer) init(checkedExpr *exprpb.CheckedExpr) {
}
r.checkedExpr = checkedExpr
r.lastID = lastID
r.replaced = false
r.unsupported = false
r.resolved = false
}

func (r *nullValueFuncReplacer) nextID() int64 {
func (r *comparingNullResolver) nextID() int64 {
r.lastID++
return r.lastID
}

func (r *nullValueFuncReplacer) Replace(checkedExpr *exprpb.CheckedExpr) bool {
func (r *comparingNullResolver) Resolve(checkedExpr *exprpb.CheckedExpr) bool {
r.init(checkedExpr)
newExprVisitor().Visit(checkedExpr.GetExpr(), func(e *exprpb.Expr) {
switch e.GetExprKind().(type) {
case *exprpb.Expr_CallExpr:
r.replaceCall(e)
case *exprpb.Expr_ComprehensionExpr:
// Comprehension is not supported by parser.Unparse.
r.unsupported = true
r.resolveCallExpr(e)
}
})
return r.replaced && !r.unsupported
return r.resolved
}

func (r *nullValueFuncReplacer) replaceCall(e *exprpb.Expr) {
func (r *comparingNullResolver) resolveCallExpr(e *exprpb.Expr) {
target := r.lookupComparingNullCallExpr(e)
if target == nil {
return
}
newID := r.nextID()
newExprKind := &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: grpcfedcel.CastNullValueFunc,
Args: []*exprpb.Expr{
{
Id: target.GetId(),
ExprKind: target.GetExprKind(),
},
},
},
}
target.Id = newID
target.ExprKind = newExprKind
r.checkedExpr.GetReferenceMap()[newID] = &exprpb.Reference{
OverloadId: []string{grpcfedcel.CastNullValueFunc},
}
r.checkedExpr.GetTypeMap()[newID] = &exprpb.Type{
TypeKind: &exprpb.Type_Dyn{},
}
r.resolved = true
}

func (r *comparingNullResolver) lookupComparingNullCallExpr(e *exprpb.Expr) *exprpb.Expr {
call := e.GetCallExpr()
fnName := call.GetFunction()
if fnName == operators.Equals || fnName == operators.NotEquals {
Expand All @@ -923,38 +946,19 @@ func (r *nullValueFuncReplacer) replaceCall(e *exprpb.Expr) {
if _, ok := rhs.GetConstExpr().GetConstantKind().(*exprpb.Constant_NullValue); ok {
if target != nil {
// maybe null == null
return
return nil
}
target = lhs
}
if target == nil {
return
return nil
}
if target.GetCallExpr() != nil && target.GetCallExpr().GetFunction() == grpcfedcel.CastNullValueFunc {
return
}
newID := r.nextID()
newExprKind := &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: grpcfedcel.CastNullValueFunc,
Args: []*exprpb.Expr{
{
Id: target.GetId(),
ExprKind: target.GetExprKind(),
},
},
},
}
target.Id = newID
target.ExprKind = newExprKind
r.checkedExpr.GetReferenceMap()[newID] = &exprpb.Reference{
OverloadId: []string{grpcfedcel.CastNullValueFunc},
return nil
}
r.checkedExpr.GetTypeMap()[newID] = &exprpb.Type{
TypeKind: &exprpb.Type_Dyn{},
}
r.replaced = true
return target
}
return nil
}

type exprVisitor struct {
Expand Down

0 comments on commit 8be4415

Please sign in to comment.