机器学习预测-CNN手写字识别

介绍

这段代码是使用PyTorch实现的卷积神经网络(CNN),用于在MNIST数据集上进行图像分类。让我一步步解释:

  1. 导入库:代码导入了必要的库,包括PyTorch(torch)、神经网络模块(torch.nn)、函数模块(torch.nn.functional)、图像数据集(torchvision)以及数据处理(torch.utils.data)和可视化(matplotlib.pyplot)的工具。

  2. 设置超参数:定义了超参数,如批大小(Batch_size)、epoch数量(Epoch)和学习率(Lr)。

  3. 加载MNIST数据集:使用torchvision.datasets.MNIST加载MNIST数据集。该数据集包含了0到9的手写数字的灰度图像。transform=torchvision.transforms.ToTensor()将PIL图像转换为PyTorch张量。

  4. 可视化样本数据:打印数据集的大小,并显示数据集中的第一张图像及其相应的标签。

  5. 准备测试数据:准备测试数据与训练数据类似。加载MNIST测试数据集,并选择前2000个图像进行测试。

  6. 创建数据加载器:使用torch.utils.data.DataLoader创建训练数据的数据加载器。它有助于在训练过程中对数据进行分批和混洗。

  7. 定义CNN架构:通过子类化nn.Module来定义CNN类。该架构包括两个卷积层(self.con1self.con2),后面跟有ReLU激活函数和最大池化层。卷积层的输出被展平并馈入全连接层(self.out),产生最终输出。

  8. 初始化CNN:创建CNN类的实例。

  9. 定义损失函数和优化器:使用交叉熵损失(nn.CrossEntropyLoss)作为损失函数,使用随机梯度下降(torch.optim.SGD)作为优化器。

  10. 训练CNN:在指定的epoch数量循环内训练模型。在循环内,将训练数据通过模型,计算损失,进行梯度反向传播,并由优化器更新模型参数。

  11. 测试模型:每50次迭代训练时,对测试数据集进行评估。将测试预测与真实标签进行比较,计算准确率。

  12. 打印结果:训练结束后,打印模型预测及前10个测试样本的真实标签。

总的来说,这段代码训练了一个CNN模型,用于在MNIST数据集上对手写数字进行分类,并在单独的测试数据集上评估其性能。

代码

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.utils.data as Data
import matplotlib.pyplot as plt# define hyper parameters
Batch_size = 100
Epoch = 1
Lr = 0.5
#DOWNLOAD_MNIST = True # 若没有数据,用此生成数据# define train data and test data
train_data = torchvision.datasets.MNIST(root='./mnist',train=True,download=False,transform=torchvision.transforms.ToTensor()
)
print(train_data.data.size())
print(train_data.targets.size())
print(train_data.data[0])
# 画一个图片显示出来
plt.imshow(train_data.data[0].numpy(),cmap='gray')
plt.title('%i'%train_data.targets[0])
plt.show()
# print(train_data.data.shape)           # torch.Size([60000, 28, 28])
# print(train_data.targets.size())        # torch.Size([60000])
# print(train_data.data[0].size())        # torch.Size([28, 28])
# plt.imshow(train_data.data[0].numpy(), cmap='gray')
# plt.show()
test_data = torchvision.datasets.MNIST(root='./mnist',train=False,# transform=torchvision.transforms.ToTensor()
)
test_x = torch.unsqueeze(test_data.data, dim=1).type(torch.FloatTensor)[:2000]
test_y = test_data.targets[:2000]
# print(test_x.shape)                         # torch.Size([2000, 1, 28, 28])
# print(test_y.shape)                         # torch.Size([2000])
train_loader = Data.DataLoader(dataset=train_data,shuffle=True,batch_size=Batch_size,
)# define network structure
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.con1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.con2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2))self.out = nn.Linear(32 * 7 * 7, 10)def forward(self, x):x = self.con1(x)            # (batch, 16, 14, 14)x = self.con2(x)            # (batch, 32, 7, 7)x = x.view(x.size(0), -1)out = self.out(x)             # (batch_size, 10)return outcnn = CNN()
# print(cnn)
optimizer = torch.optim.SGD(cnn.parameters(), lr=Lr)
loss_fun = nn.CrossEntropyLoss()for epoch in range(Epoch):for i, (x, y) in enumerate(train_loader):output = cnn(x)loss = loss_fun(output, y)optimizer.zero_grad()loss.backward()optimizer.step()if i % 50 == 0:test_output = torch.max(cnn(test_x), dim=1)[1]loss = loss_fun(cnn(test_x), test_y).item()accuracy = torch.sum(torch.eq(test_output, test_y)).item() / test_y.numpy().sizeprint('Epoch:', Epoch, '|loss:%.4f' % loss, '|accuracy:%.4f' % accuracy)print('real value', test_data.targets[: 10].numpy())
print('train value', torch.max(cnn(test_x)[: 10], dim=1)[1].numpy())

结果

real value [7 2 1 0 4 1 4 9 5 9]
train value [7 2 1 0 4 1 4 9 5 9]

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/332660.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

queue学习

std::queue 类是一种容器适配器,它提供队列的功能——尤其是 FIFO(先进先出)数据结构。此类模板用处为底层容器的包装器——只提供特定的函数集合。queue 在底层容器尾端推入元素,从首端弹出元素。 元素访问 front 访问第一个元素…

Shell环境变量深入:自定义系统环境变量

Shell环境变量深入:自定义系统环境变量 目标 能够自定义系统级环境变量 全局配置文件/etc/profile应用场景 当前用户进入Shell环境初始化的时候会加载全局配置文件/etc/profile里面的环境变量, 供给所有Shell程序使用 以后只要是所有Shell程序或命令使用的变量…

0基础认识C语言

为了给0基础一个舒服的学习路径,就有了这个专栏希望带大家一起进步。 话不多说,开始正题。 一、C语言的一段小历史 C语言的设计要追溯到20世纪60年代末和70年代初,在那个时代美国有这么一号人叫做丹尼斯.里奇,他和同事肯.汤普逊…

【vue】el-select选择器实现宽度自适应

选择器的宽度根据内容长度进行变化 <div class"Space_content"><el-selectv-model"value":placeholder"$t(bot.roommessage)"class"select"size"small"style"margin-right: 10px"change"selectcha…

【MySQL】库的基础操作

&#x1f30e;库的操作 文章目录&#xff1a; 库的操作 创建删除数据库 数据库编码集和校验集 数据库的增删查改       数据库查找       数据库修改 备份和恢复 查看数据库连接情况 总结 前言&#xff1a;   数据库操作是软件开发中不可或缺的一部分&#xff0…

若依框架对于后端返回异常后怎么处理?

1、后端返回自定义异常serviceException 2、触发该异常后返回json数据 因为若依对请求和响应都封装了&#xff0c;所以根据返回值response获取不到Code值但若依提供了一个catch方法用来捕获返回异常的数据 3、处理的方法

如何利用已有数据对模型进行微调

1.langchain整合llm做知识问答 利用LangChain的能力来结合检索和生成&#xff0c;形成一个知识增强的问答系统&#xff08;不涉及对模型的微调&#xff09;&#xff0c;而是利用llm从文档检索到问题解答。 langchain整合llm做知识检索 2.微调llm模型 1、首先是我们的数据集&…

python数据分析——数据可视化(图形绘制)

数据可视化&#xff08;图形绘制基础&#xff09; 前言一、图形绘制基础Matplotlib简介使用过程sin函数示例 二、常用图形绘制折线图的绘制plot示例 散点图的绘制scatter()示例 柱状图的绘制bar示例 箱型图绘制plot.box示例 饼状图的绘制pie示例 三、图形绘制的组合情况多个折线…

HTML静态网页成品作业(HTML+CSS)——宠物狗介绍网页(3个页面)

&#x1f389;不定期分享源码&#xff0c;关注不丢失哦 文章目录 一、作品介绍二、作品演示三、代码目录四、网站代码HTML部分代码 五、源码获取 一、作品介绍 &#x1f3f7;️本套采用HTMLCSS&#xff0c;未使用Javacsript代码&#xff0c;共有3个页面。 二、作品演示 三、代…

HCIP-Datacom-ARST自选题库__BGP判断【20道题】

1.传统的BGP-4只能管理IPV4单播路由信息&#xff0c;MP-BGP为了提供对多种网络层协议的支持&#xff0c;对BGP-4进行了扩展。其中MP-BGP对IPv6单播网络的支持特性称为BGP4&#xff0c;BGP4通过Next Hop属性携带路由下一跳地址信息。 2.BGP4通过Update报文中的Next Hop属性携带…

大模型额外篇章二:基于chalm3或Llama2-7b训练酒店助手模型

文章目录 一、代码部分讲解二、实际部署步骤(CHALM3训练步骤)1)注册AutoDL官网实名认证2)花费额度挑选GPU3)准备实验环境4)开始执行脚本5)从浏览器访问6)可以开始提问7)开始微调模型8)测试训练后的模型三、基于Llama2-7b的训练四、额外补充1)修改参数后2)如果需要访问…

告别红色波浪线:tsconfig.json 配置详解

使用PC端的朋友&#xff0c;请将页面缩小到最小比例&#xff0c;阅读最佳&#xff01; tsconfig.json 文件用于配置 TypeScript 项目的编译选项。如果配不对&#xff0c;就会在项目中显示一波又一波的红色波浪线&#xff0c;警告你这些地方的类型声明存在问题。 一般我们遇到这…

C语言对一阶指针 二阶指针的本质理解

代码&#xff1a; #include <stdio.h>char a 2; char* p &a; char** d &p;int main(){printf("a -> %d, &a -> %p\n", a, &a);printf("*p -> %d, p -> %p, &p -> %p\n", *p, p, &p);printf(&qu…

Java Swing + MySQL图书借阅管理系统

系列文章目录 Java Swing MySQL 图书管理系统 Java Swing MySQL 图书借阅管理系统 文章目录 系列文章目录前言一、项目展示二、部分代码1.Book2.BookDao3.DBUtil4.BookAddInternalFrame5.Login 三、配置 前言 项目是使用Java swing开发&#xff0c;界面设计比较简洁、适合作…

ic基础|时钟篇05:芯片中buffer到底是干嘛的?一文带你了解buffer的作用

大家好&#xff0c;我是数字小熊饼干&#xff0c;一个练习时长两年半的ic打工人。我在两年前通过自学跨行社招加入了IC行业。现在我打算将这两年的工作经验和当初面试时最常问的一些问题进行总结&#xff0c;并通过汇总成文章的形式进行输出&#xff0c;相信无论你是在职的还是…

计算机毕业设计 | SpringBoot招投标 任务发布网站(附源码)

1&#xff0c;绪论 在市场范围内&#xff0c;任务发布网站很受欢迎&#xff0c;有很多开发者以及其他领域的牛人&#xff0c;更倾向于选择工作时间、工作场景更自由的零工市场寻求零散单子来补贴家用。 如今市场上&#xff0c;任务发布网站鱼龙混杂&#xff0c;用户需要找一个…

电机转速计算(基于码盘和IO外部中断)

目录 概述 1 硬件介绍 1.1 整体硬件结构 1.2 模块功能介绍 2 测速框架介绍 2.1 测速原理 2.2 软件框架结构 3 使用STM32Cube配置Project 3.1 准备环境 3.2 配置参数 3.3 生成Project 4 功能实现 4.1 电机控制代码 4.2 测试代码 4.3 速度计算 5 测试 5.1 编写测…

搭建CMS系统

搭建CMS系统 1 介绍 内容管理系统&#xff08;Content Management System&#xff0c;CMS&#xff09;是一种用于管理、发布和修改网站内容的系统。开源的CMS系统有WordPress、帝国CMS等&#xff0c;国产的Halo很不错。 WordPress参考地址 # 官网 https://wordpress.org/# …

OrangePi KunPengPro | 开发板开箱测评之学习与使用

OrangePi KunPengPro | 开发板开箱测评之学习与使用 时间&#xff1a;2024年5月23日20:51:12 文章目录 OrangePi KunPengPro | 开发板开箱测评之学习与使用概述1.参考2.资料、工具3.使用3-1.通过串口登录系统3-2.通过SSH登录系统3-3.安装交叉编译工具链3-4.复制文件到设备3-5.第…

SpringMVC:创建一个简单的SpringMVC框架S

目录 一、框架介绍 两个重要的xml文件 SpringMVC执行流程 二、Vscode搭建SpringMVC框架 1、maven创建webapp原型项目 2、pom.xml下添加springmvc的相关依赖 3、在web.xml配置 4、springmvc.xml的配置 5、编写Controller控制器类 6、 编写JSP界面 7、项目结构图 一…