Skip to content

Commit

Permalink
Update UCCommitCoordinator
Browse files Browse the repository at this point in the history
  • Loading branch information
Omar Elhadidy committed Feb 25, 2025
1 parent 9c932bf commit 8efcf17
Show file tree
Hide file tree
Showing 18 changed files with 2,977 additions and 4 deletions.
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ lazy val spark = (project in file("spark"))
"org.apache.spark" %% "spark-core" % sparkVersion.value % "test" classifier "tests",
"org.apache.spark" %% "spark-sql" % sparkVersion.value % "test" classifier "tests",
"org.apache.spark" %% "spark-hive" % sparkVersion.value % "test" classifier "tests",
"org.mockito" % "mockito-inline" % "4.11.0" % "test",
),
Compile / packageBin / mappings := (Compile / packageBin / mappings).value ++
listPythonFiles(baseDirectory.value.getParentFile / "python"),
Expand Down
57 changes: 53 additions & 4 deletions run-integration-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,43 @@ def run_pip_installation_tests(root_dir, version, use_testpypi, use_localpypi, e
print("Failed pip installation tests in %s" % (test_file))
raise

def run_external_access_uc_managed_tables_integration_tests(root_dir, version, test_name, use_local, extra_packages):
print(
"\n\n##### Running External uc managed tables integration tests on version %s #####" % str(version)
)

if use_local:
clear_artifact_cache()
run_cmd(["build/sbt", "publishM2"])

test_dir = path.join(root_dir, \
path.join("spark", "src", "main", "java", "io", "delta", "commitcoordinator"))
test_files = [path.join(test_dir, f) for f in os.listdir(test_dir)
if path.isfile(path.join(test_dir, f)) and
f.endswith(".py") and not f.startswith("_")]

print("\n\nTests compiled\n\n")

python_root_dir = path.join(root_dir, "python")
extra_class_path = path.join(python_root_dir, path.join("delta", "testing"))
packages = "io.delta:delta-%s_2.12:%s" % (get_artifact_name(version), version)
if extra_packages:
packages += "," + extra_packages

for test_file in test_files:
if test_name is not None and test_name not in test_file:
print("\nSkipping External uc managed tables integration tests in %s\n============" % test_file)
continue
try:
cmd = ["spark-submit",
"--driver-class-path=%s" % extra_class_path, # for less verbose logging
"--packages", packages] + [test_file]
print("\nRunning External uc managed tables integration tests in %s\n=============" % test_file)
print("Command: %s" % " ".join(cmd))
run_cmd(cmd, stream_output=True)
except:
print("Failed UC coordinated commitor integration tests in %s" % (test_file))
raise

def clear_artifact_cache():
print("Clearing Delta artifacts from ivy2 and mvn cache")
Expand Down Expand Up @@ -495,10 +532,10 @@ def __exit__(self, tpe, value, traceback):
action="store_true",
help="Run the DynamoDB integration tests (and only them)")
parser.add_argument(
"--dbb-packages",
"--packages",
required=False,
default=None,
help="Additional packages required for Dynamodb logstore integration tests")
help="Additional packages required for integration tests")
parser.add_argument(
"--dbb-conf",
required=False,
Expand Down Expand Up @@ -544,6 +581,13 @@ def __exit__(self, tpe, value, traceback):
default="0.15.0",
help="Hudi library version"
)
parser.add_argument(
"--external-access-uc-managed-tables-integration-tests",
required=False,
default=False,
action="store_true",
help="Run the External access to UC managed table tests (and only them)"
)

args = parser.parse_args()

Expand Down Expand Up @@ -574,18 +618,23 @@ def __exit__(self, tpe, value, traceback):

if args.run_storage_s3_dynamodb_integration_tests:
run_dynamodb_logstore_integration_tests(root_dir, args.version, args.test, args.maven_repo,
args.dbb_packages, args.dbb_conf, args.use_local)
args.packages, args.dbb_conf, args.use_local)
quit()

if args.run_dynamodb_commit_coordinator_integration_tests:
run_dynamodb_commit_coordinator_integration_tests(root_dir, args.version, args.test, args.maven_repo,
args.dbb_packages, args.dbb_conf, args.use_local)
args.packages, args.dbb_conf, args.use_local)
quit()

if args.s3_log_store_util_only:
run_s3_log_store_util_integration_tests()
quit()

if args.external_access_uc_managed_tables_integration_tests:
run_external_access_uc_managed_tables_integration_tests(root_dir, args.version, args.test, args.use_local,
args.packages)
quit()

if run_scala:
run_scala_integration_tests(root_dir, args.version, args.test, args.maven_repo,
args.scala_version, args.use_local)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ object CommitCoordinatorProvider {
nameToBuilderMapping.retain((k, _) => initialCommitCoordinatorNames.contains(k))
}

private[delta] def clearAllBuilders(): Unit = synchronized {
nameToBuilderMapping.clear()
}

private val initialCommitCoordinatorBuilders = Seq[CommitCoordinatorBuilder](
new DynamoDBCommitCoordinatorClientBuilder()
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
/*
* Copyright (2021) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta.coordinatedcommits

import java.net.{URI, URISyntaxException}
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._
import scala.util.control.NonFatal

import org.apache.spark.sql.delta.logging.DeltaLogKeys
import org.apache.spark.sql.delta.metering.DeltaLogging
import io.delta.storage.commit.CommitCoordinatorClient
import io.delta.storage.commit.uccommitcoordinator.{UCClient, UCCommitCoordinatorClient, UCTokenBasedRestClient}

import org.apache.spark.internal.MDC
import org.apache.spark.internal.MDC
import org.apache.spark.sql.SparkSession

/**
* Builder for Unity Catalog Commit Coordinator Clients.
*
* This builder is responsible for creating and caching UCCommitCoordinatorClient instances
* based on the provided metastore IDs and catalog configurations.
*
* It caches the UCCommitCoordinatorClient instance for a given metastore ID upon its first access.
*/
object UCCommitCoordinatorBuilder extends CommitCoordinatorBuilder with DeltaLogging {

/** Prefix for Spark SQL catalog configurations. */
final private val SPARK_SQL_CATALOG_PREFIX = "spark.sql.catalog."

/** Connector class name for filtering relevant Unity Catalog catalogs. */
final private val UNITY_CATALOG_CONNECTOR_CLASS: String =
"io.unitycatalog.connectors.spark.UCSingleCatalog"

/** Suffix for the URI configuration of a catalog. */
final private val URI_SUFFIX = "uri"

/** Suffix for the token configuration of a catalog. */
final private val TOKEN_SUFFIX = "token"

/** Cache for UCCommitCoordinatorClient instances. */
private val commitCoordinatorClientCache =
new ConcurrentHashMap[String, UCCommitCoordinatorClient]()

// Helper cache for (uri, token) to metastoreId to avoid redundant calls to getMetastoreId
// catalog.
private val uriTokenToMetastoreIdCache = new ConcurrentHashMap[(String, String), String]()

// Use a var instead of val for ease of testing by injecting different UCClientFactory.
private[delta] var ucClientFactory: UCClientFactory = UCTokenBasedRestClientFactory

override def getName: String = "unity-catalog"

override def build(spark: SparkSession, conf: Map[String, String]): CommitCoordinatorClient = {
val metastoreId = conf.getOrElse(
UCCommitCoordinatorClient.UC_METASTORE_ID_KEY,
throw new IllegalArgumentException(
s"UC metastore ID not found in the provided coordinator conf: $conf"))

commitCoordinatorClientCache.computeIfAbsent(
metastoreId,
_ => new UCCommitCoordinatorClient(conf.asJava, getMatchingUCClient(spark, metastoreId))
)
}

/**
* Finds and returns a UCClient that matches the given metastore ID.
*
* This method iterates through all configured catalogs in SparkSession, creates UCClients for
* each, gets their metastore ID and returns the one that matches the provided metastore ID.
* If no matching catalog is found or if multiple matching catalogs are found, it throws an
* appropriate exception.
*/
private def getMatchingUCClient(spark: SparkSession, metastoreId: String): UCClient = {
val matchingClients: List[(String, String)] = getCatalogConfigs(spark)
.map { case (name, uri, token) => (uri, token) }
.distinct // Remove duplicates since multiple catalogs can have the same uri and token
.filter { case (uri, token) => getMetastoreId(uri, token).contains(metastoreId) }

matchingClients match {
case Nil => throw noMatchingCatalogException(metastoreId)
case (uri, token) :: Nil => ucClientFactory.createUCClient(uri, token)
case multiple => throw multipleMatchingCatalogs(metastoreId, multiple.map(_._1))
}
}

/**
* Retrieves the metastore ID for a given URI and token.
*
* This method creates a UCClient using the provided URI and token, then retrieves its metastore
* ID. The result is cached to avoid unnecessary getMetastoreId requests in future calls. If
* there's an error, it returns None and logs a warning.
*/
private def getMetastoreId(uri: String, token: String): Option[String] = {
try {
val metastoreId = uriTokenToMetastoreIdCache.computeIfAbsent(
(uri, token),
_ => {
val ucClient = ucClientFactory.createUCClient(uri, token)
try {
ucClient.getMetastoreId
} finally {
safeClose(ucClient, uri)
}
})
Some(metastoreId)
} catch {
case NonFatal(e) =>
logWarning(log"Failed to getMetastoreSummary with ${MDC(DeltaLogKeys.URI, uri)}", e)
None
}
}

private def noMatchingCatalogException(metastoreId: String) = {
new IllegalStateException(
s"No matching catalog found for UC metastore ID $metastoreId. " +
"Please ensure the catalog is configured correctly by setting " +
"`spark.sql.catalog.<catalog-name>`, `spark.sql.catalog.<catalog-name>.uri` and " +
"`spark.sql.catalog.<catalog-name>.token`. Note that the matching process involves " +
"retrieving the metastoreId using the provided `<uri, token>` pairs in Spark " +
"Session configs.")
}

private def multipleMatchingCatalogs(metastoreId: String, uris: List[String]) = {
new IllegalStateException(
s"Found multiple catalogs for UC metastore ID $metastoreId at $uris. " +
"Please ensure the catalog is configured correctly by setting " +
"`spark.sql.catalog.<catalog-name>`, `spark.sql.catalog.<catalog-name>.uri` and " +
"`spark.sql.catalog.<catalog-name>.token`. Note that the matching process involves " +
"retrieving the metastoreId using the provided `<uri, token>` pairs in Spark " +
"Session configs.")
}

/**
* Retrieves the catalog configurations from the SparkSession.
*
* Example; Given Spark configurations:
* spark.sql.catalog.catalog1 = "io.unitycatalog.connectors.spark.UCSingleCatalog"
* spark.sql.catalog.catalog1.uri = "https://dbc-123abc.databricks.com"
* spark.sql.catalog.catalog1.token = "dapi1234567890"
*
* spark.sql.catalog.catalog2 = "io.unitycatalog.connectors.spark.UCSingleCatalog"
* spark.sql.catalog.catalog2.uri = "https://dbc-456def.databricks.com"
* spark.sql.catalog.catalog2.token = "dapi0987654321"
*
* spark.sql.catalog.catalog3 = "io.unitycatalog.connectors.spark.UCSingleCatalog"
* spark.sql.catalog.catalog3.uri = "https://dbc-789ghi.databricks.com"
*
* spark.sql.catalog.catalog4 = "com.databricks.sql.lakehouse.catalog3"
* spark.sql.catalog.catalog4.uri = "https://dbc-456def.databricks.com"
* spark.sql.catalog.catalog4.token = "dapi0987654321"
*
* spark.sql.catalog.catalog5 = "io.unitycatalog.connectors.spark.UCSingleCatalog"
* spark.sql.catalog.catalog5.uri = "random-string"
* spark.sql.catalog.catalog5.token = "dapi0987654321"
*
* This method would return:
* List(
* ("catalog1", "https://dbc-123abc.databricks.com", "dapi1234567890"),
* ("catalog2", "https://dbc-456def.databricks.com", "dapi0987654321")
* )
*
* Note: catalog3 is not included in the result because it's missing the token configuration.
* Note: catalog4 is not included in the result because it's not a UCSingleCatalog connector.
* Note: catalog5 is not included in the result because its URI is not a valid URI.
*
* @return
* A list of tuples containing (catalogName, uri, token) for each properly configured catalog
*/
private[delta] def getCatalogConfigs(spark: SparkSession): List[(String, String, String)] = {
val catalogConfigs = spark.conf.getAll.filterKeys(_.startsWith(SPARK_SQL_CATALOG_PREFIX))

catalogConfigs
.keys
.map(_.split("\\."))
.filter(_.length == 4)
.map(_(3))
.filter { catalogName: String =>
val connector = catalogConfigs.get(s"$SPARK_SQL_CATALOG_PREFIX$catalogName")
connector.contains(UNITY_CATALOG_CONNECTOR_CLASS)}
.flatMap { catalogName: String =>
val uri = catalogConfigs.get(s"$SPARK_SQL_CATALOG_PREFIX$catalogName.$URI_SUFFIX")
val token = catalogConfigs.get(s"$SPARK_SQL_CATALOG_PREFIX$catalogName.$TOKEN_SUFFIX")
(uri, token) match {
case (Some(u), Some(t)) =>
try {
new URI(u) // Validate the URI
Some((catalogName, u, t))
} catch {
case _: URISyntaxException =>
logWarning(log"Skipping catalog ${MDC(DeltaLogKeys.CATALOG, catalogName)} as it " +
log"does not have a valid URI ${MDC(DeltaLogKeys.URI, u)}.")
None
}
case _ =>
logWarning(log"Skipping catalog ${MDC(DeltaLogKeys.CATALOG, catalogName)} as it does " +
"not have both uri and token configured in Spark Session.")
None
}}
.toList
}

private def safeClose(ucClient: UCClient, uri: String): Unit = {
try {
ucClient.close()
} catch {
case NonFatal(e) =>
logWarning(log"Failed to close UCClient for uri ${MDC(DeltaLogKeys.URI, uri)}", e)
}
}

def clearCache(): Unit = {
commitCoordinatorClientCache.clear()
uriTokenToMetastoreIdCache.clear()
}
}

trait UCClientFactory {
def createUCClient(uri: String, token: String): UCClient
}

object UCTokenBasedRestClientFactory extends UCClientFactory {
override def createUCClient(uri: String, token: String): UCClient =
new UCTokenBasedRestClient(uri, token)
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ trait DeltaLogKeysBase {
case object APP_ID extends LogKeyShims
case object BATCH_ID extends LogKeyShims
case object BATCH_SIZE extends LogKeyShims
case object CATALOG extends LogKeyShims
case object CLONE_SOURCE_DESC extends LogKeyShims
case object CONFIG extends LogKeyShims
case object CONFIG_KEY extends LogKeyShims
Expand Down
Loading

0 comments on commit 8efcf17

Please sign in to comment.