bert 相似度任务训练简单版本,faiss 寻找相似 topk

目录

任务

代码

train.py

predit.py

faiss 最相似的 topk


任务

使用 bert-base-chinese 训练相似度任务,参考:微调BERT模型实现相似性判断 - 知乎

参考他上面代码,他使用的是 BertForNextSentencePrediction 模型,BertForNextSentencePrediction 原本是设计用于下一个句子预测任务的。在BERT的原始训练中,模型会接收到一对句子,并试图预测第二个句子是否紧跟在第一个句子之后;所以使用这个模型标签(label)只能是 0,1,相当于二分类任务了

但其实在相似度任务中,我们每一条数据都是【text1\ttext2\tlabel】的形式,其中 label 代表相似度,可以给两个文本打分表示相似度,也可以映射为分类任务,0 代表不相似,1 代表相似,他这篇文章利用了这种思想,对新手还挺有用的。

现在我搞了一个招聘数据,里面有办公区域列,处理过了,每一行代表【地址1\t地址2\t相似度】

只要两文本中有一个地址相似我就作为相似,标签为 1,否则 0

利用这数据微调,没有使用验证数据集,就最后使用测试集来看看效果。

代码

train.py

import json
import torch
from transformers import BertTokenizer, BertForNextSentencePrediction
from torch.utils.data import DataLoader, Dataset# 能用gpu就用gpu
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")bacth_size = 32
epoch = 3
auto_save_batch = 5000
learning_rate = 2e-5# 准备数据集
class MyDataset(Dataset):def __init__(self, data_file_paths):self.texts = []self.labels = []# 分词器用默认的self.tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')# 自己实现对数据集的解析with open(data_file_paths, 'r', encoding='utf-8') as f:for line in f:text1, text2, label = line.split('\t')self.texts.append((text1, text2))self.labels.append(int(label))def __len__(self):return len(self.texts)def __getitem__(self, idx):text1, text2 = self.texts[idx]label = self.labels[idx]encoded_text = self.tokenizer(text1, text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')return encoded_text, label# 训练数据文件路径
train_dataset = MyDataset('../data/train.txt')# 定义模型
# num_labels=5 定义相似度评分有几个
model = BertForNextSentencePrediction.from_pretrained('../bert-base-chinese', num_labels=6)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
train_loader = DataLoader(train_dataset, batch_size=bacth_size, shuffle=True)
trained_data = 0
batch_after_last_save = 0
total_batch = 0
total_epoch = 0for epoch in range(epoch):trained_data = 0for batch in train_loader:inputs, labels = batch# 不知道为啥,出来的数据维度是 (batch_size, 1, 128),需要把第二维去掉inputs['input_ids'] = inputs['input_ids'].squeeze(1)inputs['token_type_ids'] = inputs['token_type_ids'].squeeze(1)inputs['attention_mask'] = inputs['attention_mask'].squeeze(1)# 因为要用GPU,将数据传输到gpu上inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(**inputs, labels=labels)loss, logits = outputs[:2]loss.backward()optimizer.step()trained_data += len(labels)trained_process = float(trained_data) / len(train_dataset)batch_after_last_save += 1total_batch += 1# 每训练 auto_save_batch 个 batch,保存一次模型if batch_after_last_save >= auto_save_batch:batch_after_last_save = 0model.save_pretrained(f'../output/cn_equal_model_{total_epoch}_{total_batch}.pth')print("保存模型:cn_equal_model_{}_{}.pth".format(total_epoch, total_batch))print("训练进度:{:.2f}%, loss={:.4f}".format(trained_process * 100, loss.item()))total_epoch += 1model.save_pretrained(f'../output/cn_equal_model_{total_epoch}_{total_batch}.pth')print("保存模型:cn_equal_model_{}_{}.pth".format(total_epoch, total_batch))

训练好后的文件,输出的最后一个文件夹才是效果最好的模型:

predit.py

import torch
from transformers import BertTokenizer, BertForNextSentencePredictiontokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
model = BertForNextSentencePrediction.from_pretrained('../output/cn_equal_model_3_171.pth')with torch.no_grad():with open('../data/test.txt', 'r', encoding='utf8') as f:lines = f.readlines()correct = 0for i, line in enumerate(lines):text1, text2, label = line.split('\t')encoded_text = tokenizer(text1, text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')outputs = model(**encoded_text)res = torch.argmax(outputs.logits, dim=1).item()print(text1, text2, label, res)if str(res) == label.strip('\n'):correct += 1print(f'{i + 1}/{len(lines)}')print(f'acc:{correct / len(lines)}')

可以看到还是较好的学习了我数据特征:只要两文本中有一个地址相似我就作为相似,标签为 1,否则 0

faiss 最相似的 topk

使用 faiss 寻找 topk 相似的,从结果上看最相似的基本都还是找到排到较为靠前的位置

import torch
import faiss
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertModel# 假设有一个数据集df,其中包含'index'列和'text'列
df = pd.read_csv('../data/DataAnalyst.csv', encoding='gbk')  # 根据实际情况加载数据集
df = df.dropna().drop_duplicates().reset_index()
df['index'] = df.index
df = df[['index', '公司所在商区']]  # 保留所需列
df['公司所在商区'] = df['公司所在商区'].map(lambda row: ','.join(eval(row)))# device = torch.device('gpu' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')# 加载微调好的模型和tokenizer
tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
model = BertModel.from_pretrained('../output/cn_equal_model_3_171.pth')
model.eval()# 将数据集转化为模型所需的格式并计算所有样本的向量表示
def encode_texts(df):text_vectors = []for index, row in df.iterrows():text = row['公司所在商区']inputs = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')with torch.no_grad():embeddings = model(**inputs.to(device))['last_hidden_state'][:, 0]text_vectors.append(embeddings.cpu().numpy())print(f'{index + 1}/{len(df)}')return np.vstack(text_vectors)# 加载数据集并计算所有样本的向量
print('enbedding all data...')
all_embeddings = encode_texts(df)# 初始化Faiss索引
print('init faiss all embedding...')
index = faiss.IndexFlatIP(all_embeddings.shape[1])  # 使用内积空间,适用于余弦相似度
index.add(all_embeddings)
print('init faiss all embedding finish~~~')# 定义查找最相似样本的函数
def find_top_k_similar(query_text, k=100):print('当前 query_text embedding.')query_embedding = encode_single_text(query_text)print('begin to search topk....')D, I = index.search(query_embedding, k)  # 返回距离和索引top_k_indices = df.iloc[I[0]].index.tolist()  # 将索引转换为原始数据集的索引return top_k_indices# 编码单个文本的函数
def encode_single_text(text):inputs = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')with torch.no_grad():embedding = model(**inputs.to(device))['last_hidden_state'][:, 0].cpu().numpy()print('当前 query_text embedding finish!')return embedding# 示例:找一个query_text的top10相似样本
query_text = "左家庄,国展,西坝河"
top10_indices = find_top_k_similar(query_text)
# 获取与查询文本最相似的前10条原始文本
top10_texts = [df.loc[index, '公司所在商区'] for index in top10_indices]print(f"与'{query_text}'最相似的前100条样本及其文本:")
for i, (idx, text) in enumerate(zip(top10_indices, top10_texts)):print(f"{i+1}. 索引:{idx},文本:{text}")

数据

链接:https://pan.baidu.com/s/1Cpr-ZD9Neakt73naGdsVTw 
提取码:eryw 
链接:https://pan.baidu.com/s/1qHYjXC7UCeUsXVnYTQIPCg 
提取码:o8py 
链接:https://pan.baidu.com/s/1CTntG1Z6AIhiPt6i8Ad97Q 
提取码:x6sz 
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/269267.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

紫光展锐T618_4G安卓核心板方案定制

紫光展锐T618核心板是一款采用纯国产化方案的高性能产品,搭载了开放的智能Android操作系统,并集成了4G网络,支持2.5G5G双频WIFI、蓝牙近距离无线传输技术以及GNSS无线定位技术。 展锐T618核心板应用旗舰级 DynamlQ架构 12nm 制程工艺&#x…

掌握PDF全面指南:Python开发者的高效编程技巧

掌握PDF全面指南:Python开发者的高效编程技巧 简介PDF基础知识PDF的结构常见用途PDF在开发中的挑战 PDF处理库介绍PyPDF2ReportLabPDFMiner辅助库 读取和分析PDF文件使用PyPDF2读取PDF文件提取PDF中的文本和元数据分析PDF结构和内容 编辑和修改PDF文件合并多个PDF文…

OSPF故障排查,这10大技巧是个网工都在用!

中午好,我的网工朋友。 OSPF这个名词网工们都不陌生吧。 OSPF,即开放式最短路径优先(Open Shortest Path First,OSPF)是广泛使用的一种动态路由协议。 它属于链路状态路由协议,具有路由变化收敛速度快、…

【RISC-V 指令集】RISC-V 向量V扩展指令集介绍(二)-向量元素到向量寄存器状态的映射

1. 引言 以下是《riscv-v-spec-1.0.pdf》文档的关键内容: 这是一份关于向量扩展的详细技术文档,内容覆盖了向量指令集的多个关键方面,如向量寄存器状态映射、向量指令格式、向量加载和存储操作、向量内存对齐约束、向量内存一致性模型、向量…

Python爬取网站视频资源

思路: 在界面找到视频对应的html元素位置,观察发现视频的url为https://www.pearvideo.com/video_视频的id,而这个id在html中的href中,所以第一步需要通过xpath捕获到所需要的id 在https://www.pearvideo.com/video_id的页面&…

LabVIEW非接触式电阻抗层析成像系统

LabVIEW非接触式电阻抗层析成像系统 非接触式电阻抗层析成像(NEIT)技术以其无辐射、非接触、响应速度快的特点,为实时监测提供了新的解决方案。基于LabVIEW的电阻抗层析成像系统,实现了数据的在线采集及实时成像,提高…

记一次dockerfile无法构建问题追溯

我有一个dockerfile如下: ENTRYPOINT ["/sbin/tini","-g", "--"] CMD /home/scrapy/start.sh 我原本的用意是先启动tini,再执行下面的cmd命令启动start.sh。 为啥要用tini? 因为我的这个docker…

如何选择程序员职业赛道

目录 前言1 个人技能分析1.1 技术栈评估1.2 经验积累1.3 数据科学能力 2 兴趣与价值观2.1 用户交互与界面设计2.2 复杂问题解决与系统优化 3 长期目标规划4 市场需求分析4.1 人工智能和云计算4.2 前沿技术趋势 5 就业前景5.1 前端在创意性公司中的应用5.2 后端在大型企业中的广…

Windows Docker 部署 MySQL

部署 MySQL 打开 Docker Desktop,切换到 Linux 容器。然后在 PowerShell 执行下面命令,即可启动一个 MySQL 服务。这里安装的是 8.3.0 Tag版本,如果需要安装其他或者最新版本,可以到 Docker Hub 进行查找。 docker run -itd --n…

YOLO v9训练自己数据集

原以为RT-DETR可以真的干翻YOLO家族,结果,!!!! 究竟能否让卷积神经网络重获新生? 1.数据准备 代码地址:https://github.com/WongKinYiu/yolov9 不能科学上网的评论区留言 数据集…

Web前端---表格和表单

1.表格概述 表格标记&#xff1a;<table></table> 表格标题标记&#xff1a;<caption></caption> 表头&#xff1a;<th></th>------heading 行标记&#xff1a;<tr></tr>-----r是row 列标记&#xff1a;<td></t…

HQL,SQL刷题,尚硅谷

目录 相关表数据&#xff1a; ​编辑 题目及思路解析&#xff1a; 复杂查询&#xff0c;子查询 1、查询所有课程成绩均小于60分的学生的学号、姓名 2、查询没有学全所有课的学生的学号、姓名 3、查询出只选修了三门课程的全部学生的学号和姓名 总结归纳&#xff1a; 知识补充&a…

JavaWeb Tomcat启动、部署、配置、集成IDEA

web服务器软件 服务器是安装了服务器软件的计算机&#xff0c;在web服务器软件中&#xff0c;可以部署web项目&#xff0c;让用户通过浏览器来访问这些项目。 Web服务器是一个应用程序&#xff08;软件&#xff09;&#xff0c;对HTTP协议的操作进行封装&#xff0c;使得程序…

【C语言】Leetcode 876. 链表的中间节点

主页&#xff1a;17_Kevin-CSDN博客 专栏&#xff1a;《Leetcode》 题目 通过题目的要求可以判断出有两种示例要解决&#xff0c;一种是偶数节点的链表&#xff0c;一种是奇数节点的链表&#xff0c;应对这两种情况我们需要使程序对二者都可以兼容。 解决思路 struct ListNode…

吴恩达机器学习笔记:第5周-9 神经网络的学习2(Neural Networks: Learning)

目录 9.4 实现注意&#xff1a;展开参数9.5 梯度检验9.6 随机初始化9.7 综合起来9.8 自主驾驶 9.4 实现注意&#xff1a;展开参数 在上一段视频中&#xff0c;我们谈到了怎样使用反向传播算法计算代价函数的导数。在这段视频中&#xff0c;我想快速地向你介绍一个细节的实现过…

java八股文复习-----2024/03/03

1.接口和抽象类的区别 相似点&#xff1a; &#xff08;1&#xff09;接口和抽象类都不能被实例化 &#xff08;2&#xff09;实现接口或继承抽象类的普通子类都必须实现这些抽象方法 不同点&#xff1a; &#xff08;1&#xff09;抽象类可以包含普通方法和代码块&#x…

Socket网络编程(四)——点对点传输场景方案

目录 场景如何去获取到TCP的IP和Port&#xff1f;UDP的搜索IP地址、端口号方案UDP搜索取消实现相关的流程&#xff1a;代码实现逻辑服务端实现客户端实现UDP搜索代码执行结果 TCP点对点传输实现代码实现步骤点对点传输测试结果 源码下载 场景 在一个局域网当中&#xff0c;不知…

LabVIEW齿轮传动健康状态静电在线监测

LabVIEW齿轮传动健康状态静电在线监测 随着工业自动化的不断发展&#xff0c;齿轮传动作为最常见的机械传动方式之一&#xff0c;在各种机械设备中发挥着至关重要的作用。然而&#xff0c;齿轮在长期运行过程中易受到磨损、变形等因素影响&#xff0c;进而影响整个机械系统的稳…

BUUCTF------[HCTF 2018]WarmUp

开局一个表情&#xff0c;源代码发现source.php <?phphighlight_file(__FILE__);class emmm{public static function checkFile(&$page){$whitelist ["source">"source.php","hint">"hint.php"];if (! isset($page) |…

Vue - 调用接口获取文件数据流并根据类型预览

Vue - 调用接口获取文件数据流并根据类型预览 一、接口返回的数据流格式二. 方法实现1. image 图片类型2. txt 文件类型3. pdf 文件类型 一、接口返回的数据流格式 二. 方法实现 1. image 图片类型 <img :src"imageUrl" alt"" srcset"" /&g…