Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
<kubernetes-client.version>6.7.2</kubernetes-client.version>
<jackson.version>2.15.2</jackson.version>
<scalatest-maven-plugin.version>2.2.0</scalatest-maven-plugin.version>
<slf4j.version>2.0.7</slf4j.version>
<!-- SPARK-36796 for JDK-17 test-->
<extraJavaTestArgs>
-XX:+IgnoreUnrecognizedVMOptions
Expand Down Expand Up @@ -190,6 +191,11 @@
<artifactId>jackson-module-scala_${scala.binary.version}</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>${slf4j.version}</version>
</dependency>
</dependencies>

<profiles>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ import org.apache.spark.deploy.k8s.features.{
KubernetesFeatureConfigStep
}
import org.apache.spark.scheduler.cluster.SchedulerBackendUtils
import org.apache.spark.scheduler.cluster.k8s.KubernetesExecutorBuilder
import org.apache.spark.util.Utils
import org.slf4j.{Logger, LoggerFactory}

import scala.collection.mutable
import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -145,18 +147,12 @@ private[spark] object ArmadaClientApplication {
private val DEFAULT_NAMESPACE = "default"
private val DEFAULT_ARMADA_APP_ID = "armada-spark-app-id"
private val DEFAULT_RUN_AS_USER = 185

}

/** Main class and entry point of application submission in KUBERNETES mode.
*/
private[spark] class ArmadaClientApplication extends SparkApplication {
// FIXME: Find the real way to log properly.
private def log(msg: String): Unit = {
// scalastyle:off println
System.err.println(msg)
// scalastyle:on println
}
private val logger = LoggerFactory.getLogger(getClass)

override def start(args: Array[String], conf: SparkConf): Unit = {
val parsedArguments = ClientArguments.fromCommandLineArgs(args)
Expand All @@ -170,17 +166,17 @@ private[spark] class ArmadaClientApplication extends SparkApplication {
val armadaJobConfig = validateArmadaJobConfig(sparkConf, clientArguments)

val (host, port) = ArmadaUtils.parseMasterUrl(sparkConf.get("spark.master"))
log(s"Connecting to Armada Server - host: $host, port: $port")
logger.info(s"Connecting to Armada Server - host: $host, port: $port")

val armadaClient = ArmadaClient(host, port, useSsl = false, sparkConf.get(ARMADA_AUTH_TOKEN))
val healthTimeout =
Duration(sparkConf.get(ARMADA_HEALTH_CHECK_TIMEOUT), SECONDS)

log(s"Checking Armada Server health (timeout: $healthTimeout)")
logger.info(s"Checking Armada Server health (timeout: $healthTimeout)")
val healthResp = Await.result(armadaClient.submitHealth(), healthTimeout)

if (healthResp.status.isServing) {
log("Armada Server is serving requests!")
logger.info("Armada Server is serving requests!")
} else {
throw new RuntimeException(
"Armada health check failed - Armada Server is not serving requests!"
Expand All @@ -196,7 +192,7 @@ private[spark] class ArmadaClientApplication extends SparkApplication {
val lookoutURL =
s"$lookoutBaseURL/?page=0&sort[id]=jobId&sort[desc]=true&" +
s"ps=50&sb=$driverJobId&active=false&refresh=true"
log(s"Lookout URL for the driver job is $lookoutURL")
logger.info(s"Lookout URL for the driver job is $lookoutURL")

()
}
Expand Down Expand Up @@ -745,7 +741,7 @@ private[spark] class ArmadaClientApplication extends SparkApplication {
val error = Some(driverResponse.jobResponseItems.head.error)
.filter(_.nonEmpty)
.getOrElse("none")
log(
logger.info(
s"Submitted driver job with ID: $driverJobId, Error: $error"
)
driverJobId
Expand All @@ -760,7 +756,7 @@ private[spark] class ArmadaClientApplication extends SparkApplication {
val executorsResponse = armadaClient.submitJobs(queue, jobSetId, executors)
executorsResponse.jobResponseItems.map { item =>
val error = Some(item.error).filter(_.nonEmpty).getOrElse("none")
log(s"Submitted executor job with ID: ${item.jobId}, Error: $error")
logger.info(s"Submitted executor job with ID: ${item.jobId}, Error: $error")
item.jobId
}
}
Expand Down
7 changes: 7 additions & 0 deletions src/test/resources/log4j2-test.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
rootLogger.level = debug
rootLogger.appenderRef.console.ref = console

appender.console.type = Console
appender.console.name = console
appender.console.layout.type = PatternLayout
appender.console.layout.pattern = %d{HH:mm:ss.SSS} %-5level %logger{20} - %msg%n
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import api.submit.Queue
import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import org.slf4j.{Logger, LoggerFactory}

import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._
Expand All @@ -46,6 +47,7 @@ object JobSetStatus {
* Armada server URL (default: "localhost:30002")
*/
class ArmadaClient(armadaUrl: String = "localhost:30002") {
private val logger = LoggerFactory.getLogger(getClass)
private val processTimeout = DefaultProcessTimeout

private val yamlMapper = {
Expand Down Expand Up @@ -94,12 +96,12 @@ class ArmadaClient(armadaUrl: String = "localhost:30002") {
def ensureQueueExists(name: String)(implicit ec: ExecutionContext): Future[Unit] = {
getQueue(name).flatMap {
case Some(_) =>
println(s"[QUEUE] Queue $name already exists")
logger.info(s"[QUEUE] Queue $name already exists")
Future.successful(())
case None =>
println(s"[QUEUE] Creating queue $name...")
logger.info(s"[QUEUE] Creating queue $name...")
createQueue(name).flatMap { _ =>
println(s"[QUEUE] Waiting for queue $name to become available...")
logger.info(s"[QUEUE] Waiting for queue $name to become available...")
Future {
var attempts = 0
val maxAttempts = 30
Expand All @@ -110,12 +112,12 @@ class ArmadaClient(armadaUrl: String = "localhost:30002") {
queueFound = getQueueSync(name).isDefined
attempts += 1
if (!queueFound && attempts % 5 == 0) {
println(s"[QUEUE] Still waiting for queue $name... (${attempts}s)")
logger.info(s"[QUEUE] Still waiting for queue $name... (${attempts}s)")
}
}

if (queueFound) {
println(s"[QUEUE] Queue $name is ready")
logger.info(s"[QUEUE] Queue $name is ready")
} else {
throw new RuntimeException(s"Queue $name not available after $attempts seconds")
}
Expand Down Expand Up @@ -164,7 +166,7 @@ class ArmadaClient(armadaUrl: String = "localhost:30002") {
val elapsed = (System.currentTimeMillis() - startTime) / 1000
// Log progress periodically for long-running jobs
if (elapsed % ProgressLogInterval.toSeconds == 0 && elapsed > 0) {
println(s"Still monitoring job - elapsed: ${elapsed}s")
logger.info(s"Still monitoring job - elapsed: ${elapsed}s")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
import org.scalatest.concurrent.{Eventually, ScalaFutures}
import org.scalatest.time.{Seconds, Span}
import org.slf4j.LoggerFactory

import java.util.Properties
import scala.concurrent.ExecutionContext.Implicits.global
Expand All @@ -34,6 +35,8 @@ class ArmadaSparkE2E
with ScalaFutures
with Eventually {

private val logger = LoggerFactory.getLogger(getClass)

implicit override val patienceConfig: PatienceConfig = PatienceConfig(
timeout = Span(300, Seconds),
interval = Span(2, Seconds)
Expand Down Expand Up @@ -72,14 +75,14 @@ class ArmadaSparkE2E
sparkVersion = finalSparkVersion
)

println(s"Test configuration loaded: $baseConfig")
logger.info(s"Test configuration loaded: $baseConfig")

// Verify Armada cluster is ready before running tests
val clusterReadyTimeout = ClusterReadyTimeout.toSeconds.toInt
val testQueueName = s"${baseConfig.baseQueueName}-cluster-check-${System.currentTimeMillis()}"

println(s"[CLUSTER-CHECK] Verifying Armada cluster readiness...")
println(s"[CLUSTER-CHECK] Will retry for up to $clusterReadyTimeout seconds")
logger.info(s"[CLUSTER-CHECK] Verifying Armada cluster readiness...")
logger.info(s"[CLUSTER-CHECK] Will retry for up to $clusterReadyTimeout seconds")

val startTime = System.currentTimeMillis()
var clusterReady = false
Expand All @@ -88,29 +91,33 @@ class ArmadaSparkE2E

while (!clusterReady && (System.currentTimeMillis() - startTime) < clusterReadyTimeout * 1000) {
attempts += 1
println(
logger.info(
s"[CLUSTER-CHECK] Attempt #$attempts - Creating and verifying test queue: $testQueueName"
)

try {
armadaClient.ensureQueueExists(testQueueName).futureValue
println(s"[CLUSTER-CHECK] Queue creation and verification succeeded - cluster is ready!")
logger.info(
s"[CLUSTER-CHECK] Queue creation and verification succeeded - cluster is ready!"
)
clusterReady = true

try {
armadaClient.deleteQueue(testQueueName).futureValue
println(s"[CLUSTER-CHECK] Test queue cleaned up")
logger.info(s"[CLUSTER-CHECK] Test queue cleaned up")
} catch {
case _: Exception => // Ignore cleanup failures
}
} catch {
case ex: Exception =>
lastError = Some(ex)
val elapsed = (System.currentTimeMillis() - startTime) / 1000
println(s"[CLUSTER-CHECK] Attempt #$attempts failed after ${elapsed}s: ${ex.getMessage}")
logger.info(
s"[CLUSTER-CHECK] Attempt #$attempts failed after ${elapsed}s: ${ex.getMessage}"
)

if ((System.currentTimeMillis() - startTime) < clusterReadyTimeout * 1000) {
println(
logger.info(
s"[CLUSTER-CHECK] Waiting ${ClusterCheckRetryDelay.toSeconds} seconds before retry..."
)
Thread.sleep(ClusterCheckRetryDelay.toMillis)
Expand All @@ -126,7 +133,7 @@ class ArmadaSparkE2E
)
}

println(
logger.info(
s"[CLUSTER-CHECK] Cluster verified ready after $totalTime seconds ($attempts attempts)"
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.concurrent.{Future, ExecutionContext, blocking}
import scala.concurrent.duration._
import scala.sys.process._
import scala.util.{Failure, Success, Try}
import org.slf4j.{Logger, LoggerFactory}

case class ProcessResult(
exitCode: Int,
Expand All @@ -31,6 +32,7 @@ case class ProcessResult(
)

object ProcessExecutor {
private val logger = LoggerFactory.getLogger(getClass)

/** Execute command and always return ProcessResult, even on failure */
def executeWithResult(command: Seq[String], timeout: Duration): ProcessResult = {
Expand All @@ -42,14 +44,14 @@ object ProcessExecutor {
stdout.append(line).append("\n")
// Print docker/spark-submit output in real-time for debugging
if (command.headOption.contains("docker") && line.nonEmpty) {
println(s"[SPARK-SUBMIT] $line")
logger.info(s"[SPARK-SUBMIT] $line")
}
},
line => {
stderr.append(line).append("\n")
// Print docker/spark-submit errors in real-time for debugging
if (command.headOption.contains("docker") && line.nonEmpty) {
println(s"[SPARK-SUBMIT] $line")
logger.info(s"[SPARK-SUBMIT] $line")
}
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

package org.apache.spark.deploy.armada.e2e

import org.slf4j.{Logger, LoggerFactory}
import scala.collection.mutable
import scala.concurrent.duration._

/** Simple pod monitoring that captures logs and events on failure */
class SimplePodMonitor(namespace: String) {
private val logger = LoggerFactory.getLogger(getClass)
private val capturedLogs = mutable.ArrayBuffer[String]()

/** Check if any pods have failed and capture their logs if so */
Expand All @@ -45,7 +47,7 @@ class SimplePodMonitor(namespace: String) {
val failedPodName = failedPods.head.split("\\s+").head

// Immediately capture logs for the failed pod
println(s"[MONITOR] Pod $failedPodName failed, capturing logs...")
logger.info(s"[MONITOR] Pod $failedPodName failed, capturing logs...")
try {
val logsCmd = Seq(
"kubectl",
Expand All @@ -58,8 +60,8 @@ class SimplePodMonitor(namespace: String) {
)
val logsResult = ProcessExecutor.executeWithResult(logsCmd, 10.seconds)
if (logsResult.exitCode == 0 && logsResult.stdout.nonEmpty) {
println(s"[MONITOR] Pod $failedPodName logs:")
println(logsResult.stdout)
logger.info(s"[MONITOR] Pod $failedPodName logs:")
logger.info(logsResult.stdout)
}

// Also try to describe the pod
Expand All @@ -69,33 +71,33 @@ class SimplePodMonitor(namespace: String) {
val lines = describeResult.stdout.split("\n")
val eventsIndex = lines.indexWhere(_.contains("Events:"))
if (eventsIndex >= 0) {
println(s"[MONITOR] Pod $failedPodName events:")
println(lines.slice(eventsIndex, eventsIndex + 20).mkString("\n"))
logger.info(s"[MONITOR] Pod $failedPodName events:")
logger.info(lines.slice(eventsIndex, eventsIndex + 20).mkString("\n"))
}
}
} catch {
case e: Exception =>
println(s"[MONITOR] Failed to capture logs for $failedPodName: ${e.getMessage}")
logger.info(s"[MONITOR] Failed to capture logs for $failedPodName: ${e.getMessage}")
}

Some(s"Pod $failedPodName failed in namespace $namespace")
} else {
None
}
} else {
println(s"[MONITOR] Failed to get pods: ${podsResult.stderr}")
logger.info(s"[MONITOR] Failed to get pods: ${podsResult.stderr}")
None
}
} catch {
case e: Exception =>
println(s"[MONITOR] Error checking pods: ${e.getMessage}")
logger.info(s"[MONITOR] Error checking pods: ${e.getMessage}")
None
}
}

/** Capture all logs and events for debugging */
def captureDebugInfo(): Unit = {
println(s"\n[DEBUG] Capturing debug info for namespace $namespace")
logger.debug(s"Capturing debug info for namespace $namespace")

try {
val podsCmd = Seq("kubectl", "get", "pods", "-n", namespace, "-o", "name")
Expand All @@ -107,7 +109,7 @@ class SimplePodMonitor(namespace: String) {

podNames.foreach { podName =>
try {
println(s"[DEBUG] Capturing logs for pod $podName")
logger.debug(s"Capturing logs for pod $podName")

val logsCmd = Seq(
"kubectl",
Expand Down Expand Up @@ -150,7 +152,7 @@ class SimplePodMonitor(namespace: String) {
}
} catch {
case e: Exception =>
println(s"[DEBUG] Failed to capture info for pod $podName: ${e.getMessage}")
logger.error(s"Failed to capture info for pod $podName: ${e.getMessage}")
}
}
}
Expand All @@ -163,16 +165,16 @@ class SimplePodMonitor(namespace: String) {

} catch {
case e: Exception =>
println(s"[DEBUG] Failed to capture debug info: ${e.getMessage}")
logger.error(s"Failed to capture debug info: ${e.getMessage}")
}
}

/** Print all captured logs */
def printCapturedLogs(): Unit = {
if (capturedLogs.nonEmpty) {
println(s"\n========== DEBUG INFO FOR NAMESPACE: $namespace ==========")
capturedLogs.foreach(println)
println(s"========== END DEBUG INFO ==========\n")
logger.info(s"\n========== DEBUG INFO FOR NAMESPACE: $namespace ==========")
capturedLogs.foreach(logger.info)
logger.info(s"========== END DEBUG INFO ==========\n")
}
}
}
Loading
Loading