PyTorch 系列教程:使用CNN实现图像分类

图像分类是计算机视觉领域的一项基本任务,也是深度学习技术的一个常见应用。近年来,卷积神经网络(cnn)和PyTorch库的结合由于其易用性和鲁棒性已经成为执行图像分类的流行选择。

理解卷积神经网络(cnn)

卷积神经网络是一类深度神经网络,对分析视觉图像特别有效。他们利用多层构建一个可以直接从图像中识别模式的模型。这些模型对于图像识别和分类等任务特别有用,因为它们不需要手动提取特征。

cnn的关键组成部分

  • 卷积层:这些层对输入应用卷积操作,将结果传递给下一层。每个过滤器(或核)可以捕获不同的特征,如边缘、角或其他模式。
  • 池化层:这些层减少了表示的空间大小,以减少参数的数量并加快计算速度。池化层简化了后续层的处理。
  • 完全连接层:在这些层中,神经元与前一层的所有激活具有完全连接,就像传统的神经网络一样。它们有助于对前一层识别的对象进行分类。
    在这里插入图片描述

使用PyTorch进行图像分类

PyTorch是开源的深度学习库,提供了极大的灵活性和多功能性。研究人员和从业人员广泛使用它来轻松有效地实现尖端的机器学习模型。

设置PyTorch

首先,确保在开发环境中安装了PyTorch。你可以通过pip安装它:

pip install torch torchvision

用PyTorch创建简单的CNN示例

下面是如何定义简单的CNN来使用PyTorch对图像进行分类的示例。

import torch
import torch.nn as nn
import torch.nn.functional as F# 定义CNN模型(修复了变量引用问题)
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)      # 第一个卷积层:3输入通道,6输出通道,5x5卷积核self.pool = nn.MaxPool2d(2, 2)        # 最大池化层:2x2窗口,步长2self.conv2 = nn.Conv2d(6, 16, 5)     # 第二个卷积层:6输入通道,16输出通道,5x5卷积核self.fc1 = nn.Linear(16 * 5 * 5, 120)# 全连接层1:400输入 -> 120输出self.fc2 = nn.Linear(120, 84)      # 全连接层2:120输入 -> 84输出self.fc3 = nn.Linear(84, 10)       # 输出层:84输入 -> 10类 logitsdef forward(self, x):# 输入形状:[batch_size, 3, 32, 32]x = self.pool(F.relu(self.conv1(x)))  # -> [batch, 6, 14, 14](池化后尺寸减半)x = self.pool(F.relu(self.conv2(x)))  # -> [batch, 16, 5, 5] x = x.view(-1, 16 * 5 * 5)            # 展平为一维向量:16 * 5 * 5=400x = F.relu(self.fc1(x))             # -> [batch, 120]x = F.relu(self.fc2(x))             # -> [batch, 84]x = self.fc3(x)                     # -> [batch, 10](未应用softmax,配合CrossEntropyLoss使用)return x

这个特殊的网络接受一个输入图像,通过两组卷积和池化层,然后是三个完全连接的层。根据数据集的复杂性和大小调整网络的架构和超参数。

模型定义

  • SimpleCNN 继承自 nn.Module
  • 使用两个卷积层提取特征,三个全连接层进行分类
  • 最终输出未应用 softmax,而是直接输出 logits(与 CrossEntropyLoss 配合使用)

训练网络

对于训练,你需要一个数据集。PyTorch通过torchvision包提供了用于数据加载和预处理的实用程序。

import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader# 初始化模型、损失函数和优化器
net = SimpleCNN()               # 实例化模型
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数(自动处理softmax)
optimizer = torch.optim.SGD(net.parameters(), lr=0.001,      # 学习率momentum=0.9)   # 动量参数# 数据预处理和加载
transform = transforms.Compose([transforms.ToTensor(),          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 加载CIFAR-10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True,  # 自动下载数据集transform=transform
)trainloader = DataLoader(trainset, batch_size=4,   # 每个batch包含4张图像shuffle=True)  # 打乱数据顺序

模型配置

  • 损失函数CrossEntropyLoss(自动包含 softmax 和 log_softmax)
  • 优化器:SGD with momentum,学习率 0.001

数据加载

  • 使用 torchvision.datasets.CIFAR10 加载数据集

  • batch_size:4(根据 GPU 内存调整,CIFAR-10 建议 batch size ≥ 32)

  • transforms.Compose 定义数据预处理流程:

    • ToTensor():将图像转换为 PyTorch Tensor
    • Normalize():标准化图像像素值到 [-1, 1]

加载数据后,训练过程包括通过数据集进行多次迭代,使用反向传播和合适的损失函数:

# 训练循环
for epoch in range(2):  # 进行2个epoch的训练running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data# 前向传播outputs = net(inputs)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()   # 清空梯度loss.backward()         # 计算梯度optimizer.step()       # 更新参数running_loss += loss.item()# 每2000个batch打印一次if i % 2000 == 1999:avg_loss = running_loss / 2000print(f'Epoch [{epoch+1}/{2}], Batch [{i+1}/2000], Loss: {avg_loss:.3f}')running_loss = 0.0print("训练完成!")

训练循环

  • epoch:完整遍历数据集一次
  • batch:数据加载器中的一个批次
  • 梯度清零:每次反向传播前需要清空梯度
  • 损失计算outputs 的形状为 [batch_size, 10]labels 为整数标签

完整代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader# 定义CNN模型(修复了变量引用问题)
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)      # 第一个卷积层:3输入通道,6输出通道,5x5卷积核self.pool = nn.MaxPool2d(2, 2)        # 最大池化层:2x2窗口,步长2self.conv2 = nn.Conv2d(6, 16, 5)     # 第二个卷积层:6输入通道,16输出通道,5x5卷积核self.fc1 = nn.Linear(16 * 5 * 5, 120)# 全连接层1:400输入 -> 120输出self.fc2 = nn.Linear(120, 84)      # 全连接层2:120输入 -> 84输出self.fc3 = nn.Linear(84, 10)       # 输出层:84输入 -> 10类 logitsdef forward(self, x):# 输入形状:[batch_size, 3, 32, 32]x = self.pool(F.relu(self.conv1(x)))  # -> [batch, 6, 14, 14](池化后尺寸减半)x = self.pool(F.relu(self.conv2(x)))  # -> [batch, 16, 5, 5] x = x.view(-1, 16 * 5 * 5)            # 展平为一维向量:16 * 5 * 5=400x = F.relu(self.fc1(x))             # -> [batch, 120]x = F.relu(self.fc2(x))             # -> [batch, 84]x = self.fc3(x)                     # -> [batch, 10](未应用softmax,配合CrossEntropyLoss使用)return x# 初始化模型、损失函数和优化器
net = SimpleCNN()               # 实例化模型
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数(自动处理softmax)
optimizer = torch.optim.SGD(net.parameters(), lr=0.001,      # 学习率momentum=0.9)   # 动量参数# 数据预处理和加载
transform = transforms.Compose([transforms.ToTensor(),            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  
])# 加载CIFAR-10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True,  # 自动下载数据集transform=transform
)
trainloader = DataLoader(trainset, batch_size=4,   # 每个batch包含4张图像shuffle=True)  # 打乱数据顺序# 训练循环
for epoch in range(2):  # 进行2个epoch的训练running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data# 前向传播outputs = net(inputs)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()   # 清空梯度loss.backward()         # 计算梯度optimizer.step()       # 更新参数running_loss += loss.item()# 每2000个batch打印一次if i % 2000 == 1999:avg_loss = running_loss / 2000print(f'Epoch [{epoch+1}/{2}], Batch [{i+1}/2000], Loss: {avg_loss:.3f}')running_loss = 0.0print("训练完成!")

最后总结

通过PyTorch和卷积神经网络,你可以有效地处理图像分类任务。借助PyTorch的灵活性,可以根据特定的数据集和应用程序构建、训练和微调模型。示例代码仅为理论过程,实际项目中还有大量优化空间。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/34448.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【2025】基于python+django的驾校招生培训管理系统(源码、万字文档、图文修改、调试答疑)

课题功能结构图如下: 驾校招生培训管理系统设计 一、课题背景 随着机动车保有量的不断增加,人们对驾驶技能的需求也日益增长。驾校作为驾驶培训的主要机构,面临着激烈的市场竞争和学员需求多样化等挑战。传统的驾校管理模式往往依赖于人工操作…

【JavaWeb】快速入门——HTMLCSS

文章目录 一、 HTML简介1、HTML概念2、HTML文件结构3、可视化网页结构 二、 HTML标签语法1、标题标签2、段落标签3、超链接4、换行5、无序列表6、路径7、图片8、块1 盒子模型2 布局标签 三、 使用HTML表格展示数据1、定义表格2、合并单元格横向合并纵向合并 四、 使用HTML表单收…

MySQL 优化方案

一、MySQL 查询过程 MySQL 查询过程是指从客户端发送 SQL 语句到 MySQL 服务器,再到服务器返回结果集的整个过程。这个过程涉及多个组件的协作,包括连接管理、查询解析、优化、执行和结果返回等。 1.1 查询过程的关键组件 连接管理器:管理…

服务性能防腐体系:基于自动化压测的熔断机制

01# 背景 在系统架构的演进过程中,项目初始阶段都会通过压力测试构建安全护城河,此时的服务性能与资源水位保持着黄金比例关系。然而在业务高速发展时期,每个冲刺周期都被切割成以业务需求为单位的开发单元,压力测试逐渐从必选项…

六十天前端强化训练之第二十天React Router 基础详解

欢迎来到编程星辰海的博客讲解 看完可以给一个免费的三连吗,谢谢大佬! 目录 一、核心概念 1.1 核心组件 1.2 路由模式对比 二、核心代码示例 2.1 基础路由配置 2.2 动态路由示例 2.3 嵌套路由实现 2.4 完整示例代码 三、关键功能实现效果 四、…

grad_traj_optimization 开源项目

开源项目 grad_traj_optimization 使用教程-CSDN博客 ubuntu如何切换到root用户_ubuntu切换到root用户-CSDN博客 catkin_make: command not found 解决办法_catkin-make not found-CSDN博客 这就说明需要编译的package虽然存在,但不在指定的目录下。catkin_make命…

深圳南柯电子|净水器EMC测试整改:水质安全与电磁兼容性的双赢

在当今注重健康生活的时代,净水器作为家庭用水安全的第一道防线,其性能与安全性备受关注。其中,电磁兼容性(EMC)测试是净水器产品上市前不可或缺的一环,它直接关系到产品在复杂电磁环境中的稳定运行及不对其…

要登录的设备ip未知时的处理方法

目录 1 应用场景... 1 2 解决方法:... 1 2.1 wireshark设置... 1 2.2 获取网口mac地址,wireshark抓包前预过滤掉自身mac地址的影响。... 2 2.3 pc网口和设备对接... 3 2.3.1 情况1:... 3 2.3.2 情…

GHCTF web方向题解

upload?SSTI! import os import refrom flask import Flask, request, jsonify,render_template_string,send_from_directory, abort,redirect from werkzeug.utils import secure_filename import os from werkzeug.utils import secure_filenameapp Flask(__name__)# 配置…

Vision Transformer (ViT):将Transformer带入计算机视觉的革命性尝试(代码实现)

Vision Transformer (ViT):将Transformer带入计算机视觉的革命性尝试 作为一名深度学习研究者,如果你对自然语言处理(NLP)领域的Transformer架构了如指掌,那么你一定不会对它在序列建模中的强大能力感到陌生。然而&am…

蓝耘携手通义万象 2.1 图生视频:开启创意无限的共享新时代

在科技飞速发展的今天,各种新奇的技术不断涌现,改变着我们的生活和工作方式。蓝耘和通义万象 2.1 图生视频就是其中两项非常厉害的技术。蓝耘就像是一个超级大管家,能把各种资源管理得井井有条;而通义万象 2.1 图生视频则像是一个…

IEC61850标准下MMS 缓存报告控制块 ResvTms详细解析

IEC61850标准是电力系统自动化领域唯一的全球通用标准。IEC61850通过标准的实现,使得智能变电站的工程实施变得规范、统一和透明,这大大提高了变电站自动化系统的技术水平和安全稳定运行水平。 在 IEC61850 标准体系中,ResvTms(r…

【DeepSeek应用】DeepSeek模型本地化部署方案及Python实现

DeepSeek实在是太火了,虽然经过扩容和调整,但反应依旧不稳定,甚至小圆圈转半天最后却提示“服务器繁忙,请稍后再试。” 故此,本文通过讲解在本地部署 DeepSeek并配合python代码实现,让你零成本搭建自己的AI助理,无惧任务提交失败的压力。 一、环境准备 1. 安装依赖库 …

蓝思科技冲刺港股上市,双重上市的意欲何为?

首先,蓝思科技冲刺港股上市,这一举措是其国际化战略进入实质性阶段的重要标志。通过港股上市,蓝思科技有望进一步拓宽融资渠道,这不仅能够为公司带来更加多元化的资金来源,还能够降低对单一市场的依赖风险,…

深入探讨RAID 5的性能与容错能力:实验与分析(磁盘阵列)

前言—— 本实验旨在探讨 RAID 5 的性能和容错能力。通过创建 RAID 5 阵列并进行一系列读写性能测试及故障模拟,我们将观察 RAID 5 在数据冗余和故障恢复方面的表现,以验证其在实际应用中的可靠性和效率。 首先说明:最少三块硬盘, 使用 4 块…

excel中两个表格的合并

使用函数: VLOOKUP函数 如果涉及在excel中两个工作表之间进行配对合并,则: VLOOKUP(C1,工作表名字!A:B,2,0) 参考: excel表格中vlookup函数的使用方法步骤https://haokan.baidu.com/v?pdwisenatural&vid132733503560775…

基于ssm的宠物医院信息管理系统(全套)

一、系统架构 前端:html | layui | vue | element-ui 后端:spring | springmvc | mybatis 环境:jdk1.8 | mysql | maven | tomcat | idea | nodejs 二、代码及数据库 三、功能介绍 01. web端-首页1 02. web端-首页…

UE小:UE5.5 PixelStreamingInfrastructure 使用时注意事项

1、鼠标默认显示 player.ts中的Config中添加HoveringMouse:true 然后运行typescript\package.json中的"build":npx webpack --config webpack.prod.js

iOS底层原理系列01-iOS系统架构概览-从硬件到应用层

1. 系统层级结构 iOS系统架构采用分层设计模式,自底向上可分为五个主要层级,每层都有其特定的功能职责和技术组件。这种层级化结构不仅使系统更加模块化,同时也提供了清晰的技术抽象和隔离机制。 1.1 Darwin层:XNU内核、BSD、驱动…

Ubuntu从源代码编译安装QT

1. 下载源码 wget https://download.qt.io/official_releases/qt/5.15/5.15.2/single/qt-everywhere-src-5.15.2.tar.xz tar xf qt-everywhere-src-5.15.2.tar.xz cd qt-everywhere-src-5.15.22. 安装依赖库 sudo apt update sudo apt install build-essential libgl1-mesa-d…