迁移学习案例-python代码

在这里插入图片描述

大白话

迁移学习就是用不太相同但又有一些联系的A和B数据,训练同一个网络。比如,先用A数据训练一下网络,然后再用B数据训练一下网络,那么就说最后的模型是从A迁移到B的。

迁移学习的具体形式是多种多样的,比如先用A训练好一个网络,然后复制这个网络的某几个层的参数到一个新的网络作为初始化的参数,然后用B数据去训练这个新网络。又或者,面对中文翻译的问题,中文翻译成英文和中文翻译成火星文,前几层在提取特征,可以共享参数层,后面几层由于任务不同就可以各自私有训练。

案例来源:李宏毅课程-机器学习-迁移学习

A数据:是源数据,量大效果好,并且有标签。
在这里插入图片描述
B数据:量少,没标签。
在这里插入图片描述

目的:希望用A数据先训练网络提取到关键特征,然后预测B数据的标签。但是把他们当作两个任务效果不佳,于是以一种迁移的方法解决--域对抗(先用A训练好模型,再直接用B测试,这样效果不佳;而是希望以一种“迁移”的方法,把A数据的知识拿到B上面用)

直接上代码

import matplotlib.pyplot as pltdef no_axis_show(img, title='', cmap=None):# imshow, 缩放模式为nearest。fig = plt.imshow(img, interpolation='nearest', cmap=cmap)# 不要显示axis。fig.axes.get_xaxis().set_visible(False)fig.axes.get_yaxis().set_visible(False)plt.title(title)titles = ['horse', 'bed', 'clock', 'apple', 'cat', 'plane', 'television', 'dog', 'dolphin', 'spider']
plt.figure(figsize=(18, 18))
for i in range(10):plt.subplot(1, 10, i+1)fig = no_axis_show(plt.imread(f'work/real_or_drawing/train_data/{i}/{500*i}.bmp'), title=titles[i])
plt.figure(figsize=(18, 18))
for i in range(10):plt.subplot(1, 10, i+1)fig = no_axis_show(plt.imread(f'work/real_or_drawing/test_data/0/' + str(i).rjust(5, '0') + '.bmp'))
import cv2
import matplotlib.pyplot as plt
titles = ['horse', 'bed', 'clock', 'apple', 'cat', 'plane', 'television', 'dog', 'dolphin', 'spider']
plt.figure(figsize=(18, 18))original_img = plt.imread(f'work/real_or_drawing/train_data/0/0.bmp')
plt.subplot(1, 5, 1)
no_axis_show(original_img, title='original')gray_img = cv2.cvtColor(original_img, cv2.COLOR_RGB2GRAY)
plt.subplot(1, 5, 2)
no_axis_show(gray_img, title='gray scale', cmap='gray')gray_img = cv2.cvtColor(original_img, cv2.COLOR_RGB2GRAY)
plt.subplot(1, 5, 2)
no_axis_show(gray_img, title='gray scale', cmap='gray')canny_50100 = cv2.Canny(gray_img, 50, 100)
plt.subplot(1, 5, 3)
no_axis_show(canny_50100, title='Canny(50, 100)', cmap='gray')canny_150200 = cv2.Canny(gray_img, 150, 200)
plt.subplot(1, 5, 4)
no_axis_show(canny_150200, title='Canny(150, 200)', cmap='gray')canny_250300 = cv2.Canny(gray_img, 250, 300)
plt.subplot(1, 5, 5)
no_axis_show(canny_250300, title='Canny(250, 300)', cmap='gray')
import cv2
import numpy as np
import paddleimport paddle.optimizer as optim
from paddle.io import DataLoader
from paddle.vision.datasets import DatasetFolder
from paddle.nn import Sequential, Conv2D, BatchNorm1D, BatchNorm2D, ReLU, MaxPool2D, Linear
from paddle.vision.transforms import Compose, Grayscale, Transpose, RandomHorizontalFlip, RandomRotation, Resize, ToTensor
class Canny(paddle.vision.transforms.transforms.BaseTransform):def __init__(self, low, high, keys=None):super(Canny, self).__init__(keys)self.low = lowself.high = highdef _apply_image(self, img):Canny = lambda img: cv2.Canny(np.array(img), self.low, self.high)return Canny(img)
source_transform = Compose([RandomHorizontalFlip(),RandomRotation(15),Grayscale(),Canny(low=170, high=300),# Transpose(),ToTensor()])
target_transform = Compose([Grayscale(),Resize((32, 32)),RandomHorizontalFlip(),RandomRotation(15, fill=(0,)),ToTensor()])source_dataset = DatasetFolder('work/real_or_drawing/train_data', transform=source_transform)
target_dataset = DatasetFolder('work/real_or_drawing/test_data', transform=target_transform)source_dataloader = DataLoader(source_dataset, batch_size=32, shuffle=True)
target_dataloader = DataLoader(target_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(target_dataset, batch_size=128, shuffle=False)
class FeatureExtractor(paddle.nn.Layer):def __init__(self):super(FeatureExtractor, self).__init__()self.conv = Sequential(Conv2D(1, 64, 3, 1, 1),BatchNorm2D(64),ReLU(),MaxPool2D(2),Conv2D(64, 128, 3, 1, 1),BatchNorm2D(128),ReLU(),MaxPool2D(2),Conv2D(128, 256, 3, 1, 1),BatchNorm2D(256),ReLU(),MaxPool2D(2),Conv2D(256, 256, 3, 1, 1),BatchNorm2D(256),ReLU(),MaxPool2D(2),Conv2D(256, 512, 3, 1, 1),BatchNorm2D(512),ReLU(),MaxPool2D(2))def forward(self, x):x = self.conv(x).squeeze()return xclass LabelPredictor(paddle.nn.Layer):def __init__(self):super(LabelPredictor, self).__init__()self.layer = Sequential(Linear(512, 512),ReLU(),Linear(512, 512),ReLU(),Linear(512, 10),)def forward(self, h):c = self.layer(h)return cclass DomainClassifier(paddle.nn.Layer):def __init__(self):super(DomainClassifier, self).__init__()self.layer = Sequential(Linear(512, 512),BatchNorm1D(512),ReLU(),Linear(512, 512),BatchNorm1D(512),ReLU(),Linear(512, 512),BatchNorm1D(512),ReLU(),Linear(512, 512),BatchNorm1D(512),ReLU(),Linear(512, 1),)def forward(self, h):y = self.layer(h)return y
feature_extractor = FeatureExtractor()
label_predictor = LabelPredictor()
domain_classifier = DomainClassifier()class_criterion = paddle.nn.loss.CrossEntropyLoss()
domain_criterion = paddle.nn.BCEWithLogitsLoss()optimizer_F = optim.Adam(parameters=feature_extractor.parameters())
optimizer_C = optim.Adam(parameters=label_predictor.parameters())
optimizer_D = optim.Adam(parameters=domain_classifier.parameters())
def train_epoch(source_dataloader, target_dataloader, lamb):'''Args:source_dataloader: source data的dataloadertarget_dataloader: target data的dataloaderlamb: 调控adversarial的loss系数。'''# D loss: Domain Classifier的loss# F loss: Feature Extrator & Label Predictor的loss# total_hit: 计算目前对了几笔 total_num: 目前经过了几笔running_D_loss, running_F_loss = 0.0, 0.0total_hit, total_num = 0.0, 0.0for i, ((source_data, source_label), (target_data, _)) in enumerate(zip(source_dataloader, target_dataloader)):# source_data = source_data.cuda()# source_label = source_label.cuda()# target_data = target_data.cuda()# 我们把source data和target data混在一起,否则batch_norm可能会算错 (两边的data的mean/var不太一样)mixed_data = paddle.concat([source_data, target_data], axis=0)domain_label = paddle.zeros([source_data.shape[0] + target_data.shape[0], 1])# 设定source data的label为1domain_label[:source_data.shape[0]] = 1# Step 1 : 训练Domain Classifierfeature = feature_extractor(mixed_data)# 因为我们在Step 1不需要训练Feature Extractor,所以把feature detach避免loss backprop上去。domain_logits = domain_classifier(feature.detach())# print('domain_logits.shape:', domain_logits.shape, 'domain_label.shape:', domain_label.shape)loss = domain_criterion(domain_logits, domain_label)# running_D_loss+= loss.numpy()[0]running_D_loss+= loss.numpy()# print('loss:', loss)loss.backward()optimizer_D.step()# Step 2 : 训练Feature Extractor和Domain Classifierclass_logits = label_predictor(feature[:source_data.shape[0]])domain_logits = domain_classifier(feature)# loss为原本的class CE - lamb * domain BCE,相减的原因同GAN中的Discriminator中的G loss。loss = class_criterion(class_logits, source_label) - lamb * domain_criterion(domain_logits, domain_label)# running_F_loss+= loss.numpy()[0]running_F_loss+= loss.numpy()loss.backward()optimizer_F.step()optimizer_C.step()optimizer_D.clear_grad()optimizer_F.clear_grad()optimizer_C.clear_grad()# print('class_logits.shape:', class_logits.shape, 'source_label.shape:', source_label.shape)# print('class_logits[0]:', class_logits[0], 'source_label[0]:', source_label[0])total_hit += np.sum((paddle.argmax(class_logits, axis=1) == source_label).numpy())total_num += source_data.shape[0]print(i, end='\r')return running_D_loss / (i+1), running_F_loss / (i+1), total_hit / total_num# 训练200 epochs
for epoch in range(200):train_D_loss, train_F_loss, train_acc = train_epoch(source_dataloader, target_dataloader, lamb=0.1)paddle.save(feature_extractor.state_dict(), f'extractor_model.pdparams')paddle.save(label_predictor.state_dict(), f'predictor_model.pdparams')print('epoch {:>3d}: train D loss: {:6.4f}, train F loss: {:6.4f}, acc {:6.4f}'.format(epoch, train_D_loss, train_F_loss, train_acc))

训练结束,预测一波

result = []
label_predictor.eval()
feature_extractor.eval()
for i, (test_data, _) in enumerate(test_dataloader):test_data = test_dataclass_logits = label_predictor(feature_extractor(test_data))x = paddle.argmax(class_logits, axis=1).detach().numpy()result.append(x)import pandas as pd
result = np.concatenate(result)# Generate your submission
df = pd.DataFrame({'id': np.arange(0,len(result)), 'label': result})
df.to_csv('work/DaNN_submission.csv',index=False)

训练比较慢,还得是把代码转到cuda上才行,demo可以把epoch减小一点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/441107.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HCIA综合实验

实验步骤 1.划分网段 内网部分---三个大块 2.先配交换机 左边:3个vlan ,3个access,1个trunk 右边:2个vlan ,2个access,1个trunk 3.再配路由 3.1 r5先配接口ipg/0/0/0 口配子接口 g0/0/0.1-0.3 g0/0/1 …

【YOLOv8实时产品缺陷检测】

YOLOv8应用于产品缺陷检测实例 项目概况项目实现YOLOv8安装及模型训练关键代码展示动态效果展示 项目概况 本项目是应用YOLOv8框架实现训练自定义模型实现单一零件的缺陷检测,软件界面由PyQt5实现。 功能已正式使用,识别效果达到预期。 项目实现 项目…

手机误删照片?试试这5款免费数据恢复神器!

大家好!今天咱们来聊聊一个大家都关心的话题——免费数据恢复工具。不论是误删照片、视频,还是丢失重要文件,数据恢复都是个让人头疼的问题。但好消息是,现在有众多免费的数据恢复工具能帮助我们找回失去的数据。今天我就来为大家…

力扣16~20题

题16&#xff08;中等&#xff09;&#xff1a; 思路&#xff1a; 双指针法&#xff0c;和15题差不多&#xff0c;就是要排除了&#xff0c;如果total<target则排除了更小的&#xff08;left右移&#xff09;&#xff0c;如果total>target则排除了更大的&#xff08;rig…

pycharm 远程ssh时,mujuco提示mujoco.FatalError: gladLoadGL error

在ubuntu系统运行时完全没问题&#xff0c;但是使用pycharm远程ssh登录时就会提示这个。 解决方法&#xff1a; 1. 可以修改环境变量 2. export LD_PRELOAD/usr/lib/x86_64-linux-gnu/libstdc.so.6 参考【Mujuco】WSL2安装Mujoco用于python,遇到FatalError,以及图形驱动架构…

【Git原理与使用】远程操作标签管理

远程操作&&标签管理 1.理解分布式版本控制系统2.新建远程仓库3.克隆远程仓库4.向远程仓库推送5.拉取远程仓库6.配置 Git7.配置命令别名8.标签管理8.1创建标签8.2操作标签 点赞&#x1f44d;&#x1f44d;收藏&#x1f31f;&#x1f31f;关注&#x1f496;&#x1f496;…

RTOS系统移植

一、完成系统移植 系统移植上官网寻找合适的系统包&#xff0c;下载后将文件移植入工程文件 二、创建任务句柄、内核对象句柄&#xff08;信号量&#xff0c;消息队列&#xff0c;事件标志组&#xff0c;软件定时器&#xff09;、声明全局变量、声明函数 三、创建主函数&#…

Vue2电商项目(七)、订单与支付

文章目录 一、交易业务Trade1. 获取用户地址2. 获取订单信息 二、提交订单三、支付1. 获取支付信息2. 支付页面--ElementUI(1) 引入Element UI(2) 弹框支付的业务逻辑(这个逻辑其实没那么全)(3) 支付逻辑知识点小总结 四、个人中心1. 搭建二级路由2. 展示动态数据(1). 接口(2).…

【计算机网络 - 基础问题】每日 3 题(二十九)

✍个人博客&#xff1a;https://blog.csdn.net/Newin2020?typeblog &#x1f4e3;专栏地址&#xff1a;http://t.csdnimg.cn/fYaBd &#x1f4da;专栏简介&#xff1a;在这个专栏中&#xff0c;我将会分享 C 面试中常见的面试题给大家~ ❤️如果有收获的话&#xff0c;欢迎点赞…

【Docker】03-自制镜像

1. 自制镜像 2. Dockerfile # 基础镜像 FROM openjdk:11.0-jre-buster # 设定时区 ENV TZAsia/Shanghai RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone # 拷贝jar包 COPY docker-demo.jar /app.jar # 入口 ENTRYPOINT ["ja…

Redis:通用命令 数据类型

Redis&#xff1a;通用命令 & 数据类型 通用命令SETGETKEYSEXISTSDELEXPIRETTLTYPEFLUSHALL 数据类型 Redis的客户端提供了很多命令用于操控Redis&#xff0c;在Redis中&#xff0c;key的类型都是字符串&#xff0c;而value有多种类型&#xff0c;每种类型都有自己的操作命…

Redis篇(最佳实践)(持续更新迭代)

介绍一&#xff1a;键值设计 一、优雅的key结构 Redis 的 Key 虽然可以自定义&#xff0c;但最好遵循下面的几个最佳实践约定&#xff1a; 遵循基本格式&#xff1a;[业务名称]:[数据名]:[id]长度不超过 44 字节不包含特殊字符 例如&#xff1a; 我们的登录业务&#xff0…

Leetcode—76. 最小覆盖子串【困难】

2024每日刷题&#xff08;167&#xff09; Leetcode—76. 最小覆盖子串 C实现代码 class Solution { public:string minWindow(string s, string t) {int bestL -1;int l 0, r 0;vector<int> cnt(128);for(const char c: t) {cnt[c];}int require t.length();int m…

【实战教程】SpringBoot全面指南:快速上手到项目实战(SpringBoot)

文章目录 【实战教程】SpringBoot全面指南&#xff1a;快速上手到项目实战(SpringBoot)1. SpringBoot介绍1.1 SpringBoot简介1.2系统要求1.3 SpringBoot和SpringMVC区别1.4 SpringBoot和SpringCloud区别 2.快速入门3. Web开发3.1 静态资源访问3.2 渲染Web页面3.3 YML与Properti…

[SpringBoot] 苍穹外卖--面试题总结--上

前言 1--苍穹外卖-SpringBoot项目介绍及环境搭建 详解-CSDN博客 2--苍穹外卖-SpringBoot项目中员工管理 详解&#xff08;一&#xff09;-CSDN博客 3--苍穹外卖-SpringBoot项目中员工管理 详解&#xff08;二&#xff09;-CSDN博客 4--苍穹外码-SpringBoot项目中分类管理 详…

pytest(六)——allure-pytest的基础使用

前言 一、allure-pytest的基础使用 二、需要掌握的allure特性 2.1 Allure报告结构 2.2 Environment 2.3 Categories 2.4 Flaky test 三、allure的特性&#xff0c;allure.step()、allure.attach的详细使用 3.1 allure.step 3.2 allure.attach&#xff08;挺有用的&a…

Redis入门第四步:Redis发布与订阅

欢迎继续跟随《Redis新手指南&#xff1a;从入门到精通》专栏的步伐&#xff01;在本文中&#xff0c;我们将深入探讨Redis的发布与订阅&#xff08;Pub/Sub&#xff09;模式。这是一种强大的消息传递机制&#xff0c;适用于各种实时通信场景&#xff0c;如聊天应用、实时通知和…

3、Redis Stack扩展功能

文章目录 一、了解Redis产品二、申请RedisCloud实例三、Redis Stack体验1、RedisStack有哪些扩展&#xff1f;2、Redis JSON1、Redis JSON是什么2、Redis JSON有什么用3、Redis JSON的优势 3、Search And Query1、传统Scan搜索2、Search And Query搜索 4、Bloom Filter1、布隆过…

LabVIEW提高开发效率技巧----阻塞时钟

在LabVIEW开发中&#xff0c;阻塞时钟&#xff08;Blocking Timed Loops&#xff09;是一种常见且强大的技术&#xff0c;尤其适用于时间关键的应用。在这些应用中&#xff0c;精确控制循环的执行频率是关键任务。阻塞时钟通过等待循环的执行完成后再进入下一次迭代&#xff0c…

如何设置LTE端到端系统

LTE Setup Guide Baseline Hardware Requirements 基础硬件要求 需要2个RF前端和2个装有基于Linux的操作系统的PC。系统架构如下&#xff1a; srsUE&#xff1a;需要1个RF前端和1个PC。srsENB&#xff1a;需要1个RF前端和1个PC。srsEPC&#xff1a;需要1个PC。 系统硬件要…