Pytorch深度学习—FashionMNIST数据集训练

文章目录

    • FashionMNIST数据集
    • 需求库导入、数据迭代器生成
    • 设备选择
    • 样例图片展示
    • 日志写入
    • 评估—计数器
    • 模型构建
    • 训练函数
    • 整体代码
    • 训练过程
    • 日志

FashionMNIST数据集

  • FashionMNIST(时尚 MNIST)是一个用于图像分类的数据集,旨在替代传统的手写数字MNIST数据集。它由 Zalando Research 创建,适用于深度学习和计算机视觉的实验。
    • FashionMNIST 包含 10 个类别,分别对应不同的时尚物品。这些类别包括 T恤/上衣、裤子、套头衫、裙子、外套、凉鞋、衬衫、运动鞋、包和踝靴。
    • 每个类别有 6,000 张训练图像和 1,000 张测试图像,总计 70,000 张图像。
    • 每张图像的尺寸为 28x28 像素,与MNIST数据集相同。
    • 数据集中的每个图像都是灰度图像,像素值在0到255之间。
      在这里插入图片描述

需求库导入、数据迭代器生成

import os
import random
import numpy as np
import datetime
import torch
import torch.nn as nn
from torch.utils.data import DataLoaderimport torchvision
from torchvision import transformsimport argparse
from tqdm import tqdmimport matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriterdef _load_data():"""download the data, and generate the dataloader"""trans = transforms.Compose([transforms.ToTensor()])train_dataset = torchvision.datasets.FashionMNIST(root='./data/', train=True, download=True, transform=trans)test_dataset = torchvision.datasets.FashionMNIST(root='./data/', train=False, download=True, transform=trans)# print(len(train_dataset), len(test_dataset))train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.num_works)test_loader = DataLoader(test_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.num_works)return (train_loader, test_loader)

设备选择

def _device():device = torch.device("cuda" if torch.cuda.is_available() else "cpu")return device

样例图片展示

"""display data examples"""
def _image_label(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]def _show_images(imgs, rows, columns, titles=None, scale=1.5):figsize = (rows * scale, columns * 1.5)fig, axes = plt.subplots(rows, columns, figsize=figsize)axes = axes.flatten()for i, (img, ax) in enumerate(zip(imgs, axes)):ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])plt.show()return axesdef _show_examples():train_loader, test_loader = _load_data()for images, labels in train_loader:images = images.squeeze(1)_show_images(images, 3, 3, _image_label(labels))break

日志写入

class _logger():def __init__(self, log_dir, log_history=True):if log_history:log_dir = os.path.join(log_dir, datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S"))self.summary = SummaryWriter(log_dir)def scalar_summary(self, tag, value, step):self.summary.add_scalars(tag, value, step)def images_summary(self, tag, image_tensor, step):self.summary.add_images(tag, image_tensor, step)def figure_summary(self, tag, figure, step):self.summary.add_figure(tag, figure, step)def graph_summary(self, model):self.summary.add_graph(model)def close(self):self.summary.close()

评估—计数器

class AverageMeter():def __init__(self):self.reset()def reset(self):self.val = 0self.avg = 0self.sum = 0self.count = 0def update(self, val, n=1):self.val = valself.sum += val * nself.count += nself.avg = self.sum / self.count

模型构建

class Conv3x3(nn.Module):def __init__(self, in_channels, out_channels, down_sample=False):super(Conv3x3, self).__init__()self.conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, 1, 1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, 3, 1, 1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))if down_sample:self.conv[3] = nn.Conv2d(out_channels, out_channels, 2, 2, 0)def forward(self, x):return self.conv(x)class SimpleNet(nn.Module):def __init__(self, in_channels, out_channels):super(SimpleNet, self).__init__()self.conv1 = Conv3x3(in_channels, 32)self.conv2 = Conv3x3(32, 64, down_sample=True)self.conv3 = Conv3x3(64, 128)self.conv4 = Conv3x3(128, 256, down_sample=True)self.fc = nn.Linear(256*7*7, out_channels)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = self.conv4(x)x = torch.flatten(x, 1)out = self.fc(x)return out

训练函数

def train(model, train_loader, test_loader, criterion, optimizor, epochs, device, writer, save_weight=False):train_loss = AverageMeter()test_loss = AverageMeter()train_precision = AverageMeter()test_precision = AverageMeter()time_tick = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")for epoch in range(epochs):print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, args.lr))model.train()for input, label in tqdm(train_loader):input, label = input.to(device), label.to(device)output = model(input)# backwardloss = criterion(output, label)optimizor.zero_grad()loss.backward()optimizor.step()# loggerpredict = torch.argmax(output, dim=1)train_pre = sum(predict == label) / len(label)train_loss.update(loss.item(), input.size(0))train_precision.update(train_pre.item(), input.size(0))model.eval()with torch.no_grad():for X, y in tqdm(test_loader):X, y = X.to(device), y.to(device)y_hat = model(X)loss_te = criterion(y_hat, y)predict_ = torch.argmax(y_hat, dim=1)test_pre = sum(predict_ == y) / len(y)test_loss.update(loss_te.item(), X.size(0))test_precision.update(test_pre.item(), X.size(0))if save_weight:best_dice = args.best_diceweight_dir = os.path.join(args.weight_dir, args.model, time_tick)os.makedirs(weight_dir, exist_ok=True)monitor_dice = test_precision.avgif monitor_dice > best_dice:best_dice = max(monitor_dice, best_dice)name = os.path.join(weight_dir, args.model + '_' + str(epoch) + \'_test_loss-' + str(round(test_loss.avg, 4)) + \'_test_dice-' + str(round(best_dice, 4)) + '.pt')torch.save(model.state_dict(), name)print("train" + '---Loss: {loss:.4f} | Dice: {dice:.4f}'.format(loss=train_loss.avg, dice=train_precision.avg))print("test " + '---Loss: {loss:.4f} | Dice: {dice:.4f}'.format(loss=test_loss.avg, dice=test_precision.avg))# summarywriter.scalar_summary("Loss/loss", {"train": train_loss.avg, "test": test_loss.avg}, epoch)writer.scalar_summary("Loss/precision", {"train": train_precision.avg, "test": test_precision.avg}, epoch)writer.close()

整体代码

import os
import random
import numpy as np
import datetime
import torch
import torch.nn as nn
from torch.utils.data import DataLoaderimport torchvision
from torchvision import transformsimport argparse
from tqdm import tqdmimport matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter"""Reproduction experiment"""
def setup_seed(seed):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)# torch.backends.cudnn.benchmark = False# torch.backends.cudnn.enabled = False# torch.backends.cudnn.deterministic = True"""data related"""
def _base_options():parser = argparse.ArgumentParser(description="Train setting for FashionMNIST")# about datasetparser.add_argument('--batch_size', default=8, type=int, help='the batch size of dataset')parser.add_argument('--num_works', default=4, type=int, help="the num_works used")# trainparser.add_argument('--epochs', default=100, type=int, help='train iterations')parser.add_argument('--lr', default=0.001, type=float, help='learning rate')parser.add_argument('--model', default="SimpleNet", choices=["SimpleNet"], help="the model choosed")# log dirparser.add_argument('--log_dir', default="./logger/", help='the path of log file')#parser.add_argument('--best_dice', default=-100, type=int, help='for save weight')parser.add_argument('--weight_dir', default="./weight/", help='the dir for save weight')args = parser.parse_args()return argsdef _load_data():"""download the data, and generate the dataloader"""trans = transforms.Compose([transforms.ToTensor()])train_dataset = torchvision.datasets.FashionMNIST(root='./data/', train=True, download=True, transform=trans)test_dataset = torchvision.datasets.FashionMNIST(root='./data/', train=False, download=True, transform=trans)# print(len(train_dataset), len(test_dataset))train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.num_works)test_loader = DataLoader(test_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.num_works)return (train_loader, test_loader)def _device():device = torch.device("cuda" if torch.cuda.is_available() else "cpu")return device"""display data examples"""
def _image_label(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]def _show_images(imgs, rows, columns, titles=None, scale=1.5):figsize = (rows * scale, columns * 1.5)fig, axes = plt.subplots(rows, columns, figsize=figsize)axes = axes.flatten()for i, (img, ax) in enumerate(zip(imgs, axes)):ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])plt.show()return axesdef _show_examples():train_loader, test_loader = _load_data()for images, labels in train_loader:images = images.squeeze(1)_show_images(images, 3, 3, _image_label(labels))break"""log"""
class _logger():def __init__(self, log_dir, log_history=True):if log_history:log_dir = os.path.join(log_dir, datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S"))self.summary = SummaryWriter(log_dir)def scalar_summary(self, tag, value, step):self.summary.add_scalars(tag, value, step)def images_summary(self, tag, image_tensor, step):self.summary.add_images(tag, image_tensor, step)def figure_summary(self, tag, figure, step):self.summary.add_figure(tag, figure, step)def graph_summary(self, model):self.summary.add_graph(model)def close(self):self.summary.close()"""evaluate the result"""
class AverageMeter():def __init__(self):self.reset()def reset(self):self.val = 0self.avg = 0self.sum = 0self.count = 0def update(self, val, n=1):self.val = valself.sum += val * nself.count += nself.avg = self.sum / self.count"""define the Net"""
class Conv3x3(nn.Module):def __init__(self, in_channels, out_channels, down_sample=False):super(Conv3x3, self).__init__()self.conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, 1, 1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, 3, 1, 1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))if down_sample:self.conv[3] = nn.Conv2d(out_channels, out_channels, 2, 2, 0)def forward(self, x):return self.conv(x)class SimpleNet(nn.Module):def __init__(self, in_channels, out_channels):super(SimpleNet, self).__init__()self.conv1 = Conv3x3(in_channels, 32)self.conv2 = Conv3x3(32, 64, down_sample=True)self.conv3 = Conv3x3(64, 128)self.conv4 = Conv3x3(128, 256, down_sample=True)self.fc = nn.Linear(256*7*7, out_channels)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = self.conv4(x)x = torch.flatten(x, 1)out = self.fc(x)return out"""progress of train/test"""
def train(model, train_loader, test_loader, criterion, optimizor, epochs, device, writer, save_weight=False):train_loss = AverageMeter()test_loss = AverageMeter()train_precision = AverageMeter()test_precision = AverageMeter()time_tick = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")for epoch in range(epochs):print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, args.lr))model.train()for input, label in tqdm(train_loader):input, label = input.to(device), label.to(device)output = model(input)# backwardloss = criterion(output, label)optimizor.zero_grad()loss.backward()optimizor.step()# loggerpredict = torch.argmax(output, dim=1)train_pre = sum(predict == label) / len(label)train_loss.update(loss.item(), input.size(0))train_precision.update(train_pre.item(), input.size(0))model.eval()with torch.no_grad():for X, y in tqdm(test_loader):X, y = X.to(device), y.to(device)y_hat = model(X)loss_te = criterion(y_hat, y)predict_ = torch.argmax(y_hat, dim=1)test_pre = sum(predict_ == y) / len(y)test_loss.update(loss_te.item(), X.size(0))test_precision.update(test_pre.item(), X.size(0))if save_weight:best_dice = args.best_diceweight_dir = os.path.join(args.weight_dir, args.model, time_tick)os.makedirs(weight_dir, exist_ok=True)monitor_dice = test_precision.avgif monitor_dice > best_dice:best_dice = max(monitor_dice, best_dice)name = os.path.join(weight_dir, args.model + '_' + str(epoch) + \'_test_loss-' + str(round(test_loss.avg, 4)) + \'_test_dice-' + str(round(best_dice, 4)) + '.pt')torch.save(model.state_dict(), name)print("train" + '---Loss: {loss:.4f} | Dice: {dice:.4f}'.format(loss=train_loss.avg, dice=train_precision.avg))print("test " + '---Loss: {loss:.4f} | Dice: {dice:.4f}'.format(loss=test_loss.avg, dice=test_precision.avg))# summarywriter.scalar_summary("Loss/loss", {"train": train_loss.avg, "test": test_loss.avg}, epoch)writer.scalar_summary("Loss/precision", {"train": train_precision.avg, "test": test_precision.avg}, epoch)writer.close()if __name__ == "__main__":# configargs = _base_options()device = _device()# datatrain_loader, test_loader = _load_data()# loggerwriter = _logger(log_dir=os.path.join(args.log_dir, args.model))# modelmodel = SimpleNet(in_channels=1, out_channels=10).to(device)optimizor = torch.optim.Adam(model.parameters(), lr=args.lr)criterion = nn.CrossEntropyLoss()train(model, train_loader, test_loader, criterion, optimizor, args.epochs, device, writer, save_weight=True)"""    args = _base_options()_show_examples()  # ———>  样例图片显示
"""

训练过程

在这里插入图片描述

日志

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/156035.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

NSIC2050JBT3G 车规级120V 50mA ±15% 用于LED照明的线性恒流调节器(CCR) 增强汽车安全

随着汽车行业的巨大变革,高品质的汽车氛围灯效、仪表盘等LED指示灯效已成为汽车内饰设计中不可或缺的元素。深力科安森美LED驱动芯片系列赋能智能座舱灯效充满艺术感和科技感——NSIC2050JBT3G LED驱动芯片,实现对每路LED亮度和颜色进行细腻控制&#xf…

【深度学习】UniControl 一个统一的扩散模型用于可控的野外视觉生成

论文:https://arxiv.org/abs/2305.11147 代码:https://github.com/salesforce/UniControl#data-preparation docker快速部署:https://qq742971636.blog.csdn.net/article/details/133129146 文章目录 AbstractIntroductionRelated WorksUniCo…

开放式耳机选择什么品牌?六款口碑好爆的开放式耳机盘点

喜欢把户外运动纳入生活计划的朋友,都是懂得享受生活的乖仔,那么大家也需要一副既匹配运动场景又可保护听力,同时还有好音质相伴的耳机一起出发运动吧?在各商家卯足劲儿推出创新产品、伙计们也都挑得眼花缭乱的时候,我…

nginx服务---2

如何统计连接数,以及根据域名配置虚拟主机 cd /usr/local/nginx/conf vim nginx.conf server {listen 80;server_name www.abc.com;charset utf-8;access_log logs/www.abc.com;error_log logs/www.abc.error.log;location / {root /var/www/html/zzr;in…

STM32 CubeMX ADC采集 单通道,多通道,内部温度(轮询,DMA,中断)(HAL库)

STM32 CubeMX ADC采集(HAL库) STM32 CubeMX STM32 CubeMX ADC采集(HAL库)ADC介绍ADC主要特征Vref的电压(2.4~3.6)就是ADC参考电压2.4V(相当于秤砣) 最小识别电压值:2.4/4…

在Android平板上使用code-server公网远程Ubuntu服务器编程

文章目录 1.ubuntu本地安装code-server2. 安装cpolar内网穿透3. 创建隧道映射本地端口4. 安卓平板测试访问5.固定域名公网地址6.结语 1.ubuntu本地安装code-server 准备一台虚拟机,Ubuntu或者centos都可以,这里以VMwhere ubuntu系统为例 下载code server服务,浏览器…

360测试开发技术面试题目

最近面试了360测试开发的职位,将面试题整理出来分享~ 一、java方面 1、java重载和重写的区别 重载overloading 多个方法、相同的名字,不同的参数 重写overwrite 子类继承父类,对方法进行重写 2、java封装的特性 可以改变内部实现,…

python - excel 设置样式

文章目录 前言python - excel 设置样式1. 准备2. 示例2.1. 给单元格设置样式"等线"、大小为24磅、斜体、红色颜色和粗体2.2. 给第二行设置样式"宋体"、大小为16磅、斜体、红色颜色和粗体2.3. 给第三行数据设置垂直居中和水平居中2.4. 给第四行设置行高为30…

RabbitMQ开启消息跟踪日志(trace)

Trace 是Rabbitmq用于记录每一次发送的消息,方便使用Rabbitmq的开发者调试、排错。 1、启动Tracing插件 在RabbitMQ中默认是关闭的,需手动开启。此处rabbitMQ是使用docker部署的 ## 进入rabbitMq中 docker exec -it rabbitmq1 bash ## 启动日志插件 r…

SpringCloud之Stream框架集成RocketMQ消息中间件

Spring Cloud Stream 是一个用来为微服务应用构建消息驱动能力的框架。它可以基于 Spring Boot 来创建独立的、可用于生产的 Spring 应用程序。Spring Cloud Stream 为一些供应商的消息中间件产品提供了个性化的自动化配置实现,并引入了发布-订阅、消费组、分区这三…

【C语言】文件的操作与文件函数的使用(详细讲解)

前言:我们在学习C语言的时候会发现在编写一个程序的时候,数据是存在内存当中的,而当我们退出这个程序的时候会发现这个数据不复存在了,因此我们可以通过文件把数据记录下来,使用文件我们可以将数据直接存放在电脑的硬盘…

进来了解实现官网搜索引擎的三种方法

做网站的目的是对自己的品牌进行推广,让越来越多的人知道自己的产品,但是如果只是做了一个网站放着,然后等着生意找上门来那是不可能的。在当今数字时代,实现官网搜索引擎对于提升用户体验和推动整体性能至关重要。搜索引擎可以帮…

【C#】什么是并发,C#常规解决高并发的基本方法

给自己一个目标,然后坚持一段时间,总会有收获和感悟! 在实际项目开发中,多少都会遇到高并发的情况,有可能是网络问题,连续点击鼠标无反应快速发起了N多次调用接口, 导致极短时间内重复调用了多次…

互联网性能和可用性优化CDN和DNS

当涉及到互联网性能和可用性优化时,DNS(Domain Name System)和CDN(Content Delivery Network)是两个至关重要的元素。它们各自发挥着关键作用,以确保用户能够快速、可靠地访问网站和应用程序。在本文中&…

第一个 Python 程序

三、第一个 Python 程序 好了,说了那么多,现在我们可以来写一下第一个 Python 程序了。 一开始写 Python 程序,个人不太建议用专门的工具来写,不方便熟悉语法,所以这里我先用 Sublime Text 来写,后期可以…

如何在会计面试中展现自己的优势?

在会计面试中展现自己的优势是非常重要的,因为这将决定你是否能够脱颖而出并获得这个职位。下面是一些可以帮助你展示自己优势的方法: 1. 准备充分:在面试前,确保你对公司的背景和业务有所了解。研究公司的财务报告和新闻&#xf…

智能井盖传感器:提升城市安全与便利的利器

在智能化城市建设的浪潮中,WITBEE万宾智能井盖传感器,正以其卓越的性能和创新的科技,吸引着越来越多的关注。本文小编将为大家详细介绍这款产品的独特优势和广阔应用前景。 在我们生活的城市中,井盖可能是一个最不起眼的存在。然而…

ctfshow萌新计划web9-14(正则匹配绕过)

目录 web9 web10 web11 web12 web13 web14 web9 审一下代码,需要匹配到system|exec|highlight才会执行eval函数 先看一下当前目录下有什么 payload:?csystem(ls); index.php是首页,我们看看config.php payload:?csystem…

Android笔记(二):JetPack Compose定义移动界面概述

一、JetPack Compose组件概述 JetPack Compose是Google公司在2021年正式推出的声明式UI工具包。Compose库用于开发原生Android应用界面。它取代传统XML文件配置界面,不需要界面编辑工具,而是采用强大Kotlin API以及函数搭建移动应用界面,代码…

nextjs构建服务端渲染,同时使用Material UI进行项目配置

一、创建一个next项目 使用create-next-app来启动一个新的Next.js应用,它会自动为你设置好一切 运行命令: npx create-next-applatest 执行结果如下: 启动项目: pnpm dev 执行结果: 启动成功! 二、安装Mater…