240712_昇思学习打卡-Day24-LSTM+CRF序列标注(3)

240712_昇思学习打卡-Day24-LSTM+CRF序列标注(3)

今天做LSTM+CRF序列标注第三部分,同样,仅作简单记录及注释,最近确实太忙了。

Viterbi算法

在完成前向训练部分后,需要实现解码部分。这里我们选择适合求解序列最优路径的Viterbi算法。与计算Normalizer类似,使用动态规划求解所有可能的预测序列得分。不同的是在解码时同时需要将第𝑖个Token对应的score取值最大的标签保存,供后续使用Viterbi算法求解最优预测序列使用。

取得最大概率得分ScoreScore,以及每个Token对应的标签历史HistoryHistory后,根据Viterbi算法可以得到公式:

请添加图片描述

从第0个至第𝑖个Token对应概率最大的序列,只需要考虑从第0个至第𝑖−1个Token对应概率最大的序列,以及从第𝑖个至第𝑖−1个概率最大的标签即可。因此我们逆序求解每一个概率最大的标签,构成最佳的预测序列。

由于静态图语法限制,我们将Viterbi算法求解最佳预测序列的部分作为后处理函数,不纳入后续CRF层的实现。

# 定义维特比解码算法,用于找出具有最大概率的标签序列
def viterbi_decode(emissions, mask, trans, start_trans, end_trans):# emissions: (seq_length, batch_size, num_tags) 发射概率矩阵# mask: (seq_length, batch_size) 序列掩码,用于标记有效序列长度# trans: 转移概率矩阵# start_trans: 初始状态转移概率向量# end_trans: 终止状态转移概率向量seq_length = mask.shape[0]  # 获取序列长度# 初始化分数矩阵,等于初始状态转移概率加上第一个发射概率score = start_trans + emissions[0]history = ()  # 初始化历史路径记录# 遍历序列中的每个时间步for i in range(1, seq_length):# 扩展维度以便广播运算broadcast_score = score.expand_dims(2)broadcast_emission = emissions[i].expand_dims(1)# 计算所有可能的转移分数next_score = broadcast_score + trans + broadcast_emission# 找出当前Token对应的最大分数标签,并保存indices = next_score.argmax(axis=1)history += (indices,)  # 保存历史路径信息# 取出最大分数next_score = next_score.max(axis=1)# 更新分数矩阵,只更新mask为True的部分score = mnp.where(mask[i].expand_dims(1), next_score, score)# 加上终止状态转移概率score += end_trans# 返回最终的分数矩阵和历史路径信息return score, history# 根据解码过程中的得分和历史路径信息,重构最优标签序列
def post_decode(score, history, seq_length):# score: 最终得分矩阵# history: 历史路径信息# seq_length: 每个样本的实际序列长度batch_size = seq_length.shape[0]  # 获取批次大小seq_ends = seq_length - 1  # 计算每个样本的最后一个Token位置# 初始化最佳标签序列列表best_tags_list = []# 对批次中的每个样本进行解码for idx in range(batch_size):# 找出使最后一个Token对应的预测概率最大的标签best_last_tag = score[idx].argmax(axis=0)best_tags = [int(best_last_tag.asnumpy())]  # 添加最佳标签到序列# 从历史路径信息中反向追踪,找到每个Token的最佳标签for hist in reversed(history[:seq_ends[idx]]):best_last_tag = hist[idx][best_tags[-1]]best_tags.append(int(best_last_tag.asnumpy()))# 将逆序的标签序列反转,得到正序的最优标签序列best_tags.reverse()best_tags_list.append(best_tags)  # 添加到结果列表# 返回最优标签序列列表return best_tags_list

CRF层

完成上述前向训练和解码部分的代码后,将其组装完整的CRF层。考虑到输入序列可能存在Padding的情况,CRF的输入需要考虑输入序列的真实长度,因此除发射矩阵和标签外,加入seq_length参数传入序列Padding前的长度,并实现生成mask矩阵的sequence_mask方法。

综合上述代码,使用nn.Cell进行封装,最后实现完整的CRF层如下:

# 导入MindSpore相关模块
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.numpy as mnp
from mindspore.common.initializer import initializer, Uniform# 定义序列掩码生成函数
def sequence_mask(seq_length, max_length, batch_first=False):"""根据序列的实际长度和最大长度生成mask矩阵。参数:seq_length: 实际序列长度张量。max_length: 序列的最大长度。batch_first: 是否将批次放在第一维度。返回:mask矩阵,形状为(batch_size, max_length),其中True表示有效位置,False表示填充位置。"""# 生成从0到max_length的范围向量range_vector = mnp.arange(0, max_length, 1, seq_length.dtype)# 创建mask矩阵,shape为(seq_length.shape + (1,))result = range_vector < seq_length.view(seq_length.shape + (1,))# 转换数据类型并根据batch_first参数调整维度顺序if batch_first:return result.astype(ms.int64)return result.astype(ms.int64).swapaxes(0, 1)# 定义条件随机场(CRF)模型类
class CRF(nn.Cell):def __init__(self, num_tags: int, batch_first: bool = False, reduction: str = 'sum') -> None:"""初始化CRF模型。参数:num_tags: 标签数量。batch_first: 是否将批次放在第一维度。reduction: 损失函数的缩减方式。"""# 检查标签数量是否有效if num_tags <= 0:raise ValueError(f'无效的标签数量: {num_tags}')super().__init__()# 检查reduction参数是否有效if reduction not in ('none', 'sum', 'mean', 'token_mean'):raise ValueError(f'无效的缩减方式: {reduction}')self.num_tags = num_tags  # 标签数量self.batch_first = batch_first  # 批次是否在第一维度self.reduction = reduction  # 损失函数缩减方式# 初始化起始和结束状态转移权重self.start_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='start_transitions')self.end_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='end_transitions')# 初始化状态间转移权重self.transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags, num_tags)), name='transitions')def construct(self, emissions, tags=None, seq_length=None):"""CRF模型的前向传播方法。参数:emissions: 发射概率张量。tags: 真实标签张量。seq_length: 序列长度张量。返回:如果tags为None,则返回解码结果;否则返回损失值。"""if tags is None:return self._decode(emissions, seq_length)return self._forward(emissions, tags, seq_length)def _forward(self, emissions, tags=None, seq_length=None):"""计算损失值。参数:emissions: 发射概率张量。tags: 真实标签张量。seq_length: 序列长度张量。返回:损失值。"""# 根据batch_first参数调整emissions和tags的维度顺序if self.batch_first:batch_size, max_length = tags.shapeemissions = emissions.swapaxes(0, 1)tags = tags.swapaxes(0, 1)else:max_length, batch_size = tags.shape# 如果seq_length未给出,则假设所有序列都是最大长度if seq_length is None:seq_length = mnp.full((batch_size,), max_length, ms.int64)# 生成mask矩阵mask = sequence_mask(seq_length, max_length)# 计算分子部分(真实路径的得分)numerator = compute_score(emissions, tags, seq_length-1, mask, self.transitions, self.start_transitions, self.end_transitions)# 计算分母部分(所有可能路径的得分总和)denominator = compute_normalizer(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)# 计算对数似然比llh = denominator - numerator# 根据reduction参数选择损失值的缩减方式if self.reduction == 'none':return llhelif self.reduction == 'sum':return llh.sum()elif self.reduction == 'mean':return llh.mean()return llh.sum() / mask.astype(emissions.dtype).sum()def _decode(self, emissions, seq_length=None):"""解码方法,用于预测最优标签序列。参数:emissions: 发射概率张量。seq_length: 序列长度张量。返回:最优标签序列。"""# 根据batch_first参数调整emissions的维度顺序if self.batch_first:batch_size, max_length = emissions.shape[:2]emissions = emissions.swapaxes(0, 1)else:batch_size, max_length = emissions.shape[:2]# 如果seq_length未给出,则假设所有序列都是最大长度if seq_length is None:seq_length = mnp.full((batch_size,), max_length, ms.int64)# 生成mask矩阵mask = sequence_mask(seq_length, max_length)# 使用维特比算法解码最优路径return viterbi_decode(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)

打卡图片:

请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/376008.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】进程间通信——消息队列和信号量

目录 消息队列&#xff08;message queue&#xff09; 信号量&#xff08;Semaphore&#xff09; system V版本的进程间通信方式有三种&#xff1a;共享内存&#xff0c;消息队列和信号量。之前我们已经说了共享内存&#xff0c;那么我们来看一下消息队列和信号量以及它们之间…

Docker容器的生命周期

引言 Docker 容器作为一种轻量级虚拟化技术&#xff0c;在现代应用开发和部署中扮演着重要角色。理解容器的生命周期对于有效地管理和运维容器化应用至关重要。本文将深入探讨 Docker 容器的生命周期&#xff0c;从创建到销毁的各个阶段&#xff0c;帮助读者更好地掌握容器管理…

Unity最新第三方开源插件《Stateful Component》管理中大型项目MonoBehaviour各种序列化字段 ,的高级解决方案

上文提到了UIState, ObjectRefactor等,还提到了远古的NGUI, KBEngine-UI等 这个算是比较新的解决方法吧,但是抽象出来,问题还是这些个问题 所以你就说做游戏是不是先要解决这些问题? 而不是高大上的UiImage,DoozyUI等 Mono管理引用基本用法 ① 添加Stateful Component …

【正点原子i.MX93开发板试用连载体验】录音小程序采集语料

本文最早发表于电子发烧友论坛&#xff1a;【新提醒】【正点原子i.MX93开发板试用连载体验】基于深度学习的语音本地控制 - 正点原子学习小组 - 电子技术论坛 - 广受欢迎的专业电子论坛! (elecfans.com) 接下来就是要尝试训练中文提示词。首先要进行语料采集&#xff0c;这是一…

【2-1:RPC设计】

RPC 1. 基础1.1 定义&特点1.2 具体实现框架1.3 应用场景2. RPC的关键技术点&一次调用rpc流程2.1 RPC流程流程两个网络模块如何连接的呢?其它特性RPC优势2.2 序列化技术序列化方式PRC如何选择序列化框架考虑因素2.3 应用层的通信协议-http2.3.1 基础概念大多数RPC大多自…

STM32Cubemx配置生成 Keil AC6支持代码

文章目录 一、前言二、AC 6配置2.1 ARM ComPiler 选择AC62.2 AC6 UTF-8的编译命令会报错 三、STM32Cubemx 配置3.1 找到stm32cubemx的模板位置3.2 替换文件内核文件3.3 修改 cmsis_os.c文件3.4 修改本地 四、编译对比 一、前言 使用keil ARM compiler V5的时候&#xff0c;编译…

M J更改图像生成方式的参数选项

一个完整的/imagine命令可能包含几个内容,例如图像 URL、图像权重、算法版本和其他开关。 /imagine参数应遵循以下顺序: /imagine prompt: https://example/tulip.jpg a field of tulips in the style of Mary Blair --no farms --iw .5 --ar 3:2 在这种情况下,“开关”是指…

如何压缩pdf文件大小,怎么压缩pdf文件大小

在数字化时代&#xff0c;pdf文件因其稳定的格式和跨平台兼容性&#xff0c;成为了工作与学习中不可或缺的一部分。然而&#xff0c;随着pdf文件内容的丰富&#xff0c;pdf文件的体积也随之增大&#xff0c;给传输和存储带来了不少挑战。本文将深入探讨如何高效压缩pdf文件大小…

【保姆级教程】CenterNet的目标检测、3D检测、关键点检测使用教程

一、代码下载 仓库地址:https://github.com/xingyizhou/CenterNet?tab=readme-ov-file 二、目标检测 2.1 下载预训练权重 下载预训练权重ctdet_coco_dla_2x.pth放到models文件夹下 下载链接:https://drive.google.com/file/d/18Q3fzzAsha_3Qid6mn4jcIFPeOGUaj1d/edit …

《昇思25天学习打卡营第19天|生成式-Pix2Pix实现图像转换》

学习内容&#xff1a;Pix2Pix实现图像转换 1.模型简介 Pix2Pix是基于条件生成对抗网络&#xff08;cGAN, Condition Generative Adversarial Networks &#xff09;实现的一种深度学习图像转换模型&#xff0c;该模型是由Phillip Isola等作者在2017年CVPR上提出的&#xff0c…

热题系列9

剑指 Offer 39. 数组中出现次数超过一半的数字 给一个长度为 n 的数组&#xff0c;数组中有一个数字出现的次数超过数组长度的一半&#xff0c;请找出这个数字。 例如输入一个长度为9的数组[1,2,3,2,2,2,5,4,2]。由于数字2在数组中出现了5次&#xff0c;超过数组长度的一半&am…

防火墙nat策略实验和多出口实验和智能选路实验

要求 7&#xff0c;办公区设备可以通过电信链路和移动链路上网(多对多的NAT&#xff0c;并且需要保留一个公网IP不能用来转换) 8&#xff0c;分公司设备可以通过总公司的移动链路和电信链路访问到Dmz区的http服务器 9&#xff0c;多出口环境基于带宽比例进行选路&#xff0c…

GuLi商城-商品服务-API-品牌管理-OSS获取服务端签名(续)

如何进行服务端签名直传_对象存储(OSS)-阿里云帮助中心 gulimall-third-party服务的代码: package com.nanjing.gulimall.thirdparty.controller;import com.aliyun.oss.OSS; import com.aliyun.oss.OSSClientBuilder; import com.aliyun.oss.common.utils.BinaryUtil; impor…

电脑如何快速删除相同的文件?分享5款重复文件删除工具

您有没有发现最近电脑运行速度变慢了&#xff1f;启动时间变得更长&#xff0c;甚至完成简单任务也难以如常&#xff1f;这可能是因为重复文件堆积所致。我们发现&#xff0c;清理或移动这些重复的文件和文件夹可以产生惊人的效果。通过删除不必要的重复文件和垃圾文件&#xf…

【C++】:继承[下篇](友元静态成员菱形继承菱形虚拟继承)

目录 一&#xff0c;继承与友元二&#xff0c;继承与静态成员三&#xff0c;复杂的菱形继承及菱形虚拟继承四&#xff0c;继承的总结和反思 点击跳转上一篇文章&#xff1a; 【C】&#xff1a;继承(定义&&赋值兼容转换&&作用域&&派生类的默认成员函数…

YOLOv5白皮书-第Y5周:yolo.py文件解读

本文为365天深度学习训练营 中的学习记录博客 原作者&#xff1a;K同学啊|接辅导、项目定制 本次训练是在前文《YOLOv5白皮书-第Y2周:训练自己的数据集》的基础上进行的。 前言 文件位置:./models/yolo.Py 这个文件是YOLOv5网络模型的搭建文件&#xff0c;如果你想改进YOLOv5&…

three.js官方案例webgpu_reflection.html学习记录

目录 ​1 判断浏览器是否支持 2 THREE.DirectionalLight 2.1DirectionalLightShadow 3 Texture 3.1 .wrapS 3.2 .wrapT 3.3 .colorSpace 4 创建地面 5 WebGPURenderer 6 OrbitControls 控制器 7 屏幕后处理 import * as THREE from three;import { MeshPhongNodeMa…

Ubuntu使用Nginx部署uniapp打包的项目

使用uniapp导出web项目&#xff1a; 安装&#xff1a; sudo apt install nginx解压web.zip unzip web.zip移动到/var/www/html目录下&#xff1a; sudo cp -r ~/web/h5/ /var/www/html/重启Nginx&#xff1a; sudo service nginx restart浏览器访问&#xff1a;http://19…

基于FPGA的千兆以太网设计(1)----大白话解释什么是以太网

1、什么是以太网? 还记得初学以太网的时候,我就被一大堆专业名词给整懵了:什么以太网,互联网,MAC,IP,局域网,万维网,网络分层模型等等等等。慢着!我学的不是以太网吗?怎么出来这么一大堆东西? 啊!以太网究竟是什么?别急,我接下来就尽量用通俗的大白话来给你解释…

AURORA仿真

AURORA 仿真验证 定义&#xff1a;AURORA是一种高速串行通信协议&#xff0c;通常用于在数字信号处理系统和其他电子设备之间传输数据。它提供了一种高效的方式来传输大量数据&#xff0c;通常用于需要高带宽和低延迟的应用中。AURORA协议通常由Xilinx公司的FPGA器件支持&#…