李沐64_注意力机制——自学笔记

注意力机制

1.卷积、全连接和池化层都只考虑不随意线索

2.注意力机制则显示的考虑随意线索
(1)随意线索倍称之为查询(query)

(2)每个输入是一个值value,和不随意线索key的对

(3)通过注意力池化层来有偏向性的选择某些输入

总结

注意力机制中,通过query(随意线索)和key(不随意线索)来有偏向性的选择输入

代码实现:注意力汇聚:Nadaraya-Watson 核回归

!pip install d2l
import torch
from torch import nn
from d2l import torch as d2l

生成数据集

在这里生成了50个训练样本和50个测试样本。

n_train = 50  # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5)   # 排序后的训练样本def f(x):return 2 * torch.sin(x) + x**0.8y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 训练样本的输出
x_test = torch.arange(0, 5, 0.1)  # 测试样本
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)  # 测试样本数
n_test
50

下面的函数将绘制所有的训练样本(样本由圆圈表示), 不带噪声项的真实数据生成函数f
(标记为“Truth”), 以及学习得到的预测函数(标记为“Pred”)。

def plot_kernel_reg(y_hat):d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],xlim=[0, 5], ylim=[-1, 5])d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);

平均汇聚

y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)

在这里插入图片描述

非参数注意力汇聚

# X_repeat的形状:(n_test,n_train),
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)

在这里插入图片描述

现在来观察注意力的权重。 这里测试数据的输入相当于查询,而训练数据的输入相当于键。 因为两个输入都是经过排序的,因此由观察可知“查询-键”对越接近, 注意力汇聚的注意力权重就越高。

d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),xlabel='Sorted training inputs',ylabel='Sorted testing inputs')

在这里插入图片描述

批量矩阵乘法

X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
torch.bmm(X, Y).shape
torch.Size([2, 1, 6])

在注意力机制的背景中,我们可以使用小批量矩阵乘法来计算小批量数据中的加权平均值。

weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape((2, 10))
torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1))
tensor([[[ 4.5000]],[[14.5000]]])

使用小批量矩阵乘法, 定义Nadaraya-Watson核回归的带参数版本为:

class NWKernelRegression(nn.Module):def __init__(self, **kwargs):super().__init__(**kwargs)self.w = nn.Parameter(torch.rand((1,), requires_grad=True))def forward(self, queries, keys, values):# queries和attention_weights的形状为(查询个数,“键-值”对个数)queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))self.attention_weights = nn.functional.softmax(-((queries - keys) * self.w)**2 / 2, dim=1)# values的形状为(查询个数,“键-值”对个数)return torch.bmm(self.attention_weights.unsqueeze(1),values.unsqueeze(-1)).reshape(-1)

将训练数据集变换为键和值用于训练注意力模型。

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train, 1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train, 1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

训练带参数的注意力汇聚模型时,使用平方损失函数和随机梯度下降。

net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])for epoch in range(5):trainer.zero_grad()l = loss(net(x_train, keys, values), y_train)l.sum().backward()trainer.step()print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')animator.add(epoch + 1, float(l.sum()))

在这里插入图片描述

如下所示,训练完带参数的注意力汇聚模型后可以发现: 在尝试拟合带噪声的训练数据时, 预测结果绘制的线不如之前非参数模型的平滑。

# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)

在这里插入图片描述

与非参数的注意力汇聚模型相比, 带参数的模型加入可学习的参数后, 曲线在注意力权重较大的区域变得更不平滑。

d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),xlabel='Sorted training inputs',ylabel='Sorted testing inputs')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/314998.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python 面向对象——6.封装

本章学习链接如下: Python 面向对象——1.基本概念 Python 面向对象——2.类与对象实例属性补充解释,self的作用等 Python 面向对象——3.实例方法,类方法与静态方法 Python 面向对象——4.继承 Python 面向对象——5.多态 1. 封装的基…

每日一练-LeeCode-移除链表元素

题目 给你一个链表的头节点 head 和一个整数 val ,请你删除链表中所有满足 Node.val val 的节点,并返回 新的头节点 。 示例 1: 输入:head [1,2,6,3,4,5,6], val 6 输出:[1,2,3,4,5] 示例 2: 输入&a…

【AI开发:音频】二、GPT-SoVITS使用方法和过程中出现的问题(GPU版)

1.FileNotFoundError: [Errno 2] No such file or directory: logs/guanshenxxx/2-name2text-0.txt 这个问题中包含了两个: 第一个:No module named pyopenjtalk 我的电脑出现的就是这个 解决:pip install pyopenjtalk 第二个&#xff1a…

156.25MHz的差分晶体振荡器SG3225VEN

数字经济正焕发出勃勃生机,云计算,大数据,5G和人工智能等新技术的发展给行业带来了新的机遇。无论是在数据中心内部还是在数据中心之间,提供低成本,高速的100/200/400G小型化解决方案都是光模块的发展需求。为了使DSP稳定工作,需要一个小型的封装晶体振荡器来提供参…

13.JAVAEE之HTTP协议

HTTP 最新的版本应该是 HTTP/3.0 目前大规模使用的版本 HTTP/1.1 使用 HTTP 协议的场景 1.浏览器打开网站 (基本上) 2.手机 APP 访问对应的服务器 (大概率) 学习 HTTP 协议, 重点学习 HTTP 的报文格式 前面的 TCP/IP/UDP 和这些不同, HTTP 的报文格式,要分两个部分来看待.请求…

「51媒体」城市推介会,地方旅游推荐,怎么做好媒体宣传

传媒如春雨,润物细无声,大家好,我是51媒体网胡老师。 城市推介会和地方旅游推荐是城市形象宣传的重要组成部分,通过有效的媒体宣传可以提升城市的知名度和吸引力。: 一,活动内容层面: 突出亮点…

Jenkins CI/CD 持续集成专题四 Jenkins服务器IP更换

一、查看brew 的 services brew services list 二、编辑 homebrew.mxcl.jenkins-lts.plist 将下面的httpListenAddress值修改为自己的ip 服务器,这里我是用的本机的ip 三 、重新启动 jenkins-lts brew services restart jenkins-lts 四 、浏览器访问 http://10.…

【Django】初识Django快速上手

Django简介 Django是一个高级的、开源的Python Web框架,旨在快速、高效地开发高质量的Web应用程序 https://developer.mozilla.org/zh-CN/docs/Learn/Server-side/Django/Introduction 安装Django pip install Django如果要知道安装的Django的版本,可…

鸿蒙内核源码分析(进程管理篇) | 谁在管理内核资源?

官方基本概念 从系统的角度看,进程是资源管理单元。进程可以使用或等待CPU、使用内存空间等系统资源,并独立于其它进程运行。 OpenHarmony内核的进程模块可以给用户提供多个进程,实现了进程之间的切换和通信,帮助用户管理业务程序…

yolov8旋转目标检测输出的角度转化为适合机械爪抓取的角度

1. 机械爪抓取时旋转的角度定义 以X轴正方向(右)为零度方向,角度取值范围[-90,90)。 确认角度的方法: 逆时针旋转X轴,X轴碰到矩形框长边时旋转过的角度记为angleX: 1.如果angleX小于90&#xf…

【源码】IM即时通讯源码/H5聊天软件/视频通话+语音通话/带文字部署教程

【源码介绍】 IM即时通讯源码/H5聊天软件/视频通话语音通话/带文字部署教程 【源码说明】 测试环境:Linux系统CentOS7.6、宝塔、PHP7.2、MySQL5.6,根目录public,伪静态laravel5,根据情况开启SSL 登录后台看到很熟悉。。原来是…

docker容器技术篇:集群管理实战mesos+zookeeper+marathon(二)

docker集群管理实战mesoszookeepermarathon(二) 一 实验环境 操作系统:centos7.9 二 基础环境配置以及安装mesos 安装过程请点击下面的链接查看: 容器集群管理实战mesoszookeepermarathon(一) 三 安装…

STM32进入睡眠模式的方法

#STM32进入睡眠模式的方法 今天学习了如何控制STM32进入睡眠模式,进入睡眠模式的好处就是省电,今天学习的只是浅度睡眠,通过中断就能唤醒。比如单片机在那一放,也许好几天好几个月都不用一次,整天的在那空跑while循环…

Git重修系列 ------ Git的使用和常用命令总结

一、Git的安装和配置 git安装: Git - Downloads git首次配置用户信息: $ git config --global user.name "kequan" $ git config --global user.email kequanchanqq.com $ git config --global credential store 配置 Git 以使用本地存储机…

鸿蒙OpenHarmony【小型系统 编译】(基于Hi3516开发板)

编译 OpenHarmony支持hb和build.sh两种编译方式。此处介绍hb方式,build.sh脚本编译方式请参考[使用build.sh脚本编译源码]。 使用build.sh脚本编译源码 进入源码根目录,执行如下命令进行版本编译。 ./build.sh --product-name name --ccache 说明&…

刷机维修进阶教程---开机定屏 红字感叹号报错 写字库保资料 救砖 刷官方包保资料的步骤方法解析

在维修各种机型 中经常会遇到开机定屏 进不去系统,正常使用无故定屏进不去系统或者更新降级开机红色感叹号的一些故障机。但顾客需要报资料救砖的要求,遇到这种情况。我们首先要确定故障机型的缘由。是摔 还是更新降级 还是无故使用重启定屏等等。根据原因来对症解决。 通过…

Tiny11作者开源:利用微软官方镜像制作独属于你的Tiny11镜像

微软对Windows 11的最低硬件要求包括至少4GB的内存、双核处理器和64GB的SSD存储。然而,这些基本要求仅仅能保证用户启动和运行系统,而非流畅使用 为了提升体验,不少用户选择通过精简系统来减轻硬件负担,我们熟知的Tiny11便是其中…

Typora for Mac:轻量级Markdown编辑器

Typora for Mac是一款专为Mac用户设计的轻量级Markdown编辑器,它以其简洁的界面和强大的功能,成为了Markdown写作爱好者的首选工具。 Typora for Mac v1.8.10中文激活版下载 Typora的最大特色在于其所见即所得的编辑模式,用户无需关心复杂的M…

(七)Servlet教程——Idea编辑器集成Tomcat

1. 点击桌面上Idea快捷方式打开Idea编辑器,假如没有创建项目的话打开Idea编辑器后的界面展示如下图所示 2. 点击界面左侧菜单中的自定义 3. 然后点击界面中的“所有设置...”,然后点击“构建、执行、部署”,选择其中的“应用程序服务器” 4. 点击“”按钮…

Centos之yum安装好玩的命令

1.会动的小火车 我在root下使用的 yum install sl.x86_64sl2.figlet yum install figlet.x86_64figlet 55553.cowsay会说话 yum install cowsay