@@ -23,14 +23,15 @@ import org.apache.spark.internal.{Logging, MDC}
23
23
import org .apache .spark .internal .LogKeys .{FUNCTION_NAME , FUNCTION_PARAM }
24
24
import org .apache .spark .sql .AnalysisException
25
25
import org .apache .spark .sql .catalyst .{InternalRow , SQLConfHelper }
26
- import org .apache .spark .sql .catalyst .analysis .NoSuchFunctionException
26
+ import org .apache .spark .sql .catalyst .analysis .{ NoSuchFunctionException , UnresolvedAttribute }
27
27
import org .apache .spark .sql .catalyst .encoders .EncoderUtils
28
28
import org .apache .spark .sql .catalyst .expressions .objects .{Invoke , StaticInvoke }
29
29
import org .apache .spark .sql .catalyst .plans .logical .LogicalPlan
30
30
import org .apache .spark .sql .connector .catalog .{FunctionCatalog , Identifier }
31
31
import org .apache .spark .sql .connector .catalog .functions ._
32
32
import org .apache .spark .sql .connector .catalog .functions .ScalarFunction .MAGIC_METHOD_NAME
33
- import org .apache .spark .sql .connector .expressions .{BucketTransform , Expression => V2Expression , FieldReference , IdentityTransform , Literal => V2Literal , NamedReference , NamedTransform , NullOrdering => V2NullOrdering , SortDirection => V2SortDirection , SortOrder => V2SortOrder , SortValue , Transform }
33
+ import org .apache .spark .sql .connector .expressions .{BucketTransform , Cast => V2Cast , Expression => V2Expression , FieldReference , GeneralScalarExpression , IdentityTransform , Literal => V2Literal , NamedReference , NamedTransform , NullOrdering => V2NullOrdering , SortDirection => V2SortDirection , SortOrder => V2SortOrder , SortValue , Transform }
34
+ import org .apache .spark .sql .connector .expressions .filter .{AlwaysFalse , AlwaysTrue }
34
35
import org .apache .spark .sql .errors .DataTypeErrors .toSQLId
35
36
import org .apache .spark .sql .errors .QueryCompilationErrors
36
37
import org .apache .spark .sql .types ._
@@ -205,4 +206,171 @@ object V2ExpressionUtils extends SQLConfHelper with Logging {
205
206
None
206
207
}
207
208
}
209
+
210
+ def toCatalyst (expr : V2Expression ): Option [Expression ] = expr match {
211
+ case _ : AlwaysTrue => Some (Literal .TrueLiteral )
212
+ case _ : AlwaysFalse => Some (Literal .FalseLiteral )
213
+ case l : V2Literal [_] => Some (Literal (l.value, l.dataType))
214
+ case r : NamedReference => Some (UnresolvedAttribute (r.fieldNames.toImmutableArraySeq))
215
+ case c : V2Cast => toCatalyst(c.expression).map(Cast (_, c.dataType, ansiEnabled = true ))
216
+ case e : GeneralScalarExpression => convertScalarExpr(e)
217
+ case _ => None
218
+ }
219
+
220
+ private def convertScalarExpr (expr : GeneralScalarExpression ): Option [Expression ] = {
221
+ convertPredicate(expr)
222
+ .orElse(convertConditionalFunc(expr))
223
+ .orElse(convertMathFunc(expr))
224
+ .orElse(convertBitwiseFunc(expr))
225
+ .orElse(convertTrigonometricFunc(expr))
226
+ }
227
+
228
+ private def convertPredicate (expr : GeneralScalarExpression ): Option [Expression ] = {
229
+ expr.name match {
230
+ case " IS_NULL" => convertUnaryExpr(expr, IsNull )
231
+ case " IS_NOT_NULL" => convertUnaryExpr(expr, IsNotNull )
232
+ case " NOT" => convertUnaryExpr(expr, Not )
233
+ case " =" => convertBinaryExpr(expr, EqualTo )
234
+ case " <=>" => convertBinaryExpr(expr, EqualNullSafe )
235
+ case " >" => convertBinaryExpr(expr, GreaterThan )
236
+ case " >=" => convertBinaryExpr(expr, GreaterThanOrEqual )
237
+ case " <" => convertBinaryExpr(expr, LessThan )
238
+ case " <=" => convertBinaryExpr(expr, LessThanOrEqual )
239
+ case " <>" => convertBinaryExpr(expr, (left, right) => Not (EqualTo (left, right)))
240
+ case " AND" => convertBinaryExpr(expr, And )
241
+ case " OR" => convertBinaryExpr(expr, Or )
242
+ case " STARTS_WITH" => convertBinaryExpr(expr, StartsWith )
243
+ case " ENDS_WITH" => convertBinaryExpr(expr, EndsWith )
244
+ case " CONTAINS" => convertBinaryExpr(expr, Contains )
245
+ case " IN" => convertExpr(expr, children => In (children.head, children.tail))
246
+ case _ => None
247
+ }
248
+ }
249
+
250
+ private def convertConditionalFunc (expr : GeneralScalarExpression ): Option [Expression ] = {
251
+ expr.name match {
252
+ case " CASE_WHEN" =>
253
+ convertExpr(expr, children =>
254
+ if (children.length % 2 == 0 ) {
255
+ val branches = children.grouped(2 ).map { case Seq (c, v) => (c, v) }.toSeq
256
+ CaseWhen (branches, None )
257
+ } else {
258
+ val (pairs, last) = children.splitAt(children.length - 1 )
259
+ val branches = pairs.grouped(2 ).map { case Seq (c, v) => (c, v) }.toSeq
260
+ CaseWhen (branches, Some (last.head))
261
+ })
262
+ case _ => None
263
+ }
264
+ }
265
+
266
+ private def convertMathFunc (expr : GeneralScalarExpression ): Option [Expression ] = {
267
+ expr.name match {
268
+ case " +" => convertBinaryExpr(expr, Add (_, _, evalMode = EvalMode .ANSI ))
269
+ case " -" =>
270
+ if (expr.children.length == 1 ) {
271
+ convertUnaryExpr(expr, UnaryMinus (_, failOnError = true ))
272
+ } else if (expr.children.length == 2 ) {
273
+ convertBinaryExpr(expr, Subtract (_, _, evalMode = EvalMode .ANSI ))
274
+ } else {
275
+ None
276
+ }
277
+ case " *" => convertBinaryExpr(expr, Multiply (_, _, evalMode = EvalMode .ANSI ))
278
+ case " /" => convertBinaryExpr(expr, Divide (_, _, evalMode = EvalMode .ANSI ))
279
+ case " %" => convertBinaryExpr(expr, Remainder (_, _, evalMode = EvalMode .ANSI ))
280
+ case " ABS" => convertUnaryExpr(expr, Abs (_, failOnError = true ))
281
+ case " COALESCE" => convertExpr(expr, Coalesce )
282
+ case " GREATEST" => convertExpr(expr, Greatest )
283
+ case " LEAST" => convertExpr(expr, Least )
284
+ case " RAND" =>
285
+ if (expr.children.isEmpty) {
286
+ Some (new Rand ())
287
+ } else if (expr.children.length == 1 ) {
288
+ convertUnaryExpr(expr, new Rand (_))
289
+ } else {
290
+ None
291
+ }
292
+ case " LOG" => convertBinaryExpr(expr, Logarithm )
293
+ case " LOG10" => convertUnaryExpr(expr, Log10 )
294
+ case " LOG2" => convertUnaryExpr(expr, Log2 )
295
+ case " LN" => convertUnaryExpr(expr, Log )
296
+ case " EXP" => convertUnaryExpr(expr, Exp )
297
+ case " POWER" => convertBinaryExpr(expr, Pow )
298
+ case " SQRT" => convertUnaryExpr(expr, Sqrt )
299
+ case " FLOOR" => convertUnaryExpr(expr, Floor )
300
+ case " CEIL" => convertUnaryExpr(expr, Ceil )
301
+ case " ROUND" => convertBinaryExpr(expr, Round (_, _, ansiEnabled = true ))
302
+ case " CBRT" => convertUnaryExpr(expr, Cbrt )
303
+ case " DEGREES" => convertUnaryExpr(expr, ToDegrees )
304
+ case " RADIANS" => convertUnaryExpr(expr, ToRadians )
305
+ case " SIGN" => convertUnaryExpr(expr, Signum )
306
+ case " WIDTH_BUCKET" =>
307
+ convertExpr(
308
+ expr,
309
+ children => WidthBucket (children(0 ), children(1 ), children(2 ), children(3 )))
310
+ case _ => None
311
+ }
312
+ }
313
+
314
+ private def convertTrigonometricFunc (expr : GeneralScalarExpression ): Option [Expression ] = {
315
+ expr.name match {
316
+ case " SIN" => convertUnaryExpr(expr, Sin )
317
+ case " SINH" => convertUnaryExpr(expr, Sinh )
318
+ case " COS" => convertUnaryExpr(expr, Cos )
319
+ case " COSH" => convertUnaryExpr(expr, Cosh )
320
+ case " TAN" => convertUnaryExpr(expr, Tan )
321
+ case " TANH" => convertUnaryExpr(expr, Tanh )
322
+ case " COT" => convertUnaryExpr(expr, Cot )
323
+ case " ASIN" => convertUnaryExpr(expr, Asin )
324
+ case " ASINH" => convertUnaryExpr(expr, Asinh )
325
+ case " ACOS" => convertUnaryExpr(expr, Acos )
326
+ case " ACOSH" => convertUnaryExpr(expr, Acosh )
327
+ case " ATAN" => convertUnaryExpr(expr, Atan )
328
+ case " ATANH" => convertUnaryExpr(expr, Atanh )
329
+ case " ATAN2" => convertBinaryExpr(expr, Atan2 )
330
+ case _ => None
331
+ }
332
+ }
333
+
334
+ private def convertBitwiseFunc (expr : GeneralScalarExpression ): Option [Expression ] = {
335
+ expr.name match {
336
+ case " ~" => convertUnaryExpr(expr, BitwiseNot )
337
+ case " &" => convertBinaryExpr(expr, BitwiseAnd )
338
+ case " |" => convertBinaryExpr(expr, BitwiseOr )
339
+ case " ^" => convertBinaryExpr(expr, BitwiseXor )
340
+ case _ => None
341
+ }
342
+ }
343
+
344
+ private def convertUnaryExpr (
345
+ expr : GeneralScalarExpression ,
346
+ catalystExprBuilder : Expression => Expression ): Option [Expression ] = {
347
+ expr.children match {
348
+ case Array (child) => toCatalyst(child).map(catalystExprBuilder)
349
+ case _ => None
350
+ }
351
+ }
352
+
353
+ private def convertBinaryExpr (
354
+ expr : GeneralScalarExpression ,
355
+ catalystExprBuilder : (Expression , Expression ) => Expression ): Option [Expression ] = {
356
+ expr.children match {
357
+ case Array (left, right) =>
358
+ for {
359
+ catalystLeft <- toCatalyst(left)
360
+ catalystRight <- toCatalyst(right)
361
+ } yield catalystExprBuilder(catalystLeft, catalystRight)
362
+ case _ => None
363
+ }
364
+ }
365
+
366
+ private def convertExpr (
367
+ expr : GeneralScalarExpression ,
368
+ catalystExprBuilder : Seq [Expression ] => Expression ): Option [Expression ] = {
369
+ val catalystChildren = expr.children.flatMap(toCatalyst).toImmutableArraySeq
370
+ if (expr.children.length == catalystChildren.length) {
371
+ Some (catalystExprBuilder(catalystChildren))
372
+ } else {
373
+ None
374
+ }
375
+ }
208
376
}
0 commit comments