使用vgg16遷移學習達成訓練目標-2013kaggle

在深度學習的路上,手寫辨識被稱作深度學習的hello world,那2013年貓狗分類的比賽,也是很多人會遇到的題目,那如何獲得高準確率並且不用很多張圖片呢?
這時候遷移學習(Transfer Learning)就上場了,什麼是遷移學習呢?簡單來說就是拿別人訓練好的模型,把輸出加以修改,來達成你要學習的東西,因此如果你的訓練樣本很少,遷移學習或許是個好選擇。
vgg被用來訓練辨識ImageNet的1000種目標,那我們這次要辨識的目標是貓跟狗,貓跟狗也有在ImageNet的1000種目標裡面,所以很適合用遷移學習來完成。
這次的訓練集為:貓500張,狗500張,這麼少的資料就可以達到90%的準確率


程式碼如下:

from keras import models
from keras import layers
from keras.applications import VGG16
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
import matplotlib.pyplot as plt
train_dir = './data/train/'
test_dir = './data/validation/'  
conv_base=VGG16(weights='imagenet',include_top=False,input_shape=(150,150,3)) #載入VGG16

model = models.Sequential()  
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(output_dim = 256, activation = 'relu'))
model.add(layers.Dense(output_dim = 1, activation = 'sigmoid')) #Dense輸出有幾個output_dim就要改成那個數字,不過輸出只要不是1,activation(激活函數)就必須修改成softmax
conv_base.trainable=True
set_trainable=False
for layer in conv_base.layers:  #解凍block5
    if layer.name=='block5_conv1':
        set_trainable=True
    if set_trainable:
        layer.trainable=True
    else:
        layer.trainable=False
conv_base.summary()
#======================================================================================================

train_datagen = ImageDataGenerator(rescale = 1./255, 
                                   rotation_range=40,#rotation range的作用是用戶指定旋轉角度範圍
                                   shear_range = 0.2, #shear_range就是錯切變換
                                   zoom_range = 0.2,#zoom_range參數可以讓圖片在長或寬的方向進行放大
                                   width_shift_range=0.2,#width_shift_range是水平位置平移
                                   height_shift_range=0.2,#height_shift_range 是上下位置平移
                                   horizontal_flip = True,  #horizo​​ntal_flip的作用是隨機對圖片執行水平翻轉操作
                                   fill_mode='nearest')  #fill_mode為填充模式
train_generator=train_datagen.flow_from_directory(train_dir,
                                                  target_size=(150, 150),
                                                  batch_size=20,
                                                  class_mode = 'binary')
#====================================================================================
test_datagen = ImageDataGenerator(rescale = 1./255)
validation_generator=test_datagen.flow_from_directory(test_dir,
                                                  target_size=(150, 150),
                                                  batch_size=20,
                                                  class_mode = 'binary')
#=====================================================================================
model.compile(optimizer = optimizers.RMSprop(lr=1e-5), loss = 'binary_crossentropy', metrics = ['acc'])
his=model.fit_generator(train_generator, 
                        steps_per_epoch = 50,
                        epochs = 10, #迭代次數
                        validation_data = validation_generator,#驗證集
                        validation_steps=50)  #*2
#==========================================================================================
plt.plot(his.history['acc'])
plt.plot(his.history['val_acc'])
plt.title('model_accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
plt.plot(his.history['loss'])
plt.plot(his.history['val_loss'])
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
#model.save('catanddog2000_vgg16_epochs.50-ferr3conv.h5')

留言

這個網誌中的熱門文章

使用DLIB函式庫達成即時人臉辨識功能

以dlib實現人臉辨識打卡系統

使用Python達成影像形態學處理(不使用Opencv函式庫)