原型网络Prototypical Network的python代码逐行解释,新手小白也可学会!!-----系列8

在这里插入图片描述

文章目录

  • 前言
  • 一、原始代码
  • 二、对每一行代码的解释:
  • 总结


前言

这是该系列原型网络的最后一段代码及其详细解释,感谢各位的阅读!


一、原始代码

if __name__ == '__main__':##载入数据labels_trainData, labels_testData = load_data()  # labels_trainData是字典,是key:value形式class_number_train = max(list(labels_trainData.keys())) #963class_number_test = max(list(labels_testData.keys())) #658wide = labels_trainData[0][0].shape[0]  # 105      #二维张量,shape[0]代表行数,shape[1]代表列数length = labels_trainData[0][0].shape[1]  # 105for label in labels_trainData.keys():labels_trainData[label] = np.reshape(labels_trainData[label], [-1, 1, wide, length])for label in labels_testData.keys():labels_testData[label] = np.reshape(labels_testData[label], [-1, 1, wide, length])##初始化模型protonets = Protonets((1, wide, length), 10, 5, 5, 60, './log/', 50)  # '''根据需求修改类的初始化参数,参数含义见protonets_net.py'''##训练prototypical_networkfor n in range(100):  ##随机选取x个类进行一个episode的训练protonets.train(labels_trainData, class_number_train)if n % 2 == 0 and n != 0:  # 每5次存储一次模型,并测试模型的准确率,训练集的准确率和测试集的准确率被存储在model_step_eval.txt中torch.save(protonets.model, './log/model_net_' + str(n) + '.pkl')protonets.save_center('./log/model_center_' + str(n) + '.csv')test_accury = protonets.evaluation_model(labels_testData, class_number_test)print(test_accury)str_data = str(n) + ',' + str('       test_accury     ') + str(test_accury) + '\n'with open('./log/model_step_eval.txt', "a") as f:f.write(str_data)print(n)

二、对每一行代码的解释:

  1. if __name__ == '__main__':
    这是一个Python的惯用写法,表示当脚本直接被运行时(而不是被作为模块导入时),才会执行下面的代码块。

  2. labels_trainData, labels_testData = load_data()
    调用 load_data() 函数加载数据,并将返回的标签训练数据和标签测试数据保存到 labels_trainDatalabels_testData 变量中。

  3. class_number_train = max(list(labels_trainData.keys()))
    获取标签训练数据中的最大键(即最大类别数),并将其保存到 class_number_train 变量中。

  4. class_number_test = max(list(labels_testData.keys()))
    获取标签测试数据中的最大键(即最大类别数),并将其保存到 class_number_test 变量中。

  5. wide = labels_trainData[0][0].shape[0]
    获取标签训练数据中第一个样本的宽度,并将其保存到 wide 变量中。

  6. length = labels_trainData[0][0].shape[1]
    获取标签训练数据中第一个样本的长度,并将其保存到 length 变量中。

  7. for label in labels_trainData.keys():
    遍历标签训练数据中的所有键。

  8. labels_trainData[label] = np.reshape(labels_trainData[label], [-1, 1, wide, length])
    对每个标签训练数据进行重塑,将其形状改为 [-1, 1, wide, length],其中 -1 表示自动计算维度大小。

  9. for label in labels_testData.keys():
    遍历标签测试数据中的所有键。

  10. labels_testData[label] = np.reshape(labels_testData[label], [-1, 1, wide, length])
    对每个标签测试数据进行重塑,将其形状改为 [-1, 1, wide, length]

  11. protonets = Protonets((1, wide, length), 10, 5, 5, 60, './log/', 50)
    创建一个 Protonets 类的实例,传入模型的初始化参数。

  12. for n in range(100):
    从0到99的循环中,执行以下代码块。

  13. protonets.train(labels_trainData, class_number_train)
    调用 protonets 实例的 train() 方法进行模型训练,传入标签训练数据和类别数。

  14. if n % 2 == 0 and n != 0:
    如果 n 是偶数且不为0,则执行以下代码块。

  15. torch.save(protonets.model, './log/model_net_' + str(n) + '.pkl')
    保存模型到 './log/model_net_' + str(n) + '.pkl' 的文件路径。

  16. protonets.save_center('./log/model_center_' + str(n) + '.csv')
    调用 protonets 实例的 save_center() 方法,将模型的中心点保存到 './log/model_center_' + str(n) + '.csv'

  17. test_accury = protonets.evaluation_model(labels_testData, class_number_test)
    调用 protonets 实例的 evaluation_model() 方法,对模型进行评估并返回测试准确率,将其保存到 test_accury 变量中。

  18. print(test_accury)
    打印测试准确率。

  19. str_data = str(n) + ',' + str(' test_accury ') + str(test_accury) + '\n'
    构建一个字符串以保存到文件中。

  20. with open('./log/model_step_eval.txt', "a") as f:
    打开一个文件,以追加模式写入。


总结

原型网络(Prototypical Network)是一种用于小样本学习的模型,由Jake Snell等人于2017年提出。它是一种基于元学习(meta-learning)的方法,主要用于解决在具有少量标记样本的情况下进行分类任务的问题。

传统的深度学习模型在处理小样本学习时通常表现不佳,因为它们需要大量的标记样本来进行训练。然而,在现实世界中,我们往往只有少量标记样本可用。原型网络通过引入一个用于表示类别的中心向量(原型)的概念,解决了这个问题。

原型网络的功能和优势如下:

  1. 小样本学习:原型网络适用于具有少量标记样本的分类任务,可以在只有几个样本可用时进行准确的分类。

  2. 元学习能力:原型网络通过学习类别的原型向量,能够在遇到新类别时进行快速学习,从而实现元学习的目标。

  3. 欧氏距离度量:原型网络使用欧氏距离来度量样本与原型之间的相似性,从而进行分类推断。这种度量方式非常直观和可解释,使得模型更易于理解

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/197288.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux下非root用户安装CUDA

目录 前言 参考链接 步骤 一. 首先,需要查看系统版本: 二. 安装包下载。 下载CUDA: cuDNN下载 三. 开始安装CUDA和cuDNN 安装CUDA 修改环境变量 安装 cuDNN 查看是否安装成功,输入nvcc -V 前言 由于一些代码实现&…

消息消费过程

前言 本文介绍下Kafka消费过程, 内容涉及消费与消费组, 主题与分区, 位移提交,分区再平衡和消费者拦截器等内容。 消费者与消费组 Kafka将消费者组织为消费组, 消息只会被投递给消费组中的1个消费者。因此, 从不同消费组中的消费者来看, Kafka是多播(Pub/Sub)模式…

十三、Docker的安装

0.安装Docker Docker 分为 CE 和 EE 两大版本。CE 即社区版(免费,支持周期 7 个月),EE 即企业版,强调安全,付费使用,支持周期 24 个月。 Docker CE 分为 stable test 和 nightly 三个更新频道…

CTFhub-RCE-过滤cat

查看当前目录:输入:127.0.0.1|ls 127.0.0.1|cat flag_42211411527984.php 无输出内容 使用单引号绕过 127.0.0.1|cat flag_42211411527984.php|base 64 使用双引号绕过 127.0.0.1|c""at flag_42211411527984.php|base64 使用特殊变量绕过 127.0.0.…

第四篇 《随机点名答题系统》——基础设置详解(类抽奖系统、在线答题系统、线上答题系统、在线点名系统、线上点名系统、在线考试系统、线上考试系统)

目录 1.功能需求 2.数据库设计 3.流程设计 4.关键代码 4.1.设置题库 4.1.1数据请求示意图 4.1.2选择题库(index.php)数据请求代码 4.1.3取消题库(index.php)数据请求代码 4.1.4业务处理Service(xztk.p…

计算机的发展

硬件的发展 第一台电子数字计算机:ENIAC(1946),作者:冯诺依曼,逻辑元件:电子管 bug:小虫子,会影响打点 Intel: 机器字长:计算机一次整数运算所能…

企业计算机服务器中了mallox勒索病毒怎么解决,勒索病毒解密文件恢复

随着科技技术的不断发展,网络技术得到了快速提升,但网络安全威胁也不断增加,近期,云天数据恢复中心陆续接到很多企业的求助信息,企业的计算机服务器遭到了mallox勒索病毒攻击,导致企业的所有业务中断&#…

[nlp] 损失缩放(Loss Scaling)loss sacle

在深度学习中,由于浮点数的精度限制,当模型参数非常大时,会出现数值溢出的问题,这可能会导致模型训练不稳定。为了解决这个问题,损失缩放(Loss Scaling)技术被引入,它通过缩放损失值来解决这个问题。 在深度学习中,损失缩放技术通常是通过将梯度进行缩放来实现的。具…

鸿蒙APP外包开发上线流程

鸿蒙系统的上线流程可能会根据具体的版本和平台要求而略有不同。在进行上线之前,开发人员应该详细了解并遵循鸿蒙生态系统的相关规定和要求。鸿蒙(HarmonyOS)应用的上线流程通常包括以下步骤,希望对大家有所帮助。北京木奇移动技术…

【深度学习】pytorch快速得到mobilenet_v2 pth 和onnx

在linux执行这个程序: import torch import torch.onnx from torchvision import transforms, models from PIL import Image import os# Load MobileNetV2 model model models.mobilenet_v2(pretrainedTrue) model.eval()# Download an example image from the P…

安卓中轻量级数据存储方案分析探讨

轻量级数据存储功能通常用于保存应用的一些常用配置信息,并不适合需要存储大量数据和频繁改变数据的场景。应用的数据保存在文件中,这些文件可以持久化地存储在设备上。需要注意的是,应用访问的实例包含文件所有数据,这些数据会一…

Qt6版使用Qt5中的类遇到的问题解决方案

如果有需要请关注下面微信公众号,会有更多收获! 1.QLinkedList 是 Qt 中的一个双向链表类。它提供了高效的插入和删除操作,尤其是在中间插入和删除元素时,比 QVector 更加优秀。下面是使用 QLinkedList 的一些基本方法&#xff1a…

微服务学习 | Eureka注册中心

微服务远程调用 在order-service的OrderApplication中注册RestTemplate 在查询订单信息时,需要同时返回订单用户的信息,但是由于微服务的关系,用户信息需要在用户的微服务中去查询,故需要用到上面的RestTemplate来让订单的这个微…

JVM虚拟机:通过日志学习PS+PO垃圾回收器

我们刚才设置参数的时候看到了-XXPrintGCDetails表示输出详细的GC处理日志,那么我们如何理解这个日志呢?日志是有规则的,我们需要按照这个规则来理解日志中的内容,它有两个格式,一个格式是GC的格式(新生代&…

YOLOv8/YOLOv7/YOLOv5/YOLOv4/Faster-rcnn系列算法改进【NO.79】改进损失函数为VariFocal Loss

前言 作为当前先进的深度学习目标检测算法YOLOv8,已经集合了大量的trick,但是还是有提高和改进的空间,针对具体应用场景下的检测难点,可以不同的改进方法。此后的系列文章,将重点对YOLOv8的如何改进进行详细的介绍&…

Egress-TLS-Origination

目录 文章目录 目录本节实战1、出口网关TLS发起2、通过 egress 网关发起双向 TLS 连接关于我最后 本节实战 实战名称🚩 实战:Egress TLS Origination-2023.11.19(failed)🚩 实战:通过 egress 网关发起双向 TLS 连接-2023.11.19(测…

CDN是什么,能起到什么作用

随着互联网的快速发展,用户对于快速、稳定、高效的互联网体验的需求日益增长。为了满足这一需求,内容分发网络(CDN)应运而生,并在近年来得到了广泛应用。CDN通过在全球范围内部署大量的服务器和网络节点,实…

excel怎么能锁住行 和/或 列的自增长,保证粘贴公式的时候不自增长或者只有部分自增长

例如在C4单元格中输入了公式: 现在如果把C4拷贝到C5,D3会自增长为D4: 现在如果想拷贝的时候不自增长,可以先把光标放到C4单元格,然后按F4键,行和列的前面加上了$符号,锁定了: …

QEMU显示虚拟化的几种选项

QEMU可以通过通过命令行"-vga type"选择为客户机模拟的VGA卡的类别,可选择的类型有多个: -vga typeSelect type of VGA card to emulate. Valid values for type arecirrusCirrus Logic GD5446 Video card. All Windows versions starting from Windows 95 should …

深度学习YOLO图像视频足球和人体检测 - python opencv 计算机竞赛

文章目录 0 前言1 课题背景2 实现效果3 卷积神经网络4 Yolov5算法5 数据集6 最后 0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 深度学习YOLO图像视频足球和人体检测 该项目较为新颖,适合作为竞赛课题方向,学长非…