Skip to content

Commit dc1a4b7

Browse files
authored
Merge pull request #1467 from Kotlin/cumSum-fixes
Cum sum fixes
2 parents bef0c40 + c8af8d1 commit dc1a4b7

File tree

3 files changed

+59
-11
lines changed
  • core/src

3 files changed

+59
-11
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/cumSum.kt

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,19 @@ import kotlin.reflect.KProperty
2424
* from the first cell to the last cell.
2525
*
2626
* __NOTE:__ If the column contains nullable values and [skipNA\] is set to `true`,
27-
* null and NaN values are skipped when computing the cumulative sum.
28-
* When false, all values after the first NA will be NaN (for Double and Float columns)
29-
* or null (for integer columns).
27+
* `null` and `NaN` values are skipped when computing the cumulative sum.
28+
* When `false`, all values after the first `NA` will be `NaN` (for [Double] and [Float] columns)
29+
* or `null` (for other columns).
3030
*
31-
* {@get [CumSumDocs.CUMSUM_PARAM] @param [columns\]
32-
* The names of the columns to apply cumSum operation.}
31+
* `cumSum` only works on columns that contain solely primitive numbers.
3332
*
34-
* @param [skipNA\] Whether to skip null and NaN values (default: `true`).
33+
* Similar to [sum][sum], [Byte][Byte]- and [Short][Short]-columns are converted to [Int][Int].
3534
*
35+
* {@get [CumSumDocs.CUMSUM_PARAM] @param [columns\] The selection of the columns to apply the `cumSum` operation to.
36+
* If not provided, `cumSum` will be applied to all primitive columns [at any depth][ColumnsSelectionDsl.colsAtAnyDepth].
37+
* }
38+
*
39+
* @param [skipNA\] Whether to skip `null` and `NaN` values (default: `true`).
3640
* @return A new {@get [CumSumDocs.DATA_TYPE]} of the same type with the cumulative sums.
3741
*
3842
* {@get [CumSumDocs.CUMSUM_PARAM] @see [Selecting Columns][SelectSelectingOptions].}
@@ -41,8 +45,11 @@ import kotlin.reflect.KProperty
4145
@ExcludeFromSources
4246
@Suppress("ClassName")
4347
private interface CumSumDocs {
48+
49+
// Can be emptied to disable information about selecting columns
4450
interface CUMSUM_PARAM
4551

52+
// Either [DataColumn] or [DataFrame]
4653
interface DATA_TYPE
4754
}
4855

@@ -157,10 +164,11 @@ public fun <T> DataFrame<T>.cumSum(
157164
* {@set [CumSumDocs.DATA_TYPE] [DataFrame]}
158165
* {@set [CumSumDocs.CUMSUM_PARAM]}
159166
*/
167+
@Refine
168+
@Interpretable("DataFrameCumSum0")
160169
public fun <T> DataFrame<T>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataFrame<T> =
161170
cumSum(skipNA) {
162-
// TODO keep at any depth?
163-
colsAtAnyDepth().filter { it.isNumber() }.cast()
171+
colsAtAnyDepth().filter { it.isPrimitiveOrMixedNumber() }.cast()
164172
}
165173

166174
// endregion
@@ -212,10 +220,11 @@ public fun <T, G> GroupBy<T, G>.cumSum(
212220
* {@set [CumSumDocs.DATA_TYPE] [GroupBy]}
213221
* {@set [CumSumDocs.CUMSUM_PARAM]}
214222
*/
223+
@Refine
224+
@Interpretable("GroupByCumSum0")
215225
public fun <T, G> GroupBy<T, G>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): GroupBy<T, G> =
216226
cumSum(skipNA) {
217-
// TODO keep at any depth?
218-
colsAtAnyDepth().filter { it.isNumber() }.cast()
227+
colsAtAnyDepth().filter { it.isPrimitiveOrMixedNumber() }.cast()
219228
}
220229

221230
// endregion

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/cumsum.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ internal fun DataColumn<Long?>.cumSumImpl(skipNA: Boolean): DataColumn<Long?> {
292292
* T : Number(?) -> T(?)
293293
*/
294294
public val cumSumTypeConversion: CalculateReturnType = { type, _ ->
295-
when (val type = type.withNullability(false)) {
295+
when (type.withNullability(false)) {
296296
// type changes to Int, carrying nullability
297297
typeOf<Short>(), typeOf<Byte>() -> typeOf<Int>().withNullability(type.isMarkedNullable)
298298

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/cumsum.kt

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package org.jetbrains.kotlinx.dataframe.statistics
22

3+
import io.kotest.assertions.throwables.shouldThrow
34
import io.kotest.matchers.shouldBe
45
import org.jetbrains.kotlinx.dataframe.DataColumn
56
import org.jetbrains.kotlinx.dataframe.api.columnOf
@@ -8,7 +9,10 @@ import org.jetbrains.kotlinx.dataframe.api.cumSum
89
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
910
import org.jetbrains.kotlinx.dataframe.api.groupBy
1011
import org.jetbrains.kotlinx.dataframe.api.map
12+
import org.jetbrains.kotlinx.dataframe.impl.nullableNothingType
13+
import org.jetbrains.kotlinx.dataframe.math.cumSumTypeConversion
1114
import org.junit.Test
15+
import kotlin.reflect.typeOf
1216

1317
@Suppress("ktlint:standard:argument-list-wrapping")
1418
class CumsumTests {
@@ -92,4 +96,39 @@ class CumsumTests {
9296
"c", 4,
9397
)
9498
}
99+
100+
@Test
101+
fun `df cumSum default`() {
102+
val df = dataFrameOf(
103+
"doubles" to columnOf(1.0, 2.0, null),
104+
"shorts" to columnOf(1.toShort(), 2.toShort(), null),
105+
"bigInts" to columnOf(1.toBigInteger(), 2.toBigInteger(), null),
106+
"mixed" to columnOf<Number?>(1.0, 2, null),
107+
)
108+
109+
val res = df.cumSum()
110+
111+
// works for Doubles, turns nulls into NaNs
112+
res["doubles"].values() shouldBe columnOf(1.0, 3.0, Double.NaN).values()
113+
// works for Shorts, turns into Ints, skips nulls
114+
res["shorts"].values() shouldBe columnOf(1, 3, null).values()
115+
// does not work for big numbers, keeps them as is
116+
res["bigInts"].values() shouldBe columnOf(1.toBigInteger(), 2.toBigInteger(), null).values()
117+
// works for mixed columns of primitives, number-unifies them; in this case to Doubles
118+
res["mixed"].values() shouldBe columnOf(1.0, 3.0, Double.NaN).values()
119+
}
120+
121+
@Test
122+
fun `cumSumTypeConversion tests`() {
123+
cumSumTypeConversion(typeOf<Int>(), false) shouldBe typeOf<Int>()
124+
cumSumTypeConversion(typeOf<Long?>(), false) shouldBe typeOf<Long?>()
125+
cumSumTypeConversion(typeOf<Short?>(), false) shouldBe typeOf<Int?>()
126+
cumSumTypeConversion(typeOf<Byte>(), false) shouldBe typeOf<Int>()
127+
cumSumTypeConversion(typeOf<Float?>(), false) shouldBe typeOf<Float>()
128+
cumSumTypeConversion(typeOf<Double?>(), false) shouldBe typeOf<Double>()
129+
cumSumTypeConversion(typeOf<Double>(), false) shouldBe typeOf<Double>()
130+
cumSumTypeConversion(nullableNothingType, false) shouldBe nullableNothingType
131+
132+
shouldThrow<IllegalStateException> { cumSumTypeConversion(typeOf<String>(), false) }
133+
}
95134
}

0 commit comments

Comments
 (0)