pytorch实战-6手写数字加法机-迁移学习

1 概述

迁移学习概念:将已经训练好的识别某些信息的网络拿去经过训练识别另外不同类别的信息

优越性:提高了训练模型利用率,解决了数据缺失的问题(对于新的预测场景,不需要大量的数据,只需要少量数据即可实现训练,可用于数据点很少的场景)

如何实现:将训练好的一个网络拿来和另一个网络连起来去训练即可实现迁移

训练方式:按是否改变源网络参数可分两类,分别是可改变和不可改变

2 案例 南非贫困预测

2.1 背景

南非存在贫困,1990-2021贫困人口从56%下降到43%,但下降的贫困人口数量和国际人道主义援助资源并不对应,而且大量资金援助一定程度加剧了贫富差距。可以看下具体哪些地区需要援助

2.2 方法

一个方法:夜光光亮遥感数据和人类gdp相关性经实验可达0.8-0.9,但夜光遥感和贫富没太大相关性:夜间光照月亮表示该地区越富有,但越安并不表示该地区越贫穷,也可能无人居住。

另一个方法:光亮遥感数据无法准确预测地区贫穷程度,但卫星遥感数据大体可以做到,判定依据有街道混乱程度等。如果要用深度网络训练,还需要对卫星遥感数据的图片标注贫困程度。非洲能获取到的贫困数据很少,但深度网络需要的数据量很大

最终方法:用迁移学习,将前两种方法合起来,见下图

3 案例2

3.1 背景

任务:区分图像里动物是蚂蚁还是蜜蜂,像素均为224x224

难点:只有244个图像,样本太少不足训练大型卷积网络,准确率只有50%左右

3.2 解决方案

解决方案:resnet与模型迁移,即用已训练好的物体分类的网络加全连接用来区分蚂蚁与蜜蜂

resnet:残差网络,对物体分类有较高精度

3.3 代码实现

3.3.1 准备数据

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as pyplot
import time
import copy
import osdata_path = 'pytorch/jizhi/figure_plus/data'
image_size = 224class TranNet():def __init__(self):super(TranNet, self).__init__()self.train_dataset = datasets.ImageFolder(os.path.join(data_path, 'train'), transforms.Compose([transforms.RandomSizedCrop(image_size),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]))self.verify_dataset = datasets.ImageFolder(os.path.join(data_path, 'verify'), transforms.Compose([transforms.Scale(256),transforms.CenterCrop(image_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]))self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=4, shuffle=True, num_workers=4)self.verify_loader = torch.utils.data.DataLoader(self.verify_dataset, batch_size=4, shuffle=True, num_workers=4)self.num_classes = len(self.train_dataset.classes)def exec(self):...def main():TranNet().exec()if __name__ == '__main__':main()

3.3.2 模型迁移

    def exec(self):self.model_prepare()def model_prepare(self):net = models.resnet18(pretrained=True)# float net valuesnum_features = net.fc.in_featuresnet.fc = nn.Linear(num_features, 2)criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)# fixed net values'''for param in net.parameters():param.requires_grad = Falsenum_features = net.fc.in_featuresnet.fc = nn.Linear(num_features, 2)criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.fc.parameters(), lr = 0.001, momentum=0.9)'''

3.3.3 gpu加速

特点:gpu速度快,但内存低,所以尽量减少在gpu中存储的数据,只用来计算就好

    def model_prepare(self):# jusge whether GPUuse_cuda = torch.cuda.is_available()dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensoritype = torch.cuda.LongTensor if use_cuda else torch.LongTensornet = models.resnet18(pretrained=True)net = net.cuda() if use_cuda else net

3.3.4 训练

    def model_prepare(self):net = models.resnet18(pretrained=True)# jusge whether GPUuse_cuda = torch.cuda.is_available()if use_cuda:dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensoritype = torch.cuda.LongTensor if use_cuda else torch.LongTensornet = net.cuda() if use_cuda else net# float net valuesnum_features = net.fc.in_featuresnet.fc = nn.Linear(num_features, 2)criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)# fixed net values'''for param in net.parameters():param.requires_grad = Falsenum_features = net.fc.in_featuresnet.fc = nn.Linear(num_features, 2)criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.fc.parameters(), lr = 0.001, momentum=0.9)'''record = []num_epochs = 3net.train(True) # open dropoutfor epoch in range(num_epochs):train_rights = []train_losses = []for batch_index, (data, target) in enumerate(self.train_loader):data, target = data.clone().detach().requires_grad_(True), target.clone().detach()output = net(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()right = rightness(output, target)train_rights.append(right)train_losses.append(loss.data.numpy())if batch_index % 400 == 0:verify_rights = []for index, (data_v, target_v) in enumerate(self.verify_loader):data_v, target_v = data_v.clone().detach(), target_v.clone().detach()output_v = net(data_v)right = rightness(output_v, target_v)verify_rights.append(right)verify_accu = sum([row[0] for row in verify_rights]) / sum([row[1] for row in verify_rights])record.append((verify_accu))print(f'verify data accu:{verify_accu}')# plotpyplot.figure(figsize=(8, 6))pyplot.plot(record)pyplot.xlabel('step')pyplot.ylabel('verify loss')pyplot.show()

4 手写数字加法机

4.1 网络结构

可以先用cnn识别出两个待求和数字,不要输出,只保留池化层加后面一层全连接层,可以获取图像一维特征,然后将两个图像识别获取的一维特征合并,然后用全连接作为剩下的网络

4.2 代码实现

4.2.1 数据加载

class FigurePlus():def __init__(self):super(FigurePlus, self).__init__()self.image_size = 28self.num_classes = 10self.num_epochs = 3self.batch_size = 64self.train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)self.test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())sampler_a = torch.utils.data.sampler.SubsetRandomSampler(np.random.permutation(range(len(self.train_dataset))))sampler_b = torch.utils.data.sampler.SubsetRandomSampler(np.random.permutation(range(len(self.train_dataset))))self.train_loader_a = torch.utils.data.DataLoader(dataset=self.train_dataset, batch_size=self.batch_size, shuffle=False, sampler=sampler_a)self.train_loader_b = torch.utils.data.DataLoader(dataset=self.train_dataset, batch_size=self.batch_size, shuffle=False, sampler=sampler_b)self.verify_size = 5000verify_index_a = range(self.verify_size)verify_index_b = np.random.permutation(range(self.verify_size))test_index_a = range(self.verify_size, len(self.test_dataset))test_index_b = np.random.permutation(test_index_a)verify_sampler_a = torch.utils.data.sampler.SubsetRandomSampler(verify_index_a)verify_sampler_b = torch.utils.data.sampler.SubsetRandomSampler(verify_index_b)test_sampler_a = torch.utils.data.sampler.SubsetRandomSampler(test_index_a)test_sampler_b = torch.utils.data.sampler.SubsetRandomSampler(test_index_b)self.verify_loader_a = torch.utils.data.DataLoader(dataset=self.test_dataset, batch_size=self.batch_size, shuffle=False, sampler=verify_sampler_a)self.verify_loader_b = torch.utils.data.DataLoader(dataset=self.test_dataset, batch_size=self.batch_size, shuffle=False, sampler=verify_sampler_b)self.test_loader_a = torch.utils.data.DataLoader(dataset=self.test_dataset, batch_size=self.batch_size, shuffle=False, sampler=test_sampler_a)self.test_loader_b = torch.utils.data.DataLoader(dataset=self.test_dataset, batch_size=self.batch_size, shuffle=False, sampler=test_sampler_b)def gpu_ok(self):use_cuda = torch.cuda.is_available()dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensoritype = torch.cuda.LongTensor if use_cuda else torch.LongTensordef exec(self):passdef main():# TranNet().exec()FigurePlus().exec()if __name__ == '__main__':main()

4.2.2 手写数字加法机实现(网络实现)

    def forward(self, x, y, training=True):x, y = F.relu(self.net1_conv1(x)), F.relu(self.net2_conv1(y))x, y = self.net_pool(x), self.net_pool(y)x, y = F.relu(self.net1_conv2(x)), F.relu(self.net2_conv2(y))x, y = self.net_pool(x), self.net_pool(y)x = x.view(-1, (self.image_size // 4) ** 2 * self.depth[1])y = y.view(-1, (self.image_size // 4) ** 2 * self.depth[1])z = torch.cat((x, y), 1)z = self.fc1(z)z = F.relu(z)z = F.dropout(z, training=self.training)z = F.relu(self.fc2(z))z = F.relu(self.fc3(z))return F.relu(self.fc4(z))

4.2.3 模型迁移

思路:将上一篇弄好的数字识别模型保存到文件,然后在本章节加载进来,将各参数权重赋值到本节创的新网络

torch.save(cnn, model_save_path)

将网络加载进来时,需要源模型的定义,在本章重新定义下,拷贝后稍作修改

class FigureIdentify(nn.Module):def __init__(self):super(FigureIdentify, self).__init__()self.depth = (4, 8)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.pool(x)x = self.conv2(x)x = F.relu(x)x = self.pool(x)x = x.view(-1, (image_size // 4)**2 * self.depth[1])x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)x = F.log_softmax(x, dim=1)return x

注意

1 从文件加载模型后只加载网络权重,没加载网络的方法,需重新定义

2 加载后模型会赋值给一个对象,这个对象需要和保存网络时的网络架构保持一致(尽量保持一致,不然报错!!

模型文件加载进来后,用预训练模式(从模型文件加载网络权重作为初始权重,参数会随新网络的训练跟随大网络参数调节)

注意:加法器两个数字识别网络不可直接将文件加载的网络赋值,因为会共享地址,实际是一组参数,一个网络训练后另一个网络的参数也会变,可以复制出来再操作

    def copy_origin_weight(self, net):self.net1_conv1.weight.data = copy.deepcopy(net.conv1.weight.data)self.net1_conv1.bias.data = copy.deepcopy(net.conv1.bias.data)self.net1_conv2.weight.data = copy.deepcopy(net.conv2.weight.data)self.net1_conv2.bias.data = copy.deepcopy(net.conv2.bias.data)self.net2_conv1.weight.data = copy.deepcopy(net.conv1.weight.data)self.net2_conv1.bias.data = copy.deepcopy(net.conv1.bias.data)self.net2_conv2.weight.data = copy.deepcopy(net.conv2.weight.data)self.net2_conv2.bias.data = copy.deepcopy(net.conv2.bias.data)def main():# TranNet().exec()net = FigurePlusNet()origin_net = torch.load(model_save_path)net.copy_origin_weight(origin_net)criterion = nn.MSELoss()optmizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)

如要固定值迁移(即新模型训练过程不改变加载进来模型的权重),设requires_grad = False即可

    def copy_origin_weight_nograd(self, net):self.copy_origin_weight(net)self.net1_conv1.weight.requires_grad = Falseself.net1_conv1.bias.requires_grad = Falseself.net1_conv2.weight.requires_grad = Falseself.net1_conv2.bias.requires_grad = Falseself.net2_conv1.weight.requires_grad = Falseself.net2_conv1.bias.requires_grad = Falseself.net2_conv2.weight.requires_grad = Falseself.net2_conv2.bias.requires_grad = False

4.3 训练与测试

    # train records = []for epoch in range(net.num_epochs):losses = []for index, data in enumerate(zip(net.train_loader_a, net.train_loader_b)):(x1, y1), (x2, y2) = dataif net.gpu_ok():x1, y1, x2, y2 = x1.cuda(), y1.cuda(), x2.cuda(), y2.cuda()optimizer.zero_grad()net.train()outputs = net(x1.clone().detach(), x2.clone().detach())outputs = outputs.squeeze()labels = y1 + y2loss = criterion(outputs, labels.type(torch.float))loss.backward()optimizer.step()loss = loss.cpu() if net.gpu_ok() else losslosses.append(loss.data.numpy())if index % 300 == 0:verify_losses = []rights = []net.eval()for verify_data in zip(net.verify_loader_a, net.verify_loader_b):(x1, y1), (x2, y2) = verify_dataif net.gpu_ok():x1, y1, x2, y2 = x1.cuda(), y1.cuda(), x2.cuda(), y2.cuda()outputs = net(x1.clone().detach(), x2.clone().detach())outputs = outputs.squeeze()labels = y1 + y2loss = criterion(outputs, labels.type(torch.float))loss = loss.cpu() if net.gpu_ok() else lossverify_losses.append(loss.data.numpy())right = rightness(outputs.data, labels)rights.append(right)right_ratio = 1.0 * np.sum([i[0] for i in rights]) / np.sum([i[1] for i in rights])print(f'no.{epoch}, {index}/{len(net.train_loader_a)}, train loss:{np.mean(losses)}, verify loss:{np.mean(verify_losses)}, accu: {right_ratio}')# records.append([np.mean(losses), np.mean(verify_losses), right_ratio])records.append([right_ratio])# plot train datapyplot.figure(figsize=(8, 6))pyplot.plot(records)pyplot.xlabel('step')  pyplot.ylabel('loss & accuracy')# testrights = []net.eval()for test_data in zip(net.test_loader_a, net.test_loader_b):(x1, y1), (x2, y2) = test_dataif net.gpu_ok():x1, y1, x2, y2 = x1.cuda(), y1.cuda(), x2.cuda(), y2.cuda()outputs = net(x1.clone().detach(), x2.clone().detach())outputs = outputs.squeeze()labels = y1 + y2loss = criterion(outputs, labels.type(torch.float))right = rightness(outputs, labels)rights.append(right)right_ratio = 1.0 * np.sum([i[0] for i in rights]) / np.sum([i[1] for i in rights])print(f'test accuracy: {right_ratio}')pyplot.show()

4.4 结果

一轮打3个点,总共6轮(多了太慢,没用gpu),总共打大概24个点。发现随着训练轮数增加,准确率逐步提高

4.5 大规模测试

4.5.1 大模型

大模型指迁移学习全连接层用4层网络。以数据量为自变量,分别看5%,50%,100%数据量情况下,迁移学习与不迁移学习准确率随轮数变化趋势

结论:1 数据量100%时,迁移学习与无迁移学习准确率趋势很接近;数据量较小时,迁移学习准确率上升速度会远快于无迁移学习。数据量到50%左右时差异就变得不是很明显了但还是有 2数据量大时,固定值训练模式比预训练模式精度更好

4.5.2 小模型

小模型指迁移学习全连接层用2层而不是4层,仍然以数据量和是否迁移学习为自变量分析

结论:1 数据量小时,迁移学习比无迁移学习准确率上升速度快,数据量大时,迁移学习与无迁移学习这种差异变小 

5 总结

适用场景(不仅限于这些):数据量小

两种迁移方式:固定值和预训练,固定值方式参数变动范围小,训练可能更快

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/245144.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Unity通用渲染管线升级URP、HDRP

Unity通用渲染管线升级URP、HDRP 一、Build-in Pipline升级到 URP 一、Build-in Pipline升级到 URP 安装URP包 升级所有材质(升级完成后材质会变成紫红色,Shader丢失,此为正常现象) 创建 UniversalRenderPipelineAsset 配置文…

Nacos 在云原生架构下的演进

作者:之卫 背景 Nacos 提供的最核心能力是动态服务发现与动态配置管理能力,在云原生环境下,借助云产品,如 EDAS(企业级分布式应用服务)平台中,我们可以很轻松地使用 K8s 来托管 Nacos 体系的微…

蓝桥杯(Python)每日练Day5

题目 OJ1229 题目分析 题目完全符合栈的特征,后进先出。如果能够熟练使用列表的9种方法那么这道题很容易解出。 题解 a[]#存衣服 nint(input()) for i in range(n):llist(input().split())#判断每一步的操作if len(l[0])2:a.append(l[1])else:while a.pop()!l…

大数据平台红蓝对抗 - 磨利刃,淬精兵!

背景 目前大促备战常见备战工作:专项压测(全链路压测、内部压测)、灾备演练、降级演练、限流、巡检(监控、应用健康度)、混沌演练(红蓝对抗),如下图所示。随着平台业务越来越复杂&a…

LabVIEW探测器CAN总线系统

介绍了一个基于FPGA和LabVIEW的CAN总线通信系统,该系统专为与各单机进行系统联调测试而设计。通过设计FPGA的CAN总线功能模块和USB功能模块,以及利用LabVIEW开发的上位机程序,系统成功实现了CAN总线信息的收发、存储、解析及显示功能。测试结…

Obsidian笔记软件结合cpolar实现安卓移动端远程本地群晖WebDAV数据同步

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

使用Robot Framework实现多平台自动化测试

基于Robot Framework、Jenkins、Appium、Selenium、Requests、AutoIt等开源框架和技术,成功打造了通用自动化测试持续集成管理平台(以下简称“平台”),显著提高了测试质量和测试用例的执行效率。 01、设计目标 平台通用且支持不…

Hadoop基本概论

目录 一、大数据概论 1.大数据的概念 2.大数据的特点 3.大数据应用场景 二、Hadoop概述 1.Hadoop定义 2.Hadoop发展历史 3.Hadoop发行版本 4.Hadoop优势 5.Hadoop1.x/2.x/3.x 6.HDFS架构 7.Yarn架构 8.MapReduce架构 9.大数据技术生态体系 一、大数据概论 1.大数…

docker 基础手册

文章目录 docker 基础手册docker 容器技术镜像与容器容器与虚拟机docker 引擎docker 架构docker 底层技术docker 二进制安装docker 镜像加速docker 相关链接docker 生态 docker 基础手册 docker 容器技术 开源的容器项目,使用 Go 语言开发原意“码头工人”&#x…

[极客大挑战 2019]LoveSQL1

万能密码测试,发现注入点 注意这里#要使用url编码才能正常注入 测试列数,得三列 查看table,一个是geekuser另一个是l0ve1ysq1 查看column,有id,username,password,全部打印出来,…

【数据结构】 顺序表的基本操作 (C语言版)

一、顺序表 1、顺序表的定义: 线性表的顺序存储结构,即将表中的结点按逻辑顺序依次存放在一组地址连续的存储单元里。这种存储方式使得在逻辑结构上相邻的数据元素在物理存储上也是相邻的,可以通过数据元素的物理存储位置来反映其逻辑关系。…

SSL证书DV和OV的区别

SSL证书是数字证书的一种,配置在服务器上,起到文件信息传输加密的作用。由受信任的数字证书颁发机构CA在验证服务器身份后颁发,防止第三方窃取或篡改信息。 在选择SSL证书的过程中,一般要注意选择的SSL证书的等级。常见有DV和OV证…

单片机面向对象思维的架构:时间轮片法

今天分享一篇单片机程序框架的文章。 程序架构重要性 很多人尤其是初学者在写代码的时候往往都是想一点写一点,最开始没有一个整体的规划,导致后面代码越写越乱,bug不断。 最终代码跑起来看似没有问题(有可能也真的没有问题),但…

清越 peropure·AI 国内版ChatGP新功能介绍

当OpenAI发布ChatGPT的时候,没有人会意识到,新一代人工智能浪潮将给人类社会带来一场眩晕式变革。其中以ChatGPT为代表的AIGC技术加速成为AI领域的热门发展方向,推动着AI时代的前行发展。面对技术浪潮,清越科技(PeroPure)立足多样化生活场景、精准把握用户实际需求,持续精确Fin…

差分进化算法求解基于移动边缘计算 (MEC) 的无线区块链网络的联合挖矿决策和资源分配(提供MATLAB代码)

一、优化模型介绍 在所研究的区块链网络中,优化的变量为:挖矿决策(即 m)和资源分配(即 p 和 f),目标函数是使所有矿工的总利润最大化。问题可以表述为: max ⁡ m , p , f F miner …

255:vue+openlayers 加载tomtom地图(多种形式)

第255个 点击查看专栏目录 本示例的目的是介绍演示如何在vue+openlayers中添加tomtom地图,这里包含了多种形式,诸如中文标记、英文标记、白天地图、晚上地图、卫星影像图,高山海拔地形图等。 直接复制下面的 vue+openlayers源代码,操作2分钟即可运行实现效果 文章目录 示…

vue3和vite项目在scss中因为本地图片,不用加~

看了很多文章说要加~,真的好坑哦,我的加了~反而出不来了: 304 Not Modified 所以需要去掉~: /* 默认dark主题 */ :root[themered] {--bg-color: #0d1117;--text-color: #f0f6fc;--backImg: url(/assets/images/redBg.png); }/* …

鸿蒙开发踩坑之dataPreferences数据存储后获取为空

问题 在开发中通过PreferencesUtil.setValue(name, 旺财)设置后,通过IDE运行App后获取之前存储的数据都为空。 问题原因 查看控制台,发现如下: $ hdc shell am force-stop com.happy.xxx $ hdc shell bm uninstall com.happy.xxx$ hdc fi…

Java PDFBox 提取页数、PDF转图片

PDF 提取 使用Apache 的pdfbox组件对PDF文件解析读取和转图片。 Maven 依赖 导入下面的maven依赖&#xff1a; <dependency><groupId>org.apache.pdfbox</groupId><artifactId>pdfbox</artifactId><version>2.0.30</version> &l…

数据结构之二叉树的遍历

数据结构是程序设计的重要基础&#xff0c;它所讨论的内容和技术对从事软件项目的开发有重要作用。学习数据结构要达到的目标是学会从问题出发&#xff0c;分析和研究计算机加工的数据的特性&#xff0c;以便为应用所涉及的数据选择适当的逻辑结构、存储结构及其相应的操作方法…