Skip to content

Commit

Permalink
updating Tensorflow versioning
Browse files Browse the repository at this point in the history
  • Loading branch information
zkaoudi committed Feb 21, 2025
1 parent 6980739 commit b2ec280
Showing 1 changed file with 6 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@
import org.apache.wayang.basic.model.op.Input;
import org.apache.wayang.basic.model.op.Op;
import org.apache.wayang.basic.model.optimizer.Optimizer;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.*;
import org.tensorflow.ndarray.*;
import org.tensorflow.ndarray.index.Indices;
import org.tensorflow.op.Ops;
Expand Down Expand Up @@ -122,14 +119,12 @@ void train(XT x, YT y, int epoch, int batchSize) {
if (accuracyCalculation != null) {
runner.fetch(accuracyCalculation.getName());
}
List<Tensor> ret = runner.run();
try (TFloat32 loss = (TFloat32) ret.get(0)) {
try (Result ret = runner.run()) {
TFloat32 loss = (TFloat32) ret.get(0);
System.out.printf("[epoch %d, batch %d] loss: %f ", i + 1, start / batchSize + 1, loss.getFloat());
}
if (accuracyCalculation != null) {
try (TFloat32 acc = (TFloat32) ret.get(1)) {
System.out.printf("accuracy: %f ", acc.getFloat());
}

TFloat32 acc = (TFloat32) ret.get(1);
System.out.printf("accuracy: %f ", acc.getFloat());
}
System.out.println();
}
Expand Down

0 comments on commit b2ec280

Please sign in to comment.