Pytorch 三小时极限入门教程

一、引言

在当今的人工智能领域,深度学习占据了举足轻重的地位。而 Pytorch 作为一款广受欢迎的深度学习框架,以其简洁、灵活的特性,吸引了大量开发者投身其中。无论是科研人员探索前沿的神经网络架构,还是工程师将深度学习技术落地到实际项目,Pytorch 都提供了强大的支持。本教程将带你从零基础开始,一步步深入了解 Pytorch 的核心知识,助你顺利踏上深度学习的征程。

二、Pytorch 基础环境搭建

安装 Anaconda

Anaconda 是一个强大的 Python 包管理器和环境管理器,方便我们创建独立的 Python 开发环境。首先,从 Anaconda 官方网站下载对应操作系统的安装包,一路默认安装即可。安装完成后,打开终端(Linux/Mac)或命令提示符(Windows),输入 conda --version 验证是否安装成功。

创建虚拟环境

使用 conda create -n pytorch_env python=3.8 创建一个名为 pytorch_env 的虚拟环境,这里指定 Python 版本为 3.8,你可以根据实际需求调整。激活虚拟环境,在 Linux/Mac 下使用 source activate pytorch_env,Windows 下使用 activate pytorch_env。

安装 Pytorch

访问 Pytorch 官方网站,根据你的系统配置(如 CUDA 是否可用)选择合适的安装命令。例如,如果你的电脑有 NVIDIA GPU 且支持 CUDA 11.3,安装命令可能为 conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch。如果没有 GPU,则选择 CPU 版本的安装命令,如 conda install pytorch torchvision torchaudio cpuonly -c pytorch。安装完成后,在 Python 交互式环境中输入 import torch,没有报错则说明安装成功。

三、张量(Tensor):深度学习的基石

张量的定义与创建

张量是 Pytorch 中最基本的数据结构,类似于 NumPy 中的数组,但具有更强的功能。可以使用 torch.tensor() 函数从 Python 列表或 NumPy 数组创建张量,例如:

import torchimport numpy as np# 从列表创建张量data_list = [1, 2, 3, 4]tensor_from_list = torch.tensor(data_list)# 从 NumPy 数组创建张量np_array = np.array([5, 6, 7, 8])tensor_from_numpy = torch.from_numpy(np_array)

还可以使用 torch.zeros()、torch.ones()、torch.rand() 等函数创建具有特定形状的全 0、全 1 或随机值张量。

张量的属性与操作

张量具有形状(shape)、数据类型(dtype)等属性。可以通过 .shape 和 .dtype 来访问,例如:

tensor = torch.rand(3, 4)print(tensor.shape)print(tensor.dtype)

张量支持丰富的数学运算,如加法、减法、乘法、除法等,操作符重载使得代码简洁直观:

a = torch.rand(2, 3)b = torch.rand(2, 3)c = a + bd = a * b

同时,也有大量的函数可供调用,像 torch.sum()、torch.mean() 等用于统计计算。

四、自动求导(Autograd):神经网络训练的关键

自动求导原理简介

在深度学习中,模型训练的核心是反向传播算法,而 Pytorch 的自动求导机制极大地简化了这一过程。当创建一个张量时,如果设置 requires_grad=True,Pytorch 会记录该张量上的所有操作,构建一个计算图。在反向传播时,利用这个计算图自动计算梯度。

示例:简单函数求导

x = torch.tensor([2.], requires_grad=True)y = x ** 2 + 3 * xy.backward()print(x.grad)

这里定义了一个简单的函数 ,对 x 求导后,x.grad 存储了梯度值,即 在 时的值 7。

 复杂模型中的应用

在构建神经网络时,模型参数都设置为 requires_grad=True。在每一次前向传播计算损失后,通过 loss.backward() 反向传播梯度,然后使用优化器(如 SGD、Adam 等)根据梯度更新参数,实现模型的训练。

五、神经网络模块(nn.Module):构建模型的利器

自定义神经网络

继承 nn.Module 类可以方便地自定义神经网络。首先在 __init__() 函数中定义模型的层结构,如全连接层 nn.Linear,卷积层 nn.Conv2d 等,然后在 forward() 函数中定义数据的前向传播路径。

import torch.nn as nnclass SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(10, 20)self.fc2 = nn.Linear(20, 1)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x

这里定义了一个简单的两层全连接神经网络,输入维度为 10,中间层维度为 20,输出维度为 1,中间使用 ReLU 作为激活函数。

预训练模型的使用与微调

Pytorch 提供了丰富的预训练模型,如 ResNet、VGG 等经典的图像分类模型。可以通过 torchvision.models 模块加载预训练模型,然后根据自己的任务需求,修改最后几层的结构并进行微调。例如:

import torchvision.models as modelsresnet = models.resnet18(pretrained=True)# 修改最后一层输出维度为自定义类别数resnet.fc = nn.Linear(resnet.fc.in_features, 10)

这使得在数据量有限的情况下,也能利用预训练模型的强大特征提取能力,快速搭建高性能模型。

六、数据加载与预处理(DataLoader)

数据集类的构建

要使用自己的数据训练模型,需要构建自定义数据集类,继承 torch.utils.data.Dataset。在类中实现 __getitem__() 方法用于获取单个样本及其标签,__len__() 方法返回数据集的大小。例如,对于图像分类数据集:

from torch.utils.data import Datasetimport osimport cv2class ImageDataset(Dataset):def __init__(self, root_dir, transform=None):self.root_dir = root_dirself.image_files = os.listdir(root_dir)self.transform = transformdef __getitem__(self, index):image_path = os.path.join(self.root_dir, self.image_files[index])image = cv2.imread(image_path)image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)label = int(self.image_files[index].split('.')[0])if self.transform:image = self.transform(image)return image, labeldef __len__(self):return len(self.image_files)

数据加载器的使用

使用 torch.utils.data.DataLoader 将数据集封装成可迭代的数据加载器,方便在训练过程中批量获取数据。可以设置批量大小(batch_size)、是否打乱数据(shuffle)等参数,例如:

from torch.utils.data import DataLoaderdataset = ImageDataset(root_dir='data/images', transform=transforms.ToTensor())dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

在训练循环中,通过遍历数据加载器获取批量数据,送入模型进行训练。

七、模型训练与评估

训练循环

模型训练通常包括多个 epoch,每个 epoch 遍历一遍整个数据集。在每个 epoch 内,按批次获取数据,前向传播计算损失,反向传播更新参数。以下是一个简单的训练循环示例:

model = SimpleNet()criterion = nn.MSELoss()optimizer = torch.optim.SGD(model.parameters(), lr=0.01)for epoch in range(10):running_loss = 0.0for i, (inputs, labels) in enumerate(dataloader):optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch {epoch + 1}, Loss: {running_loss / len(dataloader)}')

评估指标与方法

根据任务不同,评估指标各异。对于分类任务,常用准确率(Accuracy),可以通过比较模型预测结果与真实标签计算得出:

correct = 0total = 0with torch.no_grad():for inputs, labels in dataloader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / totalprint(f'Accuracy: {accuracy}')

对于回归任务,可能使用均方误差(MSE)、平均绝对误差(MAE)等指标。

八、模型保存与加载

保存模型

可以使用 torch.save() 保存模型的参数或整个模型结构,例如保存模型参数:

torch.save(model.state_dict(), 'model.pth')

若要保存整个模型,包括结构和参数:

torch.save(model, 'whole_model.pth')

加载模型

加载模型参数时,先创建模型实例,再使用 model.load_state_dict(torch.load('model.pth')) 加载。若加载整个模型,则直接 model = torch.load('whole_model.pth')。加载后,模型即可用于预测或继续训练。

九、可视化工具(TensorBoard)

安装与配置

TensorBoard 是一个强大的可视化工具,用于监控模型训练过程。使用 pip install tensorboard 安装,在 Pytorch 代码中引入相关模块:

from torch.utils.tensorboard import SummaryWriter

创建一个 SummaryWriter 实例,指定日志目录,如 writer = SummaryWriter('logs')。

可视化训练过程

在训练过程中,可以使用 writer.add_scalar() 记录损失、准确率等指标随 epoch 的变化:

for epoch in range(10):# 训练代码...writer.add_scalar('Loss', running_loss / len(dataloader), epoch)writer.add_scalar('Accuracy', accuracy, epoch)writer.close()

运行 tensorboard --logdir=logs 命令后,在浏览器中打开相应地址,即可查看可视化图表,直观了解模型训练动态。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/503038.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

欧科云链研究院:ChatGPT 眼中的 Web3

编辑|OKG Research 转眼间,2024年已经进入尾声,Web3 行业经历了热闹非凡的一年。今年注定也是属于AI的重要一年,OKG Research 决定拉上 ChatGPT 这位“最懂归纳的AI拍档”,尝试把一整年的研究内容浓缩成精华。我们一共…

.NET 9.0 WebApi 发布到 IIS 详细步骤

微软表示,.NET 9 是迄今为止性能最高的 .NET 版本,对运行时、工作负载和语言方面进行了 1,000 多项与性能相关的改进,并采用了更高效的算法来生成更好的代码。 .NET 9 是 .NET 8 的继任者,特别侧重于云原生应用和性能。 作为标准期…

【通识安全】煤气中毒急救的处置

1.煤气中毒的主要症状与体征一氧化碳中毒,其中毒症状一般分为轻、中、重三种。 (1)轻度:仅有头晕、头痛、眼花、心慌、胸闷、恶心等症状。如迅速打开门窗,或将病人移出中毒环境,使之吸入新鲜空气和休息,给些热饮料&am…

ECCV`24 | 首次解决文本到3D NeRFs分解问题!港中文等提出DreamDissector

论文链接:https://arxiv.org/abs/2407.16260 亮点直击 据作者所知,作者是第一个解决文本到3D NeRFs分解问题的团队。 为了解决这个问题,本文引入了一个名为DreamDissector的新颖框架,包括一种新颖的神经类别场(NeCF&a…

nginx-灰度发布策略(split_clients)

一. 简述: 基于客户端的灰度发布(也称为蓝绿部署或金丝雀发布)是一种逐步将新版本的服务或应用暴露给部分用户,以确保在出现问题时可以快速回滚并最小化影响的技术。对于 Nginx,可以通过配置和使用不同的模块来实现基于…

PCL点云库入门——PCL库点云特征之PFH点特征直方图(Point Feature Histograms -PHF)

1、算法原理 PFH点(Point Feature Histogram)特征直方图的原理涉及利用参数化查询点与邻域点之间的空间差异,并构建一个多维直方图以捕捉点的k邻域几何属性。这个高维超空间为特征表示提供了一个可度量的信息空间,对于点云对应曲面…

qml PathView详解

1、概述 PathView 是 Qt Quick 中一个非常强大的视图组件,它基于一个 Path 来展示视图项(如 Item、Rectangle 等)。PathView 可以让你按照定义的路径动态地显示多个元素,并且支持动画、滑动等功能。这个视图控件的最大特点是能够…

网络协议安全的攻击手法

1.使用SYN Flood泛洪攻击: SYN Flood(半开放攻击)是最经典的ddos攻击之一,他利用了TCP协议的三次握手机制,攻击者通常利用工具或控制僵尸主机向服务器发送海量的变源端口的TCP SYN报文,服务器响应了这些报文后就会生成大量的半连…

前端学习DAY31(子元素溢出父元素)

.box1{width: 200px;height: 200px;background-color: chocolate;} 子元素是在父元素的内容区中排列的,如果子元素的大小超过了父元素,则子元素会从 父元素中溢出,使用overflow属性设置父元素如何处理溢出的子元素 可选值:visible…

机器人手眼标定

机器人手眼标定 一、机器人手眼标定1. 眼在手上标定基本原理2. 眼在手外标定基本原理 二、眼在手外标定实验三、标定精度分析 一、机器人手眼标定 要实现由图像目标点到实际物体上抓取点之间的坐标转换,就必须拥有准确的相机内外参信息。其中内参是相机内部的基本参…

【前端下拉框】获取国家国旗

一、先看效果 二、代码实现&#xff08;含国旗&#xff09; <!DOCTYPE html> <html lang"zh"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><…

Timer、Ticker使用及其注意事项

Timer、Ticker使用及其注意事项 在刚开始学习golang语言的时候就听说Timer、Ticker的使用要尤其注意&#xff0c;很容易出现问题&#xff0c;这次就来一探究竟。 本文主要脉络&#xff1a; 介绍定时器体系&#xff0c;并介绍常用使用方式和错误使用方式源码解读 timer、tic…

C++11——2:可变模板参数

一.前言 C11引入了可变模板参数&#xff08;variadic template parameters&#xff09;的概念&#xff0c;它允许我们在模板定义中使用可变数量的参数。这样&#xff0c;我们就可以处理任意数量的参数&#xff0c;而不仅限于固定数量的参数。 二.可变模板参数 我们早在C语言…

君正T41交叉编译ffmpeg、opencv并做h264软解,利用君正SDK做h264硬件编码

目录 1 交叉编译ffmpeg----错误解决过程&#xff0c;不要看 1.1 下载源码 1.2 配置 1.3 编译 安装 1.3.1 报错&#xff1a;libavfilter/libavfilter.so: undefined reference to fminf 1.3.2 报错&#xff1a;error: unknown type name HEVCContext; did you mean HEVCPr…

感知器的那些事

感知器的那些事 历史背景Rosenblatt和Minsky关于感知机的争论弗兰克罗森布拉特简介提出感知器算法Mark I感知机争议与分歧马文明斯基简介单层感知器工作原理训练过程多层感知器工作原理单层感知机 vs 多层感知机感知器模型(Perceptron),是由心理学家Frank Rosenblatt在1957年…

C语言:枚举类型

一、枚举类型的声明 枚举顾名思义就是一一列举。我们可以把可能的取值一一列举。比如我们现实生活中&#xff1a; 星期一到星期日是有限的7天&#xff0c;可以一一列举 &#xff1b;性别有&#xff1a;男、女、保密&#xff0c;也可以一一列举 &#xff1b;月份有12个月&#x…

25/1/6 算法笔记<强化学习> 初玩V-REP

我们安装V-REP之后&#xff0c;使用的是下面Git克隆的项目。 git clone https://github.com/deep-reinforcement-learning_book/Chapter16-Robot-Learning-in-Simulation.git 项目中直接组装好了一个机械臂。 我们先来分析下它的对象树 DefaultCamera:摄像机&#xff0c;用于…

CODESYS MODBUS TCP通信(AM400PLC作为主站通信)

禾川Q1 PLC MODBUS-TCP通信 禾川Q1 PLC MODBUS-TCP通信(CODESYS平台完整配置+代码)-CSDN博客文章浏览阅读17次。MATLAB和S7-1200PLC水箱液位高度PID控制联合仿真(MODBUSTCP通信)_将matlab仿真导入plc-CSDN博客文章浏览阅读722次。本文详细介绍了如何使用MATLAB与S7-1200PLC进行…

OSPF - 影响OSPF邻居建立的因素

总结为这么10种 routerID 冲突区域id不一致认证MA网络掩码需一致区域类型(特殊区域)hello、dead时间MTU(如果开启检查)静默接口网络类型不匹配MA网络中路由器接口优先级全为0 如何建立邻居可以查看上一篇文章&#xff0c;可以直接专栏找&#xff08;&#x1f92b;挂链接会没流…

【大数据】(选修)实验4 安装熟悉HBase数据库并实践

实验4 安装熟悉HBase数据库并实践 1、实验目的 (1)理解HBase在Hadoop体系结构中的角色; (2)熟练使用HBase操作常用的Shell命令; (3)熟悉HBase操作常用的Java API。 2、实验平台 操作系统:Linux Hadoop版本:2.6.0或以上版本 HBase版本:1.1.2或以上版本 JDK版…