sin函数拟合

目录

一、 目的... 1

二、 模型设计... 1

2.1 输入与输出.... 1

2.2 隐藏层设计.... 1

2.3 优化算法与损失函数.... 1

2.4 神经网络结构.... 1

三、 训练... 1

3.1 数据生成.... 2

3.2 训练过程.... 2

3.3 训练参数与设置.... 2

四、 测试与分析... 2

4.1 选取不同激活函数.... 2

4.2 增加偏置.... 3

... 4

4.3 减少训练量.... 4

4.4 损失曲线分析.... 4

4.5 模型预测分析.... 5

五、 代码... 5

  • 目的

通过构建一个简单的三层神经网络,模拟正弦函数 y = sin(2πx) 的映射关系,并使用 PyTorch 框架进行训练与优化,即输入x后会产生一个和正弦函数相同结果的y。

  • 模型设计

2.1 输入与输出

本研究中的神经网络模型包括输入层、隐藏层和输出层。输入层包含一个神经元,用于接收单一的自变量 x。输出层同样包含一个神经元,输出模型计算得到的结果 y,即预测的正弦值。

2.2 隐藏层设计

网络的隐藏层包含 10 个神经元。此设计旨在增强网络的非线性表达能力,使其能够准确模拟正弦函数的波动特性。激活函数选择了 Tanh(双曲正切函数),该函数的输出范围为 [-1, 1],更符合正弦波的输出特性,相较于 Sigmoid 函数,Tanh 能更有效地模拟正弦波的起伏。

2.3 优化算法与损失函数

模型使用 Adam 优化器 进行训练。Adam 优化器结合了动量和自适应学习率,能够有效加速收敛并避免梯度消失或爆炸的情况。在损失函数的选择上,本研究使用了 均方误差(MSE)损失函数,该函数能衡量网络输出与目标正弦值之间的差异,并通过最小化损失函数来优化网络参数。

2.4 神经网络结构

模型的具体结构如下:

输入层

1 个神经元,用于接收输入 x

隐藏层

10 个神经元,激活函数为 Tanh

输出层

1 个神经元,输出拟合的正弦值

  • 训练

3.1 数据生成

为了进行模型训练,首先生成了 x 和 y 的训练数据,其中 x 在区间 [0, 1) 内均匀分布,步长为 0.01,生成 100 个数据点。对应的 y 值则通过正弦函数 y = sin(2πx) 计算得到。这些数据用于训练神经网络,使其学习到 x 与 y 之间的映射关系。

3.2 训练过程

本研究采用 随机梯度下降法(SGD 结合 Adam 优化器 对模型进行训练。训练的核心目标是最小化均方误差损失函数,以不断调整神经网络的权重和偏置。在每次迭代中,网络通过前向传播计算输出,通过反向传播计算梯度,并利用 Adam 优化器更新网络参数。训练过程的停止条件为最大迭代次数 10,000 次,损失值逐渐趋于稳定。

3.3 训练参数与设置

训练过程中使用的主要参数如下:

学习率

0.001,优化器的学习率设置为 0.001

迭代次数

最大迭代次数设置为 10,000

损失函数

均方误差(MSE)损失函数

优化器

Adam 优化器

  • 测试与分析
  1. 选取不同激活函数

如图 1和图 2所示,在本模型中,我们选择使用 Tanh 激活函数而非 Sigmoid 函数,主要是因为二者的输出范围与正弦函数的特性不匹配。Sigmoid 函数的输出范围是 (0, 1),无法有效表示正弦函数的负值部分,而正弦函数的输出范围是 [-1, 1],且具有周期性的波动。相对而言,Tanh 激活函数的输出范围为 [-1, 1],更符合正弦函数的特性,能够同时表示正负值,从而使得神经网络能够更有效地拟合正弦波的起伏。因此,选择 Tanh 激活函数有助于模型更准确地模拟正弦函数。

1 tanh双正切函数

2 sigmoid型函数

  1. 增加偏置

如图 3所示,在神经网络中,增加偏置项可以显著提升模型的拟合能力。偏置项允许每个神经元在计算时具有一个额外的自由度,使得网络能够更好地适应数据的分布。在没有偏置项的情况下,神经元的输出完全依赖于输入的加权和,限制了模型的表达能力。加入偏置项后,神经元的输出不再局限于零点,能够对输入数据进行更灵活的平移,从而更准确地捕捉到数据的特征。在拟合正弦函数的任务中,增加偏置项使得网络能够更有效地模拟正弦波的起伏,改善了拟合的效果,减少了偏差,提升了模型的预测精度。

3 增加偏置后的效果

  1. 减少训练量

减少训练的 epoch 数量可能导致模型出现欠拟合,因为模型没有足够的时间来学习数据的特征,从而无法有效捕捉到数据的复杂模式。。虽然减少 epoch 数量可以节省计算资源,但这往往以牺牲模型的表现为代价。


4减少训练的 epoch 数量

  1. 损失曲线分析

训练过程中,损失曲线的变化呈现出明显的规律性。初期,损失值较高,说明模型尚未有效学习到正弦函数的特性。随着训练的进行,损失逐渐下降,表明模型在不断优化,逐步逼近最优解。最终,损失曲线趋于平稳,接近最小值,表明模型已经学习到了数据中的规律,达到了收敛状态。

  1. 模型预测分析

通过对比模型的预测值与原始数据,可以看出,预测值与实际正弦函数的值非常接近,表明模型已成功模拟了正弦函数的行为。在可视化图中,红色的点表示预测值,蓝色的点表示实际值,两者几乎完全重合,进一步验证了模型在函数拟合任务中的高效性和准确性。

  • 代码

核心代码

介绍:这段代码是使用 Python 编写的,主要利用了 PyTorch NumPy 库来训练一个简单的神经网络模型进行数据拟合。训练过程中的损失值会被记录并展示出来,同时还会展示模型预测结果与原始数据的对比图。

import torch

import torch.nn as nn

import numpy as np

import matplotlib.pyplot as plt

import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"  # 忽略重复的库文件警告

class Network(nn.Module):

    def __init__(self, n_in, n_hidden, n_out):

        super().__init__()

        self.layer1 = nn.Linear(n_in, n_hidden, bias=False)

        self.layer2 = nn.Linear(n_hidden, n_out, bias=False)

    def forward(self, x):

        x = self.layer1(x)

        x = torch.tanh(x)  # 使用 Tanh 激活函数

        return self.layer2(x)

def generate_data(start=0.0, end=1.0, step=0.01):

    """生成训练数据"""

    x = np.arange(start, end, step)

    y = np.sin(2 * np.pi * x)

    return x.reshape(len(x), 1), y.reshape(len(y), 1)

def train_model(model, x, y, criterion, optimizer, num_epochs=10000):

    """训练模型并返回训练过程中的损失值"""

    loss_values = []

    for epoch in range(num_epochs):

        y_pred = model(x)  # 前向传播

        loss = criterion(y_pred, y)  # 计算损失

        loss.backward()  # 反向传播

        optimizer.step()  # 更新参数

        loss_values.append(loss.item())  # 保存损失值

        optimizer.zero_grad()  # 清空梯度

        # 100次打印一次损失值

        if epoch % 100 == 0:

            print(f'After {epoch} iterations, the loss is {loss.item()}')

    return loss_values

def plot_results(x, y, h, loss_values, num_epochs):

    """绘制原始数据、预测数据和训练损失曲线"""

    fig, axs = plt.subplots(1, 2, figsize=(14, 6))  # 一行两列的子图布局

    # 第一个子图:原始数据与预测数据的散点图

    axs[0].scatter(x, y, label='Original Data')

    axs[0].scatter(x, h, label='Predicted Data', color='r')

    axs[0].set_title("Model Prediction vs Original Data")

    axs[0].legend()

    # 第二个子图:训练损失曲线

    axs[1].plot(range(num_epochs), loss_values, label='Loss Curve')

    axs[1].set_xlabel('Epochs')

    axs[1].set_ylabel('Loss')

    axs[1].set_title('Training Loss')

    axs[1].legend()

    plt.tight_layout()  # 自动调整子图间距

    plt.show()

if __name__ == '__main__':

    # 生成数据

    x, y = generate_data()

    x = torch.Tensor(x)

    y = torch.Tensor(y)

    # 初始化模型、损失函数和优化器

    model = Network(1, 10, 1)

    criterion = nn.MSELoss()  # 均方误差损失函数

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001# Adam优化器

    # 训练模型

    loss_values = train_model(model, x, y, criterion, optimizer, num_epochs=10000)

    # 获取预测值

    h = model(x).detach().numpy()  # 获取模型输出并转为numpy数组

    x = x.detach().numpy()  # 获取输入数据

    # 调用绘图函数

    plot_results(x, y, h, loss_values, num_epochs=10000)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/480061.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【鸿蒙】鸿蒙开发过程中this指向问题

文章目录 什么是 this?常见 this 指向问题案例分析:HarmonyOS 组件中的 this 指向问题问题描述问题分析原因 解决方案:绑定 this 的正确方法方法一:使用箭头函数方法二:手动绑定 this 完整代码示例使用箭头函数使用 bi…

【摸鱼】Docker配置主从mysql数据库环境

docker pull mysql拉取docker镜像,国内现在访问不了docker hub,可以去阿里云上镜像加速器地址https://cr.console.aliyun.com/cn-hangzhou/instances/mirrors启动主库docker run -p 3306:3306 --name master-mysql --privilegedtrue -v /app/docker/data…

初试无监督学习 - K均值聚类算法

文章目录 1. K均值聚类算法概述2. k均值聚类算法演示2.1 准备工作2.2 生成聚类用的样本数据集2.3 初始化KMeans模型对象,并指定类别数量2.4 用样本数据训练模型2.5 用训练好的模型生成预测结果2.6 输出预测结果2.7 可视化预测结果 3. 实战小结 1. K均值聚类算法概述…

大数据笔记

第一章、大数据概述 人类的行为及产生的事件的一种记录称之为数据。 1、大数据时代的特征,并结合生活实例谈谈带来的影响。 (一)特征 1、Volume 规模性:数据量大。 2、Velocity高速性:处理速度快。数据的生成和响…

深度学习实战老照片上色

目录 1.研究背景与意义1. 卷积神经网络(CNN)在老照片上色中的应用1.1 卷积层与特征提取1.2 颜色空间转换1.3 损失函数与训练优化 2. 生成对抗网络(GAN)在老照片上色中的应用2.1 生成器与判别器2.2 对抗训练2.3 条件生成对抗网络&a…

C#面向对象,封装、继承、多态、委托与事件实例

一.面向对象封装性编程 创建一个控制台应用程序,要求: 1.定义一个服装类(Cloth),具体要求如下 (1)包含3个字段:服装品牌(mark),服装…

养老院、学校用 安科瑞AAFD-40Z单相电能监测故障电弧探测器

安科瑞戴婷 Acrel-Fanny 安科瑞单相电能监测故障电弧探测器对接入线路中的故障电弧(包括故障并联电弧、故障串联电弧)进行有效的检测,当检测到线路中存在引起火灾的故障电弧时,探测器可以进行现场的声光报警,并将报警…

PAT甲级 1056 Mice and Rice(25)

文章目录 题目题目大意基本思路AC代码总结 题目 原题链接 题目大意 给定参赛的老鼠数量为NP,每NG只老鼠分为一组,组中最胖的老鼠获胜,并进入下一轮,所有在本回合中失败的老鼠排名都相同,获胜的老鼠继续每NG只一组&am…

[SWPUCTF 2021 新生赛]include

参考博客: 文件包含 [SWPUCTF 2021 新生赛]include-CSDN博客 NSSCTF | [SWPUCTF 2021 新生赛]include-CSDN博客 考点:php伪协议和文件包含 PHP伪协议详解-CSDN博客 php://filter php://filter可以获取指定文件源码。当它与包含函数结合时,php://filter流会被当…

spring boot3.3.5 logback-spring.xml 配置

新建 resources/logback-spring.xml 控制台输出颜色有点花 可以自己更改 <?xml version"1.0" encoding"UTF-8"?> <!--关闭文件扫描 scanfalse --> <configuration debug"false" scan"false"><springProperty …

Unity shaderlab 实现LineSDF

实现效果&#xff1a; 实现代码&#xff1a; Shader "Custom/LineSDF" {Properties{}SubShader{Tags { "RenderType""Opaque" }Pass{CGPROGRAM#pragma vertex vert#pragma fragment frag#include "UnityCG.cginc"struct appdata{floa…

PHP 去掉特殊不可见字符 “\u200e“

描述 最近在排查网站业务时&#xff0c;发现有数据匹配失败的情况 肉眼上完全看不出问题所在 当把字符串 【M24308/23-14F‎】复制出来发现 末尾有个不可见的字符 使用删除键或左右移动时才会发现 最后测试通过 var_dump 打印 发现这个"空字符"占了三个长度 &#xf…

Web会话安全测试

Web会话安全测试 - 知乎 1、会话ID不可预测性 【要求】 会话ID必须采用安全随机算法&#xff08;如SecureRandom&#xff09;生成&#xff0c;并且强度不得低于256位&#xff08;32字符&#xff09;&#xff0c;如采用Tomcat原生JSESSIONID【描述】 密码与证书等认证手段&…

springboot336社区物资交易互助平台pf(论文+源码)_kaic

毕 业 设 计&#xff08;论 文&#xff09; 社区物资交易互助平台设计与实现 摘 要 传统办法管理信息首先需要花费的时间比较多&#xff0c;其次数据出错率比较高&#xff0c;而且对错误的数据进行更改也比较困难&#xff0c;最后&#xff0c;检索数据费事费力。因此&#xff…

富文本编辑器图片上传并回显

1.概述 在代码业务需求中&#xff0c;我们会经常涉及到文件上传的功能&#xff0c;通常来说&#xff0c;我们存储文件是不能直接存储到数 据库中的&#xff0c;而是以文件路径存储到数据库中&#xff1b;但是存储文件的路径到数据库中又会有一定的问题&#xff0c;就是 浏览…

结构体详解+代码展示

系列文章目录 &#x1f388; &#x1f388; 我的CSDN主页:OTWOL的主页&#xff0c;欢迎&#xff01;&#xff01;&#xff01;&#x1f44b;&#x1f3fc;&#x1f44b;&#x1f3fc; &#x1f389;&#x1f389;我的C语言初阶合集&#xff1a;C语言初阶合集&#xff0c;希望能…

学习ASP.NET Core的身份认证(基于Session的身份认证1)

ASP.NET Core使用Session也可以实现身份认证&#xff0c;关于Session的介绍请见参考文献5。基于Session的身份认证大致原理就是用户验证成功后将用户信息保存到Session中&#xff0c;然后在其它控制器中从Session中获取用户信息&#xff0c;用户退出时清空Session数据。百度基于…

题目 3209: 蓝桥杯2024年第十五届省赛真题-好数

一个整数如果按从低位到高位的顺序&#xff0c;奇数位&#xff08;个位、百位、万位 &#xff09;上的数字是奇数&#xff0c;偶数位&#xff08;十位、千位、十万位 &#xff09;上的数字是偶数&#xff0c;我们就称之为“好数”。给定一个正整数 N&#xff0c;请计算从…

人工智能如何改变你的生活?

在我们所处的这个快节奏的世界里&#xff0c;科技融入日常生活已然成为司空见惯的事&#xff0c;并且切实成为了我们生活的一部分。在这场科技变革中&#xff0c;最具变革性的角色之一便是人工智能&#xff08;AI&#xff09;。从我们清晨醒来直至夜晚入睡&#xff0c;人工智能…

MATLAB - ROS2 ros2genmsg 生成自定义消息(msg/srv...)

系列文章目录 前言 语法 ros2genmsg(folderpath)ros2genmsg(folderpath,NameValue) 一、说明 ros2genmsg(folderpath) 通过读取指定文件夹路径下的 ROS 2 自定义信息和服务定义来生成 ROS 2 自定义信息。函数文件夹必须包含一个或多个 ROS 2 软件包。这些软件包包含 .msg 文件…