不用训练,直接用在ImageNet上训练的RestNet网络就可以做一个简单的以图搜图功能
1、RestNet网络
上面是resnet网络的结构图,ImageNet是一个有1000类的数据集,我们可以把在该数据集上训练过的Resnet网络当成特征提取器,用来提取图片的特征,然后比对特征的欧式距离,判定两种图片的相似性。
average pool的输出是512x1x1,将其reshape为512x1当作该图片的特征。
2、生成RestNet网络
arch_name = “resnet18"
pretrained = True,加载预训练模型
self.retriever_net = torchvision.models.__dict__[arch_name](pretrained = pretrained)
3、提取average pool层的输出
feature_layer_name = ‘avgpool’
feature_index_in_module = 0
register_forward_hook函数,forward时负责保存某个模块的输出
self.feature_layer_name = feature_layer_name
self.feature_index_in_module = feature_index_in_module self.retriever_net._modules.get(self.feature_layer_name).register_forward_hook(self.hook_feature)
4、比对两个特征
使用欧式距离
dist = F.pairwise_distance(contrast_features, retrieved_features, p=2)
5、贴一张我简陋的gui界面,要捂脸啦
左边的load是加载一张对比图片,右边的load是加载一个文件夹,点Retriever,开始在加载的文件夹中查找和对比图片top1相似的图片,然后显示出来。
两天里面挤时间写的,所以功能很简单,gui很简陋,自己倒是觉得挺有意思的,给大家抛砖引玉吧,全部代码在这里。
image_retrieval_with_gui