使用 BERT 进行文本分类 (02/3)

一、说明

        在使用BERT(1)进行文本分类中,我向您展示了一个BERT如何标记文本的示例。在下面的文章中,让我们更深入地研究是否可以使用 BERT 来预测文本是使用 PyTorch 传达积极还是消极的情绪。首先,我们需要准备数据,以便使用 PyTorch 框架进行分析。

二、什么是 PyTorch

        PyTorch 是用于构建深度学习模型的框架,深度学习模型是一种机器学习,通常用于图像识别和语言处理等应用程序。它由Facebook的人工智能研究小组于2016年开发,由于其灵活性,易用性和动态计算图构建而广受欢迎。

        PyTorch 提供了一个基于 Python 的科学计算包,它使用图形处理单元 (GPU) 的强大功能来加速张量运算的计算。它具有简单直观的API,允许开发人员快速构建和训练深度学习模型。PyTorch 还支持自动微分,使用户能够计算任意函数的梯度。

三、准备我们的数据集

        首先,让我们从Github下载我们的数据。这里有一个关于如何从Github下载CSV文件的小提醒。只需继续并单击以下链接:

github.com

        然后,右键单击“原始”,然后左键单击“将链接文件下载为...”。您将看到“垃圾邮件.csv”并下载它。下载后,将其保存到您的首选文件夹中以供以后使用。

        现在,让我们导入数据。我们看到一条错误消息,告诉我们部分数据未采用 UTF-8 编码。

import pandas as pd
df = pd.read_csv("spam.csv")ERROR: 
UnicodeDecodeError: 'utf-8' codec can't decode bytes in position 606-607: invalid continuation byte

我们可以通过了解数据包含的字符编码并在读取数据时调用该编码来修复此错误。

# Use chardet to know the character encoding 
import chardet
with open("spam.csv", 'rb') as rawdata:result = chardet.detect(rawdata.read(100000))
resultOutput: 
{'encoding': 'Windows-1252', 'confidence': 0.7270322499829184, 'language': ''}

似乎我们的数据是在“Windows-1252”中编码的。那让我们再读一遍。它奏效了!

df = pd.read_csv("spam.csv", encoding = 'Windows-1252')
df.head()

        如我们所见,我们实际上并不需要“v1”和“v2”以外的列。此外,如果我们将“v1”和“v2”重命名为“类别”和“消息”,则更容易理解。

df = df.loc[:, ['v1', 'v2']]
df = df.rename(columns={'v1': 'Category', 'v2': 'Message'})
df.head()

        现在,我们应该看看我们的数据集,看看每个类别中有多少条消息。

df['Category'].value_counts()Output: 
ham     4825
spam     747
Name: Category, dtype: int64

四、创建平衡数据集

        事实证明,正常邮件比垃圾邮件多。构建机器学习模型时,如果数据集不平衡,其中一个类中的数据数量明显多于另一个类,则可能会对模型的性能产生各种影响。一些潜在的后果。例如:

-1 有偏差模型:如果数据集不平衡,模型可能会偏向多数类,而对少数类表现不佳。这是因为模型更有可能预测多数类,这将导致少数类的准确性较差。

-2 泛化不良:不平衡的数据集可能导致模型泛化不良。这是因为该模型将在不代表数据真实世界分布的数据集上进行训练,因此它可能无法很好地概括看不见的数据。

-3 评估不准确:如果使用准确性作为指标评估模型,则可能会产生误导性结果。例如,始终预测不平衡数据集中多数类的模型可能具有很高的准确性,但对少数类没有用。

-4 过拟合:由于数据点数量较多,模型可能会过度拟合多数类,从而导致测试数据的性能不佳。

为了解决这些问题,可以使用各种技术来平衡数据集,例如对少数类进行过采样,对多数类进行欠采样,或同时使用两者的组合。在这篇文章中,我将使用欠采样方法。

df_spam = df[df['Category']=='spam']
df_ham = df[df['Category']=='ham']
df_ham_downsampled = df_ham.sample(df_spam.shape[0])
df_balanced = pd.concat([df_ham_downsampled, df_spam])
df_balanced['Category'].value_counts()Output: 
ham     747
spam    747
Name: Category, dtype: int64

五、标记数据

        当数据表示为数字而不是分类为用于训练和测试的模型时,机器学习算法在准确性和其他性能指标方面表现更好。我们需要用数值对分类值进行标签编码。在这里,我们创建了一个新列“标签”,如果邮件是垃圾邮件,我们将其标记为 1,否则为 0。

df_balanced['Label']=df_balanced['Category'].apply(lambda x: 1 if x=='spam' else 0)
df_balanced = df_balanced.reset_index(drop=True)display(df_balanced)

由作者创建

六、训练、验证和测试数据集:谁是谁

        要记住的一件事是,当我们使用 train_test_split 库来训练模型时,我们实际上是将数据集拆分为 TRAINING 数据集和 VALIDATION 数据集,而不是 TRAINING 数据集和 TESTING 数据集。下面提醒一下这些数据集的含义。

  1. 训练集:用于构建我们的模型。我们将使用训练集来找到具有反向传播规则的“最佳”权重和偏差。在此阶段,我们通常会创建多个算法,以便在交叉验证阶段比较它们的性能。
  2. 交叉验证集:此数据集用于比较基于训练集创建的预测算法的性能。我们选择性能最佳的算法。
  3. 测试集:这是“未来”数据集。现在我们已经选择了我们喜欢的预测算法,但我们还不知道它将如何在完全看不见的真实世界数据上执行。因此,我们将我们选择的预测算法应用于我们的测试集,以查看它将如何执行,以便我们可以了解我们的算法在野外的性能。

        因此,在测试集中,我们没有数据的标签,而是使用我们的模型来预测标签。我们只能将手头的数据集拆分为训练集和验证集,因为我们还没有“未来”数据。

七、拆分为训练数据集和验证数据集

        现在我们了解了这三种类型的数据的真正含义,我们可以使用scikit-learn的train_test_split来拆分数据。

from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(df_balanced['Message'],df_balanced['Label'], stratify=df_balanced['Label'], test_size=.2)X_train.head()Output: 
708                      ;-) ok. I feel like john lennon.
1386    Cashbin.co.uk (Get lots of cash this weekend!)...
1492    REMINDER FROM O2: To get 2.50 pounds free call...
119     Back in brum! Thanks for putting us up and kee...
89                       Sorry, I can't help you on this.
Name: Message, dtype: object

八、总结

        我们已经学会了如何下载和拆分数据。在下一篇文章中,我们将首先对其进行标记,并使用DistilBERT训练分类器。达门·

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/92432.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

冉冉升起的星火,再度升级迎来2.0时代!

文章目录 前言权威性评测结果 星火大模型多模态功能插件功能简历生成文档问答PPT生成 代码能力 福利 前言 前几天从技术群里看到大家都在谈论《人工智能大模型体验报告2.0》里边的内容,抱着好奇和学习的态度把报告看了一遍。看完之后瞬间被里边提到的科大讯飞的星火…

PHP实现在线年龄计算器

1. 输入日期查询年龄 2. php laravel框架实现 代码 /*** 在线年龄计算器*/public function ageDateCal(){// 输入的生日时间$birthday $this->request(birthday);// 当前时间$currentDate date(Y-m-d);// 计算周岁$age date_diff(date_create($birthday), date_create($…

SQL Server基础之游标

一:认识游标 游标是SQL Server的一种数据访问机制,它允许用户访问单独的数据行。用户可以对每一行进行单独的处理,从而降低系统开销和潜在的阻隔情况,用户也可以使用这些数据生成的SQL代码并立即执行或输出。 1.游标的概念 游标是…

LiveDataBus 其中的一个库LiveEventBus库的源码解析

EventBus事件通知的框架我们用了很久了,随着LiveData的出现,出现了LiveDataBus来替代EventBus,因为LiveDataBus 会考虑生命周期,EventBus你可能要注意在生命周期结束的时候unregister的,否则会有内存泄漏等问题,而Live…

采用pycharm在虚拟环境使用pyinstaller打包python程序

一年多以前,我写过一篇博客描述了如何虚拟环境打包,这一次有所不同,直接用IDE pycharm构成虚拟环境并运行pyinstaller打包 之前的博文: 虚拟环境venu使用pyinstaller打包python程序_伊玛目的门徒的博客-CSDN博客 第一步&#xf…

【深入了解PyTorch】PyTorch模型解释性和可解释性:探索决策过程与预测结果的奥秘

【深入了解PyTorch】PyTorch模型解释性和可解释性:探索决策过程与预测结果的奥秘 PyTorch模型解释性和可解释性:探索决策过程与预测结果的奥秘1. 引言2. 梯度可视化3. 特征重要性分析4. 结论PyTorch模型解释性和可解释性:探索决策过程与预测结果的奥秘 在机器学习和深度学习…

临床试验设计-平行设计、析因设计、交叉设计

平行设计、析因设计、交叉设计是临床试验中最重要的三种设计方法。 平行设计:最常见 交叉设计:生物等效性试验 析因设计:药物配伍 一、平行设计 双臂试验(two-arm study)(两组):试验…

Steam 灵感的游戏卡悬停效果

先看效果&#xff1a; 再看代码&#xff08;查看更多&#xff09;&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Steam 灵感的游戏卡悬停效果</title><style>* {margin: …

【IMX6ULL驱动开发学习】04.应用程序和驱动程序数据传输和交互的4种方式:非阻塞、阻塞、POLL、异步通知

一、数据传输 1.1 APP和驱动 APP和驱动之间的数据访问是不能通过直接访问对方的内存地址来操作的&#xff0c;这里涉及Linux系统中的MMU&#xff08;内存管理单元&#xff09;。在驱动程序中通过这两个函数来获得APP和传给APP数据&#xff1a; copy_to_usercopy_from_user …

【Docker】Docker network之bridge、host、none、container以及自定义网络的详细讲解

&#x1f680;欢迎来到本文&#x1f680; &#x1f349;个人简介&#xff1a;陈童学哦&#xff0c;目前学习C/C、算法、Python、Java等方向&#xff0c;一个正在慢慢前行的普通人。 &#x1f3c0;系列专栏&#xff1a;陈童学的日记 &#x1f4a1;其他专栏&#xff1a;CSTL&…

分布式定时任务系列5:XXL-job中blockingQueue的应用

传送门 分布式定时任务系列1&#xff1a;XXL-job安装 分布式定时任务系列2&#xff1a;XXL-job使用 分布式定时任务系列3&#xff1a;任务执行引擎设计 分布式定时任务系列4&#xff1a;任务执行引擎设计续 Java并发编程实战1&#xff1a;java中的阻塞队列 引子 这篇文章的…

二.net core 自动化发布到docker (Jenkins安装之后向导)

目录 ​​​​​​​​​​​​​​ 参考资料&#xff1a;https://www.jenkins.io/doc/book/installing/docker/#setup-wizard Post-installation setup wizard.(安装后安装向导) 基于上一篇文章安装&#xff0c;在安装并运行Jenkins&#xff08;不包括使用Jenkins Opera…

日志采集分析ELK

这里的 ELK其实对应三种不同组件 1.ElasticSearch&#xff1a;基于Java&#xff0c;一个开源的分布式搜索引擎。 2.LogStash&#xff1a;基于Java&#xff0c;开源的用于收集&#xff0c;分析和存储日志的工具。&#xff08;它和Beats有重叠的功能&#xff0c;Beats出现之后&a…

6.3 社会工程学攻击

数据参考&#xff1a;CISP官方 目录 社会工程学攻击概念社会工程学攻击利用的人性 “弱点”典型社会工程学攻击方式社会工程学攻击防护 一、社会工程学攻击概念 什么是社会工程学攻击 也被称为 "社交工程学" 攻击利用人性弱点 (本能反应、贪婪、易于信任等) 进…

Reis过期删除策略

介绍 在Redis中&#xff0c;我们可以为键值对设置有效期&#xff0c;现在面临一个问题&#xff0c;如果一个键值对过期了&#xff0c;那么我们应该怎么删除呢&#xff1f; 我们目前有三种方案&#xff1a; 定时删除&#xff1a;在设置键的过期时间的同时&#xff0c;为此键设…

小游戏扫雷实现教学(详解)

目录 【前言】 一、模块化程序设计&#xff08;多文件编程&#xff09;介绍 1.概述 2.传统编程的方式 3.模块化程序设计的方法 二、扫雷代码设计思路 三、扫雷代码设计 1.创建菜单函数 2.实现9x9扫雷 3.初始化棋盘 4.打印棋盘 5.随机布置雷的位置 6.排查雷的信息 7.回…

【Apollo】推动创新:探索阿波罗自动驾驶的进步(含安装 Apollo的详细教程)

前言 Apollo (阿波罗)是一个开放的、完整的、安全的平台&#xff0c;将帮助汽车行业及自动驾驶领域的合作伙伴结合车辆和硬件系统&#xff0c;快速搭建一套属于自己的自动驾驶系统。 开放能力、共享资源、加速创新、持续共赢是 Apollo 开放平台的口号。百度把自己所拥有的强大、…

2023年7月京东冰箱行业品牌销售排行榜(京东运营数据分析)

作为日常使用的大家电之一&#xff0c;如今我国冰箱产业已渐趋饱满&#xff0c;市场增长有限。今年上半年&#xff0c;冰箱市场整体销额同比去年来看勉强保持小幅增长。不过&#xff0c;7月份&#xff0c;冰箱大盘的销售表现就略显萧条了。 根据鲸参谋电商数据分析平台的相关数…

NPM与外部服务的集成(下)

目录 1、撤消访问令牌 2、在CI/CD工作流中使用私有包 2.1 创建新的访问令牌 持续整合 持续部署 交互式工作流 CIDR白名单 2.2 将令牌设置为CI/CD服务器上的环境变量 2.3 创建并签入特定于项目的.npmrc文件 2.4 令牌安全 3、Docker和私有模块 3.1 背景&#xff1a;运…

Android的学习系列之Android Studio Setup安装

Android的学习系列之Android Studio Setup安装 [TOC](Android的学习系列之Android Studio Setup安装) 前言Android平台搭建总结 前言 还是项目需要&#xff0c;暂时搭建安卓的运行平台。 Android平台搭建 安装包 双击安装包&#xff0c;进入安装。 下一步 根据自己需求&a…