2-5 softmax 回归的简洁实现

我们发现通过深度学习框架的高级API能够使实现线性回归变得更加容易。 同样,通过深度学习框架的高级API也能更方便地实现softmax回归模型。 本节如在上节中一样, 继续使用Fashion-MNIST数据集,并保持批量大小为256。

import torch
from torch import nn  # 通过pytorch的nn的module
from d2l import torch as d2lbatch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

初始化模型参数

softmax回归的输出层是一个全连接层。 因此,为了实现我们的模型, 我们只需在Sequential中添加一个带有10个输出的全连接层。 同样,在这里Sequential并不是必要的, 但它是实现深度模型的基础。 我们仍然以均值0和标准差0.01随机初始化权重。

# PyTorch不会隐式地调整输入的形状。因此,
# 我们在线性层前定义了展平层(flatten),来调整网络输入的形状
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10)) # 这里使用了PyTorch中的nn.Sequential来构建一个顺序容器,将层按顺序添加到网络中
# nn.Flatten():这一层的作用是将输入的数据展平成一维。假设输入的数据是一个28x28的二维图像,展平后将变成一个784(28*28)长度的一维向量。
# nn.Linear(784, 10):这是一个全连接层(线性层),输入大小为784(展平后的图像向量),输出大小为10(假设有10个类别)。def init_weights(m):  # 这里定义了一个函数init_weights,用于初始化网络中的权重。if type(m) == nn.Linear: # 这行代码检查传入的层是否为nn.Linear类型,即全连接层。nn.init.normal_(m.weight, std=0.01) # 如果该层是全连接层,则使用nn.init.normal_方法将该层的权重初始化为均值为0,标准差为0.01的正态分布随机值。net.apply(init_weights);  # net.apply方法会遍历网络中的每一层,并将init_weights函数应用到每一层上,完成权重的初始化。

在交叉熵损失函数中传递未归一化的预测,并同时softmax及其对数

loss = nn.CrossEntropyLoss(reduction='none')

在这里,我们使用学习率为0.1的小批量随机梯度下降作为优化算法。 这与我们在线性回归例子中的相同,这说明了优化器的普适性。

trainer = torch.optim.SGD(net.parameters(), lr=0.1)

训练

接下来我们调用之前定义的训练函数来训练模型。

num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

在这里插入图片描述


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/370235.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【基础篇】第4章 Elasticsearch 查询与过滤

在Elasticsearch的世界里,高效地从海量数据中检索出所需信息是其核心价值所在。本章将深入解析查询与过滤的机制,从基础查询到复合查询,再到全文搜索与分析器的定制,为你揭开数据检索的神秘面纱。 4.1 基本查询 4.1.1 Match查询…

多语言版在线出租车预订完整源码+用户应用程序+管理员 Laravel 面板+ 司机应用程序最新版源码

源码带PHP后台客户端源码 Flutter 是 Google 开发的一款开源移动应用开发 SDK。它用于开发 Android 和 iOS 应用,也是为 Google Fuchsia 创建应用的主要方法。Flutter 小部件整合了所有关键的平台差异,例如滚动、导航、图标和字体,可在 iOS 和…

如何在前端网页实现live2d的动态效果

React如何在前端网页实现live2d的动态效果 业务需求: 因为公司需要做机器人相关的业务,主要是聊天形式的内容,所以需要一个虚拟的卡通形象。而且为了更直观的展示用户和机器人对话的状态,该live2d动画的嘴型需要根据播放的内容来…

昇思25天学习打卡营第08天 | 模型训练

昇思25天学习打卡营第08天 | 模型训练 文章目录 昇思25天学习打卡营第08天 | 模型训练超参数损失函数优化器优化过程 训练与评估总结打卡 模型训练一般遵循四个步骤: 构建数据集定义神经网络模型定义超参数、损失函数和优化器输入数据集进行训练和评估 构建数据集和…

【Excel、RStudio计算T检测的具体操作步骤】

目录 一、基础知识1.1 显著性检验1.2 等方差T检验、异方差T检验1.3 单尾p、双尾p1.3.1 检验目的不同1.3.2 用法不同1.3.3 如何选择 二、Excel2.1 统计分析工具2.1.1 添加统计分析工具2.1.2 数据分析 2.2 公式 -> 插入函数 -> T.TEST 三、RStudio 一、基础知识 参考: 1.…

无人机运营合格证及无人机驾驶员合格证(AOPA)技术详解

无人机运营合格证及无人机驾驶员合格证(AOPA)技术详解如下: 一、无人机运营合格证 无人机运营合格证是无人机运营企业或个人必须获得的证书,以确保无人机在运营过程中符合相关法规和标准。对于无人机运营合格证的具体要求和申请…

【React】React18 Hooks之useState

目录 useState案例1(直接修改状态)案例2(函数式更新)案例3(受控表单绑定)注意事项1:set函数不会改变正在运行的代码的状态注意事项2:set函数自动批量处理注意事项3:在下次…

【反悔堆 优先队列 临项交换 决策包容性】630. 课程表 III

本文涉及知识点 贪心 反悔堆 优先队列 临项交换 Leetcode630. 课程表 III 这里有 n 门不同的在线课程,按从 1 到 n 编号。给你一个数组 courses ,其中 courses[i] [durationi, lastDayi] 表示第 i 门课将会 持续 上 durationi 天课,并且必…

E4.【C语言】练习:while和getchar的理解

#include <stdio.h> int main() {int ch 0;while ((ch getchar()) ! EOF){if (ch < 0 || ch>9)continue;putchar(ch);}return 0; } 理解上述代码 0-->48 9-->57 if行判断是否为数字&#xff0c;打印数字&#xff0c;不打印非数字

Selenium 切换 frame/iframe

环境&#xff1a; Python 3.8 selenium3.141.0 urllib31.26.19说明&#xff1a; driver.switch_to.frame() # 将当前定位的主体切换为frame/iframe表单的内嵌页面中 driver.switch_to.default_content() # 跳回最外层的页面# 判断元素是否在 frame/ifame 中 # 126 邮箱为例 # …

k8s公网集群安装(1.23.0)

网上搜到的公网搭建k8s都不太一致, 要么说的太复杂, 要么镜像无法下载, 所以写了一个简洁版,小白也能一次搭建成功 使用的都是centos7,k8s版本为1.23.0 使用二台机器搭建的, 三台也是一样的思路1.所有节点分别设置对应主机名 hostnamectl set-hostname master hostnamectl set…

javaSwing图书管理系统

一、 引言 图书管理系统是一个用于图书馆或书店管理图书信息、借阅记录和读者信息的应用程序。本系统使用Java Swing框架进行开发&#xff0c;提供直观的用户界面&#xff0c;方便图书馆管理员或书店工作人员对图书信息进行管理。以下是系统的设计、功能和实现的详细报告。 二…

Matplotlib Artist 1 概览

Matplotlib API中有三层 matplotlib.backend_bases.FigureCanvas&#xff1a;绘制区域matplotlib.backend_bases.Renderer&#xff1a;控制如何在FigureCanvas上绘制matplotlib.artist.Artist&#xff1a;控制render如何进行绘制 开发者95%的时间都是在使用Artist。Artist有两…

【C语言】自定义类型:联合和枚举

前言 前面我们学习了一种自定义类型&#xff0c;结构体&#xff0c;现在我们学习另外两种自定义类型&#xff0c;联合 和 枚举。 目录 一、联合体 1. 联合体类型的声明 2. 联合体的特点 3. 相同成员联合体和结构体对比 4. 联合体大小的计算 5. 用联合体判断当前机…

常见算法和Lambda

常见算法和Lambda 文章目录 常见算法和Lambda常见算法查找算法基本查找&#xff08;顺序查找&#xff09;二分查找/折半查找插值查找斐波那契查找分块查找扩展的分块查找&#xff08;无规律的数据&#xff09; 常见排序算法冒泡排序选择排序插入排序快速排序递归快速排序 Array…

hdu物联网硬件实验2 GPIO亮灯

学院 班级 学号 姓名 日期 成绩 实验题目 GPIO亮灯 实验目的 点亮三个灯闪烁频率为一秒 硬件原理 无 关键代码及注释 const int ledPin1 GREEN_LED; // the number of the LED pin const int ledPin2 YELLOW_LED; const int ledPin3 RED…

githup开了代理push不上去

你们好&#xff0c;我是金金金。 场景 git push出错 解决 cmd查看 git config --global http.proxy git config --global https.proxy 如果什么都没有&#xff0c;代表没设置全局代理&#xff0c;此时如果你开了代理&#xff0c;则执行如下&#xff0c;设置代理 git con…

【深度学习】图形模型基础(5):线性回归模型第四部分:预测与贝叶斯推断

1.引言 贝叶斯推断超越了传统估计方法&#xff0c;它包含三个关键步骤&#xff1a;结合数据和模型形成后验分布&#xff0c;通过模拟传播不确定性&#xff0c;以及利用先验分布整合额外信息。本文将通过实际案例阐释这些步骤&#xff0c;展示它们在预测和推断中的挑战和应用。…

CSS中 实现四角边框效果

效果图 关键代码 border-radius:10rpx ;background: linear-gradient(#fff, #fff) left top,linear-gradient(#fff, #fff) left top,linear-gradient(#fff, #fff) right top,linear-gradient(#fff, #fff) right top,linear-gradient(#fff, #fff) left bottom,linear-gradient(…

12 Dockerfile详解

目录 1. Dockerfile 2. Dockerfile构建过程 2.1. Dockerfile编写规则&#xff1a; 2.2. Docker执行Dockerfile的大致流程 2.3. 总结 3. Dockerfile指令 3.1. FROM 3.2. MAINTAINER 3.3. RUN 3.4. EXPOSE 3.5. WORKDIR 3.6. USER 3.7. ENV 3.8. VOLUME 3.9. ADD …