【动手学深度学习Pytorch】6. LeNet实现代码

        LeNet(LeNet-5)由两个部分组成:卷积编码器和全连接层密集块

x.view(): 对tensor进行reshape

import torch
from torch import nn
from d2l import torch as d2lclass Reshape(torch.nn.Module):def forward(self, x):return x.view(-1, 1, 28, 28)net = torch.nn.Sequential(Reshape(), nn.Conv2d(1, 6, kernel_size = 5, padding = 2), nn.Sigmoid(),nn.AvgPool2d(2, stride = 2),nn.Conv2d(6, 16, kernel_size = 5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10)
)

torch.rand():生成随机数

layer.__class__.__name__:用self.__class__将实例变量指向类,然后再去调用__name__类属性

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__, 'output shape: \t', X.shape)

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

net.eval(): 变成评估模式

d2l.accuracy(y_hat, y):用于计算模型预测的准确性,即预测值与真实标签的匹配程度。

# 使用gpu进行训练
def evaluate_accuracy_gpu(net, data_iter, device=None):if isinstance(net, torch.nn.Module):net.eval()if not device: #如果没有给定device,遍历网络参数进行设备选择device = next(iter(net.parameters())).device#  如果device未提供,获取net中参数的设备,自动决定是在CPU还是GPU上运行metric = d2l.Accumulator(2) # 累计预测正确的数量(分子)和样本总数(分母)for X, y in data_iter:if isinstance(X, list):X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X), y), y.numel())# 将预测正确的数量和样本总数分别累加# 获取当前批次中标签的数量(样本总数)return metric[0] / metric[1]

d2l.Timer(): 创建一个计时器实例,用于记录时间

timer.stop(): 停止计时,并返回从启动到停止的总运行时间

def train_ch6(net, train_iter, test_iter, num_epoch2, lr, device):def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator=d2l.Animator(xlabel='eopch', xlim=[1, num_epochs],legend=['train loss','train acc','test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):metric = d2l.Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()metric.add(1*X.shape[0], d2l.accuracy(y_hat, y),X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if(i+1)%(num_batches//5)==0 or i ==num_batches-1:animator.add(epoch+(i+1)/num_batches, (train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch+1, (None, None, test_acc))print(f'loss{train_l:.3f},train acc{train_acc:.3f},'f'test acc{test_acc:.3f}')print(f'{metric[2]*num_epochs/timer.sum():.1f} examples/sec'f'on{str(device)}')

d2l.try_gpu():如果有可用 GPU,则返回 GPU 设备;否则返回 CPU 设备。

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/474553.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI工具百宝箱|任意选择与Chatgpt、gemini、Claude等主流模型聊天的Anychat,等你来体验!

文章推荐 AI工具百宝箱|使用Deep Live Cam,上传一张照片就可以实现实时视频换脸...简直太逆天! Anychat 这是一款可以与任何模型聊天 (chatgpt、gemini、perplexity、claude、metal llama、grok 等)的应用。 在页面…

Excel数据动态获取与映射

处理代码 动态映射 动态读取 excel 中的数据,并通过 json 配置 指定对应列的值映射到模板中的什么字段上 private void GetFreightFeeByExcel(string filePath) {// 文件名需要以快递公司命名 便于映射查询string fileName Path.GetFileNameWithoutExtension(fi…

SRP 实现 Cook-Torrance BRDF

写的很乱! BRDF(Bidirectional Reflectance Distribution Function)全称双向反射分布函数。辐射量单位非常多,这里为方便直观理解,会用非常不严谨的光照强度来解释说明。 BRDF光照模型,上反射率公式&#…

[代码随想录Day16打卡] 找树左下角的值 路径总和 从中序与后序遍历序列构造二叉树

找树左下角的值 定义:二叉树中最后一行最靠左侧的值。 前序,中序,后序遍历都是先遍历左然后遍历右。 因为优先遍历左节点,所以递归中因为深度增加更新result的时候,更新的值是当前深度最左侧的值,到最后就…

【第七节】在RadAsm中使用OllyDBG调试器

前言 接着本专栏上一节,我们虽然已经用上RadAsm进行编写x86汇编代码并编译运行,但是想进行断点调试怎么办。RadAsm里面找不到断点调试,下面我们来介绍如何在RadAsm上联合调试器OllyDBG进行调试代码。 OllyDBG的介绍与下载 OllyDBG 是一款功能…

WPF MVVM框架

一、MVVM简介 MVC Model View Control MVP MVVM即Model-View-ViewModel,MVVM模式与MVP(Model-View-Presenter)模式相似,主要目的是分离视图(View)和模型(Model),具有低…

PH热榜 | 2024-11-19

DevNow 是一个精简的开源技术博客项目模版,支持 Vercel 一键部署,支持评论、搜索等功能,欢迎大家体验。 在线预览 1. Layer 标语:受大脑启发的规划器 介绍:体验一下这款新一代的任务和项目管理系统吧!它…

【ArcGISPro】使用AI模型提取要素-提取车辆(目标识别)

示例数据下载 栅格数据从网上随便找一个带有车辆的栅格数据 f094a6b1e205cd4d30a2e0f816f0c6af.jpg (1200799) (588ku.com) 添加数据

联通光猫(烽火通信设备)改桥接教程

一、获得超级密码 1.打开telnet连接权限 http://192.168.1.1/telnet?enable1&key9070D3BECD70(MAC地址)2.连接光猫获取密码 telnet 192.168.1.1 用户名:admin 密码:Fh9070D3BECD70连接成功后 load_cli factory show admin_…

掌握SEO提升网站流量的关键在于长尾关键词的有效运用

内容概要 在现代数字营销中,搜索引擎优化(SEO)被广泛视为提升网站流量的核心策略之一,而其中长尾关键词的运用显得尤为重要。长尾关键词通常由三个或更多个词组成,具有更高的针对性和精确度,可以更好地满足…

【期权懂|个股期权中的备兑开仓策略是如何进行的?

期权小懂每日分享期权知识,帮助期权新手及时有效地掌握即市趋势与新资讯! 个股期权中的备兑开仓策略是如何进行的? 个股期权备兑开仓的优点和风险‌: ‌(1)优点‌:备兑开仓可以增强持股收益&…

汽车安全再进化 - SemiDrive X9HP 与环景影像系统 AVM 的系统整合

当今汽车工业正面临著前所未有的挑战与机遇,随著自动驾驶技术的迅速发展,汽车的安全性与性能需求日益提高。在这样的背景下,汽车 AVM(Automotive Visual Monitoring)标准应运而生,成为促进汽车智能化和安全…

MongoDB聚合操作

管道的聚合 管道在Unix和Linux中一般用于将当前命令的输出结果作为下一个命令的参数。 MongoDB的聚合管道将MongoDB文档在一个管道处理完毕后将结果传递给下一个管道处理。管道操作是可以重复的。 表达式:处理输入文档并输出。表达式是无状态的,只能用…

向量数据库FAISS之五:原理(LSH、PQ、HNSW、IVF)

1.Locality Sensitive Hashing (LSH) 使用 Shingling MinHashing 进行查找 左侧是字典,右侧是 LSH。目的是把足够相似的索引放在同一个桶内。 LSH 有很多的版本,很灵活,这里先介绍第一个版本,也是原始版本 Shingling one-hot …

https(day30)

1.配置需要配置端口为443 2.配置需要配置证书 ssl_certificate /path/to/your/fullchain.pem; # 证书文件 ssl_certificate_key /path/to/your/private.key; # 私钥文件 3.其他优化

【WPF】Prism学习(七)

Prism Dependency Injection 1.注册类型(Registering Types) 1.1. Prism中的服务生命周期: Transient(瞬态):每次请求服务或类型时,都会获得一个新的实例。Singleton(单例&#xf…

服务器数据恢复—热备盘未激活导致硬盘掉线的raid5阵列崩溃的数据恢复案例

服务器数据恢复环境: 某品牌X3850服务器中有一组由数块SAS硬盘组建的RAID5阵列,该阵列中有一块盘是热备盘。操作系统为linux redhat,上面跑着一个基于oracle数据库的oa。 服务器故障: 服务器raid5阵列中有一块硬盘离线&#xff0…

ADS 2022软件下载与安装教程

“ 本文以最新的Advanced Design System 2022为例介绍ADS软件的安装及crack教程 ” ADS 简介 先进设计系统 Advanced Design system(ADS)Agilent Technologies 是领先的电子设计自动化软件,适用于射频、微波和信号完整性应用。ADS 是获得商…

Chrome 浏览器 131 版本新特性

Chrome 浏览器 131 版本新特性 一、Chrome 浏览器 131 版本更新 1. 在 iOS 上使用 Google Lens 搜索 自 Chrome 126 版本以来,用户可以通过 Google Lens 搜索屏幕上看到的任何图片或文字。 要使用此功能,请访问网站,并点击聚焦时出现在地…

Unity 编辑器下 Android 平台 Addressable 加载模型粉红色,类似材质丢失

Unity 编辑器下 Android 平台 Addressable 加载模型粉红色,类似材质丢失 Addressable Play Mode Script加载模式 选择 Use Existiing Build 1.Unity 切换到 PC 平台,执行 Addressable Build 运行,加载 bundle 内的预制体 显示正常 2.Unit…