Skip to content

Commit

Permalink
Update api
Browse files Browse the repository at this point in the history
  • Loading branch information
mn-mikke committed Mar 21, 2022
1 parent 4a42b99 commit 5ecba2f
Show file tree
Hide file tree
Showing 5 changed files with 266 additions and 120 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.h2o.sparkling.ml.metrics

import hex.ModelMetricsBinomial.IndependentMetricBuilderBinomial
import hex.genmodel.utils.DistributionFamily
import org.apache.spark.sql.DataFrame

object H2OBinomialMetrics extends MetricCalculation {

def calculate(
dataFrame: DataFrame,
domain: Array[String],
predictionProbabilitiesCol: String = "detailed_prediction.probabilities",
labelCol: String = "label",
weightColOption: Option[String] = None,
offsetColOption: Option[String] = None,
distributionFamily: String = "AUTO"): H2OBinomialMetrics = {
val domainFamilyEnum = DistributionFamily.valueOf(distributionFamily)
val getMetricBuilder = () => new IndependentMetricBuilderBinomial[_](domain, domainFamilyEnum)

val gson = getMetricGson(
getMetricBuilder,
dataFrame,
predictionProbabilitiesCol,
labelCol,
offsetColOption,
weightColOption,
domain)
val result = new H2OBinomialMetrics()
result.setMetrics(gson, "H2OBinomialMetrics.calculate")
result
}

def calculate(
dataFrame: DataFrame,
domain: Array[String],
predictionProbabilitiesCol: String,
labelCol: String,
weightCol: String,
offsetCol: String,
distributionFamily: String): Unit = {
calculate(
dataFrame,
domain,
predictionProbabilitiesCol,
labelCol,
Option(weightCol),
Option(offsetCol),
distributionFamily)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.h2o.sparkling.ml.metrics

import ai.h2o.sparkling.ml.metrics.H2OBinomialMetrics.getMetricGson
import hex.ModelMetricsMultinomial.IndependentMetricBuilderMultinomial
import hex.MultinomialAucType
import org.apache.spark.sql.DataFrame

object H2OMultinomialMetrics {
def calculate(
dataFrame: DataFrame,
domain: Array[String],
predictionProbabilitiesCol: String = "detailed_prediction.probabilities",
labelCol: String = "label",
weightColOption: Option[String] = None,
offsetColOption: Option[String] = None,
priorDistributionOption: Option[Array[Double]] = None,
aucType: String = "AUTO"): H2OMultinomialMetrics = {

val aucTypeEnum = MultinomialAucType.valueOf(aucType)
val nclasses = domain.length
val priorDistribution = priorDistributionOption match {
case Some(x) => x
case None => null
}
val getMetricBuilder =
() => new IndependentMetricBuilderMultinomial[_](nclasses, domain, aucTypeEnum, priorDistribution)

val gson = getMetricGson(
getMetricBuilder,
dataFrame,
predictionProbabilitiesCol,
labelCol,
offsetColOption,
weightColOption,
domain)
val result = new H2OMultinomialMetrics()
result.setMetrics(gson, "H2OMultinomialMetrics.calculate")
result
}

def calculate(
dataFrame: DataFrame,
domain: Array[String],
predictionProbabilitiesCol: String,
labelCol: String,
weightCol: String,
offsetCol: String,
priorDistribution: Array[Double],
aucType: String): H2OMultinomialMetrics = {
calculate(
dataFrame,
domain,
predictionProbabilitiesCol,
labelCol,
Option(weightCol),
Option(offsetCol),
Option(priorDistribution),
aucType)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.h2o.sparkling.ml.metrics

import ai.h2o.sparkling.ml.metrics.H2OBinomialMetrics.getMetricGson
import hex.DistributionFactory
import hex.ModelMetricsRegression.IndependentMetricBuilderRegression
import hex.genmodel.utils.DistributionFamily
import org.apache.spark.sql.DataFrame

object H2ORegressionMetrics {

def calculate(
dataFrame: DataFrame,
predictionCol: String = "prediction",
labelCol: String = "label",
weightColOption: Option[String] = None,
offsetColOption: Option[String] = None,
distributionFamily: String = "AUTO"): H2ORegressionMetrics = {
val domainFamilyEnum = DistributionFamily.valueOf(distributionFamily)
val distribution= DistributionFactory.getDistribution(domainFamilyEnum)
val getMetricBuilder = () => new IndependentMetricBuilderRegression[_](distribution)

val gson = getMetricGson(
getMetricBuilder,
dataFrame,
predictionCol,
labelCol,
offsetColOption,
weightColOption,
null)
val result = new H2ORegressionMetrics()
result.setMetrics(gson, "H2ORegressionMetrics.calculate")
result
}

def calculate(
dataFrame: DataFrame,
predictionCol: String,
labelCol: String,
weightCol: String,
offsetCol: String,
distributionFamily: String): H2ORegressionMetrics = {
calculate(
dataFrame,
predictionCol,
labelCol,
Option(weightCol),
Option(offsetCol),
distributionFamily)
}
}
Loading

0 comments on commit 5ecba2f

Please sign in to comment.