《Python实战进阶》No34:卷积神经网络(CNN)图像分类实战

第34集:卷积神经网络(CNN)图像分类实战


摘要

卷积神经网络(CNN)是计算机视觉领域的核心技术,特别擅长处理图像分类任务。本集将深入讲解 CNN 的核心组件(卷积层、池化层、全连接层),并演示如何使用 PyTorch 构建一个完整的 CNN 模型,在 CIFAR-10 数据集上实现图像分类。我们还将探讨数据增强和正则化技术(如 Dropout 和 BatchNorm)对模型性能的影响。
在这里插入图片描述


核心概念和知识点

1. CNN 的核心组件

  • 卷积层:通过滤波器(Filter)提取局部特征(如边缘、纹理)。
  • 池化层:通过下采样(如最大池化)减少参数数量,增强特征鲁棒性。
  • 全连接层:将提取的特征映射到分类标签。

2. 数据增强技术

  • 常用方法:随机水平翻转、随机裁剪、色彩抖动(调整亮度、对比度)。
  • 作用:增加训练数据的多样性,防止过拟合。

3. 过拟合与正则化

  • 过拟合:模型在训练集表现优异,但在测试集性能下降。
  • 正则化方法
    • Dropout:随机关闭部分神经元,减少对特定特征的依赖。
    • BatchNorm:标准化每层的输入,加速训练并提升泛化能力。

4. 与 AI 大模型的关联

  • 基础架构角色:CNN 是许多大模型(如 ResNet、EfficientNet)的核心组件。
  • 迁移学习:通过预训练的 CNN 模型(如 ImageNet 权重)快速适应新任务。
  • 自监督学习:利用 CNN 提取特征,用于无标签数据的预训练。

实战案例:使用 CNN 分类 CIFAR-10 数据集

背景

CIFAR-10 包含 60,000 张 32x32 彩色图像,分为 10 个类别(飞机、汽车、鸟类等)。我们将构建一个轻量级 CNN 模型,结合数据增强和正则化技术提升分类性能。

代码实现

1. 环境准备

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader

2. 数据加载和预处理

def load_data():# 数据增强transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])# 加载CIFAR-10数据集trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)return trainset, testset

3. 构建CNN模型

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 第一个卷积块self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(32)self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)self.bn2 = nn.BatchNorm2d(32)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.dropout1 = nn.Dropout(0.25)# 第二个卷积块self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.bn3 = nn.BatchNorm2d(64)self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)self.bn4 = nn.BatchNorm2d(64)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.dropout2 = nn.Dropout(0.25)# 第三个卷积块self.conv5 = nn.Conv2d(64, 128, kernel_size=3, padding=1)self.bn5 = nn.BatchNorm2d(128)self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)self.bn6 = nn.BatchNorm2d(128)self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)self.dropout3 = nn.Dropout(0.25)# 全连接层self.fc1 = nn.Linear(128 * 4 * 4, 512)self.dropout4 = nn.Dropout(0.5)self.fc2 = nn.Linear(512, 10)def forward(self, x):# 第一个卷积块x = self.pool1(F.relu(self.bn2(self.conv2(F.relu(self.bn1(self.conv1(x)))))))x = self.dropout1(x)# 第二个卷积块x = self.pool2(F.relu(self.bn4(self.conv4(F.relu(self.bn3(self.conv3(x)))))))x = self.dropout2(x)# 第三个卷积块x = self.pool3(F.relu(self.bn6(self.conv6(F.relu(self.bn5(self.conv5(x)))))))x = self.dropout3(x)# 全连接层x = x.view(-1, 128 * 4 * 4)x = self.dropout4(F.relu(self.fc1(x)))x = self.fc2(x)return x

4. 训练和评估

def train_model(model, trainloader, criterion, optimizer, device):model.train()running_loss = 0.0correct = 0total = 0for i, data in enumerate(trainloader):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()if (i + 1) % 100 == 0:print(f'Batch [{i + 1}], Loss: {running_loss/100:.4f}, 'f'Acc: {100.*correct/total:.2f}%')running_loss = 0.0def evaluate_model(model, testloader, device):model.eval()correct = 0total = 0with torch.no_grad():for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = model(images)_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()accuracy = 100. * correct / totalprint(f'测试集准确率: {accuracy:.2f}%')return accuracy

5. 可视化训练过程

def plot_training_history(train_losses, test_accuracies):plt.figure(figsize=(12, 4))# 绘制训练损失plt.subplot(1, 2, 1)plt.plot(train_losses)plt.title('训练损失')plt.xlabel('批次')plt.ylabel('损失')# 绘制测试准确率plt.subplot(1, 2, 2)plt.plot(test_accuracies)plt.title('测试准确率')plt.xlabel('轮次')plt.ylabel('准确率 (%)')plt.tight_layout()plt.show()

程序输出结果:

在这里插入图片描述


总结

通过本集的学习,我们掌握了 CNN 的核心组件和正则化技术,并通过 CIFAR-10 图像分类任务验证了模型的有效性。CNN 的卷积层和池化层能够有效提取图像特征,而数据增强与 Dropout/BatchNorm 的结合显著提升了模型的泛化能力。


扩展思考

1. 迁移学习提升模型性能

  • 使用预训练模型(如 ResNet-18)作为特征提取器,仅微调最后几层。
  • 代码示例:
    import torchvision.models as models
    resnet = models.resnet18(pretrained=True)
    # 冻结卷积层
    for param in resnet.parameters():param.requires_grad = False
    # 替换最后的全连接层
    resnet.fc = nn.Linear(resnet.fc.in_features, 10)
    

2. 自监督学习的潜力

  • 自监督学习通过无标签数据预训练模型(如通过图像旋转预测任务),可在小数据集上取得更好的效果。
  • 例如,使用 MoCo 框架预训练 CNN 编码器。

专栏链接:Python实战进阶
下期预告:No35:循环神经网络(RNN)时间序列预测

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/41553.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【银河麒麟系统常识】命令:uname -m(查看系统架构)

命令: uname -m 功能 常用的 Linux/Unix 终端命令,用于显示当前系统的硬件架构; 返回 返回系统的CPU架构类型,用于判断软件兼容性; 输出结果架构说明常见设备x86_64Intel/AMD 64位 CPU主流 PC、服务器aarch64ARM 64位 …

游戏引擎学习第183天

回顾和今天的计划 我对接下来的进展感到非常兴奋。虽然我们可能会遇到一些问题,但昨天我们差不多完成了将所有内容迁移到新的日志系统的工作,我们正在把一些内容整合进来,甚至是之前通过不同方式记录时间戳的旧平台层部分,现在也…

Redisson 实现分布式锁简单解析

目录 Redisson 实现分布式锁业务方法:加锁逻辑LockUtil 工具类锁余额方法:工具类代码枚举代码 RedisUtil 工具类tryLock 方法及重载【分布式锁具体实现】Supplier 函数式接口调用分析 Redisson 实现分布式锁 业务方法: 如图,简单…

鸿蒙Flutter实战:19-Flutter集成高德地图,跳转页面方式

前言 在之前的文章现有Flutter项目支持鸿蒙II中,介绍了如何使用第三方插件,同时给出了非常多的使用案例,如 flutter_inappwebview,video_player, image_picker 等,本文将开始介绍如何集成高德地图。 整体方案 通过 …

26考研——图_图的代码实操(6)

408答疑 文章目录 五、图的代码实操图的存储邻接矩阵结构定义初始化插入顶点获取顶点位置在顶点 v1 和 v2 之间插入边获取第一个邻接顶点获取下一个邻接顶点显示图 邻接表结构定义初始化图插入顶点获取顶点位置在顶点 v1 和 v2 之间插入边获取第一个邻接顶点获取下一个邻接顶点…

力扣32.最长有效括号(栈)

32. 最长有效括号 - 力扣&#xff08;LeetCode&#xff09; 代码区&#xff1a; #include<stack> #include<string> /*最长有效*/ class Solution { public:int longestValidParentheses(string s) {stack<int> st;int ans0;int ns.length();st.push(-1);fo…

Node.js 下载安装及环境配置教程、卸载删除环境配置超详细步骤(附图文讲解!) 从零基础入门到精通,看完这一篇就够了

Node.js 安装 一、进入官网地址下载安装包 Node.js — Download Node.js 选择对应你系统的Node.js版本&#xff0c;这里我选择的是Windows系统、64位 Tips&#xff1a;如果想下载指定版本&#xff0c;点击【以往的版本】&#xff0c;即可选择自己想要的版本下载 二、安装程序…

SQLark导出功能详解|轻松管理数据库数据与结构

SQLark 作为一款数据库管理工具&#xff0c;为用户提供了丰富且实用的导出功能。在数据库管理与开发过程中&#xff0c;数据及结构的导出操作至关重要&#xff0c;关乎数据的迁移、备份、版本管理以及问题定位等诸多关键环节。接下来&#xff0c;让我们深入了解 SQLark 的导出功…

搭建Redis主从集群

主从集群说明 单节点Redis的并发能力是有上限的&#xff0c;要进一步提高Redis的并发能力&#xff0c;就需要搭建主从集群&#xff0c;实现读写分离。 主从结构 这是一个简单的Redis主从集群结构 集群中有一个master节点、两个slave节点&#xff08;现在叫replica&#xff09;…

自然语言处理(NLP)技术的应用面有哪些

自然语言处理&#xff08;NLP&#xff09;技术在各个领域都有广泛的应用&#xff0c;以下是一些常见的例子&#xff1a; 机器翻译&#xff1a;NLP技术用于开发翻译系统&#xff0c;可以将一个语言的文本自动翻译成另一种语言。例如&#xff0c;谷歌翻译就是一个应用了NLP技术的…

element-plus 的简单应用

前言 本篇博客是 基于 ElementPlus 快速入门_element plus x-CSDN博客 的进阶 最终成果 完成的要求 1 深入学习 设计 | Element Plus 从里面找自己合适的 使用到的 组件有&#xff1a;表格&#xff0c;分页条&#xff0c;表单&#xff0c;卡片 2 具备 前端基础&#xff08;ht…

关于Qt的各类问题

目录 1、问题&#xff1a;Qt中文乱码 2、问题&#xff1a;启动时避免ComBox控件出现默认值 博客会不定期的更新各种Qt开发的Bug与解决方法,敬请关注! 1、问题&#xff1a;Qt中文乱码 问题描述&#xff1a;我在设置标题时出现了中文乱码 this->setWindowTitle("算法…

海思烧录工具HITool电视盒子刷机详解

HiTool是华为开发的一款用于海思芯片设备的刷机和调试工具&#xff0c;可对搭载海思芯片的机顶盒、智能电视等设备进行固件烧录、参数配置等操作。以下为你详细介绍&#xff1a; 功能用途 固件烧录&#xff1a;这是HiTool最主要的功能之一。它能够将下载好的适配固件文件烧录到…

Docker Compose介绍

基本概念 Docker-Compose是Docker官方的开源项目&#xff0c;负责实现对docker容器集群的快速编排。 可以这么理解&#xff0c;docker compose是docker提出的一个工具软件&#xff0c;可以管理多个docker容器组成一个应用&#xff0c;只需要编写一个YAML格式的配置文件docker…

大疆上云api直播功能如何实现

概述 流媒体服务器作为直播画面的中转站,它接收推流端的相机画面,同时拉流端找它获取相机的画面。整个流程如下: 在流媒体服务器上创建流媒体应用(app),一个流媒体服务器上面可以创建多个流媒体应用约定推拉流的地址。假设流媒体服务器工作在1935端口上面,假设创建的流…

LabVIEW远程控制通讯接口

abVIEW提供了多种远程控制与通讯接口&#xff0c;适用于不同场景下的设备交互、数据传输和系统集成。这些接口涵盖从基础的网络协议&#xff08;如TCP/IP、UDP&#xff09;到专用技术&#xff08;如DataSocket、远程面板&#xff09;&#xff0c;以及工业标准协议&#xff08;如…

算法每日一练 (18)

&#x1f4a2;欢迎来到张翊尘的技术站 &#x1f4a5;技术如江河&#xff0c;汇聚众志成。代码似星辰&#xff0c;照亮行征程。开源精神长&#xff0c;传承永不忘。携手共前行&#xff0c;未来更辉煌&#x1f4a5; 文章目录 算法每日一练 (18)删除并获得点数题目描述解题思路解题…

Java后端API限流秘籍:高并发的防护伞与实战指南

目录导航 📜 🛡️ 为什么需要API限流?🧠 主流限流算法大解析👩‍💻 阿里巴巴的限流实践📏 四大黄金定律🤼 限流策略组合拳🏆 限流场景实战💻 技术实现方案🌟 最佳实践分享📈 结语与展望📚 推荐阅读 1. 🛡️ 为什么需要API限流? 在高并发环境中,未…

【软件测试】:软件测试实战

1. ⾃动化实施步骤 1.1 编写web测试⽤例 1.2 ⾃动化测试脚本开发 common public class AutotestUtils {public static EdgeDriver driver;// 创建驱动对象public static EdgeDriver createDriver(){// 驱动对象已经创建好了 / 没有创建if( driver null){driver new EdgeDr…

26考研——栈、队列和数组_栈(3)

408答疑 文章目录 一、栈1、栈&#xff08;Stack&#xff09;的概念和特点定义术语操作特性示例直观理解栈的基本操作初始化栈判断栈是否为空入栈操作出栈操作读取栈顶元素销毁栈 栈的数学性质 2、栈的顺序存储结构顺序栈的定义栈顶指针初始化注意事项 共享栈共享栈的操作共享栈…