[oneAPI] 手写数字识别-LSTM

[oneAPI] 手写数字识别-LSTM

  • 手写数字识别
    • 参数与包
    • 加载数据
    • 模型
    • 训练过程
    • 结果
  • oneAPI

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

手写数字识别

使用了pytorch以及Intel® Optimization for PyTorch,通过优化扩展了 PyTorch,使英特尔硬件的性能进一步提升,让手写数字识别问题更加的快速高效
在这里插入图片描述

使用MNIST数据集,该数据集包含了一系列以黑白图像表示的手写数字,每个图像的大小为28x28像素,数据集组成如下:

  • 训练集:包含60,000个图像和标签,用于训练模型。
  • 测试集:包含10,000个图像和标签,用于测试模型的性能。

每个图像都被标记为0到9之间的一个数字,表示图像中显示的手写数字。这个数据集常常被用来验证图像分类模型的性能,特别是在计算机视觉领域。

参数与包

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transformsimport intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.01

加载数据

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',train=True,transform=transforms.ToTensor(),download=True)test_dataset = torchvision.datasets.MNIST(root='../../data/',train=False,transform=transforms.ToTensor())# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)

模型

# Recurrent neural network (many-to-one)
class RNN(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(RNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, num_classes)def forward(self, x):# Set initial hidden and cell states h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)# Forward propagate LSTMout, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size)# Decode the hidden state of the last time stepout = self.fc(out[:, -1, :])return out

训练过程

model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.reshape(-1, sequence_length, input_size).to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i + 1) % 100 == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))# Test the model
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.reshape(-1, sequence_length, input_size).to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

结果

在这里插入图片描述

oneAPI

import intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# 模型
model = ConvNet(num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/93517.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

5G无人露天矿山解决方案

1、5G无人露天矿山解决方案背景 ①2010.10,国家安监总局《金属非金属地下矿山安全避险“六大系统”安装使用和监督检查暂行规定》 ②2016.03,国家发改委《能源技术革命创新行动计划(2016-2030)》,2025 年重点煤矿区采…

6.1 安全漏洞与网络攻击

数据参考:CISP官方 目录 安全漏洞及产生原因信息收集与分析网络攻击实施后门设置与痕迹清除 一、安全漏洞及产生原因 什么是安全漏洞 安全漏洞也称脆弱性,是计算机系统存在的缺陷 漏洞的形式 安全漏洞以不同形式存在漏洞数量逐年递增 漏洞产生的…

强化学习:用Python训练一个简单的机器人

一、介绍 强化学习(RL)是一个令人兴奋的研究领域,它使机器能够通过与环境的交互来学习。在这篇博客中,我们将深入到RL的世界,并探索如何使用Python训练一个简单的机器人。在本文结束时,您将对 RL 概念有基本…

Qt 杂项(Qwt、样式等)

Qt隐藏窗口边框 this->setWindowFlags(Qt::FramelessWindowHint);Qt模态框 this->setWindowModality(Qt::ApplicationModal);QLable隐藏border 代码中设置 lable->setStyleSheet("border:0px");或者UI中直接设置样式:“border:0px” Qwt开源…

什么是DNS服务器的层次化和分布式?

DNS (Domain Name System) 的结构是层次化的,意味着它是由多个级别的服务器组成,每个级别负责不同的部分。以下是 DNS 结构的层次: 根域服务器(Root Servers): 这是 DNS 层次结构的最高级别。全球有13组根域…

chrome解决http自动跳转https问题

1.地址栏输入: chrome://net-internals/#hsts 2.找到底部Delete domain security policies一栏,输入想处理的域名,点击delete。 例如我之前可能访问过这个网址,https://test.apac.com:9090/login 但是后面我去掉了https协议&…

Patch SCN一键解决ORA-600 2662故障---惜分飞

客户强制重启库之后,数据库启动报ORA-600 2037,ORA-745 kcbs_reset_pool/kcbzre1等错误 Wed Aug 09 13:25:38 2023 alter database mount exclusive Successful mount of redo thread 1, with mount id 1672229586 Database mounted in Exclusive Mode Lost write protection d…

ArcGIS 利用cartogram插件制作变形地图

成果图 注:本图数据并不完全对,只做为测试用例 操作 首先需要下载一个插件cartogram 下载地址在这里 https://www.arcgis.com/home/item.html?idd348614c97264ae19b0311019a5f2276 下载完毕之后解压将Cartograms\HelpFiles下的所有文件复制到ArcGIS…

NFT Insider#102:The Sandbox重新上线LAND桥接服务,YGG加入Base生态

引言:NFT Insider由NFT收藏组织WHALE Members(https://twitter.com/WHALEMembers)、BeepCrypto(https://twitter.com/beep_crypto)联合出品,浓缩每周NFT新闻,为大家带来关于NFT最全面、最新鲜、最有价值的讯息。每期周…

【npm run dev报错】无法加载文件 C:\Program Files\nodejs\npm.ps1,因为在此系统上禁止运行脚本。

1.winX键,使用管理员身份运行power shell 2.输入命令:set-executionpolicy remotesigned 3.输入”Y“,回车,问题解决。 文章来源:无法加载文件 C:\Program Files\nodejs\npm.ps1,因为在此系统上禁止运行脚本。 - 前…

java之juc二

JMM 请你谈谈对Volatile的理解 Volatile是jvm提供的轻量级的同步机制(和synchronized差不多,但是没有synchronized那么强大) 保证可见性不保证原子性禁止指令重排 什么是JMM JMM:java内存模型,不存在的东西&#…

Linux 修改信号的响应方式

修改信号的响应方式 1.signal()方法介绍: 修改信号的响应方式要用到方法signal()。需要引用头文件signal.h。signal()的原型: typedef重命名了一个函数指针的类型,这个指针的类型为指向一个参数为int返回值为void的函数的指针。这个函数指针…

小白到运维工程师自学之路 第七十三集 (kubernetes应用部署)

一、安装部署 1、以Deployment YAML方式创建Nginx服务 这个yaml文件在网上可以下载 cat nginx-deployment.yaml apiVersion: apps/v1 #apiVersion是当前配置格式的版本 kind: Deployment #kind是要创建的资源类型,这里是Deploymnet metadata: #metadata是该资源…

Max Compute 操作记录

编译 max compute-spark git clone https://github.com/aliyun/MaxCompute-Spark cd spark-3.x mvn clean package -DskipTests在 target 目录下生成 以下两个文件。 spark-examples_2.12-1.0.0-SNAPSHOT-shaded.jar spark-examples_2.12-1.0.0-SNAPSHOT.jar2. DataWorks 上传…

Genoss GPT简介:使用 Genoss 模型网关实现多个LLM模型的快速切换与集成

一、前言 生成式人工智能领域的发展继续加速,大型语言模型 (LLM) 的用途范围不断扩大。这些用途跨越不同的领域,包括个人助理、文档检索以及图像和文本生成。ChatGPT 等突破性应用程序为公司进入该领域并开始使用这项技术进行构建铺平了道路。 大公司正…

【设计模式】抽象工厂模式

抽象工厂模式(Abstract Factory Pattern)是围绕一个超级工厂创建其他工厂。该超级工厂又称为其他工厂的工厂。这种类型的设计模式属于创建型模式,它提供了一种创建对象的最佳方式。 在抽象工厂模式中,接口是负责创建一个相关对象…

ES基础及面试题

1. 什么是ES ES是一种开源的分布式搜索引擎,可以实现快速存储、搜索、分析大量数据。支持结构化查询和全文检索等多种方式 2. ES的实际用途 1. 全文搜索和信息检索 2. 日志分析,例如埋点分析 3. 监控和指标分析,网络流量,服务器…

如何用树莓派Pico针对IoT编程?

目录 一、Raspberry Pi Pico 系列和功能 二、Raspberry Pi Pico 的替代方案 三、对 Raspberry Pi Pico 进行编程 硬件 软件 第 1 步:连接计算机 第 2 步:在 Pico 上安装 MicroPython 第 3 步:为 Thonny 设置解释器 第 4 步&#xff…

一篇文章教会你搭建私人kindle图书馆,并内网穿透实现公网访问

搭建私人kindle图书馆,并内网穿透实现公网访问 在电子书风靡的时期,大部分人都购买了一本电子书,虽然这本电子书更多的时候是被搁置在储物架上吃灰,或者成为盖泡面的神器,但当亚马逊发布消息将放弃电子书在中国的服务…

windows pip安装出现 error: Microsoft Visual C++ 14.0 is required

可参考:如何解决 Microsoft Visual C 14.0 or greater is required. Get it with “Microsoft C Build Tools“_不吃香菜的小趴菜的博客-CSDN博客 一、安装Visual Studio2022 1、下载:下载 Visual Studio Tools - 免费安装 Windows、Mac、Linux 我这使…