@@ -48,7 +48,10 @@ use std::sync::Arc;
4848```"# ,
4949 standard_argument( name = "str" , prefix = "String" ) ,
5050 argument( name = "delimiter" , description = "String or character to split on." ) ,
51- argument( name = "pos" , description = "Position of the part to return." )
51+ argument(
52+ name = "pos" ,
53+ description = "Position of the part to return (counting from 1). Negative values count backward from the end of the string."
54+ )
5255) ]
5356#[ derive( Debug , PartialEq , Eq , Hash ) ]
5457pub struct SplitPartFunc {
@@ -233,7 +236,7 @@ where
233236 std:: cmp:: Ordering :: Less => {
234237 // Negative index: use rsplit().nth() to efficiently get from the end
235238 // rsplit iterates in reverse, so -1 means first from rsplit (index 0)
236- let idx: usize = ( -n - 1 ) . try_into ( ) . map_err ( |_| {
239+ let idx: usize = ( n . unsigned_abs ( ) - 1 ) . try_into ( ) . map_err ( |_| {
237240 exec_datafusion_err ! (
238241 "split_part index {n} exceeds minimum supported value"
239242 )
@@ -324,6 +327,20 @@ mod tests {
324327 Utf8 ,
325328 StringArray
326329 ) ;
330+ test_function ! (
331+ SplitPartFunc :: new( ) ,
332+ vec![
333+ ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( Some ( String :: from(
334+ "abc~@~def~@~ghi"
335+ ) ) ) ) ,
336+ ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( Some ( String :: from( "~@~" ) ) ) ) ,
337+ ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( i64 :: MIN ) ) ) ,
338+ ] ,
339+ Ok ( Some ( "" ) ) ,
340+ & str ,
341+ Utf8 ,
342+ StringArray
343+ ) ;
327344
328345 Ok ( ( ) )
329346 }
0 commit comments