TensorFlow 与 PyTorch 的直观区别

背景

TensorFlow 与 PyTorch 都是比较流行的深度学习框架。tf 由谷歌在 2015 年发布,而 PyTorch 则是 Facecbook AI 研究团队 2016 年在原来 Torch 的基础上发布的。

tf 采用的是静态计算图。这意味着在执行任何计算之前,你需要先定义好整个计算图,之后再执行。这种方式适合大规模生产环境,可以优化计算图以提高效率。tf 的早期版本比较复杂,但在集成 Keras 库之后相当容易上手。

PyTorch 的设计目标是提供一个易于使用、灵活且高效的框架,所以采用的是动态图,特别适合研究人员和开发人员进行快速实验和原型设计。它强调灵活性和易用性,采用了动态图机制,使得代码更接近于 Python 原生风格,便于调试和修改。PyTorch 使用更加像原来的 Python 代码。

总体来说,TensorFlow 更加容易上手,PyTorch 更加灵活且需要自己操作,例如 tf 提供了训练的方法,而 PyTorch 则需要手动训练:

# TensorFlow
model.fit(train_images, train_labels, epochs=5, batch_size=128)

而 PyTorch 需要先手动将数据分批,然后自己编写训练和测试函数,函数详细内容后面会写:

# PyTorch
epochs = 5
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print("Done!")

示例

MNIST 数据集

对于两者的示例,仍然使用 MNIST 手写数字集来做演示。MNIST 是 28 * 28 大小的单通道(黑白)手写数字图片,每个像素亮度值为 0 ~ 255。

首先加载数据集:

# TensorFlow
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

PyTorch 除了加载数据,还需要定义 DataLoader,因为其提供的框架更加底层,需要自己定义加载器,包括数据打包,转换等更加灵活的功能。

# PyTorch
to_tensor = transforms.Compose([transforms.ToTensor()])
training_data = datasets.MNIST(root="data", train=True, download=True, transform=to_tensor)
test_data = datasets.MNIST(root="data", train=False, download=True, transform=to_tensor)train_dataloader = DataLoader(training_data, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=128, shuffle=True)

定义神经网络

TensorFlow 集成了 Keras,这里可以看见对神经网络的定义非常简洁明了:

# TensorFlow
model = keras.Sequential([layers.Flatten(input_shape=(28, 28)),  # Flatten the input image to a vector of size 784layers.Dense(512, activation="relu"),layers.Dense(10, activation="softmax")
])

在 PyTorch 中,更倾向于将神经网络打包成一个类,这个类由框架提供的网络模型继承。

# PyTorch
class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.fc1 = nn.Linear(in_features=28 * 28, out_features=512)# Output layer with 10 neurous for classificationself.fc2 = nn.Linear(in_features=512, out_features=10)def forward(self, x):x = self.flatten(x) # Flatten the input tensorx = nn.functional.relu(self.fc1(x)) # ReLU activation after first layerx = self.fc2(x)return xprint(model)

PyTorch 可以检查神经网络模型

NeuralNetwork((flatten): Flatten(start_dim=1, end_dim=-1)(fc1): Linear(in_features=784, out_features=512, bias=True)(fc2): Linear(in_features=512, out_features=10, bias=True)
)

训练

TensorFlow 在网络模型定义完成后,指定损失函数和优化器,来使模型训练让参数收敛。

model.compile(optimizer="rmsprop",loss="sparse_categorical_crossentropy",metrics=["accuracy"])model.fit(train_images, train_labels, epochs=5, batch_size=128)
Epoch 1/5
469/469 [==============================] - 3s 5ms/step - loss: 5.4884 - accuracy: 0.8992
Epoch 2/5
469/469 [==============================] - 2s 4ms/step - loss: 0.6828 - accuracy: 0.9538
Epoch 3/5
469/469 [==============================] - 2s 4ms/step - loss: 0.4634 - accuracy: 0.9662
Epoch 4/5
469/469 [==============================] - 2s 4ms/step - loss: 0.3742 - accuracy: 0.9730
Epoch 5/5
469/469 [==============================] - 2s 4ms/step - loss: 0.2930 - accuracy: 0.9774

而在 PyTorch 中则更加复杂,需要自己定义训练函数和测试函数,并不断训练,框架只提供了一些基础的训练所需函数:

# Define loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.RMSprop(model.parameters(), lr=0.001)def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationloss.backward()optimizer.step()optimizer.zero_grad()if batch % 100 == 0:loss, current = loss.item(), (batch + 1) * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {100*(correct):>0.1f}%, Avg loss: {test_loss:>8f}\n")epochs = 5
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print("Done!")
Epoch 1
-------------------------------
loss: 2.319496  [  128/60000]
loss: 0.443893  [12928/60000]
loss: 0.253097  [25728/60000]
loss: 0.106967  [38528/60000]
loss: 0.208099  [51328/60000]
Test Error: Accuracy: 95.9%, Avg loss: 0.136413Epoch 2
-------------------------------
loss: 0.102781  [  128/60000]
loss: 0.089506  [12928/60000]
loss: 0.177988  [25728/60000]
loss: 0.058250  [38528/60000]
loss: 0.131542  [51328/60000]
Test Error: Accuracy: 97.3%, Avg loss: 0.087681Epoch 3
-------------------------------
loss: 0.100185  [  128/60000]
loss: 0.021117  [12928/60000]
loss: 0.058108  [25728/60000]
loss: 0.070415  [38528/60000]
loss: 0.050509  [51328/60000]
Test Error: Accuracy: 97.7%, Avg loss: 0.075040Epoch 4
-------------------------------
loss: 0.051223  [  128/60000]
loss: 0.049627  [12928/60000]
loss: 0.025712  [25728/60000]
loss: 0.090960  [38528/60000]
loss: 0.046523  [51328/60000]
Test Error: Accuracy: 97.9%, Avg loss: 0.066997Epoch 5
-------------------------------
loss: 0.012129  [  128/60000]
loss: 0.019118  [12928/60000]
loss: 0.057839  [25728/60000]
loss: 0.031959  [38528/60000]
loss: 0.020570  [51328/60000]
Test Error: Accuracy: 98.0%, Avg loss: 0.062022Done!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/12462.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL常用数据类型和表的操作

文章目录 (一)常用数据类型1.数值类2.字符串类型3.二进制类型4.日期类型 (二)表的操作1查看指定库中所有表2.创建表3.查看表结构和查看表的创建语句4.修改表5.删除表 (三)总代码 (一)常用数据类型 1.数值类 BIT([M]) 大小:bit M表示每个数的位数,取值范围为1~64,若…

DeepSeekMoE:迈向混合专家语言模型的终极专业化

一、结论写在前面 论文提出了MoE语言模型的DeepSeekMoE架构,目的是实现终极的专家专业化(expert specialization)。通过细粒度的专家分割和共享专家隔离,DeepSeekMoE相比主流的MoE架构实现了显著更高的专家专业化和性能。从较小的2B参数规模开始&#x…

寻迹传感器模块使用说明

产品用途: 1、电度表脉冲数据采样 2、传真机碎纸机纸张检测 3、障碍检测 4、黑白线检测 产品介绍: 1、采用 TCRT5000 红外反射传感器 2、检测反射距离:1mm~25mm 适用 3、比较器输出,信号干净,波形好,驱…

java项目验证码登录

1.依赖 导入hutool工具包用于创建验证码 <dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.5.2</version></dependency> 2.测试 生成一个验证码图片&#xff08;生成的图片浏览器可…

Baklib探讨如何通过内容中台提升组织敏捷性与市场竞争力

内容概要 在数字化转型的浪潮中&#xff0c;内容中台已经成为企业提升市场响应速度和竞争力的关键所在。内容中台不仅是信息处理的集结地&#xff0c;更是促进资源高效整合和灵活应用的重要平台。通过构建一个高效的内容中台架构&#xff0c;企业能够更好地应对不断变化的市场…

Java基础——分层解耦——IOC和DI入门

目录 三层架构 Controller Service Dao ​编辑 调用过程 面向接口编程 分层解耦 耦合 内聚 软件设计原则 控制反转 依赖注入 Bean对象 如何将类产生的对象交给IOC容器管理&#xff1f; 容器怎样才能提供依赖的bean对象呢&#xff1f; 三层架构 Controller 控制…

Spring中@Conditional注解详解:条件装配的终极指南

一、为什么要用条件装配&#xff1f; 在实际开发中&#xff0c;我们经常需要根据不同的运行环境、配置参数或依赖情况动态决定是否注册某个Bean。例如&#xff1a; 开发环境使用内存数据库&#xff0c;生产环境连接真实数据库 当存在某个类时才启用特定功能 根据配置文件开关…

Redis代金卷(优惠卷)秒杀案例-多应用版

Redis代金卷(优惠卷)秒杀案例-单应用版-CSDN博客 上面这种方案,在多应用时候会出现问题,原因是你通过用户ID加锁 但是在多应用情况下,会出现两个应用的用户都有机会进去 让多个JVM使用同一把锁 这样就需要使用分布式锁 每个JVM都会有一个锁监视器,多个JVM就会有多个锁监视器…

ros 发布Topic

1、确定话题名称和消息类型 自定义话题名称&#xff0c;消息类型根据发送消息需要从std_msgs中查找确定 2、在main函数中通过NodeHander发布话题 // 创建一个NodeHandle对象&#xff0c;用于与ROS系统进行交互ros::NodeHandle nh;// 创建一个Publisher对象&#xff0c;用于发…

86.(2)攻防世界 WEB PHP2

之前做过&#xff0c;回顾一遍&#xff0c;详解见下面这篇博客 29.攻防世界PHP2-CSDN博客 既然是代码审计题目&#xff0c;打开后又不显示代码&#xff0c;肯定在文件里 <?php // 首先检查通过 GET 请求传递的名为 "id" 的参数值是否严格等于字符串 "admi…

毕业设计:基于深度学习的高压线周边障碍物自动识别与监测系统

目录 前言 课题背景和意义 实现技术思路 一、算法理论基础 1.1 卷积神经网络 1.2 目标检测算法 1.3 注意力机制 二、 数据集 2.1 数据采集 2.2 数据标注 三、实验及结果分析 3.1 实验环境搭建 3.2 模型训练 3.2 结果分析 最后 前言 &#x1f4c5;大四是整个大学…

AI取代人类?

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

刷题记录 动态规划-7: 63. 不同路径 II

题目&#xff1a;63. 不同路径 II 难度&#xff1a;中等 给定一个 m x n 的整数数组 grid。一个机器人初始位于 左上角&#xff08;即 grid[0][0]&#xff09;。机器人尝试移动到 右下角&#xff08;即 grid[m - 1][n - 1]&#xff09;。机器人每次只能向下或者向右移动一步。…

深度求索DeepSeek横空出世

真正的强者从来不是无所不能&#xff0c;而是尽我所能。多少有关输赢胜负的缠斗&#xff0c;都是直面本心的搏击。所有令人骄傲振奋的突破和成就&#xff0c;看似云淡风轻寥寥数语&#xff0c;背后都是数不尽的焚膏继晷、汗流浃背。每一次何去何从的困惑&#xff0c;都可能通向…

51c视觉~CV~合集10

我自己的原文哦~ https://blog.51cto.com/whaosoft/13241694 一、CV创建自定义图像滤镜 热图滤镜 这组滤镜提供了各种不同的艺术和风格化光学图像捕捉方法。例如&#xff0c;热滤镜会将图像转换为“热图”&#xff0c;而卡通滤镜则提供生动的图像&#xff0c;这些图像看起来…

【论文复现】粘菌算法在最优经济排放调度中的发展与应用

目录 1.摘要2.黏菌算法SMA原理3.改进策略4.结果展示5.参考文献6.代码获取 1.摘要 本文提出了一种改进粘菌算法&#xff08;ISMA&#xff09;&#xff0c;并将其应用于考虑阀点效应的单目标和双目标经济与排放调度&#xff08;EED&#xff09;问题。为提升传统粘菌算法&#xf…

C++基础(2)

目录 1. 引用 1.1 引用的概念和定义 1.2 引用的特性 1.3 引用的使用 2. 常引用 3. 指针和引用的关系 4. 内联函数inline 5. nullptr 1. 引用 1.1 引用的概念和定义 引用不是新定义一个变量&#xff0c;而是给已存在变量取了一个别名&#xff0c;编译器不会为引用变量开…

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】2.29 NumPy+Scikit-learn(sklearn):机器学习基石揭秘

2.29 NumPyScikit-learn&#xff1a;机器学习基石揭秘 目录 #mermaid-svg-46l4lBcsNWrqVkRd {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-46l4lBcsNWrqVkRd .error-icon{fill:#552222;}#mermaid-svg-46l4lBcsNWr…

圆上取点(例题)

Protecting The Earth &#xff08;圆内取点&#xff09; 题目描述&#xff1a; 给定 K (地球上的人数)&#xff0c;你必须制作一个保护罩来保护他们。(地球上的人数&#xff09;&#xff0c;你必须制作一个保护罩来保护他们。 已知一个人只能站在整数的坐标上&#xff0c…

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】2.19 线性代数核武器:BLAS/LAPACK深度集成

2.19 线性代数核武器&#xff1a;BLAS/LAPACK深度集成 目录 #mermaid-svg-yVixkwXWUEZuu02L {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-yVixkwXWUEZuu02L .error-icon{fill:#552222;}#mermaid-svg-yVixkwXWUEZ…