如何使用C#实现Padim算法的训练和推理

目录

说明

项目背景

算法实现

预处理模块——图像预处理

主要模块——训练:Resnet层信息提取

主要模块——信息处理,计算Anomaly Map

主要模块——评估

主要模块——评估:门限值的确定

主要模块——推理

写在最后

项目下载链接


说明

作者:来瓶霸王防脱发

项目地址:

https://github.com/IntptrMax/PadimSharp

原文地址:

https://blog.csdn.net/qq_30270773/article/details/143029865

项目背景

缺陷检测(Anomaly Detection)算法是一个区分正常类别与异常类别的二分类问题,但在工业场景中大多数数据都为良品,不良数据难以获取,更难枚举,所以训练一个全监督的模型是不切实际的。因此,异常检测模型通常以单类别学习的模式。Padim算法是一种十分优秀的缺陷检测算法,直接上图可以看一下这个算法的效果。

良品图片

图片

不良品图片

图片

检测效果

图片

C#是一种十分受欢迎的编程语言,这种编程语言在工业场景下使用也是十分广泛的。在一些AI领域,会在Python下将模型转化为onnx形式,通过onnxruntime加载使用,进行推理。但是在onnx形式下进行训练十分困难。很多C#开发者不太熟悉Python环境,或者某些条件下希望在纯粹的C#环境下进行深度学习的训练和使用。这个还是有一定的困难的。

目前搜索了Github和CSDN排名靠前的几十条数据,还没有Padim算法在除Python平台下的训练+推理的相关项目或资料。本文就是在C#平台实现了Padim的训练+推理过程,应该在相关领域也算是独一份了。

算法实现

Padim算法的“训练”过程其实并没有涉及到真正的训练,而是使用Resnet18算法提取关键信息加以处理,在推理时再次使用,因此“训练”过程速度非常快,这也是这个算法的优点之一。Padim算法的具体实现还请参考相关论文:PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization

https://arxiv.org/abs/2011.08785

如果论文看起来困难,还有一些大佬对该算法在Python平台下的解读,也可以参考:PaDiM 原理与代码解析

https://blog.csdn.net/ooooocj/article/details/127601035

预处理模块——图像预处理

图像预处理使用的方法比较常规,使用了缩放等方式,此处并没有使用LetterBox,也可以达到预期效果:

var transformers = torchvision.transforms.Compose([torchvision.transforms.Resize(resizeHeight,resizeWidth),
torchvision.transforms.CenterCrop(cropHeight,cropWidth),
torchvision.transforms.Normalize(means, stdevs)]);

主要模块——训练:Resnet层信息提取

使用Resnet模型进行推理,并提取Layer1、Layer2、Layer3层的信息,并进行了拼接(EmbeddingConcat)。注意:这里提取时使用了钩子,钩子在使用时会有资源释放,因此这里使用了比较迂回的方式记录结果

实现代码如下:

public List<(string, Tensor)> Forward(Tensor input)
{List<(string, Tensor)> outputs = new List<(string, Tensor)>();List<TempTensor> tempTensors = new List<TempTensor>();foreach (var named_module in model.named_children()){string name = named_module.name;if (name == "layer1" || name == "layer2" || name == "layer3"){((Sequential)named_module.module).register_forward_hook((Module, input, output) =>{tempTensors.Add(new TempTensor{Data = output.data<float>().ToArray(),Name = name,Shape = output.shape,});return null;});}}model.forward(input);var layer1output = tempTensors.Find(a => a.Name == "layer1");var layer2output = tempTensors.Find(a => a.Name == "layer2");var layer3output = tempTensors.Find(a => a.Name == "layer3");Tensor l1 = torch.tensor(layer1output.Data, layer1output.Shape, device: input.device);Tensor l2 = torch.tensor(layer2output.Data, layer2output.Shape, device: input.device);Tensor l3 = torch.tensor(layer3output.Data, layer3output.Shape, device: input.device);outputs.Add(new("layer1", l1));outputs.Add(new("layer2", l2));outputs.Add(new("layer3", l3));GC.Collect();return outputs;
}
private Tensor EmbeddingConcat(Tensor[] features)
{var embeddings = features[0];for (int i = 1; i < features.Length; i++){var layerEmbedding = features[i];layerEmbedding = torch.nn.functional.interpolate(layerEmbedding, size: [embeddings.shape[2], embeddings.shape[2]], mode: InterpolationMode.Nearest);embeddings = torch.cat([embeddings, layerEmbedding], 1);}return embeddings;
}

主要模块——信息处理,计算Anomaly Map

这一块主要对信息进行处理,获取矩阵的mean和cov(协方差矩阵),代码如下:

public Tensor ComputeAnomalyMapInternal(Tensor embedding, Tensor mean, Tensor covariance)
{var scoreMap = ComputeDistance(embedding, mean, covariance);var upSampledScoreMap = UpSample(scoreMap);var smoothedAnomalyMap = SmoothAnomalyMap(upSampledScoreMap);return smoothedAnomalyMap;
}public Tensor ComputeAnomalyMap(List<(string, Tensor)> outputs, Tensor mean, Tensor covariance, Tensor idx)
{Tensor embedding = GetEmbedding(outputs);var embeddingVectors = torch.index_select(embedding, 1, idx);return ComputeAnomalyMapInternal(embeddingVectors, mean, covariance);
}

主要模块——评估

与训练过程开始部分相似,也是获取图像的Embeddings,然后利用之前获取的Cov和mean计算马氏距离,以此评估图像的异常情况。马氏距离的计算方法如下:

private Tensor ComputeDistance(Tensor embedding, Tensor mean, Tensor covariance)
{long batch = embedding.shape[0];long channel = embedding.shape[1];long height = embedding.shape[2];long width = embedding.shape[3];Tensor inv_covariance = covariance.permute(2, 0, 1).inverse();var embedding_reshaped = embedding.reshape(batch, channel, height * width);var delta = (embedding_reshaped - mean).permute(2, 0, 1);var distances = (torch.matmul(delta, inv_covariance) * delta).sum(2).permute(1, 0);distances = distances.reshape(batch, 1, height, width);distances = distances.clamp(0).sqrt();return distances;
}

主要模块——评估:门限值的确定

这里需要确定图像的评估门限和像素值的评估门限。如果在评估时有负向样本,这个值会更准确,如果只有正向样本也是可以的。在Python下有个precision_recall_curve包,可以计算相关参数,但是在C#下时没有的,因此在此处仍旧只能造轮子,代码如下:

private (float[] precisions, float[] recalls, float[] thresholds) _precision_recall_curve_compute_single_class(Tensor yTrue, Tensor yScores, int pos_label = 1)
{var (fps, tps, thresholds) = BinaryClfCurve(yScores, yTrue, pos_label);var precision = tps / (tps + fps);var recall = tps / tps[-1];var lastInd = torch.where(tps == tps[-1])[0][0].ToInt32();int[] sl = new int[lastInd + 1];for (int i = 0; i < sl.Length; i++){sl[i] = i;}var reversedPrecision = precision[sl].flip(0);var reversedRecall = recall[sl].flip(0);var reversedThresholds = thresholds[sl].flip(0);precision = torch.cat(new Tensor[] { reversedPrecision, torch.ones(1, dtype: precision.dtype, device: precision.device) });recall = torch.cat(new Tensor[] { reversedRecall, torch.zeros(1, dtype: recall.dtype, device: recall.device) });return (precision.data<float>().ToArray(), recall.data<float>().ToArray(), reversedThresholds.data<float>().ToArray());
}private (Tensor fps, Tensor tps, Tensor thresholds) BinaryClfCurve(Tensor preds, Tensor target, int posLabel = 1)
{using (torch.no_grad()){if (preds.ndim > target.ndim){preds = preds[TensorIndex.Ellipsis, 0];}var descScoreIndices = torch.argsort(preds, descending: true);preds = preds[descScoreIndices];target = target[descScoreIndices];Tensor weight = torch.tensor(1.0f);var distinctValueIndices = torch.nonzero(preds[1..] - preds[..^1]).squeeze();var thresholdIdxs = torch.cat(new Tensor[] { distinctValueIndices, torch.tensor(new long[] { target.shape[0] - 1 }, device: preds.device) });target = (target == posLabel).to_type(ScalarType.Int64);var tps = torch.cumsum(target * weight, dim: 0)[thresholdIdxs];Tensor fps = 1 + thresholdIdxs - tps;return (fps, tps, preds[thresholdIdxs]);}
}

主要模块——推理

这个过程与上面过程也十分相似,正向计算出图像的Anomaly Map后,取出这个张量中最大的值,与图像的门限值进行比较,即可评估图像是否是良品。然后对这个张量中每个元素与像素门限值做对比,即可得到按像素的异常区域,以便绘制Mask和热力图。

Tensor orgImg = tensors["orgImage"].clone().to(device);
Tensor t = anomaly_map > pixel_threshold;
anomaly_map = (anomaly_map * t).squeeze(0);
anomaly_map = torchvision.transforms.functional.resize(anomaly_map, (int)orgImg.size(2), (int)orgImg.size(1));
Tensor heatmapNormalized = (anomaly_map - anomaly_map.min()) / (anomaly_map.max() - anomaly_map.min());
Tensor coloredHeatmap = torch.zeros([3, (int)orgImg.size(2), (int)orgImg.size(1)],device:anomaly_map.device);coloredHeatmap[0] = heatmapNormalized.squeeze(0);float alpha = 0.3f;
Tensor blendedImage = (1 - alpha) * (orgImg / 255.0f) + alpha * coloredHeatmap;
var imageTensor = blendedImage.clamp(0, 1).mul(255).to(ScalarType.Byte);torchvision.io.write_jpeg(imageTensor.cpu(), "result.jpg");

写在最后

使用C#开发深度学习项目,尤其是训练的项目,是一个十分困难的过程。或者说除了Python平台,训练都十分困难。C#进行深度学习训练这个方向在国内基本很少有人开展,所以能查得到的资料很少。本人十分喜爱C#这门语言,又十分喜爱深度学习,因此仅半年一直在这方面努力。遇到了很多困难,也收获了很多。

这条路走的不容易,希望能有更多人能加入进来,一起开发,一起学习。

我在Github上已经将完整的代码发布了,项目地址为:

https://github.com/IntptrMax/PadimSharp

,期待你能在Github上送我一颗小星星。在我的Github里还GGMLSharp这个项目,这个项目也是C#平台下深度学习的开发包,希望能得到你的支持。

项目下载链接

https://download.csdn.net/download/qq_30270773/89897710

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/450683.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【即见未来,为何不拜】聊聊分布式系统中的故障监测机制——Phi Accrual failure detector

前言 昨天在看tcp拥塞控制中的BBR(Bottleneck Bandwidth and Round-trip propagation time)算法时&#xff0c;发现了这一特点&#xff1a; 在BBR以前的拥塞控制算法中(如Reno、Cubic、Vegas)&#xff0c;都依赖于丢包事件的发生&#xff0c;在高并发时则会看到网络波动的现象…

【含开题报告+文档+PPT+源码】基于SSM的景行天下旅游网站的设计与实现

开题报告 随着互联网的快速发展&#xff0c;旅游业也逐渐进入了数字化时代。作为一个旅游目的地&#xff0c;云浮市意识到了互联网在促进旅游业发展方面的巨大潜力。为了更好地推广云浮的旅游资源&#xff0c;提高旅游服务质量&#xff0c;云浮市决定开发一个专门的旅游网站。…

深入理解计算机系统--计算机系统漫游

对于一段最基础代码的文件hello.c&#xff0c;解释程序的运行 #include <stdio.h>int main() {printf ( "Hello, world\n") ;return 0; }1.1、信息就是位上下文 源程序是由值 0 和 1 组成的位&#xff08;比特&#xff09;序列&#xff0c;8 个位被组织成一组…

梯度下降算法优化—随机梯度下降、小批次、动量、Adagrad等方法pytorch实现

现有不足 现有调整网络的方法是借助成本函数的梯度下降方法&#xff0c;也就是给函数作切线&#xff0c;不断逼近最优点&#xff0c;即成本函数为零的点。 梯度下降的一般公式为&#xff1a; 即根据每个节点成本函数的梯度进行更新&#xff0c;使用该方法有一些问题&#xff…

探索OpenCV的人脸检测:用Haar特征分类器识别图片中的人脸

目录 简介 OpenCV和Haar特征分类器 实现人脸检测 1. 导入所需库 2. 加载图片和Haar特征分类器 3. 检测人脸 4. 标注人脸 5. 显示 6、结果展示 结论 简介 在计算机视觉和图像处理领域&#xff0c;人脸识别是一项重要的技术。它不仅应用于安全监控、人机交互&#xff0…

10秒钟用Midjourney画出国风味的变形金刚

上魔咒 Optimus Prime comes from the movie Transformers, Chinese style, Wu ShanMing, Ink Painting Halo Dyeing, Conceptual of the Digita Art, MasterComposition, Romantic Ancient Style, Inspired by traditional patterns and symbols, Minimalism, do not con…

day01 -- MybatisPlus

1. MybatisPlus简介 有基础的同学可结合资源中的代码一起看 MyBatis 的增强工具&#xff0c;在 MyBatis 的基础上只做增强不做改变&#xff0c;为简化开发、提高效率而生 特性 通用的 CRUD 操作&#xff1a;内置通用 Mapper、通用 Service&#xff0c;仅仅通过少量配置即可实…

私有化部署大模型最佳解决方案 Ollama (8B)模型

私有化部署大模型Ollama 为什么需要私有化部署大模型一、Ollama本地部署Llama3大模型二、Langchain4j调用Ollama本地部署模型API三、Ollama本地部署nomic向量模型四、Spring AI调用Ollama本地部署模型API 为什么需要私有化部署大模型 企业考虑成本和数据隐私问题&#xff0c;会…

021_Thermal_Transient_in_Matlab统一偏微分框架之热传导问题

Matlab求解有限元专题系列 固体热传导方程 固体热传导的方程为&#xff1a; ρ C p ( ∂ T ∂ t u t r a n s ⋅ ∇ T ) ∇ ⋅ ( q q r ) − α T d S d t Q \rho C_p \left( \frac{\partial T}{\partial t} \mathbf{u}_{\mathtt{trans}} \cdot \nabla T \right) \nab…

BM算法(手算版)

BM 算法 BM 算法是一种字符串匹配的算法。 与 KMP 相比&#xff0c;BM 算法不扫描全部输入字符&#xff0c;平均匹配时间 c・n, 常量 c <1 (随机或真实文本), 但最坏情况是 O (n・m). 可以将 BM 算法的最坏情况改进到 O (n)&#xff1a;通过记录文本后缀中最…

计算机系统简介

一、计算机的软硬件概念 1.硬件&#xff1a;计算机的实体&#xff0c;如主机、外设、硬盘、显卡等。 2.软件&#xff1a;由具有各类特殊功能的信息&#xff08;程序&#xff09;组成。 系统软件&#xff1a;用来管理整个计算机系统&#xff0c;如语言处理程序、操作系统、服…

群晖前面加了雷池社区版,安装失败,然后无法识别出用户真实访问IP

有nas的相信对公网都不模式&#xff0c;在现在基础上传带宽能有100兆的时代&#xff0c;有公网代表着家里有一个小服务器&#xff0c;像百度网盘&#xff0c;优酷这种在线服务都能部署为私有化服务。但现在运营商几乎不可能提供公网ip&#xff0c;要么自己买个云服务器做内网穿…

通过github创建自己网页链接的方法

文章目录 要使用GitHub创建静态网页链接&#xff0c;可以按照以下详细步骤进行操作&#xff1a;一、准备阶段二、创建仓库并配置三、准备并上传静态网站文件四、配置GitHub Pages五、访问和更新你的静态网页 要使用GitHub创建静态网页链接&#xff0c;可以按照以下详细步骤进行…

uniapp微信小程序调用百度OCR

uniapp编写微信小程序调用百度OCR 公司有一个识别行驶证需求&#xff0c;调用百度ocr识别 使用了image-tools这个插件&#xff0c;因为百度ocr接口用图片的base64 这里只是简单演示&#xff0c;accesstoken获取接口还是要放在服务器端&#xff0c;不然就暴露了自己的百度项目k…

Xshell使用密钥远程登录Ubuntu 22.04报错:所选的用户密钥未在远程主机上注册。请再试一次

报错截图如下&#xff1a; 问题原因&#xff1a; Ubuntu 22.04 不支持 Xshell使用的私钥。 查看系统支持的私钥&#xff1a;sudo sshd -T | egrep "pubkey" ~$ sudo sshd -T | egrep "pubkey" pubkeyauthentication yes pubkeyacceptedalgorithms ssh-ed…

基于SpringBoot+Vue的旅游服务平台【提供源码+答辩PPT+参考文档+项目部署】

&#x1f4a5; ① 前言&#xff1a;这两年毕业设计和毕业答辩的要求和难度不断提升&#xff0c;传统的JavaWeb项目缺少创新和亮点&#xff0c;往往达不到毕业答辩的要求&#xff01; ❗② 如何解决这类问题&#xff1f; 让我们能够顺利通过毕业&#xff0c;我也一直在不断思考、…

ROS 的 urdf 中 link 和 joint 的子标签中 origin 的含义

主要参考文章——主要文章&#xff0c;官方关于urdf的介绍和官方文档的翻译解析 link标签里面的origin含义 link标签里面有三个主要的子标签&#xff0c;分别是visual——连杆的外观和坐标系&#xff0c;collisoin——连杆的碰撞属性和inertial——连杆的惯性设置 首先&…

【AIGC】AI如何匹配RAG知识库: Embedding实践,语义搜索

引言 RAG作为减少模型幻觉和让模型分析、回答私域相关知识最简单高效的方式&#xff0c;我们除了使用之外可以尝试了解其是如何实现的。在实现RAG的过程中Embedding是非常重要的手段。本文将带你简单地了解AI工具都是如何通过Embedding去完成语义分析匹配的。 Embedding技术简…

低空经济发展迅猛,无人机设计制造技术详解

低空经济的迅猛发展&#xff0c;为无人机设计制造技术带来了新的机遇和挑战。无人机作为低空经济中的重要组成部分&#xff0c;其设计制造技术直接关系到无人机的性能、安全性和应用场景的拓展。以下是对无人机设计制造技术的详细解析&#xff1a; 一、无人机设计技术 1. 气动…

【HTML + CSS 魔法秀】打造惊艳 3D 旋转卡片

HTML结构 box 类是整个组件的容器。item-wrap 类是每个旋转卡片的包装器&#xff0c;每个都有一个内联样式–i&#xff0c;用于控制动画的延迟。item类是实际的卡片内容&#xff0c;包含一个图片。 <template><div class"box"><div class"item…