PyTorch定长验证码训练集数字识别(几乎每行注释,开箱即用)

文章目录

  • 前言
  • 一、代码
    • 1.1 MyDataset.py(加载数据集和计算均值,标准差)
    • 1.2 Mymodels.py(使用预训练模型)
      • 1.2.1 ResNet介绍
    • 1.3 main.py(启动代码)
    • 1.4 inferring.py(验证是否识别成功)
    • 1.5 文件目录树
    • 1.6 资源链接
  • 二、借鉴

前言

这是一个识别出验证码图片的代码。训练集和测试集我会放在下面。同理也可以训练文字以及字母。仅限于提前知道长度,本文使用的是4长度的。
如有不妥请,请文明指出。
在这里插入图片描述

一、代码

1.1 MyDataset.py(加载数据集和计算均值,标准差)

import os
import torch
import numpy as npfrom torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from tqdm import tqdmclass NumberDataset(Dataset):# 定义一个名为 NumberDataset 的子类,继承 Dataset 类,实现数据集的加载和处理操作def __init__(self, path, type_transform):super(NumberDataset, self).__init__()self.path = path# 加载数据集图片列表self.picture_list = self.load_image_list()# 数据转换self.transform = type_transform# 数字字符映射表self.map = ["1", "2", "3", "4", "5", "6", "7", "8", "9"]def load_image_list(self):"""计算数据集的长度:return:"""# 获取指定路径下的所有文件名list_data = list(os.walk(self.path))[0][-1]return list_datadef __len__(self):# 返回数据集的长度return len(self.picture_list)def __getitem__(self, item):# 打开图片image = Image.open(os.path.join(self.path, self.picture_list[item]))if self.transform:# 转换图像格式为Tensorimage = self.transform(image)# 获取该图像对应的数字字符labels = [self.map.index(i) for i in self.picture_list[item].split("_")[0]]# 将数字字符转换为Tensor类型的标签labels = torch.as_tensor(labels, dtype=torch.int64)return image, labelsif __name__ == "__main__":# 转换图像格式为Tensortransform = transforms.Compose([transforms.ToTensor(), ])# 加载数据集my_train = NumberDataset(path="./train", type_transform=transform)images = []# 遍历数据集中的所有图像for img, _ in tqdm(my_train):images.append(img)"""images = np.stack([img for img, _ in tqdm(my_train)], axis=0):对于NumberDataset数据集中的每个样本,这行代码会获取该样本的图像数据,然后将所有样本的图像数据堆叠成一个大的四维张量,其中第0个维度对应样本数,第1个维度对应通道数,第2个维度对应高度,第3个维度对应宽度。这里使用np.stack函数将图像数据堆叠起来。res_total = np.mean(images, axis=(0, 2, 3)):这行代码会计算整个数据集的均值。由于图像数据被堆叠成了一个四维张量,因此需要在第0个维度(样本维度)、第2个维度(高度维度)和第3个维度(宽度维度)上计算均值。因此,axis=(0, 2, 3)参数表示在这些维度上计算均值。res_std = np.std(images, axis=(0, 2, 3)):这行代码会计算整个数据集的标准差,与计算均值的过程类似,只需在不同维度上计算标准差即可。"""# 将列表中的所有图像数据组成一个numpy数组images = np.stack(images, axis=0)# 计算数据集中所有图像的像素平均值(后续进行正则化的时候需要用到的参数)total_mean = np.mean(images, axis=(0, 2, 3))# 计算数据集中所有图像的像素标准差(后续进行正则化的时候需要用到的参数)total_std = np.std(images, axis=(0, 2, 3))print(total_mean, total_std)

结果
在这里插入图片描述

1.2 Mymodels.py(使用预训练模型)

至于想手搓,还是不建议了,新手先用用预训练模型吧。手搓还是需要调整很多参数的,前期避免看晕,这里就用预训练模型了。

这是我们使用的模型具体情况,想要深入的读者可以自行观看
ResNet残差神经网络

1.2.1 ResNet介绍

ResNet是一种非常强大的卷积神经网络架构,它在许多计算机视觉任务中都表现良好。ResNet的深度可以根据具体任务进行选择,通常可以在18层、34层、50层、101层、152层之间选择。

对于3000张图片的训练集,可以考虑使用ResNet-18或ResNet-34,这两个模型相对较浅,适合小数据集的训练。如果数据集较大,可以考虑使用ResNet-50、ResNet-101或ResNet-152等更深的模型。

需要注意的是,选择合适的模型不仅取决于数据集的大小,还与任务的复杂度有关。如果任务比较简单,可能不需要使用太深的模型。同时,还需要考虑训练时间和计算资源的限制,因为深层的模型需要更长的时间和更多的计算资源来训练。

最好的做法是尝试几个不同深度的ResNet模型并比较它们的性能,以确定哪个模型最适合你的任务和数据集。

from torch import nn
from torchvision import modelsclass BetterNet(nn.Module):def __init__(self):super(BetterNet, self).__init__()# 使用 resnet18 模型,将输出的类别数设为 4 * 9# 一个字母9个类,需要输出4个,所以4*9self.resnet18 = models.resnet18(num_classes=4 * 9)def forward(self, x):# 将输入 x 经过 resnet18 模型的处理x = self.resnet18(x)return x

1.3 main.py(启动代码)

import numpy as np
import torch
import os
import torch.optim as optimfrom tqdm import tqdm
from torchvision import transforms
from MyDataset import NumberDataset
from Mymodels import BetterNet
from torch.utils.data import DataLoader# batch_size(每批处理的数据, 根据性能选择)
"""
对于训练神经网络,batch size 的大小会对模型的训练产生影响,不是越大越好。这是因为 batch size 的大小影响到模型参数的更新方式和更新频率。
较大的 batch size 可以在每个 epoch 内处理更多的样本,从而使梯度下降更新更加稳定,减少了训练时的波动。此外,较大的 batch size 还可以利用 GPU 并行计算的能力,从而加快训练速度。
但是,较大的 batch size 也会导致一些问题。首先,较大的 batch size 可能会导致模型过拟合训练集,因为模型可能会过度依赖于训练集中的噪声和特定样本的特征。其次,较大的 batch size 可能会降低模型的泛化能力,因为模型更容易学习到训练集的特殊性质而忽略其他可能的特征。
因此,在实践中,选择合适的 batch size 是非常重要的。通常情况下,建议选择较小的 batch size,例如 32、64 或 128,同时可以利用优化器的动量等技术来提高训练效果。如果内存和计算资源允许,也可以适当增大 batch size。但需要注意,不同的模型和数据集可能需要不同的 batch size。
由于我们这里的数据集只有3000左右,我们就选择较小的batch_size即可,使用8
"""
BATCH_SIZE = 8
# 选择gpu运行
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 训练次数
EPOCHS = 5
# 构建transform,对图像做处理
my_transform = transforms.Compose([transforms.ToTensor(),# 这里就使用到了MyDataset.py输出的均值和标准差# 这里的 transforms.Normalize 属于对输入数据进行正则化的操作。这个操作是将输入数据按照指定的均值和标准差进行标准化,从而使得输入数据的分布更加平稳,更有利于模型的训练和收敛。transforms.Normalize(mean=(0.9471762, 0.9478388, 0.9475022), std=(0.18795937, 0.18679644, 0.18724856))])# 定义损失函数
loss_function = torch.nn.CrossEntropyLoss()def deal_database():"""划分数据集:return:"""train_data = NumberDataset(path="./train", type_transform=my_transform)test_data = NumberDataset(path="./test", type_transform=my_transform)# 加载数据集(其中shuffle决定的是是否打乱数据,为了提高模型精度选择True打乱。)# drop_last=True如果数据集的大小不能被batch的size整除,最后一个batch的大小就会小于batch的size,这种情况下如果不开启drop_last,就会出现最后一个batch大小不同的情况,而且最后一个batch的处理会比其他batch多出很多问题。所以如果你希望所有的batch大小都相同,可以开启drop_last,这样最后一个小于batch的size的batch将会被忽略掉。# 训练集part_train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)# 测试集part_test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)return part_train_loader, part_test_loaderdef train_model():"""训练:return:"""total_loss = []# 这一步借用了tqdm实现了进度条打印的功能dataloader = tqdm(train_loader, total=len(train_loader))# 启动训练model.train()for (image, label) in dataloader:# 部署到DEVICEimage = image.to(DEVICE)label = label.to(DEVICE)# 梯度初始化为0optimizer.zero_grad()# 向前传播output = model(image)# ,9是分类, 4*9就是4个9分类的结果(要分类才需要这些操作)# 将模型的输出 output 和标签 label 进行维度的变换,使它们能够被送入交叉熵损失函数进行计算。output = output.view(BATCH_SIZE * 4, 9)label = label.view(-1)# 计算损失loss = loss_function(output, label)total_loss.append(loss.item())# 反向传播loss.backward()# 优化器更新optimizer.step()# 保存模型torch.save(model.state_dict(), './models/model.pkl')# 保存优化器torch.save(optimizer.state_dict(), './models/optimizer.pkl')return np.mean(total_loss)def test_model():"""测试:return:"""# 统计正确率succeed = []# 这一步借用了tqdm实现了进度条打印的功能dataloader = tqdm(test_loader, total=len(test_loader))# 启动测试model.eval()with torch.no_grad():  # 不计算梯度,不反向传播for (image, label) in dataloader:# 部署到DEVICEimage = image.to(DEVICE)label = label.to(DEVICE)# 向前传播output = model(image)# ,9是分类, 4*9就是4个9分类的结果(要分类才需要这些操作)# 将模型的输出 output 和标签 label 进行维度的变换,使它们能够被送入交叉熵损失函数进行计算。output = output.view(BATCH_SIZE * 4, 9)label = label.view(-1)# 找到概率值最大的下标result = output.argmax(dim=1)# 累计正确率succeed.append(result.eq(label).float().mean().item())return np.mean(succeed)if __name__ == "__main__":# 模型实例化(传给gpu使用)model = BetterNet().to(DEVICE)# 优化器:更新模型参数,使训练结果达到最优值optimizer = optim.Adam(model.parameters())# 加载优化好的模型和优化器继续进行训练if os.path.exists('./models/model.pkl'):model.load_state_dict(torch.load('./models/model.pkl'))optimizer.load_state_dict(torch.load('./models/optimizer.pkl'))# 加载数据集train_loader, test_loader = deal_database()# 训练for epoch in range(EPOCHS):mean_loss = train_model()mean_succeed = test_model()print(f"第{epoch + 1}次epoch---损失: {mean_loss}---成功率: {mean_succeed}")

1.4 inferring.py(验证是否识别成功)

from torchvision import transforms
from torch import load, no_grad
from PIL import Image
from Mymodels import BetterNet# 构建transform,对图像做处理,要和main里面的一致
my_transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.9471762, 0.9478388, 0.9475022), std=(0.18795937, 0.18679644, 0.18724856))])
# 数字字符映射表
mapping = ["1", "2", "3", "4", "5", "6", "7", "8", "9"]if __name__ == "__main__":# 模型实例化(传给gpu使用)model = BetterNet()# 加载模型model.load_state_dict(load('./models/model.pkl'))# 打开图片image = Image.open('./test.png')# 对图片进行转换成PyTorch处理的tensor格式,并移动到模型所在设备上image = my_transforms(image)"""第一个参数 1 表示 batch size,即处理的图像数量为 1。第二个参数 3 表示图像有 3 个通道,即 RGB 三个通道。第三个参数 50 表示图像高度为 50 像素。第四个参数 150 表示图像宽度为 150 像素。"""# 这是因为模型要求输入的图像张量形状为 [batch_size, channels, height, width]image = image.view(1, 3, 50, 150)# 预测model.eval()# 不进行梯度计算with no_grad():# 获取结果out_put = model(image)"""第一个参数 4 表示将 batch_size 调整为 4,即预测结果包含 4 个样本的预测值。第二个参数 9 表示将 num_classes 调整为 9,即模型预测的结果是 9 维的向量。(分类)"""out_put = out_put.view(4, 9)# 对模型输出结果out_put在维度1上取最大值。result = out_put.max(dim=1)[1]print([mapping[i] for i in list(result.numpy())])

在这里插入图片描述

1.5 文件目录树

没给后缀的就是目录

  • models
  • train
  • test
  • test_inferring
  • main.py
  • MyDataset.py
  • Mymodels.py
  • inferring.py
  • test.png
    资源链接给的是train(训练)、test(测试)、test_inferring(验证)、test.png的资源,其他需要自己根据代码复现,test.png是验证是否成功的图片,test_inferring是包含199张给你验证的图片文件夹

1.6 资源链接

数据下载地址

二、借鉴

猿人学-安然导师
chatGPT
残差网络

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/71822.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

chatgpt赋能python:Python制表位:优化数据可视化与分析的利器

Python 制表位:优化数据可视化与分析的利器 在数据可视化和分析中,表格是一种常用的数据展示方式。Python 提供了丰富的用于构建表格的库,其中之一便是制表位(Tabulate)。本文将介绍制表位的特点、使用方法以及另外一…

结合代谢组学和网络药理学技术发现的差异代谢物和中药成分的药物靶点关联等技术操作

本期分享一篇中南大学今年发表在Computational and Structural Biotechnology Journal 杂志(影响因子6.018)上的论文《结合代谢组学和网络药理学揭示羟基红花黄色素A抗急性颅脑损伤的机制》。 外伤性脑损伤(Traumatic brain injury,TBI)已成为世界范围内导致死亡、发病和残…

网络药理学分析工具开发好了

上次文章说开发网络药理学工具,其实上周五就已经做好了,但我为什么要今天才通知各位小伙伴呢。因为第一版做的实在太丑了图片,所以我觉得要好好打磨一下,所以今天才写这篇文章。我们先来看下软件打磨前后的对比: 第一版…

论文查重发现他引率为0怎么办

今天准备论文查重,发现虽然查重率低,但是他引率为0。搞得我一脸懵。 格式什么的都是正确的,引用大段的文献也有,为啥他引率为0呢。。。。 被逼无奈,将文章中的上标注和参考文献的标注全部用手打的,不使用…

文末送书 | 图灵宇宙:用漫画讲述图灵奖背后的计算机科学发展简史

张立波,武延军,赵琛 著 电子工业出版社-博文视点2022-09-01 ISBN: 9787121442933定价:109.00 元 新书推荐 🌟今日福利 |关于本书| 这是一本以计算机领域重要奖项——图灵奖为切入点,系统展现计算机科学发展…

“复制”马斯克(三):我们要为他的“反智事业”买单吗?

马斯克首次跻身世界首富,引发大众的强烈关注。 但是,首富的排名对马斯克、对我们而言都并不重要,对我们更为重要的一个影响是,随着马斯克所取得的商业成功和巨大财富积累,他的事业正在进入一个全新的阶段。 去年的12月…

马斯克的 39 页火星计划PPT

????????关注后回复 “进群” ,拉你进程序员交流群???????? 马斯克曾在Twitter上这样写道,“每年建造100艘星际飞船,10年内就达到1000艘,也就意味着每年的运力达到1亿吨。或者说每当地球和火星轨道同步时可以运载…

下任推特 CEO 或是“卷王”?在马斯克手下 20 年,每天工作 16 个小时,还带着家人住办公室!

整理 | 郑丽媛 出品 | CSDN(ID:CSDNnews) 上周,马斯克发起线上投票,让网友决定他是否该卸任 Twitter CEO 一职,最终超过 1700 万 Twitter 用户参与,其中 57.5% 的人投了赞成票。 于是 12 月 21…

【程序人生】马斯克:我一直有种存在的危机感

01 我一直有种存在的危机感 小时候,人们常会问我,长大要做什么,我其实也不知道。 后来我想,搞发明应该会很酷吧,因为科幻小说家亚瑟克拉克(《2001太空漫游》作者)曾说过:任何足够先进的科技,都与魔法无异。 想想看,三百年前的人类,如果看到今天我们可以飞行、可…

马斯克:SpaceX成功的背后,经历了18次失败、被骂是骗子、几近破产

美国太平洋东部时间周二下午,SpaceX发射了“猎鹰重型”(Falcon Heavy)火箭,这是该公司迄今为止最大、也是世界上最强大的运载发射系统。这次发射成功,让传奇人物马斯克和SpaceX再次成为大众瞩目的焦点。 然而,在SpaceX成立的16年里…

说一说埃隆.马斯克他妈妈的故事

特斯拉公司创始人埃隆马斯克被誉为“第二个乔布斯”、“硅谷钢铁侠”,造火箭、移民火星、星链计划……他简直就是一个不折不扣的科技天才。 每一个成功的男人的背后都站着一个优秀的女人,对于埃隆马斯克来说,他之所以取得如此不凡的成绩&…

GPT-3说:马斯克是世界最强的人,但没有他人类会更好

金磊 发自 凹非寺量子位 报道 | 公众号 QbitAI 和GPT-3的一番对话,炸出来个马斯克,既让他当总统,又建议暗杀他…… 怎么回事? 一位叫 Spencer Greenberg (以下简称S先生)的数学家,最近和GPT-3做…

马斯克39页火星计划PPT曝光,我们能学到什么

来源:管理晨读 本文ppt部分转载自公众号北美工程师求职顾问 新闻报道部分来自于中新社 SpaceX公司首席运营官马斯克一直梦想着移民火星,并在之前完成了许多的开发计划和实验。很多人说他是异想天开,也有很多人觉得火星目前没有找到绿色生物&a…

解读本世纪最成功的天才——埃隆·马斯克

转载https://blog.csdn.net/isuccess88/article/details/75500905 解读本世纪最成功的天才——埃隆马斯克 就在昨天的上午9点过,一家名为spacex私人火箭公司,成功回收了从轨道上完整运行的火箭,使火箭不仅发射还软着陆。 这项成就直接干翻了中…

chatgpt赋能python:Python对话机器人:为什么它是最好的选择?

Python对话机器人:为什么它是最好的选择? Python是一种高级编程语言,拥有简单易懂的语法和广泛的应用领域。在人工智能领域,Python成为了首选的编程语言之一。有很多Python对话机器人的框架已经被开发出来,例如Chatte…

十大BI报表可视化工具

一、Tableau 自助式BI典型的代表,目前在国内也还有许多代理商,Tableau也算是众多国外BI产品中,目前在国内还比较有竞争力的国外BI厂商吧。因其操作简单,无右键设计,设计一张报表就只需真正意义上的托拖拽拽就可快速完…

【手把手教你】股票可视化分析之Pyecharts(二)

01 引言 Pyechartss 是基于Echarts 的开源可视化库,可以制作非常精美的图表。公众号推文《【手把手教你】股票可视化分析之Pyecharts(一)》,以股票交易数据为例,为大家展示了使用 Pyehcarts 构建直角坐标系下常用的图表…

python数据分析及可视化(十四)数据分析可视化练习-上市公司可视化数据分析、黑色星期五案例分析

上市公司数据分析 从中商情报网下载的数据,表格中会存在很多的问题,查看数据的信息有无缺失,然后做数据的清晰,有无重复值,异常数据,省份和城市的列名称和数据是不对照的,删除掉一些不需要的数…

一款开源的数据可视化分析平台,提供多种大屏模板,非常炫酷

点击关注公众号,实用技术文章及时了解 DataGear是一款开源的数据可视化分析平台,可自由制作任何您想要的数据可视化看板,支持接入SQL、CSV、Excel、HTTP接口、JSON等多种数据源。系统主要功能包括:数据管理、SQL工作台、数据导入/…