diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/drop.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/drop.kt index ddd3eda148..74eb40b046 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/drop.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/drop.kt @@ -16,6 +16,7 @@ import org.jetbrains.kotlinx.dataframe.columns.size import org.jetbrains.kotlinx.dataframe.documentation.CommonTakeAndDropDocs import org.jetbrains.kotlinx.dataframe.documentation.CommonTakeAndDropWhileDocs import org.jetbrains.kotlinx.dataframe.documentation.TakeAndDropColumnsSelectionDslGrammar +import org.jetbrains.kotlinx.dataframe.impl.api.GroupByEntryImpl import org.jetbrains.kotlinx.dataframe.impl.columns.transform import org.jetbrains.kotlinx.dataframe.impl.columns.transformSingle import org.jetbrains.kotlinx.dataframe.index @@ -73,6 +74,42 @@ public inline fun DataFrame.dropWhile(predicate: RowFilter): DataFrame // endregion +// region GroupBy + +public inline fun GroupBy.dropEntries(crossinline predicate: GroupByEntryFilter): GroupBy = + filterEntries { !predicate(it, it) } + +/** + * Returns an adjusted [GroupBy] containing all entries except the first [n] entries. + * + * @throws IllegalArgumentException if [n] is negative. + */ +public fun GroupBy.dropEntries(n: Int): GroupBy { + require(n >= 0) { "Requested rows count $n is less than zero." } + return toDataFrame().drop(n).asGroupBy(groups.name()).cast() +} + +/** + * Returns an adjusted [GroupBy] containing all entries except the last [n] entries. + * + * @throws IllegalArgumentException if [n] is negative. + */ +public fun GroupBy.dropLastEntries(n: Int): GroupBy { + require(n >= 0) { "Requested rows count $n is less than zero." } + return toDataFrame().drop(n).asGroupBy(groups.name()).cast() +} + +/** + * Returns an adjusted [GroupBy] containing all entries except the first entries that satisfy the given [predicate]. + */ +public inline fun GroupBy.dropEntriesWhile(predicate: GroupByEntryFilter): GroupBy = + toDataFrame().dropWhile { + val entry = GroupByEntryImpl(it, groups) + predicate(entry, entry) + }.asGroupBy(groups.name()).cast() + +// endregion + // region ColumnsSelectionDsl /** diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/forEach.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/forEach.kt index 155ad6c142..f825c7ed8f 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/forEach.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/forEach.kt @@ -3,7 +3,6 @@ package org.jetbrains.kotlinx.dataframe.api import org.jetbrains.kotlinx.dataframe.DataColumn import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.RowExpression -import org.jetbrains.kotlinx.dataframe.columns.values // region DataColumn @@ -21,10 +20,16 @@ public inline fun DataFrame.forEach(action: RowExpression): Unit // region GroupBy +@Deprecated( + "Replaced with forEachEntry", + ReplaceWith("forEachEntry { val key = it\nval group = it.group()\nbody(key, group) }"), +) public inline fun GroupBy.forEach(body: (GroupBy.Entry) -> Unit): Unit = keys.forEach { key -> val group = groups[key.index()] body(GroupBy.Entry(key, group)) } +public inline fun GroupBy.forEachEntry(body: (GroupByEntry) -> Unit): Unit = + entriesAsSequence().forEach(body) // endregion diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt index 5c670a77ef..edb16175f7 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt @@ -13,6 +13,7 @@ import org.jetbrains.kotlinx.dataframe.annotations.Refine import org.jetbrains.kotlinx.dataframe.columns.FrameColumn import org.jetbrains.kotlinx.dataframe.columns.toColumnSet import org.jetbrains.kotlinx.dataframe.impl.aggregation.PivotImpl +import org.jetbrains.kotlinx.dataframe.impl.api.GroupByEntryImpl import org.jetbrains.kotlinx.dataframe.impl.api.getPivotColumnPaths import org.jetbrains.kotlinx.dataframe.impl.api.groupByImpl import org.jetbrains.kotlinx.dataframe.util.DEPRECATED_ACCESS_API @@ -71,10 +72,13 @@ public fun Pivot.groupByOther(): PivotGroupBy { // endregion +@Deprecated("Replaced by GroupByEntrySelector") public typealias GroupedRowSelector = GroupedDataRow.(GroupedDataRow) -> R +@Deprecated("Replaced by GroupByEntryFilter") public typealias GroupedRowFilter = GroupedRowSelector +@Deprecated("Replaced by GroupByEntry") public interface GroupedDataRow : DataRow { public fun group(): DataFrame @@ -82,6 +86,32 @@ public interface GroupedDataRow : DataRow { public val GroupedDataRow.group: DataFrame get() = group() +/** + * Represents a single combination of keys+group in a [GroupBy] instance. + * + * `this` is a [DataRow] representing the keys of the current group, while the [group()][group] + * function points to the group that corresponds to the keys of this entry. + * + * For example: + * ```kotlin + * df.groupBy { name and age }.forEachEntry { // this|it: GroupByEntry -> + * println("There are \${group().rowsCount()} instances of \$name") + * } + * ``` + */ +public interface GroupByEntry : DataRow { + + /** Returns the [DataFrame] representing the group that corresponds to the keys of this entry. */ + public fun group(): DataFrame + + // TODO? + public fun keys(): Map = this.toMap() +} + +public typealias GroupByEntrySelector = GroupByEntry.(GroupByEntry) -> R +public typealias GroupByEntryFilter = GroupByEntrySelector + +@Deprecated("Replaced by GroupByEntry") public data class GroupWithKey(val key: DataRow, val group: DataFrame) public interface GroupBy : Grouped { @@ -92,12 +122,16 @@ public interface GroupBy : Grouped { public fun updateGroups(transform: Selector, DataFrame>): GroupBy + @Deprecated("Replaced by filterEntries", ReplaceWith("filterEntries(predicate)")) public fun filter(predicate: GroupedRowFilter): GroupBy + public fun filterEntries(predicate: GroupByEntryFilter): GroupBy + @Refine @Interpretable("GroupByToDataFrame") public fun toDataFrame(groupedColumnName: String? = null): DataFrame + @Deprecated("") public data class Entry(val key: DataRow, val group: DataFrame) public companion object { @@ -117,3 +151,25 @@ public class ReducedGroupBy( @PublishedApi internal fun GroupBy.reduce(reducer: Selector, DataRow?>): ReducedGroupBy = ReducedGroupBy(this, reducer) + +/** + * Returns the total number of rows of this [GroupBy]-[DataFrame]. + * + * @return The number of rows in the [GroupBy]-[DataFrame]. + */ +public fun GroupBy<*, *>.rowsCount(): Int = groups.size() + +/** + * Retrieves all keys+group [entries][GroupByEntry] inside this [GroupBy]-[DataFrame]. + * @see entriesAsSequence + */ +public fun GroupBy.entries(): List> = entriesAsSequence().toList() + +/** + * Retrieves all keys+group [entries][GroupByEntry] inside this [GroupBy]-[DataFrame] as a [Sequence]. + * @see entries + */ +public fun GroupBy.entriesAsSequence(): Sequence> = + keys.asSequence().map { + GroupByEntryImpl(it, groups) + } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/map.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/map.kt index b030ade7b9..68efd76d9a 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/map.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/map.kt @@ -12,6 +12,7 @@ import org.jetbrains.kotlinx.dataframe.annotations.Interpretable import org.jetbrains.kotlinx.dataframe.annotations.Refine import org.jetbrains.kotlinx.dataframe.columns.ColumnReference import org.jetbrains.kotlinx.dataframe.columns.FrameColumn +import org.jetbrains.kotlinx.dataframe.impl.api.GroupByEntryImpl import org.jetbrains.kotlinx.dataframe.impl.columnName import org.jetbrains.kotlinx.dataframe.impl.columns.createComputedColumnReference import org.jetbrains.kotlinx.dataframe.impl.columns.newColumn @@ -141,6 +142,10 @@ public inline fun DataFrame.mapToFrame(body: AddDsl.() -> Unit): AnyFr // region GroupBy +@Deprecated( + "Replaced by mapEntries", + ReplaceWith("mapEntries { val key = it\nval group = it.group()\nbody(key, group) }"), +) public inline fun GroupBy.map(body: Selector, R>): List = keys.rows().mapIndexedNotNull { index, row -> val group = groups[index] @@ -148,10 +153,31 @@ public inline fun GroupBy.map(body: Selector, body(g, g) } +@Deprecated( + "Replaced by mapEntriesToRows", + ReplaceWith("mapEntriesToRows { val key = it\nval group = it.group()\nbody(key, group) }"), +) public fun GroupBy.mapToRows(body: Selector, DataRow?>): DataFrame = map(body).concat() +@Deprecated( + "Replaced by mapEntriesToFrames", + ReplaceWith("mapEntriesToFrames { val key = it\nval group = it.group()\nbody(key, group) }"), +) public fun GroupBy.mapToFrames(body: Selector, DataFrame>): FrameColumn = DataColumn.createFrameColumn(groups.name, map(body)) +public inline fun GroupBy.mapEntries(body: GroupByEntrySelector): List = + keys.rows().mapNotNull { row -> + val entry = GroupByEntryImpl(row, groups) + body(entry, entry) + } + +public fun GroupBy.mapEntriesToRows(body: GroupByEntrySelector?>): DataFrame = + mapEntries(body).concat() + +public fun GroupBy.mapEntriesToFrames( + body: GroupByEntrySelector>, +): FrameColumn = DataColumn.createFrameColumn(groups.name, mapEntries(body)) + // endregion diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/take.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/take.kt index 1a95b426b3..4c472876c1 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/take.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/take.kt @@ -7,7 +7,6 @@ import org.jetbrains.kotlinx.dataframe.DataRow import org.jetbrains.kotlinx.dataframe.RowFilter import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload import org.jetbrains.kotlinx.dataframe.annotations.Interpretable -import org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl import org.jetbrains.kotlinx.dataframe.columns.ColumnPath import org.jetbrains.kotlinx.dataframe.columns.ColumnSet import org.jetbrains.kotlinx.dataframe.columns.ColumnWithPath @@ -16,9 +15,9 @@ import org.jetbrains.kotlinx.dataframe.columns.size import org.jetbrains.kotlinx.dataframe.documentation.CommonTakeAndDropDocs import org.jetbrains.kotlinx.dataframe.documentation.CommonTakeAndDropWhileDocs import org.jetbrains.kotlinx.dataframe.documentation.TakeAndDropColumnsSelectionDslGrammar +import org.jetbrains.kotlinx.dataframe.impl.api.GroupByEntryImpl import org.jetbrains.kotlinx.dataframe.impl.columns.transform import org.jetbrains.kotlinx.dataframe.impl.columns.transformSingle -import org.jetbrains.kotlinx.dataframe.index import org.jetbrains.kotlinx.dataframe.nrow import org.jetbrains.kotlinx.dataframe.util.DEPRECATED_ACCESS_API import kotlin.reflect.KProperty @@ -66,6 +65,39 @@ public inline fun DataFrame.takeWhile(predicate: RowFilter): DataFrame // endregion +// region GroupBy + +/** + * Returns an adjusted [GroupBy] containing first [n] entries. + * + * @throws IllegalArgumentException if [n] is negative. + */ +public fun GroupBy.takeEntries(n: Int): GroupBy { + require(n >= 0) { "Requested rows count $n is less than zero." } + return toDataFrame().take(n).asGroupBy(groups.name()).cast() +} + +/** + * Returns an adjusted [GroupBy] containing last [n] entries. + * + * @throws IllegalArgumentException if [n] is negative. + */ +public fun GroupBy.takeLastEntries(n: Int): GroupBy { + require(n >= 0) { "Requested rows count $n is less than zero." } + return toDataFrame().takeLast(n).asGroupBy(groups.name()).cast() +} + +/** + * Returns an adjusted [GroupBy] containing the first entries that satisfy the given [predicate]. + */ +public inline fun GroupBy.takeEntriesWhile(predicate: GroupByEntryFilter): GroupBy = + toDataFrame().takeWhile { + val entry = GroupByEntryImpl(it, groups) + predicate(entry, entry) + }.asGroupBy(groups.name()).cast() + +// endregion + // region ColumnsSelectionDsl /** diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt index eab8acff76..f964f93ec5 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt @@ -7,6 +7,7 @@ import org.jetbrains.kotlinx.dataframe.Selector import org.jetbrains.kotlinx.dataframe.aggregation.AggregateGroupedBody import org.jetbrains.kotlinx.dataframe.aggregation.NamedValue import org.jetbrains.kotlinx.dataframe.api.GroupBy +import org.jetbrains.kotlinx.dataframe.api.GroupByEntryFilter import org.jetbrains.kotlinx.dataframe.api.GroupedRowFilter import org.jetbrains.kotlinx.dataframe.api.asGroupBy import org.jetbrains.kotlinx.dataframe.api.concat @@ -18,11 +19,13 @@ import org.jetbrains.kotlinx.dataframe.api.isColumnGroup import org.jetbrains.kotlinx.dataframe.api.pathOf import org.jetbrains.kotlinx.dataframe.api.remove import org.jetbrains.kotlinx.dataframe.api.rename +import org.jetbrains.kotlinx.dataframe.api.with import org.jetbrains.kotlinx.dataframe.columns.FrameColumn import org.jetbrains.kotlinx.dataframe.impl.aggregation.AggregatableInternal import org.jetbrains.kotlinx.dataframe.impl.aggregation.GroupByReceiverImpl import org.jetbrains.kotlinx.dataframe.impl.api.AggregatedPivot import org.jetbrains.kotlinx.dataframe.impl.api.ColumnToInsert +import org.jetbrains.kotlinx.dataframe.impl.api.GroupByEntryImpl import org.jetbrains.kotlinx.dataframe.impl.api.GroupedDataRowImpl import org.jetbrains.kotlinx.dataframe.impl.api.insertImpl import org.jetbrains.kotlinx.dataframe.impl.api.removeImpl @@ -41,29 +44,40 @@ internal class GroupByImpl( ) : GroupBy, AggregatableInternal { - override val keys by lazy { df.remove(groups) } + override val keys by lazy { df.remove { groups } } - override fun updateGroups(transform: Selector, DataFrame>) = - df.convert(groups) { transform(it, it) }.asGroupBy(groups.name()) as GroupBy + @Suppress("UNCHECKED_CAST") + override fun updateGroups(transform: Selector, DataFrame>): GroupBy = + df.convert { groups }.with { transform(it, it) } + .asGroupBy { frameCol(groups.name()) } override fun toString() = df.toString() override fun remainingColumnsSelector(): ColumnsSelector<*, *> = keyColumnsInGroups.toColumnSet().let { groupCols -> { all().except(groupCols) } } + @Deprecated("Replaced by filterEntries") override fun filter(predicate: GroupedRowFilter): GroupBy { val indices = (0 until df.nrow).filter { - val row = GroupedDataRowImpl(df.get(it), groups) + val row = GroupedDataRowImpl(df[it], groups) predicate(row, row) } - return df[indices].asGroupBy(groups) + return df[indices].asGroupBy { frameCol(groups.name()) } + } + + override fun filterEntries(predicate: GroupByEntryFilter): GroupBy { + val indices = (0 until df.nrow).filter { + val row = GroupByEntryImpl(df[it], groups) + predicate(row, row) + } + return df[indices].asGroupBy { frameCol(groups.name()) } } override fun toDataFrame(groupedColumnName: String?): DataFrame = if (groupedColumnName == null || groupedColumnName == groups.name()) { df } else { - df.rename(groups).into(groupedColumnName) + df.rename { groups }.into(groupedColumnName) } } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/groupBy.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/groupBy.kt index 6902b9a795..c10adae749 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/groupBy.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/groupBy.kt @@ -4,6 +4,7 @@ import org.jetbrains.kotlinx.dataframe.ColumnsSelector import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.DataRow import org.jetbrains.kotlinx.dataframe.api.GroupBy +import org.jetbrains.kotlinx.dataframe.api.GroupByEntry import org.jetbrains.kotlinx.dataframe.api.GroupedDataRow import org.jetbrains.kotlinx.dataframe.api.cast import org.jetbrains.kotlinx.dataframe.api.getColumnsWithPaths @@ -13,7 +14,9 @@ import org.jetbrains.kotlinx.dataframe.api.pathOf import org.jetbrains.kotlinx.dataframe.columns.FrameColumn import org.jetbrains.kotlinx.dataframe.impl.GroupByImpl import org.jetbrains.kotlinx.dataframe.impl.nameGenerator +import org.jetbrains.kotlinx.dataframe.io.renderToString +@Deprecated("Replaced by GroupByEntryImpl") internal class GroupedDataRowImpl(private val row: DataRow, private val frameCol: FrameColumn) : GroupedDataRow, DataRow by row { @@ -21,6 +24,15 @@ internal class GroupedDataRowImpl(private val row: DataRow, private val override fun group() = frameCol[row.index()] } +@PublishedApi +internal class GroupByEntryImpl(private val keysRow: DataRow, internal val allGroups: FrameColumn) : + GroupByEntry, + DataRow by keysRow { + override fun group() = allGroups[keysRow.index()] + + override fun toString(): String = "GroupByEntry(keysRow=${renderToString()}, group()=${group()})" +} + @PublishedApi internal fun DataFrame.groupByImpl(moveToTop: Boolean, columns: ColumnsSelector): GroupBy { val nameGenerator = nameGenerator(GroupBy.groupedColumnAccessor.name()) diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt index 2632a649b4..74d78391f6 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt @@ -2,11 +2,13 @@ package org.jetbrains.kotlinx.dataframe.api import io.kotest.matchers.shouldBe import org.jetbrains.kotlinx.dataframe.AnyFrame +import org.jetbrains.kotlinx.dataframe.testSets.person.BaseTest +import org.jetbrains.kotlinx.dataframe.testSets.person.age import org.junit.Test import kotlin.reflect.typeOf @Suppress("ktlint:standard:argument-list-wrapping") -class GroupByTests { +class GroupByTests : BaseTest() { @Test fun `groupBy values with nulls`() { @@ -56,4 +58,80 @@ class GroupByTests { getFrameColumn("d") into "e" }["e"].type() shouldBe typeOf>() } + + @Test + fun `groupBy forEachEntry`() { + val grouped = typed.groupBy { age } + val entries1 = buildList { + grouped.forEach { (key, group) -> + add(key.toMap() to group) + } + } + val entries2 = buildList { + grouped.forEachEntry { + add(it.toMap() to it.group()) + } + } + + entries1 shouldBe entries2 + } + + @Test + fun `groupBy mapEntries`() { + // old mapToRows, and mapToFrames stick to the same return type, so let's make the types Any? for the test + val grouped: GroupBy = df.groupBy { age } + val entries1 = grouped.map { (key, group) -> + key.toMap() to group + } + val entries2 = grouped.mapEntries { + it.toMap() to it.group() + } + entries1 shouldBe entries2 + + val entries3 = grouped.mapToRows { (key, group) -> + listOf(key.toMap() to group).toDataFrame().single() + } + val entries4 = grouped.mapEntriesToRows { + listOf(it.toMap() to it.group()).toDataFrame().single() + } + entries3 shouldBe entries4 + + val entries5 = grouped.mapToFrames { (key, group) -> + listOf(key.toMap() to group).toDataFrame() + } + val entries6 = grouped.mapEntriesToFrames { + listOf(it.toMap() to it.group()).toDataFrame() + } + entries5 shouldBe entries6 + + // let's test the -Entries variants with typed versions + val grouped2 = typed.groupBy { age } + + val entries7 = grouped2.mapEntries { + it.toMap() to it.group() + } + val entries8 = grouped2.mapEntriesToRows { + listOf(it.toMap() to it.group()).toDataFrame().single() + }.toList() + val entries9 = grouped2.mapEntriesToFrames { + listOf(it.toMap() to it.group()).toDataFrame() + }.map { it[0][0] to it[0][1] }.toList() + entries7 shouldBe entries8 + entries8 shouldBe entries9 + } + + @Test + fun `groupBy filterEntries`() { + val grouped = typed.groupBy { age } + + val entries1 = grouped.filter { age == 20 } + .mapEntries { + it.toMap() to it.group() + } + val entries2 = grouped.filterEntries { age == 20 } + .mapEntries { + it.toMap() to it.group() + } + entries1 shouldBe entries2 + } }