Skip to content

Commit

Permalink
Update UCCommitCoordinator
Browse files Browse the repository at this point in the history
  • Loading branch information
Omar Elhadidy committed Feb 26, 2025
1 parent 8efcf17 commit 7557a1e
Show file tree
Hide file tree
Showing 5 changed files with 250 additions and 5 deletions.
3 changes: 2 additions & 1 deletion run-integration-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,8 @@ def run_external_access_uc_managed_tables_integration_tests(root_dir, version, t
run_cmd(["build/sbt", "publishM2"])

test_dir = path.join(root_dir, \
path.join("spark", "src", "main", "java", "io", "delta", "commitcoordinator"))
path.join("spark", "src", "main", "scala", "org", "apache", "spark", "sql",
"delta", "coordinatedcommits", "integration_tests"))
test_files = [path.join(test_dir, f) for f in os.listdir(test_dir)
if path.isfile(path.join(test_dir, f)) and
f.endswith(".py") and not f.startswith("_")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ object CommitCoordinatorProvider {
}

private val initialCommitCoordinatorBuilders = Seq[CommitCoordinatorBuilder](
UCCommitCoordinatorBuilder,
new DynamoDBCommitCoordinatorClientBuilder()
)
initialCommitCoordinatorBuilders.foreach(registerBuilder)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#
# Copyright (2024) The Delta Lake Project Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import sys
import threading
import json

from pyspark.sql import SparkSession
import time
import uuid



"""
Run this script in root dir of repository:
===== Mandatory input from user =====
export CATALOG_NAME=___
export CATALOG_URI=___
export PAT_TOKEN=___
export TABLE_NAME=___
export SCHEMA=___
./run-integration-tests.py --use-local --external-access-uc-managed-tables-integration-tests \
--packages io.unitycatalog:unitycatalog-spark_2.12:0.2.1,org.apache.spark:spark-hadoop-cloud_2.12:3.5.4
"""

CATALOG_NAME = os.environ.get("CATALOG_NAME")
TOKEN = os.environ.get("PAT_TOKEN")
CATALOG_URI = os.environ.get("CATALOG_URI")
TABLE_NAME = os.environ.get("TABLE_NAME")
SCHEMA = os.environ.get("SCHEMA")

spark = SparkSession \
.builder \
.appName("coordinated_commit_tester") \
.master("local[*]") \
.config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
.config("spark.sql.catalog.spark_catalog", "io.unitycatalog.spark.UCSingleCatalog") \
.config(f"spark.sql.catalog.{CATALOG_NAME}", "io.unitycatalog.spark.UCSingleCatalog") \
.config(f"spark.sql.catalog.{CATALOG_NAME}.token", TOKEN) \
.config(f"spark.sql.catalog.{CATALOG_NAME}.uri", CATALOG_URI) \
.config(f"spark.sql.defaultCatalog", CATALOG_NAME) \
.config("spark.hadoop.fs.s3.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") \
.config("spark.databricks.delta.commitcoordinator.unity-catalog.impl", "org.delta.catalog.UCCoordinatedCommitClient") \
.getOrCreate()

expected_error_tag = "UNITY_CATALOG_EXTERNAL_COORDINATED_COMMITS_REQUEST_DENIED"

def create():
try:
spark.sql(f"CREATE TABLE {SCHEMA}.{TABLE_NAME} (a INT)")
except Exception:
print("[SUCCESS] Failed creating managed table using UC commit coordinator")

def insert():
try:
spark.sql(f"INSERT INTO {SCHEMA}.{TABLE_NAME} VALUES (1), (2)")
except Exception as error:
assert(expected_error_tag in str(error))
print("[SUCCESS] Failed writing to managed table using UC commit coordinator")

def update():
try:
spark.sql(f"UPDATE {SCHEMA}.{TABLE_NAME} SET a=4")
except Exception as error:
assert(expected_error_tag in str(error))
print("[SUCCESS] Failed updating managed table using UC commit coordinator")

def delete():
try:
spark.sql(f"DELETE FROM {SCHEMA}.{TABLE_NAME} where a=1")
except Exception as error:
assert(expected_error_tag in str(error))
print("[SUCCESS] Failed deleting from managed table using UC commit coordinator")

def read():
try:
res = spark.sql(f"SELECT * FROM {SCHEMA}.{TABLE_NAME}")
except Exception as error:
assert(expected_error_tag in str(error))
print("[SUCCESS] Failed reading from managed table using UC commit coordinator")

read()
insert()
update()
create()
delete()
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,8 @@ class UCCommitCoordinatorClientSuite extends UCCommitCoordinatorClientSuiteBase
override protected def commit(
version: Long,
timestamp: Long,
tableCommitCoordinatorClient: TableCommitCoordinatorClient,
tableIdentifier: Option[TableIdentifier] = None): JCommit = {
val commitResult = super.commit(
version, timestamp, tableCommitCoordinatorClient, tableIdentifier)
tableCommitCoordinatorClient: TableCommitCoordinatorClient): JCommit = {
val commitResult = super.commit(version, timestamp, tableCommitCoordinatorClient)
// As backfilling for UC happens after every commit asynchronously, we block here until
// the current in-progress backfill has completed in order to make tests deterministic.
waitForBackfill(version, tableCommitCoordinatorClient)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package org.apache.spark.sql.delta.coordinatedcommits

import java.net.URI
import java.util.{Optional, UUID}

import scala.collection.JavaConverters._

// scalastyle:off import.ordering.noEmptyLine
import org.apache.spark.sql.delta.DeltaConfigs.{COORDINATED_COMMITS_COORDINATOR_CONF, COORDINATED_COMMITS_COORDINATOR_NAME, COORDINATED_COMMITS_TABLE_CONF}
import org.apache.spark.sql.delta.DeltaLog
import org.apache.spark.sql.delta.actions.{Metadata, Protocol}
import org.apache.spark.sql.delta.metering.DeltaLogging
import org.apache.spark.sql.delta.test.DeltaTestImplicits._
import org.apache.spark.sql.delta.util.{FileNames, JsonUtils}
import io.delta.storage.commit.{CoordinatedCommitsUtils => JCoordinatedCommitsUtils, GetCommitsResponse => JGetCommitsResponse}
import io.delta.storage.commit.uccommitcoordinator.{UCClient, UCCommitCoordinatorClient}
import org.apache.hadoop.fs.Path
import org.mockito.ArgumentMatchers.anyString
import org.mockito.Mock
import org.mockito.Mockito
import org.mockito.Mockito.{mock, when}
import org.scalatest.time.SpanSugar._

import org.apache.spark.sql.types.{IntegerType, StringType, StructField}

trait UCCommitCoordinatorClientSuiteBase extends CommitCoordinatorClientImplSuiteBase
{
/**
* A unique table ID for each test.
*/
protected var tableUUID = UUID.randomUUID()

/**
* A unique metastore ID for each test.
*/
protected var metastoreId = UUID.randomUUID()

protected var ucClient: UCClient = _

@Mock
protected val mockFactory: UCClientFactory = mock(classOf[UCClientFactory])

protected var ucCommitCoordinator: InMemoryUCCommitCoordinator = _

override def beforeEach(): Unit = {
super.beforeEach()
tableUUID = UUID.randomUUID()
UCCommitCoordinatorClient.BACKFILL_LISTING_OFFSET = 100
metastoreId = UUID.randomUUID()
DeltaLog.clearCache()
Mockito.reset(mockFactory)
CommitCoordinatorProvider.clearAllBuilders()
UCCommitCoordinatorBuilder.ucClientFactory = mockFactory
UCCommitCoordinatorBuilder.clearCache()
CommitCoordinatorProvider.registerBuilder(UCCommitCoordinatorBuilder)
ucCommitCoordinator = new InMemoryUCCommitCoordinator()
ucClient = new InMemoryUCClient(metastoreId.toString, ucCommitCoordinator)
when(mockFactory.createUCClient(anyString(), anyString())).thenReturn(ucClient)
}
override protected def createTableCommitCoordinatorClient(
deltaLog: DeltaLog): TableCommitCoordinatorClient = {
var commitCoordinatorClient = UCCommitCoordinatorBuilder
.build(spark, Map(UCCommitCoordinatorClient.UC_METASTORE_ID_KEY -> metastoreId.toString))
.asInstanceOf[UCCommitCoordinatorClient]
commitCoordinatorClient = new UCCommitCoordinatorClient(
commitCoordinatorClient.conf,
commitCoordinatorClient.ucClient) with DeltaLogging {
override def recordDeltaEvent(opType: String, data: Any, path: Path): Unit = {
data match {
case ref: AnyRef => recordDeltaEvent(null, opType = opType, data = ref, path = Some(path))
case _ => super.recordDeltaEvent(opType, data, path)
}
}
}
// Initialize table ID for the calling test
// tableUUID = UUID.randomUUID().toString
commitCoordinatorClient.registerTable(
deltaLog.logPath, Optional.empty(), -1L, initMetadata(), Protocol(1, 1))
TableCommitCoordinatorClient(
commitCoordinatorClient,
deltaLog,
Map(UCCommitCoordinatorClient.UC_TABLE_ID_KEY -> tableUUID.toString)
)
}

override protected def registerBackfillOp(
tableCommitCoordinatorClient: TableCommitCoordinatorClient,
deltaLog: DeltaLog,
version: Long): Unit = {
ucClient.commit(
tableUUID.toString,
JCoordinatedCommitsUtils.getTablePath(deltaLog.logPath).toUri,
Optional.empty(),
Optional.of(version),
false,
Optional.empty(),
Optional.empty())
}

override protected def validateBackfillStrategy(
tableCommitCoordinatorClient: TableCommitCoordinatorClient,
logPath: Path,
version: Long): Unit = {
val response = tableCommitCoordinatorClient.getCommits()
assert(response.getCommits.size == 1)
assert(response.getCommits.asScala.head.getVersion == version)
assert(response.getLatestTableVersion == version)
}

protected def validateGetCommitsResult(
response: JGetCommitsResponse,
startVersion: Option[Long],
endVersion: Option[Long],
maxVersion: Long): Unit = {
val expectedVersions = endVersion.map { _ => Seq.empty }.getOrElse(Seq(maxVersion))
assert(response.getCommits.asScala.map(_.getVersion) == expectedVersions)
assert(response.getLatestTableVersion == maxVersion)
}

override protected def initMetadata(): Metadata = {
// Ensure that the metadata that is passed to registerTable has the
// correct table conf set.
Metadata(configuration = Map(
COORDINATED_COMMITS_TABLE_CONF.key ->
JsonUtils.toJson(Map(UCCommitCoordinatorClient.UC_TABLE_ID_KEY -> tableUUID.toString)),
COORDINATED_COMMITS_COORDINATOR_NAME.key -> UCCommitCoordinatorBuilder.getName,
COORDINATED_COMMITS_COORDINATOR_CONF.key ->
JsonUtils.toJson(
Map(UCCommitCoordinatorClient.UC_METASTORE_ID_KEY -> metastoreId.toString))))
}

protected def waitForBackfill(
version: Long,
tableCommitCoordinatorClient: TableCommitCoordinatorClient): Unit = {
eventually(timeout(10.seconds)) {
val logPath = tableCommitCoordinatorClient.logPath
val log = DeltaLog.forTable(spark, JCoordinatedCommitsUtils.getTablePath(logPath))
val fs = logPath.getFileSystem(log.newDeltaHadoopConf())
assert(fs.exists(FileNames.unsafeDeltaFile(logPath, version)))
}
}
}

0 comments on commit 7557a1e

Please sign in to comment.