[oneAPI] 手写数字识别-GAN

[oneAPI] 手写数字识别-GAN

  • 手写数字识别
    • 参数与包
    • 加载数据
    • 模型
    • 训练过程
    • 结果
  • oneAPI

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

手写数字识别

使用了pytorch以及Intel® Optimization for PyTorch,通过优化扩展了 PyTorch,使英特尔硬件的性能进一步提升,让手写数字识别问题更加的快速高效
在这里插入图片描述

使用MNIST数据集,该数据集包含了一系列以黑白图像表示的手写数字,每个图像的大小为28x28像素,数据集组成如下:

  • 训练集:包含60,000个图像和标签,用于训练模型。
  • 测试集:包含10,000个图像和标签,用于测试模型的性能。

每个图像都被标记为0到9之间的一个数字,表示图像中显示的手写数字。这个数据集常常被用来验证图像分类模型的性能,特别是在计算机视觉领域。

参数与包

import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_imageimport intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# Hyper-parameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'

加载数据

# Create a directory if not exists
if not os.path.exists(sample_dir):os.makedirs(sample_dir)# Image processing
# transform = transforms.Compose([
#                 transforms.ToTensor(),
#                 transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
#                                      std=(0.5, 0.5, 0.5))])
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5],  # 1 for greyscale channelsstd=[0.5])])# MNIST dataset
mnist = torchvision.datasets.MNIST(root='./data/',train=True,transform=transform,download=True)# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,batch_size=batch_size,shuffle=True)

模型

# Discriminator
D = nn.Sequential(nn.Linear(image_size, hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size, hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size, 1),nn.Sigmoid())# Generator 
G = nn.Sequential(nn.Linear(latent_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, image_size),nn.Tanh())

训练过程

# Device setting
D = D.to(device)
G = G.to(device)# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
D, d_optimizer = ipex.optimize(D, optimizer=d_optimizer)
G, g_optimizer = ipex.optimize(G, optimizer=g_optimizer)def denorm(x):out = (x + 1) / 2return out.clamp(0, 1)def reset_grad():d_optimizer.zero_grad()g_optimizer.zero_grad()# Start training
total_step = len(data_loader)
for epoch in range(num_epochs):for i, (images, _) in enumerate(data_loader):images = images.reshape(batch_size, -1).to(device)# Create the labels which are later used as input for the BCE lossreal_labels = torch.ones(batch_size, 1).to(device)fake_labels = torch.zeros(batch_size, 1).to(device)# ================================================================== ##                      Train the discriminator                       ## ================================================================== ## Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))# Second term of the loss is always zero since real_labels == 1outputs = D(images)d_loss_real = criterion(outputs, real_labels)real_score = outputs# Compute BCELoss using fake images# First term of the loss is always zero since fake_labels == 0z = torch.randn(batch_size, latent_size).to(device)fake_images = G(z)outputs = D(fake_images)d_loss_fake = criterion(outputs, fake_labels)fake_score = outputs# Backprop and optimized_loss = d_loss_real + d_loss_fakereset_grad()d_loss.backward()d_optimizer.step()# ================================================================== ##                        Train the generator                         ## ================================================================== ## Compute loss with fake imagesz = torch.randn(batch_size, latent_size).to(device)fake_images = G(z)outputs = D(fake_images)# We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))# For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdfg_loss = criterion(outputs, real_labels)# Backprop and optimizereset_grad()g_loss.backward()g_optimizer.step()if (i + 1) % 200 == 0:print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'.format(epoch, num_epochs, i + 1, total_step, d_loss.item(), g_loss.item(),real_score.mean().item(), fake_score.mean().item()))# Save real imagesif (epoch + 1) == 1:images = images.reshape(images.size(0), 1, 28, 28)save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))# Save sampled imagesfake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch + 1)))# Save the model checkpoints 
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')

结果

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

oneAPI

import intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# Device setting
D = D.to(device)
G = G.to(device)# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
D, d_optimizer = ipex.optimize(D, optimizer=d_optimizer)
G, g_optimizer = ipex.optimize(G, optimizer=g_optimizer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/99089.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux驱动开发(Day4)

思维导图: 字符设备驱动分步注册:

Linux —— 进程间通信(System V)

目录 一,共享内存 申请共享内存 shmget 控制共享内存 shmctl 关联共享内存 shmat / 去联共享内存 shmdt 二,消息队列 创建或打开消息队列 msgget 发送消息 msgsnd / 接收消息 msgrcv 控制消息 msgctl 三,信号量 创建或打开信号量 s…

搭建Everything+cpolar在线资料库,实现随时随地访问

Everythingcpolar搭建在线资料库,实现随时随地访问 文章目录 Everythingcpolar搭建在线资料库,实现随时随地访问前言1.软件安装完成后,打开Everything2.登录cpolar官网 设置空白数据隧道3.将空白数据隧道与本地Everything软件结合起来总结 前…

【算法系列篇】双指针

文章目录 前言什么是双指针算法1.移动零1.1 题目要求1.2 做题思路1.3 Java代码实现 2.复写零2.1 题目要求2.2 做题思路2.3 Java代码实现 3.快乐数3.1 题目要求3.2 做题思路3.3 Java代码实现 4.盛最多水的容器4.1 题目要求4.2 做题思路4.3 Java代码实现 5.有效三角形的个数5.1 题…

HAProxy搭建web集群

HAProxy Haproxy概念 HAProxy是可提供高可用性、负载均衡以及基于TCP和HTTP应用的代理,是免费、快速并且可靠的一种解决方案。HAProxy非常适用于并发大(并发达1w以上)web站点,这些站点通常又需要会话保持或七层处理。HAProxy的运…

023:vue中解决el-date-picker更改样式不生效问题

第023个 查看专栏目录: VUE ------ element UI 本文章目录 修改后的效果示例源代码(共52行)核心内容步骤:(1)更改样式(2)添加参数 专栏目标 在vue项目开发中,我们打算保持颜色的一致…

第三讲:ApplicationContext的实现

这里写目录标题 一、前文回顾二、基础代码准备三、基于XML的ClassPathXmlApplicationContext1. 创建spring-config.xml配置文件2. 指定配置文件的路径 四、基于注解的AnnotationConfigApplicationContext1. 新增一个配置类2.指定配置类信息 五、基于注解和ServletWebServer应用…

Azure虚拟网络对等互连

什么是Azure虚拟网络对等互联 Azure虚拟网络对等互联(Azure Virtual Network peering)是一种连接两个虚拟网络的方法,使得这两个虚拟网络能够在同一地理区域内进行通信。它通过私有IP地址在虚拟网络之间建立网络连接,不论是在同一…

Linux命令200例:head用于显示文件的开头部分(常用)

🏆作者简介,黑夜开发者,全栈领域新星创作者✌。CSDN专家博主,阿里云社区专家博主,2023年6月csdn上海赛道top4。 🏆数年电商行业从业经验,历任核心研发工程师,项目技术负责人。 &…

LVS 负载均衡集群

集群 集群(Cluster)是一组相互连接的计算机或服务器,它们通过网络一起工作以完成共同的任务或提供服务。集群的目标是通过将多台计算机协同工作,提高计算能力、可用性、性能和可伸缩性,适用于大量高并发的场景。 集群…

消息中间件的选择:RabbitMQ是一个明智的选择

💗wei_shuo的个人主页 💫wei_shuo的学习社区 🌐Hello World ! MQ(Message Queue) MQ(消息队列)是一种用于在应用程序之间进行异步通信的技术;允许应用程序通过发送和接收…

【物联网无线通信技术】NFC从理论到实践(FM17XX)

NFC,全称是Near Field Communication,即“近场通信”,也叫“近距离无线通信”。NFC诞生于2004年,是基于RFID非接触式射频识别技术演变而来,由当时的龙头企业NXP(原飞利浦半导体)、诺基亚以及索尼联合发起。NFC采用13.5…

【记录】Python3|selenium4 极速上手入门(Windows)

环境:Windows 版本:python3,selenium 4.11.2 文章目录 1 装ChromeEdge其他浏览器 2 运行报错RequestsDependencyWarning: urllib3 (1.26.9) or chardet (3.0.4) doesn‘t match a supported version打开了浏览器,但是没有显示网页…

linux 安装 kibana

首先下载 kibana https://www.elastic.co/cn/downloads/kibana 然后上传到linux /usr/local 目录下解压安装 修改config/kibana.yml 配置文件,将elasticsearch.hosts 然后再nginx 中做一个端口映射,实现在浏览器中输入后xxxx:5602 nginx 可以将请求转发…

SPSS--如何使用分层分析以及分层分析案例分享

分层分析:将资料按某个或某些需要控制的变量的不同分类进行分层,然后再估计暴露因子与某结局变量之间关系的一种资料分析方法。 分层分析的最重要的用途是评估和控制混杂因子所致的混杂偏倚。通过按混杂因子分层,可使每层内的两个比较组在所控…

从零做软件开发项目系列之二——需求调研

在接到软件开发任务之后,第一件要做的事情就是进行需求调研工作,基于前期的沟通以及合同向用户了解具体需求,从而有针对性地开展后续工作。整个调研过程分为调研准备,调研实施,需求分析。 1 调研准备 俗话说&#x…

go语言恶意代码检测系统--对接前端可视化与算法检测部分

Malware Detect System 1 产品介绍 恶意代码检测系统。 2 产品描述 2.1 产品功能 功能点详细描述注册账号未注册用户注册成为产品用户,从而具备享有产品各项服务的资格登录账号用户登录产品,获得产品提供的各项服务上传恶意样本用户可以将上传自己的…

leetcode 279. 完全平方数

2023.8.18 与零钱兑换相似&#xff0c;本题属于完全背包问题&#xff1a;完全平方数为物品&#xff0c;整数n为背包。 直接上代码&#xff1a; class Solution { public:int numSquares(int n) {vector<int> dp(n1 , INT_MAX);dp[0] 0;for(int i1; i*i<n; i){for(in…

九耶丨阁瑞钛伦特-Spring boot与Spring cloud 之间的关系

Spring Boot和Spring Cloud是两个相互关联的项目&#xff0c;它们可以一起使用来构建微服务架构。 Spring Boot是一个用于简化Spring应用程序开发的框架&#xff0c;它提供了自动配置、快速开发的特性&#xff0c;使得开发人员可以更加轻松地创建独立的、生产级别的Spring应用程…

高效实用小工具之Everything

一&#xff0c;简介 有时候我们电脑文件较多时&#xff0c;想快速找到某个文件不是一件容易的事情&#xff0c;实用windows自带的搜素太耗时&#xff0c;效率不高。今天推荐一个用来搜索电脑文件的小工具——Everything&#xff0c;本文将介绍如何安装以及使用everything&…