神经网络代码入门解析

神经网络代码入门解析

import torch
import matplotlib.pyplot as pltimport randomdef create_data(w, b, data_num):  # 数据生成x = torch.normal(0, 1, (data_num, len(w)))y = torch.matmul(x, w) + b  # 矩阵相乘再加bnoise = torch.normal(0, 0.01, y.shape)  # 为y添加噪声y += noisereturn x, ynum = 500true_w = torch.tensor([8.1, 2, 2, 4])
true_b = 1.1X, Y = create_data(true_w, true_b, num)# plt.scatter(X[:, 3], Y, 1)  # 画散点图 对X取全部的行的第三列,标签Y,点大小
# plt.show()def data_provider(data, label, batchsize):  # 每次取batchsize个数据length = len(label)indices = list(range(length))# 这里需要把数据打乱random.shuffle(indices)for each in range(0, length, batchsize):get_indices = indices[each: each+batchsize]get_data = data[get_indices]get_label = label[get_indices]yield get_data, get_label  # 有存档点的returnbatchsize = 16
# for batch_x, batch_y in data_provider(X, Y, batchsize):
#     print(batch_x, batch_y)
#     break# 定义模型
def fun(x, w, b):pred_y = torch.matmul(x, w) + breturn pred_y# 定义loss
def maeLoss(pre_y, y):return torch.sum(abs(pre_y-y))/len(y)# sgd(梯度下降)
def sgd(paras, lr):with torch.no_grad():  # 这部分代码不计算梯度for para in paras:para -= para.grad * lr  # 不能写成 para = para - paras.grad * lr !!!! 这句相当于要创建一个新的para,会导致报错para.grad.zero_()  # 将使用过的梯度归零lr = 0.01
w_0 = torch.normal(0, 0.01, true_w.shape, requires_grad=True)
b_0 = torch.tensor(0.01, requires_grad=True)
print(w_0, b_0)epochs = 50
for epoch in range(epochs):data_loss = 0for batch_x, batch_y in data_provider(X, Y, batchsize):pred_y = fun(batch_x, w_0, b_0)loss = maeLoss(pred_y, batch_y)loss.backward()sgd([w_0, b_0], lr)data_loss += lossprint("epoch %03d: loss: %.6f" % (epoch, data_loss))print("真实函数值:", true_w, true_b)
print("训练得到的函数值:", w_0, b_0)idx = 0
plt.plot(X[:, idx].detach().numpy(), X[:, idx].detach().numpy()*w_0[idx].detach().numpy()+b_0.detach().numpy())
plt.scatter(X[:, idx].detach().numpy(), Y, 1)
plt.show()

逐步分析代码

1.数据生成

image-20250301120222530

首先设计一个函数create_data,提供我们所需要的数据集的x与y

def create_data(w, b, data_num):  # 数据生成x = torch.normal(0, 1, (data_num, len(w)))  # 生成特征数据,形状为 (data_num, len(w))y = torch.matmul(x, w) + b  # 计算目标值 y = x * w + bnoise = torch.normal(0, 0.01, y.shape)  # 生成噪声,形状与 y 相同y += noise  # 为 y 添加噪声,模拟真实数据中的随机误差return x, y
  • torch.normal() 生成一个张量

    • torch.normal(0, 1, (data_num, len(w))):生成一个形状为 (data_num, len(w)) 的张量,其中的元素是从均值为 0、标准差为 1 的正态分布中随机采样的。
  • torch.matmul() 让矩阵相乘

    matmul: matrix multiply

  • 再使用torch.normal()生成一个张量,添加到y上,相当于为y添加了随机的噪声

    噪声的引入是为了模拟真实数据中的随机误差,使生成的数据更接近现实场景。

2.设计一个数据加载器

def data_provider(data, label, batchsize):  # 每次取 batchsize 个数据length = len(label)indices = list(range(length))random.shuffle(indices)  # 打乱数据顺序,避免模型学习到顺序特征for each in range(0, length, batchsize):get_indices = indices[each: each+batchsize]  # 获取当前批次的索引get_data = data[get_indices]  # 获取当前批次的数据get_label = label[get_indices]  # 获取当前批次的标签yield get_data, get_label  # 返回当前批次的数据和标签

data_provider可以分批提供数据,并通过yield来返回已实现记忆功能

首先把list y顺序打乱,这样就相当于从生成的训练集y中随机读取,若不打乱数据,可能造成训练结果的不理想

打乱数据可以避免模型在训练过程中学习到数据的顺序特征,从而提高模型的泛化能力。

之后分段遍历打乱的y,返回对应的局部的数据集来给神经网络进行训练

3.定义模型函数

image-20250301122853184

def fun(x, w, b):pred_y = torch.matmul(x, w) + b  # 计算预测值 y = x * w + breturn pred_y

fun(x, w, b) 是一个线性模型,形式为 y = x * w + b,其中 x 是输入特征,w 是权重,b 是偏置。

4.定义Loss函数

image-20250301122958888

def maeLoss(pre_y, y):return torch.sum(abs(pre_y - y)) / len(y)  # 计算平均绝对误差 (MAE)
  • maeLoss 是平均绝对误差(Mean Absolute Error, MAE),它计算预测值 pre_y 和真实值 y 之间的绝对误差的平均值。
  • 公式为:MAE = (1/n) * Σ|pre_y - y|,其中 n 是样本数量。

5.梯度下降sgd函数

# sgd(梯度下降)
def sgd(paras, lr):with torch.no_grad():  # 这部分代码不计算梯度for para in paras:para -= para.grad * lr  # 不能写成 para = para - paras.grad * lr !!!! 这句相当于要创建一个新的para,会导致报错para.grad.zero_()  # 将使用过的梯度归零

这里需要使用torch.no_grad()来避免重复计算梯度

image-20250301123531781

在前向过程中已经累计过一次梯度了,如果在梯度下降过程中又累计了梯度,那么就会造成不必要的麻烦

PyTorch 会累积梯度,如果不手动清零,梯度会不断累积,导致参数更新错误。

para -= para.grad * lr就是将参数w修正的过程(w=w-(dy^/dw)*learningRate)

torch.no_grad() 是一个上下文管理器,用于禁用梯度计算。在参数更新时,禁用梯度计算可以避免不必要的计算和内存占用。

5.开始训练

epochs = 50
for epoch in range(epochs):data_loss = 0num_batches = len(Y) // batchsize  # 计算批次数量for batch_x, batch_y in data_provider(X, Y, batchsize):pred_y = fun(batch_x, w_0, b_0)  # 前向传播loss = maeLoss(pred_y, batch_y)  # 计算损失loss.backward()  # 反向传播sgd([w_0, b_0], lr)  # 更新参数data_loss += loss.item()  # 累积损失print("epoch %03d: loss: %.6f" % (epoch, data_loss / num_batches))  # 打印平均损失

先定义一个训练轮次epochs=50,表示训练50轮

在每轮训练中将loss记录下来,以此评价训练的效果

首先用data_provider来获取数据集中随机的一部分

接着传入相应数据给模型函数,通过前向传播获得预测y值pred_y

调用Loss计算函数,获取这次的loss,再通过反向传播loss.backward()计算梯度

loss.backward() 是反向传播的核心步骤,用于计算损失函数对模型参数的梯度。

再通过梯度下降sgd([w_0, b_0], lr)来更新模型的参数

最终将这组数据的loss累加到这轮数据的loss中

6.结果绘制

idx = 0
plt.plot(X[:, idx].detach().numpy(), X[:, idx].detach().numpy() * w_0[idx].detach().numpy() + b_0.detach().numpy())  # 绘制预测直线
plt.scatter(X[:, idx].detach().numpy(), Y, 1)  # 绘制真实数据点
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/26405.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解决android studio(ladybug版本) gradle的一些task突然消失了

今天不知道干了啥,AS(ladybug版本)右边gradle的task有些不见了,研究了半天解决了,这里记录下: 操作: File -->Settings-->Experimental--> 取消选项“Enable support for multi-vari…

软件测试之白盒测试知识总结

🍅 点击文末小卡片 ,免费获取软件测试全套资料,资料在手,涨薪更快 概念与定义 白盒测试:侧重于系统或部件内部机制的测试,类型分为分支测试(判定节点测试)、路径测试、语句测试…

Unity中动态切换光照贴图的方法

关键代码:LightmapSettings.lightmaps lightmapDatas; LightmapData中操作三张图:lightmapColor,lightmapDir,以及一张ShadowMap 这里只操作前两张: using UnityEngine; using UnityEngine.EventSystems; using UnityEngine.UI;public cl…

LLC谐振变换器恒压恒流双竞争闭环simulink仿真

1.模型简介 本仿真模型基于MATLAB/Simulink(版本MATLAB 2017Ra)软件。建议采用matlab2017 Ra及以上版本打开。(若需要其他版本可联系代为转换)针对全桥LLC拓扑,利用Matlab软件搭建模型,分别对轻载&#xf…

网络变压器的主要电性参数与测试方法(2)

Hqst盈盛(华强盛)电子导读:网络变压器的主要电性参数与测试方法(2).. 今天我们继续来看看网络变压器的2个主要电性参数与它的测试方法: 1. 线圈间分布电容Cp:线圈间杂散静电容 测试条件:100KHz/0.1…

前端正则表达式完全指南:从入门到实战

文章目录 第一章:正则表达式基础概念1.1 什么是正则表达式1.2 正则表达式工作原理1.3 基础示例演示 第二章:正则表达式核心语法2.1 元字符大全表2.2 量词系统详解2.3 字符集合与排除 第三章:前端常用正则模式3.1 表单验证类3.1.1 邮箱验证3.1…

C++Primer学习(4.8位运算符)

4.8位运算符 位运算符作用于整数类型的运算对象,并把运算对象看成是二进制位的集合。位运算符提供检查和设置二进制位的功能,如17.2节(第640页)将要介绍的,一种名为bitset的标准库类型也可以表示任意大小的二进制位集合,所以位运算符同样能用…

排序算法(3):

这是我们的最后一篇排序算法了,也是我们的初阶数据结构的最后一篇了。 我们来看,我们之前已经讲完了插入排序,选择排序,交换排序,我们还剩下最后一个归并排序,我们今天就讲解归并排序,另外我们还…

【Java项目】基于SpringBoot的Java学习平台

【Java项目】基于SpringBoot的Java学习平台 技术简介:采用Java技术、SpringBoot框架、MySQL数据库等实现。系统基于B/S架构,前端通过浏览器与后端数据库进行信息交互,后端使用SpringBoot框架和MySQL数据库进行数据处理和存储,实现…

单例模式——c++

一个类,只能有1个对象 (对象在堆空间) 再次创建该对象,直接引用之前的对象 so构造函数不能随意调用 so构造函数私有 so对象不能构造 如何调用私有化的构造函数: 公开接口调用构造函数 调用构造函数:singleTon instance; 但…

lqb官方题单-速成刷题清单(上) - python版

预计3月5日 Wednesday 前完成 【2025年3月1日,记】题目太简单了,3月3日前完成 蓝桥杯速成刷题清单(上) https://www.lanqiao.cn/problems/1216/learning/?problem_list_id30&page1 替换题号1216 目录 进度题解和碎碎念1. 排…

虚拟化园区网络部署指南

《虚拟化园区网络部署指南》属于博主的“园区网”专栏,若想成为HCIE,对于园区网相关的知识需要非常了解,更多关于园区网的内容博主会更新在“园区网”专栏里,请持续关注! 一.前言 华为CloudCampus解决方案基于智简网络…

Java数据结构第十五期:走进二叉树的奇妙世界(四)

专栏:Java数据结构秘籍 个人主页:手握风云 目录 一、二叉树OJ练习题(续) 1.1. 二叉树的层序遍历 1.2. 二叉树的最近公共祖先 1.3. 从前序与中序遍历序列构造二叉树 1.4. 从中序与后序遍历序列构造二叉树 1.5. 根据二叉树创建…

ISP 常见流程

1.sensor输出:一般为raw-OBpedestal。加pedestal避免减OB出现负值,同时保证信号超过ADC最小电压阈值,使信号落在ADC正常工作范围。 2. pedestal correction:移除sensor加的基底,确保后续处理信号起点正确。 3. Linea…

Java异常

一,Java异常概述 1.异常概述: 异常:在我们程序运行过程中出现的非正常情况 在开发中,即使我们的代码写的很完善,也有可能由于一些外因(用户输入有误,文件被删除,网络问题&#xff…

Linux下的网络通信编程

在不同主机之间,进行进程间的通信。 1解决主机之间硬件的互通 2.解决主机之间软件的互通. 3.IP地址:来区分不同的主机(软件地址) 4.MAC地址:硬件地址 5.端口号:区分同一主机上的不同应用进程 网络协议…

Metal 学习笔记五:3D变换

在上一章中,您通过在 vertex 函数中计算position,来平移顶点和在屏幕上移动对象。但是,在 3D 空间中,您还想执行更多操作,例如旋转和缩放对象。您还需要一个场景内摄像机,以便您可以在场景中移动。 要移动…

数据集笔记:新加坡LTA MRT 车站出口、路灯 等位置数据集

1 MRT 车站出口 data.gov.sg (geojson格式) 1.1 kml格式 data.gov.sg 2 路灯 data.govsg ——geojson data.gov.sg——kml 版本 3 道路摄像头数据集 data.gov.sg 4 自行车道网络 data.gov.sg 5 学校区域 data.gov.sg 6 自行车停车架&#xff…

【弹性计算】弹性裸金属服务器和神龙虚拟化(一):功能特点

弹性裸金属服务器和神龙虚拟化(一):功能特点 特征一:分钟级交付特征二:兼容 VPC、SLB、RDS 等云平台全业务特征三:兼容虚拟机镜像特征四:云盘启动和数据云盘动态热插拔特征五:虚拟机…

发展中的脑机接口:SSVEP特征提取技术

一、简介 脑机接口(BCI)是先进的系统,能够通过分析大脑信号与外部设备之间建立通信,帮助有障碍的人与环境互动。BCI通过分析大脑信号,提供了一种非侵入式、高效的方式,让人们与外部设备进行交流。BCI技术越…