【人工智能】Python常用库-PyTorch常用方法教程

PyTorch 是一个强大的开源深度学习框架,以其灵活性和动态计算图而广受欢迎。以下是 PyTorch 的详细教程,涵盖从基础到实际应用的使用方法。


1. 安装与导入

1.1 安装 PyTorch

访问 PyTorch 官方网站,根据系统、Python 版本和 CUDA 支持选择安装命令。

常用安装命令:

pip install torch torchvision torchaudio
1.2 导入库
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

2. PyTorch 基础

2.1 张量(Tensor)

张量是 PyTorch 的核心数据结构,可以看作是一个高维数组。

# 创建张量
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])# 基本运算
c = a + b
print(c)  # 输出 tensor([5., 7., 9.])# 随机张量
random_tensor = torch.rand((2, 3))  # 2行3列随机数
print(random_tensor)

输出结果

tensor([5., 7., 9.])
tensor([[0.9980, 0.2970, 0.5257],[0.8807, 0.0471, 0.7896]])
2.2 自动求导

PyTorch 提供动态计算图支持自动求导。

x = torch.tensor(2.0, requires_grad=True)
y = x**2 + 3*x + 4y.backward()  # 自动求导
print(x.grad)  # 输出 dy/dx = 2*x + 3 = 7.0

输出结果

tensor(7.)

3. 数据加载

PyTorch 提供强大的数据加载功能。

import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader# 下载并加载 MNIST 数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)

4. 构建神经网络

4.1 使用 nn.Module 构建模型
import torch.nn as nnclass SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(28 * 28, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 10)self.softmax = nn.Softmax(dim=1)def forward(self, x):x = x.view(-1, 28 * 28)  # 展平输入x = self.relu(self.fc1(x))x = self.softmax(self.fc2(x))return xmodel = SimpleNN()print(model)

输出结果

SimpleNN((fc1): Linear(in_features=784, out_features=128, bias=True)(relu): ReLU()(fc2): Linear(in_features=128, out_features=10, bias=True)(softmax): Softmax(dim=1)
)

5. 模型训练

5.1 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()  # 交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.001)
5.2 训练循环
for epoch in range(5):for images, labels in train_loader:optimizer.zero_grad()  # 梯度清零outputs = model(images)loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新权重print(f"Epoch {epoch+1}, Loss: {loss.item()}")

完整代码

from torch import nn, optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoaderclass SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(28 * 28, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 10)self.softmax = nn.Softmax(dim=1)def forward(self, x):x = x.view(-1, 28 * 28)  # 展平输入x = self.relu(self.fc1(x))x = self.softmax(self.fc2(x))return xmodel = SimpleNN()# 下载并加载 MNIST 数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)criterion = nn.CrossEntropyLoss()  # 交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.001)for epoch in range(5):for images, labels in train_loader:optimizer.zero_grad()  # 梯度清零outputs = model(images)loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新权重print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

输出结果

Epoch 1, Loss: 1.482284665107727
Epoch 2, Loss: 1.4968496561050415
Epoch 3, Loss: 1.5289227962493896
Epoch 4, Loss: 1.4832825660705566
Epoch 5, Loss: 1.5070817470550537

6. 模型评估

6.1 在测试集上评估
test_data = MNIST(root='./data', train=False, transform=transform)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)correct = 0
total = 0
with torch.no_grad():  # 禁用梯度计算for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Test Accuracy: {correct / total * 100:.2f}%")

输出结果

Test Accuracy: 10.32%

7. GPU 加速

PyTorch 支持使用 GPU 加速。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)# 将数据也移动到 GPU
for images, labels in train_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)

8. 保存与加载模型

8.1 保存模型
torch.save(model.state_dict(), 'model.pth')
8.2 加载模型
model = SimpleNN()
model.load_state_dict(torch.load('model.pth'))
model.eval()  # 切换到评估模式

9. 实际案例

9.1 CIFAR-10 图像分类
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms# CIFAR-10 数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_data = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2)self.fc1 = nn.Linear(16 * 16 * 16, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = x.view(-1, 16 * 16 * 16)x = self.fc1(x)return xmodel = CNN()
# 后续训练步骤类似

10. PyTorch 优势总结

  1. 动态计算图:支持动态构建与修改模型。
  2. 灵活性:适合研究和开发,易于调试。
  3. 强大的社区支持:广泛的教程、示例和扩展工具。

通过实践,PyTorch 能够帮助用户更好地理解和实现深度学习算法!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/479558.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

打开windows 的字符映射表

快捷键 win R 打开资源管理器 输入: charmap 点击确定

Elasticsearch对于大数据量(上亿量级)的聚合如何实现?

大家好,我是锋哥。今天分享关于【Elasticsearch对于大数据量(上亿量级)的聚合如何实现?】面试题。希望对大家有帮助; Elasticsearch对于大数据量(上亿量级)的聚合如何实现? 1000道 …

解决首次加载数据空指针异常

起初效果&#xff1a; 使用async...await异步加载数据 最终效果&#xff1a; 代码&#xff1a; <template><div class"user-list-container"><!-- 加载状态 --><div v-if"loading" class"loading">正在加载用户数据..…

RTR Chaptor10 上

局部光照 面光源光泽材质一般光源形状 环境光照球面函数和半球函数简单表格形式球面基底球面径向基函数球面高斯函数球谐函数其他球面表示 半球基底AHD 基底辐射法向映射/《半条命2 》基底半球谐波 / H-Basis 在第9章中&#xff0c;我们讨论了基于物理的材质的相关理论&#xf…

若依框架部署在网站一个子目录下(/admin)问题(

部署在子目录下首先修改vue.config.js文件&#xff1a; 问题一&#xff1a;登陆之后跳转到了404页面问题&#xff0c;解决办法如下&#xff1a; src/router/index.js 把404页面直接变成了首页&#xff08;大佬有啥优雅的解决办法求告知&#xff09; 问题二&#xff1a;退出登录…

3DMAX带孔绞线插件使用方法详解

3DMAX带孔绞线插件&#xff0c;一键创建自定义形状孔洞的绞线。 【版本要求】 3dMax 2016及更高 【安装方法】 解压缩后将插件文件&#xff08;.mcg&#xff09;拖动到3dMax视口中&#xff0c;自动完成安装。 【使用方法】 1.用样条线绘制孔洞的1/2形状。 2.点击3dMax“…

阿里发布 EchoMimicV2 :从数字脸扩展到数字人 可以通过图片+音频生成半身动画视频

EchoMimicV2 是由阿里蚂蚁集团推出的开源数字人项目&#xff0c;旨在生成高质量的数字人半身动画视频。以下是该项目的简介&#xff1a; 主要功能&#xff1a; 音频驱动的动画生成&#xff1a;EchoMimicV2 能够使用音频剪辑驱动人物的面部表情和身体动作&#xff0c;实现音频与…

urllib3只支持OpenSSL1.1.1

1 现象 urllib3 v2.0 only supports OpenSSL 1.1.1, currently the ssl module is compiled with OpenSSL 1.1.0j 20 Nov 2018.2 解决方法 降低urllib3的版本。 从pycharm中&#xff0c;先卸载原有的urllib3版本。 菜单“File|Settings|Project:python|Project Interprete…

spark 写入mysql 中文数据 显示?? 或者 乱码

目录 前言 Spark报错&#xff1a; 解决办法&#xff1a; 总结一下&#xff1a; 报错&#xff1a; 解决&#xff1a; 前言 用spark写入mysql中&#xff0c;查看中文数据 显示?? 或者 乱码 Spark报错&#xff1a; Sat Nov 23 19:15:59 CST 2024 WARN: Establishing SSL…

微信小程序条件渲染与列表渲染的全面教程

微信小程序条件渲染与列表渲染的全面教程 引言 在微信小程序的开发中,条件渲染和列表渲染是构建动态用户界面的重要技术。通过条件渲染,我们可以根据不同的状态展示不同的内容,而列表渲染则使得我们能够高效地展示一组数据。本文将详细讲解这两种渲染方式的用法,结合实例…

ctfshow

1,web153 大小写绕过失败 使用.user.ini 来构造后⻔ php.ini是php的⼀个全局配置⽂件&#xff0c;对整个web服务起作⽤&#xff1b;⽽.user.ini和.htaccess⼀样是⽬录的配置⽂件&#xff0c;.user.ini就是⽤户⾃定义的⼀个php.ini&#xff0c;我们可以利⽤这个⽂件来构造后⻔和…

【大数据学习 | Spark-SQL】Spark-SQL编程

上面的是SparkSQL的API操作。 1. 将RDD转化为DataFrame对象 DataFrame&#xff1a; DataFrame是一种以RDD为基础的分布式数据集&#xff0c;类似于传统数据库中的二维表格。带有schema元信息&#xff0c;即DataFrame所表示的二维表数据集的每一列都带有名称和类型。这样的数…

DINO-X:一种用于开放世界目标检测与理解的统一视觉模型

摘要 本文介绍了由IDEA Research开发的DINO-X&#xff0c;这是一个统一的以对象为中心的视觉模型&#xff0c;具有迄今为止最佳的开放世界对象检测性能。DINO-X采用了与Grounding DINO 1.5 [47]相同的基于Transformer的编码器-解码器架构&#xff0c;以追求面向开放世界对象理…

MySQL系列之远程管理(安全)

导览 前言Q&#xff1a;如何保障远程登录安全一、远程登录的主要方式1. 用户名/口令2. SSH3. SSL/TLS 二、使用TLS协议加密连接1. 服务端2. 客户端 结语精彩回放 前言 在我们的学习或工作过程中&#xff0c;作为开发、测试或运维人员&#xff0c;经常会通过各类客户端软件&…

扫振牙刷设计思路以及技术解析

市面上目前常见的就两种&#xff1a;扫振牙刷和超声波牙刷 为了防水&#xff0c;表面还涂上了一层防水漆 一开始的电池管理芯片&#xff0c;可以让充电更加均衡。 如TP4056 第一阶段以恒流充电&#xff1b;当电压达到预定值时转入第二阶段进行恒压充电&#xff0c;此时电流逐…

Hot100 - 除自身以外数组的乘积

Hot100 - 除自身以外数组的乘积 最佳思路&#xff1a; 此问题的关键在于通过两次遍历&#xff0c;分别计算从左侧和右侧开始的累积乘积&#xff0c;以此避免使用额外的除法操作。 时间复杂度&#xff1a; 该算法的时间复杂度为 O(n)&#xff0c;因为我们只需要遍历数组两次。…

一个vue项目如何运行在docker

将 Vue.js 应用程序通过 Docker 发布是一个非常常见的做法&#xff0c;它可以帮助你轻松地部署应用到不同的环境中。下面是一个简单的指南&#xff0c;介绍如何为 Vue.js 项目创建 Dockerfile 并进行构建和运行。 第一步&#xff1a;安装 Docker 确保你的开发机器上已经安装了…

【公益接口】不定时新增接口,仅供学习

文章日期&#xff1a;2024.11.24 使用工具&#xff1a;Python 文章类型&#xff1a;公益接口 文章全程已做去敏处理&#xff01;&#xff01;&#xff01; 【需要做的可联系我】 AES解密处理&#xff08;直接解密即可&#xff09;&#xff08;crypto-js.js 标准算法&#xff…

使用phpStudy小皮面板模拟后端服务器,搭建H5网站运行生产环境

一.下载安装小皮 小皮面板官网下载网址&#xff1a;小皮面板(phpstudy) - 让天下没有难配的服务器环境&#xff01; 安装说明&#xff08;特别注意&#xff09; 1. 安装路径不能包含“中文”或者“空格”&#xff0c;否则会报错&#xff08;例如错误提示&#xff1a;Cant cha…

DolphinDB 登陆伦敦!携手中英人工智能协会共话 AI 未来

11 月 9 日&#xff0c;DolphinDB 联合中英人工智能协会&#xff08;CBAIA&#xff09;在全球人工智能中心、今年三位诺贝尔奖得主的诞生地——伦敦盖茨比计算神经科学中心举办 AI 技术交流会。来自人工智能、量化投资等领域的 150 多位全球专家齐聚一堂&#xff0c;共同探讨人…