AI开发学习之——PyTorch框架

PyTorch 简介

PyTorch (Python torch)是由 Facebook AI 研究团队开发的开源机器学习库,广泛应用于深度学习研究和生产。它以动态计算图和易用性著称,支持 GPU 加速计算,并提供丰富的工具和模块。

PyTorch的主要特点

  1. 动态计算图:PyTorch 使用动态计算图(Autograd),允许在运行时修改图结构,便于调试和实验。
  2. GPU 加速:支持 CUDA,能够利用 GPU 进行高效计算。
  3. 模块化设计:提供 torch.nn 等模块,便于构建和训练神经网络。
  4. 丰富的生态系统:包括 TorchVision、TorchText 和 TorchAudio 等,支持多种任务。、

PyTorch的安装

通过以下命令安装 PyTorch:

pip install torch torchvision

如果国内的速度慢,可以使用-i 参数使用国内的仓库源。

pip3 install torch -i https://pypi.tuna.tsinghua.edu.cn/simple

除了清华的源之外,也可以使用科大或是北外的数据源。

  • https://mirrors.bfsu.edu.cn/pypi/web/simple

  • https://mirrors.ustc.edu.cn/pypi/web/simple

使用示例

1. 张量操作
import torch# 创建张量
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([4.0, 5.0, 6.0])# 加法
z = x + y
print(z)  # 输出: tensor([5., 7., 9.])

这里的输出为什么不是 tensor([5.0, 7.0, 9.0])呢?
在Python的浮点数表示中,.0后缀通常用于明确表示一个数是浮点数(float),而不是整数(int)。然而,在大多数情况下,Python和许多库(包括PyTorch,这里提到的tensor是由PyTorch生成的)在打印浮点数时,如果小数点后没有额外的数字,它们可能会省略.0后缀以简化输出。

当使用科学计算库如NumPy或PyTorch时,它们通常有统一的输出格式,尤其是在处理数组或tensor时。在你的例子中,tensor([5., 7., 9.])tensor([5.0, 7.0, 9.0])在数值上是完全相同的,只是表示形式略有不同。PyTorch选择省略小数点后没有数字的.0后缀,以使输出更简洁。

这种输出格式的选择主要是出于可读性和简洁性的考虑,并不影响tensor中存储的实际数值。在数值计算中,5.5.0都被视为浮点数,并且在计算中没有任何区别。

2. 自动求导
import torch# 创建需要梯度的张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)# 定义函数
y = x * 2
z = y.mean()# 反向传播
z.backward()# 查看梯度
print(x.grad)  # 输出: tensor([0.6667, 0.6667, 0.6667])

这里的结果是怎么来的呢?

这段代码演示了 PyTorch 中的**自动微分(Autograd)**机制,通过计算梯度来实现反向传播。我们来逐步分析代码的运算过程。


1. 创建需要梯度的张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
  • x 是一个包含 [1.0, 2.0, 3.0] 的 1 阶张量(向量)。
  • requires_grad=True 表示 PyTorch 需要跟踪对 x 的所有操作,以便后续计算梯度。

2. 定义函数
y = x * 2
z = y.mean()
  • y = x * 2:对 x 逐元素乘以 2,得到 y = [2.0, 4.0, 6.0]
  • z = y.mean():计算 y 的均值,即:
    在这里插入图片描述

3. 反向传播
z.backward()
  • z.backward() 表示从 z 开始反向传播,计算 zx 的梯度。
  • 由于 z 是一个标量(单个值),PyTorch 会自动计算 zx 的梯度。

4. 梯度计算

PyTorch 通过链式法则计算梯度。具体步骤如下:

(1)计算 zy 的梯度

  • z = y.mean() 可以写成:
    在这里插入图片描述

  • 因此,zy 的梯度为:
    在这里插入图片描述

(2)计算 yx 的梯度

  • y = x * 2 可以写成:
    yi​=2xi​
  • 因此,yx 的梯度为:
    在这里插入图片描述

(3)计算 zx 的梯度
根据链式法则:
在这里插入图片描述

将结果代入:
在这里插入图片描述


5. 查看梯度
print(x.grad)  # 输出: tensor([0.6667, 0.6667, 0.6667])
  • x.grad 存储了 zx 的梯度,结果为:
    在这里插入图片描述

总结

这段代码的运算过程如下:

  1. 创建需要梯度的张量 x
  2. 定义函数 y = x * 2z = y.mean()
  3. 通过 z.backward() 计算 zx 的梯度。
  4. 根据链式法则,梯度计算结果为 [0.6667, 0.6667, 0.6667]

PyTorch 的自动微分机制使得梯度计算变得非常简单,尤其是在深度学习模型中,这种机制可以自动计算损失函数对模型参数的梯度,从而支持梯度下降等优化算法。

3. 简单神经网络
import torch
import torch.nn as nn
import torch.optim as optim# 定义网络
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc = nn.Linear(1, 1)def forward(self, x):return self.fc(x)# 创建网络、损失函数和优化器
model = SimpleNet()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练数据
x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
y = torch.tensor([[2.0], [4.0], [6.0], [8.0]])# 训练过程
for epoch in range(100):optimizer.zero_grad()outputs = model(x)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch+1) % 10 == 0:print(f'Epoch [{epoch+1}/100], Loss: {loss.item():.4f}')
4. 使用 GPU
import torch# 检查 GPU 是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 创建张量并移动到 GPU
x = torch.tensor([1.0, 2.0, 3.0]).to(device)
y = torch.tensor([4.0, 5.0, 6.0]).to(device)# 在 GPU 上执行加法
z = x + y
print(z)  # 输出: tensor([5., 7., 9.], device='cuda:0')

torchtorchvisiontorchaudio

torchtorchvisiontorchaudio 是 PyTorch 生态系统中的三个核心库,分别用于通用深度学习、计算机视觉和音频处理任务。以下是它们的详细介绍和作用:


1. torch

torch 是 PyTorch 的核心库,提供了深度学习的基础功能,包括张量操作、自动求导、神经网络模块等。

主要功能:
  • 张量操作:支持高效的张量计算(如加法、乘法、矩阵运算等)。
  • 自动求导:通过 Autograd 模块实现自动微分,便于梯度计算和优化。
  • 神经网络模块:提供 torch.nn 模块,包含各种层(如全连接层、卷积层)和损失函数。
  • 优化器:提供 torch.optim 模块,包含 SGD、Adam 等优化算法。
  • GPU 加速:支持 CUDA,可以利用 GPU 进行高性能计算。
使用场景:
  • 构建和训练深度学习模型。
  • 实现自定义的数学运算和算法。
  • 进行张量计算和数值模拟。
示例:
import torch# 创建张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)# 定义计算
y = x * 2
z = y.mean()# 自动求导
z.backward()# 查看梯度
print(x.grad)  # 输出: tensor([0.6667, 0.6667, 0.6667])

2. torchvision

torchvision 是 PyTorch 的计算机视觉库,提供了常用的数据集、模型架构和图像处理工具。

主要功能:
  • 数据集:提供常用的计算机视觉数据集(如 MNIST、CIFAR-10、ImageNet)。
  • 模型架构:包含预训练的经典模型(如 ResNet、VGG、AlexNet)。
  • 图像处理工具:提供数据增强和转换工具(如裁剪、旋转、归一化)。
  • 实用工具:包括可视化工具和评估指标。
使用场景:
  • 图像分类、目标检测、分割等计算机视觉任务。
  • 加载和处理图像数据。
  • 使用预训练模型进行迁移学习。
示例:
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18# 数据预处理
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 加载数据集
dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)# 加载预训练模型
model = resnet18(pretrained=True)

3. torchaudio

torchaudio 是 PyTorch 的音频处理库,提供了音频数据的加载、处理和转换工具。

主要功能:
  • 音频加载和保存:支持多种音频格式(如 WAV、MP3)。
  • 音频处理:提供音频信号处理工具(如重采样、频谱图生成)。
  • 数据集:包含常用的音频数据集(如 LibriSpeech、VoxCeleb)。
  • 特征提取:支持提取 MFCC、Mel 频谱等音频特征。
使用场景:
  • 语音识别、语音合成、音频分类等任务。
  • 音频数据的预处理和特征提取。
  • 加载和处理音频数据集。
示例:
import torchaudio
import torchaudio.transforms as T# 加载音频文件
waveform, sample_rate = torchaudio.load('example.wav')# 重采样
resampler = T.Resample(orig_freq=sample_rate, new_freq=16000)
resampled_waveform = resampler(waveform)# 提取 Mel 频谱
mel_spectrogram = T.MelSpectrogram(sample_rate=16000)(resampled_waveform)

三者的关系

  • torch 是核心库,提供基础功能(如张量计算、自动求导、神经网络模块)。
  • torchvision 是基于 torch 的扩展库,专注于计算机视觉任务。
  • torchaudio 是基于 torch 的扩展库,专注于音频处理任务。

三者可以结合使用,例如:

  • 使用 torchvision 处理图像数据,用 torch 构建和训练模型。
  • 使用 torchaudio 处理音频数据,用 torch 构建语音识别模型。

安装

可以通过以下命令安装这三个库:

pip install torch torchvision torchaudio

总结

  • torch:核心库,提供深度学习的基础功能。
  • torchvision:计算机视觉库,提供数据集、模型和图像处理工具。
  • torchaudio:音频处理库,提供音频加载、处理和特征提取工具。

三者共同构成了 PyTorch 的完整生态系统,适用于各种深度学习任务。



本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/11232.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

纯后训练做出benchmark超过DeepseekV3的模型?

论文地址 https://arxiv.org/pdf/2411.15124 模型是AI2的,他们家也是玩开源的 先看benchmark,几乎是纯用llama3 405B后训练去硬刚出一个gpt4o等级的LLamA405 我们先看之前的机遇Lllama3.1 405B进行全量微调的模型 Hermes 3,看着还没缘模型…

像接口契约文档 这种工件,在需求 分析 设计 工作流里面 属于哪一个工作流

οゞ浪漫心情ゞο(20***328) 2016/2/18 10:26:47 请教一下,像接口契约文档 这种工件,在需求 分析 设计 工作流里面 属于哪一个工作流? 潘加宇(35***47) 17:17:28 你这相当于问用例图、序列图属于哪个工作流,看内容。 如果你的&quo…

代码随想录刷题笔记

数组 二分查找 ● 704.二分查找 tips:两种方法,左闭右开和左闭右闭,要注意区间不变性,在判断mid的值时要看mid当前是否使用过 ● 35.搜索插入位置 ● 34.在排序数组中查找元素的第一个和最后一个位置 tips:寻找左右边…

PyTorch框架——基于深度学习YOLOv8神经网络学生课堂行为检测识别系统

基于YOLOv8深度学习的学生课堂行为检测识别系统,其能识别三种学生课堂行为:names: [举手, 读书, 写字] 具体图片见如下: 第一步:YOLOv8介绍 YOLOv8 是 ultralytics 公司在 2023 年 1月 10 号开源的 YOLOv5 的下一个重大更新版本…

【Elasticsearch】实现气象数据存储与查询系统

🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/literature?__c1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,…

python-leetcode-相同的树

100. 相同的树 - 力扣(LeetCode) # Definition for a binary tree node. # class TreeNode: # def __init__(self, val0, leftNone, rightNone): # self.val val # self.left left # self.right right class Solution:de…

IM 即时通讯系统-50-[特殊字符]cim(cross IM) 适用于开发者的分布式即时通讯系统

IM 开源系列 IM 即时通讯系统-41-开源 野火IM 专注于即时通讯实时音视频技术,提供优质可控的IMRTC能力 IM 即时通讯系统-42-基于netty实现的IM服务端,提供客户端jar包,可集成自己的登录系统 IM 即时通讯系统-43-简单的仿QQ聊天安卓APP IM 即时通讯系统-44-仿QQ即…

2025年1月22日(网络编程 udp)

系统信息: ubuntu 16.04LTS Raspberry Pi Zero 2W 系统版本: 2024-10-22-raspios-bullseye-armhf Python 版本:Python 3.9.2 已安装 pip3 支持拍摄 1080p 30 (1092*1080), 720p 60 (1280*720), 60/90 (640*480) 已安装 vim 已安装 git 学习…

基于微信小程序的电子商城购物系统设计与实现(LW+源码+讲解)

专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…

MP4分析工具

在实际应用中,我们经常需要对MP4文件进行分析。分析MP4封装格式的工具比较多,下面介绍几款常用的工具: 1、mp4info 优点: 带界面的可视化工具可以清晰看到各个box的组成和层次同时可以分离里面的音视频文件可以看到音视频的时间…

傅里叶分析之掐死教程

https://zhuanlan.zhihu.com/p/19763358 要让读者在不看任何数学公式的情况下理解傅里叶分析。 傅里叶分析 不仅仅是一个数学工具,更是一种可以彻底颠覆一个人以前世界观的思维模式。但不幸的是,傅里叶分析的公式看起来太复杂了,所以很多…

想品客老师的第天:类

类是一个优化js面向对象的工具 类的声明 //1、class User{}console.log(typeof User)//function//2、let Hdclass{}//其实跟1差不多class Stu{show(){}//注意这里不用加逗号,对象才加逗号get(){console.log(后盾人)}}let hdnew Stu()hd.get()//后盾人 类的原理 类…

【Git】初识Git Git基本操作详解

文章目录 学习目标Ⅰ. 初始 Git💥注意事项 Ⅱ. Git 安装Linux-centos安装Git Ⅲ. Git基本操作一、创建git本地仓库 -- git init二、配置 Git -- git config三、认识工作区、暂存区、版本库① 工作区② 暂存区③ 版本库④ 三者的关系 四、添加、提交更改、查看提交日…

基于单片机的盲人智能水杯系统(论文+源码)

1 总体方案设计 本次基于单片机的盲人智能水杯设计,采用的是DS18B20实现杯中水温的检测,采用HX711及应力片实现杯中水里的检测,采用DS1302实现时钟计时功能,采用TTS语音模块实现语音播报的功能,并结合STC89C52单片机作…

深入解析“legit”的地道用法——从俚语到正式表达:Sam Altman用来形容DeepSeek: legit invigorating(真的令人振奋)

深入解析“legit”的地道用法——从俚语到正式表达 一、引言 在社交媒体、科技圈甚至日常对话中,我们经常会看到或听到“legit”这个词。比如最近 Sam Altman 在 X(原 Twitter)上发的一条帖子中写道: we will obviously deliver …

微机原理与接口技术期末大作业——4位抢答器仿真

在微机原理与接口技术的学习旅程中,期末大作业成为了检验知识掌握程度与实践能力的关键环节。本次我选择设计并仿真一个 4 位抢答器系统,通过这个项目,深入探索 8086CPU 及其接口技术的实际应用。附完整压缩包下载。 一、系统设计思路 &…

【大模型LLM面试合集】大语言模型架构_MHA_MQA_GQA

MHA_MQA_GQA 1.总结 在 MHA(Multi Head Attention) 中,每个头有自己单独的 key-value 对;标准的多头注意力机制,h个Query、Key 和 Value 矩阵。在 MQA(Multi Query Attention) 中只会有一组 k…

【Transformer】手撕Attention

import torch from torch import nn import torch.functional as F import mathX torch.randn(16,64,512) # B,T,Dd_model 512 # 模型的维度 n_head 8 # 注意力头的数量多头注意力机制 class multi_head_attention(nn.Module): def __init__(self, d_model, n_hea…

【Linux】 冯诺依曼体系与计算机系统架构全解

Linux相关知识点可以通过点击以下链接进行学习一起加油!初识指令指令进阶权限管理yum包管理与vim编辑器GCC/G编译器make与Makefile自动化构建GDB调试器与Git版本控制工具Linux下进度条 冯诺依曼体系是现代计算机设计的基石,其统一存储和顺序执行理念推动…

冯·诺依曼体系结构

目录 冯诺依曼体系结构推导 内存提高冯诺依曼体系结构效率的方法 你使用QQ和朋友聊天时,整个数据流是怎么流动的(不考虑网络情况) 与冯诺依曼体系结构相关的一些知识 冯诺依曼体系结构推导 计算机的存在就是为了解决问题,而解…