【计算机视觉】siamfc论文复现实现目标追踪

什么是目标跟踪

使用视频序列第一帧的图像(包括bounding box的位置),来找出目标出现在后序帧位置的一种方法。

什么是孪生网络结构

孪生网络结构其思想是将一个训练样本(已知类别)和一个测试样本(未知类别)输入到两个CNN(这两个CNN往往是权值共享的)中,从而获得两个特征向量,然后通过计算这两个特征向量的的相似度,相似度越高表明其越可能是同一个类别。

在这里插入图片描述

给你一张我的正脸照(没有经过美颜处理的),你该如何在人群中找到我呢?一种最直观的方案就是:“谁长得最像就是谁”。但是对于计算机来说,如何衡量“长得像”,并不是个简单的问题。这就涉及一种基本的运算——互相关(cross-correlation)。互相关运算可以用来度量两个信号之间的相似性。互相关得到的响应图中每个像素的响应高低代表着每个位置相似度的高低。

在这里插入图片描述

在目标领域中,最早利用这种思想的是SiamFC,其网络结构如上图。图中的φ就是CNN编码器,上下两个分支使用的CNN不仅结构相同,参数也是完全共享的(说白了就是同一个网络,并不存在孪生兄弟那样的设定)。z和x分别是要跟踪的目标模版图像(尺寸为127x127)和新的一帧中的搜索范围(尺寸为255x255)。二者经过同样的编码器后得到各自的特征图,对二者进行互相关运算后则会同样得到一个响应图(尺寸为17x17),其每一个像素的值对应了x中与z等大的一个对应区域出现跟踪目标的概率。

互相关运算的步骤,像极了我们手里拿着一张目标的照片(模板图像),然后把这个照片按在需要寻找目标的图片上(搜索图像)进行移动,然后求重叠部分相似度,从而找到这个目标,只不过为了计算机计算的方便,使用AlexNet对图像数据进行了编码/特征提取

下面这个版本中有一些动图,还是会帮助理解的:https://github.com/rafellerc/Pytorch-SiamFC

SiamFC代码分析

我们对siamese的结构大致就讲完了,还有一些内容结合代码来讲,效果更好。

3.1 training

3.1.1图像预处理

小超up给出训练的框图如下。训练过程中,首先要获取训练数据集的所有视频序列(每个视频序列的所有帧),我采用的是GOT-10k数据集训练;获取数据集之后进行图像预处理,对每一个视频序列抽取两帧图像并作数据增强处理(包括裁剪、resize等过程),分别作为目标模板图像和搜索图像;把经过图像处理的所有图像对加载并以batch_size输入网络得到预测输出;建立标签和损失函数,损失函数的输入是预测输出,目标是标签;设置优化策略,梯度下降损失,最终得到网络模型。

在这里插入图片描述

先贴代码,再分析:

def train(data_dir, net_path=None,save_dir='pretrained'):#从文件中读取图像数据集seq_dataset = GOT10k(data_dir,subset='train',return_meta=False)#定义图像预处理方法transforms = SiamFCTransforms(  exemplar_sz=cfg.exemplar_sz, #127instance_sz=cfg.instance_sz, #255context=cfg.context) #0.5#从读取的数据集每个视频序列配对训练图像并进行预处理,裁剪等train_dataset = GOT10kDataset(seq_dataset,transforms)

data_dir是存放GOT-10k数据集的文件路径,GOT-10k一共有9335个训练视频序列,seq_dataset返回的是所有视频序列的图片路径列表seq_dirs及对应groundtruth列表anno_files及一些其他信息,如下:

img

接下来是定义好图像预处理方法,在GOT10kDataset方法中对每个视频序列配对两帧图像,并使用定义好的图像处理方法,接下来直接进入该方法分析代码,GOT10kDataset的代码如下:

class GOT10kDataset(Dataset): #继承了torch.utils.data的Dataset类def __init__(self, seqs, transforms=None,pairs_per_seq=1):def __getitem__(self, index): #通过_sample_pair方法得到索引返回item=(z,x,box_z,box_x),然后经过transforms处理def __len__(self): #返回9335*pairs_per_seq对def _sample_pair(self, indices): #随机挑选两个索引,这里取的间隔不超过T=100def _filter(self, img0, anno, vis_ratios=None): #通过该函数筛选符合条件的有效索引val_indices

这里最重要的方法就是__getitem__,该方法最终返回处理后的图像,在内部首先调用了_sample_pair方法,用于提取两帧有效图片(有效的定义是图片目标的面积和高宽等有约束条件)的索引,在得到这两帧图片和对应groundtruth之后通过定义好的transforms进行处理,transforms是SiamFCTransforms类的实例化对象,该类中主要继承了resize图片大小和各种裁剪方式等,如代码所示:

class SiamFCTransforms(object):def __init__(self, exemplar_sz=127, instance_sz=255, context=0.5):self.exemplar_sz = exemplar_szself.instance_sz = instance_szself.context = context#transforms_z/x是数据增强方法self.transforms_z = Compose([RandomStretch(),     #随机resize图片大小,变化再[1 1.05]之内CenterCrop(instance_sz - 8),  #中心裁剪 裁剪为255-8RandomCrop(instance_sz - 2 * 8),   #随机裁剪  255-8->255-8-8CenterCrop(exemplar_sz),   #中心裁剪 255-8-8->127ToTensor()])                        #图片的数据格式从numpy转换成torch张量形式self.transforms_x = Compose([RandomStretch(),                   #s随机resize图片CenterCrop(instance_sz - 8),      #中心裁剪 裁剪为255-8RandomCrop(instance_sz - 2 * 8),  #随机裁剪 255-8->255-8-8ToTensor()])                      #图片数据格式转化为torch张量def __call__(self, z, x, box_z, box_x): #z,x表示传进来的图像z = self._crop(z, box_z, self.instance_sz)       #对z(x类似)图像 1、box转换(l,t,w,h)->(y,x,h,w),并且数据格式转为float32,得到center[y,x],和target_sz[h,w]x = self._crop(x, box_x, self.instance_sz)       #2、得到size=((h+(h+w)/2)*(w+(h+2)/2))^0.5*255(instance_sz)/127z = self.transforms_z(z)                         #3、进入crop_and_resize:传入z作为图片img,center,size,outsize=255(instance_sz),随机选方式填充,均值填充x = self.transforms_x(x)                         #   以center为中心裁剪一块边长为size大小的正方形框(注意裁剪时的padd边框填充问题),再resize成out_size=255(instance_sz)return z, x

实例化对象后,直接从__call__开始运行代码,首先关注的应该是_crop函数,该函数将原始的两帧图片分别以目标为中心,裁剪一块包含上下文信息的patch,patch的边长定义如下:

在这里插入图片描述

式中,w、h分别表示目标的宽和高。下面具体讲里面的_crop函数:

    def _crop(self, img, box, out_size):# convert box to 0-indexed and center based [y, x, h, w]box = np.array([box[1] - 1 + (box[3] - 1) / 2,box[0] - 1 + (box[2] - 1) / 2,box[3], box[2]], dtype=np.float32)center, target_sz = box[:2], box[2:]context = self.context * np.sum(target_sz)size = np.sqrt(np.prod(target_sz + context))size *= out_size / self.exemplar_szavg_color = np.mean(img, axis=(0, 1), dtype=float)interp = np.random.choice([cv2.INTER_LINEAR,cv2.INTER_CUBIC,cv2.INTER_AREA,cv2.INTER_NEAREST,cv2.INTER_LANCZOS4])patch = ops.crop_and_resize(img, center, size, out_size,border_value=avg_color, interp=interp)return patch

因为GOT-10k里面对于目标的bbox是以ltwh(即left, top, weight, height)形式给出的,上述代码一开始就先把输入的box变成center based,坐标形式变为[y, x, h, w],结合下面这幅图就非常好理解在这里插入图片描述img

crop_and_resize:

def crop_and_resize(img, center, size, out_size,border_type=cv2.BORDER_CONSTANT,border_value=(0, 0, 0),interp=cv2.INTER_LINEAR):# convert box to corners (0-indexed)size = round(size)  # the size of square cropcorners = np.concatenate((np.round(center - (size - 1) / 2),np.round(center - (size - 1) / 2) + size))corners = np.round(corners).astype(int)# pad image if necessarypads = np.concatenate((-corners[:2], corners[2:] - img.shape[:2]))npad = max(0, int(pads.max()))if npad > 0:img = cv2.copyMakeBorder(img, npad, npad, npad, npad,border_type, value=border_value)# crop image patchcorners = (corners + npad).astype(int)patch = img[corners[0]:corners[2], corners[1]:corners[3]]# resize to out_sizepatch = cv2.resize(patch, (out_size, out_size),interpolation=interp)return patch

在裁剪过程中会出现越界的情况,需要对原始图像边缘填充,填充值固定为图像的RGB均值,填充大小根据图像边缘越界最大值作为填充值,具体实现过程由以下代码完成。

# padding操作#corners表示目标的[ymin,xmin,ymax,xmax]pads = np.concatenate((-corners[:2], corners[2:] - img.shape[:2]))npad = max(0, int(pads.max())) #得到上下左右4个越界值中最大的与0对比,<0代表无越界if npad > 0:img = cv2.copyMakeBorder(img, npad, npad, npad, npad,cv2.BORDER_CONSTANT, value=img_average)

实验结果:

img

3.1.2加载训练数据、标签及损失函数

图像预处理完成后,得到了用与训练的9335对图像,将图像加载批量加载输入网络得到输出结果作为损失函数的input,损失函数的target是制定好的labels。

#加载训练数据集loader_dataset = DataLoader( dataset = train_dataset,batch_size=cfg.batch_size,shuffle=True,num_workers=cfg.num_workers,pin_memory=True,drop_last=True, )#初始化训练网络cuda = torch.cuda.is_available()  #支持GPU为Truedevice = torch.device('cuda:0' if cuda else 'cpu')  #cuda设备号为0model = AlexNet(init_weight=True)corr = _corr()model = model.to(device)corr = corr.to(device)# 设置损失函数和标签logist_loss = BalancedLoss()labels = _create_labels(size=[cfg.batch_size, 1, cfg.response_sz - 2, cfg.response_sz - 2])labels = torch.from_numpy(labels).to(device).float()

本小节主要讲网络输出的labels和损失函数,接下来只是小超up个人的一些理解,代码与论文理论部分形式不一致,但效果一样。先上图,论文中labels以及损失函数如下图:

img

img

然而代码中的labels值却是1和0,损失函数使用的是二值交叉熵损失函数F.binary_cross_entropy_with_logits,如下图推导所示,解释了为什么代码实现部分真正使用的labels值是1和0,而理论部分使用的是1和-1。

img

利用下面代码的这个_creat_labels方法可以得到标签。

def _create_labels(size):def logistic_labels(x, y, r_pos):# x^2+y^2<4 的位置设为为1,其他为0dist = np.sqrt(x ** 2 + y ** 2)labels = np.where(dist <= r_pos,    #r_os=2np.ones_like(x),  #np.ones_like(x),用1填充xnp.zeros_like(x)) #np.zeros_like(x),用0填充xreturn labels#获取标签的参数n, c, h, w = size  # [8,1,15,15]x = np.arange(w) - (w - 1) / 2  #x=[-7 -6 ....0....6 7]y = np.arange(h) - (h - 1) / 2  #y=[-7 -6 ....0....6 7]x, y = np.meshgrid(x, y)       #建立标签r_pos = cfg.r_pos / cfg.total_stride  # 16/8labels = logistic_labels(x, y, r_pos)#重复batch_size个label,因为网络输出是batch_size张response maplabels = labels.reshape((1, 1, h, w))   #[1,1,15,15]labels = np.tile(labels, (n, c, 1, 1))  #将labels扩展[8,1,15,15]return labels

验证结果如下图,只截取了部分labels,得到的labels对应输入,大小都是[8,1,15,15]

if __name__ == '__main__':labels = _create_labels([8,1,15,15])  #返回的label.shape=(8,1,15,15)

其中关于np.tile、np.meshgrid、np.where函数的使用可以去看这篇博客,最后出来的一个batch下某一个通道下的label就是下面这样的

img

3.1.3 优化策略

这里主要说一下学习率lr,随着训练次数epoch增多而减小,具体值如下公式image-20240721105017162,式中,initial为初始学习率,gamma是定义的超参,epoch为训练次数。整个优化器及学习率调整实现代码如下:

#建立优化器,设置指数变化的学习率optimizer = optim.SGD(model.parameters(),lr=cfg.initial_lr,              #初始化的学习率,后续会不断更新weight_decay=cfg.weight_decay,  #λ=5e-4,正则化momentum=cfg.momentum)          #v(now)=dx∗lr+v(last)∗momemtumgamma = np.power(                   #np.power(a,b) 返回a^bcfg.ultimate_lr / cfg.initial_lr,1.0 / cfg.epoch_num)lr_scheduler = ExponentialLR(optimizer, gamma)  #指数形式衰减,lr=initial_lr*(gamma^epoch)=
3.1.4 模型的训练与保存

一切准备工作就绪后,就开始训练了。代码中设定epoch_num为50次,训练时密切加上model.train(),告诉网络处于训练状态,这样,网络运行时就会利用pytorch的自动求导机制求导;在测试时,改为model.eval(),关闭自动求导。模型训练的步骤如代码所示:

# loop over epochs
for epoch in range(self.cfg.epoch_num):# update lr at each epochself.lr_scheduler.step(epoch=epoch)# loop over dataloaderfor it, batch in enumerate(dataloader):loss = self.train_step(batch, backward=True)print('Epoch: {} [{}/{}] Loss: {:.5f}'.format(epoch + 1, it + 1, len(dataloader), loss))sys.stdout.flush()# save checkpointif not os.path.exists(save_dir):os.makedirs(save_dir)net_path = os.path.join(save_dir, 'siamfc_alexnet_e%d.pth' % (epoch + 1))torch.save(self.net.state_dict(), net_path)

至此此份repo的训练应该差不多结束了

参考文档

siameseFC论文和代码解析

SiamFC 学习(论文、总结与分析)

siamfc-pytorch代码讲解(一):backbone&head

siamfc-pytorch代码讲解(二):train&siamfc

SiamFC代码分析(architecture、training、test)

http://www.360doc.com/content/19/0801/10/32196507_852333196.shtml

视频推荐

目标跟踪零基础代码入门(一):SiamFC_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/380699.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

jmeter部署

一、windows环境下部署 1、安装jdk并配置jdk的环境变量 (1) 安装jdk jdk下载完成后双击安装包&#xff1a;无限点击"下一步"直到完成&#xff0c;默认路径即可。 (2) jdk安装完成后配置jdk的环境变量 找到环境变量中的系统变量&#xff1a;此电脑 --> 右键属性 …

Figma 中文版指南:获取和安装汉化插件

Figma是一种主流的在线团队合作设计工具&#xff0c;也是一种基于 Web 端的设计工具。在当今的设计时代&#xff0c;Figma 的使用满足了每个人的设计需求&#xff0c;不仅可以实现在线编辑&#xff0c;还可以方便日常管理&#xff0c;有效提高工作效率。然而&#xff0c;相信很…

RPM、YUM 安装 xtrabackup 8 (mysql 热备系列一)包含rpm安装 mysql 8 配置主从

RPM安装 percona-xtrabackup-80-8.0.35-30.1.el7.x86_64.rpm 官网&#xff1a; https://www.percona.com/ 下载地址&#xff1a; https://www.percona.com/downloads wget https://downloads.percona.com/downloads/percona-distribution-mysql-ps/percona-distribution-mysq…

【Vue3】响应式数据

【Vue3】响应式数据 背景简介开发环境基本数据类型对象数据类型使用 reactive 定义对象类型响应式数据使用 ref 定义对象类型响应式数据 ref 和 reactive 的对比使用原则建议 背景 随着年龄的增长&#xff0c;很多曾经烂熟于心的技术原理已被岁月摩擦得愈发模糊起来&#xff0…

Linux系统之部署扫雷小游戏(三)

Linux系统之部署扫雷小游戏(三) 一、小游戏介绍1.1 小游戏简介1.2 项目预览二、本次实践介绍2.1 本地环境规划2.2 本次实践介绍三、检查本地环境3.1 检查系统版本3.2 检查系统内核版本3.3 检查软件源四、安装Apache24.1 安装Apache2软件4.2 启动apache2服务4.3 查看apache2服…

项目管理进阶之RACI矩阵

前言 项目管理进阶系列续新篇。 RACI&#xff1f;这个是什么矩阵&#xff0c;有什么用途&#xff1f; 在项目管理过程中&#xff0c;如Team规模超5以上时&#xff0c;则有必要采用科学的管理方式&#xff0c;满足工作需要。否则可能事倍功半。 Q&#xff1a;什么是RACI矩阵 …

【HarmonyOS】HarmonyOS NEXT学习日记:五、交互与状态管理

【HarmonyOS】HarmonyOS NEXT学习日记&#xff1a;五、交互与状态管理 在之前我们已经学习了页面布局相关的知识&#xff0c;绘制静态页面已经问题不大。那么今天来学习一下如何让页面动起来、并且结合所学完成一个代码实例。 交互 如果是为移动端开发应用&#xff0c;那么交…

Unity Apple Vision Pro 开发(四):体积相机 Volume Camera

文章目录 &#x1f4d5;教程说明&#x1f4d5;教程内容概括&#x1f4d5;体积相机作用&#x1f4d5;创建体积相机&#x1f4d5;添加体积相机配置文件&#x1f4d5;体积相机配置文件参数&#x1f4d5;体积相机的边界盒大小&#x1f4d5;体积相机边界盒大小和应用边界盒大小的区别…

弹性网络回归(Elastic Net Regression)

弹性网络回归&#xff08;Elastic Net Regression&#xff09;的详细理论知识推导 理论背景 弹性网络回归结合了岭回归&#xff08;Ridge Regression&#xff09;和Lasso回归&#xff08;Lasso Regression&#xff09;的优点&#xff0c;通过引入两个正则化参数来实现特征选择…

kubernetes k8s Deployment 控制器配置管理 k8s 红蓝部署 金丝雀发布

目录 1、Deployment控制器&#xff1a;概念、原理解读 1.1 Deployment概述 1.2 Deployment工作原理&#xff1a;如何管理rs和Pod&#xff1f; 2、Deployment资源清单文件编写技巧 3、Deployment使用案例&#xff1a;创建一个web站点 4、Deployment管理pod&#xff1a;扩…

通义千问AI模型对接飞书机器人-模型配置(2-1)

一 背景 根据业务或者使用场景搭建自定义的智能ai模型机器人&#xff0c;可以较少我们人工回答的沟通成本&#xff0c;而且可以更加便捷的了解业务需求给出大家设定的业务范围的回答&#xff0c;目前基于阿里云的通义千问模型研究。 二 模型研究 参考阿里云帮助文档&#xf…

持续集成04--Jenkins结合Gitee创建项目

前言 在持续集成/持续部署&#xff08;CI/CD&#xff09;的旅途中&#xff0c;Jenkins与版本控制系统的紧密集成是不可或缺的一环。本篇“持续集成03--Jenkins结合Gitee创建项目”将引导如何将Jenkins与Gitee&#xff08;一个流行的Git代码托管平台&#xff09;相结合&#xff…

修改了mybatis的xml中的sql不重启服务器如何动态加载更新

目录 一、背景 二、注意 三、代码 四、使用示例 五、其他参考博客 一、背景 开发一个报表功能&#xff0c;好几百行sql&#xff0c;每次修改完想自测下都要重启服务器&#xff0c;启动一次服务器就要3分钟&#xff0c;重启10次就要半小时&#xff0c;耗不起时间呀。于是在…

【Android】Activity的生命周期

Activity的生命周期 1.返回栈 其实Android是使用任务&#xff08;task&#xff09;来管理Activity的&#xff0c;一个任务就是一组存放在栈里的Activity的集合&#xff0c;这个栈也被称作返回栈&#xff08;back stack&#xff09;。栈是一种后进先出的数据结构&#xff0c;在…

自动驾驶-预测概览

通过生成一条路径来预测一个物体的行为&#xff0c;在每一个时间段内&#xff0c;为每一辆汽车重新计算预测他们新生成的路径&#xff0c;这些预测路径为规划阶段做出决策提供了必要信息 预测路径有实时性的要求&#xff0c;预测模块能够学习新的行为。我们可以使用多源的数据…

【Unity实战100例】Unity声音可视化多种显示效果

目录 一、技术背景 二、界面搭建 三、 实现 UIAudioVisualizer 基类 四、实现 AudioSampler 类 五、实现 IAudioSample 接口 六、实现MusicAudioVisualizer 七、实现 MicrophoneAudioManager 类 八、实现 MicrophoneAudioVisualizer 类 九、源码下载 Unity声音可视化四…

系统架构设计师教程 第3章 信息系统基础知识-3.6 办公自动化系统(OAS)-解读

系统架构设计师教程 第3章 信息系统基础知识-3.6 办公自动化系统&#xff08;OAS&#xff09; 3.6.1 办公自动化系统的概念3.6.1.1 办公活动3.6.1.1 办公自动化的概念 3.6.2 办公自动化系统的功能3.6.2.1 事务处理3.6.2.1.1 单机系统3.6.2.1.2 多机系统 3.6.2.2 信息管理3.6.2.…

YOLOV8/V7/V5的PCB缺陷检测:可视化界面+GUI+目标计数+视频目标检测与跟踪

在本文中&#xff0c;我将介绍如何使用PyQt5创建一个YOLOv8V7/V5目标检测的可视化界面&#xff0c;可以根据需求选择YOLOv8V7/V5的权重。 该可视化界面的功能丰富&#xff0c;包含内容&#xff1a; 1.GUI目标计数视频目标检测与跟踪 2.完整的OLO数据格式制作流程以及代码 3…

逆向案例二十三——请求头参数加密,某区块链交易逆向

网址&#xff1a;aHR0cHM6Ly93d3cub2tsaW5rLmNvbS96aC1oYW5zL2J0Yy90eC1saXN0L3BhZ2UvNAo 抓包分析&#xff0c;发现请求头有X-Apikey参数加密&#xff0c;其他表单和返回内容没有加密。 直接搜索关键字&#xff0c;X-Apikey&#xff0c;找到疑似加密位置&#xff0c;注意这里…

Spring Boot 中使用 Resilience4j 实现弹性微服务的简单了解

1. 引言 在微服务架构中&#xff0c;服务的弹性是非常重要的。Resilience4j 是一个轻量级的容错库&#xff0c;专为函数式编程设计&#xff0c;提供了断路器、重试、舱壁、限流器和限时器等功能。 这里不做过多演示&#xff0c;只是查看一下官方案例并换成maven构建相关展示&…