动手学深度学习(Pytorch版)代码实践 -计算机视觉-48全连接卷积神经网络(FCN)

48全连接卷积神经网络(FCN

在这里插入图片描述

1.构造函数
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
import matplotlib.pyplot as plt
import liliPytorch as lp
from d2l import torch as d2l# 构造模型
pretrained_net = torchvision.models.resnet18(pretrained=True)
# print(list(pretrained_net.children())[-3:]) # ResNet-18模型的最后几层
"""
[Sequential((0): BasicBlock((conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
), AdaptiveAvgPool2d(output_size=(1, 1)), Linear(in_features=512, out_features=1000, bias=True)]  
"""# 创建一个全卷积网络net。 它复制了ResNet-18中大部分的预训练层,
# 除了最后的全局平均汇聚层和最接近输出的全连接层。
net = nn.Sequential(*list(pretrained_net.children())[:-2])# ResNet-18中
"""
X = torch.rand(size=(1, 1, 96, 96))
for layer in net:X = layer(X)print(layer.__class__.__name__, 'output shape:\t', X.shape)
# Sequential output shape:         torch.Size([1, 64, 24, 24])
# Sequential output shape:         torch.Size([1, 64, 24, 24])
# Sequential output shape:         torch.Size([1, 128, 12, 12])
# Sequential output shape:         torch.Size([1, 256, 6, 6])
# Sequential output shape:         torch.Size([1, 512, 3, 3])
前向传播将输入的高和宽减小至原来的 1/32
"""
# 使用1 X 1 卷积层将输出通道数转换为Pascal VOC2012数据集的类数(21类)。 
# 最后需要将特征图的高度和宽度增加32倍,从而将其变回输入图像的高和宽。
num_classes = 21
net.add_module('final_conv', nn.Conv2d(512, num_classes, kernel_size=1))
net.add_module('transpose_conv', nn.ConvTranspose2d(num_classes, num_classes,kernel_size=64, padding=16, stride=32))
# print(list(net.children())[-2:]) 
"""
[Conv2d(512, 21, kernel_size=(1, 1), stride=(1, 1)), 
ConvTranspose2d(21, 21, kernel_size=(64, 64), stride=(32, 32), padding=(16, 16))]
"""
2.双线性插值
# 初始化转置卷积层
# 在图像处理中,我们有时需要将图像放大,即上采样(upsampling)
# 双线性插值(bilinear interpolation) 是常用的上采样方法之一,
# 它也经常用于初始化转置卷积层
# 双线性插值的上采样可以通过转置卷积层实现,内核由以下bilinear_kernel函数构造。
def bilinear_kernel(in_channels, out_channels, kernel_size):factor = (kernel_size + 1) // 2  # 计算中心因子,用于确定卷积核的中心位置if kernel_size % 2 == 1: # 确定卷积核的中心位置# 如果卷积核大小是奇数center = factor - 1else:# 如果卷积核大小是偶数center = factor - 0.5# 生成坐标网格,用于计算每个位置的双线性内核值og = (torch.arange(kernel_size).reshape(-1, 1),  # 列向量torch.arange(kernel_size).reshape(1, -1))  # 行向量# 计算双线性内核值,基于当前位置与中心位置的距离并归一化filt = (1 - torch.abs(og[0] - center) / factor) * \(1 - torch.abs(og[1] - center) / factor)# 初始化权重张量weight = torch.zeros((in_channels, out_channels, kernel_size, kernel_size))# 填充权重张量,将计算好的双线性内核值赋值给权重张量weight[range(in_channels), range(out_channels), :, :] = filt# 返回生成的双线性卷积核权重张量return weight# 定义转置卷积层 (n + 4 - 2 - 2 ) * 2
conv_trans = nn.ConvTranspose2d(3, 3, kernel_size=4, padding=1, stride=2,bias=False)
# 将双线性卷积核权重复制到转置卷积层的权重
conv_trans.weight.data.copy_(bilinear_kernel(3, 3, 4))img = torchvision.transforms.ToTensor()(d2l.Image.open('../limuPytorch/images/catdog.jpg'))
"""
d2l.Image.open('../limuPytorch/images/catdog.jpg') 首先被执行,返回一个 PIL.Image 对象。
然后,torchvision.transforms.ToTensor() 创建一个 ToTensor 对象。
最后,ToTensor 对象被调用(通过 () 运算符),将 PIL.Image 对象作为参数传递给 ToTensor 的 __call__ 方法,
转换为 PyTorch 张量。
"""
X = img.unsqueeze(0) # 添加一个新的维度,形成形状为 (1, C, H, W) 的张量 X,
Y = conv_trans(X)
out_img = Y[0].permute(1, 2, 0).detach()print('input image shape:', img.permute(1, 2, 0).shape)
# input image shape: torch.Size([561, 728, 3])
plt.imshow(img.permute(1, 2, 0))
plt.show()
print('output image shape:', out_img.shape)
# output image shape: torch.Size([1122, 1456, 3])
# 图片放大了两倍
plt.imshow(out_img)
plt.show()
3.模型训练
# 37微调章节的代码
def train_batch_ch13(net, X, y, loss, trainer, devices):"""使用多GPU训练一个小批量数据。参数:net: 神经网络模型。X: 输入数据,张量或张量列表。y: 标签数据。loss: 损失函数。trainer: 优化器。devices: GPU设备列表。返回:train_loss_sum: 当前批次的训练损失和。train_acc_sum: 当前批次的训练准确度和。"""# 如果输入数据X是列表类型if isinstance(X, list):# 将列表中的每个张量移动到第一个GPU设备X = [x.to(devices[0]) for x in X]else:X = X.to(devices[0])# 如果X不是列表,直接将X移动到第一个GPU设备y = y.to(devices[0])# 将标签数据y移动到第一个GPU设备net.train() # 设置网络为训练模式trainer.zero_grad()# 梯度清零pred = net(X) # 前向传播,计算预测值l = loss(pred, y) # 计算损失l.sum().backward()# 反向传播,计算梯度trainer.step() # 更新模型参数train_loss_sum = l.sum()# 计算当前批次的总损失train_acc_sum = d2l.accuracy(pred, y)# 计算当前批次的总准确度return train_loss_sum, train_acc_sum# 返回训练损失和与准确度和def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices=d2l.try_all_gpus()):"""训练模型在多GPU参数:net: 神经网络模型。train_iter: 训练数据集的迭代器。test_iter: 测试数据集的迭代器。loss: 损失函数。trainer: 优化器。num_epochs: 训练的轮数。devices: GPU设备列表,默认使用所有可用的GPU。"""# 初始化计时器和训练批次数timer, num_batches = d2l.Timer(), len(train_iter)# 初始化动画器,用于实时绘制训练和测试指标animator = lp.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],legend=['train loss', 'train acc', 'test acc'])# 将模型封装成 DataParallel 模式以支持多GPU训练,并将其移动到第一个GPU设备net = nn.DataParallel(net, device_ids=devices).to(devices[0])# 训练循环,遍历每个epochfor epoch in range(num_epochs):# 初始化指标累加器,metric[0]表示总损失,metric[1]表示总准确度,# metric[2]表示样本数量,metric[3]表示标签数量metric = lp.Accumulator(4)# 遍历训练数据集for i, (features, labels) in enumerate(train_iter):timer.start()  # 开始计时# 训练一个小批量数据,并获取损失和准确度l, acc = train_batch_ch13(net, features, labels, loss, trainer, devices)metric.add(l, acc, labels.shape[0], labels.numel())   # 更新指标累加器timer.stop()  # 停止计时# 每训练完五分之一的批次或者是最后一个批次时,更新动画器if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(metric[0] / metric[2], metric[1] / metric[3], None))test_acc = d2l.evaluate_accuracy_gpu(net, test_iter) # 在测试数据集上评估模型准确度animator.add(epoch + 1, (None, None, test_acc))# 更新动画器# 打印最终的训练损失、训练准确度和测试准确度print(f'loss {metric[0] / metric[2]:.3f}, train acc 'f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')# 打印每秒处理的样本数和使用的GPU设备信息print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on 'f'{str(devices)}')# 全卷积网络用双线性插值的上采样初始化转置卷积层
W = bilinear_kernel(num_classes, num_classes, 64)
net.transpose_conv.weight.data.copy_(W)
# 读取数据集
batch_size, crop_size = 32, (320, 480)
train_iter, test_iter = lp.load_data_voc(batch_size, crop_size) # 46语义分割和数据集代码
# 损失函数
def loss(inputs, targets):return F.cross_entropy(inputs, targets, reduction='none').mean(1).mean(1)num_epochs, lr, wd, devices = 5, 0.001, 1e-3, d2l.try_all_gpus()
trainer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd)
train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)
# loss 0.443, train acc 0.863, test acc 0.848
# 254.0 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/370635.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

调试支付分回调下载平台证书

之前的原生代码放到webman里面,死活跑不通 没办法,只能用esayWeChat6.7 (自行下载) 它里面配置要用到平台证书 平台证书又要用到 composer require wechatpay/wechatpay 但是请求接口之前,你先要用到一个临时的平台…

linux下高级IO模型

高级IO 1.高级IO模型基本概念1.1 阻塞IO1.2 非阻塞IO1.3 信号驱动IO1.4 IO多路转接1.5 异步IO 2. 模型代码实现2.1 非阻塞IO2.2 多路转接-selectselect函数介绍什么才叫就绪呢?demoselect特点 2.3 多路转接-pollpoll函数介绍poll优缺点demo 2.4 多路转接-epoll&…

【算法笔记自学】第 5 章 入门篇(3)——数学问题

5.1简单数学 #include <cstdio> #include <algorithm> using namespace std; bool cmp(int a,int b){return a>b; } void to_array(int n,int num[]){for(int i0;i<4;i){num[i]n%10;n /10;} } int to_number(int num[]){int sum0;for(int i0;i<4;i){sumsu…

移动端UI风格营造舒适氛围

移动端UI风格营造舒适氛围

Spring容器Bean之XML配置方式

一、首先看applicationContext.xml里的配置项bean 我们采用xml配置文件的方式对bean进行声明和管理&#xff0c;每一个bean标签都代表着需要被创建的对象并通过property标签可以为该类注入其他依赖对象&#xff0c;通过这种方式Spring容器就可以成功知道我们需要创建那些bean实…

cs224n作业3 代码及运行结果

代码里要求用pytorch1.0.0版本&#xff0c;其实不用也可以的。 【删掉run.py里的assert(torch.version “1.0.0”)即可】 代码里面也有提示让你实现什么&#xff0c;弄懂代码什么意思基本就可以了&#xff0c;看多了感觉大框架都大差不差。多看多练慢慢来&#xff0c;加油&am…

Camunda 整合Springboot 实战篇

1.导入依赖 <dependency><groupId>org.camunda.bpm.springboot</groupId><artifactId>camunda-bpm-spring-boot-starter</artifactId><version>7.18.0</version></dependency><dependency><groupId>org.camunda.b…

C语言图书馆管理系统(管理员版)

案例&#xff1a;图书馆管理系统&#xff08;管理员版&#xff09; 背景&#xff1a; 随着信息技术的发展和普及&#xff0c;传统的图书馆管理方式已经无法满足现代图书馆高效、便捷、智能化的管理需求。传统的手工登记、纸质档案管理不仅耗时耗力&#xff0c;而且容易出现错…

剖析DeFi交易产品之UniswapV3:交易路由合约

本文首发于公众号&#xff1a;Keegan小钢 SwapRouter 合约封装了面向用户的交易接口&#xff0c;但不再像 UniswapV2Router 一样根据不同交易场景拆分为了那么多函数&#xff0c;UniswapV3 的 SwapRouter 核心就只有 4 个交易函数&#xff1a; exactInputSingle&#xff1a;指…

华为机试HJ34图片整理

华为机试HJ34图片整理 题目&#xff1a; 想法&#xff1a; 将输入的字符串中每个字符都转为ASCII码&#xff0c;再通过快速排序进行排序并输出 input_str input() input_list [int(ord(l)) for l in input_str]def partition(arr, low, high):i low - 1pivot arr[high]f…

matlab 有倾斜的椭圆函数图像绘制

matlab 有倾斜的椭圆函数图像绘制 有倾斜的椭圆函数图像绘制xy交叉项引入斜线负向斜线成分正向斜线成分 x^2 y^2 xy 1 &#xff08;负向&#xff09;绘制结果 x^2 y^2 - xy 1 &#xff08;正向&#xff09;绘制结果 有倾斜的椭圆函数图像绘制 为了确定椭圆的长轴和短轴的…

【Python】MacBook M系列芯片Anaconda下载Pytorch,并开发一个简单的数字识别代码(附带踩坑记录)

文章目录 配置镜像源下载Pytorch验证使用Pytorch进行数字识别 配置镜像源 Anaconda下载完毕之后&#xff0c;有两种方式下载pytorch&#xff0c;一种是用页面可视化的方式去下载&#xff0c;另一种方式就是直接用命令行工具去下载。 但是由于默认的Anaconda走的是外网&#x…

9 redis,memcached,nginx网络组件

课程目标: 1.网络模块要处理哪些事情 2.reactor是怎么处理这些事情的 3.reactor怎么封装 4.网络模块与业务逻辑的关系 5.怎么优化reactor? io函数 函数调用 都有两个作用:io检测 是否就绪 io操作 1. int clientfd = accept(listenfd, &addr, &len); 检测 全连接队列…

技术派Spring事件监听机制及原理

Spring事件监听机制是Spring框架中的一种重要技术&#xff0c;允许组件之间进行松耦合通信。通过使用事件监听机制&#xff0c;应用程序的各个组件可以在其他组件不直接引用的情况下&#xff0c;相互发送和接受消息。 需求 在技术派中有这样一个需求&#xff0c;当发布文章或…

简单分享下python多态

目录&#xff1a; 一、多态是啥嘞&#xff08;龙生九子各有不同&#xff0c;这就是多态&#xff09; 二、基础的实例 三、多态的优势与应用场景 四、深入理解 一、多态是啥嘞&#xff08;龙生九子各有不同&#xff0c;这就是多态&#xff09; 多态&#xff08;Polymorphism&…

如何利用算法优化广告效果

效果广告以超过67%的占比&#xff0c;成为了中国互联网广告预算的大头。在BAT、字节等大的媒体平台上&#xff0c;效果广告以CPC实时竞价广告为主。在这种广告产品的投放中&#xff0c;广告主或其代理公司通过针对每个广告点击出价&#xff0c;系统自动把这些点击出价换算成eCP…

【人工智能】-- 智能机器人

个人主页&#xff1a;欢迎来到 Papicatch的博客 课设专栏 &#xff1a;学生成绩管理系统 专业知识专栏&#xff1a; 专业知识 文章目录 &#x1f349;引言 &#x1f349;机器人介绍 &#x1f348;机器人硬件 &#x1f34d;机械结构 &#x1f34d;传感器 &#x1f34d;控…

nginx配置尝试

from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse, FileResponse, HTMLResponse import logging import os from datetime import datetime import uvicorn# 初始化日志 logging.basicConfig(filenamefile_server.lo…

学java的第3天 后端商城小程序工作

1.数据库的大坑 特殊字段名 ’我的图片表中有一个字段是描述我写成desc了&#xff0c;正好是mysql中的关键字 就不能使用了 2.后端编写 2.1可以把请求分开 在商品浏览页中 只显示商品的大致信息 当用户再点击其他按钮时在发出请求 2.2把请求合并 把数据整合到一起 利用ass…

SpringBoot环境集成 sms4j短信聚合

SpringBoot环境集成 sms4j短信聚合 官方文档 前言 在正式使用sms4j短信功能之前&#xff0c;请详细阅读本文档&#xff0c;依照本篇流程进行操作和配给&#xff0c;即可解决大部分问题&#xff0c;如对我们的文档有建议&#xff0c;请联系开发者团队&#xff0c; 我们将根据可…