深度学习入门-第4章-神经网络的学习

学习就是从训练数据中自动获取最优权重参数的过程。引入损失函数这一指标,学习的目的是找出使损失函数达到最小权重参数。使用函数斜率的梯度法来找这个最小值。

人工智能有两派,一派认为实现人工智能必须用逻辑和符号系统,自顶向下看问题;另一派认为通过仿造人脑可以达到人工智能,自底向上看问题。前一派是“想啥来啥”,后一派是“吃啥补啥”。前者偏唯心,后者偏唯物。两派一直是人工智能领域“两个阶级、两条路线”的斗争,这斗争有时还是你死我活。今天学习的是神经网络派。

4.1 从数据中学习

4.1.1 数据驱动

数据是机器学习的命根子。机器学习避免人为介入,通过数据发现模式。比如识别手写数字5,可以从图像中提取特征量,再用机器学习学习这些特征量的模式。其中图像转换为向量时使用的特征量仍由人设计,不同问题需要人工考虑不同的特征量。

神经网络(深度学习)称为端到端学习,图像中的特征量也由机器来学习。不管识别5还是识别狗,神经网络都是通过不断学习数据,尝试发现模式。

4.1.2 训练数据和测试数据

追求的模型泛化能力。训练数据也叫监督数据。一套数据集,无法获得正确的评价。要避免对某数据集的过拟合

4.2 损失函数

损失函数表示神经网络恶劣程度指标。一般乘上一个负值。

4.2.1 均方误差

4.2.2 交叉熵误差

4.2.3 mini-batch学习

从训练数据中选出小批量学习,称为mini-batch学习。随机选择小批量做为全体训练数据的近似值。

4.2.4 mini-batch版交叉熵误差实现

4.2.5 为何要设定损失函数

为了找到使损失函数值尽可能小的地方,需要计算参数导数,以导数为指引,逐步更新参数值。

对权重参数的损失函数求导,表示的是:如果稍微改变这个权重的值,损失函数的值如何变化

1)导数为负,权重参数正向变化,可以减小损失函数的值。

2)导数为正,权重参数负向变化,可以减小损失函数的值。

3)导数为0,权重参数哪个方向变化,损失函数都不变化。

不能直接使用识别精度,是因为大部分地方的参数导数为0,导致参数无法更新。

为啥是0?比如识别精度为32%, 微调权重参数,识别精度仍旧是32%,即使改变,也不会联系变化,而是33%,34%等离散值。而损失函数会连续变化。作为激活函数的阶跃函数也有类似特征,大部分地方导数为0,所以不能使用阶跃函数,要使用斜率连续变化的sigmoid函数。

4.3 数值微分

什么是梯度。

4.3.1 导数

采用中心差分

(f(x+h)-f(x-h))/(2*h)

利用微小的差分求导的过程称为数值微分numerical differentiation

数学公式推导求导称为解析性求导。如y=^X{^{^{2}}} 公式求导为\frac{dy}{dx}=2x,这样算出的是没有误差的真导数

4.3.2 数值微分的例子

数值微分的计算结果和真导数误差很小。

4.3.3 偏导数

有两个变量的情况。或者多个变量。有多个变量的函数的导数称为偏导数

偏导数将多个变量中的某个变量定为目标变量,其他变量固定为某个值

4.4 梯度

由全部变量的偏导数汇总而成的向量称为梯度(gradient)

4.4.1 梯度法

使用梯度寻找损失函数最小值的方法就是梯度法。梯度是各点处函数值减小最多的方向。方向往往不是函数的最小值。是极小值。

不断沿梯度方向前进,逐渐减小函数值的过程,叫梯度法 gradient method

学习率:一次学习,在多大程度上更新参数。

梯度下降法实现:

def gradient_descent(f, init_x, lr=0.01,step_num=100):x = init_xfor i in range(step_num):grad = numerical_gradient(f,x)x -= lr * gradreturn x

4.4.2 神经网络的梯度

神经网络梯度:损失函数关于权重参数的梯度。形状与W相同。

求梯度代码

import sys,os
sys.path.append(os.pardir)
import numpy as np
from common.functions import softmax,cross_entropy_error
from common.gradient  import numerical_gradientclass simpleNet:def __init__(self):self.W = np.random.randn(2,3)def predict(self,x):return np.dot(x,self.W)def loss(self,x,t):z = self.predict(x)y = softmax(z)loss = cross_entropy_error(y,t)return loss

4.5 学习算法的实现

4.5.1 二层神经网络类

#two_layer_net.py
import sys,os
sys.path.append(os.pardir)
from common.functions import *
from common.gradient import numerical_gradientclass TwoLayerNet:def __init__(self,input_size,hidden_size,output_size,weight_init_std=0.01):self.params = {}self.params['W1'] = weight_init_std * np.random.randn(input_size,hidden_size)self.params['b1'] = np.zeros(hidden_size)self.params['W2'] = weight_init_std * np.random.randn(hidden_size,output_size)self.params['b2'] = np.zeros(output_size)def predict(self,x):W1,W2 = self.params['W1'], self.params['W2']b1,b2 = self.params['b1'], self.params['b2']a1 = np.dot(x,W1) + b1z1 = sigmoid(a1)a2 = np.dot(z1,W2) + b2y = softmax(a2)return ydef loss(self, x, t):y = self.predict(x)return cross_entropy_error(y,t)def accuracy(self,x,t):y = self.predict(x)y = np.argmax(y, axis=1)t = np.argmax(t, axis=1)accuracy = np.sum(y==t)/float(x.shape[0])return accuracydef numerical_gradient(self,x,t):loss_W = lambda W:self.loss(x,t)grads = {}grads['W1'] = numerical_gradient(loss_W, self.params['W1']grads['b1'] = numerical_gradient(loss_W, self.params['b1']grads['W2'] = numerical_gradient(loss_W, self.params['W2']grads['b2'] = numerical_gradient(loss_W, self.params['b2']return grads

4.5.2  mini-batch学习

# train_neuralnet.py
import numpy as np
from dataset.mnist import load_mnist
from two_layer_net import TwoLayerNet(x_train,t_train),(x_test,t_test) = load_mnist(normalize=True,one_hot_label=True)
train_loss_list=[]#超参数
iters_num = 10000
train_size = x_train.shape[0]
batch_size = 100
learning_rate = 0.1
network = TwoLayerNet(input_size=784,hidden_size=50,output_size=10)
for i in range(iters_num):# 获取mini-batchbatch_mask = np.random.choice(train_size,batch_size)x_batch = x_train[batch_mask]t_batch = t_train[batch_mask]#计算梯度grad = network.numerical_gradient(x_batch,t_batch)#grad = network.gradient(x_batch,t_batch)  #高速版,下一章介绍反向传播法再说#更新参数for key in ('W1','b1','W2','b2'):network.params[key] -= learning_rate * grad[key]#记录学习过程loss = network.loss(x_batch,t_batch)train_loss_list.append(loss)

4.5.3 基于测试数据评价

epoch是一个单位,所有训练数据被使用一次时的更新次数。10000训练数据,mini-batch为100,共执行梯度下降法10000/100=100次,100次就是一个epoch

import numpy as np
from dataset.mnist import load_mnist
from two_layer_net import TwoLayerNet(x_train,t_train),(x_test,t_test) = load_mnist(normalize=True,one_hot_label=True)
train_loss_list = []
train_acc_list = []
test_acc_list = []
#平均每个epoch的重复次数
iter_per_epoch = max(train_size/batch_size,1)
#超参数
iters_num = 10000
batch_size=100
learning_rate=0.1
network=TwoLayerNet(input_size=784,hidden_size=50,output_size=10)
for i in range(iters_num):batch_mask = np.random.choice(train_size,batch_size)x_batch = x_train[batch_mask]t_batch = t_train[batch_mask]grad = network.numerical_gradient(x_batch,t_batch)for key in ('W1','b1','W2','b2'):network.params[key] -= learning_rate*grad[key]loss = network.loss(x_batch,t_batch)train_loss_list.append(loss)#计算每个epoch的识别精度if i % iter_per_epoch == 0:train_acc = network.accuracy(x_train,t_train)test_acc = network.accuracy(x_test,t_test)train_acc_list.append(train_acc)test_acc_list.append(test_acc)print("train acc,test acc | " + str(train_acc) + "," + str(test_acc))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/409051.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Sass实现网页背景主题切换

Sass 实现网页背景主题切换 前言准备工作一、 简单的两种主题黑白切换1.定义主题2. 添加主题切换功能3. 修改 data-theme 属性 二、多种主题切换1. 定义主题2. 动态生成 CSS 变量1.遍历列表2.遍历映射3.高级用法 3. 设置默认主题4. 切换功能HTML 三、多种主题多种样式切换1. 定…

Java数组的定义与使用

目录 1. 数组的基本概念 1.1为什么要使用数组 1.2 什么是数组 1.3 数组的创建及初始化 1.3.1 数组的创建 1.3.2 数组的初始化 1. 动态初始化 2. 静态初始化 1.4 数组的使用 1.4.1 数组中元素访问 1.4.2 遍历数组 2. 数组是引用类型 2.1 基本类型变量与引用类型变量…

【C++从小白到大牛】C++智能指针的使用、原理和分类

目录 1、我们为什么需要智能指针? 2、内存泄露 2.1 什么是内存泄漏,内存泄漏的危害 2.2如何避免内存泄漏 总结一下: 3.智能指针的使用及原理 3.1 RAII 3.2关于深拷贝和浅拷贝更深层次的理解: 3.3 std::auto_ptr 3.4 std::unique_pt…

Springboot里集成Mybatis-plus、ClickHouse

🌹作者主页:青花锁 🌹简介:Java领域优质创作者🏆、Java微服务架构公号作者😄 🌹简历模板、学习资料、面试题库、技术互助 🌹文末获取联系方式 📝 Springboot里集成Mybati…

Overleaf参考文献由 BibTex 转 \bibitem 格式

目录 Overleaf参考文献由 BibTex 转 \bibitem 格式一、获取引用论文的BibTex二、编写引用论文对应的bib文件三、编写生成bibitem的tex文件四、转化bibitem格式 参考资料 Overleaf参考文献由 BibTex 转 \bibitem 格式 一、获取引用论文的BibTex 搜索论文引用点击BibTex 跳转出…

怎样快速搭建 Linux 虚拟机呢?(vagrant 篇)

作为一名Coder(程序员或码农),供职于中小型互联网公司,而你恰恰偏向于服务端,那么,产品部署在生产环境的艰巨任务,便毫无疑问的落在你身上了。 只有大厂(大型互联网)企业…

Ps:首选项 - 界面

Ps菜单:编辑/首选项 Edit/Preferences 快捷键:Ctrl K Photoshop 首选项中的“界面” Interface选项卡可以定制 Photoshop 的界面外观和行为,从而创建一个最适合自己工作习惯和需求的工作环境。这些设置有助于提高工作效率,减轻眼…

Simple RPC - 07 从零开始设计一个服务端(下)_RPC服务的实现

文章目录 PreRPC服务实现服务注册请求处理 设计: 请求分发机制 Pre Simple RPC - 01 框架原理及总体架构初探 Simple RPC - 02 通用高性能序列化和反序列化设计与实现 Simple RPC - 03 借助Netty实现异步网络通信 Simple RPC - 04 从零开始设计一个客户端&#…

# 利刃出鞘_Tomcat 核心原理解析(九)-- Tomcat 安全

利刃出鞘_Tomcat 核心原理解析(九)-- Tomcat 安全 一、Tomcat专题 - Tomcat安全 - 配置安全 1、 删除 tomcat 的 webapps 目录下的所有文件,禁用 tomcat 管理界面. 如下目录均可删除: D:\java-test\apache-tomcat-8.5.42-wind…

数据结构-KMP算法

先解决 前缀与后缀串的最长匹配长度信息(前缀或后缀都不能取整体)。如下 位置6的前缀最长串就是abcab(不能取全部,即不能为abcabc) 位置6的后缀最长串就是bcabc(不能取全部,即不能为abcabc)

[Linux#47][网络] 网络协议 | TCP/IP模型 | 以太网通信

目录 1.网络协议 2.协议分层 2.1 OSI七层模型 2.2TCP/IP五层(四层)模型 2.3 以太网通信 1.网络协议 "协议"本质就是一种约定 计算机之间的传输媒介是光信号和电信号. 通过 "频率" 和 "强弱" 来表示 0 和 1 这样的 信息. 要想传递各种不同…

HTML+CSS浮动和清除浮动的效果及其应用场景举例

一、清除浮动的效果 解释 .container:用于展示浮动和清除浮动效果的容器,具有边框和背景色以便于区分。 .float-box:浮动元素,用不同的背景色标识。 .clearfix:使用伪元素清除浮动的类,应用于第二个容器。 …

IDEA 2024.2.0.2 使用 Jrebel and XRebel 热部署

安装 激活 工具网站中url和邮箱复制进去 设置 允许项目自动构建 允许开发过程中自动部署

python面向对象—封装、继承、多态

封装 ① 把现实世界中的主体中的属性和方法书写到类的里面的操作即为封装 ② 封装可以为属性和方法添加为私有权限,不能直接被外部访问 在面向对象代码中,我们可以把属性和方法分为两大类:公有(属性、方法)、私有&…

SQLSugar进阶使用:高级查询与性能优化

文章目录 前言一、高级查询1.查所有2.查询总数3.按条件查询4.动态OR查询5.查前几条6.设置新表名7.分页查询8.排序 OrderBy9.联表查询10.动态表达式11.原生 Sql 操作 ,Sql和存储过程 二、性能优化1.二级缓存2.批量操作3.异步操作4.分表组件,自动分表5.查询…

LCP:60 排列序列[leetcode-4]

LCP:60 排列序列 给出集合 [1,2,3,...,n],其所有元素共有 n! 种排列。 按大小顺序列出所有排列情况,并一一标记,当 n 3 时, 所有排列如下: "123""132""213""231""312"&quo…

09 复合查询

前面的查询都是对一张表进行查询,但这远远不够 基本查询回顾 查询工资高于500或岗位为MANAGER的雇员,同时还要满足他们的姓名首字母为大写的J select * from EMP where (sal>500 or job‘MANAGER’) and ename like ‘J%’; 按照部门号升序而雇员的…

【git】git进阶-blame/stash单个文件/rebase和merge/cherry-pick命令/reflog和log

文章目录 git blame查看单个文件修改历史git stash单个文件git rebase命令git rebase和git merge区别git cherry-pick命令git reflog和git log区别 git blame查看单个文件修改历史 git blame:查看文件中每行最后的修改作者 git blame your_filegit log和git show结合…

基本数据类型及命令

String String 是Redis最基本的类型,Redis所有的数据结构都是以唯一的key字符串作为名称,然后通过这个唯一的key值获取相应的value数据。不同的类型的数据结构差异就在于value的结构不同。 String类型是二进制安全的。意思是string可以包含任何数据&…

requests库

一、pycharm导入requests库 在终端下输入pip install requests 按回车即可导入。 如果使用pip list 可以查到requests库即导入成功。 二、requsets的get请求 url为我们要请求的网址,headers用于伪造请求头,有的网址拒绝爬虫访问。 # # GET # import r…