基于 MindSpore 实现 BERT 对话情绪识别
BERT介绍
BERT(Bidirectional Encoder Representations from Transformers)是一种基于Transformer架构的预训练语言模型,由谷歌在2018年提出。从以下6个方面来介绍BERT:
1. 预训练和微调:BERT采用预训练和微调的策略。首先,在大量无标签文本数据上进行预训练,学习通用的语言表示。然后,在特定任务的有标签数据上进行微调,以适应特定任务的需求。
2. 双向编码器:BERT的核心是双向编码器,它能够同时考虑上下文信息。与传统的单向编码器(如GPT)相比,BERT能够更好地理解语境中的多义词和长距离依赖关系。
3. Masked Language Model(MLM):BERT在预训练阶段采用了MLM任务。这个任务会随机屏蔽输入句子中的一些单词,然后让模型预测这些被屏蔽的单词。这有助于模型学习到更丰富的上下文表示。
4. Next Sentence Prediction(NSP):除了MLM任务,BERT还引入了NSP任务。这个任务的目的是让模型学会判断两个句子是否连续。这有助于模型理解句子之间的关系,提高对文本结构的把握。
5. Transformer架构:BERT基于Transformer架构,这是一种自注意力机制的变体。Transformer具有并行计算能力强、捕捉长距离依赖关系好等优点,使得BERT在处理大规模文本数据时具有很高的效率。
6. 广泛应用:BERT在自然语言处理领域取得了显著的成果,广泛应用于文本分类、命名实体识别、情感分析、问答系统等任务。此外,BERT还催生了许多改进版本和变体,如RoBERTa、ALBERT等,进一步推动了预训练语言模型的发展。
总之,BERT作为一种强大的预训练语言模型,凭借其双向编码器、MLM和NSP任务以及基于Transformer的架构,在自然语言处理领域取得了突破性的成果,为深度学习开发工程师提供了一种高效且有效的工具。
对话情绪识别
对话情绪识别是一种先进的人工智能技术,它通过自然语言处理和深度学习技术来分析对话中的情绪状态。下面将详细介绍一下对话情绪识别:
1. 定义与技术
- 基本概念:对话情绪识别旨在识别对话中的用户情绪状态,如积极、消极、中性等,以提供更精准的用户体验和服务。
- 技术基础:该技术基于自然语言处理(NLP)和深度学习,通过处理和分析自然语言来提取情感信息。
2. 情绪分类与识别
- 情绪分类:对话情绪识别能识别多种情绪,包括正向情绪如喜爱、愉快、感谢,以及负向情绪如抱怨、愤怒、厌恶、恐惧、悲伤等。
- 识别技术:现代对话情绪识别技术能够自动检测用户日常对话中的情绪特征,并提供有针对性的参考回复,有助于企业或应用方快速应对客户情绪。
3. 应用场景与作用
- 智能客服:对话情绪识别可以提升智能客服系统的效率,通过理解用户情感和需求,提供更合适的服务解决方案[。
- 智能推荐:在智能推荐系统中,该技术可以根据用户的情绪状态推荐相应的产品或服务,提高推荐的准确性和用户满意度。
- 市场调研:通过对话情绪识别技术,可以更准确地了解用户对产品的情感态度和反馈,为企业制定市场策略提供依据。
4. 技术优势与挑战
- 实时性:这一技术处理速度快,能够实时分析对话数据,及时提供情绪识别结果,适用于需要快速响应的场景。
- 准确性:通过深度学习技术,对话情绪识别可以达到很高的准确率和精度,但在某些复杂情境下仍面临挑战,如讽刺、幽默等难以识别的言语。
- 易用性:现代对话情绪识别产品通常提供简单易用的API接口,方便企业集成和应用。
5. 未来发展趋势与改进
- 模型优化:随着技术进步,对话情绪识别的模型将持续优化,提升对复杂情绪和语境的识别能力。
- 应用深化:预计对话情绪识别将更广泛地应用于更多领域,如心理健康、在线教育等,为不同行业提供定制化解决方案。
总的来说,对话情绪识别作为一种基于NLP和深度学习的先进技术,能够有效识别和处理对话中的情绪信息,为智能客服、智能推荐等领域提供了强大的技术支持。随着技术的不断发展,其应用范围将进一步扩大,为企业和个人提供更加智能化、个性化的服务。
实践
环境准备
python: 3.9.19
安装环境依赖
pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14pip install mindnlp
完整的Python环境依赖
pip list
Package Version
------------------------------ --------------
absl-py 2.1.0
addict 2.4.0
aiofiles 22.1.0
aiohttp 3.9.5
aiosignal 1.3.1
aiosqlite 0.20.0
altair 5.3.0
annotated-types 0.7.0
anyio 4.4.0
argon2-cffi 23.1.0
argon2-cffi-bindings 21.2.0
arrow 1.3.0
astroid 3.2.2
asttokens 2.0.5
astunparse 1.6.3
async-timeout 4.0.3
attrs 23.2.0
auto-tune 0.1.0
autopep8 1.5.5
Babel 2.15.0
backcall 0.2.0
beautifulsoup4 4.12.3
black 24.4.2
bleach 6.1.0
certifi 2024.6.2
cffi 1.16.0
charset-normalizer 3.3.2
click 8.1.7
cloudpickle 3.0.0
colorama 0.4.6
comm 0.2.1
contextlib2 21.6.0
contourpy 1.2.1
cycler 0.12.1
dataflow 0.0.1
datasets 2.20.0
debugpy 1.6.7
decorator 5.1.1
defusedxml 0.7.1
dill 0.3.8
dnspython 2.6.1
download 0.3.5
easydict 1.13
email_validator 2.2.0
entrypoints 0.4
evaluate 0.4.2
exceptiongroup 1.2.0
executing 0.8.3
fastapi 0.111.0
fastapi-cli 0.0.4
fastjsonschema 2.20.0
ffmpy 0.3.2
filelock 3.15.3
flake8 3.8.4
fonttools 4.53.0
fqdn 1.5.1
frozenlist 1.4.1
fsspec 2024.5.0
gitdb 4.0.11
GitPython 3.1.43
gradio 4.26.0
gradio_client 0.15.1
h11 0.14.0
hccl 0.1.0
hccl-parser 0.1
httpcore 1.0.5
httptools 0.6.1
httpx 0.27.0
huggingface-hub 0.23.4
hypothesis 6.105.1
idna 3.7
importlib-metadata 7.0.1
importlib_resources 6.4.0
iniconfig 2.0.0
ipykernel 6.28.0
ipympl 0.9.4
ipython 8.15.0
ipython-genutils 0.2.0
ipywidgets 8.1.3
isoduration 20.11.0
isort 5.13.2
jedi 0.17.2
jieba 0.42.1
Jinja2 3.1.4
joblib 1.4.2
json5 0.9.25
jsonpointer 3.0.0
jsonschema 4.22.0
jsonschema-specifications 2023.12.1
jupyter_client 7.4.9
jupyter_core 5.7.2
jupyter-events 0.10.0
jupyter-lsp 2.2.5
jupyter-resource-usage 0.7.2
jupyter_server 2.14.1
jupyter_server_fileid 0.9.2
jupyter-server-mathjax 0.2.6
jupyter_server_terminals 0.5.3
jupyter_server_ydoc 0.8.0
jupyter-ydoc 0.2.5
jupyterlab 3.6.7
jupyterlab_code_formatter 2.2.1
jupyterlab_git 0.50.1
jupyterlab-language-pack-zh-CN 4.2.post1
jupyterlab-lsp 4.3.0
jupyterlab_pygments 0.3.0
jupyterlab_server 2.27.2
jupyterlab-system-monitor 0.8.0
jupyterlab-topbar 0.6.1
jupyterlab_widgets 3.0.11
kiwisolver 1.4.5
markdown-it-py 3.0.0
MarkupSafe 2.1.5
matplotlib 3.9.0
matplotlib-inline 0.1.6
mccabe 0.6.1
mdurl 0.1.2
mindnlp 0.3.1
mindspore 2.2.14
mindvision 0.1.0
mistune 3.0.2
ml_collections 0.1.1
ml-dtypes 0.4.0
mpmath 1.3.0
msadvisor 1.0.0
multidict 6.0.5
multiprocess 0.70.16
mypy-extensions 1.0.0
nbclassic 1.1.0
nbclient 0.10.0
nbconvert 7.16.4
nbdime 4.0.1
nbformat 5.10.4
nest-asyncio 1.6.0
notebook 6.5.7
notebook_shim 0.2.4
numpy 1.26.4
op-compile-tool 0.1.0
op-gen 0.1
op-test-frame 0.1
opc-tool 0.1.0
opencv-contrib-python-headless 4.10.0.84
opencv-python 4.10.0.84
opencv-python-headless 4.10.0.84
orjson 3.10.5
overrides 7.7.0
packaging 23.2
pandas 2.2.2
pandocfilters 1.5.1
parso 0.7.1
pathlib2 2.3.7.post1
pathspec 0.12.1
pexpect 4.8.0
pickleshare 0.7.5
pillow 10.3.0
pip 24.1
platformdirs 4.2.2
pluggy 1.5.0
prometheus_client 0.20.0
prompt-toolkit 3.0.43
protobuf 5.27.1
psutil 5.9.0
ptyprocess 0.7.0
pure-eval 0.2.2
pyarrow 16.1.0
pyarrow-hotfix 0.6
pycodestyle 2.6.0
pycparser 2.22
pyctcdecode 0.5.0
pydantic 2.7.4
pydantic_core 2.18.4
pydocstyle 6.3.0
pydub 0.25.1
pyflakes 2.2.0
Pygments 2.15.1
pygtrie 2.5.0
pylint 3.2.3
pyparsing 3.1.2
pytest 7.2.0
python-dateutil 2.9.0.post0
python-dotenv 1.0.1
python-json-logger 2.0.7
python-jsonrpc-server 0.4.0
python-language-server 0.36.2
python-multipart 0.0.9
pytoolconfig 1.3.1
pytz 2024.1
PyYAML 6.0.1
pyzmq 25.1.2
referencing 0.35.1
regex 2024.5.15
requests 2.32.3
rfc3339-validator 0.1.4
rfc3986-validator 0.1.1
rich 13.7.1
rope 1.13.0
rpds-py 0.18.1
ruff 0.4.10
safetensors 0.4.3
schedule-search 0.0.1
scikit-learn 1.5.0
scipy 1.13.1
semantic-version 2.10.0
Send2Trash 1.8.3
sentencepiece 0.2.0
setuptools 69.5.1
shellingham 1.5.4
six 1.16.0
smmap 5.0.1
sniffio 1.3.1
snowballstemmer 2.2.0
sortedcontainers 2.4.0
soupsieve 2.5
stack-data 0.2.0
starlette 0.37.2
sympy 1.12.1
synr 0.5.0
te 0.4.0
terminado 0.18.1
threadpoolctl 3.5.0
tinycss2 1.3.0
tokenizers 0.19.1
toml 0.10.2
tomli 2.0.1
tomlkit 0.12.0
toolz 0.12.1
tornado 6.4.1
tqdm 4.66.4
traitlets 5.14.3
typer 0.12.3
types-python-dateutil 2.9.0.20240316
typing_extensions 4.11.0
tzdata 2024.1
ujson 5.10.0
uri-template 1.3.0
urllib3 2.2.2
uvicorn 0.30.1
uvloop 0.19.0
watchfiles 0.22.0
wcwidth 0.2.5
webcolors 24.6.0
webencodings 0.5.1
websocket-client 1.8.0
websockets 11.0.3
wheel 0.43.0
widgetsnbextension 4.0.11
xxhash 3.4.1
y-py 0.6.2
yapf 0.40.2
yarl 1.9.4
ypy-websocket 0.8.4
zipp 3.17.0
数据集介绍
数据集是已标注的、经过分词预处理的机器人聊天数据集,来自于百度飞桨团队。数据由两列组成,以制表符('\t')分隔,第一列是情绪分类的类别(0表示消极;1表示中性;2表示积极),第二列是以空格分词的中文文本,如下示例,文件为 utf8 编码。
label--text_a
0--谁骂人了?我从来不骂人,我骂的都不是人,你是人吗 ?
1--我有事等会儿就回来和你聊
2--我见到你很高兴谢谢你帮我
这部分主要包括数据集读取,数据格式转换,数据 Tokenize 处理和 pad 操作。
# download dataset
!wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
!tar xvf emotion_detection.tar.gz
实践代码
import osimport mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn, contextfrom mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy# prepare dataset
class SentimentDataset:"""Sentiment Dataset"""def __init__(self, path):self.path = pathself._labels, self._text_a = [], []self._load()def _load(self):with open(self.path, "r", encoding="utf-8") as f:dataset = f.read()lines = dataset.split("\n")for line in lines[1:-1]:label, text_a = line.split("\t")self._labels.append(int(label))self._text_a.append(text_a)def __getitem__(self, index):return self._labels[index], self._text_a[index]def __len__(self):return len(self._labels)# 数据加载和数据预处理
# 新建 process_dataset 函数用于数据加载和数据预处理,具体内容可见下面代码注释。import numpy as npdef process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True):is_ascend = mindspore.get_context('device_target') == 'Ascend'column_names = ["label", "text_a"]dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)# transformstype_cast_op = transforms.TypeCast(mindspore.int32)def tokenize_and_pad(text):if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text)return tokenized['input_ids'], tokenized['attention_mask']# map datasetdataset = dataset.map(operations=tokenize_and_pad, input_columns="text_a", output_columns=['input_ids', 'attention_mask'])dataset = dataset.map(operations=[type_cast_op], input_columns="label", output_columns='labels')# batch datasetif is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return dataset# 昇腾NPU环境下暂不支持动态Shape,数据预处理部分采用静态Shape处理:
from mindnlp.transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')tokenizer.pad_token_id
dataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer)
dataset_val = process_dataset(SentimentDataset("data/dev.tsv"), tokenizer)
dataset_test = process_dataset(SentimentDataset("data/test.tsv"), tokenizer, shuffle=False)print(dataset_train.get_col_names())
print(next(dataset_train.create_tuple_iterator()))
模型构建
通过 BertForSequenceClassification 构建用于情感分类的 BERT 模型,加载预训练权重,设置情感三分类的超参数自动构建模型。后面对模型采用自动混合精度操作,提高训练的速度,然后实例化优化器,紧接着实例化评价指标,设置模型训练的权重保存策略,最后就是构建训练器,模型开始训练。
from mindnlp.transformers import BertForSequenceClassification, BertModel
from mindnlp._legacy.amp import auto_mixed_precision# set bert config and define parameters for training
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)
model = auto_mixed_precision(model, 'O1')optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)
metric = Accuracy()
# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='bert_emotect', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='bert_emotect_best', auto_load=True)trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_val, metrics=metric,epochs=5, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb])%%time
# start training
trainer.run(tgt_columns="labels")
模型验证
将验证数据集加再进训练好的模型,对数据集进行验证,查看模型在验证数据上面的效果,此处的评价指标为准确率。
evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")
模型推理
遍历推理数据集,将结果与标签进行统一展示。
dataset_infer = SentimentDataset("data/infer.tsv")def predict(text, label=None):label_map = {0: "消极", 1: "中性", 2: "积极"}text_tokenized = Tensor([tokenizer(text).input_ids])logits = model(text_tokenized)predict_label = logits[0].asnumpy().argmax()info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"if label is not None:info += f" , label: '{label_map[label]}'"print(info)from mindspore import Tensorfor label, text in dataset_infer:predict(text, label)
自定义推理数据集
自己输入推理数据,展示模型的泛化能力。