Skip to content

Commit

Permalink
[SW-2646] Calculate Metrics on Arbitrary Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
mn-mikke committed Mar 21, 2022
1 parent d7e8f01 commit 4a42b99
Show file tree
Hide file tree
Showing 38 changed files with 1,932 additions and 146 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class AlgorithmConfigurations extends MultipleAlgorithmsConfiguration {
type KMeansParamsV3 = KMeansV3.KMeansParametersV3

val explicitDefaultValues =
Map[String, Any]("max_w2" -> 3.402823e38f, "response_column" -> "label", "model_id" -> null, "lambda" -> null)
Map[String, Any]("max_w2" -> 3.402823e38f, "response_column" -> "label", "model_id" -> null)

val noDeprecation = Seq.empty

Expand Down Expand Up @@ -173,25 +173,34 @@ class AlgorithmConfigurations extends MultipleAlgorithmsConfiguration {

type IFParameters = IsolationForestParameters

val algorithms = Seq[(String, Class[_], String, Seq[String], Option[String])](
("H2OXGBoost", classOf[XGBoostParameters], treeSupervised, Seq(withDistribution), None),
("H2OGBM", classOf[GBMParameters], treeSupervised, Seq(withDistribution), None),
("H2ODRF", classOf[DRFParameters], treeSupervised, Seq(withDistribution), None),
("H2OGLM", classOf[GLMParameters], cvSupervised, Seq(withFamily), Some("H2OGLMMetrics")),
("H2OGAM", classOf[GAMParameters], cvSupervised, Seq(withFamily), None),
("H2ODeepLearning", classOf[DeepLearningParameters], cvSupervised, Seq(withDistribution), None),
("H2ORuleFit", classOf[RuleFitParameters], supervised, Seq(withDistribution), None),
("H2OKMeans", classOf[KMeansParameters], unsupervised, Seq("H2OKMeansExtras"), Some("H2OClusteringMetrics")),
("H2OCoxPH", classOf[CoxPHParameters], supervised, Seq.empty, Some("H2ORegressionCoxPHMetrics")),
("H2OIsolationForest", classOf[IFParameters], treeUnsupervised, Seq.empty, Some("H2OAnomalyMetrics")))

for ((entityName, h2oParametersClass: Class[_], algorithmType, extraParents, metricsClass) <- algorithms)
val none = Seq.empty

val algorithms = Seq[(String, Class[_], String, Seq[String], Seq[String], Option[String])](
("H2OXGBoost", classOf[XGBoostParameters], treeSupervised, Seq(withDistribution), none, None),
("H2OGBM", classOf[GBMParameters], treeSupervised, Seq(withDistribution), none, None),
("H2ODRF", classOf[DRFParameters], treeSupervised, Seq(withDistribution), none, None),
("H2OGLM", classOf[GLMParameters], cvSupervised, Seq(withFamily), none, Some("H2OGLMMetrics")),
("H2OGAM", classOf[GAMParameters], cvSupervised, Seq(withFamily), none, None),
("H2ODeepLearning", classOf[DeepLearningParameters], cvSupervised, Seq(withDistribution), none, None),
("H2ORuleFit", classOf[RuleFitParameters], supervised, Seq(withDistribution), none, None),
(
"H2OKMeans",
classOf[KMeansParameters],
unsupervised,
Seq("H2OKMeansExtras"),
Seq("KmeansMetricCalculation"),
Some("H2OClusteringMetrics")),
("H2OCoxPH", classOf[CoxPHParameters], supervised, none, none, Some("H2ORegressionCoxPHMetrics")),
("H2OIsolationForest", classOf[IFParameters], treeUnsupervised, none, none, Some("H2OAnomalyMetrics")))

for ((entityName, h2oParametersClass: Class[_], algorithmType, extraParents, extraMOJOParents, metricsClass) <- algorithms)
yield AlgorithmSubstitutionContext(
namespace = "ai.h2o.sparkling.ml.algos",
entityName,
h2oParametersClass,
algorithmType,
extraParents,
extraMOJOParents,
specificMetricsClass = metricsClass)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ case class AlgorithmSubstitutionContext(
h2oSchemaClass: Class[_],
algorithmType: String,
extraInheritedEntities: Seq[String] = Seq.empty,
extraInheritedEntitiesOnMOJO: Seq[String] = Seq.empty,
constructorMethods: Boolean = true,
specificMetricsClass: Option[String] = None)
extends SubstitutionContextBase
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class AutoMLConfiguration extends SingleAlgorithmConfiguration {
null,
"H2OSupervisedAlgorithmWithFoldColumn",
Seq("H2OAutoMLExtras"),
false))
constructorMethods = false))
}

override def problemSpecificAlgorithmConfiguration: Seq[ProblemSpecificAlgorithmSubstitutionContext] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class FeatureEstimatorConfigurations extends MultipleAlgorithmsConfiguration {

override def algorithmConfiguration: Seq[AlgorithmSubstitutionContext] = {

def none = Seq.empty[String]
val algorithms = Seq[(String, Class[_], String, Option[String])](
("H2OAutoEncoder", classOf[DeepLearningParameters], "H2OAutoEncoderBase", Some("H2OAutoEncoderMetrics")),
("H2OPCA", classOf[PCAParameters], "H2ODimReductionEstimator", Some("H2OPCAMetrics")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,6 @@ class GridSearchConfiguration extends SingleAlgorithmConfiguration {
null,
"H2OAlgorithm",
Seq("H2OGridSearchExtras"),
false))
constructorMethods = false))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ object MOJOModelTemplate
val imports = Seq(
"com.google.gson.JsonObject",
"ai.h2o.sparkling.ml.params.ParameterConstructorMethods",
"ai.h2o.sparkling.ml.metrics._",
"hex.genmodel.MojoModel",
"org.apache.spark.expose.Logging",
"ai.h2o.sparkling.utils.DataFrameSerializationWrappers._") ++
Expand All @@ -59,7 +60,9 @@ object MOJOModelTemplate
.replace("Estimator", "MOJOModel")
.replaceFirst("Base$", "MOJOBase"),
"ParameterConstructorMethods",
"Logging") ++ explicitFieldImplementations
"Logging") ++
explicitFieldImplementations ++
algorithmSubstitutionContext.extraInheritedEntitiesOnMOJO

val entityName = algorithmSubstitutionContext.entityName
val entityParameters = "(override val uid: String)"
Expand Down Expand Up @@ -212,6 +215,11 @@ object MOJOModelTemplate
| override def getCrossValidationMetricsObject(): $metrics = {
| val value = super.getCrossValidationMetricsObject()
| if (value == null) null else value.asInstanceOf[$metrics]
| }
|
| override def getMetricsObject(dataFrame: org.apache.spark.sql.DataFrame): $metrics = {
| val value = super.getMetricsObject(dataFrame)
| if (value == null) null else value.asInstanceOf[$metrics]
| }""".stripMargin
}
}
46 changes: 45 additions & 1 deletion core/src/test/scala/ai/h2o/sparkling/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.mllib
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions.{lit, rand}
import org.apache.spark.sql.functions.{lit, rand, col, abs}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.scalatest.Matchers
Expand Down Expand Up @@ -100,6 +100,50 @@ object TestUtils extends Matchers {
""".stripMargin)
}

def assertDataFramesAreEqual(
expected: DataFrame,
produced: DataFrame,
identityColumn: String,
tolerance: Double): Unit = {
val tolerances = expected.schema.fields
.filterNot(_.name == identityColumn)
.filter(_.dataType.isInstanceOf[NumericType])
.map(_.name -> tolerance)
.toMap
assertDataFramesAreEqual(expected, produced, identityColumn, tolerances)
}

def assertDataFramesAreEqual(
expected: DataFrame,
produced: DataFrame,
identityColumn: String,
tolerances: Map[String, Double] = Map.empty): Unit = {
expected.schema shouldEqual produced.schema
val intersection = expected.as("expected").join(produced.as("produced"), identityColumn)
intersection.count() shouldEqual expected.count()
intersection.count() shouldEqual produced.count()
val isEqualExpression = expected.columns.foldLeft(lit(true)) {
case (partialExpression, columnName) =>
val columnComparision = if (tolerances.contains(columnName)) {
val difference = abs(col(s"expected.$columnName") - col(s"produced.$columnName"))
difference <= lit(tolerances(columnName))
} else if (columnName == identityColumn) {
lit(true)
} else {
col(s"expected.$columnName") === col(s"produced.$columnName")
}
partialExpression && columnComparision
}
val withComparisonDF = intersection.withColumn("isEqual", isEqualExpression)
val differentRowsDF = withComparisonDF
.filter(col("isEqual") === lit(false))
.select(col(s"expected.$identityColumn") as "id")
val differentIds = differentRowsDF.collect().map(_.get(0))
assert(
differentIds.length == 0,
s"The rows of ids($identityColumn) [${differentIds.mkString(", ")}] are not equal.")
}

def assertDatasetBasicProperties[T <: Product](
ds: Dataset[T],
df: H2OFrame,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ object Runner {
}
} else {
val metricClasses = getParamClasses("ai.h2o.sparkling.ml.metrics")
.filter(_.getSimpleName.endsWith("Metrics"))
writeResultToFile(MetricsTocTreeTemplate(metricClasses), "metrics", destinationDir)
for (metricClass <- metricClasses) {
val content = MetricsTemplate(metricClass)
Expand Down
13 changes: 11 additions & 2 deletions doc/src/site/sphinx/deployment/load_mojo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,8 @@ Obtaining Scoring History
The method ``getScoringHistory`` returns a data frame describing how the model evolved during the training process according to
a certain training and validation metrics.

Obtaining Metrics
^^^^^^^^^^^^^^^^^
Obtaining Pre-calculated Metrics
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

There are two sets of methods to obtain metrics from the MOJO model.

Expand All @@ -389,6 +389,15 @@ the metrics could be also of a complex type. (see :ref:`metrics` for details)

There is also the method ``getCurrentMetricsObject()`` working a similar way as ``getCurrentMetrics()``.

Calculation of Metrics on Arbitrary Dataset
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The below two methods calculate metrics on a provided dataset.

- ``getMetrics(dataFrame)`` - Returns a map with basic metrics of double type

- ``getMetricsObject(dataFrame)`` - Returns an object with basic and more complex metrics available via getter methods.
(see :ref:`metrics` for details)

Obtaining Cross Validation Metrics Summary
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The ``getCrossValidationMetricsSummary`` method returns data frame with information about performance of individual folds
Expand Down
1 change: 1 addition & 0 deletions extensions/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ dependencies {
compileOnly("org.scala-lang:scala-library:${scalaVersion}")

compileOnly("ai.h2o:h2o-core:${h2oVersion}")
compileOnly("ai.h2o:h2o-algos:${h2oVersion}")
compileOnly("javax.servlet:servlet-api:2.5")

testImplementation("org.scala-lang:scala-library:${scalaVersion}")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
hex.MetricsCalculationTypeExtensions
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package hex;

import hex.glm.IndependentGLMMetricBuilder;
import hex.glrm.ModelMetricsGLRM;
import hex.pca.ModelMetricsPCA;
import hex.tree.isofor.ModelMetricsAnomaly;
import java.util.Arrays;
import water.TypeMapExtension;
import water.api.ModelMetricsPCAV3;
import water.api.schemas3.*;

public class MetricsCalculationTypeExtensions implements TypeMapExtension {
public static final String[] MODEL_BUILDER_CLASSES = {
ModelMetrics.IndependentMetricBuilder.class.getName(),
ModelMetricsSupervised.IndependentMetricBuilderSupervised.class.getName(),
ModelMetricsUnsupervised.IndependentMetricBuilderUnsupervised.class.getName(),
ModelMetricsBinomial.IndependentMetricBuilderBinomial.class.getName(),
AUC2.AUCBuilder.class.getName(),
ModelMetricsRegression.IndependentMetricBuilderRegression.class.getName(),
Distribution.class.getName(),
GaussianDistribution.class.getName(),
BernoulliDistribution.class.getName(),
QuasibinomialDistribution.class.getName(),
ModifiedHuberDistribution.class.getName(),
MultinomialDistribution.class.getName(),
PoissonDistribution.class.getName(),
GammaDistribution.class.getName(),
TweedieDistribution.class.getName(),
HuberDistribution.class.getName(),
LaplaceDistribution.class.getName(),
QuantileDistribution.class.getName(),
CustomDistribution.class.getName(),
CustomDistributionWrapper.class.getName(),
LinkFunction.class.getName(),
IdentityFunction.class.getName(),
InverseFunction.class.getName(),
LogFunction.class.getName(),
LogitFunction.class.getName(),
OlogitFunction.class.getName(),
OloglogFunction.class.getName(),
OprobitFunction.class.getName(),
ModelMetricsMultinomial.IndependentMetricBuilderMultinomial.class.getName(),
ModelMetricsOrdinal.IndependentMetricBuilderOrdinal.class.getName(),
ModelMetricsClustering.IndependentMetricBuilderClustering.class.getName(),
ModelMetricsHGLM.IndependentMetricBuilderHGLM.class.getName(),
ModelMetricsGLRM.IndependentGLRMModelMetricsBuilder.class.getName(),
ModelMetricsAnomaly.IndependentMetricBuilderAnomaly.class.getName(),
IndependentGLMMetricBuilder.class.getName(),
hex.glm.GLMModel.GLMWeightsFun.class.getName(),
ModelMetricsAutoEncoder.IndependentAutoEncoderMetricBuilder.class.getName(),
ModelMetricsPCA.IndependentPCAMetricBuilder.class.getName()
};

public static final String[] SCHEMA_CLASSES = {
ModelMetricsBaseV3.class.getName(),
ModelMetricsBinomialGLMV3.class.getName(),
ModelMetricsBinomialV3.class.getName(),
ModelMetricsMultinomialGLMV3.class.getName(),
ModelMetricsMultinomialV3.class.getName(),
ModelMetricsOrdinalGLMV3.class.getName(),
ModelMetricsOrdinalV3.class.getName(),
ModelMetricsRegressionGLMV3.class.getName(),
ModelMetricsRegressionCoxPHV3.class.getName(),
ModelMetricsRegressionV3.class.getName(),
ModelMetricsAutoEncoderV3.class.getName(),
ModelMetricsPCAV3.class.getName(),
ModelMetricsHGLMV3.class.getName(),
ModelMetricsClusteringV3.class.getName(),
ConfusionMatrixV3.class.getName(),
TwoDimTableV3.class.getName(),
TwoDimTableV3.ColumnSpecsBase.class.getName()
};

@Override
public String[] getBoostrapClasses() {
String[] result =
Arrays.copyOf(MODEL_BUILDER_CLASSES, MODEL_BUILDER_CLASSES.length + SCHEMA_CLASSES.length);
System.arraycopy(
SCHEMA_CLASSES, 0, result, MODEL_BUILDER_CLASSES.length, SCHEMA_CLASSES.length);
return result;
}
}
8 changes: 4 additions & 4 deletions gradle.properties
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# artifacts group
group=ai.h2o
# Major version of H2O release
h2oMajorVersion=3.36.0
h2oMajorVersion=3.37.0
# Name of H2O major version
h2oMajorName=zorn
# H2O Build version, defined here to be overriden by -P option
h2oBuild=3
h2oBuild=1-SNAPSHOT
# Version of Mojo Pipeline library
mojoPipelineVersion=2.7.5
# Defines whether to run tests with Driverless AI mojo pipelines
Expand All @@ -29,11 +29,11 @@ pythonEnvironments=2.7 3.6 3.7 3.8
# Select for which Spark version is Sparkling Water built by default
spark=3.2
# Sparkling Water Version
version=3.38.0.1-1-SNAPSHOT
version=3.38.0.1-199-SNAPSHOT
# Spark version from which is Kubernetes Supported
kubernetesSupportSinceSpark=2.4
databricksTestSinceSpark=2.4
spotlessModern=true
testH2OBranch=master
testH2OBranch=mn/PUBDEV-8373
makeBooklet=false
testingBaseImage="harbor.h2o.ai/opsh2oai/h2o-3-hadoop-cdh-6.3:84"
Original file line number Diff line number Diff line change
Expand Up @@ -183,57 +183,4 @@ class BinomialPredictionTestSuite extends FunSuite with Matchers with SharedH2OT
assert(schema == expectedSchema)
assert(schema == expectedSchemaByTransform)
}

private def assertMetrics[T](model: H2OMOJOModel): Unit = {
assertMetrics[T](model.getTrainingMetricsObject(), model.getTrainingMetrics())
assertMetrics[T](model.getValidationMetricsObject(), model.getValidationMetrics())
assert(model.getCrossValidationMetricsObject() == null)
assert(model.getCrossValidationMetrics() == Map())
}

private def assertMetrics[T](metricsObject: H2OMetrics, metrics: Map[String, Double]): Unit = {
metricsObject.isInstanceOf[T] should be(true)
MetricsAssertions.assertMetricsObjectAgainstMetricsMap(metricsObject, metrics)
val binomialObject = metricsObject.asInstanceOf[H2OBinomialMetrics]
binomialObject.getConfusionMatrix().count() > 0
binomialObject.getConfusionMatrix().columns.length > 0
binomialObject.getGainsLiftTable().count() > 0
binomialObject.getGainsLiftTable().columns.length > 0
binomialObject.getMaxCriteriaAndMetricScores().count() > 0
binomialObject.getMaxCriteriaAndMetricScores().columns.length > 0
binomialObject.getThresholdsAndMetricScores().count() > 0
binomialObject.getThresholdsAndMetricScores().columns.length > 0
}

test("test binomial metric objects") {
val algo = new H2OGBM()
.setSplitRatio(0.8)
.setSeed(1)
.setFeaturesCols("sepal_len", "sepal_wid")
.setColumnsToCategorical("class")
.setLabelCol("class")

val model = algo.fit(dataset)
assertMetrics[H2OBinomialMetrics](model)

model.write.overwrite().save("ml/build/gbm_binomial_model_metrics")
val loadedModel = H2OGBMMOJOModel.load("ml/build/gbm_binomial_model_metrics")
assertMetrics[H2OBinomialMetrics](loadedModel)
}

test("test binomial glm metric objects") {
val algo = new H2OGLM()
.setSplitRatio(0.8)
.setSeed(1)
.setFeaturesCols("sepal_len", "sepal_wid")
.setColumnsToCategorical("class")
.setLabelCol("class")

val model = algo.fit(dataset)
assertMetrics[H2OBinomialGLMMetrics](model)

model.write.overwrite().save("ml/build/glm_binomial_model_metrics")
val loadedModel = H2OGLMMOJOModel.load("ml/build/glm_binomial_model_metrics")
assertMetrics[H2OBinomialGLMMetrics](loadedModel)
}
}
Loading

0 comments on commit 4a42b99

Please sign in to comment.