Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,9 @@ __pycache__
.coverage*
*.jar
.python-version

# Ignore SBT lock files
project/.boot/**/sbt.boot.lock
project/.boot/**/sbt.components.lock
project/.ivy/.sbt.ivy.lock
project/.sbtboot/**/.sbt.cache.lock
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,95 @@ import scala.reflect.{ClassTag, classTag}

/** Can be useful for non Scala types and for complicated case classes with implicits in the constructor. */
object ManualTypedEncoder {
/** Constructs StaticInvoke via reflection to handle Spark 3.4/3.5 constructor differences. */
private def staticInvokeSafely(
targetClass: Class[_],
dataType: DataType,
functionName: String,
arguments: Seq[Expression],
propagateNull: Boolean,
returnNullable: Boolean
): InvokeLike = {
val ctors = classOf[StaticInvoke].getConstructors
val boxedPropagateNull = Boolean.box(propagateNull)
val boxedReturnNullable = Boolean.box(returnNullable)
val TRUE = Boolean.box(true)

val ctor = ctors.maxBy(_.getParameterTypes.length)
val argTypes: Seq[DataType] = arguments.map(_.dataType)
val targetModuleClass: Class[_] = {
val moduleName = targetClass.getName + "$"
try Class.forName(moduleName)
catch { case _: ClassNotFoundException => targetClass }
}

def tryInvoke(onClass: Class[_]): InvokeLike = ctor.getParameterTypes.length match {
case 9 =>
// (Class, DataType, String, Seq, Seq, boolean, boolean, boolean, Option)
ctor.newInstance(
onClass,
dataType,
functionName,
arguments,
argTypes,
boxedPropagateNull,
boxedReturnNullable,
TRUE,
None
).asInstanceOf[InvokeLike]
case 8 =>
// (Class, DataType, String, Seq, Seq, boolean, boolean, boolean)
ctor.newInstance(
onClass,
dataType,
functionName,
arguments,
argTypes,
boxedPropagateNull,
boxedReturnNullable,
TRUE
).asInstanceOf[InvokeLike]
case _ =>
throw new NotImplementedError("StaticInvoke constructor has unexpected shape")
}
ctor.getParameterTypes.length match {
case 9 | 8 =>
// Try on the class first (top-level case classes have static forwarders), then on module
val firstError = try {
return tryInvoke(targetClass)
} catch { case t: Throwable => t }
tryInvoke(targetModuleClass)
case _ =>
throw new NotImplementedError("StaticInvoke constructor has unexpected shape")
}
}

/** Detect whether a static forwarder for `apply` of given arity exists on the given class. */
private def hasStaticApply(onClass: Class[_], arity: Int): Boolean = {
import java.lang.reflect.Modifier
onClass.getMethods.exists { m =>
m.getName == "apply" && m.getParameterCount == arity && Modifier.isStatic(m.getModifiers)
}
}

/** Invokes apply from the companion object. */
def staticInvoke[T: ClassTag](
fields: List[RecordEncoderField],
fieldNameModify: String => String = identity,
isNullable: Boolean = true
): TypedEncoder[T] = apply[T](fields, { (classTag, newArgs, jvmRepr) => StaticInvoke(classTag.runtimeClass, jvmRepr, "apply", newArgs, propagateNull = true, returnNullable = false) }, fieldNameModify, isNullable)
): TypedEncoder[T] = apply[T](fields, { (classTag, newArgs, jvmRepr) =>
val target = classTag.runtimeClass
val moduleName = target.getName + "$"
val moduleClass = try Class.forName(moduleName) catch { case _: ClassNotFoundException => null }
val arity = newArgs.length
if ((hasStaticApply(target, arity)) || (moduleClass != null && hasStaticApply(moduleClass, arity))) {
staticInvokeSafely(target, jvmRepr, "apply", newArgs, propagateNull = true, returnNullable = false)
}
else {
// Fall back to directly invoking the primary constructor
NewInstance(target, newArgs, jvmRepr, propagateNull = true)
}
}, fieldNameModify, isNullable)

/** Invokes object constructor. */
def newInstance[T: ClassTag](
Expand Down