vllm 部署GLM4模型进行 Zero-Shot 文本分类实验,让大模型给出分类原因,准确率可提高6%

文章目录

    • 简介
    • 数据集
    • 实验设置
    • 数据集转换
    • 模型推理
    • 评估

简介

本文记录了使用 vllm 部署 GLM4-9B-Chat 模型进行 Zero-Shot 文本分类的实验过程与结果。通过对 AG_News 数据集的测试,研究发现大模型在直接进行分类时的准确率为 77%。然而,让模型给出分类原因描述(reason)后,准确率显著提升至 83%,提升幅度达 6%。这一结果验证了引入 reasoning 机制的有效性。文中详细介绍了实验数据、提示词设计、模型推理方法及评估手段。

复现自这篇论文:Text Classification via Large Language Models. https://arxiv.org/abs/2305.08377 让大模型使用reason。

该项目的文件结构如下所示:

├── cls_vllm.log
├── cls_vllm.py
├── data
│   ├── basic_llm.csv
│   └── reason_llm.csv
├── data_processon.ipynb
├── eval.ipynb
├── output
│   ├── basic_vllm.pkl
│   └── reason_vllm.pkl
├── settings.py
└── utils.py

数据集

现在要找一个数据集做实验,进入 https://paperswithcode.com/。
找到 文本分类,看目前的 SOTA 是在哪些数据集上做的,文本分类. https://paperswithcode.com/task/text-classification

在这里插入图片描述

实验使用了 AG_News 数据集。若您对数据集操作技巧感兴趣,可以参考这篇文章:

datasets库一些基本方法:filter、map、select等. https://blog.csdn.net/sjxgghg/article/details/141384131

实验设置

settings.py 文件中,我们定义了一些实验中使用的提示词:

LABEL_NAMES = ['World', 'Sports', 'Business', 'Science | Technology']BASIC_CLS_PROMPT = """
你是文本分类专家,请你给下述文本分类,把它分到下述类别中:
* World
* Sports
* Business
* Science | Technologytext是待分类的文本。请你一步一步思考,在label中给出最终的分类结果:
text: {text}
label: 
"""REASON_CLS_PROMPT = """
你是文本分类专家,请你给下述文本分类,把它分到下述类别中:
* World
* Sports
* Business
* Science | Technologytext是待分类的文本。请你一步一步思考,首先在reason中说明你的判断理由,然后在label中给出最终的分类结果:
text: {text}
reason: 
label: 
""".lstrip()data_files = ["data/basic_llm.csv","data/reason_llm.csv"
]output_dirs = ["output/basic_vllm.pkl","output/reason_vllm.pkl"
]

这两个数据文件用于存储不同提示词的大模型推理数据:

  • data/basic_llm.csv
  • data/reason_llm.csv

数据集转换

为了让模型能够执行文本分类任务,我们需要对原始数据集进行转换,添加提示词。

原始的数据集样式,要经过提示词转换后,才能让模型做文本分类。

代码如下:

data_processon.ipynb

from datasets import load_datasetfrom settings import LABEL_NAMES, BASIC_CLS_PROMPT, REASON_CLS_PROMPT, data_filesimport os
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'# 加载 AG_News 数据集的测试集,只使用test的数据去预测
ds = load_dataset("fancyzhx/ag_news")# 转换为 basic 提示词格式
def trans2llm(item):item["text"] = BASIC_CLS_PROMPT.format(text=item["text"])return item
ds["test"].map(trans2llm).to_csv(data_files[0], index=False)# 转换为 reason 提示词格式
def trans2llm(item):item["text"] = REASON_CLS_PROMPT.format(text=item["text"])return item
ds["test"].map(trans2llm).to_csv(data_files[1], index=False)

上述代码实现的功能就是把数据集的文本,放入到提示词的{text} 里面。

模型推理

本文使用 ZhipuAI/glm-4-9b-chat. https://www.modelscope.cn/models/zhipuai/glm-4-9b-chat 智谱9B的chat模型,进行VLLM推理。

为了简化模型调用,我们编写了一些实用工具:

utils.py

import pickle
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from modelscope import snapshot_downloaddef save_obj(obj, name):"""将对象保存到文件:param obj: 要保存的对象:param name: 文件的名称(包括路径)"""with open(name, "wb") as f:pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)def load_obj(name):"""从文件加载对象:param name: 文件的名称(包括路径):return: 反序列化后的对象"""with open(name, "rb") as f:return pickle.load(f)def glm4_vllm(prompts, output_dir, temperature=0, max_tokens=1024):# GLM-4-9B-Chat-1Mmax_model_len, tp_size = 131072, 1model_dir = snapshot_download('ZhipuAI/glm-4-9b-chat')tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)llm = LLM(model=model_dir,tensor_parallel_size=tp_size,max_model_len=max_model_len,trust_remote_code=True,enforce_eager=True,)stop_token_ids = [151329, 151336, 151338]sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens, stop_token_ids=stop_token_ids)inputs = tokenizer.apply_chat_template(prompts, tokenize=False, add_generation_prompt=True)outputs = llm.generate(prompts=inputs, sampling_params=sampling_params)save_obj(outputs, output_dir)

glm4_vllm :

  • 参考自 https://www.modelscope.cn/models/zhipuai/glm-4-9b-chat

    给大家封装好了,以后有任务,直接调用函数

save_obj:

  • 把python对象,序列化保存到本地;

    在本项目中,用来保存 vllm 推理的结果;

模型推理代码
cls_vllm.py

from datasets import load_datasetfrom utils import glm4_vllm
from settings import data_files, output_dirs# basic 预测
basic_dataset = load_dataset("csv",data_files=data_files[0],split="train",
)
prompts = []
for item in basic_dataset:prompts.append([{"role": "user", "content": item["text"]}])
glm4_vllm(prompts, output_dirs[0])# reason 预测,添加了原因说明
reason_dataset = load_dataset("csv",data_files=data_files[1],split="train",
)
prompts = []
for item in reason_dataset:prompts.append([{"role": "user", "content": item["text"]}])
glm4_vllm(prompts, output_dirs[1])# nohup python cls_vllm.py > cls_vllm.log 2>&1 &

在推理过程中,我们使用了 glm4_vllm 函数进行模型推理,并将结果保存到指定路径。

output_dirs: 最终推理完成的结果输出路径;

评估

在获得模型推理结果后,我们需要对其进行评估,以衡量分类的准确性。

eval.ipynb

from settings import LABEL_NAMES
from utils import load_objfrom datasets import load_dataset
from settings import data_files, output_dirsimport os
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'ds = load_dataset("fancyzhx/ag_news")
def eval(raw_dataset, vllm_predict):right = 0 # 预测正确的数量multi_label = 0 # 预测多标签的数量for data, output in zip(raw_dataset, vllm_predict):true_label = LABEL_NAMES[data['label']]output_text = output.outputs[0].textpred_label = output_text.split("label")[-1]tmp_pred = []for label in LABEL_NAMES:if label in pred_label:tmp_pred.append(label)if len(tmp_pred) > 1:multi_label += 1if " ".join(tmp_pred) == true_label:right += 1return right, multi_label

我们分别对 basic 和 reason 预测结果进行了评估。

basic 预测结果的评估 :

dataset = load_dataset('csv', data_files=data_files[0], split='train')
output = load_obj(output_dirs[0])eval(dataset, output)

输出结果:

(5845, 143)

加了reason 预测结果评估:

dataset = load_dataset('csv', data_files=data_files[1], split='train')
output = load_obj(output_dirs[1])eval(dataset, output)

输出结果:

(6293, 14)

评估结果如下:

  • basic: 直接分类准确率为 77%(5845/7600),误分类为多标签的样本有 143 个。
  • reason: 在输出原因后分类准确率提高至 83%(6293/7600),多标签误分类样本减少至 14 个。

误分类多标签: 这是单分类问题,大模型应该只输出一个类别,但是它输出了多个类别;

可以发现,让大模型输出reason,不仅分类准确率提升了5%,而且在误分类多标签的数量也有所下降。
原先误分类多标签有143条数据,使用reason后,多标签误分类的数量降低到了14条。

这些结果表明,让模型输出 reason的过程,确实能够有效提升分类准确性,还能减少误分类多个标签。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/407864.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【软件测试面试题】WEB功能测试(持续更新)

Hi,大家好,我是小码哥。最近很多朋友都在说今年的互联网行情不好,面试很难,不知道怎么复习,我最近总结了一份在软件测试面试中比较常见的WEB功能测试面试面试题合集,希望对大家有帮助。 建议点赞收藏再阅读…

AI学习记录 - 怎么理解 torch 的 nn.Conv2d

有用就点个赞 怎么理解 nn.Conv2d 参数 conv_layer nn.Conv2d(in_channels1, out_channels 10 // 2, kernel_size3, stride2, padding0, biasFalse) in_channels in_channels 可以设置成1,2,3,4等等都可以,一般来说做图像识别…

微服务案例搭建

目录 一、案例搭建 1.数据库表 2.服务模块 二、具体代码实现如下: (1) 首先是大体框架为: (2)父模块中的pom文件配置 (3)shop_common模块,这个模块里面只需要配置pom.xml,与实体…

MySQL如何判断一个字段里面是否包含汉字

SQL查询中,length() 和 char_length() 都是用来获取字符串长度的函数 在单字节字符集下(如ASCII):每个字符通常占用1个字节,因此length()和char_length()在这类字符集中给出的结果是一样 在多字节字符集下&#xff0…

matplotlib绘制子图以及局部放大效果

需求:绘制1*2的子图,子图1显示两个三角函数,子图2显示三个对数函数,子图2中对指定的区域进行放大。 绘图细节: 每个子图中每个函数的数据存放到一个列表中,然后将每个子图的数据统一存到一个列表中&#…

Go 使用Redis安装、实例和基本操作

Go使用Redis:详解go-redis/v9库 引言 Redis作为一个高性能的键值对数据库,广泛应用于缓存、消息队列、实时数据分析等场景。在Go语言中,go-redis/v9库提供了丰富的接口和高效的数据交互能力,使得在Go项目中集成Redis变得简单而高…

接口限流经典算法

文章目录 限流基于计数器的限流基于滑动窗口的限流桶漏斗算法令牌桶算法 限流 为了保证系统的安全性和稳定性,防止恶意流量和突发大量流量短时间内大量请求接口,造成服务器崩溃,接口的限流是有必要的。 以下是四种经典的限流算法。 基于计数…

Python测试框架Pytest的使用

pytest基础功能 pytset功能及使用示例1.assert断言2.参数化3.运行参数4.生成测试报告5.获取帮助6.控制用例的执行7.多进程运行用例8.通过标记表达式执行用例9.重新运行失败的用例10.setup和teardown函数 pytset功能及使用示例 1.assert断言 借助python的运算符号和关键字实现不…

UE5打包iOS运行查看Crash日志

1、查看Crash 1、通过xCode打开设备 2、选择APP打开最近的日志 3、选择崩溃时间点对应的日志 4、选择对应的工程打开 5、就能看到对应的Crash日志 2、为了防止Crash写代码需要注意 1、UObject在RemoveFromRoot之前先判断是否Root if (SelectedImage && Selecte…

Frog4Shell — FritzFrog 僵尸网络将一日攻击纳入其武器库

FritzFrog 的背景 Akamai 通过我们的全球传感器网络持续监控威胁,包括我们之前发现的威胁。其中包括FritzFrog 僵尸网络(最初于 2020 年发现),这是一个基于 Golang 的复杂点对点僵尸网络,经过编译可同时支持基于 AMD 和 ARM 的机器。该恶意软件得到积极维护,多年来通过增…

百日筑基第六十天-学习一下Tomcat

百日筑基第六十天-学习一下Tomcat 一、Tomcat 顶层架构 Tomcat 中最顶层的容器是 Server,代表着整个服务器,从上图中可以看出,一个 Server可以包含至少一个 Service,用于具体提供服务。Service 主要包含两个部分:Conn…

AI周报(8.18-8.24)

AI应用-XGO-Rider: 全球首款轮腿式桌面 AI 机器人 中国的 Luwu 智能打造的XGO-Rider 是全球首款轮腿式桌面 AI 机器人。这个小巧紧凑的机器人将轮式机器人的灵活性与腿式机器人的障碍处理能力相结合,可以全方位移动,轻松适应各种地形。 XGO-Rider 主要设…

服务商模式实现JSAPI小程序微信支付(javaphp)

官方文档 https://pay.weixin.qq.com/wiki/doc/apiv3_partner/open/pay/chapter2_1.shtml 使用wechatpay-php实现JSAPI支付(服务商和普通商户)文章浏览阅读1.3k次,点赞3次,收藏7次。之前我使用的sdk是“wechatpay-guzzle-middle…

python实用教程(二):安装配置Pycharm及使用(Win10)

上一篇:python实用教程(一):安装配置anaconda(Win10)-CSDN博客 1、简介及下载 PyCharm是一款功能强大的 Python 编辑器,具有跨平台性。是Jetbrains家族中的一个明星产品。 下载地址&#xff…

redis实战——go-redis的使用与redis基础数据类型的使用场景(二)

一.go-redis操作hash 常用命令: redisClient.HSet("map", "name", "jack") // 批量设置 redisClient.HMSet("map", map[string]interface{}{"a": "b", "c": "d", "e"…

计算机毕业设计选题推荐-游戏比赛网上售票系统-Java/Python项目实战

✨作者主页:IT研究室✨ 个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Python…

棚子影院CMS程序PHP源码

01, 棚子影视是我现在最常用的一个看视频的网站,支持观看电影、国漫!动漫,电视剧、综艺、记录片、香港剧等等。同时棚子影视支持手机,PC端在线观看,不用下载任何播放器,直接电脑或者手机打开网址就可以在线…

vue3 RouterLink路由跳转后RouterView组件未加载,页面未显示,且控制台无任何报错

在使用 vue3 开发项目过程中,组件之间使用 router-link 跳转,但是当我开发的组件跳转到其他组件时,其他组件的页面未加载,再跳转回自己的组件时,自己的组件也加载不出来了,浏览器刷新后页面可以加载出来。但…

结合 curl 与住宅代理实现高效数据抓取

引言 什么是 curl?有哪些功能? 基本 curl 命令有哪些? 为什么要使用 curl 处理 HTTP 请求? 如何使用 curl 和住宅代理进行网络抓取? 总结 引言 在当今数据驱动的商业环境中,数据的获取和分析能力是企…

Redis | 非关系型数据库Redis的初步认识

本节内容相对理论,着重看基础通用命令这一节 Redis 非关 kv型{字典} 概念应用ubuntu安装配置 windows添加密码 可能问题【ubuntu】远程连接 基础通用命令 ⭐ 概念 特点: 1、开源的,使用C编写,基于内存且支持持久化 2、没有表 支持…