李宏毅机器学习2022-HW7-BERT-Question Answering

文章目录

  • Task
  • Baseline
    • Medium
    • Strong
    • Boss
  • Code Link

Task

HW7的任务是通过BERT完成Question Answering。

数据预处理流程梳理

数据解压后包含3个json文件:hw7_train.json, hw7_dev.json, hw7_test.json。

DRCD: 台達閱讀理解資料集 Delta Reading Comprehension Dataset

ODSQA: Open-Domain Spoken Question Answering Dataset

  • train: DRCD + DRCD-TTS
    • 10524 paragraphs, 31690 questions
  • dev: DRCD + DRCD-TTS
    • 1490 paragraphs, 4131 questions
  • test: DRCD + ODSQA
    • 1586 paragraphs, 4957 questions

{train/dev/test}_questions:

  • List of dicts with the following keys:
  • id (int)
  • paragraph_id (int)
  • question_text (string)
  • answer_text (string)
  • answer_start (int)
  • answer_end (int)

{train/dev/test}_paragraphs:

  • List of strings
  • paragraph_ids in questions correspond to indexs in paragraphs
  • A paragraph may be used by several questions

读取这三个文件,每个文件返回相应的question数据和paragraph数据,都是文本数据,不能作为模型的输入。

利用Tokenization将question和paragraph文本数据先按token为单位分开,再转换为tokens_to_ids数字数据。Dataset选取paragraph中固定长度的片段(固定长度为150),片段需包含answer部分,然后使用Tokenization 以CLS + question + SEP + document+ CLS + padding(不足的补0)的形式作为训练输入。

Total sequence length = question length + paragraph length + 3 (special tokens)
Maximum input sequence length of BERT is restricted to 512

在这里插入图片描述
在这里插入图片描述

training

在这里插入图片描述

testing

对于每个窗口,模型预测一个开始分数和一个结束分数,取最大值作为答案

在这里插入图片描述

Baseline

Medium

应用linear learning rate decay+change doc_stride

这里linear learning rate decay选用了两种方法

  • 手动调整学习率

    假设初始学习率为 η 0 η_0 η0,总的步骤数为 T T T,那么在第 t t t步时的学习率 η t η_t ηt 可以表示为:

    η t = η 0 − η 0 T × t η_t=η_0−\frac{η_0}{T}×t ηt=η0Tη0×t

    其中:

    • η 0 η_0 η0 是初始学习率。
    • T T T是总的步骤数(total_step)。
    • t t t 是当前的步骤数(从 0 开始计数)。

    optimizer.param_groups[0]["lr"] -= learning_rate / total_step η t = η t − η 0 T η t η_t=η_t−\frac{η_0}{T}η_t ηt=ηtTη0ηt

    • optimizer.param_groups[0]["lr"] 对应 η t η_t ηt
    • learning_rate 对应 η 0 η_0 η0
    • total_step 对应 T T T
    • i 对应 t t t
    # Medium--Learning rate dacay
    # Method 1: adjust learning rate manually
    total_step = 1000
    for i in range(total_step):optimizer.param_groups[0]["lr"] -= learning_rate / total_step
    
  • 通过scheduler自动调整学习率

    • (recommend) transformer
    • torch.optim
    # Method 2: Adjust learning rate automatically by scheduler# (Recommend) https://huggingface.co/transformers/main_classes/optimizer_schedules.html#transformers.get_linear_schedule_with_warmup
    from transformers import get_linear_schedule_with_warmup
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=1000)# https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
    # 这里如果要用pytorch的ExponentialLR,一定要导入optim模块,并且前面的AdamW是从transformers中import的这里要重新import
    import torch.optim as optim
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
    

change doc_stride在QA_Dataset的时候修改段落滑动窗口的步长

##### TODO: Change value of doc_stride #####
# 段落滑动窗口的步长
self.doc_stride = 30  # Medium

Strong

应用➢ Improve preprocessing ➢ Try other pretrained models

  • 尝试其他预训练模型

比如bert-base-multilingual-case,因为它可以避免英文无法tokenization输出[UNK],但是计算量大

model = BertForQuestionAnswering.from_pretrained("hfl/chinese-macbert-large").to(device)
tokenizer = BertTokenizerFast.from_pretrained("hfl/chinese-macbert-large")
  • preprocessing ,在QA_Dataset中修改截取答案的窗口

    1. 随机窗口选择 Random Window Selection
      随机选择窗口的起始位置

      • 随机范围的下界
        start_min = max(0, answer_end_token - self.max_paragraph_len + 1) 答案结束位置向前移动 self.max_paragraph_len - 1 个标记后的位置和 0 较大的那个
      • 随机范围的上界
        start_max = min(answer_start_token, len(tokenized_paragraph) - self.max_paragraph_len)
        • len(tokenized_paragraph) - self.max_paragraph_len:计算段落长度减去窗口长度后的位置,确保窗口不会超出段落末尾。
        • min(answer_start_token, ...):确保上界不超过答案开始位置,避免答案被截断。
      • 随机选择
        paragraph_start = random.randint(start_min, start_max)在计算出的下界和上界之间随机选择一个整数作为窗口的起始位置。
      • 计算窗口结束位置
        paragraph_end = paragraph_start + self.max_paragraph_len确保窗口长度为 self.max_paragraph_len
    2. 滑动窗口大小 Dynamic window size

            ##### TODO: Preprocessing Strong ###### Hint: How to prevent model from learning something it should not learnif self.split == "train":# Convert answer's start/end positions in paragraph_text to start/end positions in tokenized_paragraphanswer_start_token = tokenized_paragraph.char_to_token(question["answer_start"])answer_end_token = tokenized_paragraph.char_to_token(question["answer_end"])# A single window is obtained by slicing the portion of paragraph containing the answer# 在training中paragraph的截取依据的是answer的position id"""mid = (answer_start_token + answer_end_token) // 2paragraph_start = max(0, min(mid - self.max_paragraph_len // 2, len(tokenized_paragraph) - self.max_paragraph_len))paragraph_end = paragraph_start + self.max_paragraph_len"""# Strong# Method 1: Random window selectionstart_min = max(0, answer_end_token - self.max_paragraph_len + 1)  # 计算答案结束位置向前移动 self.max_paragraph_len - 1 个标记后的位置start_max = min(answer_start_token, len(tokenized_paragraph) - self.max_paragraph_len)start_max = max(start_min, start_max)paragraph_start = random.randint(start_min, start_max + 1)paragraph_end = paragraph_start + self.max_paragraph_len"""# Method 2: Dynamic window size # 这个会造成窗口的大小大于max_paragraph_len,那么会造成输入序列的长度不一致,后面padding也要改,这里暂不采用answer_length = answer_end_token - answer_start_tokendynamic_window_size = max(self.max_paragraph_len, answer_length + 20)  # 添加一些额外的空间paragraph_start = max(0, min(answer_start_token - dynamic_window_size // 2, len(tokenized_paragraph) - dynamic_window_size))paragraph_end = paragraph_start + dynamic_window_size"""

Boss

➢ Improve postprocessing ➢ Further improve the above hints

doc_stride + max_length+ learning rate scheduler + preprocessing+ postprocessing + new model + no validation

与strong baseline相比,最大的改变有两个,一是换pretrain model,在hugging face中搜索chinese + QA的模型,根据model card描述选择最好的模型,使用后大概提升2.5%的精度,二是更近一步的postprocessing,查看提交文件可看到很多answer包含CLS, SEP, UNK等字符,CLS和SEP的出现表示预测位置有误,UNK的出现说明有某些字符无法正常编码解码(例如一些生僻字),错误字符的问题均可在evaluate函数中改进,这个步骤提升了大概1%的精度。其他的修改主要是针对overfitting问题,包括减少了learning rate,提升dataset里面的paragraph max length, 将validation集合和train集合进行合并等。另外可使用的办法有ensemble,大概能提升0.5%的精度,改变random seed,也有提升的可能性。

if start_index > end_index or start_index < paragraph_start or end_index > paragraph_end:continueif '[UNK]' in answer:print('发现 [UNK],这表明有文字无法编码, 使用原始文本')#print("Paragraph:", paragraph)#print("Paragraph:", paragraph_tokenized.tokens)print('--直接解码预测:', answer)#找到原始文本中对应的位置raw_start =  paragraph_tokenized.token_to_chars(origin_start)[0]raw_end = paragraph_tokenized.token_to_chars(origin_end)[1]answer = paragraph[raw_start:raw_end]print('--原始文本预测:',answer)

Code Link

github

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/452055.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

react 中的hooks中的useState

(1). State Hook让函数组件也可以有state状态, 并进行状态数据的读写操作 (2). 语法: const [xxx, setXxx] React.useState(initValue) (3). useState()说明:参数: 第一次初始化指定的值在内部作缓存返回值: 包含2个元素的数组, 第1个为内部当前状态值, 第2个为更新状态值的…

jmeter用csv data set config做参数化1

在jmeter中&#xff0c;csv data set config的作用非常强大&#xff0c;用它来做批量测试和参数化非常好用。 csv data set config的常用配置项如下&#xff1a; Variable Names处&#xff0c;写上源文件中的参数名&#xff0c;用于后续接口发送请求时引用 Ignore first line…

【Linux】waitpid函数 及其 非阻塞等待和阻塞等待

父进程等待子进程结束可以通过两种方式实现&#xff1a;阻塞等待和非阻塞等待。这两种方式各有优缺点&#xff0c;适用于不同的场景。 简单来说&#xff1a; 阻塞等待&#xff1a;先等你&#xff0c;我再继续 非阻塞等待&#xff1a;不等你&#xff0c;我继续做自己的事&…

初识适配器模式

适配器模式 引入 生活中的例子&#xff1a;当我们使用手机充电时&#xff0c;充电器起到了转换器的作用&#xff0c;它将家用的220伏特电压转换成适合手机充电的5伏特电压。 适配器模式的三种类型 命名原则&#xff1a;适配器的命名应基于资源如何传递给适配器来进行。 类适配…

Web架构演变历程~

1、背景 对于服务架构&#xff0c;这个名词大家应该都很熟悉了吧&#xff0c;一个好的架构并不是一个最合适的架构&#xff0c;在对于选择那种架构&#xff0c;对于一个项目后续发展致关重要&#xff0c;接下来我们一起走进web服务架构的演变历程看看吧&#xff01; 2、服务架…

基于STM32F407VGT6芯片----跑马灯实验

一、在STM32F407VGT6芯片中配置GPIO环境 对于一个跑马灯实验&#xff0c;首先&#xff0c;要了解的就是&#xff0c;芯片是如何构造出来的&#xff0c;设计GPIO引脚&#xff1a;根据原理图&#xff0c; PC4&#xff0c;PC5,PC6,PC7 为 LED 输出控制管脚&#xff0c;PE0 为蜂鸣…

Spring Boot视频网站:安全与可扩展性设计

4 系统设计 4.1系统概要设计 视频网站系统并没有使用C/S结构&#xff0c;而是基于网络浏览器的方式去访问服务器&#xff0c;进而获取需要的数据信息&#xff0c;这种依靠浏览器进行数据访问的模式就是现在用得比较广泛的适用于广域网并且没有网速限制要求的B/S结构&#xff0c…

SpringDataRedis快速入门

SpringDataRedis 什么是SpringDataRedis SpringData是Spring中数据操作的模块,包含对各种数据库的集成,其中对Redis的集成模块就叫做SpringDataRedis SpringDataRedis中提供了RedsiTemplate工具类,其中封装了各种对Redis的操作。并且将不同数据类型的操作API封装到了不同的类…

Atlas800昇腾服务器(型号:3000)—YOLO全系列NPU推理【检测】(五)

服务器配置如下&#xff1a; CPU/NPU&#xff1a;鲲鹏 CPU&#xff08;ARM64&#xff09;A300I pro推理卡 系统&#xff1a;Kylin V10 SP1【下载链接】【安装链接】 驱动与固件版本版本&#xff1a; Ascend-hdk-310p-npu-driver_23.0.1_linux-aarch64.run【下载链接】 Ascend-…

YOLOv11模型改进-注意力机制-引入自适应稀疏自注意力ASSA

随着目标检测领域的快速发展&#xff0c;YOLO系列模型凭借其端到端、高效的检测性能逐渐成为工业界和学术界的标杆。然而&#xff0c;如何进一步优化YOLOv11的特征提取能力&#xff0c;减少冗余信息并提升模型对复杂场景的适应性&#xff0c;仍是一个值得深入探讨的问题。为此&…

Atlas800昇腾服务器(型号:3000)—驱动与固件安装(一)

服务器配置如下&#xff1a; CPU/NPU&#xff1a;鲲鹏 CPU&#xff08;ARM64&#xff09;A300I pro推理卡 系统&#xff1a;Kylin V10 SP1【下载链接】【安装链接】 驱动与固件版本版本&#xff1a; Ascend-hdk-310p-npu-driver_23.0.1_linux-aarch64.run【下载链接】 Ascend-…

scrapy 爬虫学习之【中医药材】爬虫

本项目纯学习使用。 1 scrapy 代码 爬取逻辑非常简单&#xff0c;根据url来处理翻页&#xff0c;然后获取到详情页面的链接&#xff0c;再去爬取详情页面的内容即可&#xff0c;最终数据落地到excel中。 经测试&#xff0c;总计获取 11299条中医药材数据。 import pandas as…

特斯拉Robotaxi发布会2024:自动驾驶未来的开端

引言 2024年10月&#xff0c;特斯拉在洛杉矶举行了一场引发全球科技界高度关注的发布会&#xff0c;主题为“We Robot”。这场发布会展示了特斯拉的最新自动驾驶技术&#xff0c;包括无人驾驶出租车Cybercab和无人驾驶厢式货车Robovan&#xff0c;并且还展示了人形机器人Optim…

Java项目-基于springboot框架的社区疫情防控平台系统项目实战(附源码+文档)

作者&#xff1a;计算机学长阿伟 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、ElementUI等&#xff0c;“文末源码”。 开发运行环境 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringBoot、Vue、Mybaits Plus、ELementUI工具&#xff1a;IDEA/…

精选的四款强大视频压缩工具的整理:

大家好&#xff01;今天我来跟大家分享一下我使用过的几款视频压缩软件的体验感受&#xff0c;以及它们各自的好用之处&#xff1b;在这个信息爆炸的时代&#xff0c;视频文件越来越大&#xff0c;如何快速有效地压缩视频&#xff0c;同时还能保持较好的画质&#xff0c;是很多…

大模型~合集14

我自己的原文哦~ https://blog.51cto.com/whaosoft/12286799 # Attention as an RNN Bengio等人新作&#xff1a;注意力可被视为RNN&#xff0c;新模型媲美Transformer&#xff0c;但超级省内 , 既能像 Transformer 一样并行训练&#xff0c;推理时内存需求又不随 token 数线性…

基于DNA算法的遥感图像加解密matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 通过DNA算法对遥感图像进行加密和解密&#xff0c;分析加解密处理后图像的直方图&#xff0c;相关性&#xff0c;熵&#xff0c;解密后图像质量等。 2.测试软件版…

MongoDB安装配置及配置和启动服务

MongoDB 安装配置 附&#xff1a;MongoDB官网下载地址&#xff1a; https://www.mongodb.com/download-center/community 注&#xff1a; 官网可以下载最新版的MongoDB安装包&#xff0c;有MSI安装版和ZIP安装版。我们课堂上使用4.4.4的ZIP安装版。安装版参考博客&#xff1…

jmeter中对于有中文内容的csv文件怎么保存

jmeter的功能很强大&#xff0c;但是细节处没把握好就得不到预期的结果。今天来讲讲有中文内容的csv文件的参数化使用中需要注意的事项。 对于有中文内容&#xff0c;涉及到编码格式&#xff0c;为了让jmeter能正确地读取csv文件中的中文&#xff0c;需要把文件转码为UTF-8BOM…

【服务器部署】Docker部署小程序

一、下载Docker 安装之前&#xff0c;一定查看是否安装docker&#xff0c;如果有&#xff0c;卸载老版本 我是虚拟机装的Centos7&#xff0c;linux 3.10 内核&#xff0c;docker官方说至少3.8以上&#xff0c;建议3.10以上&#xff08;ubuntu下要linux内核3.8以上&#xff0c…