LLM - LLaMA-2 获取文本向量并计算 Cos 相似度

 

目录

一.引言

二.获取文本向量

1.hidden_states 与 last_hidden_states

◆ hidden_states

◆ last_hidden_states 

2.LLaMA-2 获取 hidden_states

◆ model config 

◆ get Embedding

三.获取向量 Cos 相似度

1.向量选择

2.Cos 相似度

3.BERT-whitening 特征白化

四.总结


一.引言

前面提到了两种基于统计的机器翻译评估方法: Rouge 与 BLEU,二者通过统计概率计算 N-Gram 的准确率与召回率,在机器翻译这种回答相对固定的场景该方法可以作为一定参考,但在当前大模型更加多样性的场景以及发散的回答的情况下,Rouge 与 BLEU 有时候并不能更好的描述文本之间的相似度,下面我们尝试从 LLM 大模型提取文本的 Embedding 并进行向量相似度计算。

二.获取文本向量

1.hidden_states 与 last_hidden_states

根据 LLM 模型类型的不同,有的 Model 提供 hidden_states 方法,例如 LLaMA-2-13B,有的模型提供 last_hidden_states 方法,例如 GPT-2。查找模型对应方法 API 可以在 Transformer 官网。

 hidden_states

hidden_states 类型为 typing.Optional[typing.Tuple[torch.FloatTensor]],其提供一个 Tuple[Tensor] 分别记录了每层的输出,完整的解释在参数下方: 

模型在每一层输出处的隐藏状态加上可选的初始嵌入输出。这里我们可以通过打印模型 Layer 和索引从而获取 hidden_states 中隐层的输出。

◆ last_hidden_states 

一些传统的模型例如 GPT-2,还有当下一些的新模型例如 ChatGLM2 都有 last_hidden_states 的 API,可以直接获取最后一层的 Embedding 输出,而如果使用 hidden_states 则只需要通过 [-1] 索引即可获得 last_hidden_states,相比来如前者更全面后者更方便。

2.LLaMA-2 获取 hidden_states

model config 

    config_kwargs = {"trust_remote_code": True,"cache_dir": None,"revision": 'main',"use_auth_token": None,"output_hidden_states": True}config = AutoConfig.from_pretrained(ori_model_path, **config_kwargs)llama_model = AutoModelForCausalLM.from_pretrained(ori_model_path,config=config,torch_dtype=torch.float16,low_cpu_mem_usage=True,trust_remote_code=True,revision='main')

 根据 CausalLMOutputWithPast hidden_states 参数的提示,我们只需要在模型 config 中添加:

"output_hidden_states": True

get Embedding

def get_embeddings(result, llm_tokenizer, model, args):fw = open(args.output, 'w', encoding='utf-8')for qa in result:q = qa[0]a = qa[1]# 对输出文本进行 tokenize 和编码tokens = llm_tokenizer.encode_plus(a, add_special_tokens=True, padding='max_length', truncation=True,max_length=128, return_tensors='pt')input_ids = tokens["input_ids"]attention_mask = tokens['attention_mask']# 获取文本 Embeddingwith torch.no_grad():outputs = model(input_ids=input_ids.cuda(), attention_mask=attention_mask)embedding = list(outputs.hidden_states)last_hidden_states = embedding[-1].cpu().numpy()first_hidden_states = embedding[0].cpu().numpy()last_hidden_states = np.squeeze(last_hidden_states)first_hidden_states = np.squeeze(first_hidden_states)fisrt_larst_avg_status = np.mean(first_hidden_states + last_hidden_states, axis=0)log = "%s\t%s\t%s\n" % (q, a, toString(fisrt_larst_avg_status))fw.write(log)fw.close()

predict  预测       ➔  将 model 基于 Question generate 得到的 Answer 存入 result

encode 编码       ➔  对 Answer 进行编码获取对应 Token 与 input_ids、attention_mask

output 模型输出  ➔  直接调用 model 进行输出,有的也可以调用 model.transform 方法进行输出

hidden_states     ➔  outputs.hidden_states 获取各隐层输出

最后获取的向量需要先 cpu 然后再转为 numpy 数组,一般的做法是采用 mean 获得句子的平均表征。

三.获取向量 Cos 相似度

1.向量选择

在 BERT-flow 的论文中,如果不加任何后处理手段,那么基于 BERT 抽取句向量的最好 Pooling 方法是 BERT 的第一层与最后一层的所有 token 向量的平均,即 fisrt-larst-avg,对应 hidden_state 的 0 和 -1 索引,所以后面的相似度计算我们都以 fisrt-larst-avg 为基准来评估 Embedding 相似度。

# 获取文本 Embedding
with torch.no_grad():outputs = model(input_ids=input_ids.cuda(), attention_mask=attention_mask)embedding = list(outputs.hidden_states)last_hidden_states = embedding[-1].cpu().numpy()first_hidden_states = embedding[0].cpu().numpy()last_hidden_states = np.squeeze(last_hidden_states)first_hidden_states = np.squeeze(first_hidden_states)fisrt_larst_avg_status = np.mean(first_hidden_states + last_hidden_states, axis=0)

2.Cos 相似度

# 计算 Cos 相似度
def compute_cosine(a_vec, b_vec):norms1 = np.linalg.norm(a_vec, axis=1)norms2 = np.linalg.norm(b_vec, axis=1)dot_products = np.sum(a_vec * b_vec, axis=1)cos_similarities = dot_products / (norms1 * norms2)return cos_similarities

a_vec 为预测文本转化得到的 Embedding,b_vec 为人工标注正样本文本转化得到的 Embedding,通过计算二者相似度,评估预测文本与人工文本的相似程度。

3.BERT-whitening 特征白化

苏神在 BERT-whitening 一文中提出了一种基于 PCA 降维的无监督 Embedding 评估方式,Bert-whitening 又叫特征白化,其思路与 PCA 降维类似,意在对 SVD 分解后的主成分矩阵取前 λ 个特征向量构造特征值矩阵,提取向量中的关键信息,使输出向量矩阵每个维度均值为零,协方差矩阵为单位阵,λ 个特征值也对应前 λ 个主成分。其算法逻辑如下:

 下面我们调用 Sklearn 的 PCA 库简单实现下:

from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize# 取出句子的平均表示 -> 使用 PCA 降维 -> 白化处理concatenate = np.concatenate((answer_vector, predict_vector))pca = PCA(n_components=2048)pca.fit(concatenate)ans_white_vec = pca.transform(answer_vector)ans_norm_vec = normalize(ans_white_vec)pre_white_vec = pca.transform(predict_vector)pre_norm_vec = normalize(pre_white_vec)pca_cos_similarities = compute_cosine(ans_norm_vec, pre_norm_vec)

answec_vector 和 predict_vector 均通过 first_and_last 方法从 hidden_states 中获取,n_components 即 top_k 的选择,以 LLaMA-2 为例,原始得到的向量维度为 5120,原文中也有使用 n_components = 256 实验。

四.总结

博主采用 1500+ 样本分别使用 cos、pca 和 self_pca [自己实现 SVD 与特征矩阵] 三种方法对向量相似度进行评估,n_components 设为 1024:

可以看到 SVD 处理后得到的 W 和 mu 的 shape,通过下述操作可完成向量的降维:

vecs = (vecs + bias).dot(kernel)

最终得到的结果 Cosine 与 PCA 降维的相似度差距较大,由于自然语言生成的样本没有严格意义的正样本,上面计算采用的参考文本也是人工标注,有一定的不确定性,所以基于不同的度量,我们也可以统计分析,定一个 threshold,认为大于该 threshold 的输入样本为可用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/118336.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

centos安装nginx实操记录(加安全配置)

1.下载与安装 yum -y install nginx2.启动命令 /usr/sbin/nginx -c /etc/nginx/nginx.conf3.新建配置文件 cd /etc/nginx/conf.d vim index.conf配了一个负责均衡,如不需要,可将 server localhost: 多余的去掉 upstream web_server{server localhost…

软件测试/测试开发丨Selenium 高级定位 CSS

点此获取更多相关资料 本文为霍格沃兹测试开发学社学员学习笔记分享 原文链接:https://ceshiren.com/t/topic/27022 一、CSS选择器概念 CSS拥有自己的语法规则和表达式 CSS通常分为相对定位和绝对定位 CSS常和XPATH一起用于UI自动化测试 二、CSS相对定位使用场景 支…

webpack(一)模块化

模块化演变过程 阶段一:基于文件的划分模块方式 概念:将每个功能和相关数据状态分别放在单独的文件里 约定每一个文件就是一个单独的模块,使用每个模块,直接调用这个模块的成员 缺点:所有的成员都可以在模块外被访问和…

百度搜索清理大量低质量网站

我是卢松松,点点上面的头像,欢迎关注我哦! 据部分站长爆料:百度大规模删低质量网站的百度资源站长平台权限,很多网站都被删除了百度站长资源平台后台权限,以前在百度后台添加的网站大量被删除!…

vue左侧漏斗切换 echart图表动态更新

这个需求是根据点击左侧的箭头部分&#xff0c;右侧图表切换&#xff0c;左侧选中数据高亮&#xff08;图片用的svg&#xff09; 一、效果图 二、vue组件 <template><div class"funnel_wrap"><div class"flex_between"><div class&q…

[机器学习]分类算法系列①:初识概念

目录 1、概念 2、数据集介绍与划分 2.1、数据集的划分 2.2、sklearn数据集介绍 2.2.1、API 2.2.2、分类和回归数据集 分类数据集 回归数据集 返回类型 3、sklearn转换器和估计器 3.1、转换器 三种方法的区别 3.2、估计器 3.2.1、简介 3.2.2、API 3.3、工作流程 …

热门框架漏洞

文章目录 一、Thinkphp5.0.23 代码执行1.thinkphp5框架2.thinkphp5高危漏洞3.漏洞特征4.THinkphp5.0 远程代码执行--poc5.TP5实验一(Windows5.0.20)a.搭建实验环境b.测试phpinfoc.写入shelld.使用菜刀连接 6.TP5实验二(Linux5.0.23)a.搭建实验环境b.测试方法c.测试phpinfod.写入…

小程序快速备案助手代备案小程序开发

小程序快速备案助手代备案小程序开发 用户注册与登录&#xff1a;用户可以通过手机号或其他方式进行注册和登录&#xff0c;以便进行备案相关操作。备案信息填写&#xff1a;用户可以填写小程序的备案信息&#xff0c;包括小程序名称、小程序服务类目、域名等。备案材料上传&a…

ODC现已开源:与开发者共创企业级的数据库协同开发工具

OceanBase 开发者中心&#xff08;OceanBase Developer Center&#xff0c;以下简称 ODC&#xff09;是一款开源的数据库开发和数据库管理协同工具&#xff0c;从首个版本上线距今已经发展了三年有余&#xff0c;ODC 逐步由一款专为 OceanBase 打造的开发者工具演进成为支持多数…

香橙派Orangepi Zero2 刷机步骤

目录 1.香橙派Orangepi Zero2简介 2.刷机 2.1物料准备 2.2 格式化SD卡 2.3 烧录镜像到SD卡 2.4 安装SD卡到Orangepi 2.5 连接Pi电源 2.6 MobaXterm 串口登陆Orangepi 2.6.1 连线示意图 2.6.2 MobaXterm 使用 2.6.3修改登陆密码 2.6.4 网络配置 2.7 SSH登陆开发版…

14 mysql bit/json/enum/set 的数据存储

前言 这里主要是 由于之前的一个 datetime 存储的时间 导致的问题的衍生出来的探究 探究的主要内容为 int 类类型的存储, 浮点类类型的存储, char 类类型的存储, blob 类类型的存储, enum/json/set/bit 类类型的存储 本文主要 的相关内容是 bit/json/enum/set 类类型的相关…

中心差分法-学习笔记《结构动力学-陈政清》

激励分段解析法仅仅对外载荷进行了离散&#xff0c;但对运动方程还是严格满足的&#xff0c;体系的运动在时间轴上依然是满足运动微分方程。然而&#xff0c;一般的时域逐步积分法进一步放松要求&#xff0c;不仅仅对外荷载进行离散化处理&#xff0c;也对体系的运动进行离散化…

c++入门一

参考&#xff1a;https://www.learncpp.com/cpp-tutorial/ When you finish, you will not only know how to program in C, you will know how NOT to program in C, which is arguably as important. Tired or unhappy programmers make mistakes, and debugging code tends…

Swift 如何从图片数据(Data)检测原图片类型?

功能需求 如果我们之前把图片对应的数据(Data)保持在内存或数据库中,那么怎么从 Data 对象检测出原来图片的类型呢? 如上图所示:我们将 11 张不同类型的图片转换为 Data 数据,然后从 Data 对象正确检测出了原图片类型。 目前,我们的代码可以检测出 jpeg(jpg), tiff,…

逻辑回归Logistic

回归 概念 假设现在有一些数据点&#xff0c;我们用一条直线对这些点进行拟合&#xff08;这条直线称为最佳拟合直线&#xff09;&#xff0c;这个拟合的过程就叫做回归。进而可以得到对这些点的拟合直线方程。 最后结果用sigmoid函数输出 因此&#xff0c;为了实现 Logisti…

K8S自动化运维容器化(Docker)集群程序

K8S自动化运维容器化集群程序 一、K8S概述1.什么是K8S2.为什么要用K8S3.作用及功能 二、K8S的特性1.弹性伸缩2.自我修复3.服务发现和负载均衡4.自动发布和回滚5.集中化配置管理和秘钥管理6.存储编排7.任务批量处理运行 三、K8S的集群架构1.架构2.模式3.工作4.流程图 四、K8S的核…

如何解决vue3.0+typescript项目提示找不到模块“./App.vue

一、解决方案如下&#xff1a;需在项目目录下加上下面这段代码即可&#xff01;如果没有vite-env.d.ts目录需要继续往下看 declare module *.vue {import type { DefineComponent } from vueconst vueComponent: DefineComponent<{}, {}, any>export default vueCompon…

TBOX开发需求说明

TBOX功能需求&#xff1a; 支持4G上网功能&#xff0c;可获取外网IP&#xff0c;可和云端平台连通支持路由功能&#xff0c;支持计算平台、网关和云端平台建立网络连接支持USB转网口&#xff0c;智能座舱会通过USB连接AG35建立网络连接&#xff08;类似IVI通过USB口连接TBOX&a…

前端面试中Vue的有经典面试题二

7. Vue中给data中的对象属性添加一个新的属性时会发生什么&#xff0c;如何解决&#xff1f; 示例&#xff1a; 点击button会发现&#xff0c; obj.b 已经成功添加&#xff0c;但是视图并未刷新&#xff1a; 原因在于在Vue实例创建时&#xff0c; obj.b 并未声明&#xff0c;因…

SQL求解用户连续登录天数

数据分析面试过程中&#xff0c;一般都逃不掉对SQL的考察&#xff0c;可能是笔试的形式&#xff0c;也可能是面试过程中面试官当场提问&#xff0c;当场在纸上写出&#xff0c;或者简单说一下逻辑。 今天&#xff0c;就来分享一道面试中常常被问到的一类SQL问题&#xff1a;连…