使用 Faiss 创建 Index ,利用 <id, feature vector> 数据生成索引。
针对待检索图片,使用模型提取图片特征向量,然后使用 Index 检索 TopK 相似图片的 id。
可视化检索结果
1. 导包
import os
import time
import torch
import faiss
import numpy as np
import matplotlib.pyplot as pltfrom PIL import Image
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset%matplotlib inline
GPU 加速
device = torch.device('cuda'if torch.cuda.is_available()else'cpu')print(device)# cuda
2.自定义数据集
transform = transforms.Compose([transforms.Resize((256,256)),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])classMyDataset(Dataset):def__init__(self, data_path, transform=None):super().__init__()self.transform = transformself.data_path = data_pathself.data =[]img_path = os.path.join(data_path,'img.txt')withopen(img_path,'r', encoding='utf-8')as f:for line in f.readlines():line = line.strip()img_name = os.path.join(data_path, line)img = Image.open(img_name)if img.mode =='RGB':self.data.append(line)def__getitem__(self, idx):# take the data sample by it's indeximg_path = os.path.join(self.data_path, self.data[idx])# read imageimg = Image.open(img_path)# apply the transformif self.transform:img = self.transform(img)# return the image and indexdict_data ={'index': idx,'img': img}return dict_datadef__len__(self):returnlen(self.data)