Bert 在 OCNLI 训练微调

目录

  • 0 资料
  • 1 预训练权重
  • 2 wandb
  • 3 Bert-OCNLI
    • 3.1 目录结构
    • 3.2 导入的库
    • 3.3 数据集
      • 自然语言推断
      • 数据集路径
      • 读取数据集
      • 数据集样例展示
      • 数据集类别统计
      • 数据集类
      • 加载数据
    • 3.4 Bert
    • 3.4 训练
  • 4 训练微调结果
    • 3k
    • 10k
    • 50k

0 资料

【数据集微调】

阿里天池比赛 微调BERT的数据集(“任务1:OCNLI–中文原版自然语言推理”)

数据集地址:https://tianchi.aliyun.com/competition/entrance/531841/information

由于这个比赛已经结束,原地址提交不了榜单看测试结果,请参照下面的信息,下载数据集、提交榜单测试。

  • “任务1:OCNLI–中文原版自然语言推理”数据集的GitHub地址:https://github.com/CLUEbenchmark/OCNLI

  • 榜单提交地址:https://www.cluebenchmarks.com/index.html

  • 榜单提交步骤:

    • 打开“榜单提交地址”,点击“立即测评”——填写相关信息(github地址填https://github.com/CLUEbenchmark/CLUE,其他信息任意填)。
    • 上传一个.zip压缩文件,在压缩文件里存放我们模型预测结果的文件。
    • 点击提交。
  • 【注意】预测结果文件的格式:https://storage.googleapis.com/cluebenchmark/tasks/clue_submit_examples.zip

15.4. 自然语言推断与数据集:https://zh-v2.d2l.ai/chapter_natural-language-processing-applications/natural-language-inference-and-dataset.html

15.7. 自然语言推断:微调BERT:https://zh-v2.d2l.ai/chapter_natural-language-processing-applications/natural-language-inference-bert.html#id3

保姆级教程,用PyTorch和BERT进行文本分类:https://zhuanlan.zhihu.com/p/524487313

1 预训练权重

在国内,一般是手动下载预训练权重,而非网络自动下载。

我们将用到 chinese-macbert-base 这个预训练文件,下载网址如下:

https://huggingface.co/hfl/chinese-macbert-base/tree/main

除了叉掉的,其余都要下载。
在这里插入图片描述

2 wandb

pip install wandb

WandB 是一个用于实验跟踪、版本控制和结果可视化的工具,主要用于机器学习项目。
wandb使用教程(一):基础用法:https://zhuanlan.zhihu.com/p/493093033

3 Bert-OCNLI

3.1 目录结构

在这里插入图片描述

3.2 导入的库

import os
import torch
from torch import nn
import pandas as pd
from transformers import BertModel, BertTokenizer
from torch.optim import Adam
from tqdm import tqdm

3.3 数据集

自然语言推断

自然语言推断(natural language inference)主要研究 假设(hypothesis)是否可以从前提(premise)中推断出来, 其中两者都是文本序列。 换言之,自然语言推断决定了一对文本序列之间的逻辑关系。这类关系通常分为三种类型:

蕴涵(entailment):假设可以从前提中推断出来。矛盾(contradiction):假设的否定可以从前提中推断出来。中性(neutral):所有其他情况。

自然语言推断也被称为识别文本蕴涵任务。 例如,下面的一个文本对将被贴上“蕴涵”的标签,因为假设中的“表白”可以从前提中的“拥抱”中推断出来。

前提:两个女人拥抱在一起。假设:两个女人在示爱。

下面是一个“矛盾”的例子,因为“运行编码示例”表示“不睡觉”,而不是“睡觉”。

前提:一名男子正在运行Dive Into Deep Learning的编码示例。假设:该男子正在睡觉。

第三个例子显示了一种“中性”关系,因为“正在为我们表演”这一事实无法推断出“出名”或“不出名”。

前提:音乐家们正在为我们表演。假设:音乐家很有名。

自然语言推断一直是理解自然语言的中心话题。它有着广泛的应用,从信息检索到开放领域的问答。为了研究这个问题,我们将首先研究一个流行的自然语言推断基准数据集。

数据集路径

# 数据集路径
data_dir = 'OCNLI/data/ocnli'

读取数据集

# 读ocnli,两个参数,data_dir是数据集的路径,is_train为bool类型,True代表训练,False代表验证
def read_ocnli(data_dir, is_train):# 将ocnli解析为前提、假设、标签# labels_map是标签映射,0、1、2代表三类,3代表无法分类(或者应该去除的数据)。labels_map = {'entailment':0, 'neutral':1, 'contradiction':2, '-': 3}file_name = os.path.join(data_dir, 'train.3k.json' if is_train else 'dev.json')rows = pd.read_json(file_name, lines=True)premises = [sentence1 for sentence1 in rows['sentence1'] ]  # 前提hypotheses = [sentence2 for sentence2 in rows['sentence2'] ] # 假设# if label != '-' 是为了去除无法分类的标签labels = [labels_map[label] for label in rows['label'] if label != '-'] # 标签return premises, hypotheses, labels

数据集样例展示

# 样例展示
train_data = read_ocnli(data_dir, is_train=True)
for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):print("前提:", x0)print("假设:", x1)print("标签:", y)

结果:

前提: 现在,我代表国务院,向大会报告政府工作,请予审议,并请全国政协委员提出意见
假设: 全国政协委员无权提出建议
标签: 2
前提: 不过以后呢,两年增加一次工资.
假设: 多年之后工资很高
标签: 1
前提: 一万块,嗯那头盔要八千.
假设: 说话的人很有钱
标签: 1

数据集类别统计

# 类别数据统计
val_data = read_ocnli(data_dir, is_train=False)label_set = [0, 1, 2]for data in [train_data, val_data]:print([[row for row in data[2]].count(i) for i in label_set])

结果:

[974, 1054, 966]
[947, 1103, 900]

数据集类

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
class OCNLI_Dataset(torch.utils.data.Dataset):def __init__(self, dataset):sentence1 = [sentence1 for sentence1 in dataset[0]]sentence2 = [sentence2 for sentence2 in dataset[1]]# 用 _ 将前提和假设拼接在一起,但这应该不是好的做法sentence1_2 = ['{}_{}'.format(a, b) for a, b in zip(sentence1, sentence2)]self.texts = [tokenizer(sentence, padding='max_length', # bert最大可以设置到512,对OCNLI的统计计算中,# 发现所有数据没有超过128,max_length越大,计算量越大max_length = 128, truncation=True,return_tensors="pt") for sentence in sentence1_2 ] self.labels = torch.tensor(dataset[2])def __len__(self):return len(self.labels)def __getitem__(self, idx):return self.texts[idx], self.labels[idx]

加载数据

train_set = OCNLI_Dataset(read_ocnli(data_dir, True))
test_set = OCNLI_Dataset(read_ocnli(data_dir, False))
print(len(train_set))
# for train_input, train_label in train_set:
#     print(train_input)
#     print(train_label)
#     input()

结果:

3000

3.4 Bert

class BertClassifier(nn.Module):def __init__(self, dropout=0.5):super(BertClassifier, self).__init__()self.bert = BertModel.from_pretrained('bert-base-chinese')self.dropout = nn.Dropout(dropout)self.linear = nn.Linear(768, 3) # 这里的3代表输出的类别self.relu = nn.ReLU()def forward(self, input_id, mask):_, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)dropout_output = self.dropout(pooled_output)linear_output = self.linear(dropout_output)final_layer = self.relu(linear_output)return final_layer

3.4 训练

def train(model, train_data, val_data, learning_rate, epochs):# 通过Dataset类获取训练和验证集train, val = OCNLI_Dataset(train_data), OCNLI_Dataset(val_data)# DataLoader根据batch_size获取数据,训练时选择打乱样本train_dataloader = torch.utils.data.DataLoader(train, batch_size=32, shuffle=True)val_dataloader = torch.utils.data.DataLoader(val, batch_size=32)# 判断是否使用GPUuse_cuda = torch.cuda.is_available()device = torch.device("cuda" if use_cuda else "cpu")# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = Adam(model.parameters(), lr=learning_rate)if use_cuda:model = model.cuda()criterion = criterion.cuda()# 开始进入训练循环for epoch_num in range(epochs):# 定义两个变量,用于存储训练集的准确率和损失total_acc_train = 0total_loss_train = 0# 进度条函数tqdmfor train_input, train_label in tqdm(train_dataloader):train_label = train_label.to(device)mask = train_input['attention_mask'].to(device)input_id = train_input['input_ids'].squeeze(1).to(device)# 通过模型得到输出output = model(input_id, mask)# 计算损失batch_loss = criterion(output, train_label)# input()total_loss_train += batch_loss.item()# print("total_loss_train:",total_loss_train)# 计算精度acc = (output.argmax(dim=1) == train_label).sum().item()total_acc_train += acc# 模型更新model.zero_grad()batch_loss.backward()optimizer.step()# ------ 验证模型 -----------# 定义两个变量,用于存储验证集的准确率和损失total_acc_val = 0total_loss_val = 0# 不需要计算梯度with torch.no_grad():# 循环获取数据集,并用训练好的模型进行验证for val_input, val_label in val_dataloader:# 如果有GPU,则使用GPU,接下来的操作同训练val_label = val_label.to(device)mask = val_input['attention_mask'].to(device)input_id = val_input['input_ids'].squeeze(1).to(device)output = model(input_id, mask)batch_loss = criterion(output, val_label)total_loss_val += batch_loss.item()acc = (output.argmax(dim=1) == val_label).sum().item()total_acc_val += accprint(f'''Epochs: {epoch_num + 1} | Train Loss: {total_loss_train / len(train): .3f} | Train Accuracy: {total_acc_train / len(train): .3f} | Val Loss: {total_loss_val / len(train): .3f} | Val Accuracy: {total_acc_val / len(train): .3f}''')     print("total_loss_train:",total_loss_train)print("total_acc_train:",total_acc_train)print("total_loss_val:",total_loss_val)print("total_acc_val:",total_acc_val)print("len(train_data):",len(train))          
EPOCHS = 50
model = BertClassifier()
LR = 1e-6
train(model, read_ocnli(data_dir, True), read_ocnli(data_dir, False), LR, EPOCHS)

在这里插入图片描述

4 训练微调结果

3k

10k

50k

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/323351.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

UE5(射线检测)学习笔记

这一篇会讲解射线检测点击事件、离开悬停、进入悬停事件的检测,以及关闭射线检测的事件,和射线检测蓝图的基础讲解。 创建一个简单的第三人称模板 创建一个射线检测的文件夹RadiationInspection,并且右键蓝图-场景组件-命名为BPC_Radiation…

路由模块封装

目录 一、问题引入 二、步骤 一、问题引入 随着项目内容的不断扩大,路由也会越来越多,把所有的路由配置都堆在main.js中就不太合适了,所以需要将路由模块抽离出来。其好处是:拆分模块,利于维护。 二、步骤 将路由相…

linux PXE高效批量网络装机

PXE批量部署的优点 规模化:同时装配多台服务器 自动化:安装系统、配置各种服务 远程实现:不需要光盘、U盘等安装介质 部署PXE远程安装服务 搭建PXE远程安装服务器 先做好初始化准备 1.安装并启用 TFTP 服务 yum -y install tftp-server …

从开发角度理解漏洞成因(02)

文章目录 文件上传类需求文件上传漏洞 文件下载类需求文件下载漏洞 扩展 留言板类(XSS漏洞)需求XSS漏洞 登录类需求cookie伪造漏洞万能密码登录 持续更新中… 文章中代码资源已上传资源,如需要打包好的请点击PHP开发漏洞环境(SQL注…

缓存相关问题:雪崩、穿透、预热、更新、降级的深度解析

✨✨祝屏幕前的小伙伴们每天都有好运相伴左右✨✨ 🎈🎈作者主页: 喔的嘛呀🎈🎈 目录 引言 1. 缓存雪崩 1.1 问题描述 1.2 解决方案 1.2.1 加锁防止并发重建缓存 2. 缓存穿透 2.1 问题描述 2.2 解决方案 2.2.1 …

基于TL431和CSA的恒压与负压输出

Hello uu们,51去那里玩了呀?该收心回来上班了,嘿嘿! 为什么会有这个命题,因为我的手头只有这些东西如何去实现呢?让我们一起来看电路图吧.电路图如下图1所示 图1:CSA恒压输出电路 图1中,R1给U2提供偏置,Q1给R1提供电流,当U1-VOUT输出大于2.5V时候,U2内部的三极管CE导通,使得…

Golang | Leetcode Golang题解之第73题矩阵置零

题目&#xff1a; 题解&#xff1a; func setZeroes(matrix [][]int) {n, m : len(matrix), len(matrix[0])col0 : falsefor _, r : range matrix {if r[0] 0 {col0 true}for j : 1; j < m; j {if r[j] 0 {r[0] 0matrix[0][j] 0}}}for i : n - 1; i > 0; i-- {for …

Python的Web框架Flask+Vue生成漂亮的词云图

生成效果图 输入待生成词云图的文本&#xff0c;点击生成词云即可&#xff0c;在词云图生成之后&#xff0c;可以点击下载图片保存词云图。 运行步骤 分别用前端和后端编译器&#xff0c;打开backend和frontend文件夹。前端运行 npm install &#xff0c;安装相应的包。后端…

【prometheus】Pushgateway安装和使用

目录 一、Pushgateway概述 1.1 Pushgateway简介 1.2 Pushgateway优点 1.3 pushgateway缺点 二、测试环境 三、安装测试 3.1 pushgateway安装 3.2 prometheus添加pushgateway 3.3 推送指定的数据格式到pushgateway 1.添加单条数据 2.添加复杂数据 3.SDk-prometheus-…

Python深度学习基于Tensorflow(8)自然语言处理基础

RNN 模型 与前后顺序有关的数据称为序列数据&#xff0c;对于序列数据&#xff0c;我们可以使用循环神经网络进行处理&#xff0c;循环神经网络RNN已经成功的运用于自然语言处理&#xff0c;语音识别&#xff0c;图像标注&#xff0c;机器翻译等众多时序问题&#xff1b;RNN模…

16地标准化企业申请!安徽省工业和信息化领域标准化示范企业申报条件

安徽省工业和信息化领域标准化示范企业申报条件有哪些&#xff1f;合肥市 、黄山市 、芜湖市、马鞍山、安庆市、淮南市、阜阳市、淮北市、铜陵市、亳州市、宣城市、蚌埠市、六安市 、滁州市 、池州市、宿州市企业申报安徽省工业和信息化领域标准化示范企业有不明白的可在下文了…

《TAM》论文笔记(上)

原文链接 [2005.06803] TAM: Temporal Adaptive Module for Video Recognition (arxiv.org) 原文代码 GitHub - liu-zhy/temporal-adaptive-module: TAM: Temporal Adaptive Module for Video Recognition 原文笔记 What&#xff1a; TAM: Temporal Adaptive Module for …

JAVA系列:IO流

JAVA IO流 IO流图解 一、什么是IO流 I/O流是Java中用于执行输入和输出操作的抽象。它们被设计成类似于流水&#xff0c;可以在程序和外部源&#xff08;如文件、网络套接字、键盘、显示器等&#xff09;之间传输数据。按处理数据单位分为&#xff1a; 1字符 2字节 、 1字节(…

阿里发布通义千问2.5:一文带你读懂通义千问!

大家好&#xff0c;我是木易&#xff0c;一个持续关注AI领域的互联网技术产品经理&#xff0c;国内Top2本科&#xff0c;美国Top10 CS研究生&#xff0c;MBA。我坚信AI是普通人变强的“外挂”&#xff0c;所以创建了“AI信息Gap”这个公众号&#xff0c;专注于分享AI全维度知识…

【Linux系统编程】31.pthread_detach、线程属性

目录 pthread_detach 参数pthread 返回值 测试代码1 测试结果 pthread_attr_init 参数attr 返回值 pthread_attr_destroy 参数attr 返回值 pthread_attr_setdetachstate 参数attr 参数detachstate 返回值 测试代码2 测试结果 线程使用注意事项 pthread_deta…

AI智能分析高精度烟火算法EasyCVR视频方案助力打造森林防火建设

一、背景 随着夏季的来临&#xff0c;高温、干燥的天气条件使得火灾隐患显著增加&#xff0c;特别是对于广袤的森林地区来说&#xff0c;一旦发生火灾&#xff0c;后果将不堪设想。在这样的背景下&#xff0c;视频汇聚系统EasyCVR视频融合云平台AI智能分析在森林防火中发挥着至…

GeoServer 任意文件上传漏洞分析研究 CVE-2023-51444

目录 前言 漏洞信息 代码审计 漏洞复现 前言 时隔半月&#xff0c;我又再一次地审起了这个漏洞。第一次看到这个漏洞信息时&#xff0c;尝试复现了一下&#xff0c;结果却很不近人意。从官方公布的漏洞信息来看细节还是太少&#xff0c;poc不是一次就能利用成功的。也是当时…

AXI4读时序在AXI Block RAM (BRAM) IP核中的应用

在本文中将展示描述了AXI从设备&#xff08;slave&#xff09;AXI BRAM Controller IP核与Xilinx AXI Interconnect之间的读时序关系。 1 Single Read 图1展示了一个从32位BRAM&#xff08;Block RAM&#xff09;进行AXI单次读取操作的时序示例。 图1 AXI 单次读时序图 在该…

书生浦语训练营第三次课笔记:XTuner 微调 LLM:1.8B、多模态、Agent

Finetune 简介 两种Finetune范式&#xff1a;增量预训练微调、指令跟随微调 微调数据集 上述是我们所期待模型回答的内容&#xff0c;在训练时损失的计算也是基于这个。 训练数据集看起来是这样&#xff0c;但是真正喂给模型的&#xff0c;是经过对话模板组装后的 下图中&…

信息系统项目管理师0097:价值交付系统(6项目管理概论—6.4价值驱动的项目管理知识体系—6.4.6价值交付系统)

点击查看专栏目录 文章目录 6.4.6价值交付系统1.创造价值2.价值交付组件3.信息流6.4.6价值交付系统 价值交付系统描述了项目如何在系统内运作,为组织及其干系人创造价值。价值交付系统包括项目如何创造价值、价值交付组件和信息流。 1.创造价值 项目存在于组织中,包括政府机构…