P6打卡—Pytorch实现人脸识别

  •   🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

1.检查GPU

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvisiondevice=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

2.查看数据

import os,PIL,random,pathlib
data_dir = pathlib.Path('data/48-data')
data_dir=pathlib.Path(data_dir)
data_path=list(data_dir.glob("*"))
ClassNames=[str(path).split('\\')[2] for path in data_path]
ClassNames

​​​​​

3.划分数据集

train_trainsforms=transforms.Compose([transforms.Resize([224,224]),transforms.ToTensor(),transforms.Normalize(mean=[0.486,0.456,0.406],std=[0.229,0.224,0.225])
]
)
total_data=datasets.ImageFolder("data/48-data",transform=train_trainsforms)
total_datatotal_data.class_to_idxtrain_size=int(0.8*len(total_data))
test_size=len(total_data)-train_size
train_dataset,test_dataset=torch.utils.data.random_split(total_data,(train_size,test_size))
train_dataset,test_datasetbatch_size=32
train_dl=torch.utils.data.DataLoader(train_dataset,batch_size,shuffle=True,num_workers=1)
test_dl=torch.utils.data.DataLoader(test_dataset,batch_size,shuffle=True,num_workers=1)for X,y in train_dl:print(X.shape)print(y.shape)break

​​

4.调用官方模型

from torchvision.models import vgg16
print("Using {} device".format(device))
model=vgg16(pretrained=True).to(device)
for param in model.parameters():param.requires_grad=False
model.classifier._modules['6']=nn.Linear(4096,len(ClassNames))
model.to(device)
model


​​​

5.动态调整学习率函数

#调用官方动态学习率接口
learning_rate = 1e-4
lambda1=lambda epoch:0.92**(epoch//4)
optimizer=torch.optim.SGD(model.parameters(),lr=learning_rate)
scheduler=torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lambda1)

6.编译及训练模型

def train(dataloader,model,loss_fn,optimizer):size=len(dataloader.dataset)num_batches=len(dataloader)train_loss,train_acc=0,0for X,y in dataloader:X,y =X.to(device),y.to(device)pred=model(X)loss=loss_fn(pred,y)#反向传播optimizer.zero_grad()loss.backward()optimizer.step()train_loss+=loss.item()train_acc+=(pred.argmax(1)==y).type(torch.float).sum().item()train_acc/=sizetrain_loss/=num_batchesreturn train_acc,train_lossdef test(dataloader,model,loss_fn):size=len(dataloader.dataset)num_batches=len(dataloader)test_loss,test_acc=0,0with torch.no_grad():for imgs,target in dataloader:imgs,target=imgs.to(device),target.to(device)target_pred=model(imgs)loss=loss_fn(target_pred,target)test_loss+=loss.item()test_acc+=(target_pred.argmax(1)==target).type(torch.float).sum().item()test_acc/=sizetest_loss/=num_batchesreturn test_acc,test_lossimport copy
loss_fn=nn.CrossEntropyLoss()
epochs=40
train_loss=[]
train_acc=[]
test_loss=[]
test_acc=[]
best_acc=0
for epoch in range(epochs):model.train()epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,optimizer)#更新学习率scheduler.step()model.eval()epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)if epoch_test_acc>=best_acc:best_acc=epoch_test_accbest_model=copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)lr=optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss, lr))
PATH='./best_model.pth'
torch.save(best_model.state_dict(),PATH)
print('Finished Training')

​​​​​​

7.结果可视化

import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False
plt.rcParams['figure.dpi']=100epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

​​

8.预测本地图片

from PIL import Image
classes=list(total_data.class_to_idx)
def predict_one_image(image_path,model,transform,classes):test_img=Image.open(image_path).convert('RGB')plt.imshow(test_img)test_img=transform(test_img)img=test_img.to(device).unsqueeze(0)model=model.eval()output=model(img)_,pred=torch.max(output,1)pred_class=classes[pred]print('预测结果是:{pred_class}')predict_one_image(image_path='data/48-data/Angelina Jolie/005_582c121a.jpg',model=model,transform=train_trainsforms,classes=classes)
#查看最优损失及准确率
best_model.eval()
epoch_test_Acc,epoch_test_loss=test(test_dl,best_model,loss_fn)
epoch_test_Acc,epoch_test_loss

​​

​​总结:

1.VGG-16

VGG-16(Visual Geometry Group-16)是由牛津大学视觉几何组(Visual Geometry Group)提出的一种深度卷积神经网络架构,用于图像分类和对象识别任务。VGG-16在2014年被提出,是VGG系列中的一种。VGG-16之所以备受关注,是因为它在ImageNet图像识别竞赛中取得了很好的成绩,展示了其在大规模图像识别任务中的有效性。

以下是VGG-16的主要特点:

  1. 深度:VGG-16由16个卷积层和3个全连接层组成,因此具有相对较深的网络结构。这种深度有助于网络学习到更加抽象和复杂的特征。
  2. 卷积层的设计:VGG-16的卷积层全部采用3x3的卷积核和步长为1的卷积操作,同时在卷积层之后都接有ReLU激活函数。这种设计的好处在于,通过堆叠多个较小的卷积核,可以提高网络的非线性建模能力,同时减少了参数数量,从而降低了过拟合的风险。
  3. 池化层:在卷积层之后,VGG-16使用最大池化层来减少特征图的空间尺寸,帮助提取更加显著的特征并减少计算量。
  4. 全连接层:VGG-16在卷积层之后接有3个全连接层,最后一个全连接层输出与类别数相对应的向量,用于进行分类。

VGG-16结构说明:

  • 13个卷积层(Convolutional Layer),分别用blockX_convX表示;
  • 3个全连接层(Fully connected Layer),用classifier表示;
  • 5个池化层(Pool layer)。

VGG-16包含了16个隐藏层(13个卷积层和3个全连接层),故称为VGG-16

2.设置动态学习率

#非官方设置动态学习率
def adjust_learning_rate(optimizer, epoch, start_lr):# 每 2 个epoch衰减到原来的 0.98lr = start_lr * (0.92 ** (epoch // 2))for param_group in optimizer.param_groups:param_group['lr'] = lrlearn_rate = 1e-4 # 初始学习率
optimizer  = torch.optim.SGD(model.parameters(), lr=learn_rate)#官方设置动态学习率
# 调用官方动态学习率接口时使用
lambda1 = lambda epoch: 0.92 ** (epoch // 4)
optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) #选定调整方法model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler = ExponentialLR(optimizer, gamma=0.9)for epoch in range(20):for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()scheduler.step()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/491455.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

R square 的计算方法和一点思考

模型的性能评价指标有几种方案:RMSE(平方根误差)、MAE(平均绝对误差)、MSE(平均平方误差)、R2_score 其中,当量纲不同时,RMSE、MAE、MSE难以衡量模型效果好坏。这就需要用到R2_score&#xff1…

解决并发情况下调用 Instruct-pix2pix 模型推理错误:index out of bounds 问题

解决并发情况下调用 Instruct-pix2pix 模型推理错误:index out of bounds 问题 背景介绍 在对 golang 开发的 图像生成网站 进行并发测试时,调用基于 Instruct-pix2pix 模型和 FastAPI 的图像生成 API 遇到了以下错误: Model inference er…

利用DFT画有限长序列的DTFT

MATLAB中没有DTFT函数,计算机不可能给出连续结果,可以只能利用DFT的fft函数来实现。 %% L 7; x ones(1, L) figure; tiledlayout(2,3,"TileSpacing","tight") nexttile; stem([0:L-1],x) box off title([num2str(L), points rect…

【进程篇】03.进程的概念与基本操作

一、进程的概念与理解 1.1 概念 进程是程序的一个执行实例,即正在执行的程序。 1.2 理解 我们编写代码运行后会在磁盘中会形成一个可执行程序,当我们运行这个可执行程序时,这个程序此时就会被操作系统的调度器加载到内存中;操…

基于MATLAB 的数字图像处理技术总结

大家好!欢迎来到本次的总结性的一篇文章,因为咸鱼哥这几个月是真的有点小忙(参加了点小比赛,准备考试等等)所以,在数字图像学习后,我来写一个总结性的文章,同时帮助大家学习&#xf…

llama2——微调lora,第一次参考教程实践完成包括训练和模型

前言:磕磕绊绊,不过收获很多,最大的收获就是解决报错error的分析方法和解决思路 1、首先,我参考的是这篇博客:怎样训练一个自己的大语言模型?全网最简单易懂的教程!_开源模型训练出一个语言模型…

类OCSP靶场-Kioptrix系列-Kioptrix Level 3

一、前情提要 二、实战打靶 1. 信息收集 1.1. 主机发现 1.2. 端口扫描 1.3.目录遍历 1.4. 敏感信息 2.漏洞发现 2.1.登录功能账号密码爆破 2.2.CMS历史漏洞 2.2.1.exp利用 2.2.2.提权 2.3. sql注入getshell 2.3.1.发现注入点 2.3.2. 测试字段和类型 2.3.3.查询字…

WPF实现曲线数据展示【案例:震动数据分析】

wpf实现曲线数据展示,函数曲线展示,实例:震动数据分析为例。 如上图所示,如果你想实现上图中的效果,请详细参考我的内容,创作不易,给个赞吧。 一共有两种方式来实现,一种是使用第三…

PHP代码审计学习(一)--命令注入

1、漏洞原理 参数用户可控&#xff0c;程序将用户可控的恶意参数通过php可执行命令的函数中运行导致。 2、示例代码 <?php echorec-test; $command ping -c 1 .$_GET[ip]; system($command); //system函数特性 执行结果会自动打印 ?> 通过示例代码可知通过system函…

Vivado安装System Generator不支持新版Matlab解决方法

目录 前言&#xff1a; Vivado安装System Generator不支持新版Matlab解决方法 前言&#xff1a; 本文介绍一下Vivado不支持新版Matlab的解决办法&#xff0c;Vivado只支持最近两年3个版本的Matlab&#xff0c;当前最新版vivado 2018.3只支持2017a,2017b,2018a。 Vivado安装Sy…

半导体数据分析(二):徒手玩转STDF格式文件 -- 码农切入半导体系列

一、概述 在上一篇文章中&#xff0c;我们一起学习了STDF格式的文件&#xff0c;知道了这是半导体测试数据的标准格式文件。也解释了为什么码农掌握了STDF文件之后&#xff0c;好比掌握了切入半导体行业的金钥匙。 从今天开始&#xff0c;我们一起来一步步地学习如何解构、熟…

#渗透测试#漏洞挖掘#红蓝攻防#SRC漏洞挖掘02之权限漏洞挖掘技巧

免责声明 本教程仅为合法的教学目的而准备,严禁用于任何形式的违法犯罪活动及其他商业行为,在使用本教程前,您应确保该行为符合当地的法律法规,继续阅读即表示您需自行承担所有操作的后果,如有异议,请立即停止本文章读。 权限相关漏洞 越权、未授权访问、oss、后台暴露、…

IS-IS协议

IS-IS协议介绍 IS-IS&#xff08;Intermediate System to Intermediate System&#xff09;协议是一种链路状态的内部网关协议&#xff08;IGP&#xff09;&#xff0c;用于在同一个自治系统&#xff08;Autonomous System, AS&#xff09;内部的路由器之间交换路由信息。IS-I…

4.7 TCP 的流量控制

欢迎大家订阅【计算机网络】学习专栏&#xff0c;开启你的计算机网络学习之旅&#xff01; 文章目录 前言1 滑动窗口与流量控制2 持续计时器与零窗口探测3 控制TCP发送报文段的时机3.1 控制发送时机的三种机制3.2 糊涂窗口综合症 前言 在网络通信中&#xff0c;流量控制是确保…

不良人系列-复兴数据结构(栈和队列)

个人主页&#xff1a;爱编程的小新☆ 不良人经典语录&#xff1a;“相呴相济 玉汝于成 勿念 心安” 目录 一. 栈(stack) 1. 栈的概念 2. 栈的常见方法 3.栈的模拟实现 ​编辑 二. 队列 1. 队列的概念 2. 队列的使用 2.1 队列的常见方法 2.2 队列的模拟实现 2.3 队列…

在clion中使用MySQL的教程

首先就是配置好东西&#xff0c;也是非常简单的&#xff1a; 1.把mysql安装目录&#xff08;其中的lib好像&#xff09;中的2个文件复制到下面就行 2.然后配置&#xff0c;这个文件 cmake_minimum_required(VERSION 3.24) project(2024_12project)include_directories(D:\\mys…

某名校考研自命题C++程序设计——近10年真题汇总(下)

第二期&#xff0c;相比上一贴本帖的题目难度更高一些&#xff0c;我当然不会告诉你我先挑简单的写~ 某名校考研自命题C程序设计——近10年真题汇总&#xff08;上&#xff09;-CSDN博客文章浏览阅读651次&#xff0c;点赞9次&#xff0c;收藏13次。本帖更新一些某校的编程真题…

探讨不同类型的自动化测试框架

以下为作者观点&#xff1a; 在自动化测试中&#xff0c;框架提供了一种组织和执行测试案例的结构化方式。它们提供了一套准则和最佳实践&#xff0c;使测试人员能够编写可重复使用、可维护和可扩展的测试脚本。在这篇文章中&#xff0c;我们将讨论自动化测试中不同类型的框架…

C# 网络编程--关于Socket编程TCP协议中封包、拆包问题

在使用 Socket 编程&#xff0c;进行TCP协议网络通信时&#xff0c;经常会遇到“粘包”&#xff08;也称为“封包、拆包”&#xff09;的问题。粘包是指发送方发送的多个数据包被接收方合并成一个数据包&#xff0c;或者一个数据包被拆分成多个数据包接收。这通常是由于 TCP协议…

HarmonyOS:@Observed装饰器和@ObjectLink装饰器:嵌套类对象属性变化

装饰器仅能观察到第一层的变化&#xff0c;但是在实际应用开发中&#xff0c;应用会根据开发需要&#xff0c;封装自己的数据模型。对于多层嵌套的情况&#xff0c;比如二维数组&#xff0c;或者数组项class&#xff0c;或者class的属性是class&#xff0c;他们的第二层的属性变…