Pytorch 实现图片分类

CNN 网络适用于图片识别,卷积神经网络主要用于图片的处理识别。卷积神经网络,包括一下几部分,输入层、卷积层、池化层、全链接层和输出层。
在这里插入图片描述
使用 CIFAR-10 进行训练, CIFAR-10 中图片尺寸为 32 * 32。卷积层通过卷积核移动进行计算最终生成特征图。

在这里插入图片描述
通过池化层进行降维度
在这里插入图片描述

卷积网络结构从输入到输出, 3* 32*32 --> 10:

类型WeightBIAS
卷积(3, 12, 5)(12, 3, 5, 5)12
卷积(12, 12, 5)(12, 12, 5, 5)12
Norm1212
卷积(12, 24, 5)(24, 12, 5, 5)24
卷积(24 24, 5)(24, 24, 5, 5)24
Norm2424
Linear(10, 2400)10

训练分类模型

准备数据
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms
from torch.utils.data import DataLoader# Loading and normalizing the data.
# Define transformations for the training and test sets
transformations = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# CIFAR10 dataset consists of 50K training images. We define the batch size of 10 to load 5,000 batches of images.
batch_size = 10
number_of_labels = 10 # Create an instance for training. 
# When we run this code for the first time, the CIFAR10 train dataset will be downloaded locally. 
train_set =CIFAR10(root="./data",train=True,transform=transformations,download=True)# Create a loader for the training set which will read the data within batch size and put into memory.
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
print("The number of images in a training set is: ", len(train_loader)*batch_size)# Create an instance for testing, note that train is set to False.
# When we run this code for the first time, the CIFAR10 test dataset will be downloaded locally. 
test_set = CIFAR10(root="./data", train=False, transform=transformations, download=True)# Create a loader for the test set which will read the data within batch size and put into memory. 
# Note that each shuffle is set to false for the test loader.
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)
print("The number of images in a test set is: ", len(test_loader)*batch_size)print("The number of batches per epoch is: ", len(train_loader))
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
创建网络
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F# Define a convolution neural network
class Network(nn.Module):def __init__(self):super(Network, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=1)self.bn1 = nn.BatchNorm2d(12)self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=1)self.bn2 = nn.BatchNorm2d(12)self.pool = nn.MaxPool2d(2,2)self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=1)self.bn4 = nn.BatchNorm2d(24)self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=1)self.bn5 = nn.BatchNorm2d(24)self.fc1 = nn.Linear(24*10*10, 10)def forward(self, input):output = F.relu(self.bn1(self.conv1(input)))      output = F.relu(self.bn2(self.conv2(output)))     output = self.pool(output)                        output = F.relu(self.bn4(self.conv4(output)))     output = F.relu(self.bn5(self.conv5(output)))     output = output.view(-1, 24*10*10)output = self.fc1(output)return output# Instantiate a neural network model 
model = Network()

定义损失函数

使用交叉熵函数作为损失函数,交叉熵分为两种

  • 二分类交叉熵函数
    在这里插入图片描述
  • 多分类交叉熵函数
    在这里插入图片描述
loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
模型训练
from torch.autograd import Variable# Function to save the model
def saveModel():path = "./myFirstModel.pth"torch.save(model.state_dict(), path)# Function to test the model with the test dataset and print the accuracy for the test images
def testAccuracy():model.eval()accuracy = 0.0total = 0.0device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")with torch.no_grad():for data in test_loader:images, labels = data# run the model on the test set to predict labelsoutputs = model(images.to(device))# the label with the highest energy will be our prediction_, predicted = torch.max(outputs.data, 1)total += labels.size(0)accuracy += (predicted == labels.to(device)).sum().item()# compute the accuracy over all test imagesaccuracy = (100 * accuracy / total)return(accuracy)# Training function. We simply have to loop over our data iterator and feed the inputs to the network and optimize.
def train(num_epochs):best_accuracy = 0.0# Define your execution devicedevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("The model will be running on", device, "device")# Convert model parameters and buffers to CPU or Cudamodel.to(device)for epoch in range(num_epochs):  # loop over the dataset multiple timesrunning_loss = 0.0running_acc = 0.0for i, (images, labels) in enumerate(train_loader, 0):# get the inputsimages = Variable(images.to(device))labels = Variable(labels.to(device))# zero the parameter gradientsoptimizer.zero_grad()# predict classes using images from the training setoutputs = model(images)# compute the loss based on model output and real labelsloss = loss_fn(outputs, labels)# backpropagate the lossloss.backward()# adjust parameters based on the calculated gradientsoptimizer.step()# Let's print statistics for every 1,000 imagesrunning_loss += loss.item()     # extract the loss valueif i % 1000 == 999:    # print every 1000 (twice per epoch) print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 1000))# zero the lossrunning_loss = 0.0# Compute and print the average accuracy fo this epoch when tested over all 10000 test imagesaccuracy = testAccuracy()print('For epoch', epoch+1,'the test accuracy over the whole test set is %d %%' % (accuracy))# we want to save the model if the accuracy is the bestif accuracy > best_accuracy:saveModel()best_accuracy = accuracy
测试模型
import matplotlib.pyplot as plt
import numpy as np# Function to show the images
def imageshow(img):img = img / 2 + 0.5     # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()# Function to test the model with a batch of images and show the labels predictions
def testBatch():# get batch of images from the test DataLoader  images, labels = next(iter(test_loader))# show all images as one image gridimageshow(torchvision.utils.make_grid(images))# Show the real labels on the screen print('Real labels: ', ' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))# Let's see what if the model identifiers the  labels of those exampleoutputs = model(images)# We got the probability for every 10 labels. The highest (max) probability should be correct label_, predicted = torch.max(outputs, 1)# Let's show the predicted labels on the screen to compare with the real onesprint('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(batch_size)))
执行模型
if __name__ == "__main__":# Let's build our modeltrain(5)print('Finished Training')# Test which classes performed welltestAccuracy()# Let's load the model we just created and test the accuracy per labelmodel = Network()path = "myFirstModel.pth"model.load_state_dict(torch.load(path))# Test with batch of imagestestBatch()

在这里插入图片描述

总结

pytorch 搭建一个 CNN 模型比较简单,5 轮训练之后,效果就可以达到 60%,10 张图片中预测对了 6 张。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/454597.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

告别微信封号!学会这5招,让你的账号坚不可摧

在这个信息爆炸的时代,无论是工作沟通、社交互动还是获取信息,微信都扮演着极其重要的角色。但是,随着微信平台规则的日益严格,账号被封的风险也随之增加。今天,我们就来聊聊如何有效防止 微信被封,让你的账…

【RL Latest Tech】安全强化学习(Safe RL):理论、方法与应用

📢本篇文章是博主强化学习(RL)领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅…

Python | Leetcode Python题解之第486题预测赢家

题目: 题解: class Solution:def PredictTheWinner(self, nums: List[int]) -> bool:length len(nums)dp [0] * lengthfor i, num in enumerate(nums):dp[i] numfor i in range(length - 2, -1, -1):for j in range(i 1, length):dp[j] max(num…

uniapp结合uview-ui创建项目

准备工作: 下载HBuilderX;官网地址:HBuilderX-高效极客技巧 安装node.js;官网地址:Node.js — 在任何地方运行 JavaScript,下载所需版本,安装后配置好环境变量即可 方式一(NPM安装方式): 1、打开开发者…

代码随想录-哈希表-两个数组的交集

题目与思路 这个题目可以用大小为1000的数组,因为限制了变量大小。我用的是set,感觉这个才是本意。 下面代码是用set实现的hash结构。先用set1存储nums1中的去重元素,然后用set2存储结果集 最后将set转为array用了两种方式,一是…

LINUX设备OTA时无法从HTTP服务器(TOMCAT)下载文件

疑难问题排查记录 问题 linux设备作为http客户端,执行OTA前先从HTTP服务器下载bin固件,测试nginx没有问题,nodejs编写的HTTP服务器也没有问题,而软件同事使用的Tomcat则无法成功下载。 排查经过 首先利用chrome浏览器测试下载…

WebGl 实现图片平移、缩放和旋转

1.图片平移 在WebGL中实现图片平移,可以通过修改顶点着色器中的顶点位置来实现。平移的基本思想是将每个顶点的位置向量沿着指定的方向(通常是x轴和y轴)进行平移。在顶点着色器中,可以通过添加或减去一个统一的偏移量&#xff08…

「AIGC」n8n AI Agent开源的工作流自动化工具

n8n AI Agent 是一个利用大型语言模型(LLMs)来设计和构建智能体(agents)的工具,这些智能体能够执行一系列复杂的任务,如理解指令、模仿类人推理,以及从用户命令中理解隐含意图。n8n AI Agent 的核心在于构建一系列提示(prompts),使 LLM 能够模拟自主行为。 传送门→ …

C++:stack 和 queue 的使用和模拟实现

使用 要注意&#xff0c;stack 和 queue 没有迭代器。 stack void teststack() {stack<int> st;st.push(1);st.push(2);st.push(3);st.push(4);st.push(5);cout << st.size() << endl;while (!st.empty()){cout << st.top() << " ";…

电机编码器

旋转式编码器 链接: 野火编码器 单圈&#xff0c;多圈绝对式编码器 在单圈绝对值编码器中&#xff0c;单圈并不表示编码器机械转动的圈数&#xff0c;而是指断电记忆的范围。实际上&#xff0c;机械转动是可以无限制地进行圈数的&#xff0c;但是断电记忆仅限于一圈的范围内…

构建后端为etcd的CoreDNS的容器集群(二)、下载最新的etcd容器镜像

在尝试获取etcd的容器的最新版本镜像时&#xff0c;使用latest作为tag取到的并非最新版本&#xff0c;本文尝试用实际最新版本的版本号进行pull&#xff0c;从而取到想的最新版etcd容器镜像。 一、用latest作为tag尝试下载最新etcd的镜像 1、下载镜像 [rootlocalhost opt]# …

bash之基本运算符

一.算术运算符 vim test.sh #!/bin/basha10 b20valexpr $a $b echo "a b : $val"valexpr $a - $b echo "a - b : $val"valexpr $a \* $b echo "a * b : $val"valexpr $b / $a echo "b / a : $val"valexpr $b % $a echo "b % a …

《重置MobaXterm密码并连接Linux虚拟机的完整操作指南》

目录 引言 一、双击MobaXterm_Personal_24.2进入&#xff0c;但是忘记密码。 那么接下来请跟着我操作。 二、点击此链接&#xff0c;重设密码。 三、下载完成后&#xff0c;现在把这个exe文件解压。注意解压要与MobaXterm_Personal_24.2.exe在同一目录下哦&#xff0c;不然…

LMS自适应滤波器原理与应用Matlab程序

1. 引言 自适应滤波器是一种根据输入数据动态调整其滤波系数的滤波器&#xff0c;能够在未知或变化的环境中有效工作。最小均方&#xff08;LMS&#xff0c;Least Mean Squares&#xff09;自适应滤波器是自适应滤波器家族中最经典的一种&#xff0c;其以计算简便、收敛速度快…

Docker基础部署

一、安装Ubuntu系统 1.1 新建虚拟机 打开VMware Workstation&#xff0c;选择文件->新建虚拟机->典型&#xff08;推荐T&#xff09;->安装程序光盘映像文件->输入虚拟的名字->一直下一步即可 安装程序光盘映像文件 注意&#xff1a;选择CentOS-7-x86_64-DVD-…

基于知识图谱的美食推荐系统

想象一下&#xff0c;每次打开应用&#xff0c;它都能为你量身推荐最符合你口味的美食&#xff0c;不需要再为“今天吃什么&#xff1f;”烦恼。这听起来是不是非常吸引人&#xff1f;今天就给大家介绍一个适合做毕业设计的创新项目——基于知识图谱的美食推荐系统&#xff01;…

STM32_实验5_中断实验

通过外部中断来检测四个按键按下的状态&#xff1a; WK_UP 控制蜂鸣器响和停 KEY0 控制 LED_R 互斥点亮 KEY1 控制 LED_G 互斥点亮 KEY2 控制 LED_B 互斥点亮。 中断的基本概念&#xff1a; 中断请求&#xff08;IRQ&#xff09;&#xff1a; 当发生某个特定事件&#xff08;例…

告别ELK,APO提供基于ClickHouse开箱即用的高效日志方案——APO 0.6.0发布

ELK一直是日志领域的主流产品&#xff0c;但是ElasticSearch的成本很高&#xff0c;查询效果随着数据量的增加越来越慢。业界已经有很多公司&#xff0c;比如滴滴、B站、Uber、Cloudflare都已经使用ClickHose作为ElasticSearch的替代品&#xff0c;都取得了不错的效果&#xff…

AOP 面向切面编程

1.准备工作&#xff0c;创建maven项目 1. pom.xml 加入依赖 <dependencies><!--spring核心坐标--><dependency><groupId>org.springframework</groupId><artifactId>spring-context</artifactId><version>6.0.6</version&…

算术移位的学习

术移位&#xff08;Arithmetic Shift&#xff09;是一种位移操作&#xff0c;主要用于有符号整数。它与逻辑移位相似&#xff0c;但在处理负数时有一些显著的不同。算术移位能够保持符号位的完整性&#xff0c;因此在有符号数的移位运算中非常有用。 算术移位的类型 算术左移&…