深度学习框架:Pytorch与Keras的区别与使用方法

  

☁️主页 Nowl

🔥专栏《机器学习实战》 《机器学习》

📑君子坐而论道,少年起而行之 

文章目录

Pytorch与Keras介绍

Pytorch

模型定义

模型编译

模型训练

输入格式

完整代码

Keras

模型定义

模型编译

模型训练

输入格式

完整代码

区别与使用场景

结语


Pytorch与Keras介绍

pytorch和keras都是一种深度学习框架,使我们能很便捷地搭建各种神经网络,但它们在使用上有一些区别,也各自有其特性,我们一起来看看吧

Pytorch

模型定义

我们以最简单的网络定义来学习pytorch的基本使用方法,我们接下来要定义一个神经网络,包括一个输入层,一个隐藏层,一个输出层,这些层都是线性的,给隐藏层添加一个激活函数Relu,给输出层添加一个Sigmoid函数

import torch
import torch.nn as nnclass SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(1, 32)self.relu = nn.ReLU()self.fc2 = nn.Linear(32, 1)self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.Sigmoid(x)return x

模型编译

我们在之前的机器学习文章中反复提到过,模型的训练是怎么进行的呢,要有一个损失函数与优化方法,我们接下来看看在pytorch中怎么定义这些

import torch.optim as optim# 实例化模型对象
model = SimpleNet()
# 定义损失函数
criterion = nn.MSELoss()# 定义优化器
learning_rate = 0.01
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

我们上面创建的神经网络是一个类,所以我们实例化一个对象model,然后定义损失函数为mse,优化器为随机梯度下降并设置学习率

模型训练

# 创建随机输入数据和目标数据
input_data = torch.randn((100, 1))  # 100个样本,每个样本有1个特征
target_data = torch.randn((100, 1))  # 100个样本,每个样本有1个目标值# 训练模型
epochs = 100for epoch in range(epochs):# 前向传播output = model(input_data)# 计算损失loss = criterion(output, target_data)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()

以上步骤是先创建了一些随机样本,作为模型的训练集,然后定义训练轮次为100次,然后前向传播数据集,计算损失,再优化,如此反复

输入格式

关于输入格式是很多人在实战中容易出现问题的,对于pytorch创建的神经网络,我们的输入内容是一个torch张量,怎么创建呢

data = torch.Tensor([[1], [2], [3]])

很简单对吧,上面这个例子创建了一个torch张量,有三组数据,每组数据有1个特征

我们可以把这个数据输入到训练好的模型中,得到输出结果,如果输出不是torch张量,代码就会报错

完整代码

import torch
import torch.nn as nn
import torch.optim as optim# 定义一个简单的神经网络模型
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(1, 32)self.relu = nn.ReLU()self.fc2 = nn.Linear(32, 1)self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.sigmoid(x)return xmodel = SimpleNet()
criterion = nn.MSELoss()# 定义优化器
learning_rate = 0.01
optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 创建随机输入数据和目标数据
input_data = torch.randn((100, 1))  # 100个样本,每个样本有1个特征
target_data = torch.randn((100, 1))  # 100个样本,每个样本有1个目标值# 训练模型
epochs = 100for epoch in range(epochs):# 前向传播output = model(input_data)# 计算损失loss = criterion(output, target_data)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()data = torch.Tensor([[1], [2], [3]])
prediction = model(data)print(prediction)

可以看到模型输出了三个预测值

注意,这个任务本身没有意义,因为我们的训练集是随机生成的,这里主要学习框架的使用方法

Keras

我们在这里把和上面相同的神经网络结构使用keras框架实现一遍

模型定义

from keras.models import Sequential
from keras.layers import Densemodel = Sequential([Dense(32, input_dim=1, activation='relu'),Dense(1, activation='sigmoid')
])

注意这里也是一层输入层,一层隐藏层,一层输出层,和pytorch一样,输入层是隐式的,我们的输入数据就是输入层,上述代码定义了一个隐藏层,输入维度是1,输出维度是32,还定义了一个输出层,输入维度是32,输出维度是1,和pytorch环节的模型结构是一样的 

模型编译

那么在Keras中模型又是怎么编译的呢

model.compile(loss='mse', optimizer='sgd')

非常简单,只需要这一行代码 ,设置损失函数为mse,优化器为随机梯度下降

模型训练

模型的训练也非常简单

# 训练模型
model.fit(input_data, target_data, epochs=100)

 因为我们已经编译好了损失函数和优化器,在fit里只需要输入数据,输出数据和训练轮次这些参数就可以训练了

输入格式

对于Keras模型的输入,我们要把它转化为numpy数组,不然会报错

data = np.array([[1], [2], [3]])

完整代码

from keras.models import Sequential
from keras.layers import Dense
import numpy as np# 定义模型
model = Sequential([Dense(32, input_dim=1, activation='relu'),Dense(1, activation='sigmoid')
])# 创建随机输入数据和目标数据
input_data = np.random.randn(100, 1)  # 100个样本,每个样本有10个特征
target_data = np.random.randn(100, 1)  # 100个样本,每个样本有5个目标值# 编译模型
model.compile(loss='mse', optimizer='sgd')
# 训练模型
model.fit(input_data, target_data, epochs=10)data = np.array([[1], [2], [3]])prediction = model(data)
print(prediction)

可以看到,同样的任务,Keras的代码量小很多

区别与使用场景

Keras代码量少,使用便捷,适用于快速实验和快速神经网络设计

而pytorch由于结构是由类定义的,可以更加灵活地组建神经网络层,这对于要求细节的任务更有利,同时,pytorch还采用动态计算图,使得模型的结构可以在运行时根据输入数据动态调整,但这个特点我还没有接触到,之后可能会详细讲解

结语

Keras和Pytorch都各有各的优点,请读者根据需求选择,同时有些深度学习教程偏向于使用某一种框架,最好都学习一点,以适应不同的场景

 

感谢阅读,觉得有用的话就订阅下本专栏吧 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/205625.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

4G5G防爆执法记录仪、防爆智能安全帽赋能智慧燃气,可视化巡检巡线,安全生产管控

随着燃气使用的普及,燃气安全问题日益突出。传统应急安全问题处理方式暴露出以下问题: 应急预案不完善:目前一些燃气企业的应急预案存在实用性不高、流程不清晰等问题,导致在紧急情况下难以迅速启动和有效执行。 部门协同不流畅…

Less的函数的介绍

文章目录 前言描述style.less输出后言 前言 hello world欢迎来到前端的新世界 😜当前文章系列专栏:Sass和Less 🐱‍👓博主在前端领域还有很多知识和技术需要掌握,正在不断努力填补技术短板。(如果出现错误,…

数据仓库数据管理模型

数据仓库分为贴源层、数据仓库层、数据服务层,有人叫做数仓数据模型,或者叫"数据管理模型”。 我们为什么要进行数据分层管理,下图的优点介绍已经说得比较明确,再补充几点: 保障数据一致性:上层的数…

C#图像处理OpenCV开发指南(CVStar,03)——基于.NET 6的图像处理桌面程序开发实践第一步

1 Visual Studio 2022 开发基于.NET 6的OpenCV桌面程序 1.1 为什么选择.NET 6开发桌面应用? 选择 .NET 6(最早称为 .NET Core)而非 Frameworks.NET 的理由是:(1)跨平台;已经支持Windows,Linux…

Redis 事件轮询

1 Redis 为什么快 数据存在内存中, 直接操作内存中的数据单线程处理业务请求避免了多线的上下文切换, 锁竞争等弊端使用 IO 多路复用支撑更高的网络请求使用事件驱动模型, 通过事件通知模式, 减少不必要的等待… 这些都是 Redis 快的原因。 但是这些到了代码层面是如何实现的呢…

【UGUI】中Content Size Fitter)组件-使 UI 元素适应其内容的大小

官方文档:使 UI 元素适应其内容的大小 - Unity 手册 必备组件:Content Size Fitter 通常,在使用矩形变换定位 UI 元素时,应手动指定其位置和大小(可选择性地包括使用父矩形变换进行拉伸的行为)。 但是&a…

PHP项目用docker一键部署

公司新项目依赖较多,扩展版本参差不一,搭建环境复杂缓慢,所以搭建了一键部署的功能。 docker-compose build 构建docker docker-compose up 更新docker docker-compose up -d 后台运行docker docker exec -it docker-php-1 /bin/bas…

idea方法注释模版设置

方法上面的注释模版: Template text: ** Description $desc$ $param$ $return$* Aauthor yimeng* date $DATE$ $TIME$ **/param: groovyScript("def result ;def params \"${_1}\".replaceAll([\\\\[|\\\\]|\\\\s], ).split(,).toLis…

.net core 连接数据库,通过数据库生成Modell

1、安装EF Core Power Tools:打开Vs开发工具→扩展→管理扩展 2、(切记执行这步之前确保自己的代码不存在编写或者编译错误!)安装完成后在你需要创建数据库实体的项目文件夹上面单击右键,找到EF Core 工具(必须安装扩展之和才会有…

redis的集群,主从复制,哨兵

redis的高可用 在Redis中,实现高可用的技术主要包括持久化、主从复制、哨兵和集群,下面分别说明它们的作用,以及解决了什么样的问题。 持久化: 持久化是最简单的高可用方法(有时甚至不被归为高可用的手段)…

手机笔记工具怎么加密?

选择用手机笔记工具记事,大家可以记录很多学习笔记、读书笔记、私密日记等,手机作为随身携带的设备,记录相关的笔记比较快捷且方便,当手机笔记中记录的内容比较私密时,大家担心手机笔记会被别人误看,这时候…

面试篇Flink

一:为什么学习flink? 相比较spark,flink对于实时这块,使用过流的方式进行实现。 spark是通过批流的方式实现,通过减少批的时间间隔来实现流的功能。 二:什么是flink? flink是一个针对于实时进…

C++ 通过CryptoPP计算Hash值

Crypto (CryptoPP) 是一个用于密码学和加密的 C 库。它是一个开源项目,提供了大量的密码学算法和功能,包括对称加密、非对称加密、哈希函数、消息认证码 (MAC)、数字签名等。Crypto 的目标是提供高性能和可靠的密码学工具,以满足软件开发中对…

基于UDP的网络聊天室

客户端 #include <myhead.h> //定义存储信息结构体 typedef struct _MSG {char code; //操作码&#xff1a;L表示登录C表示群聊S表示系统消息S表示退出char name[128]; char txt[256];}msg_t;//定义保存客户端网络信息的链表 typedef struct _ADDR {struct sockaddr_i…

玄学调参实践篇 | 深度学习模型 + 预训练模型 + 大模型LLM

&#x1f60d; 这篇主要简单记录一些调参实践&#xff0c;无聊时会不定期更新~ 文章目录 0、学习率与batch_size判断1、Epoch数判断2、判断模型架构是否有问题3、大模型 - 计算量、模型、和数据大小的关系4、大模型调参相关论文经验总结5、训练时模型的保存 0、学习率与batch_s…

Spring不再支持Java8了

在今天新建模块的时候发现了没有java8的选项了&#xff0c;结果一查发现在11月24日&#xff0c;Spring不再支持8了&#xff0c;这可怎么办呢&#xff1f;我们可以设置来源为阿里云https://start.aliyun.com/ 。 java8没了 设置URL为阿里云的地址

c++——string字符串____迭代器.范围for.修改遍历容量操作

在成为大人的路上喘口气. 目录 &#x1f393;标准库类型string &#x1f393;定义和初始化string对象 &#x1f4bb;string类对象的常见构造 &#x1f4bb;string类对象的不常见构造 &#x1f4bb;读写string对象 &#x1f393; string类对象的修改操作 &#x1f4…

爬虫http代理有什么用处?怎么高效使用HTTP代理?

在进行网络爬虫工作时&#xff0c;我们有时会遇到一些限制&#xff0c;比如访问频率限制、IP被封等问题。这时&#xff0c;使用HTTP代理可以有效地解决这些问题&#xff0c;提高爬虫的工作效率。本文将介绍爬虫HTTP代理的用处以及如何高效地使用HTTP代理。 一、爬虫HTTP代理的用…

【数据结构】单链表---C语言版

【数据结构】单链表---C语言版 一、顺序表的缺陷二、链表的概念和结构1.概念&#xff1a; 三、链表的分类四、链表的实现1.头文件&#xff1a;SList.h2.链表函数&#xff1a;SList.c3.测试函数&#xff1a;test.c 五、链表应用OJ题1.移除链表元素&#xff08;1&#xff09;题目…

京东数据产品推荐-京东数据挖掘-京东平台2023年10月滑雪装备销售数据分析

如今&#xff0c;滑雪正成为新一代年轻人的新兴娱乐方式&#xff0c;借助北京冬奥会带来的发展机遇&#xff0c;我国冰雪经济已逐渐实现从小众竞技运动到大众时尚生活方式的升级。由此也带动滑雪相关生意的增长&#xff0c;从滑雪服靴到周边设备&#xff0c;样样都需要消费者掏…