神经网络以及简单的神经网络模型实现

神经网络基本概念:

  1. 神经元(Neuron)

    神经网络的基本单元,接收输入,应用权重并通过激活函数生成输出
  2. 层(Layer)

    神经网络由多层神经元组成。常见的层包括输入层、隐藏层和输出层
  3. 权重(Weights)和偏置(Biases)

    权重用于调整输入的重要性,偏置用于调整模型的输出
  4. 激活函数(Activation Function)

    在神经元中引入非线性,如ReLU(Rectified Linear Unit)、Sigmoid、Tanh等。
  5. 损失函数(Loss Function)

    用于衡量模型预测与实际结果之间的差异,如均方误差(MSE)、交叉熵损失等。
  6. 优化器(Optimizer)

    用于调整模型权重以最小化损失函数,如随机梯度下降(SGD)、Adam等。

简单的神经网络示例:

下面是一个使用PyTorch构建简单线性回归的神经网络示例代码。这个示例展示了如何定义一个具有一个隐藏层的前馈神经网络,并训练它来逼近一些随机生成的数据点。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt# 生成一些随机数据
np.random.seed(0)
X = np.linspace(0, 10, 100).reshape(-1, 1).astype(np.float32)
y = np.sin(X) + np.random.normal(0, 0.1, size=X.shape).astype(np.float32)# 转换为PyTorch的张量
X_tensor = torch.tensor(X)
y_tensor = torch.tensor(y)# 定义一个简单的神经网络模型
class NeuralNet(nn.Module):def __init__(self):super(NeuralNet, self).__init__()self.fc1 = nn.Linear(1, 10)  # 输入层到隐藏层self.relu = nn.ReLU()        # 激活函数self.fc2 = nn.Linear(10, 1)  # 隐藏层到输出层def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)return x# 实例化模型、损失函数和优化器
model = NeuralNet()
criterion = nn.MSELoss()  # 均方误差损失函数
optimizer = optim.Adam(model.parameters(), lr=0.01)  # Adam优化器# 训练模型
epochs = 5000
losses = []
for epoch in range(epochs):optimizer.zero_grad()outputs = model(X_tensor)loss = criterion(outputs, y_tensor)loss.backward()optimizer.step()losses.append(loss.item())if (epoch+1) % 1000 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.6f}')# 绘制损失函数变化图
plt.plot(losses, label='Training loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()# 测试模型
model.eval()
with torch.no_grad():test_x = torch.tensor([[5.0]])  # 测试输入predicted = model(test_x)print(f'预测值: {predicted.item()}')

运行结果展示: 

代码理解:

下面便是详细分解这段代码进行理解: 

  • 生成数据

    • 使用 numpy 生成一些随机的带有噪声的正弦函数数据。
      import numpy as np# 生成带有正态分布噪声的正弦函数数据
      def generate_data(n_samples):np.random.seed(0)  # 设置随机种子以确保结果可复现X = np.random.uniform(low=0, high=10, size=n_samples)y = np.sin(X) + np.random.normal(scale=0.3, size=n_samples)return X, y# 生成数据
      X_train, y_train = generate_data(100)
      

  • 定义神经网络模型

    • NeuralNet 类继承自 nn.Module,定义了一个具有一个隐藏层的前馈神经网络。使用ReLU作为隐藏层的激活函数。
      import torch
      import torch.nn as nn
      import torch.optim as optim# 定义神经网络模型
      class NeuralNet(nn.Module):def __init__(self):super(NeuralNet, self).__init__()self.fc1 = nn.Linear(1, 10)  # 输入大小为1(X),输出大小为10self.fc2 = nn.Linear(10, 1)  # 输入大小为10,输出大小为1self.relu = nn.ReLU()def forward(self, x):x = self.relu(self.fc1(x))x = self.fc2(x)return x# 实例化模型
      model = NeuralNet()# 打印模型结构
      print(model)
      

  • 实例化模型、损失函数和优化器

    • model 是我们定义的神经网络模型。
    • criterion 是损失函数,这里使用均方误差损失。
    • optimizer 是优化器,这里使用Adam优化器来更新模型参数。
      # 定义损失函数(均方误差损失)
      criterion = nn.MSELoss()# 定义优化器(Adam优化器)
      optimizer = optim.Adam(model.parameters(), lr=0.01)
      

  • 训练模型

    • 使用 X_tensor 和 y_tensor 进行训练,优化模型使其逼近 y_tensor
      # 将numpy数组转换为PyTorch张量
      X_tensor = torch.tensor(X_train, dtype=torch.float32).view(-1, 1)
      y_tensor = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)# 训练模型
      def train_model(model, criterion, optimizer, X, y, epochs=1000):model.train()for epoch in range(epochs):optimizer.zero_grad()output = model(X)loss = criterion(output, y)loss.backward()optimizer.step()if (epoch+1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')train_model(model, criterion, optimizer, X_tensor, y_tensor)
      

  • 测试模型

    • 使用 model.eval() 将模型切换到评估模式,使用 torch.no_grad() 关闭梯度计算。
    • 测试输入为 5.0,打印预测结果。
      # 测试模型
      model.eval()
      with torch.no_grad():test_input = torch.tensor([[5.0]], dtype=torch.float32)predicted_output = model(test_input)print(f'预测输入为 5.0 时的输出: {predicted_output.item():.4f}')
      

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/376506.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Camunda如何通过外部任务与其他系统自动交互

文章目录 简介流程图外部系统pom.xmllogback.xml监听类 启动流程实例常见问题Public Key Retrieval is not allowed的解决方法java.lang.reflect.InaccessibleObjectException 流程图xml 简介 前面我们已经介绍了Camunda的基本操作、任务、表: Camunda组件与服务与…

浏览器插件使用方法

如果我们经常使用的浏览器不是edge或者是chrome浏览器时,需要在浏览器安装插件时,无法获取插件以及不知道如何安装插件,本文章教你如何获取以及安装使用。 获取方法 第一种方法(推荐) 无需“魔法”,即可访问…

多表联合的查询(实例)、对于前端返回数据有很多表,可以分开操作、debug调试教程

2024.7.13 一、 对于多表的更深层的认识1. 认识2. 多表联合查询的列子:3. 对于多表查询的进一步认识4. 在实现功能的时候,原本对于省市县这样的表,对于项目的要求,是直接全部查询出来,然后开始使用,但我想着…

PDF 中图表的解析探究

PDF 中图表的解析探究 0. 引言1. 开源方案探究 0. 引言 一直以来,对文档中的图片和表格处理都非常有挑战性。这篇文章记录一下最近工作上在这块的探究。图表分为图片和表格,这篇文章主要记录了对表格的探究。还有,我个人主要做日本项目&…

如何解决VMware 安装Windows10系统出现Time out EFI Network...

一、问题描述 使用VMware 17 安装windows10出现如下图所示Time out EFI Network… Windows10镜像为微软官方下载的ISO格式镜像; 二、问题分析 VMware 17 默认的固件类型是UEFI(E),而微软官网下载的Windows10 ISO格式镜像不支持UEFI(E),支…

Android APT实战

Android开发中,注解平时我们用的比较多,也许我们会比较好奇,注解的背后是如何工作的,这篇文章帮大家一步步创建一个简单的注解处理器。 简介 APT(Annotation Processing Tool)即注解处理器,在编译的时候可以处理注解然后搞一些事情,也可以在编译时生成一些文件之类的。…

网络安全——防御课实验二

在实验一的基础上,完成7-11题 拓扑图 7、办公区设备可以通过电信链路和移动链路上网(多对多的NAT,并且需要保留一个公网IP不能用来转换) 首先,按照之前的操作,创建新的安全区(电信和移动)分别表示两个外网…

nginx的四层负载均衡实战

目录 1 环境准备 1.1 mysql 部署 1.2 nginx 部署 1.3 关闭防火墙和selinux 2 nginx配置 2.1 修改nginx主配置文件 2.2 创建stream配置文件 2.3 重启nginx 3 测试四层代理是否轮循成功 3.1 远程链接通过代理服务器访问 3.2 动图演示 4 四层反向代理算法介绍 4.1 轮询&#xff0…

大数据基础:Hadoop之MapReduce重点架构原理

文章目录 Hadoop之MapReduce重点架构原理 一、MapReduce概念 二、MapReduce 编程思想 2.1、Map阶段 2.2、Reduce阶段 三、MapReduce处理数据流程 四、MapReduce Shuffle 五、MapReduce注意点 六、MapReduce的三次排序 Hadoop之MapReduce重点架构原理 一、MapReduce概…

在word中删除endnote参考文献之间的空行

如图,在References中,每个文献之间都有空行。不建议手动删除。打开Endnote。 打开style manager 删除layout中的换行符。保存,在word中更新参考文献即可。

初阶数据结构—排序

第一章:排序的概念及其运用 1.1 排序的概念 排序:所谓排序,就是使一串记录,按照其中的某个或某些关键字的大小,递增或递减的排列起来的操作。 稳定性:假定在待排序的记录序列中,存在多个具有…

宝塔:如何开启面板ssl并更新过期ssl

1、登录宝塔面板 > 前往面板设置 > 最上方的安全设置 > 面板SSL > 面板SSL配置 打开后先查看自签证书的时间,如果时间是已经过期的,就前往这个目录,将该目录下所有文件都删掉 重新回到面板SSL配置的位置,打开后会看到…

复现ORB3-YOLO8项目记录

文章目录 1.编译错误1.1 错误11.2 错误21.3 错误31.4 错误4 1.编译错误 首先ORB-SLAM相关项目已经写过很多篇博客了,从ORB-SLAM2怎么运行,再到现在的项目。关于环境已经不想多说了 1.1 错误1 – DEPENDENCY_LIBS : /home/lvslam/ORB3-YOLO8/Thirdparty…

【web】-sql注入-login

根据网址提示打开如图: 查看源代码前台并没有过滤限制、扫描后台也没有发现特殊文件。看到标题显示flag is in database,尝试sql注入。 由于post,bp抓包如下: 运行python sqlmap.py -r 1.txt --dump 获取flag 42f4ebc342b6ed4af4aadc1ea75f…

Python打开Excel文档并读取数据

Python 版本 目前 Python 3 版本为主流版本,这里测试的版本是:Python 3.10.5。 常用库说明 Python 操作 Excel 的常用库有:xlrd、xlwt、xlutils、openpyxl、pandas。这里主要说明下 Excel 文档 .xls 格式和 .xlsx 格式的文档打开和读取。 …

实现多层感知机

目录 多层感知机: 介绍: 代码实现: 运行结果: 问题答疑: 线性变换与非线性变换 参数含义 为什么清除梯度? 反向传播的作用 为什么更新权重? 多层感知机: 介绍:…

【C++】———— 继承

作者主页: 作者主页 本篇博客专栏:C 创作时间 :2024年7月5日 一、什么是继承? 继承的概念 定义: 继承机制就是面向对象设计中使代码可以复用的重要手段,它允许在程序员保持原有类特性的基础上进行扩展…

uniapp+vue3嵌入Markdown格式

使用的库是towxml 第一步:下载源文件,那么可以git clone,也可以直接下载压缩包 git clone https://github.com/sbfkcel/towxml.git 第二步:设置文件夹内的config.js,可以选择自己需要的格式 第三步:安装…

redisTemplate报错为nil,通过redis-cli查看前缀有乱码

public void set(String key, String value, long timeout) {redisTemplate.opsForValue().set(key, value, timeout, TimeUnit.SECONDS);} 改完之后 public void set(String key, String value, long timeout) {redisTemplate.setKeySerializer(new StringRedisSerializer()…

以太网中的各种帧结构

帧结构(Ethernet Frame Structure)介绍 以太网信号帧结构(Ethernet Signal Frame Structure),有被称为以太网帧结构,一般可以分为两类 —— 数据帧和管理帧。 按照 IEEE 802.3,ISO/IEC8803-3 …