3 Tensorflow构建模型详解

上一篇:2 用TensorFlow构建一个简单的神经网络-CSDN博客

本篇目标是介绍如何构建一个简单的线性回归模型,要点如下:

  • 了解神经网络原理
  • 构建模型的一般步骤
  • 模型重要参数介绍


1、神经网络概念

接上一篇,用tensorflow写了一个猜测西瓜价格的简单模型,理解代码前先了解下什么是神经网络。

下面是百度AI对神经网络的解释:

神经网络是一种运算模型,由大量的节点(或称神经元)之间相互联接构成,每个节点代表一种特定的输出函数,称为激励函数(activation function)。每两个节点间的连接都代表一个对于通过该连接信号的加权值,称之为权重,这相当于人工神经网络的记忆。网络的输出则依网络的连接方式,权重值和激励函数的不同而不同。而网络自身通常都是对自然界某种算法或者函数的逼近,也可能是对一种逻辑策略的表达。
神经网络是一种广泛并行互连的网络,它的组织能够模拟生物神经系统对真实世界物体所做出的交互反应。

首先我们要了解下密集层(也叫全连接层),密集层是一个深度连接的神经网络层,在神经网络中指的是每个神经元都与前一层的所有神经元相连的层。

在上一篇我们创建了预测价格模型,代码为:

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

其中Sequential是顺序的意思,Dense就是密集层。

看文字有点抽象,举个例子,如下图所示:神经元a1与所有输入层数据相连(X1,X2,X3),其他神经元也一样都与上一层神经元相连,这样形成的神经网络就是密集层。

它们之间的数学关系为:

某个神经元是由连接的上一层神经元分别乘上权重(w),再加上偏差(b)得到,例如计算a1:

权重w的数字下标可以按照顺序命名,比如第一个神经元计算的权重可以为w11、w12……,第二个神经元计算的权重可以为w21、w22……

a2、a3计算以此类推。

了解这些基本的原理后,我们就开始创建一个简单的费用预测模型。

2、西瓜费用预测模型详解

代码如下:

import numpy as np
import tensorflow as tf# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[1])
])model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')history = model.fit(weight, total_cost, epochs=500)# 训练完成后,预测10斤西瓜的总费用
print(model.predict([10]))

上一篇西瓜费用计算公式 :费用=1.2元/斤*重量+0.5元

即:y=1.2x+0.5

这是一个一元线性回归问题,只有一个自变量x和一个因变量y,机器学习要推算出权重w=1.2, 偏差b=0.5,才能准确预测费用。

具体流程如下:

(1)训练数据准备

西瓜重量 weight=[1, 3, 4, 5, 6, 8]

对应的费用 total_cost=[1.7, 4.1, 5.3, 6.5, 7.7, 10.1]

(2)构建模型

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

  • tf.keras.layers.Dense(1, input_shape=[1]),参数1表示1个神经元,我们只要预测费用y,所以输出层只要一个神经元就可以了(注意:神经元不用包含输入层)。
  • input_shape=[1],表示输入数据的形状为单元素列表,即每个输入数据只有一个值。因为只有一个变量x(西瓜的重量),所以此处输入形状是[1]

该模型的示意图:

可以用model.summary()查看模型摘要,代码如下:

import numpy as np
import tensorflow as tf# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[1])
])# 查看模型摘要
model.summary()

运行结果:

可以看到可训练参数有2个,即公式中的w1和b1。

(3)设置损失函数和优化器
model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')
  • mean_squared_error是均方误差,指的是预测值与真实值差值的平方然后求和再平均。公式为:

                    MSE=1/n Σ(P-G)^2 (P为预测值,G为真实值)

  • SGD即随机梯度下降(Stochastic Gradient Descent),是一种迭代优化算法。

(4)训练模型
history = model.fit(weight, total_cost, epochs=500)
  • 设置训练数据的特征和标签,在上述代码中分别是西瓜的重量和费用:weight、total_cost
  • 设置训练轮次epochs=500,1个epochs是指使用所有样本训练一次。

(5) 查看训练结果

看下面的训练过程,第8个epoch的时候损失值loss已经很小了,训练轮次不需要设置到500就可以有很好的预测效果了。

刚开始loss很高,使用优化算法慢慢调整了权重,loss值可以很好地衡量我们的模型有多好。

我们把epoch的值调小,看看程序猜测的权重(w)和偏差(b)是多少,以及loss值的计算。

 

代码改动如下:

  •  epochs=5
  • 用model.get_weights()获取程序猜测的权重数据
import numpy as np
import tensorflow as tf# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[1])
])model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')history = model.fit(weight, total_cost, epochs=5)# 获取权重数据
w = model.get_weights()[0]
b = model.get_weights()[1]print('w:')
print(w)
print('b: ')
print(b)# 训练完成后,预测10斤西瓜的总费用
print(model.predict([10]))

运行结果:

训练了5个epoch后,程序猜测w是1.1807659,b为0.33192113

            y=wx+b=1.1807659*10+0.33192113=12.139581

所以预测10斤西瓜的总费用是12.139581

                 

3、创建更复杂一点的模型

现实生活中我们要预测的东西影响因素可能有很多个,如房价预测,房价可能受到房屋面积、房间数量等等因素影响。思考一下,下面的神经网络图创建模型时要如何设置参数呢?

model = tf.keras.Sequential([tf.keras.layers.Dense(2, input_shape=[3]),tf.keras.layers.Dense(1)
])
  • 输入层有3个变量,input_shape=[3]
  • 隐藏层有2个神经元,所以 tf.keras.layers.Dense(2, input_shape=[3]) 的units设为2
  • 输出层只有1个神经元,所以 tf.keras.layers.Dense(1) 的units设为1
  • tf.keras.Sequential的‘Sequential’是顺序的意思,添加的这些layers就按顺序堆叠

         

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/176971.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

微信小程序:自定义组件传值——获取手机验证码

一:遇到的问题 通过自己自定义的组件编写的表单,发现传值不了,点击后收到的值为空。 二:创建组件 先在根目录创建components文件夹,创建img-verify文件夹(这个是我取的组件名字),在…

什么是 DevOps

DevOps是一套融合软件开发(Dev)和 IT 运营(Ops)的实践,旨在缩短应用程序开发周期并确保以高软件质量持续交付,通过采用 DevOps 实践,您可以帮助组织更可靠、更快速、更高效地交付软件。 什么是…

一百九十八、Java——IDEA项目中有参构造、无参构造等快捷键(持续梳理中)

一、目的 由于IDEA项目中有很多快捷键,可以很好的提高开发效率,因此整理一下 二、快捷键 (一)快捷键生成public static void main(String[] args) {} 快捷键:psvm (二)快捷键在test中创建cn…

MacOS安装git

文章目录 通过Xcode Command Lines Tool安装(推荐)终端直接运行git命令根据流程安装先安装Command Lines Tool后再安装git 官网下载二进制文件进行安装官方国外源下载二进制文件(不推荐)国内镜像下载二进制文件(推荐)安装git 通过Xcode Command Lines Tool安装(推荐) 简单来讲C…

ubuntu(18.04)中架设HiGlass docker镜像服务,已尝试mcool、bedpe、wig格式文件

前言 使用到的软件 docker 文档 : https://www.docker.com/ HiGlass 文档:http://docs.higlass.io/higlass_docker.html#running-locally https://github.com/higlass/higlass-dockerhiglass-docker 地址:https://github.com/higla…

17.基干模型Swin-Transformer解读

文章目录 SWin-Transformer解读1.基础介绍关于Shifted Window based Self-Attention相对位置偏置网络整体结构和层级特征欢迎访问个人网络日志🌹🌹知行空间🌹🌹 SWin-Transformer解读 1.基础介绍 Swin-Transformer是2021年03月微软亚洲研究院提交的论文中提出的,比V…

Arduino开发

文章目录 Arduino IDE 的使用1. 使能编译以及烧录的LOG:2. 下载配置3. 下载 Arduino指令程序下载步骤通过下载器下载通过串口下载 关于Arduino IDE工程生成的二进制文件对比Tools-->burn bootloader 和 ArduinoISP例程 的区别自带例程 Arduino IDE 的使用 1. 使…

【发表案例】2区正刊,网络安全、智能系统领域,2个月3天录用,11天见刊,16天检索!

计算机类SCIE 【期刊简介】IF:4.0-5.0,JCR2区,中科院3区 【检索情况】SCIE 在检,正刊 【征稿领域】提高安全性和隐私性的边缘/云的智能方法的研究,如数字孪生等 【截稿日期】2023.11.30 录用案例:2个月…

SpringBoot / Vue 对SSE的基本使用

一、SSE是什么? SSE技术是基于单工通信模式,只是单纯的客户端向服务端发送请求,服务端不会主动发送给客户端。服务端采取的策略是抓住这个请求不放,等数据更新的时候才返回给客户端,当客户端接收到消息后,再…

恒驰服务 | 华为云数据使能专家服务offering之数仓建设

恒驰大数据服务主要针对客户在进行智能数据迁移的过程中,存在业务停机、数据丢失、迁移周期紧张、运维成本高等问题,通过为客户提供迁移调研、方案设计、迁移实施、迁移验收等服务内容,支撑客户实现快速稳定上云,有效降低时间成本…

IntelliJ IDEA快捷键sout不生效

1.刚下载完idea编辑器时,可能idea里的快捷键打印不生效。这时你打开settings 2.点击settings–>Live Templates–>找到Java这个选项,点击展开 3.找到sout 4.点击全选,保存退出就可以了 5.最后大功告成!

物联网整体框架有哪些层面?

物联网是当前非常火热的话题,各个行业对物联网的关注和投入力度也很大,一些互联网巨头都在紧锣密鼓的布局物联网产业,抢占市场先机。 物联网的整体构架大致可以分为以下四个层面: 1.感知识别层 感知层是物联网整体架构的基础&…

基于springboot实现学生就业管理系统项目【项目源码+论文说明】

基于springboot实现学生就业管理系统演示 摘要 随着信息化时代的到来,管理系统都趋向于智能化、系统化,学生就业管理系统也不例外,但目前国内仍都使用人工管理,市场规模越来越大,同时信息量也越来越庞大,人…

【安装】自建Rustdesk Server

文章目录 RustDesk说明RustDesk优点RustDesk相关链接非Docker基于CentOSRustDesk默认程序占用端口说明 启动 hbbr 是中继服务器启动 hbbs 是ID服务器客户端配置编写启动脚本hbbr、hbbs命令详细说明 RustDesk说明 RustDesk优点 自建服务端。搭建在自己的云服务器就相当于独享高…

antv/g6 节点、及自定义节点

节点 AntV G6 中内置节点支持的通用属性通常包括以下几个: id:节点的唯一标识符。 x 和 y:节点的位置坐标。 label:节点的标签文本。 style:节点的样式,用于设置节点的外观,可以包括填充颜色…

windows系统卸载mysql

1. win r 输入 control 打开控制面板 2.搜索mysql,删除搜索内容 3.删除相应路径下的mysql文件夹C:\Program Files C:\ProgramData 4.删除注册表,win r 输入 regedit 打开注册表 5.搜索MySql 删除掉 完成

高等数学啃书汇总重难点(十)重积分

方法性的一章,看着唬人,实际上定积分学得熟练,就可以很轻松的掌握这一章的内容,重点在于计算各种坐标下的二重或三重积分~ 1.几何意义 2.定义 3.性质 4.直角坐标计算二重积分 5.极坐标计算二重积分 6.三重积分 7.重积分的应用

国际物流常见风险如何规避_箱讯科技

外贸物流是国际贸易的重要环节,其管理和效率的高低直接影响着贸易的成本和效益。因此,外贸企业应该重视物流的组织和管理,提高物流运作的效率。 国际物流基础知识 01什么是“双清包税”和“双清不包税” 双清包税上门又叫双清包税到门&…

论文翻译-ImageNet Classification with Deep Convolutional Neural Networks

[toc] 前言 AlexNet是是引领深度学习浪潮的开山之作,即使是我们现在进入了ChatGPT时代,这篇论文依然具有一定的借鉴意义。AlexNet的作者是多伦多大学的Alex Krizhevsky等人。Alex Krizhevsky是Hinton的学生。网上流行说 Hinton、LeCun和Bengio是神经网…

京东h5st逆向 h5st代码之拓展

知识点 node安装模块 crypto-js JavaScript 中的加密库 则更偏向于消息摘要算法、对称加密和简单的哈希函数,支持 AES、DES、SHA-1、HMAC 等诸多算法,适用于对客户端本地存储的数据进行加密、散列或签名处理等场景 axios 一旦安装成功,我们就…