微调大模型时,如何进行数据预处理? 将<input, output>转换为模型所需的<input_ids, labels, attention_mask>

原始训练数据集格式如下:

<input, output>

形式:字符

模型训练所需数据格式如下:

# tokenizer处理后
return {"input_ids": example,"labels": labels,"attention_mask": example_mask,
}

将字符转换为id,生成三个部分。

讲解

在大模型训练中,input_idslabelsattention_mask 是标准的数据格式,用于表示输入文本和相关的辅助信息。这些数据是模型训练中必需的,尤其是在自然语言处理(NLP)任务中。下面是这些字段的详细解释:

1. input_ids

  • 含义input_ids 是输入文本的 Token IDs(即词汇表中的索引)。
  • 用途:它们是模型输入的实际内容。模型通常需要将文本转换为数字形式来处理,input_ids 就是将每个词(或子词)映射到一个整数,这个整数表示该词在预训练模型的词汇表中的位置。
  • 例子
    • 假设输入文本是:"I love AI."
    • 经过分词和映射后,可能得到 input_ids = [101, 1045, 2293, 9474, 102],其中:
      • 101102 是特殊的标记(如开始和结束标记)。
      • 1045 是词 "I" 对应的 ID。
      • 2293 是词 "love" 对应的 ID,依此类推。
  • 注意:每个模型都有自己独特的词汇表,因此同样的文本在不同模型中得到的 input_ids 可能不同。

2. labels

  • 含义labels 是模型输出的目标数据,通常用于监督学习任务中的目标标签。
  • 用途labels 用于计算损失函数(比如交叉熵损失)来优化模型。它通常与 input_ids 具有相同的格式,用于生成目标预测。在一些任务中,labels 是输入文本的某种变换,比如机器翻译的目标句子,或者是文本分类任务中的标签。
  • 例子
    • 对于语言模型的训练,labels 通常与 input_ids 相同,代表的是下一个词的预测。即,模型在每个位置预测下一个词的 input_id
    • 在问答任务中,labels 可能是模型应当输出的答案。
    • 对于分类任务,labels 可以是一个整数值,表示文本的类别(如 012 等)。

语言模型的例子

  • 输入句子:"I love AI."
  • input_ids = [101, 1045, 2293, 9474, 102]
  • labels 对于语言模型任务,可能和 input_ids 一样:
    • labels = [1045, 2293, 9474, 102, -100]
    • 这里,-100 是一个特殊的标记,用于掩盖某些位置的标签,避免在计算损失时对某些位置进行更新。

3. attention_mask

  • 含义attention_mask 是一个与 input_ids 等长的向量,指示模型在哪些位置需要关注(即哪些位置是有效的)以及哪些位置应忽略(即哪些位置是填充的)。
  • 用途:在许多 NLP 模型中,文本可能需要进行填充(padding),特别是在处理不同长度的文本时。attention_mask 用于指示模型应该计算注意力的哪些位置,而哪些位置是填充,应该被忽略。
    • 1 表示该位置是有效的,模型应该关注这个位置。
    • 0 表示该位置是填充,模型应该忽略这个位置。
  • 例子
    • 假设输入文本是两句话,分别是:"I love AI." 和 "I enjoy machine learning."
    • input_ids[101, 1045, 2293, 9474, 102, 101, 1045, 2829, 4974, 102]
    • 由于句子长度不同,第二句需要填充。假设最大长度是 10,所以填充的位置会使用 0,并且 attention_mask 会标明哪些位置需要关注:
    • attention_mask = [1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
    • 这里,1 表示模型应该关注的位置,0 表示填充的位置。

总结:

这些字段在训练时的意义:

  • input_ids:模型的输入序列,表示文本的词汇索引。
  • labels:目标输出序列,通常与 input_ids 相同,用于计算损失函数。
  • attention_mask:指示哪些位置是有效的,哪些是填充,帮助模型避免处理填充部分的数据。

这种格式在 NLP 任务中广泛使用,尤其是在语言模型训练(如 GPT、BERT 等)和序列到序列任务(如机器翻译、问答系统)中。

在大多数情况下,input_ids不包含 训练数据中的真实输出(如目标标签)。input_ids 主要用于表示输入序列,即模型的输入,而真实的输出通常会在 labels 中提供。我们可以通过具体的任务来更好地理解它们之间的关系。

input_ids 和 labels

1. 语言建模任务(如 GPT)

对于语言建模任务,input_idslabels 是非常相似的,甚至有时完全相同。这是因为语言模型的任务是根据前面的上下文预测下一个词。因此,模型的输入(input_ids)和目标输出(labels)是相同的,且 labels 的每个位置都表示目标词。

举例

  • 输入句子: "I love AI"

  • 假设使用的词汇表索引(简化表示):I -> 1045, love -> 2293, AI -> 9474

  • input_ids = [1045, 2293, 9474](这就是模型的输入)

  • labels = [2293, 9474, -100]-100 是一个占位符,表示忽略该位置)

在语言模型任务中,模型的目标是预测每个位置的下一个词。所以 labels 会从第一个词开始,包含实际的目标词。input_idslabels 对于每个位置来说,在训练时是同步的,只是模型预测的是 下一个 词。

2. 序列标注任务(如命名实体识别 NER 或 POS 标注)

在序列标注任务中,input_ids 仍然是输入序列的表示,但 labels 是每个输入单词或标记的标签。这时,input_ids 仅包含输入文本的词汇索引,而真实标签(比如实体类别或词性标签)则在 labels 中。

举例

  • 输入文本: "I love AI"

  • 目标标签(假设为命名实体识别任务):I -> O, love -> O, AI -> B-ORG

  • input_ids = [1045, 2293, 9474](表示输入的文本)

  • labels = [0, 0, 1](表示每个词的标签,其中 0 是普通词,1 表示 "AI" 是一个组织实体)

在这种任务中,input_idslabels 是不同的,input_ids 表示输入,而 labels 表示这些输入对应的标签。

3. 序列到序列任务(如机器翻译或文本生成)

在机器翻译或文本生成任务中,input_idslabels 也会有明显的区别:

  • input_ids 是源语言的文本表示。
  • labels 是目标语言的文本表示。

例如,在翻译任务中:

  • 输入文本(源语言): "I love AI"

  • 目标文本(目标语言): "J'aime l'IA"

  • input_ids = [1045, 2293, 9474](表示源语言)

  • labels = [2013, 1244, 1849](表示目标语言)

在这种情况下,input_idslabels 完全不同,因为它们分别表示源语言和目标语言。

4. 总结:

  • input_ids:表示模型的输入文本的 token 化结果,通常是对文本进行分词、编码后的词汇索引。它仅包含输入数据,不包含目标输出。
  • labels:表示模型的目标输出,通常用于训练期间计算损失。在语言建模任务中,labels 可能和 input_ids 一样;但在其他任务(如分类、序列标注、机器翻译等)中,labelsinput_ids 完全不同,表示模型应该生成或预测的目标结果。

因此,input_ids 不包含 训练数据中的真实输出,而 labels 才是训练时用来计算损失和评估模型性能的目标值。

处理代码

crop_train.json训练数据集格式如下:

{"instruction": "你是农作物领域专门进行关系抽取的专家。请从给定的文本中抽取出关系三元组,不存在的关系返回空列表。请按照JSON字符串的格式回答。","input": "煤是一种常见的化石燃料,家庭用煤经过了从\"煤球\"到\"蜂窝煤\"的演变。","output": "[{\"head\": \"煤\", \"relation\": \"use\", \"tail\": \"燃料\"}]"
},

数据预处理代码如下: 

import json
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainerdataset = load_dataset("json", data_files="./crop_train.json", split="train")
print(f"dataset: {dataset}")tokenizer = AutoTokenizer.from_pretrained("./glm-4-9b-chat", trust_remote_code=True)
print(f"tokenizer: {tokenizer}")def process_func(example):MAX_LENGTH = 256input_ids, attention_mask, labels = [], [], []# 合并example的instruction和input字段为一个字符串instruction = f"{example['instruction']} {example['input']}".strip()  # queryinstruction = tokenizer.apply_chat_template([{"role": "user", "content": instruction}],add_generation_prompt=True,tokenize=True,return_tensors="pt",return_dict=True)  # '[gMASK] <sop> <|user|> \nquery <|assistant|>'# 检查example["output"]是否是列表,并相应地处理if isinstance(example["output"], list):response_text = "\n".join(example["output"])else:response_text = "\n" + example["output"]response = tokenizer(response_text, add_special_tokens=False)  # \n response, 缺少eos token# input_ids = input + outputinput_ids = instruction["input_ids"][0].numpy().tolist() + response["input_ids"] + [tokenizer.eos_token_id]attention_mask = instruction["attention_mask"][0].numpy().tolist() + response["attention_mask"] + [1]# labels = input(-100) + outputlabels = [-100] * len(instruction["input_ids"][0].numpy().tolist()) + response["input_ids"] + [tokenizer.eos_token_id]if len(input_ids) > MAX_LENGTH:input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels}# 训练数据集经过预处理后生成<input_ids, labels, attention_mask>
tokenized_ds = dataset.map(process_func, remove_columns=['instruction', 'input', 'output'])print(f"All tokenizer tokens ids: {tokenized_ds}")     # features: ['input_ids', 'attention_mask', 'labels'],# tokenized_ds: 包含input_ids, attention_mask, labels = [], [], []
input_ids_1 = tokenized_ds[0]["input_ids"]
attention_mask_1 = tokenized_ds[0]["attention_mask"]
labels_1 = tokenized_ds[0]["labels"]
print(f"input_ids_1: {input_ids_1}")
print(f"attention_mask_1: {attention_mask_1}")
print(f"labels_1: {labels_1}")input_text_1 = tokenizer.decode(input_ids_1)
print(f"input_ids_1_decode: {input_text_1}")

tokenized_ds:里面包含所有的训练数据经过转换后的 ['input_ids', 'attention_mask', 'labels']集合,用于model直接使用来训练。

模型训练所需数据格式如下: 

input_ids_1: [151331, 151333, 151336, 198, 103408, 112687, 99788, 102014, 98638, 99172, 115023, 98314, 100153, 1773, 98964, 98484, 98602, 100966, 103231, 98322, 100319, 107325, 99172, 120673, 98555, 3837, 107399, 102189, 104559, 98745, 106522, 1773, 98964, 99928, 5370, 121478, 98314, 104714, 99770, 1773, 10231, 227, 97, 100375, 104250, 112075, 106512, 3837, 99716, 98340, 100855, 114094, 98484, 1, 100855, 98781, 1, 98344, 1, 125272, 100855, 1, 98314, 110001, 1773, 151337, 198, 58, 4913, 1983, 788, 330, 100855, 497, 330, 22166, 788, 330, 810, 497, 330, 14576, 788, 330, 106512, 9204, 60, 151329]
attention_mask_1: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
labels_1: [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 198, 58, 4913, 1983, 788, 330, 100855, 497, 330, 22166, 788, 330, 810, 497, 330, 14576, 788, 330, 106512, 9204, 60, 151329]input_ids_1_decode: 
[gMASK] <sop> <|user|> 
你是农作物领域专门进行关系抽取的专家。请从给定的文本中抽取出关系三元组,不存在的关系返回空列表。请按照JSON字符串的格式回答。 煤是一种常见的化石燃料,家庭用煤经过了从"煤球"到"蜂窝煤"的演变。 
<|assistant|> [{"head": "煤", "relation": "use", "tail": "燃料"}] <|endoftext|>

模型训练

# 模型训练参数
args = TrainingArguments(output_dir="./chatbot",per_device_train_batch_size=2,gradient_accumulation_steps=8,gradient_checkpointing=True,logging_steps=100,num_train_epochs=10,learning_rate=1e-4,remove_unused_columns=False,save_strategy="epoch"
)# 开始训练
trainer = Trainer(model=model,args=args,# 使用 <input_ids, labels, attention_mask>train_dataset=tokenized_ds.select(range(10000)),data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)

总结

LLM的训练或者微调都是需要<input_ids, labels, attention_mask>形式的数据,e. g. Structured IE中的三个task的dataset仍是如此。

class MOFDataset(Dataset):def __init__(self, dataset_config, tokenizer, split_name, max_words=1024):#self.data = json.load(open(dataset_config.data_path))if split_name == "train":self.data = json.load(open(dataset_config.data_path+"/train.json")) # self.data[0]["train"]  # Adjust this based on your dataset's structureelse:self.data = json.load(open(dataset_config.data_path+"/val.json"))# self.data[0]["validation"]  # Adjust this based on your dataset's structureself.max_words = max_wordsself.tokenizer = tokenizerdef __len__(self):return len(self.data)def __getitem__(self, index):IGNORE_INDEX = -100  # The default setting in CrossEntropyLossitem = self.data[index]#prompt = f"### Instruction:\n{item['instruction']}\n\n### Input:\n{item['input']}\n\n### Response:"prompt = item['input']#f"item['input']\n\n"# example = input+ outputexample = prompt + item["output"]
#        print(example)prompt = torch.tensor(self.tokenizer.encode(prompt), dtype=torch.int64)# example = input_ids + output_idsexample = self.tokenizer.encode(example)    # input+output
#        print(example)# example = input_ids + output_ids + <eos>example.append(self.tokenizer.eos_token_id)example = torch.tensor(example, dtype=torch.int64)padding = self.max_words - example.shape[0]# 用 -1 填充if padding > 0:example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))# 截断elif padding < 0:example = example[: self.max_words]# labels = examplelabels = copy.deepcopy(example)# labels = input(-1) + output# 复制 example,并将 example 中与 prompt 对应的部分设置为 -1。这样,模型在训练时只会关注 output 部分作为标签,而忽略掉输入部分。labels[: len(prompt)] = -1# 创建一个 mask,用于标记 example 中不为 -1(即有效)的部分。example_mask = example.ge(0)# 创建一个 mask,用于标记 labels 中不为 IGNORE_INDEX(即有效)的部分。label_mask = labels.ge(0)# 将无效的部分(填充部分)置为 0,这样它们不会对损失计算产生影响。example[~example_mask] = 0labels[~label_mask] = IGNORE_INDEXexample_mask = example_mask.float()label_mask = label_mask.float()return {"input_ids": example,   # example = input_ids + output_ids + <eos>"labels": labels,       # labels = input(-1) + output,输入部分被替换为 -1,只保留 output 部分作为目标标签。"attention_mask": example_mask,  # 返回一个 mask,指示哪些位置是有效的(即不是填充部分)。}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/495154.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【幼儿园识物】比大小启蒙资料PDF

下载链接::huanxigou-uihttp://yiwub.natapp1.cc/zyweb/#/source/viewPdf?id9

帧缓存的分配

帧缓存实际上就是一块内存。在 Android 系统中分配与回收帧缓存&#xff0c;使用的是一个叫 ION 的内核模块&#xff0c;App 使用 ioctl 系统调用后&#xff0c;会在内核内存中分配一块符合要求的内存&#xff0c;用户态会拿到一个 fd&#xff08;有的地方也称之为 handle&…

SDMTSP:黑翅鸢算法(Black-winged kite algorithm,BKA)求解单仓库多旅行商问题,可以更改数据集和起点(MATLAB代码)

一、黑翅鸢算法BKA 黑翅鸢算法&#xff08;Black-winged kite algorithm&#xff0c;BKA&#xff09;由Wang Jun等人于2024年提出&#xff0c;该算法受黑翅鸢的迁徙和掠食行为启发而得。BKA集成了柯西突变策略和领导者策略&#xff0c;增强了算法的全局搜索能力&#xff0c;提…

【Python】基础语法介绍

目录 一、标识符和关键字 二、注释 三、缩进 四、输入和输出 五、字符串操作 六、基本数据类型 七、复合数据类型 7.1 列表 7.2 元组 7.3 字典 7.4 集合 八、数据类型转换 九、运算符 8.1 算术运算符 8.2 比较运算符 8.3 赋值运算符 8.4 位运算符 8.5 逻辑运…

stm32定时器输出比较----驱动步进电机

定时器输出比较理论 OC(Output Compare)输出比较输出比较可以通过比较CNT与CCR寄存器值的关系,来对输出电平进行置1、置0或翻转的操作,用于输出一定频率和占空比的PWM波形每个高级定时器和通用定时器都拥有4个输出比较通道高级定时器的前3个通道额外拥有死区生成和互补输出…

【NLP 17、NLP的基础——分词】

我始终相信&#xff0c;世间所有的安排都有它的道理&#xff1b;失之东隅&#xff0c;收之桑榆 —— 24.12.20 一、中文分词的介绍 1.为什么讲分词&#xff1f; ① 分词是一个被长期研究的任务&#xff0c;通过了解分词算法的发展&#xff0c;可以看到NLP的研究历程 ② 分词…

Rust 在前端基建中的使用

摘要 随着前端技术的不断发展&#xff0c;前端基础设施&#xff08;前端基建&#xff09;的建设已成为提升开发效率、保障产品质量的关键环节。然而&#xff0c;在应对复杂业务场景与高性能需求时&#xff0c;传统的前端技术栈逐渐暴露出诸多不足。近年来&#xff0c;Rust语言…

谷歌浏览器的网络连接问题解决方案

在数字化时代&#xff0c;网络浏览器已成为日常工作和生活中不可或缺的工具。谷歌浏览器以其快速、稳定和丰富的功能深受用户喜爱。然而&#xff0c;就像其他软件一样&#xff0c;谷歌浏览器也可能遇到网络连接问题&#xff0c;这可能由多种因素引起。本文将为您提供一系列解决…

【Unity3D】Particle粒子特效或3D物体显示在UGUI上的方案

目录 一、RawImage Camera RenderTexture方式 &#xff08;1&#xff09;扩展知识&#xff1a;实现射线检测RawImage内的3D物体 &#xff08;2&#xff09;扩展知识&#xff1a;实现粒子特效显示RawImage上 二、UI摄像机 Canvas(Screen Space - Camera模式)方式 &#…

14-zookeeper环境搭建

0、环境 java&#xff1a;1.8zookeeper&#xff1a;3.5.6 1、下载 zookeeper下载点击这里。 2、安装 下载完成后解压&#xff0c;放到你想放的目录里。先看一下zookeeper的目录结构&#xff0c;如下图&#xff1a; 进入conf目录&#xff0c;复制zoo_sample.cfg&#xff0…

精准提升:从94.5%到99.4%——目标检测调优全纪录

&#x1f680; 目标检测模型调优过程记录 在进行目标检测模型的训练过程中&#xff0c;我们面对了许多挑战与迭代。从初始模型的训练结果到最终的调优优化&#xff0c;每一步的实验和调整都有其独特的思路和收获。本文记录了我在优化目标检测模型的过程中进行的几次尝试&#…

贪心算法(三)

目录 一、k次取反后最大化的数组和 二、优势洗牌 三、最长回文串 四、增减字符串匹配 一、k次取反后最大化的数组和 k次取反后最大化的数组和 贪心策略&#xff1a; 解题代码&#xff1a; class Solution { public:int largestSumAfterKNegations(vector<int>&am…

基于Springboot的在线问卷调查系统【附源码】

基于Springboot的在线问卷调查系统 效果如下&#xff1a; 系统主页面 问卷列表页面 个人中心页面 系统登陆页面 管理员主页面 问卷管理页面 研究背景 随着互联网技术的飞速发展&#xff0c;传统的问卷调查方式因其时间和地点的限制&#xff0c;难以高效地收集到足够的数据。…

Python选择题训练工具:高效学习、答题回顾与音频朗读一站式体验

一、引言 随着人工智能技术的不断进步&#xff0c;传统的教学方式已经逐渐向智能化、互动化转变。在众多英语测试题型中&#xff0c;选择题作为一种高效的方式被广泛应用于各类培训与考试中。为了帮助学生高效学习与自测&#xff0c;本篇文章将采用Python编写一款基于 Python …

《三角洲行动》游戏运行时提示“缺失kernel32.dll”:问题解析与解决方案

《三角洲行动》游戏运行时提示“缺失kernel32.dll”&#xff1a;问题解析与解决方案 作为软件开发领域的一名从业者&#xff0c;我深知电脑游戏运行过程中可能遇到的各种挑战&#xff0c;尤其是文件丢失、文件损坏以及系统报错等问题。今天&#xff0c;我将以经典游戏《三角洲…

【从零开始入门unity游戏开发之——unity篇02】unity6基础入门——软件下载安装、Unity Hub配置、安装unity编辑器、许可证管理

文章目录 一、软件下载安装1、Unity官网2、下载Unity Hub 二、修改Unity Hub配置1、设置Unity Hub中文语言2、修改默认存储目录 三、安装unity编辑器1、点击安装编辑器2、版本选择3、关于版本号4、安装模块选择5、等待下载完成自动安装即可6、追加unity和模块 四、许可证管理专…

AtCoder Beginner Contest 385(A~F)题解

A - Equally 思路&#xff1a;由题可知最多只能分成三组&#xff0c;我们只需要判断是否三个数都相等&#xff0c;或者两个数相加等于另外一个数即可 #include<bits/stdc.h> using namespace std; #define int long long int n; string s; int a,b,c; signed main() {ci…

STM32串口第一次接收数据时第一个字节丢失的问题

解决方法&#xff1a;开启中断之前&#xff0c;先清除标志位【1】。 串口清除标志位&#xff1a; __HAL_UART_CLEAR_PEFLAG(&huart1); HAL_UART_Receive_IT(&huart1,&RxUart, 1); 定时器清除标志位&#xff1a; __HAL_TIM_CLEAR_FLAG(&htim3,TIM_FLAG_UPDATE);…

Unity3d 基于UGUI和VideoPlayer 实现一个多功能视频播放器功能(含源码)

前言 随着Unity3d引擎在数字沙盘、智慧工厂、数字孪生等场景的广泛应用&#xff0c;视频已成为系统程序中展示时&#xff0c;不可或缺的一部分。在 Unity3d 中&#xff0c;我们可以通过强大的 VideoPlayer 组件和灵活的 UGUI 系统&#xff0c;将视频播放功能无缝集成到用户界面…

第22天:信息收集-Web应用各语言框架安全组件联动系统数据特征人工分析识别项目

#知识点 1、信息收集-Web应用-开发框架-识别安全 2、信息收集-Web应用-安全组件-特征分析 一、ICO图标&#xff1a; 1、某个应用系统的标示&#xff0c;如若依系统有自己特点的图标&#xff1b;一旦该系统出问题&#xff0c;使用该系统的网站都会受到影响&#xff1b; 2、某个公…