3 文本分类入门finetune:bert-base-chinese

项目实战:

数据准备工作

        `bert-base-chinese` 是一种预训练的语言模型,基于 BERT(Bidirectional Encoder Representations from Transformers)架构,专门用于中文自然语言处理任务。BERT 是由 Google 在 2018 年提出的一种革命性的预训练模型,通过大规模的无监督训练,能够学习到丰富的语言表示。

        `bert-base-chinese` 是 BERT 在中文语料上进行预训练的版本,它包含了 12 层 Transformer 编码器和 110 万个参数。这个模型在中文文本上进行了大规模的预训练,可以用于各种中文自然语言处理任务,如文本分类、命名实体识别、情感分析等。

使用 `bert-base-chinese` 模型时,可以将其作为一个特征提取器,将输入的文本转换为固定长度的向量表示,然后将这些向量输入到其他机器学习模型中进行训练或推断。也可以对 `bert-base-chinese` 进行微调,将其用于特定任务的训练。

        预训练的 `bert-base-chinese` 模型可以通过 Hugging Face 的 Transformers 库进行加载和使用。在加载模型后,可以使用它的 `encode` 方法将文本转换为向量表示,或者使用 `forward` 方法对文本进行特定任务的预测。

        需要注意的是,`bert-base-chinese` 是一个通用的中文语言模型,但它可能在特定的任务上表现不佳。在某些情况下,可能需要使用更大的模型或进行微调来获得更好的性能。

进行微调时,可以按照以下步骤进行操作:

1. 准备数据集:首先,你需要准备一个与你的任务相关的标注数据集。这个数据集应该包含输入文本以及相应的标签或注释,用于训练和评估模型。

2. 加载预训练模型:使用 Hugging Face 的 Transformers 库加载预训练的 `bert-base-chinese` 模型。你可以选择加载整个模型或只加载其中的一部分,具体取决于你的任务需求。

3. 创建模型架构:根据你的任务需求,创建一个适当的模型架构。这通常包括在 `bert-base-chinese` 模型之上添加一些额外的层,用于适应特定的任务。

4. 数据预处理:将你的数据集转换为适合模型输入的格式。这可能包括将文本转换为输入的编码表示,进行分词、填充和截断等操作。

5. 定义损失函数和优化器:选择适当的损失函数来衡量模型预测与真实标签之间的差异,并选择合适的优化器来更新模型的参数。

6. 微调模型:使用训练集对模型进行训练。在每个训练步骤中,将输入文本提供给模型,计算损失并进行反向传播,然后使用优化器更新模型的参数。

7. 评估模型:使用验证集或测试集评估模型的性能。可以计算准确率、精确率、召回率等指标来评估模型在任务上的表现。

8. 调整和优化:根据评估结果,对模型进行调整和优化。你可以尝试不同的超参数设置、模型架构或训练策略,以获得更好的性能。

9. 推断和应用:在微调完成后,你可以使用微调后的模型进行推断和应用。将新的输入文本提供给模型,获取预测结果,并根据任务需求进行后续处理。

需要注意的是,微调的过程可能需要大量的计算资源和时间,并且需要对模型和数据进行仔细的调整和优化。此外,合适的数据集规模和质量对于获得良好的微调结果也非常重要。

准备模型

https://huggingface.co/bert-base-chinese/tree/main

数据集

与之前的训练数据一样使用;

代码部分

# 导入transformers
import transformers
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup# 导入torch
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F# 常用包
import re
import numpy as np
import pandas as pd
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from collections import defaultdict
from textwrap import wrap
from tqdm import tqdm%matplotlib inline
%config InlineBackend.figure_format='retina' # 主题device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
devicePRE_TRAINED_MODEL_NAME = '../bert-base-chinese/' # 英文bert预训练模型
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
tokenizerdf = pd.read_csv("../data/online_shopping_10_cats.csv")
#myDataset[0]
RANDOM_SEED = 1012
df_train, df_test = train_test_split(df, test_size=0.1, random_state=RANDOM_SEED)
df_val, df_test = train_test_split(df_test, test_size=0.5, random_state=RANDOM_SEED)
df_train.shape, df_val.shape, df_test.shapeclass MyDataSet(Dataset):def __init__(self,texts,labels,tokenizer,max_len):self.texts = textsself.labels = labelsself.tokenizer = tokenizerself.max_len = max_lendef __len__(self):return len(self.texts)def __getitem__(self,item):text = str(self.texts[item])label = self.labels[item]encoding = self.tokenizer(text=text,max_length=self.max_len,pad_to_max_length=True,add_special_tokens=True,return_attention_mask=True,return_token_type_ids=True,return_tensors='pt')return {"text":text,"input_ids":encoding['token_type_ids'].flatten(),"attention_mask":encoding['attention_mask'].flatten(),"labels":torch.tensor(label,dtype=torch.long)}
def create_data_loader(df,tokenizer,max_len,batch_size=4):ds = MyDataSet(texts=df['review'].values,labels=df['label'].values,tokenizer = tokenizer,max_len=max_len)return DataLoader(ds,batch_size=batch_size)
MAX_LEN = 512
BATCH_SIZE = 32
train_data_loader = create_data_loader(df_train,tokenizer,max_len=MAX_LEN, batch_size=BATCH_SIZE)
val_data_loader = create_data_loader(df_val,tokenizer,max_len=MAX_LEN, batch_size=BATCH_SIZE)
test_data_loader = create_data_loader(df_test,tokenizer,max_len=MAX_LEN, batch_size=BATCH_SIZE)class BaseBertModel(nn.Module):def __init__(self,n_class=2):super(BaseBertModel,self).__init__()self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)self.drop = nn.Dropout(0.2)self.out = nn.Linear(self.bert.config.hidden_size, n_class)passdef forward(self,input_ids,attention_mask):_,pooled_output = self.bert(input_ids=input_ids,attention_mask=attention_mask,return_dict = False)out = self.drop(pooled_output)return self.out(out)#pooled_output
model = BaseBertModel()
model = model.to(device)EPOCHS = 10 # 训练轮数
optimizer = AdamW(model.parameters(),lr=2e-4,correct_bias=False)
total_steps = len(train_data_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=0,num_training_steps=total_steps)
loss_fn = nn.CrossEntropyLoss().to(device)def train_epoch(model,data_loader,loss_fn,optimizer,device,schedule,n_examples):model = model.train()losses = []correct_predictions = 0count = 0for d in tqdm(data_loader):input_ids = d['input_ids'].to(device)attention_mask = d['attention_mask'].to(device)targets = d['labels'].to(device)outputs = model(input_ids=input_ids,attention_mask=attention_mask)_,preds = torch.max(outputs, dim=1)loss = loss_fn(outputs,targets)losses.append(loss.item())correct_predictions += torch.sum(preds==targets)loss.backward()nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)optimizer.step()scheduler.step()optimizer.zero_grad()count += 1# if count > 100:#     break#breakreturn correct_predictions.double() / n_examples, np.mean(losses)def eval_model(model, data_loader, loss_fn, device, n_examples):model = model.eval() # 验证预测模式losses = []correct_predictions = 0with torch.no_grad():for d in data_loader:input_ids = d["input_ids"].to(device)attention_mask = d["attention_mask"].to(device)targets = d["labels"].to(device)outputs = model(input_ids=input_ids,attention_mask=attention_mask)_, preds = torch.max(outputs, dim=1)loss = loss_fn(outputs, targets)correct_predictions += torch.sum(preds == targets)losses.append(loss.item())return correct_predictions.double() / n_examples, np.mean(losses)# train model
history = defaultdict(list) # 记录10轮loss和acc
best_accuracy = 0for epoch in range(EPOCHS):print(f'Epoch {epoch + 1}/{EPOCHS}')print('-' * 10)#(model,data_loader,loss_fn,device,schedule,n_exmaples)train_acc, train_loss = train_epoch(model,train_data_loader,loss_fn,optimizer,device,scheduler,len(df_train))print(f'Train loss {train_loss} accuracy {train_acc}')val_acc, val_loss = eval_model(model,val_data_loader,loss_fn,device,len(df_val))print(f'Val loss {val_loss} accuracy {val_acc}')print()history['train_acc'].append(train_acc)history['train_loss'].append(train_loss)history['val_acc'].append(val_acc)history['val_loss'].append(val_loss)if val_acc > best_accuracy:torch.save(model.state_dict(), 'best_model_state.bin')best_accuracy = val_acc# 模型评估
test_acc, _ = eval_model(model,test_data_loader,loss_fn,device,len(df_test)
)
test_acc.item()def get_predictions(model, data_loader):model = model.eval()texts = []predictions = []prediction_probs = []real_values = []with torch.no_grad():for d in data_loader:texts = d["texts"]input_ids = d["input_ids"].to(device)attention_mask = d["attention_mask"].to(device)targets = d["labels"].to(device)outputs = model(input_ids=input_ids,attention_mask=attention_mask)_, preds = torch.max(outputs, dim=1)probs = F.softmax(outputs, dim=1)texts.extend(texts)predictions.extend(preds)prediction_probs.extend(probs)real_values.extend(targets)predictions = torch.stack(predictions).cpu()prediction_probs = torch.stack(prediction_probs).cpu()real_values = torch.stack(real_values).cpu()return texts, predictions, prediction_probs, real_valuesy_texts, y_pred, y_pred_probs, y_test = get_predictions(model,test_data_loader
)
print(classification_report(y_test, y_pred, target_names=[str(label) for label in class_names]))# 模型预测sample_text='Hard but Robust, Easy but Sensitive: How Encod.'
encoded_text = tokenizer.encode_plus(sample_text,max_length=MAX_LEN,add_special_tokens=True,return_token_type_ids=False,pad_to_max_length=True,return_attention_mask=True,return_tensors='pt',
)input_ids = encoded_text['input_ids'].to(device)
attention_mask = encoded_text['attention_mask'].to(device)output = model(input_ids, attention_mask)
_, prediction = torch.max(output, dim=1)print(f'Sample text: {sample_text}')
print(f'Danger label  : {label_id2cate[prediction.cpu().numpy()[0]]}')

runing....

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/216813.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

迅腾文化助力企业品牌创新,加快增强品牌发展新动能

迅腾文化助力企业品牌创新,加快增强品牌发展新动能 随着市场竞争的日益激烈,品牌创新已成为企业持续发展的关键。为了在市场中脱颖而出,许多企业纷纷寻求外部合作伙伴以加快品牌发展。广州迅腾文化传播有限公司拥有13年品宣经验的企业&#…

关于Cython生成的so动态链接库逆向

来个引子:TPCTF的maze题目 如何生成这个so文件 为了研究逆向,我们先搞个例子感受一下生成so的整个过程,方便后续分析 创建对应python库文件 testso.py def test_add(a,b):a int(a)b int(b)return a bdef test_calc(li):for i in range…

JavaWeb笔记之MySQL数据库

#Author 流云 #Version 1.0 一、引言 1.1 现有的数据存储方式有哪些? Java程序存储数据(变量、对象、数组、集合),数据保存在内存中,属于瞬时状态存储。 文件(File)存储数据,保存…

【基于Flask、MySQL和Echarts的热门游戏数据可视化平台设计与实现】

基于Flask、MySQL和Echarts的热门游戏数据可视化平台设计与实现 前言数据获取与清洗数据集数据获取数据清洗 数据分析与可视化数据分析功能可视化功能 创新点结语 前言 随着游戏产业的蓬勃发展,了解游戏销售数据对于游戏从业者和游戏爱好者都至关重要。为了更好地分…

自动化补丁管理软件

什么是自动化补丁管理 自动补丁管理(或自动补丁)是指整个补丁管理过程的自动化,从扫描网络中的所有系统到检测缺失的补丁,在一组测试系统上测试补丁,将它们部署到所需的系统,并提供定期更新和补丁部署状态…

Duplicate keys detected: This may cause an update error.【Vue遍历渲染报错的解决】

今天在写项目时,写到一个嵌套评论的遍历时,控制台出现了一个报错信息,但是并不影响页面的渲染,然后一看这个错的原因是 key值重复,那么问题的解决方式就很简单了。(vue for循环读取key值时, key…

LLM Agent发展演进历史(观看metagpt视频笔记)

LLM相关的6篇重要的论文,其中4篇来自谷歌,2篇来自openai。技术路径演进大致是:SSL (Self-Supervised Learning) -> SFT (Supervised FineTune) IT (Instruction Tuning) -> RLHF。 word embedding的问题:新词如何处理&…

文档或书籍扫描为 PDF:ScanPapyrus Crack

ScanPapyrus 可让您快速轻松地将文档或书籍扫描为 PDF,批处理模式使扫描过程快速高效,自动处理书籍并将其拆分为单独的页面 用于快速扫描文档、书籍或打印照片的扫描仪软件 快速扫描文档 使用此扫描仪软件,您无需在扫描仪和计算机之间来回移动…

JavaEE 09 锁策略

1.锁策略 1.1 乐观锁与悲观锁 其实前三个锁是同一种锁,只是站在不同的角度上去进行描述,此处的乐观与悲观其实是指在预测的角度上看会发生锁竞争的概率大小,概率大的则是悲观锁,概率小的则是乐观锁 乐观锁在加锁的时候就会做较少的事情,加锁的速度较快,但是消耗的cpu资源等也会…

大数据机器学习与深度学习——过拟合、欠拟合及机器学习算法分类

大数据机器学习与深度学习——过拟合、欠拟合及机器学习算法分类 过拟合,欠拟合 针对模型的拟合,这里引入两个概念:过拟合,欠拟合。 过拟合:在机器学习任务中,我们通常将数据集分为两部分:训…

beebox靶场A3 low级别 xss通关教程(二)

六:xss get型 eval 通过观察我们可以发现url地址中存在一个date函数 那我们可以试一下把后面的date()函数去掉,直接写入一个alert(555) 发现直接弹出一个框,证明有xss漏洞 七:xss href 直接进入页面会看到是get方法&#xff0c…

【JVM从入门到实战】(五)类加载器

一、什么是类加载器 类加载器(ClassLoader)是Java虚拟机提供给应用程序去实现获取类和接口字节码数据的技术。 类加载器只参与加载过程中的字节码获取并加载到内存这一部分。 二、jdk8及之前的版本 类加载器分为三类: 启动类加载器-加载Ja…

Docker Compose入门:打造多容器应用的完美舞台

Docker Compose 是一个强大的工具,它允许开发者通过简单的 YAML 文件定义和管理多容器的应用。本文将深入讨论 Docker Compose 的基本概念、常用命令以及高级应用场景,并通过更为丰富和实际的示例代码,助您轻松掌握如何通过 Docker Compose 打…

VLAN协议与单臂路由

文章目录 VLAN协议与单臂路由一、VLAN的概念及优势1、分割广播域2、VLAN的优势3、VLAN数据帧 二、VLAN的种类1、静态VLAN2、动态VLAN3、VLAN划分方式 三、静态VLAN的配置1、VLAN的范围2、静态VLAN的配置2.1 配置静态VLAN的步骤2.2 vlan三种端口类型举例:配置静态VLA…

1688按关键字搜索工厂数据,商品详情页数据的采集

公共参数 名称类型必须描述keyString是调用key(必须以GET方式拼接在URL中,点击获取测试key和secret)secretString是调用密钥api_nameString是API接口名称(包括在请求地址中)[item_search,item_get,item_search_shop等]cacheStrin…

【稳定检索】2024年物理化学工程与应用力学国际会议(ICPCEAM 2024)

2024年物理化学工程与应用力学国际会议(ICPCEAM 2024) 2024 International Conference on Physical and Chemical Engineering and Applied Mechanics(ICPCEAM) 一、【会议简介】 2024年物理化学工程与应用力学国际会议(ICPCEAM 2024)将于2024年3月9日在中国上海盛大召开。本次…

SpringIOC之@EnableLoadTimeWeaving

博主介绍:✌全网粉丝5W+,全栈开发工程师,从事多年软件开发,在大厂呆过。持有软件中级、六级等证书。可提供微服务项目搭建与毕业项目实战,博主也曾写过优秀论文,查重率极低,在这方面有丰富的经验✌ 博主作品:《Java项目案例》主要基于SpringBoot+MyBatis/MyBatis-plus+…

论文阅读:PointCLIP: Point Cloud Understanding by CLIP

CVPR2022 链接:https://arxiv.org/pdf/2112.02413.pdf 0、Abstract 最近,通过对比视觉语言预训练(CLIP)的零镜头学习和少镜头学习在2D视觉识别方面表现出了鼓舞人心的表现,即学习在开放词汇设置下将图像与相应的文本匹配。然而,…

jdk+zookeeper+kafka 搭建kafka集群

环境准备 环境资源包: jdk-8u341-linux-x64.tar.gz kafka_2.12-2.2.0.tgz zookeeper-3.4.14.tar.gz server-idip状态server110.206.120.10leaderserver210.206.120.2followerserver310.206.120.3follower 一、安装jdk 因为kafka需要Java环境,所以优先…

Linux AMH服务器管理面板本地安装与远程访问

最近,我发现了一个超级强大的人工智能学习网站。它以通俗易懂的方式呈现复杂的概念,而且内容风趣幽默。我觉得它对大家可能会有所帮助,所以我在此分享。点击这里跳转到网站。 文章目录 1. Linux 安装AMH 面板2. 本地访问AMH 面板3. Linux安装…