【单层神经网络】基于MXNet的线性回归实现(底层实现)

写在前面

  1. 基于亚马逊的MXNet库
  2. 本专栏是对李沐博士的《动手学深度学习》的笔记,仅用于分享个人学习思考
  3. 以下是本专栏所需的环境(放进一个environment.yml,然后用conda虚拟环境统一配置即可)
  4. 刚开始先从普通的寻优算法开始,熟悉一下学习训练过程
  5. 下面将使用梯度下降法寻优,但这大概只能是局部最优,它并不是一个十分优秀的寻优算法
name: gluon
dependencies:
- python=3.6
- pip:- mxnet==1.5.0- d2lzh==1.0.0- jupyter==1.0.0- matplotlib==2.2.2- pandas==0.23.4

整体流程

  1. 生成训练数据集(实际工程中,需要从实际对象身上采集数据)
  2. 确定模型及其参数(输入输出个数、阶次,偏置等)
  3. 确定学习方式(损失函数、优化算法,学习率,训练次数,终止条件等)
  4. 读取数据集(不同的读取方式会影响最终的训练效果)
  5. 训练模型

完整程序及注释

from IPython import display
from matplotlib import pyplot as plt
from mxnet import autograd, nd
import random'''
获取(生成)训练集
'''
input_num = 2				# 输入个数
examples_num = 1000			# 生成样本个数
# 确定真实模型参数
real_W = [10.9, -8.7]		
real_bias = 6.5	features = nd.random.normal(scale=1, shape=(examples_num, input_num))       # 标准差=1,均值缺省=0
labels = real_W[0]*features[:,0] + real_W[1]*features[:,1] + real_bias		# 根据特征和参数生成对应标签
labels_noise = labels + nd.random.normal(scale=0.1, shape=labels.shape)		# 为标签附加噪声,模拟真实情况# 绘制标签和特征的散点图(矢量图)
# def use_svg_display():
#     display.set_matplotlib_formats('svg')# def set_figure_size(figsize=(3.5,2.5)):
#     use_svg_display()
#     plt.rcParams['figure.figsize'] = figsize# set_figure_size()
# plt.scatter(features[:,0].asnumpy(), labels_noise.asnumpy(), 1)
# plt.scatter(features[:,1].asnumpy(), labels_noise.asnumpy(), 1)
# plt.show()# 创建一个迭代器(确定从数据集获取数据的方式)
def data_iter(batch_size, features, labels):num = len(features)indices = list(range(num))                                  # 生成索引数组random.shuffle(indices)                                     # 打乱indices# 该遍历方式同时确保了随机采样和无遗漏for i in range(0, num, batch_size):j = nd.array(indices[i: min(i+batch_size, num)])        # 对indices从i开始取,取batch_size个样本,并转换为列表yield features.take(j), labels.take(j)                  # take方法使用索引数组,从features和labels提取所需数据"""
训练的基础准备
"""
# 声明训练变量,并赋高斯随机初始值
w = nd.random.normal(scale=0.01, shape=(input_num))
b = nd.zeros(shape=(1,))
# b = nd.zeros(1)       # 不同写法,等价于上面的
w.attach_grad()         # 为需要迭代的参数申请求梯度空间
b.attach_grad()# 定义模型
def linreg(X, w, b):return nd.dot(X,w)+b# 定义损失函数
def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) **2 /2# 定义寻优算法
def sgd(params, learning_rate, batch_size):for param in params:# 新参数 = 原参数 - 学习率*当前批量的参数梯度/当前批量的大小param[:] = param - learning_rate * param.grad / batch_size# 确定超参数和学习方式
lr = 0.03
num_iterations = 5
net = linreg				# 目标模型
loss = squared_loss			# 代价函数(损失函数)
batch_size = 10				# 每次随机小批量的大小'''
开始训练
'''
for iteration in range(num_iterations):		# 确定迭代次数for x, y in data_iter(batch_size, features, labels):with autograd.record():l = loss(net(x,w,b), y)			# 求当前小批量的总损失l.backward()						# 求梯度sgd([w,b], lr, batch_size)			# 梯度更新参数train_l = loss(net(features,w,b), labels)print("iteration %d, loss %f" % (iteration+1, train_l.mean().asnumpy()))
# 打印比较真实参数和训练得到的参数
print("real_w " + str(real_W) + "\n train_w " + str(w))
print("real_w " + str(real_bias) + "\n train_b " + str(b))

具体程序解释

param[:] = param - learning_rate * param.grad / batch_size
将batch_size与参数调整相关联的原因,是为了使得每次更新的步长不受批次大小的影响
具体来说,当计算一批数据的损失函数的梯度时,实际上是将这批数据中每个样本对损失函数的贡献累加起来。这意味着如果批次较大,梯度的模也会相应增大
故更新权值时,使用的是数据集的平均梯度,而不是总和

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/11714.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Shadow DOM举例

这东西具有隔离效果&#xff0c;对于一些插件需要append一些div倒是不错的选择 <!DOCTYPE html> <html lang"zh-CN"> <head> <meta charset"utf-8"> <title>演示例子</title> </head> <body> <style&g…

万字长文深入浅出负载均衡器

前言 本篇博客主要分享Load Balancing&#xff08;负载均衡&#xff09;&#xff0c;将从以下方面循序渐进地全面展开阐述&#xff1a; 介绍什么是负载均衡介绍常见的负载均衡算法 负载均衡简介 初识负载均衡 负载均衡是系统设计中的一个关键组成部分&#xff0c;它有助于…

ChatGPT-4o和ChatGPT-4o mini的差异点

在人工智能领域&#xff0c;OpenAI再次引领创新潮流&#xff0c;近日正式发布了其最新模型——ChatGPT-4o及其经济实惠的小型版本ChatGPT-4o Mini。这两款模型虽同属于ChatGPT系列&#xff0c;但在性能、应用场景及成本上展现出显著的差异。本文将通过图文并茂的方式&#xff0…

双指针算法思想——OJ例题扩展算法解析思路

大家好&#xff01;上一期我发布了关于双指针的OJ平台上的典型例题思路解析&#xff0c;基于上一期的内容&#xff0c;我们这一期从其中内容扩展出来相似例题进行剖析和运用&#xff0c;一起来试一下吧&#xff01; 目录 一、 基于移动零的举一反三 题一&#xff1a;27. 移除…

WSL2中安装的ubuntu开启与关闭探讨

1. PC开机后&#xff0c;查询wsl状态 在cmd或者powersell中输入 wsl -l -vNAME STATE VERSION * Ubuntu Stopped 22. 从windows访问WSL2 wsl -l -vNAME STATE VERSION * Ubuntu Stopped 23. 在ubuntu中打开一个工作区后…

git笔记-简单入门

git笔记 git是一个分布式版本控制系统&#xff0c;它的优点有哪些呢&#xff1f;分为以下几个部分 与集中式的版本控制系统比起来&#xff0c;不用担心单点故障问题&#xff0c;只需要互相同步一下进度即可。支持离线编辑&#xff0c;每一个人都有一个完整的版本库。跨平台支持…

Chapter2 Amplifiers, Source followers Cascodes

Chapter2 Amplifiers, Source followers & Cascodes MOS单管根据输入输出, 可分为CS放大器, source follower和cascode 三种结构. Single-transistor amplifiers 这一章学习模拟电路基本单元-单管放大器 单管运放由Common-Source加上DC电流源组成. Avgm*Rds, gm和rds和…

LabVIEW透镜多参数自动检测系统

在现代制造业中&#xff0c;提升产品质量检测的自动化水平是提高生产效率和准确性的关键。本文介绍了一个基于LabVIEW的透镜多参数自动检测系统&#xff0c;该系统能够在单一工位上完成透镜的多项质量参数检测&#xff0c;并实现透镜的自动搬运与分选&#xff0c;极大地提升了检…

K8S集群部署--亲测好用

最近在自学K8S&#xff0c;花了三天最后终于成功部署一套K8S Cluster集群&#xff08;masternode1node2&#xff09; 在这里先分享一下具体的步骤&#xff0c;后续再更新其他的内容&#xff1a;例如部署期间遇到的问题及其解决办法。 部署步骤是英文写的&#xff0c;最近想练…

PHP实现混合加密方式,提高加密的安全性(代码解密)

代码1&#xff1a; <?php // 需要加密的内容 $plaintext 授权服务器拒绝连接;// 1. AES加密部分 $aesKey openssl_random_pseudo_bytes(32); // 生成256位AES密钥 $iv openssl_random_pseudo_bytes(16); // 生成128位IV// AES加密&#xff08;CBC模式&#xff09…

Linux - 进程间通信(3)

目录 3、解决遗留BUG -- 边关闭信道边回收进程 1&#xff09;解决方案 2&#xff09;两种方法相比较 4、命名管道 1&#xff09;理解命名管道 2&#xff09;创建命名管道 a. 命令行指令 b. 系统调用方法 3&#xff09;代码实现命名管道 构建类进行封装命名管道&#…

C语言 --- 分支

C语言 --- 分支 语句分支语句含义if...else语句单分支if语句语法形式 双分支 if-else 语句语法形式 悬空else含义问题描述 多分支 if-else 语句语法形式 switch...case语句含义语法形式 总结 &#x1f4bb;作者简介&#xff1a;曾与你一样迷茫&#xff0c;现以经验助你入门 C 语…

SSRF 漏洞利用 Redis 实战全解析:原理、攻击与防范

目录 前言 SSRF 漏洞深度剖析 Redis&#xff1a;强大的内存数据库 Redis 产生漏洞的原因 SSRF 漏洞利用 Redis 实战步骤 准备环境 下载安装 Redis 配置漏洞环境 启动 Redis 攻击机远程连接 Redis 利用 Redis 写 Webshell 防范措施 前言 在网络安全领域&#xff0…

Spring Boot - 数据库集成06 - 集成ElasticSearch

Spring boot 集成 ElasticSearch 文章目录 Spring boot 集成 ElasticSearch一&#xff1a;前置工作1&#xff1a;项目搭建和依赖导入2&#xff1a;客户端连接相关构建3&#xff1a;实体类相关注解配置说明 二&#xff1a;客户端client相关操作说明1&#xff1a;检索流程1.1&…

深度学习之“线性代数”

线性代数在深度学习中是解决多维数学对象计算问题的核心工具。这些数学对象包括标量、向量、矩阵和张量&#xff0c;借助它们可以高效地对数据进行操作和建模。以下将详细介绍这些数学对象及其在深度学习中的典型用途。 数学对象概述 标量 标量是最简单的数学对象&#xff0…

使用PyQt5绘制带有刻度的温度计控件

前言&#xff1a;进入学习Python开发上位机界面的第二阶段&#xff0c;学习如何开发自定义控件&#xff0c;从常用的控件入手学习&#xff0c;本期主要学习如何使用PyQt5绘制带有刻度的温度计控件。 1. 先找到一篇参考文章 参考文章&#xff1a;Qt编写自定义控件5-柱状温度计…

问deepseek,如何看待ai降低学习成本而导致软件开发岗位需求降低,和工资下降。 软件从业人员何去何从?

它给我的回答是这样的&#xff1a; 思考逻辑 嗯&#xff0c;用户问的是AI如何降低学习成本&#xff0c;进而导致软件开发岗位需求减少和工资下降&#xff0c;以及软件从业人员该怎么办。这个问题挺复杂的&#xff0c;我得先理清楚各个部分。首先&#xff0c;AI确实在改变很多行…

Error: Expected a mutable image

你的函数用了不支持的图片格式比如我的人脸检测&#xff0c;本来要RGB565我却用JPEG所以报错

海思ISP开发说明

1、概述 ISP&#xff08;Image Signal Processor&#xff09;图像信号处理器是专门用于处理图像信号的硬件或处理单元&#xff0c;广泛应用于图像传感器&#xff08;如 CMOS 或 CCD 传感器&#xff09;与显示设备之间的信号转换过程中。ISP通过一系列数字图像处理算法完成对数字…

2.攻防世界PHP2及知识点

进入题目页面如下 意思是你能访问这个网站吗&#xff1f; ctrlu、F12查看源码&#xff0c;什么都没有发现 用kali中的dirsearch扫描根目录 命令如下&#xff0c;根据题目提示以及需要查看源码&#xff0c;扫描以php、phps、html为后缀的文件 dirsearch -u http://61.147.17…