残差网络实现

代码中涉及的图片实验数据下载地址:https://download.csdn.net/download/m0_37567738/88235543?spm=1001.2014.3001.5501

代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
#from utils import load_data,get_accur,train
import timeimport torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
import torch.optim as optim
import numpy as npdef load_data(path, batch_size):datasets = torchvision.datasets.ImageFolder(root = path,transform = transforms.Compose([transforms.ToTensor()]))dataloder = DataLoader(datasets, batch_size=batch_size, shuffle=True)return datasets,dataloderdef get_accur(preds, labels):preds = preds.argmax(dim=1)return torch.sum(preds == labels).item()def train(model, epochs, learning_rate, dataloader, criterion, testdataloader):optimizer = optim.Adam(model.parameters(),lr=learning_rate)train_loss_list = []test_loss_list = []train_accur_list = []test_accur_list = []train_len = len(dataloader.dataset)test_len = len(testdataloader.dataset)for i in range(epochs):train_loss = 0.0train_accur = 0test_loss = 0.0test_accur = 0for batch in dataloader:imgs, labels = batchpreds = model(imgs)optimizer.zero_grad()loss = criterion(preds, labels)loss.backward()optimizer.step()train_loss += loss.item()train_accur += get_accur(preds,labels)train_loss_list.append(train_loss)train_accur_list.append(train_accur / train_len)for batch in testdataloader:imgs, labels = batchpreds = model(imgs)loss = criterion(preds, labels)test_loss += loss.item()test_accur += get_accur(preds,labels)test_loss_list.append(test_loss)test_accur_list.append(test_accur / test_len)print("epoch {} : train_loss : {}; train_accur : {}".format(i + 1, train_loss, train_accur / train_len))return np.array(train_accur_list), np.array(train_loss_list), np.array(test_accur_list), np.array(test_loss_list)class ResidualBlock(nn.Module):def __init__(self, inchannel, outchannel, stride=1):super().__init__()self.left = nn.Sequential(nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),nn.BatchNorm2d(outchannel),nn.ReLU(inplace=True),nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1,padding=1, bias=False),# 尺寸不发生变化 通道改变nn.BatchNorm2d(outchannel))self.shortcut = nn.Sequential()# 注意shortcut是对输入X进行卷积,利用1×1卷积改变形状if inchannel != outchannel or stride != 1:self.shortcut = nn.Sequential(nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(outchannel))def forward(self, X):h = self.left(X)# 先相加再激活h += self.shortcut(X)out = F.relu(h)return outclass ResidualNet(nn.Module):def __init__(self):super().__init__()self.residual_block = nn.Sequential(ResidualBlock(3, 32),ResidualBlock(32, 64),ResidualBlock(64, 32),ResidualBlock(32, 3))self.fc1 = nn.Linear(3 * 64 * 64, 1024)self.fc2 = nn.Linear(1024, 3)def forward(self, X):h = self.residual_block(X)h = h.view(-1, 3 * 64 * 64)h = self.fc1(h)out = self.fc2(h)return outif __name__ == "__main__":train_path = "./cnn/train/"test_path = "./cnn/test/"_, train_dataloader = load_data(train_path, 32)_, test_dataloader = load_data(test_path, 32)model = ResidualNet()critic = nn.CrossEntropyLoss()epoch = 20lr = 0.01start = time.clock()print("Start training model.....")train_accur_list, train_loss_list, test_accur_list, test_loss_list = train(model, epoch, lr, train_dataloader,critic, test_dataloader)end = time.clock()print("Train cost: {} s".format(end - start))test_accur = 0for batch in test_dataloader:imgs, labels = batchpreds = model(imgs)test_accur += get_accur(preds, labels)print("Accuracy on test datasets : {}".format(test_accur / len(test_dataloader.dataset)))

执行结果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/98514.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

茂名 湛江阳江某学校 ibm x3850服务器维修经历

简介:中国广东省阳江市某中学联想 IBM System x3850 x6服务器维修 io板故障处理经历分享: 这一天一位阳江的老师经其他学校老师介绍推荐对接我,说他们学校有一台ibm服务器出问题了,老师大致跟我描述了一下这台服务器发生故障的前…

Android12之com.android.media.swcodec无法生成apex问题(一百六十三)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生…

Apache DolphinScheduler 支持使用 OceanBase 作为元数据库啦!

DolphinScheduler是一个开源的分布式任务调度系统,拥有分布式架构、多任务类型、可视化操作、分布式调度和高可用等特性,适用于大规模分布式任务调度的场景。目前DolphinScheduler支持的元数据库有Mysql、PostgreSQL、H2,如果在业务中需要更好…

iOS UIAlertController控件

ios 9 以后 UIAlertController取代UIAlertView和UIActionSheet UIAlertControllerStyleAlert和UIAlertControllerStyleActionSheet。 在UIAlertController中添加按钮和关联输入框 UIAlertAction共有三种类型,默认(UIAlertActionStyleDefault&#xff0…

【Linux】进程信号篇Ⅰ:信号的产生(signal、kill、raise、abort、alarm)、信号的保存(core dump)

文章目录 一、 signal 函数:用户自定义捕捉信号二、信号的产生1. 通过中断按键产生信号2. 调用系统函数向进程发信号2.1 kill 函数:给任意进程发送任意信号2.2 raise 函数:给调用进程发送任意信号2.3 abort 函数:给调用进程发送 6…

机器学习深度学习——NLP实战(情感分析模型——数据集)

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er 🌌上期文章:机器学习&&深度学习——BERT(来自transformer的双向编码器表示) 📚订阅专栏:机器…

Kubernetes 安全机制 认证 授权 准入控制

客户端应用若想发送请求到 apiserver 操作管理K8S资源对象,需要先通过三关安全验证 认证(Authentication)鉴权(Authorization)准入控制(Admission Control) Kubernetes 作为一个分布式集群的管理…

Flink的Standalone部署实战

在Flink是通用的框架,以混合和匹配的方式支持部署不同场景,而Standalone单机部署方便快速部署,记录本地部署过程,方便备查。 环境要求 1)JDK1.8及以上 2)flink-1.14.3 3)CentOS7 Flink相关信…

ELK日志监控系统搭建docker版

目录 日志来源elk介绍elasticsearch介绍logstash介绍kibana介绍 部署elasticsearch拉取镜像:docker pull elasticsearch:7.17.9修改配置⽂件:/usr/share/elasticsearch/config/elasticsearch.yml启动容器设置密码(123456)忘记密码…

opencv-进阶05 手写数字识别原理及示例

前面我们仅仅取了两个特征维度进行说明。在实际应用中,可能存在着更多特征维度需要计算。 下面以手写数字识别为例进行简单的介绍。 假设我们要让程序识别图 20-2 中上方的数字(当然,你一眼就知道是“8”,但是现在要让计算机识别…

【JUC系列-01】深入理解JMM内存模型的底层实现原理

一,深入理解JMM内存模型 1,什么是可见性 在谈jmm的内存模型之前,先了解一下并发并发编程的三大特性,分别是:可见性,原子性,有序性。可见性指的就是当一个线程修改某个变量的值之后&#xff0c…

自动化测试用例设计实例

在编写用例之间,笔者再次强调几点编写自动化测试用例的原则: 1、一个脚本是一个完整的场景,从用户登陆操作到用户退出系统关闭浏览器。 2、一个脚本脚本只验证一个功能点,不要试图用户登陆系统后把所有的功能都进行验证再退出系统…

智慧水利利用4G物联网技术实现远程监测、控制、管理

智慧水利工业路由器是集合数据采集、实时监控、远程管理的4G物联网通讯设备,能够让传统水利系统实现智能化的实时监控和远程管理。工业路由器利用4G无线网络技术,能够实时传输数据和终端信息,为水利系统的运维提供有效的支持。 智慧水利系统是…

湘潭大学 湘大 XTU OJ 1055 整数分类 题解(非常详细)

链接 整数分类 题目 Description 按照下面方法对整数x进行分类:如果x是一个个位数,则x属于x类;否则将x的各位上的数码累加,得到一个新的x,依次迭代,可以得到x的所属类。比如说24,246&#…

手写模拟SpringBoot核心流程(二):实现Tomcat和Jetty的切换

实现Tomcat和Jetty的切换 前言 上一篇文章我们聊到,SpringBoot中内置了web服务器,包括Tomcat、Jetty,并且实现了SpringBoot启动Tomcat的流程。 那么SpringBoot怎样自动切换成Jetty服务器呢? 接下来我们继续学习如何实现Tomcat…

⛳ TCP 协议面试题

目录 ⛳ TCP 协议面试题🐾 一、为什么关闭连接的需要四次挥⼿,⽽建⽴连接却只要三次握⼿呢?🏭 二、为什么连接建⽴的时候是三次握⼿,可以改成两次握⼿吗?👣 三、为什么主动断开⽅在TIME-WAIT状态…

服务器感染了.360勒索病毒,如何确保数据文件完整恢复?

引言: 随着科技的不断进步,互联网的普及以及数字化生活的发展,网络安全问题也逐渐成为一个全球性的难题。其中,勒索病毒作为一种危害性极高的恶意软件,在近年来频频袭扰用户。本文91数据恢复将重点介绍 360 勒索病毒&a…

Git分布式版本控制系统

目录 2、安装git 2.1 初始环境 2.2 Yum安装Git 2.3 编译安装 2.4 初次运行 Git 前的配置 2.5 初始化及获取 Git 仓库 2.6 Git命令常规操作 2.6.2 添加新文件 2.6.3 删除git内的文件 2.6.4 重命名暂存区数据 2.6.5 查看历史记录 2.6.6 还原历史数据 2.6.7 还原未来…

[机器学习]特征工程:主成分分析

目录 主成分分析 1、简介 2、帮助理解 3、API调用 4、案例 本文介绍主成分分析的概述以及python如何实现算法,关于主成分分析算法数学原理讲解的文章,请看这一篇: 探究主成分分析方法数学原理_逐梦苍穹的博客-CSDN博客https://blog.csdn.…

WebRTC | 网络传输协议RTP与RTCP

目录 一、UDP与TCP 1. TCP 2. UDP 二、RTP 1. RTP协议头 (1)V(Version)字段 (2)P(Padding)字段 (3)X(eXtension)字段 &#x…