【深度学习|目标跟踪】多目标跟踪之训练reid网络

多目标跟踪之reid网络

  • 1、准备数据集
  • 2、reid网络的搭建(分类网络)
  • 3、reid网络的训练
  • 4、特征提取推理demo

1、准备数据集

  按照分类任务那样准备数据集,即创建一个train文件夹,在这下面准备若干个子文件夹,表示一共有多少类别,每个文件夹中准备好对应类的图片即可。

2、reid网络的搭建(分类网络)

  定义一个resnet18的分类网络(我们可以自定义的修改模块,也可以直接import torchvision中现成的特征提取网络):

import torch
import torch.nn as nn
import torch.nn.functional as Fclass BasicBlock(nn.Module):def __init__(self, c_in, c_out, is_downsample=False):super(BasicBlock, self).__init__()self.is_downsample = is_downsampleif is_downsample:self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=2, padding=1, bias=False)else:self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(c_out)self.relu = nn.ReLU(True)self.conv2 = nn.Conv2d(c_out, c_out, 3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(c_out)if is_downsample:self.downsample = nn.Sequential(nn.Conv2d(c_in, c_out, 1, stride=2, bias=False),nn.BatchNorm2d(c_out))elif c_in != c_out:self.downsample = nn.Sequential(nn.Conv2d(c_in, c_out, 1, stride=1, bias=False),nn.BatchNorm2d(c_out))self.is_downsample = Truedef forward(self, x):y = self.conv1(x)y = self.bn1(y)y = self.relu(y)y = self.conv2(y)y = self.bn2(y)if self.is_downsample:x = self.downsample(x)return F.relu(x.add(y), True)def make_layers(c_in, c_out, repeat_times, is_downsample=False):blocks = []for i in range(repeat_times):if i == 0:blocks += [BasicBlock(c_in, c_out, is_downsample=is_downsample), ]else:blocks += [BasicBlock(c_out, c_out), ]return nn.Sequential(*blocks)# ResNet18
class Net(nn.Module):def __init__(self, num_classes=751, reid=False):super(Net, self).__init__()# 3 128 64self.conv = nn.Sequential(nn.Conv2d(3, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),# nn.Conv2d(32,32,3,stride=1,padding=1),# nn.BatchNorm2d(32),# nn.ReLU(inplace=True),nn.MaxPool2d(3, 2, padding=1),)# 32 64 32self.layer1 = make_layers(64, 64, 2, False)# 32 64 32self.layer2 = make_layers(64, 128, 2, True)# 64 32 16self.layer3 = make_layers(128, 256, 2, True)# 128 16 8self.layer4 = make_layers(256, 512, 2, True)# 256 8 4self.avgpool = nn.AvgPool2d((4, 8), 1)# 256 1 1self.reid = reidself.classifier = nn.Sequential(nn.Linear(512, 256),nn.BatchNorm1d(256),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(256, num_classes),)def forward(self, x):x = self.conv(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.view(x.size(0), -1)# B x 128if self.reid:x = x.div(x.norm(p=2, dim=1, keepdim=True))return x# classifierx = self.classifier(x)return x# if __name__ == '__main__':
#    net = Net()
#    x = torch.randn(4, 3, 128, 64)
#    y = net(x)

  这里和普通分类网络不同的点在于我们在最后一层特征提取层和分类的全连接层之间加入了一个开关,训练时,我们将这个开关打开,以分类任务的形式,权重会更新到网络的所有层,当在目标跟踪中进行特征提取推理时,我们将这个开关关闭,前向传播只进行到特征提取层的最后一层,然后输出一个指定维度的特征向量,用于跟踪时计算track的向量库与当前帧detections之间的余弦相似度。

3、reid网络的训练

import argparse
import os
import timeimport numpy as np
import matplotlib.pyplot as plt
import torch
import torch.backends.cudnn as cudnn
import torchvisionfrom model import Netparser = argparse.ArgumentParser(description="Train on benign")
parser.add_argument("--data-dir", default='./dataset', type=str)
parser.add_argument("--no-cuda", action="store_true")
parser.add_argument("--gpu-id", default=0, type=int)
parser.add_argument("--lr", default=0.1, type=float)
parser.add_argument("--interval", '-i', default=20, type=int)
parser.add_argument('--resume', '-r', action='store_true')
args = parser.parse_args()# device
device = "cuda:{}".format(args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu"
if torch.cuda.is_available() and not args.no_cuda:cudnn.benchmark = True# data loading
root = args.data_dir
train_dir = os.path.join(root, "train")
test_dir = os.path.join(root, "test")# 训练的数据处理
transform_train = torchvision.transforms.Compose([torchvision.transforms.Resize((64, 128)),torchvision.transforms.RandomCrop((64, 128), padding=4),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])# 测试的数据处理
transform_test = torchvision.transforms.Compose([torchvision.transforms.Resize((64, 128)),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])trainloader = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(train_dir, transform=transform_train),batch_size=128, shuffle=True
)testloader = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(test_dir, transform=transform_test),batch_size=128, shuffle=True
)
num_classes = max(len(trainloader.dataset.classes), len(testloader.dataset.classes))
print("num_classes = %s" % num_classes)# net definition
start_epoch = 0# 实例化一个分类网络
net = Net(num_classes=num_classes)# 是否加载之前的权重
if args.resume:assert os.path.isfile("./checkpoint/ckpt.t7"), "Error: no checkpoint file found!"print('Loading from checkpoint/ckpt.t7')checkpoint = torch.load("./checkpoint/ckpt.t7")# import ipdb; ipdb.set_trace()net_dict = checkpoint['net_dict']net.load_state_dict(net_dict)best_acc = checkpoint['acc']start_epoch = checkpoint['epoch']
net.to(device)# loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), args.lr, momentum=0.9, weight_decay=5e-4)
best_acc = 0.# train function for each epoch
def train(epoch):print("\nEpoch : %d" % (epoch + 1))net.train()training_loss = 0.train_loss = 0.correct = 0total = 0interval = args.intervalstart = time.time()for idx, (inputs, labels) in enumerate(trainloader):# forwardinputs, labels = inputs.to(device), labels.to(device)outputs = net(inputs)loss = criterion(outputs, labels)# backwardoptimizer.zero_grad()loss.backward()optimizer.step()# accumuratingtraining_loss += loss.item()train_loss += loss.item()correct += outputs.max(dim=1)[1].eq(labels).sum().item()total += labels.size(0)# printif (idx + 1) % interval == 0:end = time.time()print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format(100. * (idx + 1) / len(trainloader), end - start, training_loss / interval, correct, total,100. * correct / total))training_loss = 0.start = time.time()return train_loss / len(trainloader), 1. - correct / totaldef test(epoch):global best_accnet.eval()test_loss = 0.correct = 0total = 0start = time.time()with torch.no_grad():for idx, (inputs, labels) in enumerate(testloader):inputs, labels = inputs.to(device), labels.to(device)outputs = net(inputs)loss = criterion(outputs, labels)test_loss += loss.item()correct += outputs.max(dim=1)[1].eq(labels).sum().item()total += labels.size(0)print("Testing ...")end = time.time()print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format(100. * (idx + 1) / len(testloader), end - start, test_loss / len(testloader), correct, total,100. * correct / total))# saving checkpointacc = 100. * correct / totalif acc > best_acc:best_acc = accprint("Saving parameters to checkpoint/reid.pt")checkpoint = {'net_dict': net.state_dict(),'acc': acc,'epoch': epoch,}if not os.path.isdir('checkpoint'):os.mkdir('checkpoint')torch.save(checkpoint, './checkpoint/reid.pt')return test_loss / len(testloader), 1. - correct / total# plot figure
x_epoch = []
record = {'train_loss': [], 'train_err': [], 'test_loss': [], 'test_err': []}
fig = plt.figure()
ax0 = fig.add_subplot(121, title="loss")
ax1 = fig.add_subplot(122, title="top1err")def draw_curve(epoch, train_loss, train_err, test_loss, test_err):global recordrecord['train_loss'].append(train_loss)record['train_err'].append(train_err)record['test_loss'].append(test_loss)record['test_err'].append(test_err)x_epoch.append(epoch)ax0.plot(x_epoch, record['train_loss'], 'bo-', label='train')ax0.plot(x_epoch, record['test_loss'], 'ro-', label='val')ax1.plot(x_epoch, record['train_err'], 'bo-', label='train')ax1.plot(x_epoch, record['test_err'], 'ro-', label='val')if epoch == 0:ax0.legend()ax1.legend()fig.savefig("train.jpg")# lr decay
def lr_decay():global optimizerfor params in optimizer.param_groups:params['lr'] *= 0.1lr = params['lr']print("Learning rate adjusted to {}".format(lr))def main():total_epoches = 50for epoch in range(start_epoch, start_epoch + total_epoches):train_loss, train_err = train(epoch)test_loss, test_err = test(epoch)draw_curve(epoch, train_loss, train_err, test_loss, test_err)if (epoch + 1) % (total_epoches) == 0:lr_decay()if __name__ == '__main__':main()

4、特征提取推理demo

import torch
import torchvision.transforms as transforms
import numpy as np
import cv2
import loggingfrom model import Net'''
特征提取器:
提取对应bounding box中的特征, 得到一个固定维度的embedding作为该bounding box的代表,
供计算相似度时使用。模型训练是按照传统ReID的方法进行,使用Extractor类的时候输入为一个list的图片,得到图片对应的特征。
'''class Extractor(object):def __init__(self, model_path, use_cuda=True):self.net = Net(reid=True)self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)['net_dict']self.net.load_state_dict(state_dict)logger = logging.getLogger("root.tracker")logger.info("Loading weights from {}... Done!".format(model_path))self.net.to(self.device)self.size = (64, 128)self.norm = transforms.Compose([# RGB图片数据范围是[0-255],需要先经过ToTensor除以255归一化到[0,1]之后,# 再通过Normalize计算(x - mean)/std后,将数据归一化到[-1,1]。transforms.ToTensor(),# mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225]是从imagenet训练集中算出来的transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),])def _preprocess(self, im_crops):"""TODO:1. to float with scale from 0 to 12. resize to (64, 128) as Market1501 dataset did3. concatenate to a numpy array3. to torch Tensor4. normalize"""def _resize(im, size):return cv2.resize(im.astype(np.float32)/255., size)im_batch = torch.cat([self.norm(_resize(im, self.size)).unsqueeze(0) for im in im_crops], dim=0).float()return im_batch# __call__()是一个非常特殊的实例方法。该方法的功能类似于在类中重载 () 运算符,
# 使得类实例对象可以像调用普通函数那样,以“对象名()”的形式使用,而无需再定义一个forward函数或者其他名称的执行函数。def __call__(self, im_crops):im_batch = self._preprocess(im_crops)with torch.no_grad():im_batch = im_batch.to(self.device)features = self.net(im_batch)return features.cpu().numpy()if __name__ == '__main__':# 默认图像维度按照(h,w,c)排列,将图像的通道转成rgb通道, 第一个:表示选择所有行,第二个:表示选择所有列,(2,1,0)表示让通道交换顺序img = cv2.imread("0002_c1s1_000551_01.jpg")[:,:,(2,1,0)]extr = Extractor("checkpoint/reid.pt")feature = extr([img])print(feature.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/475989.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring Cloud Stream实现数据流处理

1.什么是Spring Cloud Stream? 我看很多回答都是“为了屏蔽消息队列的差异,使我们在使用消息队列的时候能够用统一的一套API,无需关心具体的消息队列实现”。 这样理解是有些不全面的,Spring Cloud Stream的核心是Stream&#xf…

OpenMMlab导出Mask R-CNN模型并用onnxruntime和tensorrt推理

onnxruntime推理 使用mmdeploy导出onnx模型: from mmdeploy.apis import torch2onnx from mmdeploy.backend.sdk.export_info import export2SDKimg demo.JPEG work_dir ./work_dir/onnx/mask_rcnn save_file ./end2end.onnx deploy_cfg mmdeploy/configs/mmd…

【大语言模型】ACL2024论文-19 SportsMetrics: 融合文本和数值数据以理解大型语言模型中的信息融合

【大语言模型】ACL2024论文-19 SportsMetrics: 融合文本和数值数据以理解大型语言模型中的信息融合 https://arxiv.org/pdf/2402.10979 目录 文章目录 【大语言模型】ACL2024论文-19 SportsMetrics: 融合文本和数值数据以理解大型语言模型中的信息融合目录摘要研究背景问题与挑…

39页PDF | 毕马威_数据资产运营白皮书(限免下载)

一、前言 《毕马威数据资产运营白皮书》探讨了数据作为新型生产要素在企业数智化转型中的重要性,提出了数据资产运营的“三要素”(组织与意识、流程与规范、平台与工具)和“四重奏”(数据资产盘点、评估、治理、共享)…

【UE5】使用基元数据对材质传参,从而避免新建材质实例

在项目中,经常会遇到这样的需求:多个模型(例如 100 个)使用相同的材质,但每个模型需要不同的参数设置,比如不同的颜色或随机种子等。 在这种情况下,创建 100 个实例材质不是最佳选择。正确的做…

[STBC]

空时分组编码STBC(Space Time Block Coding): //一个数据流通过多个天线发射发送,硬件编码器 STBC概念是从MIMO技术衍生出来的,目的是在多天线系统中提高数据传输的可靠性和传输距离。在rx(接收天线)和tx&…

241120学习日志——[CSDIY] [InternStudio] 大模型训练营 [09]

CSDIY:这是一个非科班学生的努力之路,从今天开始这个系列会长期更新,(最好做到日更),我会慢慢把自己目前对CS的努力逐一上传,帮助那些和我一样有着梦想的玩家取得胜利!!&…

PCB 间接雷击模拟

雷击是一种危险的静电放电事件,其中两个带电区域会瞬间释放高达 1 千兆焦耳的能量。雷击就像一个短暂而巨大的电流脉冲,会对建筑物和电子设备造成严重损坏。雷击可分为直接和间接两类,其中间接影响是由于感应能量耦合到靠近雷击位置的物体。间…

IDEA2019搭建Springboot项目基于java1.8 解决Spring Initializr无法创建jdk1.8项目 注释乱码

后端界面搭建 将 https://start.spring.io/ 替换https://start.aliyun.com/ 报错 打开设置 修改如下在这里插入代码片 按此方法无果 翻阅治疗后得知 IDEA2019无法按照网上教程修改此问题因此更新最新idea2024或利用插件Alibaba Clouod Toolkit 换用IDEA2024创建项目 下一步…

单向C to DP视频传输解决方案 | LDR6500

LDR6500D如何通过Type-C接口实现手机到DP接口的单向视频传输 在当今数字化浪潮中,投屏技术作为连接设备、共享视觉内容的桥梁,其重要性日益凸显。PD(Power Delivery)芯片,特别是集成了Type-C接口与DisplayPort&#xf…

Leetcode 第 143 场双周赛题解

Leetcode 第 143 场双周赛题解 Leetcode 第 143 场双周赛题解题目1:3345. 最小可整除数位乘积 I思路代码复杂度分析 题目2:3346. 执行操作后元素的最高频率 I思路代码复杂度分析 题目3:3347. 执行操作后元素的最高频率 II题目4:33…

Spark 之 Aggregate

Aggregate 参考链接: https://github.com/PZXWHU/SparkSQL-Kernel-Profiling 完整的聚合查询的关键字包括 group by、 cube、 grouping sets 和 rollup 4 种 。 分组语句 group by 后面可以是一个或多个分组表达式( groupingExpressions )…

【IDEA】解决总是自动导入全部类(.*)问题

文章目录 问题描述解决方法 我是一名立志把细节说清楚的博主,欢迎【关注】🎉 ~ 原创不易, 如果有帮助 ,记得【点赞】【收藏】 哦~ ❥(^_-)~ 如有错误、疑惑,欢迎【评论】指正探讨,我会尽可能第一时间回复…

如何快速将Excel数据导入到SQL Server数据库

工作中,我们经常需要将Excel数据导入到数据库,但是对于数据库小白来说,这可能并非易事;对于数据库专家来说,这又可能非常繁琐。 这篇文章将介绍如何帮助您快速的将Excel数据导入到sql server数据库。 准备工作 这里&…

在centos7中安装SqlDeveloper的Oracle可视化工具

1.下载安装包 (1)在SqlDeveloper官网下载(Oracle SQL Developer Release 19.2 - Get Started)对应版本的安装包即可(安装包和安装命令如下): (2)执行完上述命令后&#x…

【动手学深度学习Pytorch】4. 神经网络基础

模型构造 回顾一下感知机。 nn.Sequential():定义了一种特殊的module。 torch.rand():用于生成具有均匀分布的随机数,这些随机数的范围在[0, 1)之间。它接受一个形状参数(shape),返回一个指定形状的张量&am…

Spring Boot + Vue 基于 RSA 的用户身份认证加密机制实现

Spring Boot Vue 基于 RSA 的用户身份认证加密机制实现 什么是RSA?安全需求介绍前后端交互流程前端使用 RSA 加密密码安装 jsencrypt库实现敏感信息加密 服务器端生成RSA的公私钥文件Windows环境 生成rsa的公私钥文件Linux环境 生成rsa的公私钥文件 后端代码实现返…

一键部署 200+ 开源软件的 Websoft9 面板,Github 2k+ 星星

Websoft9面板是一款基于Web的PaaS/Linux面板,可用于在自己的服务器上一键部署200多种热门开源应用,在Github上获得了2k星星。 特点与优势 丰富的开源软件集成:涵盖数据库、Web服务器、企业建站、电商系统、教育系统、中间件、大数据工具等多…

NLP论文速读(MPO)|通过混合偏好优化提高多模态大型语言模型的推理能力

论文速读|Dynamic Rewarding with Prompt Optimization Enables Tuning-free Self-Alignment of Language Models 论文信息: 简介: 本文探讨的背景是多模态大型语言模型(MLLMs)在多模态推理能力上的局限性,尤其是在链式…

动态规划子数组系列一>等差数列划分

题目&#xff1a; 解析&#xff1a; 代码&#xff1a; public int numberOfArithmeticSlices(int[] nums) {int n nums.length;int[] dp new int[n];int ret 0;for(int i 2; i < n; i){dp[i] nums[i] - nums[i-1] nums[i-1] - nums[i-2] ? dp[i-1]1 : 0;ret dp[i…