实现多层感知机

目录

多层感知机:

介绍:

代码实现:

运行结果:

问题答疑:

线性变换与非线性变换

参数含义

为什么清除梯度?

反向传播的作用

为什么更新权重?


多层感知机:

介绍:

缩写:MLP,这是一种人工神经网络,由一个输入层、一个或多个隐藏层以及一个输出层组成,每一层都由多个节点(神经元)构成。在MLP中,节点之间只有前向连接,没有循环连接,这使得它属于前馈神经网络的一种。每个节点都应用一个激活函数,如sigmoid、ReLU等,以引入非线性,从而使网络能够拟合复杂的函数和数据分布。

代码实现:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# Step 1: Define the MLP model
class SimpleMLP(nn.Module):def __init__(self):super(SimpleMLP, self).__init__()self.fc1 = nn.Linear(784, 128)  # Input layer to hidden layerself.fc2 = nn.Linear(128, 64)   # Hidden layer to another hidden layerself.fc3 = nn.Linear(64, 10)    # Hidden layer to output layerself.relu = nn.ReLU()def forward(self, x):x = x.view(-1, 784)             # Flatten the input from 28x28 to 784x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.fc3(x)return x# Step 2: Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# Step 3: Define loss function and optimizer
model = SimpleMLP()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# Step 4: Train the model
num_epochs = 5
for epoch in range(num_epochs):for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))# Step 5: Evaluate the model on the test set (optional)
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))

运行结果:

问题答疑:

线性变换与非线性变换

在神经网络中

线性变换通常指的是权重矩阵和输入数据的矩阵乘法,再加上偏置向量。数学上,对于一个输入向量𝑥x和权重矩阵𝑊W,加上偏置向量𝑏b,线性变换可以表示为: 𝑧=𝑊𝑥+𝑏z=Wx+b

非线性变换是指在神经网络的每一层之后应用的激活函数,如ReLU、sigmoid或tanh等。这些函数引入了非线性,使神经网络能够学习和表达复杂的函数关系。没有非线性变换,无论多少层的神经网络最终都将简化为一个线性模型。

参数含义

在上述模型中,参数如784, 128, 64, 10并不是字节,而是神经网络层的尺寸,具体来说是神经元的数量:

  • 784: 这是输入层的神经元数量,对应于MNIST数据集中每个图片的像素数量。MNIST的图片是28x28像素,因此总共有784个像素点。
  • 128 和 64: 这是两个隐藏层的神经元数量。它们代表了第一层和第二层的宽度,即这一层有多少个神经元。
  • 10: 这是输出层的神经元数量,对应于MNIST数据集中的10个数字类别(0到9)。

为什么清除梯度?

在每一次前向传播和反向传播过程中,梯度会被累积在张量的.grad属性中。如果不手动清零,这些梯度将会被累加,导致不正确的梯度值。因此,在每次迭代开始之前,都需要调用optimizer.zero_grad()来清空梯度。

反向传播的作用

反向传播(Backpropagation)是一种算法,用于计算损失函数相对于神经网络中所有权重的梯度。它的目的是为了让神经网络知道,当损失函数值较高时,哪些权重需要调整,以及调整的方向和幅度。这些梯度随后被用于权重更新,以最小化损失函数。

为什么更新权重?

权重更新是基于梯度下降算法进行的。在反向传播计算出梯度后,权重通过optimizer.step()函数更新,以朝着减小损失函数的方向移动。

这是训练神经网络的核心,即通过不断调整权重和偏置,使模型能够更好地拟合训练数据,从而提高预测准确性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/376484.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【C++】———— 继承

作者主页: 作者主页 本篇博客专栏:C 创作时间 :2024年7月5日 一、什么是继承? 继承的概念 定义: 继承机制就是面向对象设计中使代码可以复用的重要手段,它允许在程序员保持原有类特性的基础上进行扩展…

uniapp+vue3嵌入Markdown格式

使用的库是towxml 第一步:下载源文件,那么可以git clone,也可以直接下载压缩包 git clone https://github.com/sbfkcel/towxml.git 第二步:设置文件夹内的config.js,可以选择自己需要的格式 第三步:安装…

redisTemplate报错为nil,通过redis-cli查看前缀有乱码

public void set(String key, String value, long timeout) {redisTemplate.opsForValue().set(key, value, timeout, TimeUnit.SECONDS);} 改完之后 public void set(String key, String value, long timeout) {redisTemplate.setKeySerializer(new StringRedisSerializer()…

以太网中的各种帧结构

帧结构(Ethernet Frame Structure)介绍 以太网信号帧结构(Ethernet Signal Frame Structure),有被称为以太网帧结构,一般可以分为两类 —— 数据帧和管理帧。 按照 IEEE 802.3,ISO/IEC8803-3 …

跨域解决方案

跨域 当发起请求的协议号、域名、端口号中有一个不一样时就会导致跨域 跨域解决方案 分为两个方面,是否可以修改服务器端。 可以修改服务器端:cors方案、jsonp方案 不可以修改服务器端: 使用代理: 因为跨域主要是针对浏览器…

springboot+vue 开发记录(九)后端打包部署运行

本篇文章主要内容是后端项目写好了,怎么打包部署到服务器上运行。 文章目录 1. 在服务器上安装Docker2. 在Docker中装MySQL3. 在Docker中设置网桥,实现容器间的网络通信4. 修改后端配置文件5. 修改pom.xml文件6. 打包7. 编写DockerFile文件8. 上传文件到…

Qt MV架构-视图类

一、基本概念 在MV架构中,视图包含了模型中的数据项,并将它们呈现给用户。数据项的表示方法,可能和数据项在存储时用的数据结构完全不同。 这种内容与表现分离之所以能够实现,是因为使用了 QAbstractItemModel提供的一个标准模…

Learning vtkjs之hello vtk

学习vtkjs 最近由于工作需要,开始学习vtkjs的相关内容,发现其实在医疗和工业领域,这个vtk的库的example还是非常有帮助,但是实际用的一些开发工具,或者研发生态却没有three的好,也就是能抄写的东西不多&am…

Java常用排序算法

冒泡排序(Bubble Sort) arr[0] 与 arr[1]比较,如果前面元素大就交换,如果后边元素大就不交换。然后依次arr[1]与arr[2]比较,第一轮将最大值排到最后一位。 第二轮arr.length-1个元素进行比较,将第二大元素…

数据处理-Matplotlib 绘图展示

文章目录 1. Matplotlib 简介2. 安装3. Matplotlib Pyplot4. 绘制图表1. 折线图2. 散点图3. 柱状图4. 饼图5. 直方图 5. 中文显示 1. Matplotlib 简介 Matplotlib 是 Python 的绘图库,它能让使用者很轻松地将数据图形化,并且提供多样化的输出格式。 Ma…

C++ | Leetcode C++题解之第232题用栈实现队列

题目&#xff1a; 题解&#xff1a; class MyQueue { private:stack<int> inStack, outStack;void in2out() {while (!inStack.empty()) {outStack.push(inStack.top());inStack.pop();}}public:MyQueue() {}void push(int x) {inStack.push(x);}int pop() {if (outStac…

Linux 下 redis 集群部署

目录 1. redis下载 2. 环境准备 3. redis部署 3.1 修改系统配置文件 3.2 开放端口 3.3 安装 redis 3.4 验证 本文将以三台服务器为例&#xff0c;介绍在 linux 系统下redis的部署方式。 1. redis下载 下载地址&#xff1a;Index of /releases/ 选择需要的介质下载&am…

Windows安装linux子系统

Windows安装linux子系统 步骤 1 - 启用适用于 Linux 的 Windows 子系统 需要先启用“适用于 Linux 的 Windows 子系统”可选功能&#xff0c;然后才能在 Windows 上安装 Linux 分发。 以管理员身份打开 PowerShell&#xff08;“开始”菜单 >“PowerShell” >单击右键 …

uniapp 支付宝小程序 芝麻免押 免押金

orderStr参数如下&#xff1a; my.tradePay({orderStr:res, // 完整的支付参数拼接成的字符串&#xff0c;从 alipay.fund.auth.order.app.freeze 接口获取success: (res) > {console.log(免押成功);console.log(JSON.stringify(res),不是JOSN);console.log(JSON.stringify…

使用机器学习 最近邻算法(Nearest Neighbors)进行点云分析 (scikit-learn Open3D numpy)

使用 NearestNeighbors 进行点云分析 在数据分析和机器学习领域&#xff0c;最近邻算法&#xff08;Nearest Neighbors&#xff09;是一种常用的非参数方法。它广泛应用于分类、回归和聚类分析等任务。下面将介绍如何使用 scikit-learn 库中的 NearestNeighbors 类来进行点云数…

前端JS特效第33波:jQuery旋转木马焦点图轮播插件PicCarousel

jQuery旋转木马焦点图轮播插件PicCarousel&#xff0c;先来看看效果&#xff1a; 部分核心的代码如下&#xff1a; <!doctype html> <html> <head> <meta charset"utf-8"> <meta http-equiv"X-UA-Compatible" content"IE…

FLinkCDC引起的生产事故(二)

背景&#xff1a; 最近在做实时数据的抽取工作&#xff0c;利用FLinkCDC实时抽取目标库Oracle的数据到Doris中&#xff0c;但是在抽取的过程中&#xff0c;会导致目标库的生产库数据库非常卡顿&#xff0c;为了避免对生产环境的数据库造成影响&#xff0c;对生产环境的数据库利…

element UI时间组件两种使用方式

加油&#xff0c;新时代打工&#xff01; 组件官网&#xff1a;https://element.eleme.cn/#/zh-CN/component/date-picker 先上效果图&#xff0c;如下&#xff1a; 第一种实现方式 <div class"app-container"><el-formref"submitForm":model&q…

11计算机视觉—语义分割与转置卷积

目录 1.语义分割应用语义分割和实例分割2.语义分割数据集:Pascal VOC2012 语义分割数据集预处理数据:我们使用图像增广中的随机裁剪,裁剪输入图像和标签的相同区域。3.转置卷积 上采样填充、步幅和多通道填充步幅多通道转置卷积是一种卷积:重新排列输入和核转置卷积是一种卷…

机器学习筑基篇,Jupyter Notebook 精简指南

[ 知识是人生的灯塔&#xff0c;只有不断学习&#xff0c;才能照亮前行的道路 ] 0x00 Jupyter Notebook 简明指南 描述&#xff1a;前面我们已经在机器学习工作站&#xff08;Ubuntu 24.04 Desktop Geforce RTX 4070Ti SUPER&#xff09;中安装 Anaconda 工具包&#xff0c;其…