用TensorFlow实现线性回归

说明

本文采用TensorFlow框架进行讲解,虽然之前的文章都采用mxnet,但是我发现tensorflow提供了免费的gpu可供使用,所以果断开始改为tensorflow,若要实现文章代码,可以使用colaboratory进行运行,当然,如果您已经安装了tensorflow,可以采用python直接运行。

贡献

学习时采取动手学深度学习第二版作为教材,但由于本书通过引入d2l(著者自写库)进行深度学习,我希望将d2l的影响去掉,即不使用d2l,使用tensorflow,这一点通过查询GitHub中d2l库提供的相关函数尝试进行实现。

如果本系列文章具有良好表现,将译为英文版上传至Github。

预备知识

学习本篇文章之前,您最好具有以下基础知识:

  1. 线性回归的基础知识
  2. python的基础知识

基本原理 

使用一个仿射变换,通过y=wx+b的模型来对数据进行预测(w和x均为矩阵,大小取决于输入规模),反向传播采用随机梯度下降对参数进行更新,参数包括w和b,即权重和偏差。

实现过程

生成数据集

只需要引入tensorflow即可,synthetic_data()函数将初始化X和Y,即通过真实的权重和偏差值生成数据集。

import tensorflow as tfdef synthetic_data(w, b, num_examples):X = tf.zeros((num_examples, w.shape[0]))X += tf.random.normal(shape=X.shape)y = tf.matmul(X, tf.reshape(w, (-1, 1))) + by += tf.random.normal(shape=y.shape, stddev=0.01)y = tf.reshape(y, (-1, 1))return X, ytrue_w = tf.constant([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

读取数据集

加载刚刚生成的数据集,is_train表示是否进行打乱,默认对数据进行打乱处理,使用load_array函数加载数据集。

def load_array(data_arrays, batch_size, is_train=True):dataset = tf.data.Dataset.from_tensor_slices(data_arrays)if is_train:dataset = dataset.shuffle(buffer_size=1000)dataset = dataset.batch(batch_size)return datasetbatch_size = 10
data_iter = load_array((features, labels), batch_size)

定义模型

模型使用keras API实现,keras是tensorflow中机器学习相关的库。先使用Sequential类定义承载容器,之后添加一个单神经元的全连接层。在TensorFlow中,Sequential表示容器相关的类,layer表示层相关的类。线性回归只需要通过keras中的单神经元的全连接层即可实现,神经元的值即为输出结果。

net = tf.keras.Sequential()
net.add(tf.keras.layers.Dense(1))

示例的线性回归仅有一个输入X,实际在其他线性回归过程中,很有可能有多个x及其对应的w,但keras的代码均不会发生改变,因为keras的Dense类可以自动判断输入的个数。 

初始化模型参数 

stddev表示标准差,initializer生成一个标准差为1,均值为0的正态分布。在构建全连接层时,使用该正态分布进行初始化。

initializer = tf.initializers.RandomNormal(stddev=0.01)
net = tf.keras.Sequential()
net.add(tf.keras.layers.Dense(1, kernel_initializer=initializer))

定义损失函数和优化算法 

损失函数使用平方损失函数进行计算,训练时使用小批量随机梯度下降SGD方法进行训练,学习率为0.03。

loss = tf.keras.losses.MeanSquaredError()
trainer = tf.keras.optimizers.SGD(learning_rate=0.03)

训练

运行以下代码可以观察训练结果。运行轮次为3轮,每一轮对所有训练集数据进行学习。计算w和b的梯度值,使用梯度下降更新权重w和偏差b。每一轮输出损失函数的值,最终显示权重和偏差的估计误差。

num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:with tf.GradientTape() as tape:l = loss(net(X, training=True), y)grads = tape.gradient(l, net.trainable_variables)trainer.apply_gradients(zip(grads, net.trainable_variables))l = loss(net(features), labels)print(f'epoch {epoch + 1}, loss {l:f}')
w = net.get_weights()[0]
print('w的估计误差:', true_w - tf.reshape(w, true_w.shape))
b = net.get_weights()[1]
print('b的估计误差:', true_b - b)

运行结果

epoch 1, loss 0.000194

epoch 2, loss 0.000091

epoch 3, loss 0.000091

w的估计误差: tf.Tensor([-0.00026917 0.00094557], shape=(2,), dtype=float32)

b的估计误差: [4.7683716e-06]

 改进尝试

  1. 更改SGD优化算法为Adam
  2. 更改MeanSquaredError为其他损失函数

对于上述改进,损失均有显著增加,表明原有方法已为最好方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/406675.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

外挂程序:增强点及辅助

1.关于前几篇介绍的外挂程序,SAP中的业务单据还是要区分具体的操作人员。如建立财务凭证,工号A,B,C使用相同的SAP账号,那就没办法知道是谁操作的了啊,所以sap的业务单据需要细分到具体人员的都要增强实现以下: 如生产工单: 具体的增强点: 2.辅助程序:SAP账号自动锁定功…

从新手到专家必读书籍:官方推荐.NET技术体系架构指南

前言 Microsoft 官方推荐了一系列有关 .NET 体系结构的指南,旨在帮助开发人员掌握最新的技术和最佳实践。这些资源覆盖了从微服务架构到云原生应用开发等多个主题,是开发高质量 .NET 应用程序不可或缺的参考资料。 通过这些指南,可以深入了…

瑞幸x《黑神话》周边秒空,联名营销真的是流量密码吗?

​8月19日,瑞幸上线了与国产3A游戏《黑神话:悟空》合作的联名活动,其中包括黑神话腾云美式咖啡及周边产品。很多人为了抢到联名的周边,一大早就在瑞幸卡点下单,更有一些网友早上6点多就在瑞幸门口“蹲点”,…

会话跟踪方案:Cookie Session Token

什么是会话技术? Cookie 以登录为例,用户在浏览器中将账号密码输入并勾选自动登录,浏览器发送请求,请求头中设置Cookie:userName:张三 ,password:1234aa ,若登录成功,服务器将这个cookie保存…

河南萌新联赛2024第(六)场:郑州大学(补题ABCDFGIL)

文章目录 河南萌新联赛2024第(六)场:郑州大学A 装备二选一(一)简单介绍:思路:代码: B 百变吗喽简单介绍:思路:代码: C 16进制世界简单介绍&#x…

【时时三省】(C语言基础)指针进阶2

山不在高,有仙则名。水不在深,有龙则灵。 ----CSDN 时时三省 数组指针 是一种指针-是指向数组的指针 整型指针-是指向整形的指针 字符指针-是指向字符的指针 什么叫做数组指针 上面的整形指针跟字符指针只需要&am…

【鸿蒙学习】HarmonyOS应用开发者高级认证 - 一次开发,多端部署

一、学习目的 掌握鸿蒙的核心概念和端云一体化开发、数据、网络、媒体、并发、分布式、多设备协同等关键技术能力,具备独立设计和开发鸿蒙应用能力。 二、总体介绍 HarmonyOS 系统面向多终端提供了“一次开发,多端部署”(后文中简称为“一…

日志审计-graylog ssh登录超过6次告警

Apt 设备通过UDP收集日志,在gray创建接收端口192.168.0.187:1514 1、ssh登录失败次数大于5次 ssh日志级别默认为INFO级别,通过系统rsyslog模块处理,日志默认存储在/var/log/auth.log。 将日志转发到graylog vim /etc/rsyslog.conf 文件末…

深入探讨SD NAND的SD模式与SPI模式初始化

在嵌入式系统和存储解决方案中,SD NAND的广泛应用是显而易见的。CS创世推出的SD NAND支持SD模式和SPI模式,这两种模式在功能和实现上各有优劣。在本文中,我们将深入探讨这两种模式的初始化过程,并比较它们在不同应用场景下的优劣&…

MySQL 配置免密码登陆(mysql_config_editor Configuration)

当使用mysql, mysqldump, mysqladmin等客户端连接MySQL数据库服务器时,需要提供用户凭证信息。你可以在每次连接时都输入连接信息(用户名/密码/地址/端口等)或者将用户信息保存在my.cnf配置文件的[client]模块。 第一种方式每次都输入用户密…

深度学习 --- VGG16各层feature map可视化(JupyterNotebook实战)

VGG16模块的可视化 VGG16简介: VGG是继AlexNet之后的后起之秀,相对于AlexNet他有如下特点: 1,更深的层数!相对于仅有8层的AlexNet而言,VGG把层数增加到了16和19层。 2,更小的卷积核!…

数据库MySQL多表设计、查询

目录 1.概述 2.一对多 3.一对一 4.多对多 5.多表查询 5.1内连接 5.2外连接 5.3子查询 1.概述 项目开发中,在进行数据库表结构设计时,会根据业务需求及业务模块之间的关系,分析并设计表结构,由于业务之间相互关联,所以各个…

网络编程TCP与UDP

TCP与UDP UDP头: 包括源端口、目的地端口、用户数据包长度,检验和 数据。 typedef struct _UDP_HEADER {unsigned short m_usSourPort;    // 源端口号16bitunsigned short m_usDestPort;    // 目的端口号16bitunsigned short m_usLen…

Docker!!!

⼀、Docker 1、Docker介绍.pdf 1、Docker 是什么? Docker 是⼀个开源的应⽤容器引擎,可以实现虚拟化,完全采⽤“沙盒”机制,容器之间不会存在任何接⼝。Docker 通过 Linux Container(容器)技术将任意类型…

Ubuntu 安装 mysql 与 远程连接配置

1、安装 mysql ubuntu 默认安装 8.0 版本: sudo apt install mysql-server安装过程中 提示 是否继续操作 y 即可 2、使用ubuntu 系统用户 root 直接进入 mysql 切换至 系统用户 su root 输入命令 可直接进入 mysql: mysql3、创建一个允许远程登录的用户 创建 …

《Python编程:从入门到实践》笔记(一)

一、字符串 1.修改字符串大小写 title()以首字母大写的方式显示每个单词,即将每个单词的首字母都改为大写,其他的改为小写。 upper()将字母都改为大写,lower()将字母都改为小写。 2.合并(拼接)字符串 Python使用加号()来合并字符串。这种合…

超容易出成果的方向:多模态医学图像处理!

哈喽朋友们,今天给大家推荐一个比较容易出成果的方向:多模态医学图像处理。 众所周知,多模态如今火的一塌糊涂,早就成了很多应用科学与AI结合的重要赛道,特别是在医学图像处理领域。 由此提出的多模态医学图像处理融合…

「Java 项目详解」API 文档搜索引擎(万字长文)

目录 运行效果 一、项目介绍 一)需求介绍 二)功能介绍 三)实现思路 四)项目目标 二、前期准备 一)了解正排索引 二)了解倒排索引 三)获取 Java API 开发文档 四)了解分词…

二叉树检验:算法详解

问题描述 /** 检查二叉树是否为有效的二叉搜索树有效的二叉搜索树满足左子树的节点值都小于根节点值,右子树的节点值都大于根节点值并且左右子树也必须是有效的二叉搜索树param root 二叉树的根节点return 如果二叉树是有效的二叉搜索树,则返回true&…

火绒使用详解 为什么选择火绒?使用了自定义规则及其高级功能的火绒,为什么能吊打卡巴斯基,360,瑞星,惠普联想戴尔的电脑管家等?

目录 前言 必看 为什么选择火绒? 使用了自定义规则及其高级功能的火绒,为什么能吊打卡巴斯基,360,瑞星,惠普联想戴尔的电脑管家等? 原因如下: 火绒的主要优势 1. 轻量化设计 2. 强大的自…