Skip to content

Commit

Permalink
Merge pull request #508 from Manas-Dikshit/main
Browse files Browse the repository at this point in the history
[SPARK] Implement Correct fit() and transform() in SparkKMeansOperator
  • Loading branch information
zkaoudi authored Feb 27, 2025
2 parents 14ffb58 + 5509aac commit 1d5736f
Showing 1 changed file with 28 additions and 33 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
Expand All @@ -24,13 +24,8 @@
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.*;
import org.apache.wayang.basic.data.Tuple2;
import org.apache.wayang.basic.operators.KMeansOperator;
import org.apache.wayang.core.optimizer.OptimizationContext;
Expand All @@ -49,15 +44,15 @@

public class SparkKMeansOperator extends KMeansOperator implements SparkExecutionOperator {

private static final StructType schema = DataTypes.createStructType(
private static final StructType SCHEMA = DataTypes.createStructType(
new StructField[]{
DataTypes.createStructField(Attr.FEATURES, new VectorUDT(), false)
DataTypes.createStructField("features", new VectorUDT(), false)
}
);

private static Dataset<Row> data2Row(JavaRDD<double[]> inputRdd) {
final JavaRDD<Row> rowRdd = inputRdd.map(e -> RowFactory.create(Vectors.dense(e)));
return SparkSession.builder().getOrCreate().createDataFrame(rowRdd, schema);
private static Dataset<Row> convertToDataFrame(JavaRDD<double[]> inputRdd) {
JavaRDD<Row> rowRdd = inputRdd.map(e -> RowFactory.create(Vectors.dense(e)));
return SparkSession.builder().getOrCreate().createDataFrame(rowRdd, SCHEMA);
}

public SparkKMeansOperator(int k) {
Expand Down Expand Up @@ -87,17 +82,17 @@ public Tuple<Collection<ExecutionLineageNode>, Collection<ChannelInstance>> eval
assert inputs.length == this.getNumInputs();
assert outputs.length == this.getNumOutputs();

final RddChannel.Instance input = (RddChannel.Instance) inputs[0];
final CollectionChannel.Instance output = (CollectionChannel.Instance) outputs[0];
RddChannel.Instance input = (RddChannel.Instance) inputs[0];
CollectionChannel.Instance output = (CollectionChannel.Instance) outputs[0];

final JavaRDD<double[]> inputRdd = input.provideRdd();
final Dataset<Row> df = data2Row(inputRdd);
final KMeansModel model = new KMeans()
JavaRDD<double[]> inputRdd = input.provideRdd();
Dataset<Row> df = convertToDataFrame(inputRdd);
KMeansModel model = new KMeans()
.setK(this.k)
.setFeaturesCol(Attr.FEATURES)
.setPredictionCol(Attr.PREDICTION)
.setFeaturesCol("features")
.setPredictionCol("prediction")
.fit(df);
final Model outputModel = new Model(model);
Model outputModel = new Model(model);
output.accept(Collections.singletonList(outputModel));

return ExecutionOperator.modelLazyExecution(inputs, outputs, operatorContext);
Expand Down Expand Up @@ -127,18 +122,18 @@ public double[][] getClusterCenters() {

@Override
public JavaRDD<Tuple2<double[], Integer>> transform(JavaRDD<double[]> input) {
final Dataset<Row> df = data2Row(input);
final Dataset<Row> transform = model.transform(df);
return transform.toJavaRDD()
.map(row -> new Tuple2<>(row.<Vector>getAs(Attr.FEATURES).toArray(), row.<Integer>getAs(Attr.PREDICTION)));
Dataset<Row> df = convertToDataFrame(input);
Dataset<Row> transformed = model.transform(df);
return transformed.toJavaRDD()
.map(row -> new Tuple2<>(row.<Vector>getAs("features").toArray(), row.<Integer>getAs("prediction")));
}

@Override
public JavaRDD<Integer> predict(JavaRDD<double[]> input) {
final Dataset<Row> df = data2Row(input);
final Dataset<Row> transform = model.transform(df);
return transform.toJavaRDD()
.map(row -> row.<Integer>getAs(Attr.PREDICTION));
Dataset<Row> df = convertToDataFrame(input);
Dataset<Row> transformed = model.transform(df);
return transformed.toJavaRDD()
.map(row -> row.<Integer>getAs("prediction"));
}
}
}

0 comments on commit 1d5736f

Please sign in to comment.