paddle实现手写数字模型(一)

  1. 参考文档:paddle官网文档
  2. 调试代码如下:
    LeNet.py
import paddle
import paddle.nn.functional as Fclass LeNet(paddle.nn.Layer):def __init__(self):super().__init__()self.conv1 = paddle.nn.Conv2D(in_channels=1,out_channels=6,kernel_size=5,stride=1,padding=2)self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2,  stride=2)self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.max_pool1(x)x = self.conv2(x)x = F.relu(x)x = self.max_pool2(x)x = paddle.flatten(x, start_axis=1,stop_axis=-1)x = self.linear1(x)x = F.relu(x)x = self.linear2(x)x = F.relu(x)x = self.linear3(x)return x

train.py


import paddle
from paddle.vision.transforms import Compose,Normalize,ToTensor
import paddle.vision.transforms as T  import numpy as np
import matplotlib.pyplot as plt
from paddle.metric import Accuracyfrom LeNet import LeNet
from PIL import Imageprint(paddle.__version__)
transform = Compose([Normalize(mean=[127.5],std=[127.5],data_format='CHW')])
print('下载和加载训练数据...')
train_dataset = paddle.vision.datasets.MNIST(mode='train',transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test',transform=transform)
print('load finished')train_data0,train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
plt.imshow(train_data0,cmap=plt.cm.binary)
#plt.show()
print('train_data0 label is: '+str(train_label_0))model = paddle.Model(LeNet())   # 用Model封装模型
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())# 配置模型
print('配置模型...')
model.prepare(optim,paddle.nn.CrossEntropyLoss(),Accuracy())
# 训练模型
print('训练模型...')
model.fit(train_dataset,epochs=2,batch_size=64,verbose=1)
# 保存模型  
model.save('./model/mnist_model')  # 默认保存模型结构和参数 #预测模型
print('预测模型...')
model.evaluate(test_dataset, batch_size=64, verbose=1)

predicted.py


import paddleimport numpy as npfrom LeNet import LeNet
from PIL import Image# 读取一张本地的样例图片,转变成模型输入的格式
def load_image(img_path):# 从img_path中读取图像,并转为灰度图im = Image.open(img_path).convert('L')#plt.imshow(im,cmap='gray')# print(np.array(im))im = im.resize((28, 28), Image.Resampling.LANCZOS)im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32)# 图像归一化,保持和数据集的数据范围一致im = 1 - im / 255 return im# 加载训练好的模型参数
model = LeNet()
model.load_dict(paddle.load('./model/mnist_model.pdparams'))# 设置模型为评估模式
model.eval()# 准备一个MNIST样例图像
example_image = load_image("d:/8.png")# 转换为Tensor并进行推理
with paddle.no_grad():example_tensor = paddle.to_tensor(example_image)prediction = model(example_tensor)print(prediction)# 获取预测类别
predicted_class = np.argmax(prediction.numpy(), axis=1)[0]
print(f"Predicted class: {predicted_class}")

说明:先通过执行train.py训练数据集,将模型保存在model文件夹中,
然后运行predicted.py加载训练出来的数据集,推理出d:/8.png图片的结果。
结果图片如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/301214.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

YOLOV9 + 双目测距

YOLOV9 双目测距 1. 环境配置2. 测距流程和原理2.1 测距流程2.2 测距原理 3. 代码部分解析3.1 相机参数stereoconfig.py3.2 测距部分3.3 主代码yolov9-stereo.py 4. 实验结果4.1 测距4.2 视频展示 相关文章 1. YOLOV5 双目测距(python) 2. YOLOv7双目…

强化学习MPC——(一)

1.什么是强化学习 强化学习是机器学习的一种,是一种介于监督学习和非监督学习的机器学习方法。 学习二字就很形象的说明了这是一种利用数据(任何形式的)来实现一些已有问题的方法,学习方法,大致可以分为机器学习&…

说说TCP为什么需要三次握手和四次挥手?

一、三次握手 三次握手(Three-way Handshake)其实就是指建立一个TCP连接时,需要客户端和服务器总共发送3个包 主要作用就是为了确认双方的接收能力和发送能力是否正常、指定自己的初始化序列号为后面的可靠性传送做准备 过程如下&#xff…

Redis 常见面试题

目录 1. Redis是什么?2. Redis优缺点?3. Redis为什么这么快?4. 既然Redis那么快,为什么不用它做主数据库,只用它做缓存?5. Redis的线程模型?6. Redis 采用单线程为什么还这么快?7. R…

如何使用生成式人工智能撰写关于新产品发布的文章?

利用生成式人工智能撰写新产品发布文章确实是一种既有创意又高效的内容生成方式。以下是如何做到这一点的指南,附带一些背景信息: • 背景:在撰写文章之前,收集有关您的新产品的信息。这包括产品的名称、类别、特点、优势、目标受…

解决Xshell连接不上虚拟机

相信有很多同学和我一样遇到这个问题,在网上看了很多教程基本上都先让在虚拟机输入ifconfig命令查看ip地址,弄来弄去最后还是解决不了😭😭,其实问题根本就不在命令上,很大概率是我们的虚拟机没有开启网卡&a…

基于单片机便携式测振仪的研制系统设计

**单片机设计介绍,基于单片机便携式测振仪的研制系统设计 文章目录 一 概要二、功能设计三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于单片机便携式测振仪的研制系统设计概要主要涉及利用单片机作为核心控制器件,结合测振原理和技术&#x…

python-可视化篇-turtle-画爱心

文章目录 原效果替换关键字5为8,看看效果改下颜色 原效果 import turtle as tt.color(red,pink) t.begin_fill() t.width(5) t.left(135) t.fd(100) t.right(180) t.circle(50,-180) t.left(90) t.circle(50,-180) t.right(180) t.fd(100) t.pu() t.goto(50,-30) t…

蓝鲸6.1 CMDB 事件推送的开源替代方案

本文来自腾讯蓝鲸智云社区用户:木讷大叔爱运维 背景 在蓝鲸社区“社区问答”帖子中发现这么一个需求: 究其原因,我在《不是CMDB筑高墙,运维需要一定的开发能力!》一文中已经介绍,在此我再简单重复下&#…

JavaScript实现全选、反选功能(Vue全选、反选,js原生全选、反选)

简介: 在JavaScript中,实现全选和反选通常是通过操作DOM元素和事件监听来实现; 全选功能:当用户点击一个“全选”复选框时,页面中所有具有相同类名的复选框都将被选中; 反选功能:用户点击一个…

ARP寻址过程

当知道目标的IP但是不知道目标的Mac地址的时候就需要借助ARP寻址获取目标的Mac地址,传输层借助四元组(源IP源端口:目标IP目标端口)匹配,网络层借助IP匹配,数据链路层则根据Mac地址匹配,数据传输…

局域网共享文件夹怎么加密?局域网共享文件夹加密方法介绍

在企业局域网中,共享文件夹扮演着重要的角色。为了保护数据安全,我们需要加密保护局域网共享文件夹。那么,局域网共享文件夹怎么加密?下面我们来了解一下吧。 局域网共享文件夹加密方法 局域网共享文件夹加密推荐使用共享文件夹加…

在git上先新建仓库-把本地文件提交远程

一.在git新建远程项目库 1.选择新建仓库 以下以gitee为例 2.输入仓库名称,点击创建 这个可以选择仓库私有化还公开权限 3.获取仓库clone链接 这里选择https模式就行,就不需要配置对电脑进行sshkey配置了。只是需要每次提交输入账号密码 二、远…

万字源码解析!彻底搞懂 HashMap【一】:概念辨析与构造方法源码解析

HashMap 的底层原理和扩容机制一直都是面试的时候经常被问到的问题,同时也是集合源码中最难阅读的一部分😢,之前更新的 ArrayList 源码阅读收获了很多朋友的喜欢,也给了我很多自信;本次我准备完成一个关于 HashMap 源码…

python练习三

模式A num int(input("请输入模式A的层数:")) for i in range(1, num 1):# 画数字for j in range(1, i 1):print(str(j) "\t", end"")print() 模式B num int(input("请输入模式B的层数:")) for i in ran…

九州金榜|孩子叛逆的原因是什么?

孩子随着年龄增长都会出现叛逆心理,很多家长不知道孩子为什么会出现叛逆心理,也不知道如何去引导孩子,下面九州金榜家庭教育就带大家了解一下孩子出现叛逆的原因。 一、心理需求增加 孩子对新事物的探索以及追求会随着人际交往扩大而增加&am…

2024年MathorCup妈妈杯数学建模思路D题思路解析+参考成品

1 赛题思路 (赛题出来以后第一时间在群内分享,点击下方群名片即可加群) 2 比赛日期和时间 报名截止时间:2024年4月11日(周四)12:00 比赛开始时间:2024年4月12日(周五)8:00 比赛结束时间&…

买卖股票的最佳时机IV

题目链接 买卖股票的最佳时机 IV 题目描述 注意点 1 < k < 1001 < prices.length < 10000 < prices[i] < 1000不能同时参与多笔交易&#xff08;必须在再次购买前出售掉之前的股票&#xff09;最多可以完成 k 笔交易 解答思路 本题与买卖股票的最佳时机…

单例模式--理解

单例模式 单例模式是指在内存中只会创建且仅创建一次对象的设计模式。在程序中多次使用同一个对象且作用相同时&#xff0c;为了防止频繁地创建对象使得内存飙升&#xff0c;单例模式可以让程序仅在内存中创建一个对象&#xff0c;让所有需要调用的地方都共享这一单例对象。 单…

Vue - 你知道Vue组件之间是如何进行数据传递的吗

难度级别:中级及以上 提问概率:85% 这道题还可以理解为Vue组件之间的数据是如何进行共享的,也可以理解为组件之间是如何通信的,很多人叫法不同,但都是说的同一个意思。我们知道,在Vue单页面应用项目中,所有的组件都是被嵌套在App.vue内…