PyTorch入门之【AlexNet】

参考文献:https://www.bilibili.com/video/BV1DP411C7Bw/?spm_id_from=333.999.0.0&vd_source=98d31d5c9db8c0021988f2c2c25a9620
AlexNet 是一个经典的卷积神经网络模型,用于图像分类任务。

目录

  • 大纲
  • dataloader
  • model
  • train
  • test

大纲

在这里插入图片描述
各个文件的作用:

  • data就是数据集
  • dataloader.py就是数据集的加载以及实例初始化
  • model.py就是AlexNet模块的定义
  • train.py就是模型的训练
  • test.py就是模型的测试

dataloader

import torch
import torchvision
import torchvision.transforms as transformsimport matplotlib.pyplot as plt
import numpy as np# define the dataloader
transform = transforms.Compose([transforms.Resize(224),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])batch_size = 16trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,shuffle=False)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')if __name__ == '__main__':# get some random training imagesdataiter = iter(train_loader)images, labels = next(dataiter)# print labelsprint(' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))# show imagesimg_grid = torchvision.utils.make_grid(images)img_grid = img_grid / 2 + 0.5npimg = img_grid.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()

model

import torch.nn as nn
import torchclass AlexNet(nn.Module):def __init__(self, num_classes=10):super(AlexNet, self).__init__()self.conv_1 = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),nn.BatchNorm2d(96),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2))self.conv_2 = nn.Sequential(nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2))self.conv_3 = nn.Sequential(nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(384),nn.ReLU())self.conv_4 = nn.Sequential(nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(384),nn.ReLU())self.conv_5 = nn.Sequential(nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2))self.fc_1 = nn.Sequential(nn.Dropout(0.5),nn.Linear(9216, 4096),nn.ReLU())self.fc_2 = nn.Sequential(nn.Dropout(0.5),nn.Linear(4096, 4096),nn.ReLU())self.fc_3= nn.Sequential(nn.Linear(4096, num_classes))def forward(self, x):out = self.conv_1(x)out = self.conv_2(out)out = self.conv_3(out)out = self.conv_4(out)out = self.conv_5(out)out = out.reshape(out.size(0), -1)out = self.fc_1(out)out = self.fc_2(out)out = self.fc_3(out)return outif __name__ == '__main__':model = AlexNet()print(model)x = torch.randn(1, 3, 224, 224)y = model(x)print(y.size())

train

import torch
import torch.nn as nnfrom dataloader import train_loader, test_loader
from model import AlexNet# define the hyperparameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 10
num_epochs = 20
learning_rate = 1e-3# load the model
model = AlexNet(num_classes).to(device)# loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  # train the model
total_len = len(train_loader)for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):# move tensors to the configured deviceimages = images.to(device)labels = labels.to(device)# forward passoutputs = model(images)loss = criterion(outputs, labels)# backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 100 == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_len, loss.item()))# Validationwith torch.no_grad():model.eval()correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()model.train()print('Accuracy of the network on the {} validation images: {} %'.format(10000, 100 * correct / total))# save the model checkpoint
torch.save(model.state_dict(), 'alexnet.pth')

test

import torchfrom dataloader import test_loader, classes
from model import AlexNet# load the pretrained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AlexNet().to(device)
model.load_state_dict(torch.load('alexnet.pth', map_location=device))# test the pretrained model on CIFAR-10 test data
with torch.no_grad():model.eval()correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the {} validation images: {} %'.format(10000, 100 * correct / total))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/149710.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

软件设计师_数据结构与算法基础_学习笔记

文章目录 6.1 数组与矩阵6.1.1 数组6.1.2 稀疏矩阵 6.2 线性表6.2.1 数据结构的定义6.2.2 顺序表与链表6.2.2.1 定义6.2.2.2 链表的操作 6.2.3 顺序存储和链式存储的对比6.2.4 队列、循环队列、栈6.2.4.2 循环队列队空与队满条件6.2.4.3 出入后不可能出现的序列练习 6.2.5 串6.…

Hive【Hive(六)窗口函数】

窗口函数(window functions) 概述 定义 窗口函数能够为每行数据划分 一个窗口,然后对窗口范围内的数据进行计算,最后将计算结果返回给该行数据。 语法 窗口函数的语法主要包括 窗口 和 函数 两个部分。其中窗口用于定义计算范围…

【计算机网络面试题(62道)】

文章目录 计算机网络面试题(62道)基础1.**说下计算机网络体系结构2.说一下每一层对应的网络协议有哪些?3.那么数据在各层之间是怎么传输的呢? 网络综合4.**从浏览器地址栏输入 url 到显示主页的过程?5.说说 DNS 的解析…

LabVIEW工业虚拟仪器的标准化实施

LabVIEW工业虚拟仪器的标准化实施 创建计算机化的测试和测量系统,从计算机桌面控制外部测量硬件设备,以及在计算机屏幕上显示的类似仪器的面板上查看来自外部设备的测试或测量数据,所有这些都需要虚拟仪器系统软件。该软件允许用户执行所有这…

VL53L5CX驱动开发(1)----驱动TOF进行区域检测

VL53L5CX驱动开发----1.驱动TOF进行区域检测 闪烁定义视频教学样品申请源码下载主要特点硬件准备技术规格系统框图应用示意图区域映射生成STM32CUBEMX选择MCU 串口配置IIC配置X-CUBE-TOF1串口重定向代码配置Tera Term配置演示结果 闪烁定义 VL53L5CX是一款先进的飞行感应&…

总结二:linux面经

文章目录 1、 Linux中查看进程运行状态的指令、查看内存使用情况的指令、tar解压文件的参数。2、文件权限怎么修改?3、说说常用的Linux命令?4、说说如何以root权限运行某个程序?5、 说说软链接和硬链接的区别?6、说说静态库和动态…

【目标检测】——PE-YOLO精读

yolo,暗光目标检测 论文:PE-YOLO 1. 简介 卷积神经网络(CNNs)在近年来如何推动了物体检测的发展。许多检测器已经被提出,而且在许多基准数据集上的性能正在不断提高。然而,大多数现有的检测器都是在正常条…

1700*C. Number of Ways(贪心前缀和)

Problem - 466C - Codeforces Number of Ways - 洛谷 解析: 首先判断所有数总和是否能被三整除。 之后遍历前缀和数组,如果某个位置的前缀和等于sum/3,则记录。 某个位置前缀和等于sum/3*2则记录答案。 注意由于分成三份,所以同…

Qt 设置软件的版本信息:QMake、CMake工程

本文借鉴了Qt 设置软件的版本信息 - 疯狂delphi - 博客园 (cnblogs.com) 在原文基础增加了CMake工程实现的方法。 Qt设置软件的版本等信息 对于Qt开发的软件,我们如何去方便的查看其软件的版本信息。这里提供了几种方式。 在运行程序期间设置版本信息 大部分的程序…

黑马点评-01基于Redis实现短信登陆的功能

环境准备 当前模型 nginx服务器的作用 手机或者app端向nginx服务器发起请求,nginx基于七层模型走的是HTTP协议,可以实现基于Lua直接绕开tomcat访问Redis nginx也可以作为静态资源服务器,轻松扛下上万并发并负载均衡到下游的tomcat服务器,利用集群支撑起整个项目 使用nginx部…

二分查找:34. 在排序数组中查找元素的第一个和最后一个位置

个人主页 : 个人主页 个人专栏 : 《数据结构》 《C语言》《C》《算法》 文章目录 前言一、题目解析二、解题思路1. 暴力查找2. 一次二分查找 部分遍历3. 两次二分查找分别查找左右端点1.查找区间左端点2. 查找区间右端点 三、代码实现总结 前言 本篇文…

力扣 -- 873. 最长的斐波那契子序列的长度

解题步骤&#xff1a; 参考代码&#xff1a; class Solution { public:int lenLongestFibSubseq(vector<int>& nums) {int nnums.size();unordered_map<int,int> hash;for(int i0;i<n;i){hash[nums[i]]i;}int ret2;vector<vector<int>> dp(n,v…

基于SSM的旅游网站设计与实现

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;采用JSP技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#x…

国庆看坚如磐石

坚如磐石上映了&#xff0c;可以在爱奇艺观看。 而博主在使用蓝牙耳机连接电脑的过程中&#xff0c;发现没有蓝牙开启选项&#xff0c;并且在服务的设备管理器中也没有找到&#xff0c;很明显这是缺少驱动导致的&#xff0c;因此便去联想官方网站下载对应的驱动。 这里可以输入…

外包做了3个月,技术退步明显。。。。。

先说一下自己的情况&#xff0c;大专生&#xff0c;17年通过校招进入广州某软件公司&#xff0c;干了接近4年的功能测试&#xff0c;今年年初&#xff0c;感觉自己不能够在这样下去了&#xff0c;长时间呆在一个舒适的环境会让一个人堕落!而我已经在一个企业干了四年的功能测试…

python开发幸运水果抽奖大转盘

概述 当我女朋友跟我说要吃水果&#xff0c;又不知道吃啥水果时候&#xff0c;她以为难为到我了&#xff0c;有啥事难为到程序员的呢&#xff01; 今天用python利用第三方tkinterthreadingtime库开发一个幸运水果抽奖大转盘&#xff01;抽到啥吃啥 详细 老规矩&#xff01;咱…

初级数值计算理论总结

本文用于总结复习与研究生面试 一问&#xff0c;小伙子会不会数值计算啊一答&#xff1a;会二问&#xff1a;哦&#xff0c;讲讲看二答&#xff1a;讲不出来三问&#xff1a;...... 数值求根 二分法Jacobi 迭代法 Jacobi 迭代改进算法&#xff08;事后加速法&#xff09;&#…

频次直方图、KDE和密度图

Seaborn的主要思想是用高级命令为统计数据探索和统计模型拟合创建各种图形&#xff0c;下面将介绍一些Seaborn中的数据集和图形类型。 虽然所有这些图形都可以用Matplotlib命令实现&#xff08;其实Matplotlib就是Seaborn的底层&#xff09;&#xff0c;但是用 Seaborn API会更…

用JMeter对HTTP接口进行压测(一)压测脚本的书写、调试思路

文章目录 安装JMeter和Groovy为什么选择Groovy&#xff1f; 压测需求以及思路准备JMeter脚本以及脚本正确性验证使用Test Script Recorder来获取整条业务线上涉及的接口为什么使用Test Script Recorder&#xff1f; 配置Test Script Recorder对接口进行动态化处理处理全局变量以…

几种开源协议的区别(Apache、MIT、BSD、MPL、GPL、LGPL)

作为一名软件开发人员&#xff0c;你一定也是经常接触到开源软件&#xff0c;但你真的就了解这些开源软件使用的开源许可协议吗&#xff1f; 你不会真的认为&#xff0c;开源就是完全免费吧&#xff1f;那么让我们通过本文来寻找答案。 一、开源许可协议简述 开源许可协议是指开…