Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spark] Support committing multiple column changes in one txn #4186

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,23 @@ object DeltaOperations {

override val isInPlaceFileMetadataUpdate: Option[Boolean] = Some(false)
}

/** Recorded when columns are changed in bulk. */
case class ChangeColumns(columns: Seq[ChangeColumn]) extends Operation("CHANGE COLUMNS") {

override val parameters: Map[String, Any] = Map(
"columns" -> JsonUtils.toJson(
columns.map(col =>
structFieldToMap(col.columnPath, col.newColumn) ++ col.colPosition.map("position" -> _))
)
)

// This operation shouldn't be introducing AddFile actions at all. This check should be trivial.
override def checkAddFileWithDeletionVectorStatsAreNotTightBounds: Boolean = true

override val isInPlaceFileMetadataUpdate: Option[Boolean] = Some(false)
}

/** Recorded when columns are replaced. */
case class ReplaceColumns(
columns: Seq[StructField]) extends Operation("REPLACE COLUMNS") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import scala.util.control.NonFatal

import com.databricks.spark.util.TagDefinitions.TAG_LOG_STORE_CLASS
import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.DeltaOperations.{ChangeColumn, CreateTable, Operation, ReplaceColumns, ReplaceTable, UpdateSchema}
import org.apache.spark.sql.delta.DeltaOperations.{ChangeColumn, ChangeColumns, CreateTable, Operation, ReplaceColumns, ReplaceTable, UpdateSchema}
import org.apache.spark.sql.delta.RowId.RowTrackingMetadataDomain
import org.apache.spark.sql.delta.actions._
import org.apache.spark.sql.delta.catalog.DeltaTableV2
Expand Down Expand Up @@ -2751,6 +2751,9 @@ trait OptimisticTransactionImpl extends DeltaTransaction
case change: ChangeColumn if usesDefaults(change.newColumn) =>
throwError("WRONG_COLUMN_DEFAULTS_FOR_DELTA_FEATURE_NOT_ENABLED",
Array("ALTER TABLE"))
case changes: ChangeColumns if changes.columns.exists(c => usesDefaults(c.newColumn)) =>
throwError("WRONG_COLUMN_DEFAULTS_FOR_DELTA_FEATURE_NOT_ENABLED",
Array("ALTER TABLE"))
case create: CreateTable if create.metadata.schema.fields.exists(usesDefaults) =>
throwError("WRONG_COLUMN_DEFAULTS_FOR_DELTA_FEATURE_NOT_ENABLED",
Array("CREATE TABLE"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,9 +660,7 @@ class DeltaCatalog extends DelegatingCatalogExtension
case _ => return super.alterTable(ident, changes: _*)
}

// Whether this is an ALTER TABLE ALTER COLUMN SYNC IDENTITY command.
var syncIdentity = false
val columnUpdates = new mutable.HashMap[Seq[String], (StructField, Option[ColumnPosition])]()
val columnUpdates = new mutable.HashMap[Seq[String], DeltaChangeColumnSpec]()
val isReplaceColumnsCommand = grouped.get(classOf[DeleteColumn]) match {
case Some(deletes) if grouped.contains(classOf[AddColumn]) =>
// Convert to Seq so that contains method works
Expand Down Expand Up @@ -745,7 +743,8 @@ class DeltaCatalog extends DelegatingCatalogExtension
} else {
snapshotSchema
}
def getColumn(fieldNames: Seq[String]): (StructField, Option[ColumnPosition]) = {
def getColumn(fieldNames: Seq[String])
: DeltaChangeColumnSpec = {
columnUpdates.getOrElseUpdate(fieldNames, {
val colName = UnresolvedAttribute(fieldNames).name
val fieldOpt = schema.findNestedField(fieldNames, includeCollections = true,
Expand All @@ -754,7 +753,12 @@ class DeltaCatalog extends DelegatingCatalogExtension
val field = fieldOpt.getOrElse {
throw DeltaErrors.nonExistentColumnInSchema(colName, schema.treeString)
}
field -> None
DeltaChangeColumnSpec(
fieldNames.init,
fieldNames.last,
field,
colPosition = None,
syncIdentity = false)
})
}

Expand All @@ -768,59 +772,63 @@ class DeltaCatalog extends DelegatingCatalogExtension
disallowedColumnChangesOnIdentityColumns.foreach {
case change: ColumnChange =>
val field = change.fieldNames()
val (existingField, _) = getColumn(field)
if (ColumnWithDefaultExprUtils.isIdentityColumn(existingField)) {
val spec = getColumn(field)
if (ColumnWithDefaultExprUtils.isIdentityColumn(spec.newColumn)) {
throw DeltaErrors.identityColumnAlterColumnNotSupported()
}
}

columnChanges.foreach {
case comment: UpdateColumnComment =>
val field = comment.fieldNames()
val (oldField, pos) = getColumn(field)
columnUpdates(field) = oldField.withComment(comment.newComment()) -> pos
val spec = getColumn(field)
columnUpdates(field) = spec.copy(
newColumn = spec.newColumn.withComment(comment.newComment()))

case dataType: UpdateColumnType =>
val field = dataType.fieldNames()
val (oldField, pos) = getColumn(field)
columnUpdates(field) = oldField.copy(dataType = dataType.newDataType()) -> pos
val spec = getColumn(field)
columnUpdates(field) = spec.copy(
newColumn = spec.newColumn.copy(dataType = dataType.newDataType()))

case position: UpdateColumnPosition =>
val field = position.fieldNames()
val (oldField, pos) = getColumn(field)
columnUpdates(field) = oldField -> Option(position.position())
val spec = getColumn(field)
columnUpdates(field) = spec.copy(colPosition = Option(position.position()))

case nullability: UpdateColumnNullability =>
val field = nullability.fieldNames()
val (oldField, pos) = getColumn(field)
columnUpdates(field) = oldField.copy(nullable = nullability.nullable()) -> pos
val spec = getColumn(field)
columnUpdates(field) = spec.copy(
newColumn = spec.newColumn.copy(nullable = nullability.nullable()))

case rename: RenameColumn =>
val field = rename.fieldNames()
val (oldField, pos) = getColumn(field)
columnUpdates(field) = oldField.copy(name = rename.newName()) -> pos
val spec = getColumn(field)
columnUpdates(field) = spec.copy(
newColumn = spec.newColumn.copy(name = rename.newName()))

case sync: SyncIdentity =>
syncIdentity = true
val field = sync.fieldNames
val (oldField, pos) = getColumn(field)
if (!ColumnWithDefaultExprUtils.isIdentityColumn(oldField)) {
val spec = getColumn(field).copy(syncIdentity = true)
columnUpdates(field) = spec
if (!ColumnWithDefaultExprUtils.isIdentityColumn(spec.newColumn)) {
throw DeltaErrors.identityColumnAlterNonIdentityColumnError()
}
// If the IDENTITY column does not allow explicit insert, high water mark should
// always be sync'ed and this is an no-op.
if (IdentityColumn.allowExplicitInsert(oldField)) {
columnUpdates(field) = oldField.copy() -> pos
// always be sync'ed and this is a no-op.
if (IdentityColumn.allowExplicitInsert(spec.newColumn)) {
columnUpdates(field) = spec
}

case updateDefault: UpdateColumnDefaultValue =>
val field = updateDefault.fieldNames()
val (oldField, pos) = getColumn(field)
val spec = getColumn(field)
val updatedField = updateDefault.newDefaultValue() match {
case "" => oldField.clearCurrentDefaultValue()
case newDefault => oldField.withCurrentDefaultValue(newDefault)
case "" => spec.newColumn.clearCurrentDefaultValue()
case newDefault => spec.newColumn.withCurrentDefaultValue(newDefault)
}
columnUpdates(field) = updatedField -> pos
columnUpdates(field) = spec.copy(newColumn = updatedField)

case other =>
throw DeltaErrors.unrecognizedColumnChange(s"${other.getClass}")
Expand Down Expand Up @@ -872,14 +880,8 @@ class DeltaCatalog extends DelegatingCatalogExtension
}
}

columnUpdates.foreach { case (fieldNames, (newField, newPositionOpt)) =>
AlterTableChangeColumnDeltaCommand(
table,
fieldNames.dropRight(1),
fieldNames.last,
newField,
newPositionOpt,
syncIdentity = syncIdentity).run(spark)
if (columnUpdates.nonEmpty) {
AlterTableChangeColumnDeltaCommand(table, columnUpdates.values.toSeq).run(spark)
}

loadTable(ident)
Expand Down
Loading
Loading