使用pytorch利用神经网络原理进行图片的训练(持续学习中....)

1.做这件事的目的
语言只是工具,使用python训练图片数据,最终会得到.pth的训练文件,java有使用这个文件进行图片识别的工具,顺便整合,我觉得Neo4J正确率太低了,草莓都能识别成为苹果,而且速度慢,不能持续识别视频帧

2.什么是神经网络?(其实就是数学的排列组合最终得到统计结果的概率)

1.先把二维数组转为一维
2.通过公式得到节点个数和值
3…同2
4.通过节点得到概率(softmax归一化公式)
5.对比模型的和 差值=原始概率-目标结果概率
6.不断优化原来模型的概率
5.激活函数,激活某个节点的函数,可以引入非线性的(因为所有问题不可能是线性的比如 很少图片识别一定可以识别出绝对的正方形,他可能中间有一定弯曲或者线在中心短开了)

在这里插入图片描述
在这里插入图片描述

3.训练的代码
//环境python3.8 最好使用conda进行版本管理,不然每个版本都可能不兼容,到处碰壁

 #安装依赖pip install numpy torch torchvision matplotlib

#文件夹结构,图片一定要是28x28的
在这里插入图片描述

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
from torchvision.datasets.folder import ImageFolderclass Net(torch.nn.Module):def __init__(self):super().__init__()self.fc1 = torch.nn.Linear(28 * 28, 64)self.fc2 = torch.nn.Linear(64, 64)self.fc3 = torch.nn.Linear(64, 64)self.fc4 = torch.nn.Linear(64, 10)def forward(self, x):x = torch.nn.functional.relu(self.fc1(x))x = torch.nn.functional.relu(self.fc2(x))x = torch.nn.functional.relu(self.fc3(x))x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)return x#导入数据
def get_data_loader(is_train):#张量,多维数组to_tensor = transforms.Compose([transforms.ToTensor()])# 下载数据集 下载目录data_set = MNIST("", is_train, transform=to_tensor, download=True)#一个批次15张,顺序打乱return DataLoader(data_set, batch_size=15, shuffle=True)def get_image_loader(folder_path):to_tensor = transforms.Compose([transforms.ToTensor()])data_set = ImageFolder(folder_path, transform=to_tensor)return DataLoader(data_set, batch_size=1)#评估准确率
def evaluate(test_data, net):n_correct = 0n_total = 0with torch.no_grad():#按批次取数据for (x, y) in test_data:#计算神经网络预测值outputs = net.forward(x.view(-1, 28 * 28))for i, output in enumerate(outputs):#比较预测结果和测试集结果if torch.argmax(output) == y[i]:#统计正确预测结果数n_correct += 1#统计全部预测结果n_total += 1#返回准确率=正确/全部的return n_correct / n_totaldef main():#加载训练集train_data = get_data_loader(is_train=True)#加载测试集test_data = get_data_loader(is_train=False)#初始化神经网络net = Net()#打印测试网络的准确率 0.1print("initial accuracy:", evaluate(test_data, net))#训练神经网络optimizer = torch.optim.Adam(net.parameters(), lr=0.001)#重复利用数据集 2次for epoch in range(100):for (x, y) in train_data:#初始化 固定写法net.zero_grad()#正向传播output = net.forward(x.view(-1, 28 * 28))#计算差值loss = torch.nn.functional.nll_loss(output, y)#反向误差传播loss.backward()#优化网络参数optimizer.step()print("epoch", epoch, "accuracy:", evaluate(test_data, net))# #使用3张图片进行预测# for (n, (x, _)) in enumerate(test_data):#     if n > 3:#         break#     predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))#     plt.figure(n)#     plt.imshow(x[0].view(28, 28))#     plt.title("prediction: " + str(int(predict)))# plt.show()image_loader = get_image_loader("aa")for (n, (x, _)) in enumerate(image_loader):if n > 2:breakpredict = torch.argmax(net.forward(x.view(-1, 28 * 28)))plt.figure(n)plt.imshow(x[0].permute(1, 2, 0))plt.title("prediction: " + str(int(predict)))plt.show()if __name__ == "__main__":main()

#运行结果 弹框出现图片和识别结果

4.测试电脑的cuda是否安装成功,不成功不能运行下面的代码

import torchdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print('CUDA version:', torch.version.cuda)
print('PyTorch version:', torch.__version__)

5.在gpu上运行,需要去官网下载cuda安装
https://developer.nvidia.com/cuda-toolkit-archive
#并且需要安装和torch对应的版本,我的电脑是1660ti的所以安装了10.2的cuda
#安装torchgpu版本

pip install torch==1.9.0+cu102 -f
https://download.pytorch.org/whl/cu102/torch_stable.html

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
from torchvision.datasets.folder import ImageFolderdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class Net(torch.nn.Module):def __init__(self):super().__init__()self.fc1 = torch.nn.Linear(28 * 28, 64)self.fc2 = torch.nn.Linear(64, 64)self.fc3 = torch.nn.Linear(64, 64)self.fc4 = torch.nn.Linear(64, 10)def forward(self, x):x = torch.nn.functional.relu(self.fc1(x))x = torch.nn.functional.relu(self.fc2(x))x = torch.nn.functional.relu(self.fc3(x))x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)return xdef get_data_loader(is_train):to_tensor = transforms.Compose([transforms.ToTensor()])data_set = MNIST("", is_train, transform=to_tensor, download=True)return DataLoader(data_set, batch_size=15, shuffle=True)def get_image_loader(folder_path):to_tensor = transforms.Compose([transforms.ToTensor()])data_set = ImageFolder(folder_path, transform=to_tensor)return DataLoader(data_set, batch_size=1)def evaluate(test_data, net):n_correct = 0n_total = 0with torch.no_grad():for (x, y) in test_data:x, y = x.to(device), y.to(device)outputs = net.forward(x.view(-1, 28 * 28))for i, output in enumerate(outputs):if torch.argmax(output.cpu()) == y[i].cpu():n_correct += 1n_total += 1return n_correct / n_totaldef main():train_data = get_data_loader(is_train=True)test_data = get_data_loader(is_train=False)net = Net().to(device)print("initial accuracy:", evaluate(test_data, net))optimizer = torch.optim.Adam(net.parameters(), lr=0.001)for epoch in range(100):for (x, y) in train_data:x, y = x.to(device), y.to(device)net.zero_grad()output = net.forward(x.view(-1, 28 * 28))loss = torch.nn.functional.nll_loss(output, y)loss.backward()optimizer.step()print("epoch", epoch, "accuracy:", evaluate(test_data, net))image_loader = get_image_loader("aa")for (n, (x, _)) in enumerate(image_loader):if n > 2:breakx = x.to(device)predict = torch.argmax(net.forward(x.view(-1, 28 * 28)).cpu())plt.figure(n)plt.imshow(x[0].permute(1, 2, 0).cpu())plt.title("prediction: " + str(int(predict)))plt.show()if __name__ == "__main__":main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/200068.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ElementUI及ElementUI Plus Axure RP高保真交互元件库及模板库

基于ElementUI2.0及ElementUI Plus3.0二次创作的ElementUI 元件库。2个版本的原型图内容会有所不同,ElementUI Plus3.0的交互更加丰富和高级。你可以同时使用这两个版本。 不仅包含Element UI 2.0版,还包含Element Plus 3版本。Element 2版支持Axure 8&…

接口自动化项目落地之HTTPBin网站

原文:https://www.cnblogs.com/df888/p/16011061.html 接口自动化项目落地系列 找个开源网站或开源项目,用tep实现整套pytest接口自动化项目落地,归档到电子书,作为tep完整教程的项目篇一部分。自从tep完整教程发布以后&#…

基于单片机的公共场所马桶设计(论文+源码)

1.系统设计 本课题为公共场所的马桶设计,其整个系统架构如图2.1所示,其采用STC89C52单片机为核心控制器,结合HC-SR04人体检测模块,压力传感器,LCD1602液晶,蜂鸣器,L298驱动电路等构成整个系统&…

RedisInsight——redis的桌面UI工具使用实践

下载 官网下载安装。下载地址在这里 填个邮箱地址就可以下载了。 安装使用。 安装成功后开始使用。 1. 你可以add一个地址。或者登录redis cloud 去auto-discover 2 . 新增你的redis库地址。注意index的取值 3。现在可以登录到redis了。看看结果 这是现在 在服务器上执行…

C#核心笔记——(二)C#语言基础

一、C#程序 1.1 基础程序 using System; //引入命名空间namespace CsharpTest //将以下类定义在CsharpTest命名空间中 {internal class TestProgram //定义TestProgram类{public void Test() { }//定义Test方法} }方法是C#中的诸多种类的函数之一。另一种函数*,还…

BLIP-2:冻结现有视觉模型和大语言模型的预训练模型

Li J, Li D, Savarese S, et al. Blip-2: Bootstrapping language-image pre-training with frozen image encoders and large language models[J]. arXiv preprint arXiv:2301.12597, 2023. BLIP-2,是 BLIP 系列的第二篇,同样出自 Salesforce 公司&…

数据仓库高级面试题

数仓高内聚低耦合是怎么做的 定义 高内聚:强调模块内部的相对独立性,要求模块内部的元素尽可能的完成一个功能,不混杂其他功能,从而使模块保持简洁,易于理解和管理。 低耦合:模块之间的耦合度要尽可能的…

python实现炫酷的屏幕保护程序

shigen日更文章的博客写手,擅长Java、python、vue、shell等编程语言和各种应用程序、脚本的开发。记录成长,分享认知,留住感动。 上次的文章如何实现一个下班倒计时程序的阅读量很高,觉得也很实用酷炫,下边是昨天的体验…

基于STC12C5A60S2系列1T 8051单片机的液晶显示器LCD1602显示两行常规字符应用

基于基于STC12C5A60S2系列1T 8051单片机的液晶显示器LCD1602显示两行常规字符应用 STC12C5A60S2系列1T 8051单片机管脚图STC12C5A60S2系列1T 8051单片机I/O口各种不同工作模式及配置STC12C5A60S2系列1T 8051单片机I/O口各种不同工作模式介绍液晶显示器LCD1602简单介绍通过液晶显…

Java实现象棋算法

象棋算法包括搜索算法、评估函数和剪枝算法。以下是一个简单的实现: 搜索算法:使用极大极小值算法,即每个玩家都会做出最好的选择,考虑到对方也会做出最好的选择,所以需要搜索多层。 public int search(int depth, i…

(二)pytest自动化测试框架之添加测试用例步骤(@allure.step())

前言 在编写自动化测试用例的时候经常会遇到需要编写流程性测试用例的场景,一般流程性的测试用例的测试步骤比较多,我们在测试用例中添加详细的步骤会提高测试用例的可阅读性。 allure提供的装饰器allure.step()是allure测试报告框架非常有用的功能&am…

Linux下安装两个版本python

1 python下载: 官网地址:Download Python | Python.org 第一:点击下载如下图: 第二:找到对应的python版本源码包: 点击右键复制下载地址,如下图 例如我的是:https://www.python.org/…

Java(四)(多态,final,常量,抽象类,接口)

目录 多态 基本概念: 使用多态的好处 类型转换 遇到的问题 解决方法 强制类型转换的一个注意事项 final 常量 抽象类 啥是个抽象类? 抽象类的注意事项,特点 抽象类的场景和好处 抽象类的常见应用场景: 模板方法设计模式 接口 基本概念 接口的好处 JDK8开始,接…

太累了,是时候让AI数字人来帮我干活了(走,上教程)

阿酷TONY,原创文章,长沙,2023.11.21 关 键 词:AI数字人,生成式AI,智能数字分身适用场景:培训数字人,演讲授课数字人,直播带货数字人特别说明:教程用的是国内…

学习Rust适合写什么练手项目?【云驻共创】

Rust是一门备受关注的系统级编程语言,因其出色的内存安全性、高性能和并发性能而备受赞誉。对于那些希望学习和掌握Rust编程语言的人来说,练手项目是一个不可或缺的环节。通过实际动手完成项目,你可以加深对Rust语言特性和最佳实践的理解&…

如何修改百科内容?百度百科内容怎么修改?

百科词条创建上去是相当不易的,同时修改也是如此,一般情况下,百科词条是不需要修改的,但是很多时候企业或是人物在近期收获了更多成就或是有更多的变动,这个时候就需要补充维护词条了,如何修改百科内容&…

猫12分类:使用yolov5训练检测模型

前言: 在使用yolov5之前,尝试过到百度飞桨平台(小白不建议)、AutoDL平台(这个比较友好,经济实惠)训练模型。但还是没有本地训练模型来的舒服。因此远程了一台学校电脑来搭建自己的检测模型。配置…

C++引用

目录 一.概念 二. 引用特性 三. 常引用 四. 使用场景 1,做参数 2,做返回值 五. 传值,传引用效率比较 5.1 值和引用的作为参数的性能比较 5.2 值和引用的作为返回值类型的性能比较 六. 引用和指针的区别 一.概念 引用不是新定义一…

16位 (MCU) R7F101G6G3CSP、R7F101G6E3CSP、R7F101G6G2DSP、R7F101G6E2DSP是新一代RL78通用微控制器

产品描述 RL78/G24微控制器具有RL78系列MCU的最高处理性能,CPU工作频率高达48MHz,设有灵活的应用加速器 (FAA)。FAA是一款专门用于算法运算的协处理器,可以独立于CPU运行,提供更高处理能力。RL78/G24 MCU具有增强的模拟功能和大量…

构建和应用卡尔曼滤波器 (KF)--扩展卡尔曼滤波器 (EKF)

作为一名数据科学家,我们偶尔会遇到需要对趋势进行建模以预测未来值的情况。虽然人们倾向于关注基于统计或机器学习的算法,但我在这里提出一个不同的选择:卡尔曼滤波器(KF)。 1960 年代初期,Rudolf E. Kal…