昇思MindSpore 应用学习-基于MindSpore的GPT2文本摘要

基于MindSpore的GPT2文本摘要 --AI代码解析

数据集加载与处理

  1. 数据集加载 本次实验使用的是nlpcc2017摘要数据,内容为新闻正文及其摘要,总计50000个样本。
from mindnlp.utils import http_get  # 导入http_get模块,用于下载数据集# download dataset
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'  # 数据集的URL
path = http_get(url, './')  # 使用http_get函数下载数据集并保存到当前目录from mindspore.dataset import TextFileDataset  # 导入TextFileDataset模块,用于加载文本文件数据集# load dataset
dataset = TextFileDataset(str(path), shuffle=False)  # 加载下载的数据集,shuffle设置为False表示不打乱数据顺序
dataset.get_dataset_size()  # 获取数据集的大小# split into training and testing dataset
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)  # 将数据集按比例分割为训练集(90%)和测试集(10%),randomize设置为False表示不随机打乱

代码解析:

  1. 导入模块
    • from mindnlp.utils import http_get:导入http_get函数,用于从指定的URL下载数据。
  2. 下载数据集
    • url = '...':定义数据集的下载链接。
    • path = http_get(url, './'):调用http_get函数,将数据集下载到当前目录,并返回文件路径。
  3. 加载数据集
    • from mindspore.dataset import TextFileDataset:导入TextFileDataset类,用于处理文本文件类型的数据集。
    • dataset = TextFileDataset(str(path), shuffle=False):创建一个TextFileDataset对象,加载之前下载的数据,shuffle=False表示保持原始顺序。
  4. 获取数据集大小
    • dataset.get_dataset_size():调用该方法获取数据集的大小(即样本数量)。
  5. 分割数据集
    • train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False):将数据集按90%和10%的比例分割为训练集和测试集,randomize=False表明不打乱顺序。

API 解析:

  • http_get(url, path):该函数用于从url指定的地址下载数据,并将其保存到path指定的路径。下载完成后返回文件的路径。
  • TextFileDataset:这是一个用于处理文本数据集的类,支持从文件中读取文本数据并进行后续处理。
  • get_dataset_size():该方法用于获取数据集的总样本数量,方便后续的数据处理和分析。
  • split(ratios, randomize):此方法用于将数据集按照指定的比例分割成多个子集,ratios为比例列表,randomize表示是否打乱数据顺序。
  1. 数据预处理 原始数据格式:
article: [CLS] article_context [SEP]
summary: [CLS] summary_context [SEP]

预处理后的数据格式:

[CLS] article_context [SEP] summary_context [SEP]
import json  # 导入json模块,用于处理JSON格式的数据
import numpy as np  # 导入numpy库,用于数值计算和数组处理# preprocess dataset
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):# 定义数据预处理函数def read_map(text):data = json.loads(text.tobytes())  # 将文本数据解析为字典格式return np.array(data['article']), np.array(data['summarization'])  # 返回文章和摘要作为numpy数组def merge_and_pad(article, summary):# 对文章和摘要进行标记化处理# 对于最大序列长度进行填充,只截断文章tokenized = tokenizer(text=article, text_pair=summary,padding='max_length', truncation='only_first', max_length=max_seq_len)return tokenized['input_ids'], tokenized['input_ids']  # 返回标记化的输入IDdataset = dataset.map(read_map, 'text', ['article', 'summary'])  # 使用read_map函数将文本数据映射为文章和摘要# 将列名更改为input_ids和labels以用于后续训练dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])  # 使用merge_and_pad处理数据dataset = dataset.batch(batch_size)  # 按照batch_size对数据集进行分批处理if shuffle:dataset = dataset.shuffle(batch_size)  # 如果shuffle为True,则对数据集进行洗牌return dataset  # 返回经过处理的数据集

代码解析:

  1. 导入模块
    • import json:导入JSON模块,用于处理JSON格式的数据。
    • import numpy as np:导入NumPy库,以便进行数值计算和数组操作。
  2. 定义数据预处理函数
    • def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):定义一个函数来预处理数据集,包含数据集、分词器、批量大小、最大序列长度和是否打乱的参数。
  3. **内部函数 **read_map
    • def read_map(text):定义一个内部函数,用于处理输入的文本数据。
    • data = json.loads(text.tobytes()):将文本数据转换为字典格式。
    • return np.array(data['article']), np.array(data['summarization']):返回文章和摘要的NumPy数组。
  4. **内部函数 **merge_and_pad
    • def merge_and_pad(article, summary):定义一个函数,用于对文章和摘要进行标记化处理。
    • tokenized = tokenizer(...):使用分词器对文章和摘要进行标记化,设置填充和截断的选项。
    • return tokenized['input_ids'], tokenized['input_ids']:返回标记化后的输入ID。
  5. 映射和处理数据集
    • dataset = dataset.map(read_map, 'text', ['article', 'summary']):使用read_map函数将文本数据映射为文章和摘要。
    • dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels']):使用merge_and_pad函数处理数据并重命名列。
  6. 分批和洗牌
    • dataset = dataset.batch(batch_size):将数据集按指定的批量大小进行分批。
    • if shuffle: dataset = dataset.shuffle(batch_size):如果shuffle为真,则对数据集进行洗牌。
  7. 返回处理后的数据集
    • return dataset:返回经过预处理的数据集。

API 解析:

  • json.loads():该函数用于将JSON格式的字符串解析为Python字典或列表。
  • tokenizer():这是一个用于文本标记化的函数,通常将原始文本转换为模型可接受的输入格式,包括输入ID和其他必要的信息。参数中包含填充、截断策略以及最大长度等设置。
  • dataset.map(func, input_columns, output_columns):该方法应用指定的函数到数据集的指定列,返回新的数据集。
  • dataset.batch(batch_size):将数据集按指定的batch_size进行分批处理,方便后续训练。
  • dataset.shuffle(buffer_size):该方法用于随机打乱数据集的顺序,buffer_size指代在打乱过程中使用的缓冲区大小。

因GPT2无中文的tokenizer,我们使用BertTokenizer替代。

from mindnlp.transformers import BertTokenizer  # 导入BertTokenizer,用于中文文本的标记化# We use BertTokenizer for tokenizing chinese context.
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')  # 从预训练模型加载BERT分词器len(tokenizer)  # 获取分词器的词汇表大小# 处理训练数据集
train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)  # 使用process_dataset函数处理训练集,设置batch_size为4next(train_dataset.create_tuple_iterator())  # 创建一个元组迭代器并获取下一个批次的数据

代码解析:

  1. 导入BertTokenizer
    • from mindnlp.transformers import BertTokenizer:导入BertTokenizer类,该类用于处理中文文本的标记化。
  2. 初始化分词器
    • tokenizer = BertTokenizer.from_pretrained('bert-base-chinese'):从预训练的BERT模型加载中文分词器,'bert-base-chinese'是BERT的中文基础版。
  3. 获取词汇表大小
    • len(tokenizer):获取分词器的词汇表大小,以便了解可以处理的词汇数量。
  4. 处理训练数据集
    • train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4):调用process_dataset函数处理训练数据集,设置每个批次的大小为4。
  5. 创建迭代器并获取数据
    • next(train_dataset.create_tuple_iterator()):使用create_tuple_iterator()方法创建一个元组迭代器,并获取下一个批次的数据。这通常用于迭代训练数据集中的样本。

API 解析:

  • BertTokenizer.from_pretrained():该方法用于加载指定预训练BERT模型的分词器,可以处理特定语言的文本,如中文。
  • len(tokenizer):用于获取分词器的词汇表大小,返回值为一个整数,表示分词器可以识别的词语数量。
  • process_dataset(dataset, tokenizer, batch_size):自定义函数,用于处理数据集并返回一个批量处理过的数据集。
  • create_tuple_iterator():该方法用于创建一个元组形式的迭代器,可以迭代数据集中的样本,适合于模型训练过程中的数据输入。
  • next(iterator):用于获取迭代器中的下一个元素,通常在处理批量数据时使用。

模型构建

  1. 构建GPT2ForSummarization模型,注意_shift right_的操作。
from mindspore import ops  # 导入MindSpore的操作模块
from mindnlp.transformers import GPT2LMHeadModel  # 导入GPT2LMHeadModel类,用于文本生成任务class GPT2ForSummarization(GPT2LMHeadModel):  # 定义一个新的类,继承自GPT2LMHeadModeldef construct(self,input_ids=None,  # 输入的token IDattention_mask=None,  # 注意力掩码labels=None,  # 真实标签):# 调用父类的construct方法,获取模型输出outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)# 对logits进行偏移处理,移除最后一个token的输出shift_logits = outputs.logits[..., :-1, :]  # 将labels进行偏移,移除第一个token的标签shift_labels = labels[..., 1:] # 计算交叉熵损失loss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)return loss  # 返回计算得到的损失

代码解析:

  1. 导入模块
    • from mindspore import ops:导入MindSpore的操作模块,提供各种张量操作和计算功能。
    • from mindnlp.transformers import GPT2LMHeadModel:导入GPT2的语言模型头类,用于文本生成和相关任务。
  2. **定义类 **GPT2ForSummarization
    • class GPT2ForSummarization(GPT2LMHeadModel):定义一个名为GPT2ForSummarization的类,继承自GPT2LMHeadModel,目的是用于文本摘要生成。
  3. 构造方法
    • def construct(self, input_ids=None, attention_mask=None, labels=None):定义构造方法,接受输入的token ID、注意力掩码和真实标签。
  4. 调用父类方法
    • outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask):调用父类的方法,以获取模型的输出,包括logits(预测结果)。
  5. 处理logits和labels
    • shift_logits = outputs.logits[..., :-1, :]:从logits中移除最后一个token的输出,便于后续计算。
    • shift_labels = labels[..., 1:]:从labels中移除第一个token的标签,以对齐logits。
  6. 计算损失
    • loss = ops.cross_entropy(...):使用交叉熵函数计算损失。shift_logits被展平,并与shift_labels进行比较,ignore_index参数用于忽略填充的token。
  7. 返回损失
    • return loss:返回计算得到的损失值。

API 解析:

  • ops.cross_entropy():用于计算交叉熵损失的函数。输入包括预测的logits和真实标签,常用于分类任务的损失计算。
  • super().construct():调用父类(GPT2LMHeadModel)的构造函数,以便获取标准的输出,包括logits。
  • .view():这是一个张量操作,用于改变张量的形状,通常用于将多维张量展平为一维或改变其维度,便于后续处理。
  • tokenizer.pad_token_id:这是分词器中定义的填充token的ID,常用于在计算损失时忽略填充部分的影响。
  1. 动态学习率
from mindspore import ops  # 导入MindSpore的操作模块
from mindspore.nn.learning_rate_schedule import LearningRateSchedule  # 导入学习率调度类class LinearWithWarmUp(LearningRateSchedule):  # 定义一个线性带热身的学习率调度类"""Warmup-decay learning rate."""def __init__(self, learning_rate, num_warmup_steps, num_training_steps):super().__init__()  # 调用父类的构造函数self.learning_rate = learning_rate  # 初始化学习率self.num_warmup_steps = num_warmup_steps  # 初始化热身步骤数self.num_training_steps = num_training_steps  # 初始化训练总步骤数def construct(self, global_step):  # 定义构造方法,接收当前的全局步骤数if global_step < self.num_warmup_steps:  # 如果当前步骤小于热身步骤数# 计算线性热身学习率return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate# 计算衰减后的学习率return ops.maximum(0.0, (self.num_training_steps - global_step) / float(max(1, self.num_training_steps - self.num_warmup_steps))) * self.learning_rate  # 返回学习率

代码解析:

  1. 导入模块
    • from mindspore import ops:导入MindSpore的操作模块,提供数值计算功能。
    • from mindspore.nn.learning_rate_schedule import LearningRateSchedule:导入学习率调度类,允许创建自定义的学习率调度策略。
  2. **定义类 **LinearWithWarmUp
    • class LinearWithWarmUp(LearningRateSchedule):定义一个新的类,继承自LearningRateSchedule,用于实现带热身的线性学习率调度。
  3. 构造方法
    • def __init__(self, learning_rate, num_warmup_steps, num_training_steps):定义构造方法,接受学习率、热身步骤数和训练总步骤数。
    • super().__init__():调用父类的构造函数进行初始化。
    • self.learning_rate = learning_rate:保存学习率。
    • self.num_warmup_steps = num_warmup_steps:保存热身步骤数。
    • self.num_training_steps = num_training_steps:保存训练总步骤数。
  4. 构造学习率
    • def construct(self, global_step):定义构造方法,根据当前全局步骤计算学习率。
    • if global_step < self.num_warmup_steps:检查当前步骤是否在热身阶段。
    • return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate:如果在热身阶段,计算线性增长的学习率。
  5. 计算衰减学习率
    • return ops.maximum(...) * self.learning_rate:在热身结束后,计算衰减学习率,并确保学习率不会为负。

API 解析:

  • LearningRateSchedule:一个基类,用于实现自定义的学习率调度策略,允许用户定义如何根据训练进度调整学习率。
  • ops.maximum():用于计算输入张量的最大值,通常用于确保学习率的非负性。
  • max(1, value):返回value和1之间的较大值,避免在计算中出现除以零的情况。
  • global_step:当前的全局训练步骤,通常用于动态调整学习率。
  • construct(...):这是一个方法,用于根据当前的训练状态(如全局步骤)动态计算所需的学习率。

模型训练

num_epochs = 1  # 设置训练的轮数为1
warmup_steps = 2000  # 设置热身步骤数为2000
learning_rate = 1.5e-4  # 设置初始学习率为1.5e-4num_training_steps = num_epochs * train_dataset.get_dataset_size()  # 计算总训练步骤数from mindspore import nn  # 导入MindSpore的神经网络模块
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel  # 导入GPT2的配置和模型类config = GPT2Config(vocab_size=len(tokenizer))  # 创建GPT2模型配置,设置词汇表大小
model = GPT2ForSummarization(config)  # 实例化模型# 创建学习率调度器
lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)# 定义优化器
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)# 记录模型参数数量
print('number of model parameters: {}'.format(model.num_parameters()))  # 输出模型的参数数量

代码解析:

  1. 设置超参数
    • num_epochs = 1:设置训练的轮数为1。
    • warmup_steps = 2000:设置热身步骤数为2000,以在训练初期逐渐增加学习率。
    • learning_rate = 1.5e-4:设置初始学习率为1.5e-4,控制模型学习的速度。
  2. 计算训练步骤
    • num_training_steps = num_epochs * train_dataset.get_dataset_size():计算总的训练步骤数,乘以训练数据集的大小。
  3. 导入必要模块
    • from mindspore import nn:导入MindSpore的神经网络模块,以便使用神经网络相关的类和函数。
    • from mindnlp.transformers import GPT2Config, GPT2LMHeadModel:导入用于配置和创建GPT2模型的类。
  4. 创建模型配置
    • config = GPT2Config(vocab_size=len(tokenizer)):使用分词器的词汇表大小创建GPT2模型的配置。
  5. 实例化模型
    • model = GPT2ForSummarization(config):根据配置实例化文本摘要生成模型。
  6. 创建学习率调度器
    • lr_scheduler = LinearWithWarmUp(...):使用先前定义的LinearWithWarmUp类创建学习率调度器,传入学习率、热身步骤数和总训练步骤数。
  7. 定义优化器
    • optimizer = nn.AdamWeightDecay(...):使用Adam优化器及权重衰减,对模型的可训练参数进行优化,学习率由学习率调度器提供。
  8. 记录并输出模型参数数量
    • print('number of model parameters: {}'.format(model.num_parameters())):输出模型的参数总数,以便了解模型的复杂性。

API 解析:

  • train_dataset.get_dataset_size():获取训练数据集的大小,通常返回样本数量,用于计算总训练步骤。
  • GPT2Config(vocab_size=len(tokenizer)):用于初始化GPT2模型的配置对象,包括设置词汇表大小等参数。
  • GPT2ForSummarization(config):创建模型实例,使用之前定义的模型配置。
  • LinearWithWarmUp(...):创建一个线性带热身的学习率调度器对象,使学习率在训练初期逐渐增加,随后再进行衰减。
  • nn.AdamWeightDecay(...):Adam优化器的一个变种,支持权重衰减功能,用于正则化以防止过拟合。
  • model.num_parameters():返回模型中可训练参数的数量,帮助用户理解模型的规模。
from mindnlp._legacy.engine import Trainer  # 导入Trainer类,用于训练模型
from mindnlp._legacy.engine.callbacks import CheckpointCallback  # 导入CheckpointCallback类,用于保存训练模型的检查点# 创建检查点回调对象
ckpoint_cb = CheckpointCallback(save_path='checkpoint',  # 指定保存路径ckpt_name='gpt2_summarization',  # 指定检查点的名称epochs=1,  # 设置每个epoch保存一次检查点keep_checkpoint_max=2  # 最大保留2个检查点
)# 创建Trainer对象,用于训练模型
trainer = Trainer(network=model,  # 训练的模型train_dataset=train_dataset,  # 训练数据集epochs=1,  # 设置训练轮数optimizer=optimizer,  # 使用的优化器callbacks=ckpoint_cb  # 添加回调函数
)# 设置混合精度
trainer.set_amp(level='O1')  # 开启混合精度训练,O1级别表示混合精度的使用# 开始训练,指定目标列
trainer.run(tgt_columns="labels")  # 开始训练,"labels"为目标列

代码解析:

  1. 导入模块
    • from mindnlp._legacy.engine import Trainer:导入Trainer类,用于管理和执行模型的训练流程。
    • from mindnlp._legacy.engine.callbacks import CheckpointCallback:导入CheckpointCallback类,用于在训练过程中保存模型的检查点。
  2. 创建检查点回调对象
    • ckpoint_cb = CheckpointCallback(...):创建一个检查点回调对象,允许在训练过程中定期保存模型状态。
      • save_path='checkpoint':指定保存检查点的文件夹名称。
      • ckpt_name='gpt2_summarization':指定保存的检查点文件名。
      • epochs=1:设置每个epoch结束时保存一次检查点。
      • keep_checkpoint_max=2:最多保留2个检查点文件,旧的检查点将会被覆盖。
  3. 创建Trainer对象
    • trainer = Trainer(...):实例化Trainer对象,控制训练进程。
      • network=model:指定要训练的模型。
      • train_dataset=train_dataset:传入训练数据集。
      • epochs=1:设置训练轮数为1。
      • optimizer=optimizer:指定优化器。
      • callbacks=ckpoint_cb:添加回调函数(检查点回调)。
  4. 设置混合精度
    • trainer.set_amp(level='O1'):启用混合精度训练,使用O1级别的配置,O1表示大部分操作使用FP16(半精度浮点),而少数关键操作使用FP32(单精度浮点),以提高计算效率并减少内存使用。
  5. 开始训练
    • trainer.run(tgt_columns="labels"):开始模型的训练过程,tgt_columns="labels"指定用于训练的目标列。

API 解析:

  • Trainer:一个训练管理类,提供训练过程中所需的功能,如训练循环、优化器管理和回调函数管理。
  • CheckpointCallback(...):用于设置训练过程中的检查点保存策略,使得每次训练完成一个epoch后可以保存模型状态,以便后续恢复和继续训练。
  • trainer.set_amp(level='O1'):配置混合精度训练,旨在提高训练性能,降低内存使用。
  • trainer.run(tgt_columns="labels"):执行训练过程,tgt_columns指定模型训练所需的目标输出列,通常用于处理标签。

模型推理

数据处理,将向量数据变为中文数据

def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):# 定义处理测试数据集的函数def read_map(text):# 解析文本数据为文章和摘要data = json.loads(text.tobytes())  # 将字节文本加载为JSONreturn np.array(data['article']), np.array(data['summarization'])  # 返回文章和摘要def pad(article):# 对文章进行分词和填充tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len - max_summary_len)  # 分词并截断return tokenized['input_ids']  # 返回输入ID# 使用map函数对数据集进行处理dataset = dataset.map(read_map, 'text', ['article', 'summary'])  # 解析数据集dataset = dataset.map(pad, 'article', ['input_ids'])  # 填充分词后的文章# 将数据集按批次处理dataset = dataset.batch(batch_size)return dataset  # 返回处理后的数据集# 处理测试数据集
test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)  
print(next(test_dataset.create_tuple_iterator(output_numpy=True)))  # 打印测试数据集中的一项# 加载预训练模型
model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)  
model.set_train(False)  # 设置模型为评估模式# 设置模型的结束标记ID
model.config.eos_token_id = model.config.sep_token_id  i = 0
# 迭代测试数据集
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():# 使用模型生成总结output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)  output_text = tokenizer.decode(output_ids[0].tolist())  # 解码生成的ID为文本print(output_text)  # 输出生成的总结i += 1if i == 1:  # 只处理一个样本break

代码解析:

  1. **定义函数 **process_test_dataset
    • def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):定义一个函数,用于处理测试数据集。
    • read_map(text):定义内部函数,用于将输入文本转换为文章和摘要。
      • data = json.loads(text.tobytes()):将字节文本解析为JSON格式。
      • return np.array(data['article']), np.array(data['summarization']):返回文章和摘要内容。
    • pad(article):定义另一个内部函数,用于对文章进行分词处理。
      • tokenized = tokenizer(...):使用分词器对文章进行分词和截断,确保长度不超过max_seq_len - max_summary_len
      • return tokenized['input_ids']:返回分词后的输入ID。
  2. 处理数据集
    • dataset = dataset.map(read_map, 'text', ['article', 'summary']):对数据集应用read_map函数,解析输入文本。
    • dataset = dataset.map(pad, 'article', ['input_ids']):对文章进行分词填充。
    • dataset = dataset.batch(batch_size):将数据集按批次进行处理。
  3. 返回处理后的数据集
    • return dataset:返回处理后的测试数据集。
  4. 处理测试数据集
    • test_dataset = process_test_dataset(...):调用函数处理测试集。
    • print(next(test_dataset.create_tuple_iterator(output_numpy=True))):打印处理后的测试数据集中的一项。
  5. 加载预训练模型
    • model = GPT2LMHeadModel.from_pretrained(...):从指定路径加载预训练的GPT2模型。
    • model.set_train(False):设置模型为评估模式,不进行训练。
  6. 设置结束标记ID
    • model.config.eos_token_id = model.config.sep_token_id:将结束标记ID设置为分隔标记ID。
  7. 生成摘要
    • for (input_ids, raw_summary) in test_dataset.create_tuple_iterator()::迭代处理后的测试数据集。
      • output_ids = model.generate(...):使用模型生成摘要,设置最大生成长度、束搜索数量和不重复n-gram的大小。
      • output_text = tokenizer.decode(output_ids[0].tolist()):解码生成的ID,转换为可读文本。
      • print(output_text):输出生成的摘要。
      • i += 1:累加处理的样本数量。
      • if i == 1: break:仅处理一个样本,之后退出循环。

API 解析:

  • json.loads(...):用于将JSON字符串解析为Python对象,适合处理动态数据。
  • tokenizer(...):分词器对象的方法,用于将文本转换为模型输入所需的ID。
  • dataset.map(...):数据集的映射函数,允许用户将自定义函数应用于数据集的每个元素。
  • dataset.batch(batch_size):将数据集划分为指定大小的批次,以便于训练或推理。
  • GPT2LMHeadModel.from_pretrained(...):从指定路径加载预训练的GPT2模型,便于进行推理或再训练。
  • model.generate(...):使用模型生成文本,支持多种生成策略,如束搜索、最大生成长度等。
  • tokenizer.decode(...):用于将生成的ID序列转换回可读的文本格式。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/386953.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

听说它可以让代码更优雅

一提到静态代码检查工具这个词应该比较好理解&#xff0c;所谓静态代码检查工具就是检查静态代码的工具&#xff0c;完美~ 言归正传&#xff0c;相信很多程序员朋友都听说过静态代码检查工具这个概念&#xff0c;它可能是我们IDE里的某一个插件&#xff0c;可能是计算机中的一…

RK3588+MIPI+GMSL+AI摄像机:自动车载4/8通道GMSL采集/边缘计算盒解决方案

RK3588作为目前市面能买到的最强国产SOC&#xff0c;有强大的硬件配置。在智能汽车飞速发展&#xff0c;对图像数据矿场要求越来越多的环境下&#xff0c;如何高效采集数据&#xff0c;或者运行AI应用&#xff0c;成为刚需。 推出的4/8通道GMSL采集/边缘计算盒产品满足这些需求…

Spring验证码

前言&#xff1a;使用Hutool 1.什么是Hutool&#xff1f; 2.代码复制到test类中 3.代码爆红&#xff0c;说明需要引入依赖 4.根据名取Maven仓库相关依赖 5.在pom.xml文件中进行配置 6.引入成功 7. 运行程序 打开d盘&#xff0c;发现已经生成了验证码的图片&#xff0c;路径在…

Codeforces Round 654 (Div. 2) C. A Cookie for You (模拟)

我认为这道题就是个脑筋急转弯。 首先我们知道当a b < n m的时候&#xff0c;饼干总数都不够人的总数&#xff0c;那肯定是NO。 并且注意题干&#xff0c;我们可以得知当a b的时候&#xff0c;第一类和第二类人可以任意选两种饼干中的一种。 之后我们可以分类讨论一下。 …

网格布局 HTML CSS grid layout demo

文章目录 页面效果代码 (HTML CSS)参考 页面效果 代码 (HTML CSS) <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"…

[ BLE4.0 ] 伦茨ST17H66开发-串口UART0的接收与发送

目录 一、前言 二、实现步骤 1.设置回调函数 2.关闭睡眠模式 三、效果展示 四、工程源代码 一、前言 串口通信在任何一款单片机开发中都是尤为重要的。本文涉及的开发所使用的例程依然是基于[ BLE4.0 ] 伦茨ST17H66开发-OSAL系统中添加自己的Task任务文章的工程源码&#x…

windows@powershell@任务计划@自动任务计划@taskschd.msc.md

文章目录 使用任务计划windows中的任务计划任务计划命令行程序开发windows 应用中相关api传统图形界面FAQ schtasks 命令常见用法创建计划任务删除计划任务查询计划任务修改计划任务运行计划任务 PowerShell ScheduledTasks常用 cmdlet 简介1. Get-ScheduledTask2. Register-Sc…

Git远程仓库推送

这里我只连接了两个站点的远程仓库&#xff0c;一个是国内的Gitee&#xff0c;另一个是Github&#xff0c;这两个站点的连接方式主要有两种&#xff0c;第一种就是通过https来连接远程仓库&#xff0c;另一种是通过ssh公钥来连接&#xff0c;这两个站点练接的大致过程都是一样的…

我出一道面试题,看看你能拿 3k 还是 30k!

大家好&#xff0c;我是程序员鱼皮。欢迎屏幕前的各位来到今天的模拟面试现场&#xff0c;接下来我会出一道经典的后端面试题&#xff0c;你只需要进行 4 个简单的选择&#xff0c;就能判断出来你的水平是新手&#xff08;3k&#xff09;、初级&#xff08;10k&#xff09;、中…

4 款最佳 C# 无头浏览器

摘要&#xff1a; 在当今大数据时代&#xff0c;高效的数据采集成为众多项目的关键一环。对于偏好C#语言的开发者而言&#xff0c;无头浏览器是实现网页自动化交互、数据抓取的强大工具。本文将深入探讨四款顶尖的C#无头浏览器库&#xff0c;分析它们的特性和应用场景&#xf…

怎么把C盘分成两个盘?让C盘分区更简单,赶快试试!

在日常使用电脑的过程中&#xff0c;有时我们可能希望将C盘分割成两个独立的分区&#xff0c;以便更好地管理文件和数据。这种操作需要谨慎进行&#xff0c;因为错误的分区操作可能导致数据丢失。那么&#xff0c;我们该怎么把C盘分成两个盘呢&#xff1f;下面&#xff0c;我将…

lua 游戏架构 之 游戏 AI (六)ai_auto_skill

定义一个为ai_auto_skill的类&#xff0c;继承自ai_base类。ai_auto_skill类的目的是在AI自动战斗模式下&#xff0c;根据配置和条件自动选择并使用技能。 lua 游戏架构 之 游戏 AI &#xff08;一&#xff09;ai_base-CSDN博客文章浏览阅读379次。定义了一套接口和属性&#…

vue3在元素上绑定自定义事件弹出虚拟键盘

最近开发中遇到一个需求: 焊接机器人的屏幕上集成web前端网页, 但是没有接入键盘。这就需要web端开发一个虚拟键盘,在网上找个很多虚拟键盘没有特别适合,索性自己写个简单的 图片: 代码: (代码可能比较垃圾冗余,也没时间优化,凑合看吧) 第一步:创建键盘组件 为了方便使用…

3.2.微调

微调 ​ 对于一些样本数量有限的数据集&#xff0c;如果使用较大的模型&#xff0c;可能很快过拟合&#xff0c;较小的模型可能效果不好。这个问题的一个解决方案是收集更多数据&#xff0c;但其实在很多情况下这是很难做到的。 ​ 另一种方法就是迁移学习(transfer learning…

c++如何理解多态与虚函数

目录 **前言****1. 何为多态**1.1 **编译时多态**1.1.1 函数重载1.1.2 模板 **1.2 运行时多态****1.2.1 虚函数****1.2.2 为什么要用父类指针去调用子类函数** **2. 注意****2.1 基类的析构函数应写为虚函数****2.2 构造函数不能设为虚函数** **本文参考** 前言 在学习 c 的虚…

打造重庆市数字化教育“新名片”,广阳湾珊瑚中学凭实力“出圈”!

分布于教学楼连廊顶部的智能照明设备,根据不同的时间和场景需求自动调节灯光亮度和开关状态;安装于各个教室内的智能黑板、学校同步时钟、学生互动设备,在极简以太全光网的赋能下,为师生提供丰富的教学体验与学习支持......行走于重庆市广阳湾珊瑚中学,像是与充满科技感的“校园…

病理AI领域的基础模型汇总|顶刊专题汇总·24-07-26

小罗碎碎念 本期文献主题&#xff1a;病理AI领域的最新基础模型 今天的推文是一期生日特辑&#xff0c;定时在下午六点二十一分发表&#xff08;今天农历六月二十一&#xff0c;哈哈&#xff09;&#xff0c;算是自己给自己的24岁生日礼物&#xff0c;希望24岁这一年&#xff0…

ollama本地部署大语言模型记录

目录 安装Ollama更改模型存放位置 拉取模型GemmaMistralQwen1.5(通义千问)codellama 部署Open webui测试性能知识广度问题1问题2 代码能力总结 最近突然对大语言模型感兴趣 同时在平时的一些线下断网的CTF比赛中&#xff0c;大语言模型也可以作为一个能对话交互的高级知识检索…

SSRF中伪协议学习

SSRF常用的伪协议 file:// 从文件系统中获取文件内容,如file:///etc/passwd dict:// 字典服务协议,访问字典资源,如 dict:///ip:6739/info: ftp:// 可用于网络端口扫描 sftp:// SSH文件传输协议或安全文件传输协议 ldap://轻量级目录访问协议 tftp:// 简单文件传输协议 gopher…

【JavaScript】函数声明和函数表达式的区别

文章目录 一、函数声明1. 定义方式2. 作用域提升&#xff08;Hoisting&#xff09;3. 块级作用域 二、函数表达式1. 定义方式2. 作用域提升&#xff08;Hoisting&#xff09;3. 自引用 三、其他区别1. 函数名2. 可读性和代码组织3. 使用场景 四、总结函数声明函数表达式 在Java…