交叉熵Loss多分类问题实战(手写数字)

1、import所需要的torch库和包
在这里插入图片描述
2、加载mnist手写数字数据集,划分训练集和测试集,转化数据格式,batch_size设置为200在这里插入图片描述
3、定义三层线性网络参数w,b,设置求导信息
在这里插入图片描述
4、初始化参数,这一步比较关键,是否初始化影响到数据质量以及后续网络学习效果
在这里插入图片描述
5、自定义三层线性网络
在这里插入图片描述
6、选定优化器激活函数和loss函数
在这里插入图片描述
7、训练及测试,并记录每轮训练的loss变化和在测试集上的效果。第一轮就达到了98的准确度,判断是初始化效果较好,在前几次测试中根据初始化的情况不同,初始准确率为50%-85%不等
在这里插入图片描述
完整代码:

import torch
import torchvision
import torch.nn.functional as Ftrain_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307, ), (0.3081, ))])),batch_size=200, shuffle=True)test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307, ), (0.3081, ))])),batch_size=200, shuffle=True)w1 = torch.randn(200, 784, requires_grad=True)
b1 = torch.randn(200, requires_grad=True)
w2 = torch.randn(200, 200, requires_grad=True)
b2 = torch.randn(200, requires_grad=True)
w3 = torch.randn(10, 200, requires_grad=True)
b3 = torch.randn(10, requires_grad=True)torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)def forward(x):x = x@w1.t() +b1x = F.relu(x)x = x@w2.t() +b2x = F.relu(x)x = x@w3.t() +b3x = F.relu(x)return xoptimizer = torch.optim.Adam([w1, b1, w2, b2, w3, b3], lr=0.001)
criterion = torch.nn.CrossEntropyLoss()for epoch in range(10):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28)logits = forward(data)loss = criterion(logits, target)optimizer.zero_grad()loss.backward()optimizer.step()if (batch_idx+1) % 150 == 0:print('Train Epoch:{} [{}/{}({:.0f}%)]\tLoss:{:.6f}'.format(epoch, (batch_idx+1) * len(data), len(train_loader.dataset),100. * (batch_idx+1) / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28*28)logits = forward(data)test_loss += criterion(logits, target).item()pred = logits.data.max(1)[1]correct += pred.eq(target.data).sum()test_loss /= len(test_loader)print('\nTest Set:Average Loss:{:.4f}, Accuracy:{}/{}({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/158067.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

强化学习(Reinforcement Learning)与策略梯度(Policy Gradient)

写在前面:本篇博文的内容来自李宏毅机器学习课程与自己的理解,同时还参考了一些其他博客(懒得放链接)。博文的内容主要用于自己学习与记录。 1 强化学习的基本框架 强化学习(Reinforcement Learning, RL)主要由智能体(Agent/Actor)、环境(Environment)、…

【学习笔记】项目进行过程中遇到有关composer的问题

composer.json内容详解 以项目中的composer.json为例,参考文档。 name:composer包名type:包的类型,project和library两种keywords:关键词,方便别人在安装时通过关键词检索(没试过,好…

《C++ Primer》练习9.52:使用栈实现四则运算

栈可以用来使用四则运算,是一个稍微有点复杂的题目,去学习了一下用栈实现四则运算的原理,用C实现了一下。首先要把常见的中缀表达式改成后缀表达式,然后通过后缀表达式计算,具体的原理可以参考这位博主的文章&#xff…

抖音直播招聘小程序可以增加职位展示,提升转化率,增加曝光度

抖音直播招聘报白是指进入抖音的白名单,允许在直播间或小视频中发布招聘或找工作等关键词。否则会断播、不推流、限流。抖音已成为短视频流量最大的平台,但招聘企业数量较少。抖音招聘的优势在于职位以视频、直播方式展示,留存联系方式更加精…

到底什么是5G-R?

近日,工信部向中国国家铁路集团有限公司(以下简称“国铁集团”)批复5G-R试验频率的消息,引起了行业内的广泛关注。 究竟什么是5G-R?为什么工信部会在此时批复5G-R的试验频率? 今天,小枣君就通过…

公司文件防泄密软件——「天锐绿盾」@德人合科技

天锐绿盾是一款企业级数据安全解决方案,主要用于保护企业的知识产权、客户资料、财务数据、技术图纸、应用系统等机密信息化数据不外泄。 PC访问地址: 🔗isite.baidu.com/site/wjz012xr/2eae091d-1b97-4276-90bc-6757c5dfedee 该软件解决方案…

vue单页面应用使用 history模式路由时刷新页面404的一种可能性

原先使用的是 hash模式路由,因为要结合qiankun进行微前端改造,改成了 history模式,结果页面刷新之后没有正确渲染组件。按照一般思路检查 nginx配置 try_files $uri $uri/ /index.html;也配置上了,还是有问题。 页面异常显示 问题…

电脑重做系统---win10

电脑重做系统---win10 前言制作启动U盘材料方法打开网址下载启动盘制作工具参照官方说明进行制作使用U盘重做系统 常用软件官网地址 前言 记得最早学习装电脑还是04年左右,最为一个啥也不知道的大一傻白胖,花了几百大洋在电脑版把了个“电脑组装与维修”…

使用 KubeSkoop exporter 监测和定位容器网络抖动问题

作者:遐宇、溪恒 本文是 8 月 17 日直播的文字稿整理,文末可观看直播回放。除去文章内容外,还包括针对实际网络问题的实战环节。 容器网络抖动问题发生频率低,时间短,是网络问题中最难定位和解决的问题之一。 不仅如…

CVE-2017-15715 apache换行解析文件上传漏洞

影响范围 httpd 2.4.0~2.4.29 复现环境 vulhub/httpd/CVE-2017-15715 docker-compose 漏洞原理 在apache2的配置文件: /etc/apache2/conf-available/docker-php.conf 中,php的文件匹配以正则形式表达 ".php$"的正则匹配模式意味着以.ph…

见微知著:从企业售后技术支持看云计算发展

作者:余凯 售后业务中的细微变化 作为阿里云企业容器技术支持的一员,每天会面对全球各地企业级客户提出的关于容器的各种问题,通过这几年的技术支持的经历,逐步发现容器问题客户的一些惯性,哪些是重度用户&#xff0…

将Excel表中数据导入MySQL数据库

1、准备好Excel表: 2、数据库建表case2 字段信息与表格对应建表: 3、实现代码 import pymysql import pandas as pd import openpyxl 从excel表里读取数据后,再存入到mysql数据库。 需要安装openpyxl pip install openpyxl# 读入数据&#x…

【面试高频题】难度 1/5,经典树的搜索(多语言)

题目描述 这是 LeetCode 上的 「109. 有序链表转换二叉搜索树」 ,难度为 「中等」 Tag : 「二叉树」、「树的搜索」、「分治」、「中序遍历」 给定一个单链表的头节点 head,其中的元素 按升序排序 ,将其转换为高度平衡的二叉搜索树。 本题中&…

基于SSM的校园音乐平台系统

基于SSM的校园音乐平台系统~ 开发语言:Java数据库:MySQL技术:SpringSpringMVCMyBatisVue工具:IDEA/Ecilpse、Navicat、Maven 系统展示 主页 登录界面 管理员界面 歌手管理 歌曲管理 摘要 校园音乐平台系统(Campus Mu…

基于知识图谱建模、全文检索的智能知识管理库(源码)

一、项目介绍 一款全源码,可二开,可基于云部署、私有部署的企业级知识库云平台,一款让企业知识变为实打实的数字财富的系统,应用在需要进行文档整理、分类、归集、检索、分析的场景。 知识图谱提供了一种从海量文本和图像中抽取结…

【C++】:初阶模板

朋友们、伙计们,我们又见面了,本期来给大家解读一下有关Linux的基础知识点,如果看完之后对你有一定的启发,那么请留下你的三连,祝大家心想事成! C 语 言 专 栏:C语言:从入门到精通 数…

【算法-动态规划】最长公共子串

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kuan 的首页,持续学…

磁盘非跨盘访问算法实现

1. 背景说明 本算法基于已将磁盘分布合并并排序为升序线性表。实现示例为:磁盘扇区大小:32(可自定义),待拆分磁盘内存: [0 - 50],[60 - 100](可增加)。示意图如下&#x…

RISC-V 特权级架构

特权级别 级别的数值越大,特权级越高,掌控硬件的能力越强,在CPU硬件层面,M模式必须存在,其它模式可以不存在 执行环境调用 ecall ,这是一种很特殊的陷入类的指令, 相邻两特权级软件之间的接口正…