pytorch学习笔记

文章目录

  • 前言
  • 一、What is PyTorch
  • 二、Training Neural Networks
  • 三、Training&Testing Neural Networks
  • 四、Tensors
  • 五、Training&Testing Neural Networks
  • 六、torch.nn
  • 七、Neural Network Training Setup
  • 总结


前言

PyTorch 是一个流行的深度学习框架,具有动态计算图的特性,广泛应用于研究和生产环境。它的灵活性、简单的接口和与 Python 的深度集成使其成为构建机器学习模型的理想工具。


一、What is PyTorch

PyTorch是python中机器学习的框架

两个主要特点:
将高维的矩阵运算用GPU进行运算
用于训练深度神经网络的自动微分

在这里插入图片描述
在训练神经网络时需要计算一些梯度的部分,PyTorch已经将这些部分打包可以很方便的使用

二、Training Neural Networks

在这里插入图片描述

训练神经网络的三个步骤

机器学习笔记-1

机器学习笔记-2

三、Training&Testing Neural Networks

在这里插入图片描述

数据集&数据加载器

在这里插入图片描述

Dataset:Dataset 是数据样本和标签的集合。在 PyTorch 中,Dataset 用于存储数据样本和对应的期望输出值。它是所有数据处理的基础。
Dataloader:Dataloader 是 PyTorch 中用于将数据分批(batch)加载的工具,它将 Dataset 中的数据按照设定的批量大小(batch_size)进行分组,并能启用多线程来提高数据加载的效率。

代码示例:
Dataset:

dataset = MyDataset(file)

这里 MyDataset 是自定义的 Dataset 类,传入的数据文件 file 包含了所需的数据。

Dataloader:

dataloader = DataLoader(dataset, batch_size, shuffle=True)

这个 DataLoader 将数据集 dataset 载入,设定的 batch_size 是每次训练中处理的样本数量,并且设置 shuffle=True 代表在训练过程中数据会随机打乱。

Training:当训练模型时,shuffle=True 是常见的设置,以确保数据的顺序不会影响模型训练。
Testing:在测试时,通常不打乱数据,所以 shuffle=False。

在这里插入图片描述
导入模块

from torch.utils.data import Dataset, DataLoader

PyTorch 提供了 Dataset 和 DataLoader 模块,Dataset 是用于存储和管理数据的基类,DataLoader 则用于批量加载数据。

自定义 Dataset 类

class MyDataset(Dataset):def __init__(self, file):self.data = ...

__init__ 方法:在初始化时,读取并处理数据。这里会从 file 中加载数据并存储在 self.data 中。

在实际应用中,这部分可能包括加载 CSV 文件、图像文件或其他格式的数据,并进行必要的预处理。

实现 __getitem__ 方法

def __getitem__(self, index):return self.data[index]

__getitem__:用于通过索引获取数据集中的一个样本。每次调用时,返回指定索引位置的样本。Dataloader 会在训练时多次调用这个方法。

实现 __len__ 方法

def __len__(self):return len(self.data)

__len__:返回数据集的样本总数。这个方法在使用 Dataloader 时用于确定一共有多少个批次(batch)。

PyTorch 中 DataLoader 从 Dataset 中按批次加载数据的过程。

在这里插入图片描述

四、Tensors

Tensors的形状及其维度概念,并通过 .shape() 方法来检查张量的形状。

在这里插入图片描述

Creat Tensors

在这里插入图片描述

Common Operations

张量的一些常见运算

在这里插入图片描述

import torch# 创建张量
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])# 加法
z_add = x + y
print(z_add)  # 输出: tensor([5, 7, 9])# 减法
z_sub = x - y
print(z_sub)  # 输出: tensor([-3, -3, -3])# 求平方
z_pow = x.pow(2)
print(z_pow)  # 输出: tensor([1, 4, 9])# 求和
z_sum = x.sum()
print(z_sum)  # 输出: tensor(6)# 求平均值
z_mean = x.mean()
print(z_mean)  # 输出: tensor(2.)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

Data Type

在这里插入图片描述

PyTorch v.s. NumPy

在这里插入图片描述
在这里插入图片描述

Device

在这里插入图片描述
在这里插入图片描述

Gradient Calculation

使用 自动求导(autograd)来计算张量的梯度。

在这里插入图片描述

五、Training&Testing Neural Networks

在PyTorch中训练和测试神经网络的流程,包括几个关键步骤:定义网络结构、损失函数、优化算法,以及训练、验证、测试的整体流程。

在这里插入图片描述

六、torch.nn

nn.Linear(in_features, out_features) 是 PyTorch 中的全连接层,也叫线性层(Linear Layer)。它将输入张量的每一个特征映射到输出特征空间。

  • in_features:输入的特征数,即输入张量的最后一维大小。
  • out_features:输出的特征数,即输出张量的最后一维大小。

在这里插入图片描述

  • nn.Linear 是 PyTorch 中用于全连接层的模块,用于将输入特征空间线性映射到输出特征空间。
  • 输入张量的形状:输入张量的最后一个维度必须与 in_features 匹配,其他维度可以任意。
  • 输出张量的形状:输出张量的最后一个维度为 out_features,其他维度与输入张量相同。

Network Layers

在这里插入图片描述
在这里插入图片描述

  • 全连接层:每个输入特征与每个输出神经元都有一个权重连接,计算方式为输入张量乘以权重矩阵加上偏置向量。
  • 矩阵表示:通过矩阵乘法快速实现大规模输入特征到输出特征的映射,PyTorch 中的 nn.Linear 实现了这一操作。
  • 输入输出维度:输入维度与权重矩阵的列数相同,输出维度与权重矩阵的行数相同。

Network Parameters

PyTorch中nn.Linear全连接层的网络参数——权重和偏置

在这里插入图片描述
W 是权重矩阵,大小为[64,32]。
x 是输入张量,大小为 [32]。
b 是偏置向量,大小为 [64]。
y 是输出张量,大小为 [64]。

Non-Linear Activation Functions

两种非线性激活函数:Sigmoid 和 ReLU。

在这里插入图片描述
激活函数引入了非线性,使神经网络能够拟合更复杂的模式。线性层(如 nn.Linear)本身是线性的,如果不加入非线性激活函数,整个神经网络的表现将退化为线性变换,无法处理复杂的非线性问题。

Build your own neural network

用 PyTorch 的 torch.nn 模块来自定义一个神经网络模型。

在这里插入图片描述
导入必要的模块:

import torch.nn as nn
  • torch.nn 是 PyTorch 中的神经网络模块,提供了构建神经网络模型所需的所有层和功能。

定义一个继承自 nn.Module 的模型类:

class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.net = nn.Sequential(nn.Linear(10, 32),nn.Sigmoid(),nn.Linear(32, 1))
  • 这里定义了一个名为 MyModel 的类,继承自 nn.Module,这是 PyTorch 中所有神经网络的基类。每个自定义模型都需要继承这个类。
    init() 方法中:
    使用 super() 来初始化父类(nn.Module)。
    self.net 定义了网络的层次结构:
    nn.Linear(10, 32):一个全连接层,输入维度为 10,输出维度为 32。
    nn.Sigmoid():使用 Sigmoid 激活函数。
    nn.Linear(32, 1):另一个全连接层,将输出维度从 32 映射到 1。
    这些层被组合在 nn.Sequential 中,这是一个容器,用来按顺序堆叠层。

定义前向传播(forward)函数:

def forward(self, x):return self.net(x)
  • forward() 函数是 PyTorch 中所有模型的核心,用来定义数据的前向传播过程。
  • 在这个函数中,输入数据 x 通过 self.net(即模型的所有层)进行计算,并返回最终的输出结果。

nn.Module 是所有神经网络的基类,定义自定义模型时需要继承它。
nn.Sequential 是一个按顺序堆叠多个层的容器,可以简化神经网络的定义。
forward() 函数 定义了前向传播的逻辑,用于计算模型的输出。

在这里插入图片描述

nn.Sequential:通过一个容器,将各层按顺序定义,前向传播时一次传递输入。使用方便,适合网络层数较少且没有复杂结构的模型。
手动定义各层:显式地定义每一层并在 forward 函数中逐层调用,灵活性更高。适合在前向传播中有更复杂操作或跳跃连接(如 ResNet 中的残差连接)。

Loss Functions

在这里插入图片描述

优化算法模块 torch.optim

在这里插入图片描述
在这里插入图片描述
优化算法:用于通过梯度更新模型的参数,减少误差。
随机梯度下降(SGD):经典的优化算法,简单有效。
PyTorch 实现:通过 torch.optim.SGD 或其他优化器进行参数更新。
学习率和动量:是优化器的重要超参数,影响模型的收敛速度和稳定性。

七、Neural Network Training Setup

在这里插入图片描述
加载数据集

dataset = MyDataset(file)
  • 使用自定义的 MyDataset 类来读取数据,file 是数据文件的路径或名称。MyDataset 是一个继承自torch.utils.data.Dataset 的类,包含了数据的读取和预处理逻辑。

创建 DataLoader

tr_set = DataLoader(dataset, batch_size=16, shuffle=True)
  • DataLoader 是 PyTorch 中用于分批次加载数据的工具。这里将 dataset 数据集以批量大小为 16 进行分组,设置shuffle=True 使数据在每个 epoch 开始时被随机打乱。
  • batch_size=16:每次从数据集中加载 16 个样本。
  • shuffle=True:表示在每个 epoch 训练前打乱数据,以防止模型过拟合数据的顺序。

构建模型并将其移动到设备

model = MyModel().to(device)
  • 构建了一个自定义的神经网络模型 MyModel,并将模型移动到指定的设备(如 GPU 或 CPU)。device 可以是 torch.device(‘cuda’) 或 torch.device(‘cpu’),根据你的硬件设置决定。
  • .to(device):确保模型和数据被移动到同一设备(如 GPU),以加速计算。

设置损失函数

criterion = nn.MSELoss()
  • criterion 定义了损失函数,这里使用的是 均方误差损失(MSELoss),常用于回归任务。它计算模型预测值与目标值之间的平方差异。
  • nn.MSELoss():适用于回归任务,输出是预测值和目标值之间的均方误差。

设置优化器

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
  • optimizer 定义了优化算法,这里使用的是 随机梯度下降(SGD) 优化器。它负责根据计算出的梯度更新模型参数,以最小化损失函数。
  • model.parameters():返回模型的所有可训练参数。
  • lr=0.1:学习率,决定每次参数更新时的步长。较大的学习率可能加速收敛,但过大会导致不稳定。

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述


总结

PyTorch 是一个灵活、强大的深度学习框架,广泛应用于研究和实际生产中。它提供了动态计算图、GPU 加速、易用的 API 和丰富的社区资源,适合从初学者到专家的各种需求。通过 PyTorch,你可以轻松构建、训练和优化各种复杂的神经网络模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/447475.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

实战篇:(三)项目实战Vue 3 + WebGL 创建一个简单的 3D 渲染应用

Vue 3 WebGL 创建一个简单的 3D 渲染应用 我们将使用 Vue 3 和 WebGL 创建一个简单的 3D 渲染应用。项目将展示如何在 Vue 组件中集成 WebGL,并渲染一个旋转的立方体。 1. 项目准备 首先,确保你已经安装了 Node.js 和 Vue CLI。如果还没有安装&#x…

解密京东详情 API 接口:获取与运用指南

一、什么是京东详情API接口? 京东详情API接口是京东开放平台提供的一种服务,允许开发者通过编程方式获取商品的详细信息。通过调用这个接口,你可以获取到商品的基本信息、价格、库存、评价等数据。这些数据可以帮助你更好地了解商品的情况&a…

基于WebSocket实现简易即时通讯功能

代码实现 pom.xml <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId> </dependency> <dependency><groupId>org.springframework.boot</groupId><artifa…

2024最新分别用sklearn和NumPy设计k-近邻法对鸢尾花数据集进行分类(包含详细注解与可视化结果)

本文章代码实现以下功能&#xff1a; 利用sklearn设计实现k-近邻法。 利用NumPy设计实现k-近邻法。 将设计的k-近邻法对鸢尾花数据集进行分类&#xff0c;通过准确率来验证所设计算法的正确性&#xff0c;并将分类结果可视化。 评估k取不同值时算法的精度&#xff0c;并通过…

HTML CSS 基础

HTML & CSS 基础 HTML一、HTML简介1、网页1.1 什么是网页1.2 什么是HTML1.3 网页的形成1.4总结 2、web标准2.1 为什么需要web标准2.2 Web 标准的构成 二、HTML 标签1、HTML 语法规范1.1基本语法概述1.2 标签关系 2、 HTML 基本结构标签2.1 第一个 HTML 网页2.2 基本结构标签…

uniapp 游戏 - 使用 uniapp 实现的扫雷游戏

0. 思路 1. 效果图 2. 游戏规则 扫雷的规则很简单。盘面上有许多方格,方格中随机分布着一些雷。你的目标是避开雷,打开其他所有格子。一个非雷格中的数字表示其相邻 8 格子中的雷数,你可以利用这个信息推导出安全格和雷的位置。你可以用右键在你认为是雷的地方插旗(称为标…

中华春节符号·世界品牌——粤港澳企(实)业协会商会经济合作座谈会成功举办

日前&#xff0c;一场旨在推动粤港澳三地经济深度合作的盛会——《粤港澳企&#xff08;实&#xff09;业协会商会经济合作座谈会》在广州市天河区时代TIT广场2栋801车陂社区文化中心隆重举行。此次活动由泰康之家粤园与广东经贸文化促进会联合主办&#xff0c;吸引了全球华人企…

Dubbo SpringBoot应用创建和K8S部署

推荐阅读&#xff1a;Dubbo 快速入门-CSDN博客 创建基于Spring Boot的微服务应用 以下文档将引导您从头创建一个基于 Spring Boot 的 Dubbo 应用&#xff0c;并为应用配置 Triple 通信协议、服务发现等微服务基础能力。 快速创建应用 以下是我们为您提前准备好的示例项目&am…

AI大模型开发架构设计(12)——以真实场景案例驱动深度剖析 AIGC 新时代 IT 人的能力模型

文章目录 以真实场景案例驱动深度剖析 AIGC 新时代 IT 人的能力模型1 AI 工具以及大模型会给我们带来哪些实际收益?业务研发流程环节为什么 LLM 大模型不是万能的?LLM 大模型带来实际收益 2 新时代IT人的能力模型会发生哪些变化?古典互联网架构师能力模型IT人能力模型变化 以…

这都能封!开发者行为导致Google账号关联?

从去年10月开始&#xff0c;在AI加持下&#xff0c;Google Play不断更新和迭代审查机制&#xff0c;Google Play在最近一年的时间真是杀疯了&#xff0c;封号的声音响彻整个行业&#xff0c;尤其是一些敏感品类行业。账号关联&#xff0c;恶意软件&#xff0c;欺骗行为&#xf…

小红书新ID保持项目StoryMaker,面部特征、服装、发型和身体特征都能保持一致!(已开源)

继之前和大家介绍的小红书在ID保持以及风格转换方面相关的优秀工作&#xff0c;感兴趣的小伙伴可以点击以下链接阅读~ 近期&#xff0c;小红书又新开源了一款文生图身份保持项目&#xff1a;StoryMaker&#xff0c;是一种个性化解决方案&#xff0c;它不仅保留了面部的一致性&…

贪吃蛇游戏(代码篇)

我们并不是为了满足别人的期待而活着。 前言 这是我自己做的第五个小项目---贪吃蛇游戏&#xff08;代码篇&#xff09;。后期我会继续制作其他小项目并开源至博客上。 上一小项目是贪吃蛇游戏&#xff08;必备知识篇&#xff09;&#xff0c;没看过的同学可以去看看&#xf…

Angular Count-To 项目教程

Angular Count-To 项目教程 angular-count-to Angular directive to animate counting to a number 项目地址: https://gitcode.com/gh_mirrors/an/angular-count-to 1. 项目介绍 Angular Count-To 是一个用于 AngularJS 的动画计数器指令。该指令可以在指定的时间内从…

Lfsr32

首先分析 Lfsr5 首先要理解什么是抽头点&#xff08;tap&#xff09;&#xff0c;注意到图中有两个触发器的输入为前级输出与q[0]的异或&#xff0c;这些位置被称为 tap position.通过观察上图&#xff0c;所谓抽头点指的就是第5个&#xff0c;第3个寄存器的输入经过了异或逻辑…

利用C++封装鼠标轨迹算法为DLL:游戏行为检测的利器

在现代软件开发中&#xff0c;鼠标轨迹模拟技术因其在自动化测试、游戏脚本编写等领域的广泛应用而备受青睐。本文将介绍如何使用C语言将鼠标轨迹算法封装为DLL&#xff08;动态链接库&#xff09;&#xff0c;以便在多种编程环境中实现高效调用&#xff0c;同时探讨其在游戏行…

cudnn8编译caffe过程(保姆级图文全过程,涵盖各种报错及解决办法)

众所周知,caffe是个较老的框架,而且只支持到cudnn7,但是笔者在复现ds-slam过程中又必须编译caffe,我的cuda版本是11.4,最低只支持到8.2.4,故没办法,只能编译了 在此记录过程、报错及解决办法如下; 首先安装依赖: sudo apt-get install git sudo apt-get install lib…

【IEEE独立出版 | 厦门大学主办】第四届人工智能、机器人和通信国际会议(ICAIRC 2024)

【IEEE独立出版 | 厦门大学主办】 第四届人工智能、机器人和通信国际会议&#xff08;ICAIRC 2024&#xff09; 2024 4th International Conference on Artificial Intelligence, Robotics, and Communication 2024年12月27-29日 | 中国厦门 >>往届均已成功见刊检索…

【Kubernetes① 基础】一、容器基础

目录 一、进程二、隔离与限制三、容器镜像总结参考书籍 一、进程 容器技术的兴起源于PaaS技术(平台即服务)的普及&#xff1b;Docker公司发布的Docker项目具有里程碑式的意义&#xff1b;Docker项目通过“容器镜像”解决了应用打包这个根本性难题(CloudFoundry)。 容器本身的价…

【QAMISRA】解决导入commands.json时报错问题

1、 文档目标 解决导入commands.json时报错“Could not obtain system-wide includes and defines”的问题。 2、 问题场景 客户导入commands.json时报错“Could not obtain system-wide includes and defines”。 3、软硬件环境 1、软件版本&#xff1a; QA-MISRA23.04 2、…

【电路笔记】-运算放大器多谐振荡器

运算放大器多谐振荡器 文章目录 运算放大器多谐振荡器1、概述2、施密特触发器3、运算放大器稳态多谐振荡器4、运算放大器单稳态多谐振荡器5、运算放大器双稳态多谐振荡器6、总结1、概述 本文将重点介绍通常称为多谐振荡器的配置,特别是基于运算放大器的电路。 事实上,多谐振…