bert 适合 embedding 的模型

目录

背景

embedding

求最相似的 topk

结果查看


背景

想要求两个文本的相似度,就单纯相似度,不要语义相似度,直接使用 bert 先 embedding 然后找出相似的文本,效果都不太好,试过 bert-base-chinese,bert-wwm,robert-wwm 这些,都有一个问题,那就是明明不相似的文本却在结果中变成了相似,真正相似的有没有,

例如:手机壳迷你版,与这条数据相似的应该都是跟手机壳有关的才合理,但结果不太好,明明不相关的,余弦相似度都能有有 0.9 以上的,所以问题出在 embedding 上,找了适合做 embedding 的模型,再去计算相似效果好了很多,合理很多。

之前写了一篇 bert+np.memap+faiss文本相似度匹配 topN-CSDN博客 是把流程打通,现在是找适合文本相似的来操作。

模型:

bge-small-zh-v1.5

bge-large-zh-v1.5

embedding

数据弄的几条测试数据,方便看那些相似

我用 bge-large-zh-v1.5 来操作,embedding 代码,为了知道 embedding 进度,加了进度条功能,同时打印了当前使用 embedding 的 bert 模型输出为度,这很重要,会影响求相似的 topk

import numpy as np
import pandas as pd
import time
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel
import torchclass TextEmbedder():def __init__(self, model_name="./bge-large-zh-v1.5"):# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 自己电脑跑不起来 gpuself.device = torch.device("cpu")self.tokenizer = AutoTokenizer.from_pretrained(model_name)self.model = AutoModel.from_pretrained(model_name).to(self.device)self.model.eval()# 没加进度条的# def embed_sentences(self, sentences):#     encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')#     with torch.no_grad():#         model_output = self.model(**encoded_input)#         sentence_embeddings = model_output[0][:, 0]#     sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)##     return sentence_embeddings# 加进度条def embed_sentences(self, sentences):embedded_sentences = []for sentence in tqdm(sentences):encoded_input = self.tokenizer([sentence], padding=True, truncation=True, return_tensors='pt')with torch.no_grad():model_output = self.model(**encoded_input)sentence_embedding = model_output[0][:, 0]sentence_embedding = torch.nn.functional.normalize(sentence_embedding, p=2)embedded_sentences.append(sentence_embedding.cpu().numpy())print('当前 bert 模型输出维度为,', embedded_sentences[0].shape[1])return np.array(embedded_sentences)def save_embeddings_to_memmap(self, sentences, output_file, dtype=np.float32):embeddings = self.embed_sentences(sentences)shape = embeddings.shapeembeddings_memmap = np.memmap(output_file, dtype=dtype, mode='w+', shape=shape)embeddings_memmap[:] = embeddings[:]del embeddings_memmap  # 关闭并确保数据已写入磁盘def read_data():data = pd.read_excel('新建 XLSX 工作表.xlsx')return data['addr'].to_list()def main():# text_data = ["这是第一个句子", "这是第二个句子", "这是第三个句子"]text_data = read_data()embedder = TextEmbedder()# 设置输出文件路径output_filepath = 'sentence_embeddings.npy'# 将文本数据向量化并保存到内存映射文件embedder.save_embeddings_to_memmap(text_data, output_filepath)if __name__ == "__main__":start = time.time()main()end = time.time()print(end - start)

求最相似的 topk

使用 faiss 索引需要设置 bert 模型的维度,所以我们前面打印出来了,要不然会报错,像这样的:

ValueError: cannot reshape array of size 10240 into shape (768)

所以  print('当前 bert 模型输出维度为,', embedded_sentences[0].shape[1]) 的值换上去,我这里打印的 1024

index = faiss.IndexFlatL2(1024)  # 假设BERT输出维度是768# 确保embeddings_memmap是二维数组,如有需要转换
if len(embeddings_memmap.shape) == 1:embeddings_memmap = embeddings_memmap.reshape(-1, 1024)

完整代码 

import pandas as pd
import numpy as np
import faiss
from tqdm import tqdmdef search_top4_similarities(index_path, data, topk=4):embeddings_memmap = np.memmap(index_path, dtype=np.float32, mode='r')index = faiss.IndexFlatL2(768)  # 假设BERT输出维度是768# 确保embeddings_memmap是二维数组,如有需要转换if len(embeddings_memmap.shape) == 1:embeddings_memmap = embeddings_memmap.reshape(-1, 768)index.add(embeddings_memmap)results = []for i, text_emb in enumerate(tqdm(embeddings_memmap)):D, I = index.search(np.expand_dims(text_emb, axis=0), topk)  # 查找前topk个最近邻# 获取对应的 nature_df_img_id 的索引top_k_indices = I[0][:topk]  ## 根据索引提取 nature_df_img_idtop_k_ids = [data.iloc[index]['index'] for index in top_k_indices]# 计算余弦相似度并构建字典cosine_similarities = [cosine_similarity(text_emb, embeddings_memmap[index]) for index in top_k_indices]top_similarity = dict(zip(top_k_ids, cosine_similarities))results.append((data['index'].to_list()[i], top_similarity))return results# 使用余弦相似度公式,这里假设 cosine_similarity 是一个计算两个向量之间余弦相似度的函数
def cosine_similarity(vec1, vec2):return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))def main_search():data = pd.read_excel('新建 XLSX 工作表.xlsx')data['index'] = data.indexsimilarities = search_top4_similarities('sentence_embeddings.npy', data)# 输出结果similar_df = pd.DataFrame(similarities, columns=['id', 'top'])similar_df.to_csv('similarities.csv', index=False)# 执行搜索并保存结果
main_search()

结果查看

看一看到余弦数值还是比较合理的,没有那种明明不相关但余弦值是 0.9 的情况了,这两个模型还是可以的

实际案例

以前做过一个地址相似度聚合的,找出每个地址与它相似的地址,最多是 0-3 个相似的地址(当时人工验证过的,这里直接说明)

我们用 bge-small-zh-v1.5 模型来做 embedding,这个模型维度是 512,数据是店名id,地址两列

embedding 代码:

import numpy as np
import pandas as pd
import time
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel
import torchclass TextEmbedder():def __init__(self, model_name="./bge-small-zh-v1.5"):# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 自己电脑跑不起来 gpuself.device = torch.device("cpu")self.tokenizer = AutoTokenizer.from_pretrained(model_name)self.model = AutoModel.from_pretrained(model_name).to(self.device)self.model.eval()# 没加进度条的# def embed_sentences(self, sentences):#     encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')#     with torch.no_grad():#         model_output = self.model(**encoded_input)#         sentence_embeddings = model_output[0][:, 0]#     sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)##     return sentence_embeddingsdef embed_sentences(self, sentences):embedded_sentences = []for sentence in tqdm(sentences):encoded_input = self.tokenizer([sentence], padding=True, truncation=True, return_tensors='pt')with torch.no_grad():model_output = self.model(**encoded_input)sentence_embedding = model_output[0][:, 0]sentence_embedding = torch.nn.functional.normalize(sentence_embedding, p=2)embedded_sentences.append(sentence_embedding.cpu().numpy())print('当前 bert 模型输出维度为,', embedded_sentences[0].shape[1])return np.array(embedded_sentences)def save_embeddings_to_memmap(self, sentences, output_file, dtype=np.float32):embeddings = self.embed_sentences(sentences)shape = embeddings.shapeembeddings_memmap = np.memmap(output_file, dtype=dtype, mode='w+', shape=shape)embeddings_memmap[:] = embeddings[:]del embeddings_memmap  # 关闭并确保数据已写入磁盘def read_data():data = pd.read_excel('data.xlsx')return data['address'].to_list()def main():# text_data = ["这是第一个句子", "这是第二个句子", "这是第三个句子"]text_data = read_data()embedder = TextEmbedder()# 设置输出文件路径output_filepath = 'sentence_embeddings.npy'# 将文本数据向量化并保存到内存映射文件embedder.save_embeddings_to_memmap(text_data, output_filepath)if __name__ == "__main__":start = time.time()main()end = time.time()print(end - start)

求 embeddgin 是串行的,要想使用 gpu ,可以需修改 embed_sentences 函数:

    def embed_sentences(self, sentences, batch_size=32):inputs = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(self.device)# 计算批次数量batch_count = (len(inputs['input_ids']) + batch_size - 1) // batch_sizeembeddings_list = []with tqdm(total=len(sentences), desc="Embedding Progress") as pbar:for batch_idx in range(batch_count):start = batch_idx * batch_sizeend = min((batch_idx + 1) * batch_size, len(inputs['input_ids']))current_batch_input = inputs[start:end]with torch.no_grad():model_output = self.model(**current_batch_input)sentence_embeddings = model_output[0][:, 0]embedding_batch = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1).cpu().numpy()# 将当前批次的嵌入向量添加到列表中embeddings_list.extend(embedding_batch.tolist())# 更新进度条pbar.update(end - start)# 将所有批次的嵌入向量堆叠成最终的嵌入矩阵embeddings = np.vstack(embeddings_list)return embeddings

求 topk 的,我们求 top4 就可以了

import pandas as pd
import numpy as np
import faiss
from tqdm import tqdmdef search_top4_similarities(data_target_embedding, data_ori_embedding, data_target, data_ori, topk=4):target_embeddings_memmap = np.memmap(data_target_embedding, dtype=np.float32, mode='r')ori_embeddings_memmap = np.memmap(data_ori_embedding, dtype=np.float32, mode='r')index = faiss.IndexFlatL2(512)  # BERT输出维度# 确保embeddings_memmap是二维数组,如有需要转换if len(target_embeddings_memmap.shape) == 1:target_embeddings_memmap = target_embeddings_memmap.reshape(-1, 512)if len(ori_embeddings_memmap.shape) == 1:ori_embeddings_memmap = ori_embeddings_memmap.reshape(-1, 512)index.add(target_embeddings_memmap)results = []for i, text_emb in enumerate(tqdm(ori_embeddings_memmap)):D, I = index.search(np.expand_dims(text_emb, axis=0), topk)  # 查找前topk个最近邻# 获取对应的 nature_df_img_id 的索引top_k_indices = I[0][:topk]  ## 根据索引提取 nature_df_img_idtop_k_ids = [data_target.iloc[index]['store_id'] for index in top_k_indices]# 计算余弦相似度并构建字典cosine_similarities = [cosine_similarity(text_emb, target_embeddings_memmap[index]) for index in top_k_indices]top_similarity = dict(zip(top_k_ids, cosine_similarities))results.append((data_ori['store_id'].to_list()[i], top_similarity))return results# 使用余弦相似度公式,这里假设 cosine_similarity 是一个计算两个向量之间余弦相似度的函数
def cosine_similarity(vec1, vec2):return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))def main_search():data_target = pd.read_excel('data.xlsx')data_ori = pd.read_excel('data.xlsx')data_target_embedding = 'sentence_embeddings.npy'data_ori_embedding = 'sentence_embeddings.npy'similarities = search_top4_similarities(data_target_embedding, data_ori_embedding, data_target, data_ori)# 输出结果similar_df = pd.DataFrame(similarities, columns=['id', 'top'])similar_df.to_csv('similarities.csv', index=False)def format_res():similarities_data = pd.read_csv('similarities.csv')ori_data = pd.read_excel('data.xlsx')target_data = pd.read_excel('data.xlsx')res = pd.DataFrame()for index, row in similarities_data.iterrows():ori_id = row['id']tops = row['top']tmp_ori_data = ori_data[ori_data['store_id'] == ori_id]tmp_target_data = target_data[target_data['store_id'].isin(list(eval(tops).keys()))]res_tmp = pd.merge(tmp_ori_data, tmp_target_data, how='cross')res = pd.concat([res, res_tmp])print(f'进度 {index + 1}/{len(similarities_data)}')res.to_excel('format.xlsx', index=False)# 执行搜索并保存结果
# main_search()# 格式化
format_res()

在这里我们把原始数据当两份使用,一份作为目标数据,一份原始数据,要原始数据的每一个地址在目标数据中找相似的

最后为了人工方便查看验证,数据格式化了,开始我说了,这数据结果每个地址跟它相似的有 0-3 条,黄色的每一组,红色的是真正相似的,从结果上来看,还是符合预期的

代码链接:

链接:https://pan.baidu.com/s/1S951j1TNoN9XbRA286jU-w 
提取码:nb4b 
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/291258.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

pdf在浏览器上无法正常加载的问题

一、背景 觉得很有意思给大家分享一下。事情是这样的,开发给我反馈说,线上环境接口请求展示pdf异常,此时碰巧我前不久正好在ingress前加了一层nginx,恰逢此时内心五谷杂陈,思路第一时间便放在了改动项。捣鼓了好久无果…

vue中使用图片url直接下载图片

vue中使用图片url直接下载图片 // 下载图片downloadByBlob(url, name) {let image new Image()image.setAttribute(crossOrigin, anonymous)image.src urlimage.onload () > {let canvas document.createElement(canvas)canvas.width image.widthcanvas.height image…

Ubuntu20.04下PCL安装,查看,卸载等操作

Ubuntu20.04下PCL安装,查看,卸载等操作 项目来源 https://github.com/PointCloudLibrary/pclhttps://pointclouds.org/documentation/modules.htmlhttps://pcl.readthedocs.io/projects/tutorials/en/master/ 点云学习: https://github.c…

【Spring】SpringMvc项目当中,页面删除最后一条数据,页面不跳转并且数据为空。

期待您的关注 在之前学习SpringMvc的时候遇到过这样一个BUG,当我在一个页面删除该页面的最后一条数据的时候,一旦我删除成功,那么这个页面不会进行跳转,而是还停留在这个本不应该存在的页面,而且数据什么都没有。如下…

零基础教程:R语言lavaan结构方程模型(SEM)

查看原文>>>最新基于R语言lavaan结构方程模型(SEM)实践技术应用 基于R语言lavaan程序包,通过理论讲解和实际操作相结合的方式,由浅入深地系统介绍结构方程模型的建立、拟合、评估、筛选和结果展示的全过程。我们筛选大量…

ChatGPT与Discord的完美结合——团队协作的得力助手

本文将教你如何集成Discord Bot,助力团队在工作中实现更高效的沟通与协作。通过充分发挥ChatGPT的潜力,进一步提升工作效率和团队协作能力。无需编写任何代码即可完成本文所述的操作,进行个性化定制只需对参数进行微调即可。 方案介绍 如果在…

【JavaWeb】Day27.Web入门——Tomcat介绍

目录 WEB服务器-Tomcat 一.服务器概述 二.Web服务器 三.Tomcat- 基本使用 1.下载 2.安装与卸载 3.启动与关闭 4.常见问题 四.Tomcat- 入门程序 WEB服务器-Tomcat 一.服务器概述 服务器硬件:指的也是计算机,只不过服务器要比我们日常使用的计算…

AI绘画核心技术与实战【课程推荐】

AI结合绘画领域属于是《黑客与画家》的结合了,推荐大家来一起学习

Linux:程序地址空间详解

目录 一、堆、栈、环境参数所在位置 二、进程地址空间底层实现原理 ​编辑 三、什么是地址空间 四、为什么要有进程地址空间 五、细谈写实拷贝的实现及意义 在C/C学习中,都学习过如上图所示的一套存储结构,我们大致知道一般存储空间分为堆区&#…

代码随想录算法训练营三刷 day38 | 动态规划之 509. 斐波那契数 70. 爬楼梯 746. 使用最小花费爬楼梯

三刷day38 509. 斐波那契数1 确定dp数组以及下标的含义2 确定递推公式3 dp数组如何初始化4 确定遍历顺序5 举例推导dp数组 70. 爬楼梯1 确定dp数组以及下标的含义2 确定递推公式3 dp数组如何初始化4 确定遍历顺序5 举例推导dp数组 746. 使用最小花费爬楼梯1 确定dp数组以及下标…

[C++初阶] 爱上C++ : 与C++的第一次约会

🔥个人主页:guoguoqiang 🔥专栏:我与C的爱恋 本篇内容带大家浅浅的了解一下C中的命名空间。 在c中,名称(name)可以是符号常量、变量、函数、结构、枚举、类和对象等等。工程越大,名称…

C++进阶,手把手带你学继承

🪐🪐🪐欢迎来到程序员餐厅💫💫💫 主厨:邪王真眼 主厨的主页:Chef‘s blog 所属专栏:c大冒险 总有光环在陨落,总有新星在闪烁 【本节目标】 1.继…

【讲解下go和java的区别】

🔥博主:程序员不想YY啊🔥 💫CSDN优质创作者,CSDN实力新星,CSDN博客专家💫 🤗点赞🎈收藏⭐再看💫养成习惯 🌈希望本文对您有所裨益,如有…

Windows安装tomcat,以服务的方式管理,如何设置虚拟内存

之前工作中,部署tomcat都是使用Linux服务器,最近遇到个客户,提供的服务器是Windows server,并且需要通过服务的方式管理tomcat;以自己多年的码农经验,感觉应该没有问题,结果啪啪打脸了&#xf…

python安装删除以及pip的使用

目录 你无法想象新手到底会在什么地方出问题——十二个小时的血泪之言! 问题引入 python modify setup 隐藏文件夹 环境变量的配置 彻底删除python 其他零碎发现 管理员终端 删不掉的windous应用商店apps 发现问题 总结 你无法想象新手到底会在什么地方…

银河麒麟服务器操作系统安装SQLite数据库

SQLite,是一款轻型的数据库,是遵守ACID的关系型数据库管理系统,它包含在一个相对小的C库中。它是D.RichardHipp建立的公有领域项目。它的设计目标是嵌入式的,而且已经在很多嵌入式产品中使用了它,它占用资源非常的低&a…

服务器配置入门教程

问题环境: 现场调试的时候遇到很多离奇的问题,部分设备已经老到需要使用清华同方 Windows XP 系统的接口,所以写下这边记录,本文主要是基础教程。 快速入门常识 服务器基础知识_mezz卡-CSDN博客 基本接口识别 IOIOI-RJ45串口&a…

NLP 笔记:Latent Dirichlet Allocation (介绍篇)

1 问题介绍 假设我们有一堆新闻,每个新闻都有≥1个主题 我们现在只知道新闻的内容,我们希望一个算法,帮我们把这些新闻分类成主题人类可以根据每个每个文章里面的单词判断主题,那计算机怎么做呢? ——>LDA(Latent D…

wps表格怎么加一行详细介绍

刚接触wps表格的小伙伴肯定很多都不知道该怎么去操作吧,肯定也不知道怎么去加入一行来添加文字,为此我们带来了教程,帮助你们了解wps表格怎么加一行。 wps表格怎么加一行: 1、首先去打开wps软件,然后选中里面的行。 …

零基础入门转录组数据分析——绘制差异火山图

零基础入门转录组数据分析——绘制差异火山图 差异分析的火山图(Volcano Plot)在生物信息学数据分析中,特别是在基因表达差异分析中,是一个非常直观和有用的工具。 本教程将从导入的数据结构开始,一步步带大家在R中绘制好看的火山图,最后对火山图进行解读,确保读者理解…