从预训练的BERT中提取Embedding

文章目录

    • 背景
    • 前置准备
    • 思路
    • 利用Transformer 库实现

背景

假设要执行一项情感分析任务,样本数据如下
在这里插入图片描述
可以看到几个句子及其对应的标签,其中1表示正面情绪,0表示负面情绪。我们可以利用给定的数据集训练一个分类器,对句子所表达的情感进行分类。

前置准备

# 安装modelscope包
pip install modelscope
# 下载 bert-base-uncased 模型
modelscope download --model AI-ModelScope/bert-base-uncased

思路

  1. 分词:以第一句为例,我们使用WordPiece对句子进行分词,并得到标记(单词),如下所示。

    tokens = [I, love, Paris]

  2. 添加标记:在开头添加[CLS]标记,在结尾添加[SEP]标记,如下所示。

    tokens = [ [CLS], I, love, Paris, [SEP] ]

  3. 填充:为了保持所有标记的长度一致,我们将数据集中的所有句子的标记长度设为7。句子I loveParis的标记长度是5,为了使其长度为7,需要添加两个标记来填充,即[PAD]。因此,新标记如下所示。

    tokens = [ [CLS], I, love, Paris, [SEP], [PAD], [PAD] ]

    添加两个[PAD]标记后,标记的长度达到所要求的7。

  4. 注意力掩码:下一步,要让模型理解[PAD]标记只是为了匹配标记的长度,而不是实际标记的一部分。为了做到这一点,我们需要引入一个注意力掩码。我们将所有位置的注意力掩码值设置为1,将[PAD]标记的位置设置为0,如下所示。

    attention_mask = [ 1, 1, 1, 1, 1, 0, 0]

  5. 映射到token id:然后,将所有的标记映射到一个唯一的标记ID。假设映射的标记ID如下所示。

    token_ids = [101, 1045, 2293, 3000, 102, 0, 0]

    ID 101表示标记[CLS],1045表示标记I,2293表示标记love,以此类推。

    现在,我们把token_ids和attention_mask一起输入预训练的BERT模型,并获得每个标记的特征向量(嵌入)。通过代码,我们可以进一步理解以上步骤。下图显示的标记+单词而不是id,但实际传入的是id
    在这里插入图片描述

以上,可以得到每个单词的Embedding,整个句子的Embedding是 R [ C L S ] R_{[CLS]} R[CLS]

利用Transformer 库实现

from transformers import BertModel, BertTokenizer
import torch
# 下载并加载预训练的模型
model = BertModel.from_pretrained('bert-base-uncased')
# 下载并加载用于预训练模型的词元分析器。
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')# 下面,让我们看看如何对输入进行预处理。
# 0. 对输入进行预处理假设输入句如下所示。
sentence = 'I love Paris'
# 1. 分词
tokens = tokenizer.tokenize(sentence)
print(tokens) # ['i', 'love', 'paris']# 2. 添加标记
tokens = ['[CLS]'] + tokens + ['[SEP]']
print(tokens) # ['[CLS]', 'i', 'love', 'paris', '[SEP]']# 3. 填充
tokens = tokens + ['[PAD]'] + ['[PAD]']
print(tokens) #['[CLS]', 'i', 'love', 'paris', '[SEP]', '[PAD]', '[PAD]' ]# 4. 注意力掩码
attention_mask = [1 if i!= '[PAD]' else 0 for i in tokens]
print(attention_mask) # [1, 1, 1, 1, 1, 0, 0]# 5. 将所有标记转换为它们的标记ID
token_ids = tokenizer.convert_tokens_to_ids(tokens)
print(token_ids) # [101, 1045, 2293, 3000, 102, 0, 0]# 6. 将token_ids和attention_mask转换为张量
token_ids = torch.tensor(token_ids).unsqueeze(0)
attention_mask = torch.tensor(attention_mask).unsqueeze(0)# 7. 将token_ids和atten-tion_mask送入模型,并得到嵌入向量。
# 需要注意,model返回的输出是一个有两个值的元组。第1个值hidden_rep表示隐藏状态的特征,它包括从顶层编码器(编码器12)获得的所有标记的特征。第2个值cls_head表示[CLS]标记的特征。
hidden_rep, cls_head = model(token_ids, attention_mask = attention_mask)
print(hidden_rep.shape) # torch.Size([1, 7, 768])'''
数组[1, 7, 768]表示[batch_size, se-quence_length, hidden_size],也就是说,批量大小设为1,序列长度等于标记长度,即7。因为有7个标记,所以序列长度为7。隐藏层的大小等于特征向量(嵌入向量)的大小,在BERT-base模型中,其为768。
* hidden_rep[0][0]给出了第1个标记[CLS]的特征。   
* hidden_rep[0][1]给出了第2个标记I的特征。   
* hidden_rep[0][2]给出了第3个标记love的特征
'''print(cls_head.shape) # torch.Size([1, 768])'''
大小[1, 768]表示[batch_size, hid-den_size]。我们知道cls_head持有句子的总特征,所以,可以用cls_head作为句子I love Paris的整句特征。
'''

以上获得的是从顶层编码器(编码器12)获得的特征,如果要获取所有编码器的特征,需要修改以下两个地方。

# 下载并加载预训练的模型时,设置output_hidden_states = True
model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = True)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')# 调用模型时,产生的是三元组
last_hidden_state, pooler_output, hidden_states = model(token_ids, attention_mask = attention_mask)'''
* last_hidden_state,它仅有从最后的编码器(编码器12)中获得的所有标记的特征
* pooler_output表示来自最后的编码器的[CLS]标记的特征,它被一个线性激活函数和tanh激活函数进一步处理。
* hidden_states包含从所有编码器层获得的所有标记的特征。它是一个包含13个值的元组,含有所有编码器层(隐藏层)的特征,即从输入嵌入层h到最后的编码器层h。
'''

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/105.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HarmonyOS鸿蒙开发 弹窗及加载中指示器HUD功能实现

HarmonyOS鸿蒙开发 弹窗及加载中指示器HUD功能实现 最近在学习鸿蒙开发过程中,阅读了官方文档,在之前做flutter时候,经常使用overlay,使用OverlayEntry加入到overlayState来做添加悬浮按钮、提示弹窗、加载中指示器、加载失败的t…

基于华为ENSP的OSPF状态机、工作过程、配置保姆级别详解(2)

本篇技术博文摘要 🌟 基于华为enspOSPF状态机、OSPF工作过程、.OSPF基本配置等保姆级别具体详解步骤;精典图示举例说明、注意点及常见报错问题所对应的解决方法 引言 📘 在这个快速发展的技术时代,与时俱进是每个IT人的必修课。我…

DeepSeek:性能强劲的开源模型

deepseek 全新系列模型 DeepSeek-V3 首个版本上线并同步开源。登录官网 chat.deepseek.com 即可与最新版 V3 模型对话。 性能对齐海外领军闭源模型​ DeepSeek-V3 为自研 MoE 模型,671B 参数,激活 37B,在 14.8T token 上进行了预训练。 论…

Elastic-Job相关

文档参考视频:09_SpringBoot案例演示_哔哩哔哩_bilibili 一、Elastic-Job介绍 Elastic-Job 是一个轻量级、分布式的任务调度框架,旨在解决分布式环境下的定时任务调度问题。 1.1. Elastic-Job 的核心组件 Elastic-Job 是由多个核心组件构成的&#x…

【Linux】文件 文件描述符fd

🌻个人主页:路飞雪吖~ 🌠专栏:Linux 目录 🌻个人主页:路飞雪吖~ 一、C文件接口 🌟写文件 🌠小贴士: 🌠stdin && stdout && stderr Linux下…

Java Spring Boot实现基于URL + IP访问频率限制

点击下载《Java Spring Boot实现基于URL IP访问频率限制(源代码)》 1. 引言 在现代 Web 应用中,接口被恶意刷新或暴力请求是一种常见的攻击手段。为了保护系统资源,防止服务器过载或服务不可用,需要对接口的访问频率进行限制。本文将介绍如…

QML states和transitions的使用

一、介绍 1、states Qml states是指在Qml中定义的一组状态(States),用于管理UI元素的状态转换和属性变化。每个状态都包含一组属性值的集合,并且可以在不同的状态间进行切换。 通过定义不同的状态,可以在不同的应用场…

SpringCloud

1.认识微服务 随着互联网行业的发展,对服务的要求也越来越高,服务架构也从单体架构逐渐演变为现在流行的微服务架构。这些架构之间有怎样的差别呢? 1.0.学习目标 了解微服务架构的优缺点 1.1.单体架构 单体架构:将业务的所有功…

DSP+Simulink——点亮LED灯(TMSDSP28379D)超详细

实现功能:DSP28379D-LED灯闪烁 :matlab为2019a :环境建立见之前文章 Matlab2019a安装C2000 Processors超详细过程 matlab官网链接: Getting Started with Embedded Coder Support Package for Texas Instruments C2000 Processors Overview of Creat…

java_将数据存入elasticsearch进行高效搜索

使用技术简介: (1) 使用Nginx实现反向代理,使前端可以调用多个微服务 (2) 使用nacos将多个服务管理关联起来 (3) 将数据存入elasticsearch进行高效搜索 (4) 使用消息队列rabbitmq进行消息的传递 (5) 使用 openfeign 进行多个服务之间的api调用 参…

最近在盘gitlab.0.先review了一下docker

# 正文 本猿所在产品的代码是保存到了一个本地gitlab实例上,实例是别的同事搭建的。最近又又又想了解一下,而且已经盘了一些了,所以写写记录一下。因为这个事儿没太多的进度压力,索性写到哪儿算哪儿,只要是新了解到的…

计算机网络(四)网络层

4.1、网络层概述 简介 网络层的主要任务是实现网络互连,进而实现数据包在各网络之间的传输 这些异构型网络N1~N7如果只是需要各自内部通信,他们只要实现各自的物理层和数据链路层即可 但是如果要将这些异构型网络互连起来,形成一个更大的互…

AI人工智能(2):机器学习

1 简介 机器学习(Machine Learning)是人工智能(AI)的一个分支,它使计算机系统能够利用数据和算法自动学习和改进其性能。机器学习是让机器通过经验(数据)来做决策和预测。机器学习已经广泛应用于…

Photon最新版本PUN 2.29 PREE,在无网的局域网下,无法连接自己搭建的本地服务器

1.图1为官方解答 2.就是加上这一段段代码:PhotonNetwork.NetworkingClient.SerializationProtocol SerializationProtocol.GpBinaryV16; 完美解决 unity 商店最新PUN 2 插件 不能连接 (环境为:本地局域网 无外网情况 ) …

android 官网刷机和线刷

nexus、pixel可使用google官网线上刷机的方法。网址:https://flash.android.com/ 本文使用google线上刷机,将Android14 刷为Android12 以下是失败的线刷经历。 准备工作 下载升级包。https://developers.google.com/android/images?hlzh-cn 注意&…

Qt官方下载地址

1. 最新版本 Qt官方最新版本下载地址:https://www.qt.io/download-qt-installer 当前最新版本Qt6.8.* 如下图: 2. 历史版本 如果你要下载历史版本安装工具或者源码编译方式安装,请转至此链接进行下载:https://download.qt.i…

带格式 pdf 翻译

支持 openAI 接口,国内 deepseek 接口兼容 openAI 接口, deepseek api 又非常便宜 https://pdf2zh.com/ https://github.com/Byaidu/PDFMathTranslate

WebRTC 在视频联网平台中的应用:开启实时通信新篇章

在当今这个以数字化为显著特征的时代浪潮之下,实时通信已然稳稳扎根于人们生活与工作的方方面面,成为了其中不可或缺的关键一环。回首日常生活,远程办公场景中的视频会议让分散各地的团队成员能够跨越地理距离的鸿沟,齐聚一堂共商…

《ROS2 机器人开发 从入门道实践》 鱼香ROS2——第6章内容

第6章 建模与仿真-创建自己的机器人 6.1 机器人建模与仿真概述 6.2使用URDF创建机器人 6.2.1 帮机器人创建一个身体 1. 新建文件chapt6/chapt6_ws/src/fishbot_description/urdf/ 2. 新建文件first_robot.urdf <?xml version"1.0"?> <robot name &…

Postman接口测试03|执行接口测试、全局变量和环境变量、接口关联、动态参数、断言

目录 七、Postman 1、安装 2、postman的界面介绍 八、Postman执行接口测试 1、请求页签 3、响应页签 九、Postman的环境变量和全局变量 1、创建环境变量和全局变量可以解决的问题 2、postman中的操作-全局变量 1️⃣手动设置 2️⃣代码设置 3️⃣界面获取 4️⃣代…