Skip to content

Commit c94ac17

Browse files
committed
edited transfer learning tutorial
1 parent 62d02cc commit c94ac17

File tree

4 files changed

+10
-11
lines changed

4 files changed

+10
-11
lines changed

machine-learning/image-classifier-using-transfer-learning/test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# constructs the model
66
model = create_model(input_shape=IMAGE_SHAPE)
77
# load the optimal weights
8-
model.load_weights("results/MobileNetV2_finetune_last5_less_lr-loss-0.45-acc-0.86.h5")
8+
model.load_weights("results/MobileNetV2_finetune_last5-loss-0.66.h5")
99

1010
validation_steps_per_epoch = np.ceil(validation_generator.samples / batch_size)
1111
# print the validation loss & accuracy

machine-learning/image-classifier-using-transfer-learning/train.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
21
import tensorflow as tf
3-
from keras.models import Model
4-
from keras.applications import MobileNetV2, ResNet50, InceptionV3 # try to use them and see which is better
5-
from keras.layers import Dense
6-
from keras.callbacks import ModelCheckpoint, TensorBoard
7-
from keras.utils import get_file
8-
from keras.preprocessing.image import ImageDataGenerator
2+
from tensorflow.keras.models import Model
3+
from tensorflow.keras.applications import MobileNetV2, ResNet50, InceptionV3 # try to use them and see which is better
4+
from tensorflow.keras.layers import Dense
5+
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
6+
from tensorflow.keras.utils import get_file
7+
from tensorflow.keras.preprocessing.image import ImageDataGenerator
98
import os
109
import pathlib
1110
import numpy as np
@@ -65,7 +64,7 @@ def create_model(input_shape):
6564
# print the summary of the model architecture
6665
model.summary()
6766

68-
# training the model using rmsprop optimizer
67+
# training the model using adam optimizer
6968
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
7069
return model
7170

@@ -81,8 +80,8 @@ def create_model(input_shape):
8180
model_name = "MobileNetV2_finetune_last5"
8281

8382
# some nice callbacks
84-
tensorboard = TensorBoard(log_dir=f"logs/{model_name}")
85-
checkpoint = ModelCheckpoint(f"results/{model_name}" + "-loss-{val_loss:.2f}-acc-{val_acc:.2f}.h5",
83+
tensorboard = TensorBoard(log_dir=os.path.join("logs", model_name))
84+
checkpoint = ModelCheckpoint(os.path.join("results", f"{model_name}" + "-loss-{val_loss:.2f}.h5"),
8685
save_best_only=True,
8786
verbose=1)
8887

0 commit comments

Comments
 (0)