使用的数据集卫星图像有两类,airplane和lake,每个类别样本量各700张,大小为256*256,RGB三通道彩色卫星影像。搭建深度卷积神经网络,实现卫星影像二分类。
数据链接百度网盘地址,提取码: cq47
1、查看tensorflow版本
import tensorflow as tfprint('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())
2、加载并显示训练数据
从文件夹中获取所有数据路径
import glob
import randomall_image_path = glob.glob('./data/air_lake_dataset/*/*.jpg') # glob相比于pathlib更简洁
random.shuffle(all_image_path)
读取并处理图像
def load_and_preprocess_image(path):img_raw = tf.io.read_file(path)img_tensor = tf.image.decode_jpeg(img_raw,channels=3)img_tensor = tf.image.resize(img_tensor,[256,256])img_tensor = tf.cast(img_tensor,tf.float32)img_tensor = img_tensor/255return img_tensor
处理标签
label_to_index = {'airplane':0,'lake':1}
index_to_label = dict((v,k) for k,v in label_to_index.items())
labels = [label_to_index.get(img.split('/')[3]) for img in all_image_path]
显示卫星影像
import matplotlib.pyplot as pltdef plot_images_lables(all_image_path,labels,start_idx,num=5):fig = plt.gcf()fig.set_size_inches(12,14)images = [load_and_preprocess_image(img_path) for img_path in all_image_path[start_idx:start_idx+5]]for i in range(num):ax = plt.subplot(1,num,1+i)ax.imshow(images[i])title = 'label=' + index_to_label.get(labels[start_idx+i])ax.set_title(title,fontsize=10)ax.set_xticks([])ax.set_yticks([])plt.show()plot_images_lables(all_image_path,labels,0,5)
4、使用tf.data.Dataset制作训练/测试数据
制作 Dataset
img_ds = tf.data.Dataset.from_tensor_slices(all_image_path)
img_ds = img_ds.map(load_and_preprocess_image)
label_ds = tf.data.Dataset.from_tensor_slices(labels)
img_label_ds = tf.data.Dataset.zip((img_ds,label_ds))
训练集、测试集划分
test_count = int(len(labels)*0.2)
train_count = len(labels) - test_counttrain_ds = img_label_ds.skip(test_count)
test_ds = img_label_ds.take(test_count)
分批次加载数据
BATCH_SIZE = 16
train_ds = train_ds.repeat().shuffle(100).batch(BATCH_SIZE)
test_ds = test_ds.repeat().batch(BATCH_SIZE)
5、CNN模型构建
from keras.layers import Input,Dense,Dropout
from keras.layers import Conv2D,MaxPool2D,GlobalAvgPool2Dmodel = tf.keras.Sequential([Input(shape=(256,256,3)),Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'), # 增加filter个数,增加模型拟合能力Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'),MaxPool2D(), # 默认2*2. 池化层扩大视野Dropout(0.2), # 防止过拟合Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'),Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'),MaxPool2D(),Dropout(0.2),Conv2D(filters=256,kernel_size=(3,3),activation='relu',padding='same'),Conv2D(filters=256,kernel_size=(3,3),activation='relu',padding='same'),MaxPool2D(),Dropout(0.2),Conv2D(filters=512,kernel_size=(3,3),activation='relu',padding='same'),Conv2D(filters=512,kernel_size=(3,3),activation='relu',padding='same'),GlobalAvgPool2D(), # 全局平均池化Dense(1024,activation='relu'),Dense(256,activation='relu'),Dense(1,activation='sigmoid')
])model.summary()
6、模型编译与训练
model.compile(optimizer=tf.keras.optimizers.Adam(0.0001),loss=tf.keras.losses.BinaryCrossentropy(from_logits=False), # 已经使用sigmoid激活过了metrics=['acc'])steps_per_epoch = train_count//BATCH_SIZE
val_step = test_count//BATCH_SIZEH = model.fit(train_ds,epochs=10,steps_per_epoch=steps_per_epoch,validation_data=test_ds,validation_steps=val_step,verbose=1)
7、模型评估
import matplotlib.pyplot as pltfig = plt.gcf()
fig.set_size_inches(12,4)
plt.subplot(1,2,1)
plt.plot(H.epoch, H.history['loss'], label='loss')
plt.plot(H.epoch, H.history['val_loss'], label='val_loss')
plt.legend()
plt.title('loss')plt.subplot(1,2,2)
plt.plot(H.epoch, H.history['acc'], label='acc')
plt.plot(H.epoch, H.history['val_acc'], label='val_acc')
plt.legend()
plt.title('acc')
plt.show()
8、模型预测
def pred_img(img_path):img = load_and_preprocess_image(img_path)img = tf.expand_dims(img, axis=0)pred = model.predict(img)pred = index_to_label.get((pred>0.5).astype('int')[0][0])return predimg_path = './data/air_lake_dataset/airplane/airplane_240.jpg'
pred = pred_img(img_path)
img_tensor = load_and_preprocess_image(img_path)
plt.imshow(img_tensor)
title = 'label=' + img_path.split('/')[3].strip() + ', pred=' + pred
plt.title(title)
plt.show()