Pytorch-MLP-Mnist

文章目录

  • model.py
  • main.py
  • 参数设置
  • 注意事项
    • 初始化权重
    • 如果发现loss和acc不变
    • 关于数据下载
    • 关于输出格式
  • 运行图

model.py

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as initclass MLP_cls(nn.Module):def __init__(self,in_dim=28*28):super(MLP_cls,self).__init__()self.lin1 = nn.Linear(in_dim,128)self.lin2 = nn.Linear(128,64)self.lin3 = nn.Linear(64,10)self.relu = nn.ReLU()init.xavier_uniform_(self.lin1.weight)init.xavier_uniform_(self.lin2.weight)init.xavier_uniform_(self.lin3.weight)def forward(self,x):x = x.view(-1,28*28)x = self.lin1(x)x = self.relu(x)x = self.lin2(x)x = self.relu(x)x = self.lin3(x)x = self.relu(x)return x

main.py

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torch.optim as optim
from model import MLP_clsseed = 42
torch.manual_seed(seed)
batch_size_train = 64
batch_size_test  = 64
epochs = 10
learning_rate = 0.01
momentum = 0.5
mlp_net = MLP_cls()train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5,), (0.5,))])),batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5,), (0.5,))])),batch_size=batch_size_test, shuffle=True)optimizer = optim.SGD(mlp_net.parameters(), lr=learning_rate,momentum=momentum)
criterion = nn.CrossEntropyLoss()print("****************Begin Training****************")
mlp_net.train()
for epoch in range(epochs):run_loss = 0correct_num = 0for batch_idx, (data, target) in enumerate(train_loader):out = mlp_net(data)_,pred = torch.max(out,dim=1)optimizer.zero_grad()loss = criterion(out,target)loss.backward()run_loss += lossoptimizer.step()correct_num  += torch.sum(pred==target)print('epoch',epoch,'loss {:.2f}'.format(run_loss.item()/len(train_loader)),'accuracy {:.2f}'.format(correct_num.item()/(len(train_loader)*batch_size_train)))print("****************Begin Testing****************")
mlp_net.eval()
test_loss = 0
test_correct_num = 0
for batch_idx, (data, target) in enumerate(test_loader):out = mlp_net(data)_,pred = torch.max(out,dim=1)test_loss += criterion(out,target)test_correct_num  += torch.sum(pred==target)
print('loss {:.2f}'.format(test_loss.item()/len(test_loader)),'accuracy {:.2f}'.format(test_correct_num.item()/(len(test_loader)*batch_size_test)))

参数设置

'./data/' #数据保存路径
seed = 42 #随机种子
batch_size_train = 64
batch_size_test  = 64
epochs = 10optim --> SGD
learning_rate = 0.01
momentum = 0.5

注意事项

初始化权重

这里使用这种方式

        init.xavier_uniform_(self.lin1.weight)init.xavier_uniform_(self.lin2.weight)init.xavier_uniform_(self.lin3.weight)

如果发现loss和acc不变

检查一下是不是忘记写optimizer.step()了

关于数据下载

数据在download=True时,会下载在./data文件夹下

关于输出格式

这里用‘xxx {:.2f}'.format(xxx),保留两位小数。注意中间的空格,区分:.2f和%2f

运行图

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/135039.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

快递、外卖、网购自动定位及模糊检索收/发件地址功能实现

概述 目前快递、外卖、团购、网购等行业 :为了简化用户在收发件地址填写时的体验感,使用辅助定位及模糊地址检索来丰富用户的体验 本次demo分享给大家;让大家理解辅助定位及模糊地址检索的功能实现过程,以及开发出自己理想的作品…

【C++初阶】C++STL详解(四)—— vector的模拟实现

​ ​📝个人主页:Sherry的成长之路 🏠学习社区:Sherry的成长之路(个人社区) 📖专栏链接:C初阶 🎯长路漫漫浩浩,万事皆有期待 【C初阶】CSTL详解(三…

Python 文件写入操作

视频版教程 Python3零基础7天入门实战视频教程 w模式是写入,通过write方法写入内容。 # 打开文件 模式w写入,文件不存在,则自动创建 f open("D:/测试3.txt", "w", encoding"UTF-8")# write写入操作 内容写入…

C++---继承

继承 前言继承的概念及定义继承的概念继承定义继承关系和访问限定符 基类和派生类对象赋值转换继承中的作用域派生类的默认成员函数继承与友元继承与静态成员**多重继承**多继承下的类作用域菱形继承虚继承使用虚基类 支持向基类的常规类型转换 前言 在需要写Father类和Mother…

Python爬虫实战案例——第五例

文章中所有内容仅供学习交流使用,不用于其他任何目的!严禁将文中内容用于任何商业与非法用途,由此产生的一切后果与作者无关。若有侵权,请联系删除。 目标:采集三国杀官网的精美壁纸 地址:aHR0cHM6Ly93d3…

Qt/C++音视频开发54-视频监控控件的极致设计

一、前言 跌跌撞撞摸爬滚打一步步迭代完善到今天,这个视频监控控件的设计,在现阶段水平上个人认为是做的最棒的(稍微自恋一下),理论上来说应该可以用5年不用推翻重写,推翻重写当然也是程序员爱干的事情&am…

Visual Studio2019报错

1- Visual Studio2019报错 错误 MSB8036 找不到 Windows SDK 版本 10.0.19041.0的解决方法 小伙伴们在更新到Visual Studio2019后编译项目时可能遇到过这个错误:“ 错误 MSB8036 找不到 Windows SDK 版本 10.0.19041.0的解决方法”,但是我们明明安装了该…

Linux多线程【线程控制】

✨个人主页: 北 海 🎉所属专栏: Linux学习之旅 🎃操作环境: CentOS 7.6 阿里云远程服务器 文章目录 🌇前言🏙️正文1、线程知识补充1.2、线程私有资源1.3、线程共享资源1.4、原生线程库 2、线程…

安卓机型固件系统分区的基础组成 手机启动规律初步常识 各分区的基本含义与说明

此贴为基本常识。感兴趣的友友可以了解手机的启动顺序和各模式的基本操作与意义。另外了解手机系统分区各文件夹的含义 分区说明对应贴:安卓机型固件中分区对应说明 手机开机基本启动顺序 当我们按下手机开机键的时候。基本的启动顺序为 注意:该结构图…

Learn Prompt-“标准“提示

在前面的教程中,我们介绍了指令输入的简单提示,提供实例的提示和角色扮演类的提示,那么是否有一个公式来列出提示的各个部分,并将其组合成一个标准化的提示?答案是肯定的。 角色扮演(Role) 指令…

SQL 性能优化总结

文章目录 一、性能优化策略二、索引创建规则三、查询优化总结 一、性能优化策略 1. SQL 语句中 IN 包含的值不应过多 MySQL 将 IN中的常量全部存储在一个排好序的数组里面,但是如果数值较多,产生的消耗也是比较大的。所以对于连续的数值,能用…

如何用在线模版快速制作活动海报?

在时代的发展和信息传播的快速发展下,活动海报成为了宣传活动的重要方式之一。设计一张吸引眼球的活动海报,不仅能够有效传递信息,还能够吸引人们的注意力。那么,在这里我将教会大家如何设计活动海报,只需要三分钟&…

12.(Python数模)(相关性分析一)相关系数矩阵

相关系数矩阵 相关系数矩阵是用于衡量多个变量之间关系强度和方向的统计工具。它是一个对称矩阵,其中每个元素表示对应变量之间的相关系数。 要计算相关系数矩阵,首先需要计算每对变量之间的相关系数。常用的相关系数包括皮尔逊相关系数和斯皮尔曼相关…

【JAVA-Day14】深入了解 Java 中的 while 循环语句

深入了解 Java 中的 while 循环语句 深入了解 Java 中的 while 循环语句摘要引言一、什么是 while 循环语句二、while 循环语句的语法和使用场景使用场景 三、while 循环的优势和使用场景优势使用建议 四、总结参考资料 博主 默语带您 Go to New World. ✍ 个人主页—— 默语 的…

Mysql详解Explain索引优化最佳实践

目录 1 Explain工具介绍2 explain 两个变种3 explain中的列3.1 id列3.2 select_type列3.3 table列3.4. type列3.5 possible_keys列3.6 key列3.7 key_len列3.8 ref列3.9 rows列3.10 Extra列 4 索引最佳实践4.1.全值匹配4.2.最左前缀法则4.3.不在索引列上做任何操作(计…

stringBuffer.append(analyze);使用这个拼接时候如何在字符串参数字符串参数整数参数字符串数组参数内容之间添加空格

stringBuffer.append(analyze);使用这个拼接时候如何在字符串参数字符串参数整数参数字符串数组参数内容之间添加空格? 在添加参数到 StringBuffer 时,你可以在每次添加参数之后都添加一个空格,如下所示: StringBuffer stringBu…

零信任:基于Apisix构建认证网关

最终效果 基于身份认证的零信任网关 - 知乎 背景 零信任一直是我们未来主攻的一个方向,全球加速,SD-WAN组网都是一些非常成熟的产品,全球加速是我们所有产品的底座,SD-WAN解决的是多个网络打通的问题,而零信任则主打…

『PyQt5-Qt Designer篇』| 09 Qt Designer中分割线和间隔如何使用?

09 Qt Designer中分割线和间隔如何使用? 1 间隔1.1 水平间隔1.2 垂直间隔2 分割线2.1 水平线2.2 垂直线3 保存并执行1 间隔 间隔有水平间隔和垂直间隔: 1.1 水平间隔 拖动4个按钮,并设置为水平布局: 在第一个按钮的右边添加一个水平间隔: 设置其sizeType为Fixed,宽度为20…

c++ 函数的参数是否可以为auto

(1)在vs2019开到 cpp20 的语法规范,是可以的 (2)但网上和文心一言和书上说不可以 (2) 再附上一种auto 的很炫酷的写法:

HTML+CSS画一个卡通中秋月饼

HTMLCSS画一个卡通中秋月饼🥮🥮🥮 中秋活动水个文章 整个divcss实现个月饼,给前端初学者一个练手的demo 效果图 思路 HTMl 先来个轮廓画脸上的东西:眼睛、眉毛、腮红、嘴巴眼睛丰富下瞳孔画20个花瓣 CSS 轮廓是要外…