CENet及多模态情感计算实战


✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨

🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。

我是Srlua小谢,在这里我会分享我的知识和经验。🎥

希望在这里,我们能一起探索IT世界的奥妙,提升我们的技能。🔮

记得先点赞👍后阅读哦~ 👏👏

📘📚 所属专栏:传知代码论文复现

欢迎访问我的主页:Srlua小谢 获取更多信息和资源。✨✨🌙🌙

​​

​​

目录

一、概述

二、论文地址

三、研究背景

四、主要贡献

五、论文思路

六、主要内容和网络架构

七、数据集介绍

八、性能对比

九、复现过程(重要)

十、演示结果


本文所有资源均可在该地址处获取。

一、概述

本文对 “Cross-Modal Enhancement Network for Multimodal Sentiment Analysis” 论文进行讲解和手把手复现教学,解决当下热门的多模态情感计算问题,并展示在MOSI和MOSEI两个数据集上的效果

二、论文地址

DOI: 10.1109/TMM.2022.3183830

三、研究背景

情感分析在人工智能向情感智能发展中起着重要作用。早期的情感分析研究主要集中在分析单模态数据上,包括文本情感分析、图像情感分析、音频情感分析等。然而,人类的情感是通过人体的多种感官来传达的。因此,单模态情感分析忽略了人类情感的多维性。相比之下,多模态情感分析通过结合文本、视觉和音频等多模态数据来推断一个人的情感状态。与单模态情感分析相比,多模态数据包含多样化的情感信息,具有更高的预测精度。目前,多模态情感分析已被广泛应用于视频理解、人机交互、政治活动等领域。近年来,随着互联网和各种多媒体平台的快速发展,通过互联网表达情感的载体和方式也变得越来越多样化。这导致了多媒体数据的快速增长,为多模态情感分析提供了大量的数据源。下图展示了多模态在情感计算任务中的优势。

四、主要贡献

  • 提出了一种跨模态增强网络,通过融入长范围非文本情感语境来增强预训练语言模型中的文本表示;
  • 提出一种特征转换策略,通过减小文本模态和非文本模态的初始表示之间的分布差异,促进了不同模态的融合;
  • 融合了新的预训练语言模型SentiLARE来提高模型对情感词的提取效率,从而提升对情感计算的准确性。

五、论文思路

作者提出的跨模态增强网络(CENet)模型通过将视觉和声学信息集成到语言模型中来增强文本表示。在基于transformer的预训练语言模型中嵌入跨模态增强(CE)模块,根据非对齐非文本数据中隐含的长程情感线索增强每个单词的表示。此外,针对声学模态和视觉模态,提出了一种特征转换策略,以减少语言模态和非语言模态的初始表示之间的分布差异,从而促进不同模态的融合。

六、主要内容和网络架构

首先我们展示一下CENet的总体网络架构


通过该图,我们可以看出该模型主要有以下几部分组成:1.非文本模态特征转化;2.跨模态增强;3.预训练语言模型输出;接下来将对他们分别进行讲解:
1. 非文本模态转换
针对预训练语言模型,初始文本表示是基于词汇表的单词索引序列,而视觉和听觉的表示则是实值向量序列。为了缩小这些异质模态之间的初始分布差异,进而减少在融合过程中非文本特征和文本特征之间的分布差距,本文提出了一种将非文本向量转换为索引的特征转换策略。这种策略有助于促进文本表征与非文本情感语境的有机融合。

具体而言,特征转换策略利用无监督聚类算法分别构建了“声学词汇表”和“视觉词汇表”。通过查询这些非语言词汇表,可以将原始的非语言特征序列转换为索引序列。下图展示了特征转换过程的具体步骤。考虑到k-means方法具有计算复杂度低和实现简单等优点,作者选择使用k-means算法来学习非语言模态的词汇。

2. 跨模态增强模块
本文提出的CE模块旨在将长程视觉和声学信息集成到预训练语言模型中,以增强文本的表示能力。CE模块的核心组件是跨模态嵌入单元,其结构如下图所示。该单元利用跨模态注意力机制捕捉长程非文本情感信息,并生成基于文本的非语言嵌入。嵌入层的参数可学习,用于将经过特征转换策略处理后得到的非文本索引向量映射到高维空间,然后生成文本模态对非文本模态的注意力权重矩阵。

在初始训练阶段,由于语言表征和非语言表征处于不同的特征空间,它们之间的相关性通常较低。因此,注意力权重矩阵中的元素可能较小。为了更有效地学习模型参数,研究者在应用softmax之前使用超参数来缩放这些注意力权重矩阵。

基于注意力权重矩阵,可以生成基于文本的非语言向量。将基于文本的声学嵌入和基于文本的视觉嵌入结合起来,形成非语言增强嵌入。最后,通过整合非语言增强嵌入来更新文本的表示。因此,CE模块的提出旨在为文本提供非语言上下文信息,通过增加非语言增强嵌入来调整文本表示,从而使其在语义上更加准确和丰富。

3. 预训练语言模型输出
作者采用SentiLARE作为语言模型,其利用包括词性和单词情感极性在内的语言知识来学习情感感知的语言表示。CE模块被集成到预训练语言模型的第i层中。值得注意的是,任何基于Transformer的预训练语言模型都可以与CE模块集成。下面是作者根据SentiLARE的设置进行的步骤:

  1. 给定一个单词序列,首先通过Stanford Log-Linear词性(POS)标记器学习其词性序列,并通过SentiwordNet学习单词级情感极性序列。

  2. 然后,使用预训练语言模型的分词器获取词标索引序列。这个序列作为输入,产生一个初步的增强语言知识表示。

  3. 更新后的文本表示将作为第(i+1)层的输入,并通过SentiLARE中的剩余层进行处理。

  4. 每一层的输出将是具有视觉和听觉信息的文本主导的高级情感表示。

  5. 最后,将这些文本表示输入到分类头中,以获取情感强度。

因此,CE模块通过将非语言增强嵌入集成到预训练语言模型中,有助于生成更富有情感感知的语言表示。这种方法能够在文本表示中有效地整合视觉和听觉信息,从而提升情感分析等任务的性能。

七、数据集介绍

1. CMU-MOSI: 它是一个多模态数据集,包括文本、视觉和声学模态。它来自Youtube上的93个电影评论视频。这些视频被剪辑成2199个片段。每个片段都标注了[-3,3]范围内的情感强度。该数据集分为三个部分,训练集(1,284段)、验证集(229段)和测试集(686段)。
2. CMU-MOSEI: 它类似于CMU-MOSI,但规模更大。它包含了来自在线视频网站的23,453个注释视频片段,涵盖了250个不同的主题和1000个不同的演讲者。CMU-MOSEI中的样本被标记为[-3,3]范围内的情感强度和6种基本情绪。因此,CMU-MOSEI可用于情感分析和情感识别任务。

八、性能对比

有下图可以观察到,该论文提出的CENet与其他SOTA模型对比性能有明显提升:

九、复现过程(重要)

1. 数据集准备
下载MOSI和MOSEI数据集已提取好的特征文件(.pkl)。把它放在"./dataset”目录。

2. 下载预训练语言模型
下载SentiLARE语言模型文件,然后将它们放入"/pretrained-model / sentilare_model”目录。

3. 下载需要的包

pip install -r requirements.txt

4. 搭建CENet模块
利用pytorch框架对CENet模块进行搭建:

class CE(nn.Module):def __init__(self, beta_shift_a=0.5, beta_shift_v=0.5, dropout_prob=0.2):super(CE, self).__init__()self.visual_embedding = nn.Embedding(label_size + 1, TEXT_DIM, padding_idx=label_size)self.acoustic_embedding = nn.Embedding(label_size + 1, TEXT_DIM, padding_idx=label_size)self.hv = SelfAttention(TEXT_DIM)self.ha = SelfAttention(TEXT_DIM)self.cat_connect = nn.Linear(2 * TEXT_DIM, TEXT_DIM)def forward(self, text_embedding, visual=None, acoustic=None, visual_ids=None, acoustic_ids=None):visual_ = self.visual_embedding(visual_ids)acoustic_ = self.acoustic_embedding(acoustic_ids)visual_ = self.hv(text_embedding, visual_)acoustic_ = self.ha(text_embedding, acoustic_) visual_acoustic = torch.cat((visual_, acoustic_), dim=-1)shift = self.cat_connect(visual_acoustic)embedding_shift = shift + text_embeddingreturn embedding_shift

5. 将CE模块与预训练语言模型融合

class BertEncoder(nn.Module):def __init__(self, config):super(BertEncoder, self).__init__()self.output_attentions = config.output_attentionsself.output_hidden_states = config.output_hidden_statesself.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])self.CE = CE()def forward(self, hidden_states, visual=None, acoustic=None, visual_ids=None, acoustic_ids=None, attention_mask=None, head_mask=None):all_hidden_states = ()all_attentions = ()for i, layer_module in enumerate(self.layer):if self.output_hidden_states:all_hidden_states = all_hidden_states + (hidden_states,)if i == ROBERTA_INJECTION_INDEX:hidden_states = self.CE(hidden_states, visual=visual, acoustic=acoustic, visual_ids=visual_ids, acoustic_ids=acoustic_ids)

6. 训练代码编写

  • 定义一个整体的训练过程 train()函数,它负责训练模型多个 epoch,并在每个 epoch 结束后评估模型在验证集和测试集上的性能,并记录相关的指标和损失;并在训练最后一轮输出所有测试集id,true label 和 predicted label。
def train(args,model,train_dataloader,validation_dataloader,test_data_loader,optimizer,scheduler,
):valid_losses = []test_accuracies = []for epoch_i in range(int(args.n_epochs)):train_loss, train_pre, train_label = train_epoch(args, model, train_dataloader, optimizer, scheduler)valid_loss, valid_pre, valid_label = evaluate_epoch(args, model, validation_dataloader)test_loss, test_pre, test_label = evaluate_epoch(args, model, test_data_loader)train_acc, train_mae, train_corr, train_f_score = score_model(train_pre, train_label)test_acc, test_mae, test_corr, test_f_score = score_model(test_pre, test_label)non0_test_acc, _, _, non0_test_f_score = score_model(test_pre, test_label, use_zero=True)valid_acc, valid_mae, valid_corr, valid_f_score = score_model(valid_pre, valid_label)print("epoch:{}, train_loss:{}, train_acc:{}, valid_loss:{}, valid_acc:{}, test_loss:{}, test_acc:{}".format(epoch_i, train_loss, train_acc, valid_loss, valid_acc, test_loss, test_acc))valid_losses.append(valid_loss)test_accuracies.append(test_acc)wandb.log(({"train_loss": train_loss,"valid_loss": valid_loss,"train_acc": train_acc,"train_corr": train_corr,"valid_acc": valid_acc,"valid_corr": valid_corr,"test_loss": test_loss,"test_acc": test_acc,"test_mae": test_mae,"test_corr": test_corr,"test_f_score": test_f_score,"non0_test_acc": non0_test_acc,"non0_test_f_score": non0_test_f_score,"best_valid_loss": min(valid_losses),"best_test_acc": max(test_accuracies),}))# 输出测试集的 id、真实标签和预测标签with torch.no_grad():for step, batch in enumerate(test_data_loader):batch = tuple(t.to(DEVICE) for t in batch)input_ids, visual_ids, acoustic_ids, pos_ids, senti_ids, polarity_ids, visual, acoustic, input_mask, segment_ids, label_ids = batchvisual = torch.squeeze(visual, 1)outputs = model(input_ids,visual,acoustic,visual_ids,acoustic_ids,pos_ids, senti_ids, polarity_ids,token_type_ids=segment_ids,attention_mask=input_mask,labels=None,)logits = outputs[0]logits = logits.detach().cpu().numpy()label_ids = label_ids.detach().cpu().numpy()logits = np.squeeze(logits).tolist()label_ids = np.squeeze(label_ids).tolist()# 假设您从 label_ids 中获取 idsids = [f"sample_{idx}" for idx in range(len(label_ids))]  # 这里是示例,您可以根据实际情况生成合适的 ids# 输出所有测试样本的 id、真实标签和预测标签值for i in range(len(ids)):print(f"id: {ids[i]}, true label: {label_ids[i]}, predicted label: {logits[i]}")

6. 开始训练+测试

python train.py

7. 输出结果
此外,为了方便直观地查看模型性能,我在最后一层训练结束后将所有测试集视频的clip id、真实标签和预测标签依次进行输出;并且结合wandb库自动保存结果可视化;结果在后续章节展示。

十、演示结果

  • 以下是我们的训练过程

  • 下面是模型性能结果

  • 接下来是我们自己补充的每个测试集的真实标签和预测标签

  • 可视化

​​

希望对你有帮助!加油!

若您认为本文内容有益,请不吝赐予赞同并订阅,以便持续接收有价值的信息。衷心感谢您的关注和支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/482187.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

共享售卖机语音芯片方案选型:WTN6020引领智能化交互新风尚

在共享经济蓬勃发展的今天,共享售卖机作为便捷购物的新形式,正逐步渗透到人们生活的各个角落。为了提升用户体验,增强设备的智能化和互动性,增加共享售卖机的语音功能就显得尤为重要。 共享售卖机语音方案选型: WTN602…

云备份实战项目

文章目录 前言一、整体项目简介二、服务端环境及功能简介三、 客户端环境及功能简介四、服务端文件管理类的实现1. 获取文件大小,最后一次修改时间,最后一次访问时间,文件名称,以及文件内容的读写等功能2. 判断文件是否存在&#…

【Linux】死锁、读写锁、自旋锁

文章目录 1. 死锁1.1 概念1.2 死锁形成的四个必要条件1.3 避免死锁 2. 读者写者问题与读写锁2.1 读者写者问题2.2 读写锁的使用2.3 读写策略 3. 自旋锁3.1 概念3.2 原理3.3 自旋锁的使用3.4 优点与缺点 1. 死锁 1.1 概念 死锁是指在⼀组进程中的各个进程均占有不会释放的资源…

通俗易懂:序列标注与命名实体识别(NER)概述及标注方法解析

目录 一、序列标注(Sequence Tagging)二、命名实体识别(Named Entity Recognition,NER)**命名实体识别的作用****命名实体识别的常见实体类别** : 三、标签类型四、序列标注的三种常见方法1. **BIO&#xf…

设计模式学习[10]---迪米特法则+外观模式

文章目录 前言1. 迪米特法则2. 外观模式2.1 原理阐述2.2 举例说明 总结 前言 之前有写到过 依赖倒置原则,这篇博客中涉及到的迪米特法则和外观模式更像是这个依赖倒置原则的一个拓展。 设计模式的原则嘛,总归还是高内聚低耦合,下面就来阐述…

JUnit介绍:单元测试

1、什么是单元测试 单元测试是针对最小的功能单元编写测试代码(Java 程序最小的功能单元是方法)单元测试就是针对单个Java方法的测试。 2、为什么要使用单元测试 确保单个方法运行正常; 如果修改了代码,只需要确保其对应的单元…

用Transformers和FastAPI快速搭建后端算法api

用Transformers和FastAPI快速搭建后端算法api 如果你对自然语言处理 (NLP, Natural Language Processing) 感兴趣,想要了解ChatGPT是怎么来的,想要搭建自己的聊天机器人,想要做分类、翻译、摘要等各种NLP任务,Huggingface的Trans…

架构师:Dubbo 服务请求失败处理的实践指南

1、简述 在分布式服务中,服务调用失败是不可避免的,可能由于网络抖动、服务不可用等原因导致。Dubbo 作为一款高性能的 RPC 框架,提供了多种机制来处理服务请求失败问题。本文将介绍如何在 Dubbo 中优雅地处理服务请求失败,并结合具体实践步骤进行讲解。 2、常见处理方式 …

shell语法(1)bash

shell是我们通过命令行与操作系统沟通的语言,是一种解释型语言 shell脚本可以直接在命令行中执行,也可以将一套逻辑组织成一个文件,方便复用 Linux系统中一般默认使用bash为脚本解释器 在Linux中创建一个.sh文件,例如vim test.sh…

探索文件系统,Python os库是你的瑞士军刀

文章目录 探索文件系统,Python os库是你的瑞士军刀第一部分:背景介绍第二部分:os库是什么?第三部分:如何安装os库?第四部分:简单库函数使用方法1. 获取当前工作目录2. 改变当前工作目录3. 列出目…

Paper -- 建筑物高度估计 -- 基于深度学习、图像处理和自动地理空间分析的街景图像建筑高度估算

论文题目: Building height estimation from street-view imagery using deep learning, image processing and automated geospatial analysis 中文题目: 基于深度学习、图像处理和自动地理空间分析的街景图像建筑高度估算 作者: Ala’a Al-Habashna, Ryan Murdoch 作者单位: …

FTP介绍与配置

前言: FTP是用来传送文件的协议。使用FTP实现远程文件传输的同时,还可以保证数据传输的可靠性和高效性。 介绍 FTP的应用 在企业网络中部署一台FTP服务器,将网络设备配置为FTP客户端,则可以使用FTP来备份或更新VRP文件和配置文件…

DDD领域应用理论实践分析回顾

目录 一、DDD的重要性 (一)拥抱互联网黑话(抓痛点、谈愿景、搞方法论) (二)DDD真的重要吗? 二、领域驱动设计DDD在B端营销系统的实践 (一)设计落地步骤 &#xff0…

【聚类】K-Means 聚类(无监督)及K-Means ++

1. 原理 2. 算法步骤 3. 目标函数 4. 优缺点 import torch import numpy as np import matplotlib.pyplot as plt from sklearn.cluster import KMeans from sklearn.decomposition import PCA import torch.nn as nn# 数据准备 # 生成数据:100 个张量&#xff0c…

C++学习日记---第14天(蓝桥杯备赛)

笔记复习 1.对象的初始化和清理 对象的初始化和清理是两个非常重要的安全问题,一个对象或者变量没有初始状态,对其使用后果是未知,同样的使用完一个对象或者变量,没有及时清理,也会造成一定的安全问题 构造函数&…

网络安全防护指南:筑牢网络安全防线(5/10)

一、网络安全的基本概念 (一)网络的定义 网络是指由计算机或者其他信息终端及相关设备组成的按照一定的规则和程序对信息收集、存储、传输、交换、处理的系统。在当今数字化时代,网络已经成为人们生活和工作中不可或缺的一部分。它连接了世…

【最新鸿蒙开发——应用导航设计】

大家好,我是小z,不知道大家在开发过程中有没有遇到模块间跳转的问题,今天给大家分享关于模块间跳转的三种方法 文章目录 1. 命名路由(ohos.router)使用步骤 2. 使用navigation组件跳转。步骤缺点 3. 路由管理模块1. 路由管理模块…

Wireshark常用功能使用说明

此处用于记录下本人所使用 wireshark 所可能用到的小技巧。Wireshark是一款强大的数据包分析工具,此处仅介绍常用功能。 Wireshark常用功能使用说明 1.相关介绍1.1.工具栏功能介绍1.1.1.时间戳/分组列表概况等设置 1.2.Windows抓包 2.wireshark过滤器规则2.1.wiresh…

【进阶篇-Day15:JAVA线程-Thread的介绍】

目录 1、进程和线程1.1 进程的介绍1.2 并行和并发1.3 线程的介绍 2、JAVA开启线程的三种方法2.1 继承Thread类:2.2 实现Runnable接口2.3 实现Callable接口2.4 总结: 3、线程相关方法3.1 获取和设置线程名字的方法3.2 线程休眠方法:3.3 线程优…

组播基础实验

当需要同时发给多个接受者或者接收者ip未知时使用组播 一、组播IP地址 1、组播IP地址范围 组播地址属于D类地址:224.0.0.0/4(224.0.0.0-239.255.255.255) 2、分类 (1)链路本地地址(link-local&#xf…