[oneAPI] 图像分类CIFAR-10

[oneAPI] 图像分类CIFAR-10

  • 图像分类
    • 参数与包
    • 加载数据
    • 模型
    • 训练过程
    • 结果
  • oneAPI

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

图像分类

使用了pytorch以及Intel® Optimization for PyTorch,通过优化扩展了 PyTorch,使英特尔硬件的性能进一步提升,让手写数字识别问题更加的快速高效
在这里插入图片描述

使用CIFAR-10数据集, 数据集是一个常用的计算机视觉数据集,包含了 60000 张 32x32 像素的彩色图像,涵盖了 10 个不同的类别,每个类别有 6000 张图像。这个数据集被广泛用于图像分类、物体识别等任务的训练和评估。

数据集被分成了训练集和测试集,其中训练集包含 50000 张图像,测试集包含 10000 张图像。
CIFAR-10 数据集包含以下 10 个类别:

飞机(airplane)
汽车(automobile)
鸟类(bird)
猫(cat)
鹿(deer)
狗(dog)
青蛙(frog)
马(horse)
船(ship)
卡车(truck)

参数与包

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transformsimport intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# Hyper-parameters
num_epochs = 80
batch_size = 100
learning_rate = 0.001

加载数据

# Image preprocessing modules
transform = transforms.Compose([transforms.Pad(4),transforms.RandomHorizontalFlip(),transforms.RandomCrop(32),transforms.ToTensor()])# CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data/',train=True,transform=transform,download=True)test_dataset = torchvision.datasets.CIFAR10(root='./data/',train=False,transform=transforms.ToTensor())# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)

模型

# 3x3 convolution
def conv3x3(in_channels, out_channels, stride=1):return nn.Conv2d(in_channels, out_channels, kernel_size=3,stride=stride, padding=1, bias=False)# Residual block
class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1, downsample=None):super(ResidualBlock, self).__init__()self.conv1 = conv3x3(in_channels, out_channels, stride)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(out_channels, out_channels)self.bn2 = nn.BatchNorm2d(out_channels)self.downsample = downsampledef forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample:residual = self.downsample(x)out += residualout = self.relu(out)return out# ResNet
class ResNet(nn.Module):def __init__(self, block, layers, num_classes=10):super(ResNet, self).__init__()self.in_channels = 16self.conv = conv3x3(3, 16)self.bn = nn.BatchNorm2d(16)self.relu = nn.ReLU(inplace=True)self.layer1 = self.make_layer(block, 16, layers[0])self.layer2 = self.make_layer(block, 32, layers[1], 2)self.layer3 = self.make_layer(block, 64, layers[2], 2)self.avg_pool = nn.AvgPool2d(8)self.fc = nn.Linear(64, num_classes)def make_layer(self, block, out_channels, blocks, stride=1):downsample = Noneif (stride != 1) or (self.in_channels != out_channels):downsample = nn.Sequential(conv3x3(self.in_channels, out_channels, stride=stride),nn.BatchNorm2d(out_channels))layers = []layers.append(block(self.in_channels, out_channels, stride, downsample))self.in_channels = out_channelsfor i in range(1, blocks):layers.append(block(out_channels, out_channels))return nn.Sequential(*layers)def forward(self, x):out = self.conv(x)out = self.bn(out)out = self.relu(out)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.avg_pool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out

训练过程

model = ResNet(ResidualBlock, [2, 2, 2]).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)# For updating learning rate
def update_lr(optimizer, lr):for param_group in optimizer.param_groups:param_group['lr'] = lr# Train the model
total_step = len(train_loader)
curr_lr = learning_rate
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i + 1) % 100 == 0:print("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}".format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))# Decay learning rateif (epoch + 1) % 20 == 0:curr_lr /= 3update_lr(optimizer, curr_lr)# Test the model
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))# Save the model checkpoint
torch.save(model.state_dict(), 'resnet.ckpt')

结果

在这里插入图片描述

在这里插入图片描述

oneAPI

import intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# 模型
model = ConvNet(num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/92789.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解决内网GitLab 社区版 15.11.13项目拉取失败

问题描述 GitLab 社区版 发布不久,搭建在内网拉取项目报错,可能提示 unable to access https://github.comxxxxxxxxxxx: Failed to connect to xxxxxxxxxxxxxGit clone error - Invalid argument error:14077438:SSL routines:SSL23_GET_S 15.11.13ht…

【学会动态规划】乘积为正数的最长子数组长度(21)

目录 动态规划怎么学? 1. 题目解析 2. 算法原理 1. 状态表示 2. 状态转移方程 3. 初始化 4. 填表顺序 5. 返回值 3. 代码编写 写在最后: 动态规划怎么学? 学习一个算法没有捷径,更何况是学习动态规划, 跟我…

《系统架构设计师教程》重点章节思维导图

内容来自《系统架构设计师教程》,筛选系统架构设计师考试中分值重点分布的章节,根据章节的内容整理出相关思维导图。 重点章节 第2章:计算机系统知识第5章:软件工程基础知识第7章:系统架构设计基础知识第8章&#xff1…

常见程序搜索关键字转码

个别搜索类的网站因为用户恶意搜索出现误拦截情况,这类网站本身没有非法信息,只是因为把搜索关键字显示在网页中(如下图),可以参考下面方法对输出的关键字进行转码 DEDECMS程序 本文针对Dedecms程序进行搜索转码&…

XXL-JOB任务调度平台的安装使用教程

首先从GitHub上面将项目clone下来。 GitHub地址:https://gitee.com/xuxueli0323/xxl-job.git 下载好之后,然后通过IDEA打开,将Maven编译好后项目结构如下 在数据库中运行这个SQL文件 ,将基础表创建出来。就可以得到左边图中那些表…

复合 类型

字符串和切片 切片 切片的作用是允许你引用集合中部分连续的元素序列,而不是引用整个集合。 例如: let s String::from("hello world");let hello &s[0..5]; // 切片 [0,5) 等效于&s[..5] let world &s[6..11]; // 切片…

电脑缺失msvcp140.dll怎么办?解决msvcp140.dll缺失问题

msvcp140.dll是Microsoft Visual C Redistributable中的一个动态链接库文件,用于提供运行时支持。它的主要作用是为应用程序提供必要的函数和组件,以便在运行时执行特定的任务。当系统中缺少msvcp140.dll文件时,我们打开游戏或许软件时候就会…

Wordcloud | 风中有朵雨做的‘词云‘哦!~

1写在前面 今天可算把key搞好了,不得不说🏥里手握生杀大权的人,都在自己的能力范围内尽可能的难为你。😂 我等小大夫也是很无奈,毕竟奔波霸、霸波奔是要去抓唐僧的。 🤐 好吧,今天是词云&#x…

使用selenium如何实现自动登录

回顾使用requests如何实现自动登录一文中,提到好多网站在我们登录过后,在之后的某段时间内访问该网页时,不会给出请登录的提示,时间到期后就会提示请登录!这样在使用爬虫访问网页时还要登录,打乱我们的节奏…

【BASH】回顾与知识点梳理(二十七)

【BASH】回顾与知识点梳理 二十七 二十七. 磁盘配额(Quota)27.1 磁盘配额 (Quota) 的应用与实作什么是 QuotaQuota 的一般用途Quota 的使用限制Quota 的规范设定项目 27.2 一个 XFS 文件系统的 Quota 实作范例实作 Quota 流程:设定账号实作 Quota 流程-1&#xff1a…

Python random模块用法整理

随机数在计算机科学领域扮演着重要的角色,用于模拟真实世界的随机性、数据生成、密码学等多个领域。Python 中的 random 模块提供了丰富的随机数生成功能,本文整理了 random 模块的使用。 文章目录 Python random 模块注意事项Python random 模块的内置…

用AI攻克“智能文字识别创新赛题”,这场大学生竞赛掀起了什么风潮?

文章目录 一、前言1.1 大赛介绍1.2 项目背景 二、基于智能文字场景个人财务管理创新应用2.1 作品方向2.2 票据识别模型2.2.1 文本卷积神经网络TextCNN2.2.2 Bert 预训练微调2.2.3 模型对比2.2.4 效果展示 2.3 票据文字识别接口 三、未来展望 一、前言 1.1 大赛介绍 中国大学生…

时序预测-Informer简介

文章目录 Informer介绍1. Transformer存在的问题2. Informer研究背景3. Informer 整体架构3.1 ProbSparse Self-attention3.2 Self-attention Distilling3.3 Generative Style Decoder 4. Informer的实验性能5. 相关资料 Informer介绍 1. Transformer存在的问题 Informer实质…

网络套接字

网络套接字 文章目录 网络套接字认识端口号初识TCP协议初识UDP协议网络字节序 socket编程接口socket创建socket文件描述符bind绑定端口号sockaddr结构体netstat -nuap:查看服务器网络信息 代码编译运行展示 实现简单UDP服务器开发 认识端口号 端口号(port)是传输层协…

【Linux】ICMP协议——网络层

ICMP协议 ICMP(Internet Control Message Protoco)Internet控制报文协议,用于在IP主机、路由器之间传递控制信息,是一个TCP/IP协议。该协议是用来检测网络传输的问题,相当于维修人员的工具。 ICMP协议的定位 在TCP/IP…

Aspera替代方案:探索这些安全且可靠的文件传输工具

科技的发展日新月异,文件的传输方式也在不断地更新换代。传统的邮件附件、FTP等方式已经难以满足人们对于传输速度和安全性的需求了。近年来,一些新兴的文件传输工具受到了人们的关注,其中除了知名的Aspera之外,还有许多可靠安全的…

简绘ChatGPT支持Midjourney绘图 支持stable diffusion绘图

简绘支持Midjourney绘图和stable diffusion绘图。 这意味着简绘具备Midjourney绘图和stable diffusion绘图功能的支持。

CSS自学框架之表单

首先我们看一下表单样式,下面共有5张截图 一、CSS代码 /*表单*/fieldset{border: none;margin-bottom: 2em;}fieldset > *{ margin-bottom: 1em }fieldset:last-child{ margin-bottom: 0 }fieldset legend{ margin: 0 0 1em }/* legend标签是CSS中用于定义…

《qt quick核心编程》笔记四

11 Model/View Delegate实际上可以看成是Item的一个模板 11.1 ListView ListView用于显示一个条目列表,数据来自于Model,每个条目的外观来自于Delegate 要使用ListView必须指定一个Model、一个Delegate Model可以是QML内建类型,如ListModel…

QGraphicsView实现简易地图6『异步加载-无底图』

前文链接:QGraphicsView实现简易地图5『经纬网格』 同步加载,虽然程序已做到最少瓦片加载,但或多或少都存在一定程度上的卡顿现象,或者说是不够流畅吧。因此尝试采用异步加载,大致思路是每次缩放或漫游时计算所需重新加…