使用 AlexNet 实现图片分类 | PyTorch 深度学习实战

前一篇文章,CNN 卷积神经网络处理图片任务 | PyTorch 深度学习实战

本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started

本篇文章内容来自于 强化学习必修课:引领人工智能新时代【梗直哥瞿炜】

使用 AlexNet 实现图片分类

  • 经典卷积网络
  • AlexNet 特点
  • 实验代码
  • 实验结果
  • Why AlexNet Works
  • Links

经典卷积网络

以下是卷积神经网络发展的里程碑:

在这里插入图片描述

  • AlexNet 在各项比赛中,比其它算法好,证明了深度神经网络算法的优越性和前景
  • VGGNet 则使用比 AlexNet 更深更宽的网络,取得了比 AlexNet 还好的成绩
  • GoogLeNet 效果则比 VGGNet 更好
  • ResNet 引入残差模块,解决了深度网络训练中的退化问题,超越之前的模型
  • DenseNet 模型采用密集连接的结构,使模型具有更好的鲁棒性

AlexNet 特点

在这里插入图片描述

AlexNet 结构

![[../assets/media/screenshot_20250208184431.png]]

更多详细介绍,阅读作者 Paper 论文, ImageNet Classification with Deep Convolutional Neural Networks

视频资源:9年后重读深度学习奠基作之一:AlexNet【上】【论文精读】

实验代码

from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import transforms
from torchvision import datasets
import torch.nn as nn
import torch
import numpy as np# configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_rootdir = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir, "data")################################
# 定义 dataset loader
################################
def get_train_valid_loader(data_dir,batch_size,augment,random_seed,valid_size=0.1,shuffle=True):# 正则化图片的参数,mean 和 std 的值来自于 imagenet 的数据统计normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],std=[0.2023, 0.1994, 0.2010],)# define transformsvalid_transform = transforms.Compose([transforms.Resize((227, 227)),transforms.ToTensor(),normalize,])if augment:# 数据增强train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),normalize,])else:train_transform = transforms.Compose([transforms.Resize((227, 227)),transforms.ToTensor(),normalize,])# load the datasettrain_dataset = datasets.CIFAR10(root=data_dir, train=True,download=True, transform=train_transform,)valid_dataset = datasets.CIFAR10(root=data_dir, train=True,download=True, transform=valid_transform,)num_train = len(train_dataset)indices = list(range(num_train))split = int(np.floor(valid_size * num_train))if shuffle:np.random.seed(random_seed)np.random.shuffle(indices)train_idx, valid_idx = indices[split:], indices[:split]train_sampler = SubsetRandomSampler(train_idx)valid_sampler = SubsetRandomSampler(valid_idx)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, sampler=valid_sampler)return (train_loader, valid_loader)def get_test_loader(data_dir,batch_size,shuffle=True):normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225],)# define transformtransform = transforms.Compose([transforms.Resize((227, 227)),transforms.ToTensor(),normalize,])dataset = datasets.CIFAR10(root=data_dir, train=False,download=True, transform=transform,)data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)return data_loader# CIFAR10 dataset
CIFAR10_data = os.path.join(data_rootdir, "CIFAR10")
train_loader, valid_loader = get_train_valid_loader(data_dir= CIFAR10_data,                                      batch_size = 64,augment = False,                                          random_seed = 1)test_loader = get_test_loader(data_dir= CIFAR10_data,batch_size = 64)################################
# 定义 Model
################################
class AlexNet(nn.Module):def __init__(self, num_classes=10):super(AlexNet, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0),nn.BatchNorm2d(96),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2))self.layer2 = nn.Sequential(nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2))self.layer3 = nn.Sequential(nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(384),nn.ReLU())self.layer4 = nn.Sequential(nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(384),nn.ReLU())self.layer5 = nn.Sequential(nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2))self.fc = nn.Sequential(nn.Dropout(0.5),nn.Linear(9216, 4096),nn.ReLU())self.fc1 = nn.Sequential(nn.Dropout(0.5),nn.Linear(4096, 4096),nn.ReLU())self.fc2= nn.Sequential(nn.Linear(4096, num_classes))def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.layer5(out)out = out.reshape(out.size(0), -1)out = self.fc(out)out = self.fc1(out)out = self.fc2(out)return out######################################
# Setting Hyperparameters
######################################
num_classes = 10
num_epochs = 20
batch_size = 64
learning_rate = 0.005model = AlexNet(num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay = 0.005, momentum = 0.9)  # Train the model
total_step = len(train_loader)######################################
# Training
######################################
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):  # Move tensors to the configured deviceimages = images.to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' .format(epoch+1, num_epochs, i+1, total_step, loss.item()))# Validationwith torch.no_grad():correct = 0total = 0for images, labels in valid_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()del images, labels, outputsprint('Accuracy of the network on the {} validation images: {} %'.format(5000, 100 * correct / total))# Now, we see how our model performs on unseen data
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()del images, labels, outputsprint('Accuracy of the network on the {} test images: {} %'.format(10000, 100 * correct / total))

在这里插入图片描述

实验结果

以上代码在 NVIDIA GeForce RTX 2050 WDDM 显存上训练和测试,大约花了半个小时时间。
最终,测试集上的准确率达到了 82.24% 。
在这里插入图片描述

Why AlexNet Works

Links

  • Writing AlexNet from Scratch in PyTorch
  • ImageNet Classification with Deep Convolutional
    Neural Networks
  • Conv2d API in PyTorch
  • AlexNet 论文精读(上),9年后重读深度学习奠基作之一
  • AlexNet 论文精读(下)
  • AlexNet Explained: A Step-by-Step Guide

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/15698.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C# Winform 使用委托实现C++中回调函数的功能

C# Winform 使用委托实现C中回调函数的功能 在项目中遇到了使用C#调用C封装的接口,其中C接口有一个回调函数的参数。参考对比后,在C#中是使用委托(delegate)来实现类似的功能。 下面使用一个示例来介绍具体的使用方式: 第一步:…

攻防世界33 catcat-new【文件包含/flask_session伪造】

题目: 点击一只猫猫: 看这个url像是文件包含漏洞,试试 dirsearch扫出来/admin,访问也没成功(--delay 0.1 -t 5) 会的那几招全用不了了哈哈,那就继续看答案 先总结几个知识点 1./etc/passwd&am…

ArgoCD实战指南:GitOps驱动下的Kubernetes自动化部署与Helm/Kustomize集成

摘要 ArgoCD 是一种 GitOps 持续交付工具,专为 Kubernetes 设计。它能够自动同步 Git 仓库中的声明性配置,并将其应用到 Kubernetes 集群中。本文将介绍 ArgoCD 的架构、安装步骤,以及如何结合 Helm 和 Kustomize 进行 Kubernetes 自动化部署。 引言 为什么选择 ArgoCD?…

尝试一下,交互式的三维计算python库,py3d

py3d是一个我开发的三维计算python库,目前不定期在PYPI上发版,可以通过pip直接安装 pip install py3d 开发这个库主要可视化是想把自己在工作中常用的三维方法汇总积累下来,不必每次重新造轮子。其实现成的python库也有很多,例如…

【愚公系列】《循序渐进Vue.js 3.x前端开发实践》070-商业项目:电商后台管理系统实战(商品管理模块的开发)

标题详情作者简介愚公搬代码头衔华为云特约编辑,华为云云享专家,华为开发者专家,华为产品云测专家,CSDN博客专家,CSDN商业化专家,阿里云专家博主,阿里云签约作者,腾讯云优秀博主&…

5 个释放 安卓潜力的 Shizuku 应用

Shizuku 软件推荐:释放安卓潜力的五款应用 Shizuku (日语:雫,意为“水滴”) 正如其名,是一款轻巧但功能强大的安卓工具。它无需 Root 权限,通过 ADB (Android Debug Bridge) 授权,即可让应用调用系统 API&…

前端权限控制和管理

前端权限控制和管理 1.前言2.权限相关概念2.1权限的分类(1)后端权限(2)前端权限 2.2前端权限的意义 3.前端权限控制思路3.1菜单的权限控制3.2界面的权限控制3.3按钮的权限控制3.4接口的权限控制 4.实现步骤4.1菜单栏控制4.2界面的控制(1)路由导航守卫(2)动态路由 4.3按钮的控制…

分布式kettle调度平台- web版转换,作业编排新功能介绍

介绍 Kettle(也称为Pentaho Data Integration)是一款开源的ETL(Extract, Transform, Load)工具,由Pentaho(现为Hitachi Vantara)开发和维护。它提供了一套强大的数据集成和转换功能&#xff0c…

Docker容器访问外网:启动时的网络参数配置指南

在启动Docker镜像时,可以通过设置网络参数来确保容器能够访问外网。以下是几种常见的方法: 1. 使用默认的bridge网络 Docker的默认网络模式是bridge,它会创建一个虚拟网桥,将容器连接到宿主机的网络上。在大多数情况下,使用默认的bridge网络配置即可使容器访问外网。 启动…

大语言模型RAG,transformer

1、RAG技术流总结 第一张图是比较经典的RAG知识图谱,第二张图是更加详细扎实的介绍图。 1.1 索引 坦白来说这部分的技术并不是大模型领域的,更像是之前技术在大模型领域的应用;早在2019年我就做过faiss部分的尝试,彼时索引技术已…

数据结构与算法(test3)

七、查找 1. 看图填空 查找表是由同一类型的数据元素(或记录)构成的集合。例如上图就是一个查找表。 期中(1)是______________. (2)是______________(3)是_____关键字_______。 2. 查找(Searching) 就是根据给定的某个值, 在查…

机器学习在癌症分子亚型分类中的应用

学习笔记:机器学习在癌症分子亚型分类中的应用——Cancer Cell 研究解析 1. 文章基本信息 标题:Classification of non-TCGA cancer samples to TCGA molecular subtypes using machine learning发表期刊:Cancer Cell发表时间:20…

48V电气架构全面科普和解析:下一代智能电动汽车核心驱动

48V电气架构:下一代智能电动汽车核心驱动 随着全球汽车产业迈入电动化、智能化的新时代,传统12V电气系统逐渐暴露出其无法满足现代高功率需求的不足。在此背景下,48V电气架构应运而生,成为现代电动汽车(EV&#xff09…

Mac(m1)本地部署deepseek-R1模型

1. 下载安装ollama 直接下载软件,下载完成之后,安装即可,安装完成之后,命令行中可出现ollama命令 2. 在ollama官网查看需要下载的模型下载命令 1. 在官网查看deepseek对应的模型 2. 选择使用电脑配置的模型 3. copy 对应模型的安…

操作教程丨使用1Panel开源面板快速部署DeepSeek-R1

近期,DeepSeek-R1模型因其在数学推理、代码生成与自然语言推理等方面的优异表现而受到广泛关注。作为能够有效提升生产力的工具,许多个人和企业用户都希望能在本地部署DeepSeek-R1模型。 通过1Panel的应用商店能够简单、快速地在本地部署DeepSeek-R1模型…

免费在腾讯云Cloud Studio部署DeepSeek-R1大模型

2024年2月2日,腾讯云宣布DeepSeek-R1大模型正式支持一键部署至腾讯云HAI(高性能应用服务)。开发者仅需3分钟即可完成部署并调用模型,大幅简化了传统部署流程中买卡、装驱动、配网络、配存储、装环境、装框架、下载模型等繁琐步骤。…

C语言-结构体

1.共用体: union //联合--共用体 早期的时候,计算机的硬件资源有限, 能不能让多个成员变量 公用同一块空间 //使用方式 类似 结构体 --- 也是构造类型 struct 结构体名 { 成员变量名 }; union 共用体名 { 成员变量名 }; //表示构造了一个共用体…

多头自注意力中的多头作用及相关思考

文章目录 1. num_heads2. pytorch源码演算 1. num_heads 将矩阵的最后一维度进行按照num_heads的方式进行切割矩阵,具体表示如下: 2. pytorch源码演算 pytorch 代码 import torch import torch.nn as nn import torch.nn.functional as Ftorch.set…

数据仓库和商务智能:洞察数据,驱动决策

在数据管理的众多领域中,数据仓库和商务智能(BI)是将数据转化为洞察力、支持决策制定的关键环节。它们通过整合、存储和分析数据,帮助组织更好地理解业务运营,预测市场趋势,从而制定出更明智的战略。今天&a…

C++ ——从C到C++

1、C的学习方法 (1)C知识点概念内容比较多,需要反复复习 (2)偏理论,有的内容不理解,可以先背下来,后续可能会理解更深 (3)学好编程要多练习,简…