pytorch-01

加载mnist数据集

one-hot编码实现

import numpy as np
import torch
x_train = np.load("../dataset/mnist/x_train.npy") # 从网站提前下载数据集,并解压缩
y_train_label = np.load("../dataset/mnist/y_train_label.npy")
x = torch.tensor(y_train_label[:5],dtype=torch.int64)  # 获取前5个样本的标签数据
# 定义一个张量输入,因为此时有 5 个数值,且最大值为9,类别数为10
# 所以我们可以得到 y 的输出结果的形状为 shape=(5,10),即5行12列
y = torch.nn.functional.one_hot(x, 10)  # 一个参数张量x,10为类别数
print(y)

对于拥有6000个样本的MNIST数据集来说,标签就是一个6000\times 10大小的矩阵张量。

多层感知机模型

#设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = torch.nn.Flatten()  # 拉平图像矩阵self.linear_relu_stack = torch.nn.Sequential(torch.nn.Linear(28*28,312),   # 输入大小为28*28,输出大小为312维的线性变换层torch.nn.ReLU(),   # 激活函数层torch.nn.Linear(312, 256),torch.nn.ReLU(),torch.nn.Linear(256, 10)  # 最终输出大小为10,对应one-hot标签维度)def forward(self, input):   # 构建网络x = self.flatten(input)  #拉平矩阵为1维logits = self.linear_relu_stack(x) # 多层感知机return logits

损失函数

优化函数

model = NeuralNetwork()
loss_fu = torch.nn.CrossEntropyLoss() # 交叉熵损失函数,内置了softmax函数,
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #设定优化函数loss = loss_fu(pred,label_batch)  # 计算损失

完整模型

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #指定GPU编
import torch
import numpy as npbatch_size = 320                        #设定每次训练的批次数
epochs = 1024                           #设定训练次数#device = "cpu"                         #Pytorch的特性,需要指定计算的硬件,如果没有GPU的存在,就使用CPU进行计算
device = "cuda"                         #在这里读者默认使用GPU,如果读者出现运行问题可以将其改成cpu模式#设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = torch.nn.Flatten()self.linear_relu_stack = torch.nn.Sequential(torch.nn.Linear(28*28,312),torch.nn.ReLU(),torch.nn.Linear(312, 256),torch.nn.ReLU(),torch.nn.Linear(256, 10))def forward(self, input):x = self.flatten(input)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork()
model = model.to(device)                #将计算模型传入GPU硬件等待计算
torch.save(model, './model.pth')
#model = torch.compile(model)            #Pytorch2.0的特性,加速计算速度
loss_fu = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #设定优化函数#载入数据
x_train = np.load("../../dataset/mnist/x_train.npy")
y_train_label = np.load("../../dataset/mnist/y_train_label.npy")train_num = len(x_train)//batch_size#开始计算
for epoch in range(20):train_loss = 0for i in range(train_num):start = i * batch_sizeend = (i + 1) * batch_sizetrain_batch = torch.tensor(x_train[start:end]).to(device)label_batch = torch.tensor(y_train_label[start:end]).to(device)pred = model(train_batch)loss = loss_fu(pred,label_batch)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()  # 记录每个批次的损失值# 计算并打印损失值train_loss /= train_numaccuracy = (pred.argmax(1) == label_batch).type(torch.float32).sum().item() / batch_sizeprint("epoch:",epoch,"train_loss:", round(train_loss,2),"accuracy:",round(accuracy,2))

可视化模型结构和参数

model = NeuralNetwork()
print(model)

是对模型具体使用的函数及其对应的参数进行打印。

格式化显示:

param = list(model.parameters())
k=0
for i in param:l = 1print('该层结构:'+str(list(i.size())))for j in i.size():l*=jprint('该层参数和:'+str(l))k = k+l
print("总参数量:"+str(k))

模型保存

model = NeuralNetwork()
torch.save(model, './model.pth')

netron可视化

安装:pip install netron

运行:命令行输入netron

打开:通过网址http://localhost:8080打开

打开保存的模型文件model.pth:

 

 点击颜色块,可以显示详细信息:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/364112.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【小程序静态页面】猜拳游戏大转盘积分游戏小程序前端模板源码

猜拳游戏大转盘积分游戏小程序前端模板源码, 一共五个静态页面,首页、任务列表、大转盘和猜拳等五个页面。 主要是通过做任务来获取积分,积分可以兑换商品,也可用来玩游戏;通过玩游戏既可能获取奖品或积分也可能会消…

一文速览Google的Gemma:从gemma1到gemma2(2代27B的能力接近llama3 70B)

前言 如此文《七月论文审稿GPT第3.2版和第3.5版:通过paper-review数据集分别微调Mistral、gemma》所讲 Google作为曾经的AI老大,我司自然紧密关注,所以当Google总算开源了一个gemma 7b,作为有技术追求、技术信仰的我司&#xff0…

maven安装jar和pom到本地仓库

举例子我们要将 elastic-job-spring-boot-starter安装到本地的maven仓库&#xff0c;如下&#xff1a; <dependency><groupId>com.github.yinjihuan</groupId><artifactId>elastic-job-spring-boot-starter</artifactId><version>1.0.5&l…

关于组织赴俄罗斯(莫斯科)第 28 届国际汽车零部件、汽车维修设备和商品展览会商务考察的通知

关于组织赴俄罗斯&#xff08;莫斯科&#xff09; 第 28 届国际汽车零部件、汽车维修设备和商品展览会商务考察的通知 展会名称&#xff1a;俄罗斯&#xff08;莫斯科&#xff09;第 28 届国际汽车零部件、汽车零部件、汽车维修设备和商品展览会 时间&#xff1a;2024 年 8 月…

day02-Spark集群及参数

一、Spark运行环境变量问题(了解) 1-pycharm远程开发运行时&#xff0c;执行的是服务器的代码 2-通过本地传递指令到远程服务器运行代码时&#xff0c;会加载对应环境变量数据&#xff0c;加载环境变量文件是用户目录下的.bashrc文件 在/etc/bashrc 1-1 在代码中添加 使用os模块…

文本编辑命令和正则表达式

一、 编辑文本的命令 正则表达式匹配的是文本内容&#xff0c;Linux的文本三剑客&#xff0c;都是针对文本内容。 文本三剑客 grep&#xff1a;过滤文本内容 sed&#xff1a;针对文本内容进行增删改查 &#xff08;本文不相关&#xff09; awk&#xff1a;按行取列 &#x…

【网络架构】keepalive

目录 一、keepalive基础 1.1 作用 1.2 原理 1.3 功能 二、keepalive安装 2.1 yum安装 2.2 编译安装 三、配置文件 3.1 keepalived相关文件 3.2 主配置的组成 3.2.1 全局配置 3.2.2 配置虚拟路由器 四、实际操作 4.1 lvskeepalived高可用群集 4.2 keepalivedngi…

element 问题整合

没关系&#xff0c;凡事发生必有利于我 文章目录 一、el-table 同级数据对齐及展开图标的位置问题二、el-table 勾选框为圆角及只能勾选一个三、el-tree 弹框打开&#xff0c;使得列表关闭&#xff0c;且弹框滚动条回到顶部 一、el-table 同级数据对齐及展开图标的位置问题 ele…

Facebook的投流技巧有哪些?

相信大家都知道Facebook拥有着巨大的用户群体和高转化率&#xff0c;在国外社交推广中的影响不言而喻。但随着Facebook广告的竞争越来越激烈&#xff0c;在Facebook广告上获得高投资回报率也变得越来越困难。IPIDEA代理IP今天就教大家如何在Facebook上投放广告的技巧&#xff0…

使用 Ubuntu x86_64 平台交叉编译适用于 Linux aarch64(arm64) 平台的 QT5(包含OpenGL/WebEngine支持) 库

使用 Ubuntu AMD64 平台交叉编译适用于 Linux ARM64 平台的 QT5(包含 OpenGL/WebEngine 支持) 库 目录 使用 Ubuntu AMD64 平台交叉编译适用于 Linux ARM64 平台的 QT5(包含 OpenGL/WebEngine 支持) 库写在前面前期准备编译全流程1. 环境搭建2. 复制源码包并解压&#xff0c;创…

响应式高端家居装修网站源码pbootcms模板

模板介绍 分享一款黄色的响应式高端家居装修网站源码pbootcms模板&#xff0c;该模板能自适应手机端&#xff0c;响应式的设计可让您自由编辑&#xff0c;适合任何关于装修&#xff0c;空间设计&#xff0c;家装&#xff0c;家居等业务的企业。 模板截图 源码下载 响应式高端…

C++——探索智能指针的设计原理

前言: RAII是资源获得即初始化&#xff0c; 是一种利用对象生命周期来控制程序资源地手段。 智能指针是在对象构造时获取资源&#xff0c; 并且在对象的声明周期内控制资源&#xff0c; 最后在对象析构的时候释放资源。注意&#xff0c; 本篇文章参考——C 智能指针 - 全部用法…

已解决问题 | 该扩展程序未列在 Chrome 网上应用店中,并可能是在您不知情的情况下添加的

在Chrome浏览器中&#xff0c;如果你看到“该扩展程序未列在 Chrome 网上应用店中&#xff0c;并可能是在您不知情的情况下添加的”这样的提示&#xff0c;通常是因为该扩展程序没有通过Chrome网上应用店进行安装。以下是解决这个问题的步骤&#xff1a; 解决办法&#xff1a;…

计算机网络知识整理笔记

目录 1.对网络协议的分层&#xff1f; 2.TCP/IP和UDP之间的区别&#xff1f; 3.建立TCP连接的三次握手&#xff1f; 4.断开TCP连接的四次挥手&#xff1f; 5.TCP协议如何保证可靠性传输&#xff1f; 6.什么是TCP的拥塞控制&#xff1f; 7.什么是HTTP协议&#xff1f; 8…

MySQL高级-SQL优化- limit优化(覆盖索引加子查询)

文章目录 0、limit 优化0.1、从表 tb_sku 中按照 id 列进行排序&#xff0c;然后跳过前 9000000 条记录0.2、通过子查询获取按照 id 排序后的第 9000000 条开始的 10 条记录的 id 值&#xff0c;然后在原表中根据这些 id 值获取对应的完整记录 1、上传5个sql文件到 /root2、查看…

【工具推荐】ONLYOFFICE 桌面编辑器 8.1:引入全新功能,提升文档处理体验

ONLYOFFICE 桌面编辑器 8.1 现已发布&#xff1a;功能完善的 PDF 编辑器、幻灯片版式、改进从右至左显示、新的本地化选项等 【工具推荐】ONLYOFFICE 桌面编辑器 8.1&#xff1a;引入全新功能&#xff0c;提升文档处理体验 一、什么是ONLYOFFICE&#xff1f; ONLYOFFICE 是…

PG备份与恢复

一、开启WAL归档 1、创建归档目录 我们除了存储数据目录pgdata之外&#xff0c;还要创建backups&#xff0c;scripts&#xff0c;archive_wals文件 mkdir -p /home/mydba/pgdata/arch mkdir -p /home/mydba/pgdata/scripts mkdir -p /home/mydba/backups chown -R mydba.myd…

API接口知识小结

应用程序接口API&#xff08;Application Programming Interface&#xff09;&#xff0c;是提供特定业务输出能力、连接不同系统的一种约定。这里包括外部系统与提供服务的系统&#xff08;中后台系统&#xff09;或后台不同系统之间的交互点。包括外部接口、内部接口&#xf…

ANSYS Electronics 电磁场仿真工具下载安装,ANSYS Electronics强大的功能和灵活性

ANSYS Electronics无疑是一款在电磁场仿真领域表现卓越的软件工具。它凭借强大的功能和灵活性&#xff0c;帮助用户在产品设计阶段就能精确预测和优化电磁场性能&#xff0c;从而极大地降低了实际测试成本&#xff0c;并显著提升了产品的可靠性。 这款软件不仅在电子设计领域有…

Python | Leetcode Python题解之第204题计数质数

题目&#xff1a; 题解&#xff1a; MX5000000 is_prime [1] * MX is_prime[0]is_prime[1]0 for i in range(2, MX):if is_prime[i]:for j in range(i * i, MX, i):#循环每次增加iis_prime[j] 0 class Solution:def countPrimes(self, n: int) -> int:return sum(is_prim…