数值微分求梯度、计算图求梯度,实现单层线性回归 模型速度差异及损失率比对

文章目录

    • 简述
    • 测试结果
    • 完整代码

简述

先将前面两篇文章的代码重构一下,抽离共同函数到utils.py

重构后结构:
在这里插入图片描述

ComputationGraphLinearNet.py
使用计算图(forward、backward)求梯度构建的线性模型,代码不变;
代码:https://itsven.blog.csdn.net/article/details/141168288

NumericalGradientLinearNet.py:
使用数值微分求梯度构建的线性模型,代码不变;
代码:https://itsven.blog.csdn.net/article/details/141168156

utils.py
测试数据构建;

main.py
对上述代码整合运行;

测试结果

测试前提:
不考虑epoch、batch_size、learning_rate是否合理,只在这些参数相同情况下,比较数值微分求梯度、计算图求梯度两种方式构建的线性模型运算速度、运算结果的差异。

参数误差:

计算方式:相同参数的绝对值误差平均值,运行结果来看,误差很小。

for key in graphNet.params.keys():  diff = np.average(np.abs(graphNet.params[key] - numericalNet.params[key]))  params[key] = diff

在这里插入图片描述

时间差异:

epoch、batch_size=、learning_rate相同情况下,每次使用不同数据(数据大小5000),分别运行5次,明显计算图求梯度方法运算速度更快(时间:毫秒)。
在这里插入图片描述

完整代码

ComputationGraphLinearNet.pyNumericalGradientLinearNet.py 见:

使用计算图(forward、backward)求梯度构建的线性模型https://itsven.blog.csdn.net/article/details/141168288

使用数值微分求梯度构建的线性模型:https://itsven.blog.csdn.net/article/details/141168156

utils.py

import numpy as np  def build_data(weights, bias, num_examples):  x = np.random.randn(num_examples, len(weights))  y = x.dot(weights) + bias  # 给y加个噪声  y += np.random.rand(1)  return x, y  def data_iter(features, labels, batch_size):  num_examples = len(features)  # 按样本数量构造索引  indices = list(range(num_examples))  # 打乱索引数组  np.random.shuffle(indices)  for i in range(0, num_examples, batch_size):  batch_indices = np.array(indices[i:min(i + batch_size, num_examples)])  yield features[batch_indices], labels[batch_indices]

main.py

from ComputationGraphLinearNet import Network as GraphNet  
from NumericalGradientLinearNet import Network as NumericalNet  import utils  
import numpy as np  
import matplotlib.pyplot as plt  
import time  def numerical_test(x_train, y_train, batch_size, num_epochs):  numericalNet = NumericalNet(2, 1, 0.01)  loss_history = list()  for i in range(num_epochs):  for x, y in utils.data_iter(x_train, y_train, batch_size):  grads = numericalNet.numerical_gradient(x, y)  for key in grads:  numericalNet.params[key] -= learning_rate * grads[key]  running_loss = numericalNet.loss(x, y)  loss_history.append(running_loss)  print(f'最后一次损失值:{loss_history[-1]}')  print(f'预测参数: true_w1={numericalNet.params["w1"]}, true_b1={numericalNet.params["b1"]}')  return numericalNet  def graph_test(x_train, y_train, batch_size, num_epochs):  graphNet = GraphNet(2, 1, 0.01)  loss_history = list()  for i in range(num_epochs):  for x_batch, y_batch in utils.data_iter(x_train, y_train, batch_size):  grads = graphNet.gradient(x_batch, y_batch)  for key in grads:  graphNet.params[key] -= learning_rate * grads[key]  running_loss = graphNet.loss(x_batch, y_batch)  loss_history.append(running_loss)  print(f'最后一次损失值:{loss_history[-1]}')  print(f'预测参数: true_w1={graphNet.params["w1"]}, true_b1={graphNet.params["b1"]}')  return graphNet  if __name__ == '__main__':  test_num = 5  num_epochs = 2  batch_size = 50  learning_rate = 0.01  graph_time = list()  numerical_time = list()  error_list = list()  for i in range(test_num):  true_w1 = np.random.rand(2, 1)  true_b1 = np.random.rand(1)  x_train, y_train = utils.build_data(true_w1, true_b1, 5000)  print(f'\n----------------------------第{i+1}次')  print(f'第{i+1}次, 正确参数: true_w1={true_w1}, true_b1={true_b1}\n')  print("------------数值微分法:")  start = time.perf_counter()  numericalNet = numerical_test(x_train, y_train, batch_size, num_epochs)  end = time.perf_counter()  print(f"运行时间:{(end - start) * 1000}毫秒")  numerical_time.append((end - start) * 1000)  print("------------计算图法:")  start = time.perf_counter()  graphNet = graph_test(x_train, y_train, batch_size, num_epochs)  end = time.perf_counter()  print(f"运行时间:{(end - start) * 1000}毫秒")  graph_time.append((end - start) * 1000)  params = {}  for key in graphNet.params.keys():  diff = np.average(np.abs(graphNet.params[key] - numericalNet.params[key]))  params[key] = diff  error_list.append(params)  plt.title("数值微分、计算图两种求导速度差异", fontproperties="STSong")  plt.xlabel("nums")  plt.ylabel("time")  plt.xticks(range(0, test_num))  plt.plot(graph_time, linestyle='dotted', marker='o', label='graph_time')  plt.plot(numerical_time,  linestyle='dotted', marker='*', label='numerical_time')  plt.legend(loc='upper right')  plt.show()  for i in range(len(error_list)):  print(f'第{i+1}次运行, 各参数绝对误差的平均值{error_list[i]}', end="\n")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/399617.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

分库分表的使用场景和中间件

文章目录 一、为什么要分库分表?分库分表的使用场景?二、分库分表常用中间件1、Cobar2、TDDL3、Atlas4、Sharding-jdbc5、Mycat6、总结 一、为什么要分库分表?分库分表的使用场景? 场景1:注册用户就 20 万&#xff0c…

<数据集>集装箱缺陷识别数据集<目标检测>

数据集格式:VOCYOLO格式 图片数量:3793张 标注数量(xml文件个数):3793 标注数量(txt文件个数):3793 标注类别数:4 标注类别名称:[DAMAGE - DEFRAME, DENT, DAMAGE - RUST, DAMAGE - HOLE] 序号类别名…

飞睿智能8km无人机WiFi图传模块,高清、稳定、超远距!实时传输新高度

在数字化飞速发展的今天,无人机已经从一个遥不可及的科幻概念,变成了我们日常生活中的得力助手。无论是航拍美景、农业植保,还是紧急救援、物流配送,无人机都展现出了其独特的优势。而在这背后,一个至关重要的技术支撑…

ThinkPHP教程

thinkPHP笔记 01. phpEnv配置安装 主讲老师 - 李炎恢 1. 学习基础 ThinkPHP8.x: 前端基础:HTML5/CSS(必须)、JavaScript(可选、但推荐有);后端基础:PHP基础,版本不限,但不能太老,至少PHP5.4以上语法,TP8是兼容PHP8.x的;数据库基础:MySQL数据库,掌握了常规的SQL…

再谈表的约束

文章目录 自增长唯一键外键 自增长 auto_increment:当对应的字段,不给值,会自动的被系统触发,系统会从当前字段中已经有的最大值1操作,得到一个新的不同的值。通常和主键搭配使用,作为逻辑主键。 自增长的…

面向服务架构(SOA)介绍

在汽车电子电气架构还处于分布式时代时,汽车软件的开发方式主要是采用嵌入式软件进行开发,而随着汽车智能化程度的加深,更加复杂且多样的功能需求让汽车软件在复杂度上再上一层。在整车的自动驾驶方面,由于未来高阶自动驾驶能力的…

《Unity3D网络游戏实战》正确收发数据流

TCP数据流 系统缓冲区 当收到对端数据时,操作系统会将数据存入到Socket的接收缓冲区中 操作系统层面上的缓冲区完全由操作系统操作,程序并不能直接操作它们,只能通过socket.Receive、socket.Send等方法来间接操作。当系统的接收缓冲区为空&…

RCE绕过技巧

目录 EVAL长度限制突破技巧 1.使用反引号 2.file_put_contents写入文件 3.php5.6变长参数usort回调后门 命令长度限制突破技巧 1.拼接文件名 无字母数字的webshell命令执行 1.取反码 2.上传临时文件 EVAL长度限制突破技巧 分析代码:首先传递一个param参数&…

OceanBase V4.3 列存引擎之场景问题汇总

在OceanBase 4.3版本发布后(OceanBase社区版 V4.3 免费下载),其新增的列存引擎,及行列混存一体化的能力,可以支持秒级实时分析,引发了用户、开发者及业界人士的广泛讨论。本文选取了这些讨论中较为典型的一…

企业应该如何准备 EcoVadis 审核?

企业准备 EcoVadis 审核可以参考以下步骤: 注册:在网上注册并提供公司的相关信息,包括法律实体名称、国家和地区、企业规模和行业等。如果是受客户邀请参加评估,需按照邀请邮件中的链接进行注册,并确保客户能随时获知评…

安卓默认混淆规则文件的区别

在 Android 项目中,ProGuard 是一个优化和混淆代码的工具。proguard-android-optimize.txt 和 proguard-android.txt 是两个用于配置 ProGuard 的默认规则文件,如图下 它们有以下区别: proguard-android-optimize.txt: 优化:这个配…

Django中事务的基本使用

1. Django事务处理 事务(Transaction): 是一种将多个数据库操作组合成一个单一工作单元的机制. 如果事务中的所有操作都成功完成, 则这些更改将永久保存到数据库中. 如果事务中的某个操作失败, 则整个事务将回滚到事务开始前的状态, 所有的更改都不会被保存到数据库中. 这对于…

系统编程 day10 进程2

进程创建之后: 1.任务-----子进程与父进程干的活差不多 2.父进程创建出子进程之后,子进程做的与父进程完全不同 shell程序-----bash----- 以上为进程运行的过程中,典型的两种应用场景 能够改变子进程的执行效果的函数是exec函数族 l和v&a…

【网盘系统3.0版本】百度云盘手动cookie获取,添加到扫码系统管理平台。

一.获取cookie步骤 1.谷歌浏览器选择开发者模式。 2.选择网路,过滤接口main 3.选择request head,cookie列表里面可查看二.添加到管理平台。 1.登录管理平台,输入账户和密码 2.选择账户设置,添加cookie。 4.复制卡密链接&#xf…

LVS实验的三模式总结

文章目录 LVS的概念叙述NAT工作模式实战案例**思想:**NAT工作模式的优点NAT工作模式的缺点 NAT工作模式的应用场景大致配置 route:打开路由内核功能 部署DR模式集群案例工作思想:大致工作图如下思路模型 具体配置与事实步骤补充 防火墙标签解…

c++编程(20)——类与对象(6)继承

欢迎来到博主的专栏——c编程 博主ID:代码小豪 文章目录 继承继承与权限访问 基类和派生类基类和派生类的赋值兼容转换基类与派生类的类作用域派生类与基类的构造函数基类与派生类拷贝构造函数 继承与静态成员final关键字 面向对象编程的核心思想是封装、继承和多态…

计算机网络408考研 2021

2021 计算机网络408考研2021年真题解析_哔哩哔哩_bilibili 1 1 11 1 1 11

解决No module named ‘tensorflow‘

import tensorflow as tf ModuleNotFoundError: No module named tensorflow 安装合适的tensorflow版本 先查看自己的python版本 或者输入指令;python --version 安装兼容的tensorflow版本,安装指定版本的tensorflow pip install tensorflow-gpu2.3.0…

Qt | QSQLite内存数据库增删改查

点击上方"蓝字"关注我们 01、演示 参数随便设置 查询 修改 右键菜单是重点 手动提交,点击Submit All

【Docker】基础篇

系列综述: 💞目的:本系列是个人整理为了云计算学习的,整理期间苛求每个知识点,平衡理解简易度与深入程度。 🥰来源:材料主要源于–Docker视频教程从入门到进阶,docker视频教程详解–…