吴恩达机器学习COURSE2 WEEK2

COURSE2 WEEK2

模型训练的细节

  • 定义模型,即指定如何在给定输入特征 x x x 以及参数 w w w b b b 的情况下计算输出

  • 指定损失函数 L ( f w ⃗ , b ( x ⃗ ) , y ) L(f_{\vec w, b}(\vec x),y) L(fw ,b(x ),y)

    指定成本函数 J ( w ⃗ , b ) = 1 m ∑ i = 1 m L ( f w ⃗ , b ( x ⃗ ( i ) ) , y ( i ) ) J(\vec w, b) = \frac{1}{m} \sum_{i=1}^mL(f_{\vec w, b}(\vec x^{(i)}),y^{(i)}) J(w ,b)=m1i=1mL(fw ,b(x (i)),y(i))

  • 使用算法来进行训练,找到最小化成本函数对应的参数

定义模型

该步骤指定了神经网络的整个框架

  • 通过Dense用来定义神经网络层
    • units表示这个层的神经元个数
    • activation表示使用的激活函数
  • 通过Sequential来把各个神经网络层连接起来
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense##### 定义模型 #####
model = Sequential([Dense(units=25, activation='sigmoid'),Dense(units=15, activation='sigmoid'),Dense(units=1, activation='sigmoid')
])

指定损失函数

由于是分类任务,所以使用交叉熵函数(二元交叉熵)

使用compile()定义

##### 定义损失函数 #####
from tensorflow.keras.losses import BinaryCrossentropymodel.compile(loss=BinaryCrossentropy())

keras中的losses集成了很多损失函数:

  • BinaryCrossentropy() 二元交叉熵损失函数
  • MeanSquaredError() 均方误差函数

训练模型

一般而言,我们使用梯度下降来进行模型的训练

直接调用Tensorflow中的fit函数

##### 模型训练 #####
model.fit(X, y, epochs = 100)  # epoch 为迭代次数

激活函数

常见激活函数

使用激活函数,加在每一层神经网络层之间,可以使得我们的模型具有更强大的能力

Sigmoid激活函数和Relu激活函数

在这里插入图片描述

加入激活函数后,可以使得我们的模型具有非线性的特点,从而使得我们的模型更加复杂,能够学习到更复杂的模型,拟合结果更加准确

激活函数的使用

对于激活函数的选择,要根据具体的输出等要适当的选择

对于输出层:

  • 如果是二分类问题,可以使用Sigmoid函数
  • 如果是回归任务,预测房价等,因为房价永远不会是负值,因此可以使用Relu函数
  • 对于更一般的回归任务,即输出可正可负,可以使用线性激活函数

对于隐藏层:

  • Relu的使用较多,且使用效果较好

使用Relu较多的原因:

  • 计算比Sigmoid函数简单
  • Relu函数图像只有一边比较平坦,而Sigmoid函数图像有两边比较平坦。当函数图像比较平坦时,会造成梯度下降算法较慢

多分类问题

多分类问题指的是,在分类问题中,最终输出的类别是多个,即 y y y 得输出有多个(有限)

在这里插入图片描述

上图左侧是二分类任务,右侧是多分类任务

注意与聚类算法地区别

  • 聚类算法属于无监督学习,即没有标签
  • 多分类是监督学习,有标签

Softmax函数

当分类任务是二分类时,我们可以使用Sigmoid函数作为输出层地激活函数,即最终地输出为 0 或 1

而当我们的任务是多分类任务时,这时就有多个输出的值,输出层的激活函数就要使用Softmax函数,保证我们的模型能够输出多个类别的概率

Softmax回归公式
P ( y = i ∣ x ⃗ ) = a i = e z i ∑ k = 1 N e z k P(y=i | \vec x) =a_i = \frac{e^{z_i}}{\sum_{k=1}^{N}e^{z_k}} P(y=ix )=ai=k=1Nezkezi
并且最终 ∑ k = 1 N P ( y = k ∣ x ⃗ ) = 1 \sum_{k=1}^{N}P(y=k | \vec x) = 1 k=1NP(y=kx )=1

对于Softmax函数,最终输出的是 a 1 … a n a_1 \dots a_n a1an 的值

成本函数

参考逻辑回归的损失函数,对于多分类任务,其损失函数为
l o s s ( a I , a 2 , … , a N , y ) = { − log ⁡ a 1 i f y = 1 − log ⁡ a 2 i f y = 2 ⋮ − log ⁡ a N i f y = N loss(a_I, a_2, \dots ,a_N, y) = \begin{equation} \begin{cases} -\log a_1 \ \ \ \ \ if \ \ y = 1 \\ -\log a_2 \ \ \ \ \ if \ \ y = 2 \\ \vdots \\ -\log a_N \ \ \ \ \ if \ \ y = N \end{cases} \end{equation} loss(aI,a2,,aN,y)= loga1     if  y=1loga2     if  y=2logaN     if  y=N
Tensorflow中对应函数SparseCrossEntropy()

Softmax实现的改进

在计算机中,由于每次计算都可能会有浮点数的舍入,即存在一定的误差,因此不同的计算顺序,可能会导致得到不同的结果

例如,对于二分类的逻辑回归,如果我们按照以下方式计算:
S t e p 1 : a = g ( z ) = 1 1 + e z S t e p 2 : l o s s = − y log ⁡ ( a ) − ( 1 − y ) log ⁡ ( 1 − a ) Step1: \ \ a = g(z) = \frac{1}{1+e^z} \\ Step2: \ \ loss = -y\log (a) - (1-y)\log (1-a) Step1:  a=g(z)=1+ez1Step2:  loss=ylog(a)(1y)log(1a)
那么 a a a 作为中间值,在第一步的赋值和第二步的带入过程中,会产生精度的损失,最终会造成误差较大的结果

而,如果我们直接把 z z z 带入到损失函数里计算,即:
l o s s = − y log ⁡ ( 1 1 + e z ) − ( 1 − y ) log ⁡ ( 1 − 1 1 + e z ) loss = -y\log (\frac{1}{1+e^z}) - (1-y)\log (1-\frac{1}{1+e^z}) loss=ylog(1+ez1)(1y)log(11+ez1)
这样就会避免一些误差的产生,使得我们的结果更加准确

在代码中的改变

model.compile(loss = BinaryCrossEntropy(from_logits=True))

多个输出的分类

对于一个图片里,我们要同时检测是否有人、是否有轿车、是否有公交车

对于这种多个输出的分类任务,我们可以通过一个神经网络的训练,在最后使用三个Sigmoid函数,分别用来输出是否有人,是否有轿车,是否有公交车,即转化为了一个三个二分类任务

在这里插入图片描述

Adam优化算法

对于传统的梯度下降算法,当其学习率较小时,会迭代较多次数,而当其学习率较大时,收敛过程中就会产生震荡,为了改进这个缺点,出现了Adam算法,可以自动的调节学习率

Adam算法并没有全局单一的使用学习率,而是将学习率自适应:

  • 如果我们的梯度下降过程中参数一直朝着一个大致的方向移动,那么Adam算法就会增加学习率,加快算法的步伐
  • 如果参数一致来回不断的震荡,那么Adam算法就会降低学习率,使得算法能够较好的朝收敛方向移动

Adam算法的使用

model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3),loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logts=True))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/397101.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux系统驱动(十三)Linux内核定时器

文章目录 一、内核定时器原理二、定时器API三、使用定时器让LED灯闪烁四、使用定时器对按键进行消抖 一、内核定时器原理 内核当前时间通过jiffies获取,它是内核时钟节拍数,在linux内核启动的时候,jiffies开始(按照一定频率&…

【数据结构】顺序结构实现:特殊完全二叉树(堆)+堆排序

二叉树 一.二叉树的顺序结构二.堆的概念及结构三.堆的实现1.堆的结构2.堆的初始化、销毁、打印、判空3.堆中的值交换4.堆顶元素5.堆向上调整算法:实现小堆的插入6.堆向下调整算法:实现小堆的删除7.堆的创建1.堆向上调整算法:建堆建堆的时间复…

CentOS 安装Redis

在 CentOS 安装 Redis 操作系统:centos-7.9.2009-Core 1. 更新系统 首先,确保你的系统是最新的: sudo yum update -y2. 安装 EPEL 仓库 Redis 可能不在默认的 CentOS 仓库中,因此你需要安装 EPEL(Extra Packages f…

TCP详解及其在音视频传输中的应用

传输控制协议(TCP,Transmission Control Protocol)是互联网协议栈中至关重要的传输层协议。它提供了可靠、面向连接的数据传输服务,广泛应用于各种网络应用中。对于音视频传输,虽然TCP协议并不是最常用的传输协议&…

LVS实验——部署DR模式集群

目录 一、实验环境 二、配置 1、LVS 2、router 3、client 4、RS 三、配置策略 四、测试 1.Director服务器采用双IP桥接网络,一个是VPP,一个DIP 2.Web服务器采用和DIP相同的网段和Director连接 3.每个Web服务器配置VIP 4.每个web服务器可以出外网…

《Advanced RAG》-11-RAG查询分类和细化

总结 文章介绍了两种高级的检索增强生成(RAG)技术:自适应 RAG 和 RQ-RAG,以及它们在问题复杂性学习和查询细化方面的应用和优势,以及如何通过小型模型的训练来提高这些技术的性能。 摘要 传统 RAG 技术虽然能够减少大型…

「MyBatis」数据库相关操作2

🎇个人主页 🎇所属专栏:Spring 🎇欢迎点赞收藏加关注哦! #{} 和 ${} 我们前面都是采用 #{} 对参数进行赋值,实际上也可以用 ${} 客户端发送⼀条 SQL 给服务器后,大致流程如下: 1.…

51单片机之动态数码管显示

一、硬件介绍 LED数码管是一种由多个发光二极管(LED)封装在一起,形成“8”字型的显示器件。它广泛用于仪表、时钟、车站、家电等场合,用于显示数字、字母或符号。 通过控制点亮a b c d e f g dp来显示数字,本实验开发板…

前端八股文笔记【三】

JavaScript 基础题型 1.JS的基本数据类型有哪些 基本数据类型:String,Number,Boolean,Nndefined,NULL,Symbol,Bigint 引用数据类型:object NaN是一个数值类型,但不是…

十三、代理模式

文章目录 1 基本介绍2 案例2.1 Sortable 接口2.2 BubbleSort 类2.3 SortTimer 类2.4 Client 类2.5 Client 类的运行结果2.6 总结 3 各角色之间的关系3.1 角色3.1.1 Subject ( 主体 )3.1.2 RealObject ( 目标对象 )3.1.3 Proxy ( 代理 )3.1.4 Client ( 客户端 ) 3.2 类图 4 动态…

Java网络编程、TCP、UDP、Socket通信---初识版

标题 InetAddress----IP地址端口号协议(UDP/TCP)JAVA操作-UDP一发一收模式多发多收 JAVA操作-TCP一发一收多发多收 实现群聊功能BS架构线程池优化 InetAddress----IP地址 端口号 协议(UDP/TCP) JAVA操作-UDP 一发一收模式 多发多收…

React 性能优化

使用 useMemo 缓存数据 (类似 vue 的 computed)使用 useCallback 缓存函数异步组件 ( lazy )路由懒加载( lazy )服务器渲染 SSR用 CSS 模拟 v-show 循环渲染添加 key使用 Fragment (空标签)减少层级 不在JSX 中定义函数&#xff0…

一篇教会搭建ELK日志分析平台

日志分析的概述 日志分析是运维工程师解决系统故障,发现问题的主要手段日志主要包括系统日志、应用程序日志和安全日志系统运维和开发人员可以通过日志了解服务器软硬件信息、检查配置过程中的错误及错误发生的原因经常分析日志可以了解服务器的负荷,性…

使用本地大模型从论文PDF中提取结构化信息

1 安装ollama 点击前往网站 https://ollama.com/ ,下载ollama软件,支持win、Mac、linux 2 下载LLM ollama软件目前支持多种大模型, 如阿里的(qwen、qwen2)、meta的(llama3、llama3.1), 读者根据自己电脑…

C语言:求最大数不用数组

(1)题目: 输入一批正数用空格隔开,个数不限,输入0时结束循环,并且输出这批整数的最大值。 (2)代码: #include "stdio.h" int main() {int max 0; // 假设输入…

Qt——多线程

一、QThread类 如果要设计多线程程序,一般是从QThread继承定义一个线程类,并重新定义QThread的虚函数 run() ,在函数 run() 里处理线程的事件循环。 应用程序的线程称为主线程,创建的其他线程称为工作线程。主线程的 start() 函数…

计算机网络408考研 2014

1 计算机网络408考研2014年真题解析_哔哩哔哩_bilibili 1 111 1 11 1

MyBatis:Maven,Git,TortoiseGit,Gradle

1,Maven Maven是一个非常优秀的项目管理工具,采用一种“约定优于配置(CoC)”的策略来管理项目。使用Maven不仅可以把源代码构建成可发布的项目(包括编译、打包、测试和分发),还可以生成报告、生…

短视频SDK,支持Flutter跨平台框架,加速产品上线进程

在数字内容爆炸式增长的今天,短视频已成为连接用户、传递情感、展现创意的重要桥梁。为助力开发者快速融入这股潮流,美摄科技匠心打造了一款专为Flutter框架优化的短视频SDK解决方案,旨在降低技术门槛,加速产品迭代,让…

主题与分区

主题和分区是Kafka的两个核心概念,分区的划分不仅为Kafka提供了可伸缩性、水平扩展的功能,还通过多副本机制来为Kafka提供数据冗余以提高数据可靠性。 主题创建 主题和分区都是提供给上层用户的抽象,而在副本层面或更加准确地说是Log层面&a…