多目标跟踪之reid网络
- 1、准备数据集
- 2、reid网络的搭建(分类网络)
- 3、reid网络的训练
- 4、特征提取推理demo
1、准备数据集
按照分类任务那样准备数据集,即创建一个train文件夹,在这下面准备若干个子文件夹,表示一共有多少类别,每个文件夹中准备好对应类的图片即可。
2、reid网络的搭建(分类网络)
定义一个resnet18的分类网络(我们可以自定义的修改模块,也可以直接import torchvision中现成的特征提取网络):
import torch
import torch.nn as nn
import torch.nn.functional as Fclass BasicBlock(nn.Module):def __init__(self, c_in, c_out, is_downsample=False):super(BasicBlock, self).__init__()self.is_downsample = is_downsampleif is_downsample:self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=2, padding=1, bias=False)else:self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(c_out)self.relu = nn.ReLU(True)self.conv2 = nn.Conv2d(c_out, c_out, 3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(c_out)if is_downsample:self.downsample = nn.Sequential(nn.Conv2d(c_in, c_out, 1, stride=2, bias=False),nn.BatchNorm2d(c_out))elif c_in != c_out:self.downsample = nn.Sequential(nn.Conv2d(c_in, c_out, 1, stride=1, bias=False),nn.BatchNorm2d(c_out))self.is_downsample = Truedef forward(self, x):y = self.conv1(x)y = self.bn1(y)y = self.relu(y)y = self.conv2(y)y = self.bn2(y)if self.is_downsample:x = self.downsample(x)return F.relu(x.add(y), True)def make_layers(c_in, c_out, repeat_times, is_downsample=False):blocks = []for i in range(repeat_times):if i == 0:blocks += [BasicBlock(c_in, c_out, is_downsample=is_downsample), ]else:blocks += [BasicBlock(c_out, c_out), ]return nn.Sequential(*blocks)# ResNet18
class Net(nn.Module):def __init__(self, num_classes=751, reid=False):super(Net, self).__init__()# 3 128 64self.conv = nn.Sequential(nn.Conv2d(3, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),# nn.Conv2d(32,32,3,stride=1,padding=1),# nn.BatchNorm2d(32),# nn.ReLU(inplace=True),nn.MaxPool2d(3, 2, padding=1),)# 32 64 32self.layer1 = make_layers(64, 64, 2, False)# 32 64 32self.layer2 = make_layers(64, 128, 2, True)# 64 32 16self.layer3 = make_layers(128, 256, 2, True)# 128 16 8self.layer4 = make_layers(256, 512, 2, True)# 256 8 4self.avgpool = nn.AvgPool2d((4, 8), 1)# 256 1 1self.reid = reidself.classifier = nn.Sequential(nn.Linear(512, 256),nn.BatchNorm1d(256),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(256, num_classes),)def forward(self, x):x = self.conv(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.view(x.size(0), -1)# B x 128if self.reid:x = x.div(x.norm(p=2, dim=1, keepdim=True))return x# classifierx = self.classifier(x)return x# if __name__ == '__main__':
# net = Net()
# x = torch.randn(4, 3, 128, 64)
# y = net(x)
这里和普通分类网络不同的点在于我们在最后一层特征提取层和分类的全连接层之间加入了一个开关,训练时,我们将这个开关打开,以分类任务的形式,权重会更新到网络的所有层,当在目标跟踪中进行特征提取推理时,我们将这个开关关闭,前向传播只进行到特征提取层的最后一层,然后输出一个指定维度的特征向量,用于跟踪时计算track的向量库与当前帧detections之间的余弦相似度。
3、reid网络的训练
import argparse
import os
import timeimport numpy as np
import matplotlib.pyplot as plt
import torch
import torch.backends.cudnn as cudnn
import torchvisionfrom model import Netparser = argparse.ArgumentParser(description="Train on benign")
parser.add_argument("--data-dir", default='./dataset', type=str)
parser.add_argument("--no-cuda", action="store_true")
parser.add_argument("--gpu-id", default=0, type=int)
parser.add_argument("--lr", default=0.1, type=float)
parser.add_argument("--interval", '-i', default=20, type=int)
parser.add_argument('--resume', '-r', action='store_true')
args = parser.parse_args()# device
device = "cuda:{}".format(args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu"
if torch.cuda.is_available() and not args.no_cuda:cudnn.benchmark = True# data loading
root = args.data_dir
train_dir = os.path.join(root, "train")
test_dir = os.path.join(root, "test")# 训练的数据处理
transform_train = torchvision.transforms.Compose([torchvision.transforms.Resize((64, 128)),torchvision.transforms.RandomCrop((64, 128), padding=4),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])# 测试的数据处理
transform_test = torchvision.transforms.Compose([torchvision.transforms.Resize((64, 128)),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])trainloader = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(train_dir, transform=transform_train),batch_size=128, shuffle=True
)testloader = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(test_dir, transform=transform_test),batch_size=128, shuffle=True
)
num_classes = max(len(trainloader.dataset.classes), len(testloader.dataset.classes))
print("num_classes = %s" % num_classes)# net definition
start_epoch = 0# 实例化一个分类网络
net = Net(num_classes=num_classes)# 是否加载之前的权重
if args.resume:assert os.path.isfile("./checkpoint/ckpt.t7"), "Error: no checkpoint file found!"print('Loading from checkpoint/ckpt.t7')checkpoint = torch.load("./checkpoint/ckpt.t7")# import ipdb; ipdb.set_trace()net_dict = checkpoint['net_dict']net.load_state_dict(net_dict)best_acc = checkpoint['acc']start_epoch = checkpoint['epoch']
net.to(device)# loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), args.lr, momentum=0.9, weight_decay=5e-4)
best_acc = 0.# train function for each epoch
def train(epoch):print("\nEpoch : %d" % (epoch + 1))net.train()training_loss = 0.train_loss = 0.correct = 0total = 0interval = args.intervalstart = time.time()for idx, (inputs, labels) in enumerate(trainloader):# forwardinputs, labels = inputs.to(device), labels.to(device)outputs = net(inputs)loss = criterion(outputs, labels)# backwardoptimizer.zero_grad()loss.backward()optimizer.step()# accumuratingtraining_loss += loss.item()train_loss += loss.item()correct += outputs.max(dim=1)[1].eq(labels).sum().item()total += labels.size(0)# printif (idx + 1) % interval == 0:end = time.time()print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format(100. * (idx + 1) / len(trainloader), end - start, training_loss / interval, correct, total,100. * correct / total))training_loss = 0.start = time.time()return train_loss / len(trainloader), 1. - correct / totaldef test(epoch):global best_accnet.eval()test_loss = 0.correct = 0total = 0start = time.time()with torch.no_grad():for idx, (inputs, labels) in enumerate(testloader):inputs, labels = inputs.to(device), labels.to(device)outputs = net(inputs)loss = criterion(outputs, labels)test_loss += loss.item()correct += outputs.max(dim=1)[1].eq(labels).sum().item()total += labels.size(0)print("Testing ...")end = time.time()print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format(100. * (idx + 1) / len(testloader), end - start, test_loss / len(testloader), correct, total,100. * correct / total))# saving checkpointacc = 100. * correct / totalif acc > best_acc:best_acc = accprint("Saving parameters to checkpoint/reid.pt")checkpoint = {'net_dict': net.state_dict(),'acc': acc,'epoch': epoch,}if not os.path.isdir('checkpoint'):os.mkdir('checkpoint')torch.save(checkpoint, './checkpoint/reid.pt')return test_loss / len(testloader), 1. - correct / total# plot figure
x_epoch = []
record = {'train_loss': [], 'train_err': [], 'test_loss': [], 'test_err': []}
fig = plt.figure()
ax0 = fig.add_subplot(121, title="loss")
ax1 = fig.add_subplot(122, title="top1err")def draw_curve(epoch, train_loss, train_err, test_loss, test_err):global recordrecord['train_loss'].append(train_loss)record['train_err'].append(train_err)record['test_loss'].append(test_loss)record['test_err'].append(test_err)x_epoch.append(epoch)ax0.plot(x_epoch, record['train_loss'], 'bo-', label='train')ax0.plot(x_epoch, record['test_loss'], 'ro-', label='val')ax1.plot(x_epoch, record['train_err'], 'bo-', label='train')ax1.plot(x_epoch, record['test_err'], 'ro-', label='val')if epoch == 0:ax0.legend()ax1.legend()fig.savefig("train.jpg")# lr decay
def lr_decay():global optimizerfor params in optimizer.param_groups:params['lr'] *= 0.1lr = params['lr']print("Learning rate adjusted to {}".format(lr))def main():total_epoches = 50for epoch in range(start_epoch, start_epoch + total_epoches):train_loss, train_err = train(epoch)test_loss, test_err = test(epoch)draw_curve(epoch, train_loss, train_err, test_loss, test_err)if (epoch + 1) % (total_epoches) == 0:lr_decay()if __name__ == '__main__':main()
4、特征提取推理demo
import torch
import torchvision.transforms as transforms
import numpy as np
import cv2
import loggingfrom model import Net'''
特征提取器:
提取对应bounding box中的特征, 得到一个固定维度的embedding作为该bounding box的代表,
供计算相似度时使用。模型训练是按照传统ReID的方法进行,使用Extractor类的时候输入为一个list的图片,得到图片对应的特征。
'''class Extractor(object):def __init__(self, model_path, use_cuda=True):self.net = Net(reid=True)self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)['net_dict']self.net.load_state_dict(state_dict)logger = logging.getLogger("root.tracker")logger.info("Loading weights from {}... Done!".format(model_path))self.net.to(self.device)self.size = (64, 128)self.norm = transforms.Compose([# RGB图片数据范围是[0-255],需要先经过ToTensor除以255归一化到[0,1]之后,# 再通过Normalize计算(x - mean)/std后,将数据归一化到[-1,1]。transforms.ToTensor(),# mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225]是从imagenet训练集中算出来的transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),])def _preprocess(self, im_crops):"""TODO:1. to float with scale from 0 to 12. resize to (64, 128) as Market1501 dataset did3. concatenate to a numpy array3. to torch Tensor4. normalize"""def _resize(im, size):return cv2.resize(im.astype(np.float32)/255., size)im_batch = torch.cat([self.norm(_resize(im, self.size)).unsqueeze(0) for im in im_crops], dim=0).float()return im_batch# __call__()是一个非常特殊的实例方法。该方法的功能类似于在类中重载 () 运算符,
# 使得类实例对象可以像调用普通函数那样,以“对象名()”的形式使用,而无需再定义一个forward函数或者其他名称的执行函数。def __call__(self, im_crops):im_batch = self._preprocess(im_crops)with torch.no_grad():im_batch = im_batch.to(self.device)features = self.net(im_batch)return features.cpu().numpy()if __name__ == '__main__':# 默认图像维度按照(h,w,c)排列,将图像的通道转成rgb通道, 第一个:表示选择所有行,第二个:表示选择所有列,(2,1,0)表示让通道交换顺序img = cv2.imread("0002_c1s1_000551_01.jpg")[:,:,(2,1,0)]extr = Extractor("checkpoint/reid.pt")feature = extr([img])print(feature.shape)