Python - 深度学习系列38 重塑实体识别5-预测并行化改造

说明

在重塑实体识别4中梳理了数据流,然后我发现pipeline的串行效率太低了,所以做了并行化改造。里面还是有不少坑的,记录一下。

内容

1 pipeline

官方的pipeline看起来的确是比较好用的,主要是实现了比较好的数据预处理。因为在训练/使用过程中都要进行数据的令牌化与反令牌化,有些字符会被特殊处理,例如 '##A’等。
在这里插入图片描述
在使用过程中,我用200条新闻数据进行测试,用pipeline方法花了11分钟处理完毕,期间CUDA的使用率大约为10%。按此估算,即使用多接口并行的方式,那么一分钟最多处理2000条,一天最多处理0.14*2000~30万条数据。这个效率太低了。

2 并行化

最终的结论是不到30秒处理200条,显存只占用2.6G,理论上可以支持3个服务并行(以确保GPU的完全利用)。按最保守的估计,改造后的并行化应该可以提升3倍的效率,稍微激进一点,可以提升10倍的效率。这个之后可以进行测试。

一些主要的点如下

2.1 结果解析

结果可以分为:

  • 1 仅含解析出的实体列表,用逗号连接字符串表示。
  • 2 含实体及其起始位置的表示,这个用于标注反馈、二次增强处理。
  • 3 仅含BIO标签,主要用于和测试数据进行效果比对。

对应的相关函数,看起来有点繁杂,我自己都不太想看第二眼。

from datasets import ClassLabel
# 定义标签列表
label_list = ['B', 'I', 'O']
# 创建 ClassLabel 对象
class_label = ClassLabel(names=label_list)
def convert_entity_label_batch(x):x1 = xreturn class_label.int2str(x1)
# 定义函数将整数Tensor转换为字符串 | 反令牌函数,但是用不上;因为predict label列表的长度和 ss_padding相同
def tensor_to_string(tensor, tokenizer = None , skip_special_tokens = True):return tokenizer.decode(tensor.tolist(), skip_special_tokens=skip_special_tokens).replace(' ','')from datasets import ClassLabel
def detokenize(word_piece):"""将 WordPiece 令牌还原为原始句子。"""if word_piece.startswith('##'):x = word_piece[2:]else:x = word_piecereturn x
import re
def extract_bio_positions(bio_string):pattern = re.compile(r'B(I+)(O|$)')matches = pattern.finditer(bio_string)results = []for match in matches:start, end = match.span()results.append((start, end - 1))  # end-1 to include the last 'I'return results# 0.1ms
def parse_ent_pos_map_batch(some_dict = None):word_list = some_dict['token_words']label_list = [int(x) for x in list(some_dict['label_list'])]min_len = min(len(word_list),len(label_list))word_list = word_list[:min_len]label_list = label_list[:min_len]label_list1 =  list(map(convert_entity_label_batch,label_list))oriword_list1 = list(map(detokenize,word_list))ori_word_str =''.join(oriword_list1)# 补到等长label_str = ''for i in range(len(label_list1)):len_of_ori_word = len(oriword_list1[i])if len_of_ori_word == 1:tem_str = label_list1[i]else:if label_list1[i] in ['I','O']:tem_str = label_list1[i] * len_of_ori_wordelse:tem_str = 'B' + 'I' * (len_of_ori_word -1)        label_str += tem_strpos_list = extract_bio_positions(label_str)part_ent_list = [(ori_word_str[x[0]:x[1]] , *x) for x in pos_list]return part_ent_list# =============
def make_BIO_by_len(some_len):default_str = 'I' * some_lenstr_list = list(default_str)str_list[0] ='B'return str_list
def gen_BIO_list2(some_dict):the_content = some_dict['clean_data']ent_list =  some_dict['ent_tuple_list']content_list = list(the_content)tag_list = list('O'* len(content_list))for ent_info in ent_list:start = ent_info[1]end = ent_info[2]label_len = end-starttem_bio_list = make_BIO_by_len(label_len)tag_list[start:end] = tem_bio_listres_dict = {}res_dict['x'] = content_listres_dict['y'] = tag_listreturn res_dictdef trim_len(some_dict = None):padding_BIO = some_dict['padding_BIO']ss_len = some_dict['ss_len']return padding_BIO[:ss_len]

2.2 批量预测

看起来同样很繁杂,但是不得不细看。首先,数据会按照几个长度 20,50,198分为三部分处理,batch_predict每次仅处理一个批次。在这里,将数据转为定长的令牌长度,然后转入CUDA进行批量预测。结果再按照实体-位置 tuple, 实体列表和BIO三种方式进行解析。

from functools import partial
import transformers 
import torch 
from transformers import AutoModelForMaskedLM, AutoTokenizer,AutoModelForTokenClassification
from functools import partial
# some_batch 是原文经过padding的数据,['ss_hash','ss','ss_len', 'ss_padding'], 其中ss_padding的长度是固定的
# 模型文件和令牌文件都放在model_path之下,model比较大,避免重载;而tokenize会有padding过程,必须重载
# 模型先载入cuda
def batch_predict(some_batch, ss_padding_len = None, model = None, model_path = None):# 因为tokenize会在令牌的前后加上分隔令牌,所以+2if ss_padding_len is None:ss_padding_len = some_batch['ss_padding'].apply(len).max()print('ss_padding_len is %s ' % ss_padding_len)max_len = ss_padding_len+2tokenizer = AutoTokenizer.from_pretrained(model_path)tencoder = partial(tokenizer.encode,truncation=True, max_length=max_len, is_split_into_words=True, return_tensors="pt",  padding='max_length')some_batch['ss_padding_token'] = some_batch['ss_padding'].apply(list).apply(tencoder)# 构成矩阵minput = torch.cat(list(some_batch['ss_padding_token'].values))# 将数据搬到GPU中处理再返回with torch.no_grad():input_cuda = minput.to(device)outputs_cuda = model(input_cuda).logitspredictions = torch.argmax(outputs_cuda, dim=2)predictions_list = list(predictions.to('cpu').numpy())predict_list1 = []for predictions in predictions_list:tem_pred_tag = [int(x) for x in predictions[1:-1]]predict_list1.append(tem_pred_tag)some_batch['label_list'] = predict_list1_s = cols2s(some_df =some_batch, cols= ['ss_padding','label_list'], cols_key_mapping= ['token_words', 'label_list'])_s1 = _s.apply(parse_ent_pos_map_batch)some_batch['ent_tuple_list'] = list(_s1)some_batch['ent_list'] = some_batch['ent_tuple_list'].apply(lambda x: ','.join([a[0] for a in x ]))_s = cols2s(some_batch, cols= ['ss_padding', 'ent_tuple_list'], cols_key_mapping= ['clean_data', 'ent_tuple_list'])s1 = _s.apply(gen_BIO_list2)ent_tuple_res_df1 = pd.DataFrame(s1.to_list())some_batch['padding_BIO'] = list(ent_tuple_res_df1['y'].apply(lambda x: ''.join(x)))_s00 = cols2s(some_batch, cols = ['ss_len', 'padding_BIO'], cols_key_mapping=['ss_len', 'padding_BIO'])some_batch['BIO'] = list(_s00.apply(trim_len))return some_batch    

3 迭代器

在推送数据处理时,可以采用迭代器来控制不同的批次数据

# 迭代器切分
import pandas as pd
class DataFrameBatchIterator:def __init__(self, dataframe, batch_size):self.dataframe = dataframeself.batch_size = batch_size# 【我增加的】self.fail_batch_list = []def __iter__(self):num_rows = len(self.dataframe)num_batches = (num_rows - 1) // self.batch_size + 1for i in range(num_batches):start_idx = i * self.batch_sizeend_idx = (i + 1) * self.batch_sizebatch_data = self.dataframe.iloc[start_idx:end_idx]yield batch_data# 【我增加的】def clear_fail(self):self.fail_batch_list = []# 【我增加的】def get_some_batch(self, batch_idx):return self.dataframe.iloc[self.batch_size * batch_idx: self.batch_size * (batch_idx + 1)]# 【我增加的】记录失败的批次def rec_fail_batch_idx(self, batch_idx):self.fail_batch_list.append(batch_idx)
# 创建一个示例 DataFrame
data = {'Name': ['John', 'Jane', 'Mike', 'Alice', 'Bob'],'Age': [25, 30, 35, 28, 32],'City': ['New York', 'Paris', 'London', 'Tokyo', 'Sydney']}
df = pd.DataFrame(data)
# 创建 DataFrame 迭代器
batch_iterator = DataFrameBatchIterator(df, batch_size=2)
import tqdm
# 使用迭代器逐批次处理数据
for i,batch in tqdm.tqdm(enumerate(batch_iterator)):try:# 在这里可以对当前批次的数据进行相应的操作# 例如进行数据清洗、特征处理、模型训练等# 示例:打印当前批次的数据
#         raise Exception(e) print(batch)except:print('>>> %s Fail' % i)batch_iterator.rec_fail_batch_idx(i)

以下是实际的调度

# 假设处理长度为1万的句子
# 20 * 2000 ~ 4w
# 50 * 800 ~  4w
# 200 * 200 ~ 4w
import warnings 
warnings.filterwarnings('ignore')
batch_slice_para = {20:2000, 50:800, 200:200}
batch_len_list = sorted(list(batch_slice_para.keys()))
batch_len_list.insert(0,0)batch_df_list = []
for i in range(len(batch_len_list)):if i >0:sel = (ss_df['ss_len'] >=batch_len_list[i-1]) & (ss_df['ss_len'] < batch_len_list[i])if sel.sum():padding_len = batch_len_list[i]padding_batch = batch_slice_para[padding_len]tem_df= ss_df[sel]# tem_df['ss_padding'] = tem_df['ss'].apply(lambda x: x.ljust(padding_len,'a'))tem_df['ss_padding'] = tem_df['ss']tem_df_iterator = DataFrameBatchIterator(tem_df, padding_batch)batch_df_list.append(tem_df_iterator)else:batch_df_list.append(None)

对每个批次执行处理,载入模型

label_list = ['B','I','O']
model_checkpoint = 'model03'
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device: %s' % device)
# 自动切换设备
if model.device.type != device:model.to(device)print('>>> 检测到模型设备与当前指定不一致,切换 %s' % device )
else:print('>>> 模型设备一致,不切换 %s' % device)

分批次预测(主要是确保显存不溢出)

batch_res_list = []
for some_iter in batch_df_list:for some_batch in some_iter:batch_res = batch_predict(some_batch, model = model, model_path = 'model03')batch_res_list.append(batch_res)

结果合并

batch_res_df = pd.concat(batch_res_list, ignore_index= True)
mdf = pd.merge(input_df , batch_res_df[['ss_hash', 'ent_list']],how='left', on ='ss_hash')

在这里插入图片描述

4 总结

一个在理论上证明可以显著提升效率的点在于,模型进行实体识别时先切分了短句,然后按短句进行了去重:相同短句的实体结果一定是相同的。

实验中,200条新闻产生了约5万个短句,去重后只剩下约3.5万。所以即使在这一步也是有提升的。当然,这种方式同样也可以被用于pipeline。

还有就是在处理填充时,并不按照最大长度统一填充。而是按照句子长度的统计特性分为了短、中、长三种方式。从统计上看,约70%的短句长度是在20个字符以内的,真正超过50个字符的短句(中间无分隔符),即使从语法上来看也是比较奇怪的。
这样在填充数据时浪费就比pipeline要小,同样显存可以装下更多的数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/344059.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

党史馆3d网上展馆

在数字化浪潮的推动下&#xff0c;华锐视点运用实时互动三维引擎技术&#xff0c;为用户带来前所未有的场景搭建体验。那就是领先于同行业的线上三维云展编辑平台搭建编辑器&#xff0c;具有零基础、低门槛、低成本等特点&#xff0c;让您轻松在数字化世界中搭建真实世界的仿真…

【SpringBoot】SpringBoot整合RabbitMQ消息中间件,实现延迟队列和死信队列

&#x1f4dd;个人主页&#xff1a;哈__ 期待您的关注 目录 一、&#x1f525;死信队列 RabbitMQ的工作模式 死信队列的工作模式 二、&#x1f349;RabbitMQ相关的安装 三、&#x1f34e;SpringBoot引入RabbitMQ 1.引入依赖 2.创建队列和交换器 2.1 变量声明 2.2 创建…

Python实现半双工的实时通信SSE(Server-Sent Events)

Python实现半双工的实时通信SSE&#xff08;Server-Sent Events&#xff09; 1 简介 实现实时通信一般有WebSocket、Socket.IO和SSE&#xff08;Server-Sent Events&#xff09;三种方法。WebSocket和Socket.IO是全双工的实时双向通信技术&#xff0c;适合用于聊天和会话等&a…

SwiftUI中Mask修饰符的理解与使用

Mask是一种用于控制图形元素可见性的图形技术&#xff0c;使用给定视图的alpha通道掩码该视图。在SwiftUI中&#xff0c;它类似于创建一个只显示视图的特定部分的模板。 Mask修饰符的定义&#xff1a; func mask<Mask>(alignment: Alignment .center,ViewBuilder _ ma…

AI论文速读 | 2024[KDD]GinAR—变量缺失端到端多元时序预测

题目&#xff1a;GinAR: An End-To-End Multivariate Time Series Forecasting Model Suitable for Variable Missing 作者&#xff1a;Chengqing Yu&#xff08;余澄庆&#xff09;, Fei Wang&#xff08;王飞&#xff09;, Zezhi Shao&#xff08;邵泽志&#xff09;, Tangw…

XML解析库tinyxml2库使用详解

XML语法规则介绍及总结-CSDN博客 TinyXML-2 是一个简单轻量级的 C XML 解析库,它提供了一种快速、高效地解析 XML 文档的方式。 1. 下载地址 Gitee 极速下载/tinyxml2 2. 基本用法 下面将详细介绍 TinyXML-2 的主要使用方法: 2.1. 引入头文件和命名空间 #i…

Docker 国内镜像源更换

实现 替换docker 镜像源 前提要求 安装 docker docker-compose 参考创建一键更换docker国内源 vim /docker_daemon.sh #!/bin/bash # -*- coding: utf-8 -*- # Author: make.han # Email: CIASM@CIASM # Date: 2024/06/07 # docker daemon.jsondaemon_json_file="/et…

js 选择一个音频文件,绘制音频的波形,从右向左逐渐前进。

选择一个音频文件&#xff0c;绘制波形&#xff0c;从右向左逐渐前进。 完整代码&#xff1a; <template><div><input type"file" change"handleFileChange" accept"audio/*" /><button click"stopPlayback" :…

大模型微调工具LLaMA-Factory docker安装、大模型lora微调训练

参考: https://github.com/hiyouga/LLaMA-Factory 报错解决: 1)Docker 构建报错 RuntimeError: can’t start new thread: https://github.com/hiyouga/LLaMA-Factory/issues/3859 修改后的Dockerfile: FROM nvcr.io/nvidia/pytorch:24.01-py3WORKDIR /appCOPY require…

el-input中change事件造成的坑

el-input中change事件造成的坑 一、change事件定义二、如果仅回车时候触发 一、change事件定义 仅在输入框失去焦点或用户按下回车时触发 二、如果仅回车时候触发 <el-inputv-model.trim"questionInput"placeholder"请输入你的问题&#xff0c;按回车发送&…

NDIS Filter开发-PNP响应和安装

NDIS filter驱动可能是最容易生成的驱动之一&#xff0c;如果你安装了VS 2015 WDK之后&#xff0c;你可以直接生成一个能运行的Filter驱动&#xff0c;它一般是ndislwf。 和大部分硬件不同&#xff0c;NDIS Filter驱动介于软件和硬件抽象层之上&#xff0c;它和硬件相关&…

2024年最新的软件测试面试总结(答案+文档)

&#x1f345; 点击文末小卡片 &#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 测试技术面试题 1、什么是兼容性测试&#xff1f;兼容性测试侧重哪些方面&#xff1f; 参考答案&#xff1a; 兼容测试主要是检查软件在不同的硬件平台、软件平…

老师如何制作高考后志愿填报信息采集系统?

高考结束后&#xff0c;志愿填报成为学生们的头等大事。面对众多选择&#xff0c;如何高效、准确地填报志愿&#xff0c;是每个学生和家长都关心的问题。作为老师&#xff0c;能否利用现有的技术工具&#xff0c;帮助学生更好地完成志愿填报呢&#xff1f; 老师们需要一个能够…

[C#]使用OpenCvSharp图像滤波中值滤波均值滤波高通滤波双边滤波锐化滤波自定义滤波

在使用OpenCvSharp进行图像滤波处理时&#xff0c;各种滤波方法都有其特定的用途和效果。以下是对中值滤波、均值滤波、高通滤波、双边滤波、锐化滤波和自定义滤波的详细解释和归纳&#xff1a; 中值滤波&#xff08;MedianBlur&#xff09; 原理与作用&#xff1a;中值滤波是…

SpringBoot实现参数校验拦截(采用AOP方式)

一、AOP是什么&#xff1f; 目的&#xff1a;分离横切关注点&#xff08;如日志记录、事务管理&#xff09;与核心业务逻辑。 优势&#xff1a;提高代码的可读性和可维护性。 关键概念 切面&#xff08;Aspect&#xff09;&#xff1a;包含横切关注点代码的模块。通知&#xff…

leetcode-04-[24]两两交换链表中的节点[19]删除链表的倒数第N个节点[160]相交链表[142]环形链表II

一、[24]两两交换链表中的节点 重点&#xff1a;暂存节点 class Solution {public ListNode swapPairs(ListNode head) {ListNode dummyHeadnew ListNode(-1);dummyHead.nexthead;ListNode predummyHead;//重点&#xff1a;存节点while(pre.next!null&&pre.next.next…

视频去水印电脑版,视频去水印软件

视频去水印怎么去&#xff0c;一直是视频编辑者们的热门话题。那么&#xff0c;如何去除频水印呢&#xff1f;接下来&#xff0c;我们将为您详细介绍视频去水印方法。 第一种方法&#xff1a; 首先通过浏览器打开 “ 51视频处理官网” 的网站。打开网站后&#xff0c;我们上传…

第一个小爬虫_爬取 股票数据

前言 爬取 雪球网的股票数据 [环境使用]&#xff1a;python 3.12 解释器pycharm 编辑器 【模块使用】&#xff1a;import requests -->数据请求模块 要安装 命令 pip install requestsimport csv -->将数据保存到CSV表格中import pandas -->也可以将数据保…

VS2019 QT无法打开 源 文件 “QTcpSocket“

VS2019 QT无法打开 源 文件 "QTcpSocket" QT5.15.2_msvc2019_64 严重性 代码 说明 项目 文件 行 禁止显示状态 错误(活动) E1696 无法打开 源 文件 "QTcpSocket" auto_pack_line_demo D:\vs_qt_project\auto_pack_line_de…

【启明智显分享】基于工业级芯片Model3A的7寸彩色触摸屏应用于智慧电子桌牌方案

一场大型会议的布置&#xff0c;往往少不了制作安放参会人物的桌牌。制作、打印、裁剪&#xff0c;若有临时参与人员变更&#xff0c;会务方免不了手忙脚乱更新桌牌。由此&#xff0c;智能电子桌牌应运而生&#xff0c;工作人员通过系统操作更新桌牌信息&#xff0c;解决了传统…