【卷积神经网络】AlexNet实践

构建模型

模版搭建

# 定义一个AlexNet模型类def __init__(self):# 调用父类的构造函数(如果继承自nn.Module的话)super(AlexNet, self).__init__()# ReLU激活函数self.ReLU = nn.ReLU()# 卷积层1:输入1个通道(灰度图),输出96个通道,卷积核大小11x11,步幅4self.c1 = nn.Conv2d(in_channels=1, out_channels=96, kernel_size=11, stride=4)# 最大池化层2:池化窗口大小为3x3,步幅为2self.s2 = nn.MaxPool2d(kernel_size=3, stride=2)# 卷积层2:输入96个通道,输出256个通道,卷积核大小为5x5,使用padding=2保持输出尺寸self.c3 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, padding=2)# 最大池化层4:池化窗口大小为3x3,步幅为2self.s4 = nn.MaxPool2d(kernel_size=3, stride=2)# 卷积层3:输入256个通道,输出384个通道,卷积核大小为3x3,使用padding=1保持输出尺寸self.c5 = nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1)# 卷积层4:输入384个通道,输出384个通道,卷积核大小为3x3,使用padding=1保持输出尺寸self.c6 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1)# 卷积层5:输入384个通道,输出256个通道,卷积核大小为3x3,使用padding=1保持输出尺寸self.c7 = nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1)# 最大池化层8:池化窗口大小为3x3,步幅为2self.s8 = nn.MaxPool2d(kernel_size=3, stride=2)# 展平层:将卷积输出的特征图展平为一维向量,供全连接层处理self.flatten = nn.Flatten()# 全连接层1:输入维度为6*6*256(假设输入图片尺寸为224x224,经过卷积和池化后的输出尺寸),输出4096维self.f1 = nn.Linear(6 * 6 * 256, 4096)# 全连接层2:输入4096维,输出4096维self.f2 = nn.Linear(4096, 4096)# 全连接层3:输入4096维,输出10维(表示10个类别的分类输出)self.f3 = nn.Linear(4096, 10)

forward

def forward(self, x):# 第1步:通过卷积层 c1 提取特征,并通过 ReLU 激活函数处理x = self.ReLU(self.c1(x))# 第2步:通过最大池化层 s2 进行池化操作,减小特征图尺寸x = self.s2(x)# 第3步:通过卷积层 c3 提取更高层次的特征,并通过 ReLU 激活函数处理x = self.ReLU(self.c3(x))# 第4步:通过最大池化层 s4 进行池化操作x = self.s4(x)# 第5步:通过卷积层 c5 提取特征,并通过 ReLU 激活函数处理x = self.ReLU(self.c5(x))# 第6步:通过卷积层 c6 提取更多特征,并通过 ReLU 激活函数处理x = self.ReLU(self.c6(x))# 第7步:通过卷积层 c7 提取更多特征,并通过 ReLU 激活函数处理x = self.ReLU(self.c7(x))# 第8步:通过最大池化层 s8 进行池化操作,进一步减小特征图尺寸x = self.s8(x)# 第9步:展平层,将卷积输出的特征图展平成一维向量,准备输入全连接层x = self.flatten(x)# 第10步:通过全连接层 f1 提取高层次特征,并通过 ReLU 激活函数处理x = self.ReLU(self.f1(x))# 第11步:在全连接层 f1 后应用 dropout,防止过拟合,dropout率为 50%x = F.dropout(x, 0.5)# 第12步:通过全连接层 f2 提取更多特征,并通过 ReLU 激活函数处理x = self.ReLU(self.f2(x))# 第13步:在全连接层 f2 后应用 dropout,防止过拟合,dropout率为 50%x = F.dropout(x, 0.5)# 第14步:通过最后一个全连接层 f3 得到最终的分类输出(10个类别)x = self.f3(x)# 返回输出return x

训练模型

整体代码与LeNet类似,详细分析复习LeNet实践文章

代码实现

def train_val_data_process():train_data = FashionMNIST(root='./data',train=True,transform=transforms.Compose([transforms.Resize(size=227), transforms.ToTensor()]),download=True)train_data, val_data = Data.random_split(train_data, [round(0.8*len(train_data)), round(0.2*len(train_data))])train_dataloader = Data.DataLoader(dataset=train_data,batch_size=16,shuffle=True,num_workers=2)val_dataloader = Data.DataLoader(dataset=val_data,batch_size=16,shuffle=True,num_workers=2)return train_dataloader, val_dataloaderdef train_model_process(model, train_dataloader, val_dataloader, num_epochs):# 设定训练所用到的设备,有GPU用GPU没有GPU用CPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# device = torch.device("cpu")# device = torch.device("cuda")# 使用Adam优化器,学习率为0.001optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 损失函数为交叉熵函数criterion = nn.CrossEntropyLoss()# 将模型放入到训练设备中model = model.to(device)# 复制当前模型的参数best_model_wts = copy.deepcopy(model.state_dict())# 初始化参数# 最高准确度best_acc = 0.0# 训练集损失列表train_loss_all = []# 验证集损失列表val_loss_all = []# 训练集准确度列表train_acc_all = []# 验证集准确度列表val_acc_all = []# 当前时间since = time.time()for epoch in range(num_epochs):print("Epoch {}/{}".format(epoch, num_epochs-1))print("-"*10)# 初始化参数# 训练集损失函数train_loss = 0.0# 训练集准确度train_corrects = 0# 验证集损失函数val_loss = 0.0# 验证集准确度val_corrects = 0# 训练集样本数量train_num = 0# 验证集样本数量val_num = 0# 对每一个mini-batch训练和计算for step, (b_x, b_y) in enumerate(train_dataloader):# 将特征放入到训练设备中b_x = b_x.to(device)# 将标签放入到训练设备中b_y = b_y.to(device)# 设置模型为训练模式model.train()# 前向传播过程,输入为一个batch,输出为一个batch中对应的预测output = model(b_x)# 查找每一行中最大值对应的行标pre_lab = torch.argmax(output, dim=1)# 计算每一个batch的损失函数loss = criterion(output, b_y)# 将梯度初始化为0optimizer.zero_grad()# 反向传播计算loss.backward()# 根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值的作用optimizer.step()# 对损失函数进行累加train_loss += loss.item() * b_x.size(0)# 如果预测正确,则准确度train_corrects加1train_corrects += torch.sum(pre_lab == b_y.data)# 当前用于训练的样本数量train_num += b_x.size(0)for step, (b_x, b_y) in enumerate(val_dataloader):# 将特征放入到验证设备中b_x = b_x.to(device)# 将标签放入到验证设备中b_y = b_y.to(device)# 设置模型为评估模式model.eval()# 前向传播过程,输入为一个batch,输出为一个batch中对应的预测output = model(b_x)# 查找每一行中最大值对应的行标pre_lab = torch.argmax(output, dim=1)# 计算每一个batch的损失函数loss = criterion(output, b_y)# 对损失函数进行累加val_loss += loss.item() * b_x.size(0)# 如果预测正确,则准确度train_corrects加1val_corrects += torch.sum(pre_lab == b_y.data)# 当前用于验证的样本数量val_num += b_x.size(0)# 计算并保存每一次迭代的loss值和准确率# 计算并保存训练集的loss值train_loss_all.append(train_loss / train_num)# 计算并保存训练集的准确率train_acc_all.append(train_corrects.double().item() / train_num)# 计算并保存验证集的loss值val_loss_all.append(val_loss / val_num)# 计算并保存验证集的准确率val_acc_all.append(val_corrects.double().item() / val_num)print("{} train loss:{:.4f} train acc: {:.4f}".format(epoch, train_loss_all[-1], train_acc_all[-1]))print("{} val loss:{:.4f} val acc: {:.4f}".format(epoch, val_loss_all[-1], val_acc_all[-1]))if val_acc_all[-1] > best_acc:# 保存当前最高准确度best_acc = val_acc_all[-1]# 保存当前最高准确度的模型参数best_model_wts = copy.deepcopy(model.state_dict())# 计算训练和验证的耗时time_use = time.time() - sinceprint("训练和验证耗费的时间{:.0f}m{:.0f}s".format(time_use//60, time_use%60))# 选择最优参数,保存最优参数的模型model.load_state_dict(best_model_wts)# torch.save(model.load_state_dict(best_model_wts), "C:/Users/86159/Desktop/LeNet/best_model.pth")torch.save(best_model_wts, "E:\秋招就业\CNN卷积神经网络\测试用例\AlexNet\\best_model.pth")train_process = pd.DataFrame(data={"epoch":range(num_epochs),"train_loss_all":train_loss_all,"val_loss_all":val_loss_all,"train_acc_all":train_acc_all,"val_acc_all":val_acc_all,})return train_processdef matplot_acc_loss(train_process):# 显示每一次迭代后的训练集和验证集的损失函数和准确率plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(train_process['epoch'], train_process.train_loss_all, "ro-", label="Train loss")plt.plot(train_process['epoch'], train_process.val_loss_all, "bs-", label="Val loss")plt.legend()plt.xlabel("epoch")plt.ylabel("Loss")plt.subplot(1, 2, 2)plt.plot(train_process['epoch'], train_process.train_acc_all, "ro-", label="Train acc")plt.plot(train_process['epoch'], train_process.val_acc_all, "bs-", label="Val acc")plt.xlabel("epoch")plt.ylabel("acc")plt.legend()plt.show()if __name__ == '__main__':# 加载需要的模型AlexNet = AlexNet()# 加载数据集train_data, val_data = train_val_data_process()# 利用现有的模型进行模型的训练train_process = train_model_process(AlexNet, train_data, val_data, num_epochs=8)matplot_acc_loss(train_process)

训练结果

 

poch 0/19
----------
0 train loss:0.5552 train acc: 0.7938
0 val loss:0.3909 val acc: 0.8588
训练和验证耗费的时间3m36s
Epoch 1/19
----------
1 train loss:0.3775 train acc: 0.8642
1 val loss:0.3952 val acc: 0.8555
训练和验证耗费的时间7m6s
Epoch 2/19
----------
2 train loss:0.3436 train acc: 0.8736
2 val loss:0.3405 val acc: 0.8824
训练和验证耗费的时间10m36s
Epoch 3/19
----------
3 train loss:0.3290 train acc: 0.8820
3 val loss:0.3438 val acc: 0.8778
训练和验证耗费的时间14m7s
Epoch 4/19
----------
4 train loss:0.3233 train acc: 0.8843
4 val loss:0.3224 val acc: 0.8826
训练和验证耗费的时间17m37s
Epoch 5/19
----------
5 train loss:0.3158 train acc: 0.8863
5 val loss:0.3457 val acc: 0.8764
训练和验证耗费的时间21m7s
Epoch 6/19
----------
6 train loss:0.3136 train acc: 0.8851
6 val loss:0.3431 val acc: 0.8812
训练和验证耗费的时间24m38s
Epoch 7/19
----------
7 train loss:0.3103 train acc: 0.8873
7 val loss:0.3542 val acc: 0.8716
训练和验证耗费的时间28m7s
Epoch 8/19
----------
8 train loss:0.3022 train acc: 0.8919
8 val loss:0.3463 val acc: 0.8759
训练和验证耗费的时间31m37s
Epoch 9/19
----------
9 train loss:0.2981 train acc: 0.8919
9 val loss:0.3462 val acc: 0.8792
训练和验证耗费的时间35m7s
Epoch 10/19
----------
10 train loss:0.3009 train acc: 0.8909
10 val loss:0.3683 val acc: 0.8682
训练和验证耗费的时间38m37s
Epoch 11/19
----------
11 train loss:0.2968 train acc: 0.8923
11 val loss:0.5043 val acc: 0.8171
训练和验证耗费的时间42m8s
Epoch 12/19
----------
12 train loss:0.3037 train acc: 0.8888
12 val loss:0.3076 val acc: 0.8921
训练和验证耗费的时间45m38s
Epoch 13/19
----------
13 train loss:0.2950 train acc: 0.8932
13 val loss:0.3680 val acc: 0.8776
训练和验证耗费的时间49m8s
Epoch 14/19
----------
14 train loss:0.3002 train acc: 0.8925
14 val loss:0.3667 val acc: 0.8712
训练和验证耗费的时间52m38s
Epoch 15/19
----------
15 train loss:0.2911 train acc: 0.8945
15 val loss:0.3753 val acc: 0.8758
训练和验证耗费的时间56m9s
Epoch 16/19
----------
16 train loss:0.2958 train acc: 0.8937
16 val loss:0.3534 val acc: 0.8778
训练和验证耗费的时间59m39s
Epoch 17/19
----------
17 train loss:0.2981 train acc: 0.8918
17 val loss:0.3350 val acc: 0.8905
训练和验证耗费的时间63m9s
Epoch 18/19
----------
18 train loss:0.2876 train acc: 0.8968
18 val loss:0.3525 val acc: 0.8891
训练和验证耗费的时间66m39s
Epoch 19/19
----------
19 train loss:0.2937 train acc: 0.8941
19 val loss:0.3477 val acc: 0.8778
训练和验证耗费的时间70m9s

测试模型

测试代码

def test_data_process():test_data = FashionMNIST(root='./data',train=False,transform=transforms.Compose([transforms.Resize(size=227), transforms.ToTensor()]),download=True)test_dataloader = Data.DataLoader(dataset=test_data,batch_size=1,shuffle=True,num_workers=0)return test_dataloaderdef test_model_process(model, test_dataloader):# 设定测试所用到的设备,有GPU用GPU没有GPU用CPUdevice = "cuda" if torch.cuda.is_available() else 'cpu'# 讲模型放入到训练设备中model = model.to(device)# 初始化参数test_corrects = 0.0test_num = 0# 只进行前向传播计算,不计算梯度,从而节省内存,加快运行速度with torch.no_grad():for test_data_x, test_data_y in test_dataloader:# 将特征放入到测试设备中test_data_x = test_data_x.to(device)# 将标签放入到测试设备中test_data_y = test_data_y.to(device)# 设置模型为评估模式model.eval()# 前向传播过程,输入为测试数据集,输出为对每个样本的预测值output= model(test_data_x)# 查找每一行中最大值对应的行标pre_lab = torch.argmax(output, dim=1)# 如果预测正确,则准确度test_corrects加1test_corrects += torch.sum(pre_lab == test_data_y.data)# 将所有的测试样本进行累加test_num += test_data_x.size(0)# 计算测试准确率test_acc = test_corrects.double().item() / test_numprint("测试的准确率为:", test_acc)if __name__ == "__main__":# 加载模型model = AlexNet()model.load_state_dict(torch.load('best_model.pth'))# # 利用现有的模型进行模型的测试test_dataloader = test_data_process()test_model_process(model, test_dataloader)# 设定测试所用到的设备,有GPU用GPU没有GPU用CPUdevice = "cuda" if torch.cuda.is_available() else 'cpu'model = model.to(device)classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']with torch.no_grad():for b_x, b_y in test_dataloader:b_x = b_x.to(device)b_y = b_y.to(device)# 设置模型为验证模型model.eval()output = model(b_x)pre_lab = torch.argmax(output, dim=1)result = pre_lab.item()label = b_y.item()print("预测值:",  classes[result], "------", "真实值:", classes[label])

测试结果

测试的准确率为: 0.8871
预测值: Ankle boot ------ 真实值: Ankle boot
预测值: Pullover ------ 真实值: Pullover
预测值: Shirt ------ 真实值: T-shirt/top
预测值: Dress ------ 真实值: Dress
预测值: T-shirt/top ------ 真实值: T-shirt/top
预测值: Shirt ------ 真实值: Pullover
预测值: Dress ------ 真实值: Dress
预测值: Coat ------ 真实值: Coat
预测值: Sneaker ------ 真实值: Sneaker
预测值: Trouser ------ 真实值: Trouser
预测值: Shirt ------ 真实值: Shirt
预测值: Pullover ------ 真实值: Pullover
预测值: Ankle boot ------ 真实值: Ankle boot
预测值: Ankle boot ------ 真实值: Ankle boot
预测值: Coat ------ 真实值: Coat
预测值: Coat ------ 真实值: Coat
预测值: T-shirt/top ------ 真实值: T-shirt/top
预测值: T-shirt/top ------ 真实值: T-shirt/top
预测值: T-shirt/top ------ 真实值: T-shirt/top
预测值: Dress ------ 真实值: Shirt
预测值: Shirt ------ 真实值: Shirt
预测值: Trouser ------ 真实值: Trouser
预测值: Dress ------ 真实值: Dress
...............................................

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/490480.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

socket编程UDP-实现停等机制(接收确认、超时重传)

在下面博客中,我介绍了利用UDP模拟TCP连接、按数据包发送文件的过程,并附上完整源码。 socket编程UDP-文件传输&模拟TCP建立连接脱离连接(进阶篇)_udp socket发送-CSDN博客 下面博客实现的是滑动窗口机制: sock…

Leetcode 面试150题 399.除法求值

系列博客目录 文章目录 系列博客目录题目思路代码 题目 链接 思路 广度优先搜索 我们可以将整个问题建模成一张图:给定图中的一些点(点即变量),以及某些边的权值(权值即两个变量的比值),试…

Python机器视觉的学习

一、二值化 1.1 二值化图 二值化图:就是将图像中的像素改成只有两种值,其操作的图像必须是灰度图。 1.2 阈值法 阈值法(Thresholding)是一种图像分割技术,旨在根据像素的灰度值或颜色值将图像分成不同的区域。该方法…

Linux 支持多个spi-nor flash

1. 需求 通常在嵌入式开发过程中可能会遇到需要再同一个SPI总线上挂载多个spi nor flash才能满足存储需求。 2. 技术简介 对于spi-nor flash驱动通常不需要驱动开发人员手搓,一般内核会有一套固定的驱动,而且走的是内核的MTD子系统那一套,市…

超标量处理器设计笔记(11)发射内容:分配、仲裁、唤醒

发射 概述集中式和分布式数据捕捉和非数据捕捉数据捕捉非数据捕捉总结对比 压缩式和非压缩式压缩式发射队列非压缩式发射队列总结 发射过程的流水线非数据捕捉结构的流水线数据捕捉结构的流水线 分配仲裁1-of-M 的仲裁电路N of M 的仲裁电路 唤醒单周期指令的唤醒多周期指令的…

ArrayList源码分析、扩容机制面试题,数组和List的相互转换,ArrayList与LinkedList的区别

目录 1.java集合框架体系 2. 前置知识-数组 2.1 数组 2.1.1 定义: 2.1.2 数组如何获取其他元素的地址值?(寻址公式) 2.1.3 为什么数组索引从0开始呢?从1开始不行吗? 3. ArrayList 3.1 ArrayList和和…

地下管线三维建模,市面上有哪些软件

1. 地下管线:城市“生命线” 地下管线是城市的重要基础设施,包括供水、排水、燃气、热力、电力、通信等管线,它们如同城市的“生命线”,支撑着城市的正常运转。如果缺乏完整和准确的地下管线信息,施工破坏地下管线的事…

蓝桥杯刷题——day5

蓝桥杯刷题——day5 题目一题干解题思路一代码解题思路二代码 题目二题干解题思路代码 题目一 题干 给定n个整数 a1,a2,⋯ ,an,求它们两两相乘再相加的和,即: 示例一: 输入: 4 1 3 6 9 输出: 117 题目链…

L1-3流量分析

1. 初步分析 数据包下载 流量分析基础篇 使用科来网络分析系统,打开L1-3.pcapng数据包,查看数据包中ssh的协议占的比例较大。 2. 通过分析数据包L1-3,找出黑客的IP地址,并将黑客的IP地址作为FLAG(形式:[IP地址)提交; 获取的fl…

docker启动一个helloworld(公司内网服务器)

这里写目录标题 容易遇到的问题:1、docker连接问题 我来介绍几种启动 Docker Hello World 的方法: 最简单的方式: docker run hello-world这会自动下载并运行官方的 hello-world 镜像。 使用 Nginx 作为 Hello World: docker…

【网络取证篇】取证实战之PHP服务器镜像网站重构及绕密分析

【网络取证篇】取证实战之PHP服务器镜像网站重构及绕密分析 在裸聊敲诈、虚假理财诈骗案件类型中,犯罪分子为了能实现更低成本、更快部署应用的目的,其服务器架构多为常见的初始化网站架构,也称为站库同体服务器!也就是说网站应用…

图像处理 - 车道线检测:智能驾驶的“眼睛”

引言 在智能驾驶技术飞速发展的今天,车道线检测作为一项基础而关键的技术,扮演着车辆“眼睛”的角色。它不仅关系到车辆的导航和定位,还直接影响到自动驾驶系统的安全性和可靠性。本文将带你深入了解车道线检测技术的原理、方法以及在实际应用…

【Linux学习】十五、Linux/CentOS 7 用户和组管理

Linux下组和用户的管理都必须是root用户下进行: 一、组的管理 1.组的创建 格式: groupadd 组名参数: -g:指定用户组的组ID(GID),如果不提供则由系统自动分配。 【案例】创建一个名为 oldg…

Unity类银河战士恶魔城学习总结(P179 Enemy Archer 弓箭手)

教程源地址:https://www.udemy.com/course/2d-rpg-alexdev/ 本章节实现了敌人弓箭手的制作 Enemy_Archer.cs 核心功能 状态机管理敌人的行为 定义了多个状态对象(如 idleState、moveState、attackState 等),通过状态机管理敌人的…

Pikachu靶场——XXE漏洞

XXE(XML External Entity)漏洞 XXE(XML External Entity)漏洞是一种常见的安全漏洞,发生在处理 XML 数据的应用程序中。当应用程序解析 XML 输入时,如果没有正确配置或过滤外部实体的加载,就可能…

使用 ESP-IDF 进行esp32-c3开发第四步:VSCode里安装ESP-IDF插件

很多小伙伴还是习惯在VSCode里写代码,所以今天进行了--使用 ESP-IDF 进行esp32-c3开发第四步:VSCode里安装ESP-IDF插件 安装和配置 首先到VSCode的插件页面,搜索esp,排名第一的就是ESP-IDF插件,点击安装即可。 在命令…

SSM 垃圾分类系统——高效分类的科技保障

第五章 系统功能实现 5.1管理员登录 管理员登录,通过填写用户名、密码、角色等信息,输入完成后选择登录即可进入垃圾分类系统,如图5-1所示。 图5-1管理员登录界面图 5.2管理员功能实现 5.2.1 用户管理 管理员对用户管理进行填写账号、姓名、…

部署GitLab服务器

文章目录 环境准备GitLab部署GitLab服务器GitLab中主要的概念客户端上传代码到gitlab服务器CI-CD概述软件程序上线流程安装Jenkins服务器 配置jenkins软件版本管理配置jenkins访问gitlab远程仓库下载到子目录部署代码到web服务器自动化部署流程 配置共享服务器配置jenkins把git…

泷羽sec学习打卡-brupsuite8伪造IP和爬虫审计

声明 学习视频来自B站UP主 泷羽sec,如涉及侵权马上删除文章 笔记的只是方便各位师傅学习知识,以下网站只涉及学习内容,其他的都 与本人无关,切莫逾越法律红线,否则后果自负 关于brupsuite的那些事儿-Brup-FaskIP 伪造IP配置环境brupsuite导入配置1、扩展中先配置python环境2、安…

如何在 Ubuntu 22.04 上使用 Fail2Ban 保护 SSH

前言 SSH,这玩意儿,简直是连接云服务器的标配。它不仅好用,还很灵活。新的加密技术出来,它也能跟着升级,保证核心协议的安全。但是,再牛的协议和软件,也都有可能被攻破。SSH 在网上用得这么广&…