心血来潮,想从零开始编写一个相对完整的深度学习小项目。想到就做,那么首先要考虑的问题是,写什么?
思量再三,我决定写一个宠物识别系统,即给定一张图片,判断图片上的宠物是什么。宠物种类暂定为四类——猫、狗、鼠、兔。之所以想到做这个,是因为在不使用公开数据集的情况下,宠物图片数据集获取的难度相对低一些。
小项目分为如下几个部分:
- 爬虫。从网络上下载宠物图片,构建训练用的数据集。
- 模型构建、训练和调优。鉴于我们的数据比较少,这部分需要做迁移学习。
- 模型部署和Web服务。将训练好的模型部署成web接口,并使用Vue.js + Element UI编写测试页面。
好嘞,开搞吧!
本文涉及到的所有代码,均已上传到GitHub:
pets_classifer (https://github.com/AaronJny/pets_classifer)
转载请注明来源:https://blog.csdn.net/aaronjny/article/details/103605988
一、爬虫
训练模型肯定是需要数据集的,那么数据集从哪来?因为是从零开始嘛,假设我们做的这个问题,业内没有公开的数据集,我们需要自己制作数据集。
一个很简单的想法是,利用搜索引擎搜索相关图片,使用爬虫批量下载,然后人工去除不正确的图片。举个例子,我们先处理猫的图片,步骤如下:
- 1.使用搜索引擎搜索猫的图片。
- 2.使用爬虫将搜索出的猫的图片批量下载到本地,放到一个名为
cats
的文件夹里面。 - 3.人工浏览一遍图片,将“不包含猫”的图片和“除猫外还包含其他宠物(狗、鼠、兔)”的图片从文件夹中删除。
这样,猫的图片我们就搜集完成了,其他几个类别的图片也是类似的操作。不用担心人工过滤图片花费的时间较长,全部过一遍也就二十多分钟吧。
然后是搜索引擎的选择。搜索引擎用的比较多的无非两种——Google和百度。我分别使用Google和百度进行了图片搜索,发现百度的搜索结果远不如Google准确,于是就选择了Google,所以我的爬虫代码是基于Google编写的,运行我的爬虫代码需要你的网络能够访问Google。
如果你的网络不能访问Google,可以考虑自行实现基于百度的爬虫程序,逻辑都是相通的。
因为想让项目轻量级一些,故没有使用scrapy框架。爬虫使用requests+beautifulsoup4实现,并发使用gevent实现。
# -*- coding: utf-8 -*-
# @File : spider.py
# @Author : AaronJny
# @Time : 2019/12/16
# @Desc : 从谷歌下载指定图片
from gevent import monkeymonkey.patch_all()
import functools
import logging
import os
from bs4 import BeautifulSoup
from gevent.pool import Pool
import requests
import settings# 设置日志输出格式
logging.basicConfig(format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s',level=logging.INFO)# 搜索关键词字典
keywords_map = settings.IMAGE_CLASS_KEYWORD_MAP# 图片保存根目录
images_root = settings.IMAGES_ROOT
# 每个类别下载多少页图片
download_pages = settings.SPIDER_DOWNLOAD_PAGES
# 图片编号字典,每种图片都从0开始编号,然后递增
images_index_map = dict(zip(keywords_map.keys(), [0 for _ in keywords_map]))
# 图片去重器
duplication_filter = set()# 请求头
headers = {'accept-encoding': 'gzip, deflate, br','accept-language': 'zh-CN,zh;q=0.9','user-agent': 'Mozilla/5.0 (Linux; Android 4.0.4; Galaxy Nexus Build/IMM76B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/46.0.2490.76 Mobile Safari/537.36','accept': '*/*','referer': 'https://www.google.com/','authority': 'www.google.com',
}# 重试装饰器
def try_again_while_except(max_times=3):"""当出现异常时,自动重试。连续失败max_times次后放弃。"""def decorator(func):@functools.wraps(func)def wrapper(*args, **kwargs):error_cnt = 0error_msg = ''while error_cnt < max_times:try:return func(*args, **kwargs)except Exception as e:error_msg = str(e)error_cnt += 1if error_msg:logging.error(error_msg)return wrapperreturn decorator@try_again_while_except()
def download_image(session, image_url, image_class):"""从给定的url中下载图片,并保存到指定路径"""# 下载图片resp = session.get(image_url, timeout=20)# 检查图片是否下载成功if resp.status_code != 200:raise Exception('Response Status Code {}!'.format(resp.status_code))# 分配一个图片编号image_index = images_index_map.get(image_class, 0)# 更新待分配编号images_index_map[image_class] = image_index + 1# 拼接图片路径image_path = os.path.join(images_root, image_class, '{}.jpg'.format(image_index))# 保存图片with open(image_path, 'wb') as f:f.write(resp.content)# 成功写入了一张图片return True@try_again_while_except()
def get_and_analysis_google_search_page(session, page, image_class, keyword):"""使用google进行搜索,下载搜索结果页面,解析其中的图片地址,并对有效图片进一步发起请求"""logging.info('Class:{} Page:{} Processing...'.format(image_class, page + 1))# 记录从本页成功下载的图片数量downloaded_cnt = 0# 构建请求参数params = (('q', keyword),('tbm', 'isch'),('async', '_id:islrg_c,_fmt:html'),('asearch', 'ichunklite'),('start', str(page * 100)),('ijn', str(page)),)# 进行搜索resp = requests.get('https://www.google.com/search', params=params, timeout=20)# 解析搜索结果bsobj = BeautifulSoup(resp.content, 'lxml')divs = bsobj.find_all('div', {'class': 'islrtb isv-r'})for div in divs:image_url = div.get('data-ou')# 只有当图片以'.jpg','.jpeg','.png'结尾时才下载图片if image_url.endswith('.jpg') or image_url.endswith('.jpeg') or image_url.endswith('.png'):# 过滤掉相同图片if image_url not in duplication_filter:# 使用去重器记录duplication_filter.add(image_url)# 下载图片flag = download_image(session, image_url, image_class)if flag:downloaded_cnt += 1logging.info('Class:{} Page:{} Done. {} images downloaded.'.format(image_class, page + 1, downloaded_cnt))def search_with_google(image_class, keyword):"""通过google下载数据集"""# 创建session对象session = requests.session()session.headers.update(headers)# 每个类别下载10页数据for page in range(download_pages):get_and_analysis_google_search_page(session, page, image_class, keyword)def run():# 首先,创建数据文件夹if not os.path.exists(images_root):os.mkdir(images_root)for sub_images_dir in keywords_map.keys():# 对于每个图片类别都创建一个单独的文件夹保存sub_path = os.path.join(images_root, sub_images_dir)if not os.path.exists(sub_path):os.mkdir(sub_path)# 开始下载,这里使用gevent的协程池进行并发pool = Pool(len(keywords_map))for image_class, keyword in keywords_map.items():pool.spawn(search_with_google, image_class, keyword)pool.join()if __name__ == '__main__':run()
项目中涉及到的所有配置参数,都提取到了settings.py
中,内容如下,以供查阅:
# -*- coding: utf-8 -*-
# @File : settings.py
# @Author : AaronJny
# @Time : 2019/12/16
# @Desc :# ##########爬虫############# 图片类别和搜索关键词的映射关系
IMAGE_CLASS_KEYWORD_MAP = {'cats': '宠物猫','dogs': '宠物狗','mouses': '宠物鼠','rabbits': '宠物兔'
}
# 图片保存根目录
IMAGES_ROOT = './images'
# 爬虫每个类别下载多少页图片
SPIDER_DOWNLOAD_PAGES = 20# #########数据############ 每个类别选取的图片数量
SAMPLES_PER_CLASS = 345
# 参与训练的类别
CLASSES = ['cats', 'dogs', 'mouses', 'rabbits']
# 参与训练的类别数量
CLASS_NUM = len(CLASSES)
# 类别->编号的映射
CLASS_CODE_MAP = {'cats': 0,'dogs': 1,'mouses': 2,'rabbits': 3
}
# 编号->类别的映射
CODE_CLASS_MAP = {0: '猫',1: '狗',2: '鼠',3: '兔'
}
# 随机数种子
RANDOM_SEED = 13 # 四个类别时样本较为均衡的随机数种子
# RANDOM_SEED = 19 # 三个类别时样本较为均衡的随机数种子
# 训练集比例
TRAIN_DATASET = 0.6
# 开发集比例
DEV_DATASET = 0.2
# 测试集比例
TEST_DATASET = 0.2
# mini_batch大小
BATCH_SIZE = 16
# imagenet数据集均值
IMAGE_MEAN = [0.485, 0.456, 0.406]
# imagenet数据集标准差
IMAGE_STD = [0.299, 0.224, 0.225]# #########训练########## 学习率
LEARNING_RATE = 0.001
# 训练epoch数
TRAIN_EPOCHS = 30
# 保存训练模型的路径
MODEL_PATH = './model.h5'# ########Web########## Web服务端口
WEB_PORT = 5000
爬虫使用Google进行图片搜索,每个宠物搜索10页,下载其中的所有图片。当爬虫运行完成后,项目下会多出一个images
文件夹,点进去有四个子文件夹,分别为cats
、dogs
、mouses
、rabbits
。每一个子文件夹里面是对应类别的宠物图片。
其中猫图片600+张,狗图片600+张,鼠图片400+张,兔图片500+张。花二十多分钟时间,过一遍全部图片,剔除其中不符合要求的图片。注意,这一步是必做的,而且要认真对待,我吃了亏的= =
进行一轮筛选后,剩下图片张数:
宠物 | 图片数量 |
---|---|
猫 | 521 |
狗 | 526 |
鼠 | 346 |
兔 | 345 |
考虑各类别样本均衡的问题,无非是过采样和欠采样。因为是图片数据,也可以使用数据增强的手段,为图片数量较少的类别生成一些图片,使样本数量均衡。但出于如下原因考虑,我直接做了欠采样,即每个类别只选取了345张样本:
- 使用数据增强的话,需要在原图片的基础上,重新生成一份数据集,嫌麻烦……
- 使用数据增强后,样本数量比较多,无法同时读取到内存里面,只能写个生成器,处理哪一部分的时候,实时从硬盘读取。弊端有俩:①频繁读取硬盘,肯定比不上所有数据都放在内存里面,会拖慢训练速度;②还是嫌麻烦……
说到底就是自己太懒了……当然,可想而知,使用数据增强(在这里,数据增强可以作为一种过采样的方式)使数据样本都达到526,训练的效果肯定会更好,能好多少就不知道了,有兴趣的可以自行实现,没啥难点,就是麻烦点。
下面该对数据做预处理了。很多经典的模型接收的输入格式都为(None,224,224,3),由于我们的样本较少,不可避免地需要用到迁移学习,所以我们的数据格式与经典模型保持一致,也使用(None,224,224,3),下面是预处理过程:
# -*- coding: utf-8 -*-
# @File : data.py
# @Author : AaronJny
# @Time : 2019/12/16
# @Desc :
import os
import random
import tensorflow as tf
import settings# 每个类别选取的图片数量
samples_per_class = settings.SAMPLES_PER_CLASS
# 图片根目录
images_root = settings.IMAGES_ROOT
# 类别->编码的映射
class_code_map = settings.CLASS_CODE_MAP# 我们准备使用经典网络在imagenet数据集上的与训练权重,所以归一化时也要使用imagenet的平均值和标准差
image_mean = tf.constant(settings.IMAGE_MEAN)
image_std = tf.constant(settings.IMAGE_STD)def normalization(x):"""对输入图片x进行归一化,返回归一化的值"""return (x - image_mean) / image_stddef train_preprocess(x, y):"""对训练数据进行预处理。注意,这里的参数x是图片的路径,不是图片本身;y是图片的标签值"""# 读取图片x = tf.io.read_file(x)# 解码成张量x = tf.image.decode_jpeg(x, channels=3)# 将图片缩放到[244,244],比输入[224,224]稍大一些,方便后面数据增强x = tf.image.resize(x, [244, 244])# 随机决定是否左右镜像if random.choice([0, 1]):x = tf.image.random_flip_left_right(x)# 随机从x中剪裁出(224,224,3)大小的图片x = tf.image.random_crop(x, [224, 224, 3])# 读完上面的代码可以发现,这里的数据增强并不增加图片数量,一张图片经过变换后,# 仍然只是一张图片,跟我们前面说的增加图片数量的逻辑不太一样。# 这么做主要是应对我们的数据集里可能会存在相同图片的情况。# 将图片的像素值缩放到[0,1]之间x = tf.cast(x, dtype=tf.float32) / 255.# 归一化x = normalization(x)# 将标签转成one-hot形式y = tf.cast(y, dtype=tf.int32)y = tf.one_hot(y, settings.CLASS_NUM)return x, ydef dev_preprocess(x, y):"""对验证集和测试集进行数据预处理的方法。和train_preprocess的主要区别在于,不进行数据增强,以保证验证结果的稳定性。"""# 读取并缩放图片x = tf.io.read_file(x)x = tf.image.decode_jpeg(x, channels=3)x = tf.image.resize(x, [224, 224])# 归一化x = tf.cast(x, dtype=tf.float32) / 255.x = normalization(x)# 将标签转成one-hot形式y = tf.cast(y, dtype=tf.int32)y = tf.one_hot(y, settings.CLASS_NUM)return x, y# (图片路径,标签)的列表
image_path_and_labels = []
# 排序,保证每次拿到的顺序都一样
sub_images_dir_list = sorted(list(os.listdir(images_root)))
# 遍历每一个子目录
for sub_images_dir in sub_images_dir_list:sub_path = os.path.join(images_root, sub_images_dir)# 如果给定路径是文件夹,并且这个类别参与训练if os.path.isdir(sub_path) and sub_images_dir in settings.CLASSES:# 获取当前类别的编码current_label = class_code_map.get(sub_images_dir)# 获取子目录下的全部图片名称images = sorted(list(os.listdir(sub_path)))# 随机打乱(排序和置随机数种子都是为了保证每次的结果都一样)random.seed(settings.RANDOM_SEED)random.shuffle(images)# 保留前settings.SAMPLES_PER_CLASS个images = images[:samples_per_class]# 构建(x,y)对for image_name in images:abs_image_path = os.path.join(sub_path, image_name)image_path_and_labels.append((abs_image_path, current_label))
# 计算各数据集样例数
total_samples = len(image_path_and_labels) # 总样例数
train_samples = int(total_samples * settings.TRAIN_DATASET) # 训练集样例数
dev_samples = int(total_samples * settings.DEV_DATASET) # 开发集样例数
test_samples = total_samples - train_samples - dev_samples # 测试集样例数
# 打乱数据集
random.seed(settings.RANDOM_SEED)
random.shuffle(image_path_and_labels)
# 将图片数据和标签数据分开,此时它们仍是一一对应的
x_data = tf.constant([img for img, label in image_path_and_labels])
y_data = tf.constant([label for img, label in image_path_and_labels])
# 开始划分数据集
# 训练集
train_db = tf.data.Dataset.from_tensor_slices((x_data[:train_samples], y_data[:train_samples]))
# 打乱顺序,数据预处理,设置批大小
train_db = train_db.shuffle(10000).map(train_preprocess).batch(settings.BATCH_SIZE)
# 开发集(验证集)
dev_db = tf.data.Dataset.from_tensor_slices((x_data[train_samples:train_samples + dev_samples], y_data[train_samples:train_samples + dev_samples]))
# 数据预处理,设置批大小
dev_db = dev_db.map(dev_preprocess).batch(settings.BATCH_SIZE)
# 测试集
test_db = tf.data.Dataset.from_tensor_slices((x_data[train_samples + dev_samples:], y_data[train_samples + dev_samples:]))
# 数据预处理,设置批大小
test_db = test_db.map(dev_preprocess).batch(settings.BATCH_SIZE)
二、模型构建、训练和调优
数据已经全部处理完毕,该考虑模型了。首先,我们数据集太小了,直接构建自己的网络并训练,并不是一个好方案。因为这几种宠物其实挺难区分的,所以模型需要有一定复杂度,才能很好拟合这些数据,但我们的数据又太少了,最后的结果一定是过拟合,而且还是救不回来的那种= =所以我们考虑从迁移学习入手。
什么是迁移学习?懒得重新组织语言的我,默默地从之前写的博文里面摘了一段:
一般认为,深度卷积神经网络的训练是对数据集特征的一步步抽取的过程,从简单的特征,到复杂的特征。
训练好的模型学习到的是对图像特征的抽取方法,所以在imagenet数据集上训练好的模型理论上来说,也可以直接用于抽取其他图像的特征,这也是迁移学习的基础。自然,这样的效果往往没有在新数据上重新训练的效果好,但能够节省大量的训练时间,在特定情况下非常有用。
上面说的特定情况
也包括我们面临的这一种——用于实际问题的数据集过小。
说到迁移学习,我最先想到的是VGG16,就先用VGG16搞了一波。使用在imagenet数据集上预训练的VGG16网络,去除顶部的全连接层,冻结全部参数,使它们在接下来的训练中不会改变。然后加上自己的全连接层,最后的输出层节点为4,对应于我们的四分类问题。开始训练。
模型在训练集上的误差很快降到5%以下,但是在验证集上的准确率基本在70+%,很明显,过拟合了。好嘛,盘它!主要使用如下方法尝试解决过拟合问题:
- 调节全连接层的层数和每层的节点数
- 添加BN层(虽说不是为了解决过拟合问题诞生的,但一定程度上是有效果的)
- 添加Dropout层
- 调节Dropout Rate
- 添加l2正则
一顿操作猛如虎,回头一看0-5。这些方法确实对过拟合有所缓解,验证集上的准确率也确实有所提升,但只能达到81%左右。
然后我尝试了Resnet50,当然也过拟合了,盘它!最后验证集accuracy能达到83%左右。
很明显了,在全连接层的调整意义不大,究其根本,在于VGG16和ResNet50去除了全连接层之后,参数的数量也达到了20M+。两千万的参数使得模型严重过拟合,所以我们需要换一个参数少一点的模型。
于是,我盯上了DenseNet121,它的参数数量只有7M。继续盘它!果然,在一段时间的调优后,模型的性能有了明显的提升,验证集上的accuracy达到了87%左右。虽然和ResNet相比,准确率只高了4%,但相比于ResNet50 96%的训练accuracy而言,DenseNet121的训练accuracy只有90%左右。也就是说,对于DenseNet121而言,这个问题已经不再是过拟合问题了(相差3%我是可以接受的),而是欠拟合了。
然而淡腾的是,再怎么调参,模型都很难继续拟合了,调小学习率也不行。模型本身没啥问题的话,我开始怀疑数据集有没有问题,毕竟这种无法拟合的问题有很大概率是数据导致的。于是我就去检查了一下数据集……
这就是我前面强调认真过一遍数据集的原因了,我当时只是花个几分钟粗略地过了一下,删除掉一些明显不对的图片。我第二次认真过数据集的时候才发现,有很多异常图片没有过滤掉,比如猫的目录下有狗的图片,狗的目录下有猫的图片,还有一些不同动物同框的图片,以及我自己都认不出来的图片……
文章第一部分中各类图片数量的表格,其实就是我第二遍过滤后的结果统计。
过滤完成后,模型的性能有了明显的提升,训练accuracy约为93%-94%,验证accuracy为94%,测试accuracy为92%.我们先来看一下代码,后面会对这个结果再进行分析。
首先,是模型的构建:
# -*- coding: utf-8 -*-
# @File : models.py
# @Author : AaronJny
# @Time : 2019/12/16
# @Desc :
import tensorflow as tf
import settingsdef my_densenet():"""创建并返回一个基于densenet的Model对象"""# 获取densenet网络,使用在imagenet上训练的参数值,移除头部的全连接网络,池化层使用max_poolingdensenet = tf.keras.applications.DenseNet121(include_top=False, weights='imagenet', pooling='max')# 冻结预训练的参数,在之后的模型训练中不会改变它们densenet.trainable = False# 构建模型model = tf.keras.Sequential([# 输入层,shape为(None,224,224,3)tf.keras.layers.Input((224, 224, 3)),# 输入到DenseNet121中densenet,# 将DenseNet121的输出展平,以作为全连接层的输入tf.keras.layers.Flatten(),# 添加BN层tf.keras.layers.BatchNormalization(),# 随机失活tf.keras.layers.Dropout(0.5),# 第一个全连接层,激活函数relutf.keras.layers.Dense(512, activation=tf.nn.relu),# BN层tf.keras.layers.BatchNormalization(),# 随机失活tf.keras.layers.Dropout(0.5),# 第二个全连接层,激活函数relutf.keras.layers.Dense(64, activation=tf.nn.relu),# BN层tf.keras.layers.BatchNormalization(),# 输出层,为了保证输出结果的稳定,这里就不添加Dropout层了tf.keras.layers.Dense(settings.CLASS_NUM, activation=tf.nn.softmax)])return modelif __name__ == '__main__':model = my_densenet()model.summary()
网络的summary:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
densenet121 (Model) (None, 1024) 7037504
_________________________________________________________________
flatten (Flatten) (None, 1024) 0
_________________________________________________________________
batch_normalization (BatchNo (None, 1024) 4096
_________________________________________________________________
dropout (Dropout) (None, 1024) 0
_________________________________________________________________
dense (Dense) (None, 512) 524800
_________________________________________________________________
batch_normalization_1 (Batch (None, 512) 2048
_________________________________________________________________
dropout_1 (Dropout) (None, 512) 0
_________________________________________________________________
dense_1 (Dense) (None, 64) 32832
_________________________________________________________________
batch_normalization_2 (Batch (None, 64) 256
_________________________________________________________________
dense_2 (Dense) (None, 4) 260
=================================================================
Total params: 7,601,796
Trainable params: 561,092
Non-trainable params: 7,040,704
_________________________________________________________________
参数总量7601796个,其中可训练参数561092个 。
模型和数据都已准备完毕,可以开始训练了。让我们编写一个训练用的脚本:
# -*- coding: utf-8 -*-
# @File : train.py
# @Author : AaronJny
# @Time : 2019/12/17
# @Desc :
import tensorflow as tf
from data import train_db, dev_db
import models
import settings# 从models文件中导入模型
model = models.my_densenet()
model.summary()# 配置优化器、损失函数、以及监控指标
model.compile(tf.keras.optimizers.Adam(settings.LEARNING_RATE), loss=tf.keras.losses.categorical_crossentropy,metrics=['accuracy'])# 在每个epoch结束后尝试保存模型参数,只有当前参数的val_accuracy比之前保存的更优时,才会覆盖掉之前保存的参数
model_check_point = tf.keras.callbacks.ModelCheckpoint(filepath=settings.MODEL_PATH, monitor='val_accuracy',save_best_only=True)
# 使用tf.keras的高级接口进行训练
model.fit_generator(train_db, epochs=settings.TRAIN_EPOCHS, validation_data=dev_db, callbacks=[model_check_point])
现在,我们可以运行脚本进行训练了,最优的参数将被保存在settings.MODEL_PATH
。训练完成后,我们需要调用验证脚本,验证下模型在验证集和测试集上的表现:
# -*- coding: utf-8 -*-
# @File : eval.py
# @Author : AaronJny
# @Time : 2019/12/17
# @Desc :
import tensorflow as tf
from data import dev_db, test_db
from models import my_densenet
import settings# 创建模型
model = my_densenet()
# 加载参数
model.load_weights(settings.MODEL_PATH)
# 因为想用tf.keras的高级接口做验证,所以还是需要编译模型
model.compile(tf.keras.optimizers.Adam(settings.LEARNING_RATE), loss=tf.keras.losses.categorical_crossentropy,metrics=['accuracy'])
# 验证集accuracy
print('dev', model.evaluate(dev_db))
# 测试集accuracy
print('test', model.evaluate(test_db))
输出如下:
18/18 [==============================] - 5s 304ms/step - loss: 0.1936 - accuracy: 0.9457
dev [0.19364455559601387, 0.9456522]
18/18 [==============================] - 1s 64ms/step - loss: 0.2666 - accuracy: 0.9203
test [0.26657224384446937, 0.9202899]
能够看到,模型在验证集上的准确率为94.57%,在测试集上的准确率为92.03%,已经达到我的心里预期了,毕竟这么少的数据,还要啥自行车?
随着训练epoch的增多,模型的训练accuracy始终在[0.92,0.95]左右徘徊不定,没法继续拟合。究其原因,应该还是数据的锅。我们看一下识别错的样本,在eval.py脚本中,增加下面这一段程序:
# 查看识别错误的数据
for x, y in test_db:y_pred = model(x)y_pred = tf.argmax(y_pred, axis=1).numpy()y_true = tf.argmax(y, axis=1).numpy()batch_size = y_pred.shape[0]for i in range(batch_size):if y_pred[i] != y_true[i]:print('{} 被错误识别成 {}!'.format(settings.CODE_CLASS_MAP[y_true[i]], settings.CODE_CLASS_MAP[y_pred[i]]))
重新跑一下eval.py脚本,输出如下:
18/18 [==============================] - 5s 291ms/step - loss: 0.1936 - accuracy: 0.9457
dev [0.19364455559601387, 0.9456522]
18/18 [==============================] - 1s 64ms/step - loss: 0.2666 - accuracy: 0.9203
test [0.26657224384446937, 0.9202899]
狗 被错误识别成 兔!
狗 被错误识别成 兔!
狗 被错误识别成 兔!
鼠 被错误识别成 兔!
狗 被错误识别成 猫!
鼠 被错误识别成 猫!
狗 被错误识别成 兔!
狗 被错误识别成 鼠!
鼠 被错误识别成 兔!
狗 被错误识别成 兔!
猫 被错误识别成 兔!
猫 被错误识别成 鼠!
猫 被错误识别成 兔!
鼠 被错误识别成 兔!
狗 被错误识别成 兔!
狗 被错误识别成 猫!
鼠 被错误识别成 兔!
狗 被错误识别成 兔!
鼠 被错误识别成 兔!
狗 被错误识别成 猫!
鼠 被错误识别成 兔!
狗 被错误识别成 兔!
来,跟我一起唱——都是兔子惹的祸~
能够看到,出错的大部分都是被误识别成兔子了。对应到数据集上,虽然已经删掉了部分问题比较大的图片,但兔子的图片确实不好认。有很多兔子图片我人工分辨都认不出是兔子(捂脸.jpg)。然后,有些兔子图片看起来很像猫,有些看起来很像狗,有些看起来很像鼠……
如果我们把兔子图片去掉,将系统改为三分类问题,准确度将大幅度提高。当然了,按理说识别的类别数量变了,除了调整输出层的节点数量外,要想取得最佳效果,模型的其他参数也需要做相应调整的。我自己已经实测了,但限于篇幅,就不演示了,如果有兴趣的话,可以直接在settings.py
里进行调整,将它变为三分类问题。改这两个地方:
# 参与训练的类别
CLASSES = ['cats', 'dogs', 'mouses', 'rabbits']
# 随机数种子
RANDOM_SEED = 13 # 四个类别时样本较为均衡的随机数种子
改成:
# 参与训练的类别
CLASSES = ['cats', 'dogs', 'mouses']
# 随机数种子
RANDOM_SEED = 19 # 三个类别时样本较为均衡的随机数种子
然后重新训练和验证即可。这只是一个插曲,本文仍然以四分类问题继续说明后续内容。
三、Web接口编写
模型训练好了,我们要把它应用起来。我准备编写一个Web服务,用户可以通过浏览器上传一张图片,服务器判断此图片的类别后,返回相关数据给用户。Web后端使用Flask,小而轻,前端则选用Vue.js + Element-UI实现。
先写后端:
# -*- coding: utf-8 -*-
# @File : app.py
# @Author : AaronJny
# @Time : 2019/12/18
# @Desc :
from flask import Flask
from flask import jsonify
from flask import request, render_template
import tensorflow as tf
from models import my_densenet
import settingsapp = Flask(__name__)# 导入模型
model = my_densenet()
# 加载训练好的参数
model.load_weights(settings.MODEL_PATH)@app.route('/', methods=['GET'])
def index():"""首页,vue入口"""return render_template('index.html')@app.route('/api/v1/pets_classify/', methods=['POST'])
def pets_classify():"""宠物图片分类接口,上传一张图片,返回此图片上的宠物是那种类别,概率多少"""# 获取用户上传的图片img_str = request.files.get('file').read()# 进行数据预处理x = tf.image.decode_image(img_str, channels=3)x = tf.image.resize(x, (224, 224))x = x / 255.x = (x - tf.constant(settings.IMAGE_MEAN)) / tf.constant(settings.IMAGE_STD)x = tf.reshape(x, (1, 224, 224, 3))# 预测y_pred = model(x)pet_cls_code = tf.argmax(y_pred, axis=1).numpy()[0]pet_cls_prob = float(y_pred.numpy()[0][pet_cls_code])pet_cls_prob = '{}%'.format(int(pet_cls_prob * 100))pet_class = settings.CODE_CLASS_MAP.get(pet_cls_code)# 将预测结果组织成jsonres = {'code': 0,'data': {'pet_cls': pet_class,'probability': pet_cls_prob,'msg': '<br><br><strong style="font-size: 48px;">{}</strong> <span style="font-size: 24px;"''>概率<strong>{}</strong></span>'.format(pet_class, pet_cls_prob),}}# 返回json数据return jsonify(res)if __name__ == '__main__':app.run(port=settings.WEB_PORT)
后端脚本app.py
很简单,主要就两个方法。其中index
方法会返回首页的html源码,是用户在浏览器端的访问入口;另一个方法pets_classify
则提供了计算给定图片类别的功能。
前端文件index.html
主要是提供了一个照片墙,用户上传图片到照片墙,服务器就会计算图片类别并返回相关数据。代码如下:
<!DOCTYPE html>
<html>
<head><meta charset="UTF-8"><!-- import CSS --><link rel="stylesheet" href="https://unpkg.com/element-ui/lib/theme-chalk/index.css">
</head>
<body>
<div id="app"><el-card class="box-card"><div slot="header" class="clearfix"><h1>宠物识别Demo</h1></div><el-uploadaction="http://localhost:5000/api/v1/pets_classify/"list-type="picture-card":on-preview="handlePictureCardPreview":on-success="handleUploadSuccess":on-remove="handleRemove"><i class="el-icon-plus"></i></el-upload><el-dialog :visible.sync="dialogVisible"><img width="100%" :src="dialogImageUrl" alt=""></el-dialog></el-card>
</div>
</body><!-- import Vue before Element -->
<script src="https://unpkg.com/vue/dist/vue.js"></script>
<!-- import JavaScript -->
<script src="https://unpkg.com/element-ui/lib/index.js"></script>
<script>new Vue({el: '#app',data() {return {dialogImageUrl: '',dialogVisible: false};},methods: {handleRemove(file, fileList) {console.log(file, fileList);console.log(this.dialogImageUrl);},handlePictureCardPreview(file) {this.dialogImageUrl = file.url;this.dialogVisible = true;},handleUploadSuccess(response, file, fileList) {this.$notify({title: '识别结果',message: response.data.msg,dangerouslyUseHTMLString: true,type: 'success',duration: 3000});}}})
</script>
</html>
让我们试试效果。首先,运行app.py
脚本,启动web服务,当你看到如下输出时,说明服务启动成功了:
* Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)* Serving Flask app "app" (lazy loading)* Environment: productionWARNING: This is a development server. Do not use it in a production deployment.Use a production WSGI server instead.* Debug mode: off
因为只是开发环境,这么启动就可以了。如果是生产环境,请不要这么做,可以选择使用nginx + gunicorn +uWSGI + gevent进行部署。
四、测试
打开浏览器,输入 http://localhost:5000 进入index页面。页面长这个样子:
点击网页中的上传框,我们可以选择图片上传并识别:
当然了,这里不选择我们数据集里的图片更好,哪怕是测试集里的。你可以去网上下载、或者通过其他渠道获取这四种动物的图片来测试,这里我只做演示,就不搞那么麻烦了,直接从数据集里随便选几张照片。我们可以继续上传图片给服务器识别:
OK,演示到此为止,如果有兴趣的话可以自行测试。
结语
文章到此结束,如果您喜欢的话,给我点个赞呗~
菜鸟一只,欢迎大佬们拍砖~