diff --git a/wayang-platforms/wayang-tensorflow/src/main/java/org/apache/wayang/tensorflow/model/TensorflowModel.java b/wayang-platforms/wayang-tensorflow/src/main/java/org/apache/wayang/tensorflow/model/TensorflowModel.java index 80113ffbf..e881f95de 100644 --- a/wayang-platforms/wayang-tensorflow/src/main/java/org/apache/wayang/tensorflow/model/TensorflowModel.java +++ b/wayang-platforms/wayang-tensorflow/src/main/java/org/apache/wayang/tensorflow/model/TensorflowModel.java @@ -22,10 +22,7 @@ import org.apache.wayang.basic.model.op.Input; import org.apache.wayang.basic.model.op.Op; import org.apache.wayang.basic.model.optimizer.Optimizer; -import org.tensorflow.Graph; -import org.tensorflow.Operand; -import org.tensorflow.Session; -import org.tensorflow.Tensor; +import org.tensorflow.*; import org.tensorflow.ndarray.*; import org.tensorflow.ndarray.index.Indices; import org.tensorflow.op.Ops; @@ -122,14 +119,12 @@ void train(XT x, YT y, int epoch, int batchSize) { if (accuracyCalculation != null) { runner.fetch(accuracyCalculation.getName()); } - List ret = runner.run(); - try (TFloat32 loss = (TFloat32) ret.get(0)) { + try (Result ret = runner.run()) { + TFloat32 loss = (TFloat32) ret.get(0); System.out.printf("[epoch %d, batch %d] loss: %f ", i + 1, start / batchSize + 1, loss.getFloat()); - } - if (accuracyCalculation != null) { - try (TFloat32 acc = (TFloat32) ret.get(1)) { - System.out.printf("accuracy: %f ", acc.getFloat()); - } + + TFloat32 acc = (TFloat32) ret.get(1); + System.out.printf("accuracy: %f ", acc.getFloat()); } System.out.println(); }