从零开始编写一个宠物识别系统(爬虫、模型训练和调优、模型部署、Web服务)

心血来潮,想从零开始编写一个相对完整的深度学习小项目。想到就做,那么首先要考虑的问题是,写什么?

思量再三,我决定写一个宠物识别系统,即给定一张图片,判断图片上的宠物是什么。宠物种类暂定为四类——猫、狗、鼠、兔。之所以想到做这个,是因为在不使用公开数据集的情况下,宠物图片数据集获取的难度相对低一些。

小项目分为如下几个部分:

  • 爬虫。从网络上下载宠物图片,构建训练用的数据集。
  • 模型构建、训练和调优。鉴于我们的数据比较少,这部分需要做迁移学习。
  • 模型部署和Web服务。将训练好的模型部署成web接口,并使用Vue.js + Element UI编写测试页面。

好嘞,开搞吧!

本文涉及到的所有代码,均已上传到GitHub:

pets_classifer (https://github.com/AaronJny/pets_classifer)

转载请注明来源:https://blog.csdn.net/aaronjny/article/details/103605988

一、爬虫

训练模型肯定是需要数据集的,那么数据集从哪来?因为是从零开始嘛,假设我们做的这个问题,业内没有公开的数据集,我们需要自己制作数据集。

一个很简单的想法是,利用搜索引擎搜索相关图片,使用爬虫批量下载,然后人工去除不正确的图片。举个例子,我们先处理猫的图片,步骤如下:

  • 1.使用搜索引擎搜索猫的图片。
  • 2.使用爬虫将搜索出的猫的图片批量下载到本地,放到一个名为cats的文件夹里面。
  • 3.人工浏览一遍图片,将“不包含猫”的图片和“除猫外还包含其他宠物(狗、鼠、兔)”的图片从文件夹中删除。

这样,猫的图片我们就搜集完成了,其他几个类别的图片也是类似的操作。不用担心人工过滤图片花费的时间较长,全部过一遍也就二十多分钟吧。

然后是搜索引擎的选择。搜索引擎用的比较多的无非两种——Google和百度。我分别使用Google和百度进行了图片搜索,发现百度的搜索结果远不如Google准确,于是就选择了Google,所以我的爬虫代码是基于Google编写的,运行我的爬虫代码需要你的网络能够访问Google。

如果你的网络不能访问Google,可以考虑自行实现基于百度的爬虫程序,逻辑都是相通的。

因为想让项目轻量级一些,故没有使用scrapy框架。爬虫使用requests+beautifulsoup4实现,并发使用gevent实现。

# -*- coding: utf-8 -*-
# @File    : spider.py
# @Author  : AaronJny
# @Time    : 2019/12/16
# @Desc    : 从谷歌下载指定图片
from gevent import monkeymonkey.patch_all()
import functools
import logging
import os
from bs4 import BeautifulSoup
from gevent.pool import Pool
import requests
import settings# 设置日志输出格式
logging.basicConfig(format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s',level=logging.INFO)# 搜索关键词字典
keywords_map = settings.IMAGE_CLASS_KEYWORD_MAP# 图片保存根目录
images_root = settings.IMAGES_ROOT
# 每个类别下载多少页图片
download_pages = settings.SPIDER_DOWNLOAD_PAGES
# 图片编号字典,每种图片都从0开始编号,然后递增
images_index_map = dict(zip(keywords_map.keys(), [0 for _ in keywords_map]))
# 图片去重器
duplication_filter = set()# 请求头
headers = {'accept-encoding': 'gzip, deflate, br','accept-language': 'zh-CN,zh;q=0.9','user-agent': 'Mozilla/5.0 (Linux; Android 4.0.4; Galaxy Nexus Build/IMM76B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/46.0.2490.76 Mobile Safari/537.36','accept': '*/*','referer': 'https://www.google.com/','authority': 'www.google.com',
}# 重试装饰器
def try_again_while_except(max_times=3):"""当出现异常时,自动重试。连续失败max_times次后放弃。"""def decorator(func):@functools.wraps(func)def wrapper(*args, **kwargs):error_cnt = 0error_msg = ''while error_cnt < max_times:try:return func(*args, **kwargs)except Exception as e:error_msg = str(e)error_cnt += 1if error_msg:logging.error(error_msg)return wrapperreturn decorator@try_again_while_except()
def download_image(session, image_url, image_class):"""从给定的url中下载图片,并保存到指定路径"""# 下载图片resp = session.get(image_url, timeout=20)# 检查图片是否下载成功if resp.status_code != 200:raise Exception('Response Status Code {}!'.format(resp.status_code))# 分配一个图片编号image_index = images_index_map.get(image_class, 0)# 更新待分配编号images_index_map[image_class] = image_index + 1# 拼接图片路径image_path = os.path.join(images_root, image_class, '{}.jpg'.format(image_index))# 保存图片with open(image_path, 'wb') as f:f.write(resp.content)# 成功写入了一张图片return True@try_again_while_except()
def get_and_analysis_google_search_page(session, page, image_class, keyword):"""使用google进行搜索,下载搜索结果页面,解析其中的图片地址,并对有效图片进一步发起请求"""logging.info('Class:{} Page:{} Processing...'.format(image_class, page + 1))# 记录从本页成功下载的图片数量downloaded_cnt = 0# 构建请求参数params = (('q', keyword),('tbm', 'isch'),('async', '_id:islrg_c,_fmt:html'),('asearch', 'ichunklite'),('start', str(page * 100)),('ijn', str(page)),)# 进行搜索resp = requests.get('https://www.google.com/search', params=params, timeout=20)# 解析搜索结果bsobj = BeautifulSoup(resp.content, 'lxml')divs = bsobj.find_all('div', {'class': 'islrtb isv-r'})for div in divs:image_url = div.get('data-ou')# 只有当图片以'.jpg','.jpeg','.png'结尾时才下载图片if image_url.endswith('.jpg') or image_url.endswith('.jpeg') or image_url.endswith('.png'):# 过滤掉相同图片if image_url not in duplication_filter:# 使用去重器记录duplication_filter.add(image_url)# 下载图片flag = download_image(session, image_url, image_class)if flag:downloaded_cnt += 1logging.info('Class:{} Page:{} Done. {} images downloaded.'.format(image_class, page + 1, downloaded_cnt))def search_with_google(image_class, keyword):"""通过google下载数据集"""# 创建session对象session = requests.session()session.headers.update(headers)# 每个类别下载10页数据for page in range(download_pages):get_and_analysis_google_search_page(session, page, image_class, keyword)def run():# 首先,创建数据文件夹if not os.path.exists(images_root):os.mkdir(images_root)for sub_images_dir in keywords_map.keys():# 对于每个图片类别都创建一个单独的文件夹保存sub_path = os.path.join(images_root, sub_images_dir)if not os.path.exists(sub_path):os.mkdir(sub_path)# 开始下载,这里使用gevent的协程池进行并发pool = Pool(len(keywords_map))for image_class, keyword in keywords_map.items():pool.spawn(search_with_google, image_class, keyword)pool.join()if __name__ == '__main__':run()

项目中涉及到的所有配置参数,都提取到了settings.py中,内容如下,以供查阅:

# -*- coding: utf-8 -*-
# @File    : settings.py
# @Author  : AaronJny
# @Time    : 2019/12/16
# @Desc    :# ##########爬虫############# 图片类别和搜索关键词的映射关系
IMAGE_CLASS_KEYWORD_MAP = {'cats': '宠物猫','dogs': '宠物狗','mouses': '宠物鼠','rabbits': '宠物兔'
}
# 图片保存根目录
IMAGES_ROOT = './images'
# 爬虫每个类别下载多少页图片
SPIDER_DOWNLOAD_PAGES = 20# #########数据############ 每个类别选取的图片数量
SAMPLES_PER_CLASS = 345
# 参与训练的类别
CLASSES = ['cats', 'dogs', 'mouses', 'rabbits']
# 参与训练的类别数量
CLASS_NUM = len(CLASSES)
# 类别->编号的映射
CLASS_CODE_MAP = {'cats': 0,'dogs': 1,'mouses': 2,'rabbits': 3
}
# 编号->类别的映射
CODE_CLASS_MAP = {0: '猫',1: '狗',2: '鼠',3: '兔'
}
# 随机数种子
RANDOM_SEED = 13  # 四个类别时样本较为均衡的随机数种子
# RANDOM_SEED = 19  # 三个类别时样本较为均衡的随机数种子
# 训练集比例
TRAIN_DATASET = 0.6
# 开发集比例
DEV_DATASET = 0.2
# 测试集比例
TEST_DATASET = 0.2
# mini_batch大小
BATCH_SIZE = 16
# imagenet数据集均值
IMAGE_MEAN = [0.485, 0.456, 0.406]
# imagenet数据集标准差
IMAGE_STD = [0.299, 0.224, 0.225]# #########训练########## 学习率
LEARNING_RATE = 0.001
# 训练epoch数
TRAIN_EPOCHS = 30
# 保存训练模型的路径
MODEL_PATH = './model.h5'# ########Web########## Web服务端口
WEB_PORT = 5000

爬虫使用Google进行图片搜索,每个宠物搜索10页,下载其中的所有图片。当爬虫运行完成后,项目下会多出一个images文件夹,点进去有四个子文件夹,分别为catsdogsmousesrabbits。每一个子文件夹里面是对应类别的宠物图片。

image.png

其中猫图片600+张,狗图片600+张,鼠图片400+张,兔图片500+张。花二十多分钟时间,过一遍全部图片,剔除其中不符合要求的图片。注意,这一步是必做的,而且要认真对待,我吃了亏的= =

进行一轮筛选后,剩下图片张数:

宠物图片数量
521
526
346
345

考虑各类别样本均衡的问题,无非是过采样和欠采样。因为是图片数据,也可以使用数据增强的手段,为图片数量较少的类别生成一些图片,使样本数量均衡。但出于如下原因考虑,我直接做了欠采样,即每个类别只选取了345张样本:

  • 使用数据增强的话,需要在原图片的基础上,重新生成一份数据集,嫌麻烦……
  • 使用数据增强后,样本数量比较多,无法同时读取到内存里面,只能写个生成器,处理哪一部分的时候,实时从硬盘读取。弊端有俩:①频繁读取硬盘,肯定比不上所有数据都放在内存里面,会拖慢训练速度;②还是嫌麻烦……

说到底就是自己太懒了……当然,可想而知,使用数据增强(在这里,数据增强可以作为一种过采样的方式)使数据样本都达到526,训练的效果肯定会更好,能好多少就不知道了,有兴趣的可以自行实现,没啥难点,就是麻烦点。

下面该对数据做预处理了。很多经典的模型接收的输入格式都为(None,224,224,3),由于我们的样本较少,不可避免地需要用到迁移学习,所以我们的数据格式与经典模型保持一致,也使用(None,224,224,3),下面是预处理过程:

# -*- coding: utf-8 -*-
# @File    : data.py
# @Author  : AaronJny
# @Time    : 2019/12/16
# @Desc    :
import os
import random
import tensorflow as tf
import settings# 每个类别选取的图片数量
samples_per_class = settings.SAMPLES_PER_CLASS
# 图片根目录
images_root = settings.IMAGES_ROOT
# 类别->编码的映射
class_code_map = settings.CLASS_CODE_MAP# 我们准备使用经典网络在imagenet数据集上的与训练权重,所以归一化时也要使用imagenet的平均值和标准差
image_mean = tf.constant(settings.IMAGE_MEAN)
image_std = tf.constant(settings.IMAGE_STD)def normalization(x):"""对输入图片x进行归一化,返回归一化的值"""return (x - image_mean) / image_stddef train_preprocess(x, y):"""对训练数据进行预处理。注意,这里的参数x是图片的路径,不是图片本身;y是图片的标签值"""# 读取图片x = tf.io.read_file(x)# 解码成张量x = tf.image.decode_jpeg(x, channels=3)# 将图片缩放到[244,244],比输入[224,224]稍大一些,方便后面数据增强x = tf.image.resize(x, [244, 244])# 随机决定是否左右镜像if random.choice([0, 1]):x = tf.image.random_flip_left_right(x)# 随机从x中剪裁出(224,224,3)大小的图片x = tf.image.random_crop(x, [224, 224, 3])# 读完上面的代码可以发现,这里的数据增强并不增加图片数量,一张图片经过变换后,# 仍然只是一张图片,跟我们前面说的增加图片数量的逻辑不太一样。# 这么做主要是应对我们的数据集里可能会存在相同图片的情况。# 将图片的像素值缩放到[0,1]之间x = tf.cast(x, dtype=tf.float32) / 255.# 归一化x = normalization(x)# 将标签转成one-hot形式y = tf.cast(y, dtype=tf.int32)y = tf.one_hot(y, settings.CLASS_NUM)return x, ydef dev_preprocess(x, y):"""对验证集和测试集进行数据预处理的方法。和train_preprocess的主要区别在于,不进行数据增强,以保证验证结果的稳定性。"""# 读取并缩放图片x = tf.io.read_file(x)x = tf.image.decode_jpeg(x, channels=3)x = tf.image.resize(x, [224, 224])# 归一化x = tf.cast(x, dtype=tf.float32) / 255.x = normalization(x)# 将标签转成one-hot形式y = tf.cast(y, dtype=tf.int32)y = tf.one_hot(y, settings.CLASS_NUM)return x, y# (图片路径,标签)的列表
image_path_and_labels = []
# 排序,保证每次拿到的顺序都一样
sub_images_dir_list = sorted(list(os.listdir(images_root)))
# 遍历每一个子目录
for sub_images_dir in sub_images_dir_list:sub_path = os.path.join(images_root, sub_images_dir)# 如果给定路径是文件夹,并且这个类别参与训练if os.path.isdir(sub_path) and sub_images_dir in settings.CLASSES:# 获取当前类别的编码current_label = class_code_map.get(sub_images_dir)# 获取子目录下的全部图片名称images = sorted(list(os.listdir(sub_path)))# 随机打乱(排序和置随机数种子都是为了保证每次的结果都一样)random.seed(settings.RANDOM_SEED)random.shuffle(images)# 保留前settings.SAMPLES_PER_CLASS个images = images[:samples_per_class]# 构建(x,y)对for image_name in images:abs_image_path = os.path.join(sub_path, image_name)image_path_and_labels.append((abs_image_path, current_label))
# 计算各数据集样例数
total_samples = len(image_path_and_labels)  # 总样例数
train_samples = int(total_samples * settings.TRAIN_DATASET)  # 训练集样例数
dev_samples = int(total_samples * settings.DEV_DATASET)  # 开发集样例数
test_samples = total_samples - train_samples - dev_samples  # 测试集样例数
# 打乱数据集
random.seed(settings.RANDOM_SEED)
random.shuffle(image_path_and_labels)
# 将图片数据和标签数据分开,此时它们仍是一一对应的
x_data = tf.constant([img for img, label in image_path_and_labels])
y_data = tf.constant([label for img, label in image_path_and_labels])
# 开始划分数据集
# 训练集
train_db = tf.data.Dataset.from_tensor_slices((x_data[:train_samples], y_data[:train_samples]))
# 打乱顺序,数据预处理,设置批大小
train_db = train_db.shuffle(10000).map(train_preprocess).batch(settings.BATCH_SIZE)
# 开发集(验证集)
dev_db = tf.data.Dataset.from_tensor_slices((x_data[train_samples:train_samples + dev_samples], y_data[train_samples:train_samples + dev_samples]))
# 数据预处理,设置批大小
dev_db = dev_db.map(dev_preprocess).batch(settings.BATCH_SIZE)
# 测试集
test_db = tf.data.Dataset.from_tensor_slices((x_data[train_samples + dev_samples:], y_data[train_samples + dev_samples:]))
# 数据预处理,设置批大小
test_db = test_db.map(dev_preprocess).batch(settings.BATCH_SIZE)

二、模型构建、训练和调优

数据已经全部处理完毕,该考虑模型了。首先,我们数据集太小了,直接构建自己的网络并训练,并不是一个好方案。因为这几种宠物其实挺难区分的,所以模型需要有一定复杂度,才能很好拟合这些数据,但我们的数据又太少了,最后的结果一定是过拟合,而且还是救不回来的那种= =所以我们考虑从迁移学习入手。

什么是迁移学习?懒得重新组织语言的我,默默地从之前写的博文里面摘了一段:

一般认为,深度卷积神经网络的训练是对数据集特征的一步步抽取的过程,从简单的特征,到复杂的特征。
训练好的模型学习到的是对图像特征的抽取方法,所以在imagenet数据集上训练好的模型理论上来说,也可以直接用于抽取其他图像的特征,这也是迁移学习的基础。自然,这样的效果往往没有在新数据上重新训练的效果好,但能够节省大量的训练时间,在特定情况下非常有用。

上面说的特定情况也包括我们面临的这一种——用于实际问题的数据集过小。

说到迁移学习,我最先想到的是VGG16,就先用VGG16搞了一波。使用在imagenet数据集上预训练的VGG16网络,去除顶部的全连接层,冻结全部参数,使它们在接下来的训练中不会改变。然后加上自己的全连接层,最后的输出层节点为4,对应于我们的四分类问题。开始训练。

模型在训练集上的误差很快降到5%以下,但是在验证集上的准确率基本在70+%,很明显,过拟合了。好嘛,盘它!主要使用如下方法尝试解决过拟合问题:

  • 调节全连接层的层数和每层的节点数
  • 添加BN层(虽说不是为了解决过拟合问题诞生的,但一定程度上是有效果的)
  • 添加Dropout层
  • 调节Dropout Rate
  • 添加l2正则

一顿操作猛如虎,回头一看0-5。这些方法确实对过拟合有所缓解,验证集上的准确率也确实有所提升,但只能达到81%左右。

然后我尝试了Resnet50,当然也过拟合了,盘它!最后验证集accuracy能达到83%左右。

很明显了,在全连接层的调整意义不大,究其根本,在于VGG16和ResNet50去除了全连接层之后,参数的数量也达到了20M+。两千万的参数使得模型严重过拟合,所以我们需要换一个参数少一点的模型。

于是,我盯上了DenseNet121,它的参数数量只有7M。继续盘它!果然,在一段时间的调优后,模型的性能有了明显的提升,验证集上的accuracy达到了87%左右。虽然和ResNet相比,准确率只高了4%,但相比于ResNet50 96%的训练accuracy而言,DenseNet121的训练accuracy只有90%左右。也就是说,对于DenseNet121而言,这个问题已经不再是过拟合问题了(相差3%我是可以接受的),而是欠拟合了。

然而淡腾的是,再怎么调参,模型都很难继续拟合了,调小学习率也不行。模型本身没啥问题的话,我开始怀疑数据集有没有问题,毕竟这种无法拟合的问题有很大概率是数据导致的。于是我就去检查了一下数据集……

这就是我前面强调认真过一遍数据集的原因了,我当时只是花个几分钟粗略地过了一下,删除掉一些明显不对的图片。我第二次认真过数据集的时候才发现,有很多异常图片没有过滤掉,比如猫的目录下有狗的图片,狗的目录下有猫的图片,还有一些不同动物同框的图片,以及我自己都认不出来的图片……

image.png

文章第一部分中各类图片数量的表格,其实就是我第二遍过滤后的结果统计。

过滤完成后,模型的性能有了明显的提升,训练accuracy约为93%-94%,验证accuracy为94%,测试accuracy为92%.我们先来看一下代码,后面会对这个结果再进行分析。

首先,是模型的构建:

# -*- coding: utf-8 -*-
# @File    : models.py
# @Author  : AaronJny
# @Time    : 2019/12/16
# @Desc    :
import tensorflow as tf
import settingsdef my_densenet():"""创建并返回一个基于densenet的Model对象"""# 获取densenet网络,使用在imagenet上训练的参数值,移除头部的全连接网络,池化层使用max_poolingdensenet = tf.keras.applications.DenseNet121(include_top=False, weights='imagenet', pooling='max')# 冻结预训练的参数,在之后的模型训练中不会改变它们densenet.trainable = False# 构建模型model = tf.keras.Sequential([# 输入层,shape为(None,224,224,3)tf.keras.layers.Input((224, 224, 3)),# 输入到DenseNet121中densenet,# 将DenseNet121的输出展平,以作为全连接层的输入tf.keras.layers.Flatten(),# 添加BN层tf.keras.layers.BatchNormalization(),# 随机失活tf.keras.layers.Dropout(0.5),# 第一个全连接层,激活函数relutf.keras.layers.Dense(512, activation=tf.nn.relu),# BN层tf.keras.layers.BatchNormalization(),# 随机失活tf.keras.layers.Dropout(0.5),# 第二个全连接层,激活函数relutf.keras.layers.Dense(64, activation=tf.nn.relu),# BN层tf.keras.layers.BatchNormalization(),# 输出层,为了保证输出结果的稳定,这里就不添加Dropout层了tf.keras.layers.Dense(settings.CLASS_NUM, activation=tf.nn.softmax)])return modelif __name__ == '__main__':model = my_densenet()model.summary()

网络的summary:

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
densenet121 (Model)          (None, 1024)              7037504   
_________________________________________________________________
flatten (Flatten)            (None, 1024)              0         
_________________________________________________________________
batch_normalization (BatchNo (None, 1024)              4096      
_________________________________________________________________
dropout (Dropout)            (None, 1024)              0         
_________________________________________________________________
dense (Dense)                (None, 512)               524800    
_________________________________________________________________
batch_normalization_1 (Batch (None, 512)               2048      
_________________________________________________________________
dropout_1 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 64)                32832     
_________________________________________________________________
batch_normalization_2 (Batch (None, 64)                256       
_________________________________________________________________
dense_2 (Dense)              (None, 4)                 260       
=================================================================
Total params: 7,601,796
Trainable params: 561,092
Non-trainable params: 7,040,704
_________________________________________________________________

参数总量7601796个,其中可训练参数561092个 。

模型和数据都已准备完毕,可以开始训练了。让我们编写一个训练用的脚本:

# -*- coding: utf-8 -*-
# @File    : train.py
# @Author  : AaronJny
# @Time    : 2019/12/17
# @Desc    :
import tensorflow as tf
from data import train_db, dev_db
import models
import settings# 从models文件中导入模型
model = models.my_densenet()
model.summary()# 配置优化器、损失函数、以及监控指标
model.compile(tf.keras.optimizers.Adam(settings.LEARNING_RATE), loss=tf.keras.losses.categorical_crossentropy,metrics=['accuracy'])# 在每个epoch结束后尝试保存模型参数,只有当前参数的val_accuracy比之前保存的更优时,才会覆盖掉之前保存的参数
model_check_point = tf.keras.callbacks.ModelCheckpoint(filepath=settings.MODEL_PATH, monitor='val_accuracy',save_best_only=True)
# 使用tf.keras的高级接口进行训练
model.fit_generator(train_db, epochs=settings.TRAIN_EPOCHS, validation_data=dev_db, callbacks=[model_check_point])

现在,我们可以运行脚本进行训练了,最优的参数将被保存在settings.MODEL_PATH。训练完成后,我们需要调用验证脚本,验证下模型在验证集和测试集上的表现:

# -*- coding: utf-8 -*-
# @File    : eval.py
# @Author  : AaronJny
# @Time    : 2019/12/17
# @Desc    :
import tensorflow as tf
from data import dev_db, test_db
from models import my_densenet
import settings# 创建模型
model = my_densenet()
# 加载参数
model.load_weights(settings.MODEL_PATH)
# 因为想用tf.keras的高级接口做验证,所以还是需要编译模型
model.compile(tf.keras.optimizers.Adam(settings.LEARNING_RATE), loss=tf.keras.losses.categorical_crossentropy,metrics=['accuracy'])
# 验证集accuracy
print('dev', model.evaluate(dev_db))
# 测试集accuracy
print('test', model.evaluate(test_db))

输出如下:

18/18 [==============================] - 5s 304ms/step - loss: 0.1936 - accuracy: 0.9457
dev [0.19364455559601387, 0.9456522]
18/18 [==============================] - 1s 64ms/step - loss: 0.2666 - accuracy: 0.9203
test [0.26657224384446937, 0.9202899]

能够看到,模型在验证集上的准确率为94.57%,在测试集上的准确率为92.03%,已经达到我的心里预期了,毕竟这么少的数据,还要啥自行车?

随着训练epoch的增多,模型的训练accuracy始终在[0.92,0.95]左右徘徊不定,没法继续拟合。究其原因,应该还是数据的锅。我们看一下识别错的样本,在eval.py脚本中,增加下面这一段程序:

# 查看识别错误的数据
for x, y in test_db:y_pred = model(x)y_pred = tf.argmax(y_pred, axis=1).numpy()y_true = tf.argmax(y, axis=1).numpy()batch_size = y_pred.shape[0]for i in range(batch_size):if y_pred[i] != y_true[i]:print('{} 被错误识别成 {}!'.format(settings.CODE_CLASS_MAP[y_true[i]], settings.CODE_CLASS_MAP[y_pred[i]]))

重新跑一下eval.py脚本,输出如下:

18/18 [==============================] - 5s 291ms/step - loss: 0.1936 - accuracy: 0.9457
dev [0.19364455559601387, 0.9456522]
18/18 [==============================] - 1s 64ms/step - loss: 0.2666 - accuracy: 0.9203
test [0.26657224384446937, 0.9202899]
狗 被错误识别成 兔!
狗 被错误识别成 兔!
狗 被错误识别成 兔!
鼠 被错误识别成 兔!
狗 被错误识别成 猫!
鼠 被错误识别成 猫!
狗 被错误识别成 兔!
狗 被错误识别成 鼠!
鼠 被错误识别成 兔!
狗 被错误识别成 兔!
猫 被错误识别成 兔!
猫 被错误识别成 鼠!
猫 被错误识别成 兔!
鼠 被错误识别成 兔!
狗 被错误识别成 兔!
狗 被错误识别成 猫!
鼠 被错误识别成 兔!
狗 被错误识别成 兔!
鼠 被错误识别成 兔!
狗 被错误识别成 猫!
鼠 被错误识别成 兔!
狗 被错误识别成 兔!

来,跟我一起唱——都是兔子惹的祸~

能够看到,出错的大部分都是被误识别成兔子了。对应到数据集上,虽然已经删掉了部分问题比较大的图片,但兔子的图片确实不好认。有很多兔子图片我人工分辨都认不出是兔子(捂脸.jpg)。然后,有些兔子图片看起来很像猫,有些看起来很像狗,有些看起来很像鼠……

如果我们把兔子图片去掉,将系统改为三分类问题,准确度将大幅度提高。当然了,按理说识别的类别数量变了,除了调整输出层的节点数量外,要想取得最佳效果,模型的其他参数也需要做相应调整的。我自己已经实测了,但限于篇幅,就不演示了,如果有兴趣的话,可以直接在settings.py里进行调整,将它变为三分类问题。改这两个地方:

# 参与训练的类别
CLASSES = ['cats', 'dogs', 'mouses', 'rabbits']
# 随机数种子
RANDOM_SEED = 13  # 四个类别时样本较为均衡的随机数种子

改成:

# 参与训练的类别
CLASSES = ['cats', 'dogs', 'mouses']
# 随机数种子
RANDOM_SEED = 19  # 三个类别时样本较为均衡的随机数种子

然后重新训练和验证即可。这只是一个插曲,本文仍然以四分类问题继续说明后续内容。

三、Web接口编写

模型训练好了,我们要把它应用起来。我准备编写一个Web服务,用户可以通过浏览器上传一张图片,服务器判断此图片的类别后,返回相关数据给用户。Web后端使用Flask,小而轻,前端则选用Vue.js + Element-UI实现。

先写后端:

# -*- coding: utf-8 -*-
# @File    : app.py
# @Author  : AaronJny
# @Time    : 2019/12/18
# @Desc    :
from flask import Flask
from flask import jsonify
from flask import request, render_template
import tensorflow as tf
from models import my_densenet
import settingsapp = Flask(__name__)# 导入模型
model = my_densenet()
# 加载训练好的参数
model.load_weights(settings.MODEL_PATH)@app.route('/', methods=['GET'])
def index():"""首页,vue入口"""return render_template('index.html')@app.route('/api/v1/pets_classify/', methods=['POST'])
def pets_classify():"""宠物图片分类接口,上传一张图片,返回此图片上的宠物是那种类别,概率多少"""# 获取用户上传的图片img_str = request.files.get('file').read()# 进行数据预处理x = tf.image.decode_image(img_str, channels=3)x = tf.image.resize(x, (224, 224))x = x / 255.x = (x - tf.constant(settings.IMAGE_MEAN)) / tf.constant(settings.IMAGE_STD)x = tf.reshape(x, (1, 224, 224, 3))# 预测y_pred = model(x)pet_cls_code = tf.argmax(y_pred, axis=1).numpy()[0]pet_cls_prob = float(y_pred.numpy()[0][pet_cls_code])pet_cls_prob = '{}%'.format(int(pet_cls_prob * 100))pet_class = settings.CODE_CLASS_MAP.get(pet_cls_code)# 将预测结果组织成jsonres = {'code': 0,'data': {'pet_cls': pet_class,'probability': pet_cls_prob,'msg': '<br><br><strong style="font-size: 48px;">{}</strong> <span style="font-size: 24px;"''>概率<strong>{}</strong></span>'.format(pet_class, pet_cls_prob),}}# 返回json数据return jsonify(res)if __name__ == '__main__':app.run(port=settings.WEB_PORT)

后端脚本app.py很简单,主要就两个方法。其中index方法会返回首页的html源码,是用户在浏览器端的访问入口;另一个方法pets_classify则提供了计算给定图片类别的功能。

前端文件index.html主要是提供了一个照片墙,用户上传图片到照片墙,服务器就会计算图片类别并返回相关数据。代码如下:

<!DOCTYPE html>
<html>
<head><meta charset="UTF-8"><!-- import CSS --><link rel="stylesheet" href="https://unpkg.com/element-ui/lib/theme-chalk/index.css">
</head>
<body>
<div id="app"><el-card class="box-card"><div slot="header" class="clearfix"><h1>宠物识别Demo</h1></div><el-uploadaction="http://localhost:5000/api/v1/pets_classify/"list-type="picture-card":on-preview="handlePictureCardPreview":on-success="handleUploadSuccess":on-remove="handleRemove"><i class="el-icon-plus"></i></el-upload><el-dialog :visible.sync="dialogVisible"><img width="100%" :src="dialogImageUrl" alt=""></el-dialog></el-card>
</div>
</body><!-- import Vue before Element -->
<script src="https://unpkg.com/vue/dist/vue.js"></script>
<!-- import JavaScript -->
<script src="https://unpkg.com/element-ui/lib/index.js"></script>
<script>new Vue({el: '#app',data() {return {dialogImageUrl: '',dialogVisible: false};},methods: {handleRemove(file, fileList) {console.log(file, fileList);console.log(this.dialogImageUrl);},handlePictureCardPreview(file) {this.dialogImageUrl = file.url;this.dialogVisible = true;},handleUploadSuccess(response, file, fileList) {this.$notify({title: '识别结果',message: response.data.msg,dangerouslyUseHTMLString: true,type: 'success',duration: 3000});}}})
</script>
</html>

让我们试试效果。首先,运行app.py脚本,启动web服务,当你看到如下输出时,说明服务启动成功了:

 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)* Serving Flask app "app" (lazy loading)* Environment: productionWARNING: This is a development server. Do not use it in a production deployment.Use a production WSGI server instead.* Debug mode: off

因为只是开发环境,这么启动就可以了。如果是生产环境,请不要这么做,可以选择使用nginx + gunicorn +uWSGI + gevent进行部署。

四、测试

打开浏览器,输入 http://localhost:5000 进入index页面。页面长这个样子:

image.png

点击网页中的上传框,我们可以选择图片上传并识别:

image.png

image.png

当然了,这里不选择我们数据集里的图片更好,哪怕是测试集里的。你可以去网上下载、或者通过其他渠道获取这四种动物的图片来测试,这里我只做演示,就不搞那么麻烦了,直接从数据集里随便选几张照片。我们可以继续上传图片给服务器识别:

image.png
image.png
image.png
image.png
image.png
image.png
image.png
image.png
image.png
image.png

OK,演示到此为止,如果有兴趣的话可以自行测试。

结语

文章到此结束,如果您喜欢的话,给我点个赞呗~

菜鸟一只,欢迎大佬们拍砖~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/26875.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python:tflearn训练的猫狗识别模型及其使用

需要下载&#xff1a;pip install tflearn 一些理论知识在前一篇文章中&#xff1a;可以一起阅读学习 https://blog.csdn.net/m0_64596200/article/details/126918240?spm1001.2014.3001.5501 已经处理好的.npy文件&#xff1a; https://download.csdn.net/download/m0_645962…

基于Pytorch实现猫狗分类

文章目录 一、环境配置二、数据集的准备三、猫狗分类的实例四、实现分类预测测试五、参考资料 一、环境配置 安装Anaconda 具体安装过程&#xff0c;请自行百度配置Pytorchpip install -i https://pypi.tuna.tsinghua.edu.cn/simple torch pip install -i https://pypi.tuna.t…

猫狗训练单张图片的测试

猫狗训练的训练模型的建立&#xff0c;模型在整个预测集上的预测效果的测试的程序代码网上或一些书籍上都可查阅&#xff0c;但是对单张或某些图片的分类测试程序不多&#xff0c;这里通过参考博客&#xff1a;https://blog.csdn.net/baidu_35113561/article/details/79371716 …

宠物鼻纹识别及面部识别进一步在城市养犬登记场景落地

最近安阳狗咬人事件造成了极其恶劣的社会影响&#xff0c;大型禁养犬类伤人成为城市治安管理不容忽视的隐患&#xff0c;正威胁人们的生命安全&#xff0c;养犬热潮也给城市管理带来了不小的挑战&#xff0c;粗放式的养犬管理不再适应时代的需求&#xff0c;城市养犬管理改革已…

借助互联网,“宝贝它”欲打造线上宠物交易与服务平台

作为人类忠实的朋友&#xff0c;宠物一直伴随着很多家庭的成长。而随着人们生活节奏的不断加快&#xff0c;电子商务正成为越来越多传统垂直领域的解决方案&#xff0c;宠物交易与服务同样也不例外。上海创业公司宝贝它希望借助互联网&#xff0c;打造线上宠物交易及服务平台。…

小动物领养网站/宠物救助网站

摘 要 本论文对小动物领养网站的开发过程进行了较为详细的论述&#xff0c;采用B/S架构、ssm 框架和 java 开发的 Web 框架&#xff0c;eclipse开发工具。 小动物领养网站&#xff0c;主要的模块包括管理员&#xff1b;首页、个人中心、用户管理、动物展示管理、动物分类管理…

语音合成工具Coqui TTS安装及体验

先介绍两种免费的语音合成工具 balabolka 官网 http://balabolka.site/balabolka.htm 是一种基于微软Speech API (SAPI)的免费语音合成工具&#xff0c;只是简单的发音合成&#xff0c;效果比较生硬 Coqui TTS 官网 https://coqui.ai/ 是基于深度学习的语音合成软件&#x…

音视频进阶教程|如何实现游戏中的实时语音

1 游戏实时语音功能简介 1.1 游戏实时语音概念解释 范围&#xff1a;收听者接收音频的范围。方位&#xff1a;指收听者在游戏世界坐标中的位置和朝向&#xff0c;详情可参考 5.5 初始化设置 中的“步骤 1”。收听者&#xff1a;房间内接收音频的用户发声者&#xff1a;房间内…

通过实时语音驱动人像模拟真人说话

元宇宙的火热让人们对未来虚拟世界的形态充满了幻想&#xff0c;此前我们为大家揭秘了声网自研的 3D 空间音频技术如何在虚拟世界中完美模拟现实听觉体验&#xff0c;增加玩家沉浸感。今天我们暂时离开元宇宙&#xff0c;回到现实世界&#xff0c;来聊聊声网自研的 Agora Lipsy…

聊天语音APP开发|聊天语音软件开发-实时音视频技术

聊天语音软件的开发应该是一个以视频和语音直播为核心的社交系统。对于用户来说&#xff0c;更好的视频和语音直播功能可以增强用户的接受感&#xff0c;让用户持续使用。为了方便视频和语音直播的采用体验&#xff0c;减少直播的延时&#xff0c;聊天语音软件的开发将采用实时…

拿到offer提出离职,公司拖30天才放人,但下家公司等不了30天,怎么办?

拿到offer想跳槽&#xff0c;向公司提出了离职&#xff0c;但公司要拖30天才放人&#xff0c;新公司又等不了30天&#xff0c;offer可能就没有了&#xff0c;这就是一位网友面临的两难局面&#xff0c;这种情况有没有什么解决的好办法呢&#xff1f; 有人安慰楼主&#xff0c;下…

怎么说离职原因新的公司比较能接受?

怎么说离职原因新的公司比较能接受&#xff1f; 我来提供一些格式化的应对方法&#xff1b; 1.实际原因&#xff1a;原单位工资太少。离职原因&#xff1a;我认为我自己已经具备了一定的积累&#xff0c;希望可以迈向一个新的台阶。 2.实际原因&#xff1a;跟同事出不来。离…

我提了离职,公司给我涨薪了,还能待下去吗?

金三银四到了&#xff0c;相信不少同学又开始在物色新的公司。 不少同学反映&#xff0c;在提出离职后&#xff0c;公司给自己加了薪&#xff0c;虽然不多。 那“在职员工&#xff0c;提出辞职被挽留&#xff0c;应该留下吗&#xff1f;” 为什么想要离职&#xff1f; 这个问…

是的,我离职了

终于可以敞开说这件事情了&#xff0c;年后的这一个月&#xff0c;我彻底停更了&#xff0c;并不是偷懒了&#xff0c;而是我要找工作。大家也都知道18年的寒冬&#xff0c;很多大厂开始裁员&#xff0c;所以我要更加认真的学习&#xff0c;毕竟跟大厂出来的相比&#xff0c;自…

办理离职手续流程的详细流程(离职交接的标准流程)

1、正式员工办理离职手续流程 若员工自离&#xff0c;需提前一个月向部门领导提出辞职申请&#xff08;即时聊天工具或邮件&#xff09;和《解除劳动合同申请》。 1&#xff09;面谈&#xff1a;一般领导都会先谈话&#xff0c;确定你离职的时间及安排交接人员进行工作交接。 2…

程序员新公司入职被拒 只因离职证明多了一句话!

程序猿&#xff08;微信号&#xff1a;imkuqin&#xff09; 猿妹 整编 新闻报道来自&#xff1a;成都商报 近日&#xff0c;成都一名程序员被新应聘的公司通知入职&#xff0c;然而因为原公司给他出具的一份离职证明上&#xff0c;记载了一句“该员工在项目未完成情况下因个人原…

提交辞职申请时,领导极力挽留,还答应加薪,要不要留下来?

提交辞职申请时&#xff0c;领导极力挽留&#xff0c;还答应加薪&#xff0c;要不要留下来&#xff1f;张工是一名程序员&#xff0c;最近他向领导提交了辞职申请表后&#xff0c;却被领导极力挽留&#xff0c;领导不仅打感情牌&#xff0c;还打加薪牌。就是希望张工能够留下来…

医学影像处理与识别,应用AI模型,探索疾病辅助诊断!

关注公众号&#xff0c;发现CV技术之美 今天&#xff08;2023.1.9&#xff09; arXiv.CV 上有7篇医学影像处理与识别相关论文。不过粗略看来&#xff0c;医学影像类的论文&#xff0c;很多都是直接使用已有模型&#xff08;甚至都不是最先进的模型&#xff09;&#xff0c;加以…

【react从入门到精通】初识React

文章目录 人工智能福利文章前言React技能树什么是 React&#xff1f;安装和配置 React创建 React 组件渲染 React 组件使用 JSX传递属性&#xff08;Props&#xff09;处理组件状态&#xff08;State&#xff09;处理用户输入&#xff08;事件处理&#xff09;组合和嵌套组件写…

JWT续期问题,ChatGPT解决方案

JWT&#xff08;JSON Web Token&#xff09;通常是在用户登录后签发的&#xff0c;用于验证用户身份和授权。JWT 的有效期限&#xff08;或称“过期时间”&#xff09;通常是一段时间&#xff08;例如1小时&#xff09;&#xff0c;过期后用户需要重新登录以获取新的JWT。然而&…