基于 R 语言的深度学习——简单回归案例

近年来深度学习在人工智能领域飞速发展,各行业的学者、研究人员纷纷涌入研究热潮。本文将从 R 语言角度来介绍深度学习并解决以下几个问题:

  • 什么是深度学习?

  • 相关深度学习包有哪些?

  • 如何配置工作环境?

  • 如何使用神经网络建立模型?

本文主要解决:如何使用神经网络建立模型?,其他推文可见:基于 R 语言的深度学习——配置环境;基于 R 语言的深度学习——简介及资料分享。

简单神经网络建模

本节将从一个简单的回归例子[1]来介绍如何在 R 中使用 keras 包进行深度学习。

该案例是在 CPU 下进行的。如果你的设备有 GPU,并想用 GPU 训练模型。你不需要修改以下的代码,只需前期安装 GPU 版本的 TensorFlow,默认情况下,运算会优先使用 GPU。

知识点包括:

  1. 数据导入与数据处理。

  2. 构建神经网络。

  3. 训练神经网络。

  4. 评估模型的准确性。

  5. 保存并恢复创建的模型。

加载包

library(keras)
library(mlbench) #使用内部数据
library(dplyr)
library(magrittr)

加载数据

使用 1970 年波士顿 506 个人口普查区的住房数据作为例子。该数据集一共有 14 列,506 行。其中,因变量为 medv(自有住房的中位数报价, 单位 1000 美元),自变量为其他 13 个变量,包括:CRIM (城镇人均犯罪率)、ZN(占地面积超过 25000 平方英尺的住宅用地比例)、INDUS (每个城镇非零售业务的比例)等。

data("BostonHousing")
data <- BostonHousing
data %<>% mutate_if(is.factor, as.numeric)
knitr::kable(head(data[,1:12])) #由于呈现不了所有列,这里只展示 12 列的前 6 行数据。
crimzninduschasnoxrmagedisradtaxptratiob
0.0063182.3110.5386.57565.24.090129615.3396.9
0.027307.0710.4696.42178.94.967224217.8396.9
0.027307.0710.4697.18561.14.967224217.8392.8
0.032402.1810.4586.99845.86.062322218.7394.6
0.069002.1810.4587.14754.26.062322218.7396.9
0.029902.1810.4586.43058.76.062322218.7394.1

数据处理

首先,对 506 条数据进行划分。随机选择其中的 70% 数据作为训练样本,另外 30% 数据作为测试样本。

# 构建矩阵
data <- as.matrix(data)
dimnames(data) <- NULL# 数据集划分
set.seed(1234)
ind <- sample(2, nrow(data), replace = T, prob = c(.7, .3)) #从 1,2 中有放回抽取一个数,概率分别为(0.7,0.3)。
training <- data[ind==1,1:13]
test <- data[ind==2, 1:13]
trainingtarget <- data[ind==1, 14]
testtarget <- data[ind==2, 14]

此外,由于各个特征的数据范围不同,直接输入到神经网络中,会让网络学习变得困难。所以在进行网络训练之前,先将该数据集进行特征标准化:输入数据中的每个特征,将其减去特征平均值并除以标准差,使得特征值以 0 为中心,且具有单位标准差。在 R 中可以使用 scale() 函数实现该效果。

数据集 BostonHousing 也可以直接通过 keras 包中的 dataset_boston_housing() 进行加载,并且已经提前划分好了训练集和测试集。本文使用的是 mlbench 包中数据集进行加载,主要是呈现划分数据集的过程。

# 数据标准化
m <- colMeans(training)
s <- apply(training, 2, sd)
training <- scale(training, center = m, scale = s)
test <- scale(test, center = m, scale = s)

构建模型

由于可用样本量很少,这里构建一个非常小的网络。使用 keras_model_sequential() 定义模型,并设置了 1 个隐藏层和 1 个输出层。激活函数为 relu。

model <- keras_model_sequential() %>% layer_dense(units = 10, activation = 'relu', input_shape = c(13)) %>%layer_dense(units = 1)

通过 summary() 查看模型个层形状和参数,可以看到,总共包含 151 个参数。

summary(model)
## Model: "sequential"
## ________________________________________________________________________________
##  Layer (type)                       Output Shape                    Param #     
## ================================================================================
##  dense_1 (Dense)                    (None, 10)                      140         
##                                                                                 
##  dense (Dense)                      (None, 1)                       11          
##                                                                                 
## ================================================================================
## Total params: 151
## Trainable params: 151
## Non-trainable params: 0
## ________________________________________________________________________________

编译模型

编译主要需要设定三个部分:

  1. 损失函数:训练期间需要最小化的目标函数;

  2. 优化器:对数据和损失函数进行自我更新;

  3. 监控度量:训练和测试期间的评价标准。

该例子是一个典型的回归问题,我们使用的损失函数是均方误差(Mean Square Error,MSE),即预测和目标之间差异的平方。使用均方根传播方法(Root Mean Squared Propagation,RMSProp)作为该模型的优化器。使用平均绝对误差(Mean Absolute Error,MAE)来监控网络。

model %>% compile(loss = 'mse', #损失函数optimizer = 'rmsprop', #优化器metrics = 'mae'#监控度量)

优化器有很多种,详细介绍可参考:理论[2]、实践[3];损失函数和评价度量的选择,可以参考这篇博客[4]。

拟合模型

拟合模型时,RStudio 的 Viewer 会出现:随着迭代变化的损失函数值。如下所示:

mymodel <- model %>%fit(training,trainingtarget,epochs = 200,batch_size = 32,validation_split = 0.2)

图片

图中的 loss 是指损失函数,val_loss 是指验证集下的损失函数(代码中设置的验证集划分比例为 0.2)。mae 表示平均绝对误差,而 val_mae 表示验证集下的平均绝对误差。图中可以看到,随着训练轮数的增加,mae 与 loss 在不断减小并趋于稳定。

评估模型

使用 evaluate() 评估模型,给出预测结果。计算真实值和预测值的均方误差。

model %>% evaluate(test, testtarget)
##   loss    mae 
## 42.571  4.185
pred <- predict(model,test) #预测结果
mean((testtarget-pred)^2) #计算均方误差
## [1] 42.57

通过 ggplot2[5] 包将预测结果和真实结果可视化。

library(ggplot2)
library(viridis)
library(ggsci)
ev_data = data.frame("Item" = seq(1,length(pred)),"Value" = c(testtarget,pred),"Class" = rep(c("True","Pred"),each = length(pred)))
ggplot(ev_data) +geom_line(aes(Item,Value,col = Class,lty = Class)) +scale_color_aaas() +theme_bw() + theme(panel.grid = element_blank())

图片

总体来看,预测结果还算不错,但是也有一些预测结果和真实值相差甚远。主要原因是,我们没有调整参数来使模型达到最优的效果。读者可以使用 K 折验证的方法来寻找最有的参数,例如:训练轮数,神经网络层数,各层神经元数等。具体案例可以见 《Deep Learning with R》的第 3.6.4 节[6]。

存储 / 加载模型

为了保存 Keras 模型以供未来使用,使用 save_model_tf() 函数保存模型。

save_model_tf(object = model, filepath = "BostonHousing_model") #保存模型

使用 load_model_tf() 函数加载模型,并对新数据集(下面使用测试集)进行预测。

reloaded_model <- load_model_tf("BostonHousing_model") #加载模型
predict(reloaded_model, test) #对新数据集进行预测

相关拓展

以上例子介绍了如何使用神经网络来处理简单问题(数据量较小的回归问题),但在实际过程中可能面临种种困难,包括:如何对数据进行预处理,如何进行特征筛选,如何解决过拟合问题,如何调整参数等。

由于笔者时间和能力有限,这篇推文不能一一给出系统的解决方案。相关资源可见:基于 R 语言的深度学习——简介及资料分享,以供读者翻阅。

该系列还会继续写下去,欢迎来我的公众号《庄闪闪的 R 语言手册》关注新内容。

参考资料

[1]

回归例子: https://github.com/fmmattioni/deep-learning-with-r-notebooks/blob/master/notebooks/3.6-predicting-house-prices.Rmd

[2]

理论: http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf

[3]

实践: https://keras.io/api/optimizers/

[4]

博客: https://machinelearningmastery.com/how-to-choose-loss-functions-when-training-deep-learning-neural-networks/

[5]

ggplot2: https://ggplot2.tidyverse.org/

[6]

3.6.4 节: https://github.com/fmmattioni/deep-learning-with-r-notebooks/blob/master/notebooks/3.6-predicting-house-prices.Rmd

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/415141.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

微服务--认识微服务

微服务架构的演变 1. 单体架构&#xff08;Monolithic&#xff09; 阶段描述&#xff1a;在单体应用时代&#xff0c;整个应用程序被设计为一个项目&#xff0c;并在一个进程内运行。这种架构方式开发简单&#xff0c;便于集中管理&#xff0c;但随着应用的复杂化&#xff0c…

【算法 动态规划 简单多状态 dp 问题】打家劫舍题型

打家劫舍题型 按摩师 (easy)解题思路代码 打家劫舍II &#xff08;medium&#xff09;解题思路代码 删除并获得点数&#xff08;medium&#xff09;解题思路代码 按摩师 (easy) 题目链接 该题是打家劫舍的变形 解题思路 状态表示 分析: 注意题目, 对于当天的预约, 可以接受…

C语言遇见的一些小问题

问题如下&#xff1a; 1&#xff1a;为什么这样的代码为报错 #define _CRT_SECURE_NO_WARNINGS #include <iostream> #include <algorithm> #include <cstdio> #include<string> #include<stdlib.h> using namespace std; int main() {int i …

mysql5.7 TIMESTAMP NOT NULL DEFAULT ‘0000-00-00 00:00:00‘ 换版8版本 引发的问题

mysql5.7 TIMESTAMP NOT NULL DEFAULT 0000-00-00 00:00:00 换版引发的问题 问题背景sql_mode上机演示5.78.4 问题背景 在项目mysql版本由5.7 换版到8.4版本后&#xff0c;我们进行回归测试时&#xff0c;却发现一个积年代码报错了&#xff0c;是数据库插入报的错 xxx can not…

华为IS-IS实验及配置

AR1配置 #进入ISIS进程 isis 1 #配置设备类型为Level-1is-level level-1 #定义区域和System-ID等信息network-entity 49.0001.0010.0000.0001.00 #ISIS邻居命名is-name AR1 #接口配置IP和启用ISIS interface GigabitEthernet0/0/0ip address 10.1.12.1 255.255.255.0 isis ena…

数仓基础(六):离线与实时数仓区别和建设思路

文章目录 离线与实时数仓区别和建设思路 一、离线数仓与实时数仓区别 二、实时数仓建设思路 离线与实时数仓区别和建设思路 ​​​​​​​一、离线数仓与实时数仓区别 离线数据与实时数仓区别如下&#xff1a; 对比方面 离线数仓 实时数仓 架构选择 传统大数据架构 …

Ribbon负载均衡底层原理

springcloude服务实例与服务实例之间发送请求&#xff0c;首先根据服务名注册到nacos&#xff0c;然后发送请求&#xff0c;nacos可以根据服务名找到对应的服务实例。 SpringCloudRibbon的底层采用了一个拦截器&#xff0c;拦截了openfeign发出的请求&#xff0c;对地址做了修…

【C++】将myString类中能够实现的操作都实现一遍

myString.h #ifndef MYSTERAM_H #define MYSTERAM_H #include <iostream> #include<cstring> using namespace std; class myString { private:char *str; //字符串int size; //字符串容量char error[20] "error"; public://无参构造myString():siz…

深入解析FPGA在SOC设计中的核心作用

在集成系统SoC设计中&#xff0c;CPU核心的嵌入至关重要&#xff0c;以实现软硬件的有效交互。此过程中涉及到功能IP&#xff08;如图像处理、无线通信等&#xff09;的验证和整合&#xff0c;利用硬件描述语言(HDL)来实现可编程逻辑(FPGA)&#xff0c;并通过验证技术提高设计效…

Django Admin对自定义的计算字段进行排序

通常&#xff0c;Django会为模型属性字段&#xff0c;自动添加排序功能。当你添加计算字段时&#xff0c;Django不知道如何执行order_by&#xff0c;因此它不会在该字段上添加排序功能。 如果要在计算字段上添加排序&#xff0c;则必须告诉Django需要排序的内容。你可以通过在…

【机器学习】神经网络、隐藏层的基本概念、如何选择隐藏层数量以及胶囊网络对神经网络的影响

引言 神经网络是机器学习的一种方法&#xff0c;它通过模拟人脑神经元的工作原理来构建算法 文章目录 引言一、神经网络1.1 定义1.2 主要组成部分1.3 工作原理1.4 应用1.5 类型1.6 优化算法1.7 总结 二、隐藏层2.1 定义2.2 隐藏层的作用2.3 隐藏层的数量和大小2.4 隐藏层的结构…

单片机-STM32 时钟(六)

1.时钟的概念 在我们单片机中&#xff0c;时钟主要是用于 提供一个工作的频率&#xff0c;时钟信号越大&#xff0c;设备执行的速度就越快。 时钟---处理器运行的频率---72MHZ Dbus--数据总线 AHB--总线桥 APB2--外设总线2&#xff08;时钟&#xff09; APB1--外设总线1&a…

基于Java+SpringBoot+Vue的植物健康系统

基于JavaSpringBootVue的植物健康系统 前言 ✌全网粉丝20W,csdn特邀作者、博客专家、CSDN[新星计划]导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345; 某信 gzh 搜索【智能编程小助手】获取项…

[C++11#44](一) 统一的列表初始化 | 声明 | STL中一些变化 | emplace的优化 | move

目录 0. 回忆 1.隐式类型转化 特性 1.统一的列表初始化 1.{}初始化 2.2 std::initializer_list 二、声明 1.auto 2.decltype 3.nullptr 宏定义的例子 使用 const 和 enum 替代宏 4. 范围 for 循环 5.final与override final 关键字 override 关键字 示例代码 智…

wpf prism 《3》 弹窗 IOC

传统的弹窗 这种耦合度高 new 窗体() . Show(); new 窗体() . ShowDialog(); 利用Prism 自动的 IOC 弹窗的 必须 必须 必须 页面控件 弹窗的 必须 必须 必须 页面控件 弹窗的 必须 必须 必须 页面控件 弹窗的 必须 必须 必须 页面控件 弹窗的 必须 必须 必须 页面控件 》》否…

惠中科技光伏清洗剂:绿色清洁,引领光伏行业新潮流

在当今全球能源转型的大潮中&#xff0c;光伏产业作为绿色能源的重要组成部分&#xff0c;正以前所未有的速度蓬勃发展。然而&#xff0c;随着光伏板在户外环境的长时间暴露&#xff0c;其表面不可避免地会积累灰尘、鸟粪、油污等污染物&#xff0c;严重影响光伏板的透光率和发…

tiny_qemu模拟qemu虚拟化原理

一、模仿一个x86平台虚机 cpu虚拟化原理来源于Linux虚拟化KVM-Qemu分析&#xff08;四&#xff09;之CPU虚拟化&#xff08;2&#xff09; 笔者就实现了下相关操作。看汇编是在x86平台下操作的&#xff0c;其中两个文件分别是 1.tiny_kernel.S start: /* Hello */ mov …

数据安全与个人信息保护的辨析

文章目录 前言一、合规1、合规的目标导向原则2、监管平衡的原则二、基础设施1、公共基础设施2、企业基础设施三、数据流通1、数据生产要素是数字化时代生产要素的变革理论2、数据产品的保护源自于数据产品的价值四、产品与服务1、数据安全与网络安全2、数据安全的分类分级与数据…

Unity(2022.3.41LTS) - UI详细介绍-Scrollbar(滚动条)

目录 零.简介 一、基本功能与用途 二、组件介绍 三、使用方法 四、优化和注意事项 五.和滑动条的区别 零.简介 在 Unity 中&#xff0c;滚动条&#xff08;Scrollbar&#xff09;是一种用于实现滚动功能的 UI 组件。 一、基本功能与用途 滚动内容&#xff1a;主要用于…

NeRF原理学习

一个2020年的工作我现在才来学习并总结它的原理&#xff0c;颇有种“时过境迁”的感觉。这篇总结是基于NeRF原文 NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis 阅读理解后写的&#xff0c;作用是以后如果记不太清了可以回忆。 目的&应用 先说…