pytorch实战---IMDB情感分析

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢迎在文章下方留下你的评论和反馈。我期待着与你分享知识、互相学习和建立一个积极的社区。谢谢你的光临,让我们一起踏上这个知识之旅!
请添加图片描述

文章目录

  • 🥦引言
  • 🥦完整代码
  • 🥦代码分析
    • 🥦导库
    • 🥦设置日志
    • 🥦模型定义
      • 🥦GCNN
      • 🥦TextClassificationModel
    • 🥦准备IMDb数据集
    • 🥦整理函数
    • 🥦训练函数
    • 🥦模型初始化和优化器
    • 🥦加载用于训练和评估的数据
    • 🥦恢复训练
    • 🥦调用训练
  • 🥦保存文件的读取
  • 🥦扩展 LSTM、GRU
  • 🥦总结

🥦引言

本文使用IMDB数据集,结合pytorch进行情感分析

🥦完整代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_scorefrom torch import utilsimport torchtext
from tqdm import tqdm
from torchtext.datasets import IMDBfrom torchtext.datasets.imdb import NUM_LINES
from torchtext.data import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.functional import to_map_style_datasetimport os
import sys
import logging
import logginglogging.basicConfig(level=logging.WARN, stream=sys.stdout, format = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")VOCAB_SIZE = 15000# step1 编写GCNN模型代码,门(Gate)卷积网络
class GCNN(nn.Module):def __init__(self, vocab_size=VOCAB_SIZE, embedding_dim=64, num_class=2):super(GCNN, self).__init__()self.embedding_table = nn.Embedding(vocab_size, embedding_dim)nn.init.xavier_uniform_(self.embedding_table.weight)# 都是1维卷积self.conv_A_1 = nn.Conv1d(embedding_dim, 64, 15, stride=7)self.conv_B_1 = nn.Conv1d(embedding_dim, 64, 15, stride=7)self.conv_A_2 = nn.Conv1d(64, 64, 15, stride=7)self.conv_B_2 = nn.Conv1d(64, 64, 15, stride=7)self.output_linear1 = nn.Linear(64, 128)self.output_linear2 = nn.Linear(128, num_class)def forward(self, word_index):"""定义GCN网络的算子操作流程,基于句子单词ID输入得到分类logits输出"""# 1. 通过word_index得到word_embedding# word_index shape: [bs, max_seq_len]word_embedding = self.embedding_table(word_index)  # [bs, max_seq_len, embedding_dim]# 2. 编写第一层1D门卷积模块,通道数在第2维word_embedding = word_embedding.transpose(1, 2)  # [bs, embedding_dim, max_seq_len]A = self.conv_A_1(word_embedding)B = self.conv_B_1(word_embedding)H = A * torch.sigmoid(B)  # [bs, 64, max_seq_len]A = self.conv_A_2(H)B = self.conv_B_2(H)H = A * torch.sigmoid(B)  # [bs, 64, max_seq_len]# 3. 池化并经过全连接层pool_output = torch.mean(H, dim=-1)  # 平均池化,得到[bs, 4096]linear1_output = self.output_linear1(pool_output)# 最后一层需要设置为隐含层数目logits = self.output_linear2(linear1_output)  # [bs, 2]return logits# PyTorch官网的简单模型
class TextClassificationModel(nn.Module):"""简单版embedding.DNN模型"""def __init__(self, vocab_size=VOCAB_SIZE, embed_dim=64, num_class=2):super(TextClassificationModel, self).__init__()self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)self.fc = nn.Linear(embed_dim, num_class)def forward(self, token_index):# 词袋embedded = self.embedding(token_index)  # shape: [bs, embedding_dim]return self.fc(embedded)# step2 构建IMDB Dataloader
BATCH_SIZE = 64def yeild_tokens(train_data_iter, tokenizer):for i, sample in enumerate(train_data_iter):label, comment = sampleyield tokenizer(comment)  # 字符串转换为token索引的列表train_data_iter = IMDB(root="./data", split="train")  # Dataset类型的对象
tokenizer = get_tokenizer("basic_english")
# 只使用出现次数大约20的token
vocab = build_vocab_from_iterator(yeild_tokens(train_data_iter, tokenizer), min_freq=20, specials=["<unk>"])
vocab.set_default_index(0)  # 特殊索引设置为0
print(f'单词表大小: len(vocab)')# 校对函数, batch是dataset返回值,主要是处理batch一组数据
def collate_fn(batch):"""对DataLoader所生成的mini-batch进行后处理"""target = []token_index = []max_length = 0  # 最大的token长度for i, (label, comment) in enumerate(batch):tokens = tokenizer(comment)token_index.append(vocab(tokens))  # 字符列表转换为索引列表# 确定最大的句子长度if len(tokens) > max_length:max_length = len(tokens)if label == "pos":target.append(0)else:target.append(1)token_index = [index + [0] * (max_length - len(index)) for index in token_index]# one-hot接收长整形的数据,所以要转换为int64return (torch.tensor(target).to(torch.int64), torch.tensor(token_index).to(torch.int32))# step3 编写训练代码
def train(train_data_loader, eval_data_loader, model, optimizer, num_epoch, log_step_interval, save_step_interval,  eval_step_interval, save_path, resume=""):"""此处data_loader是map-style dataset"""start_epoch = 0start_step = 0if resume != "":# 加载之前训练过的模型的参数文件logging.warning(f"loading from resume")checkpoint = torch.load(resume)model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])start_epoch = checkpoint['epoch']start_step = checkpoint['step']for epoch_index in tqdm(range(start_epoch, num_epoch), desc="epoch"):ema_loss = 0total_acc_account = 0total_account = 0true_labels = []predicted_labels = []num_batches = len(train_data_loader)for batch_index, (target, token_index) in enumerate(train_data_loader):optimizer.zero_grad()step = num_batches * (epoch_index) + batch_index + 1logits = model(token_index)# one-hot需要转换float32才可以训练bce_loss = F.binary_cross_entropy(torch.sigmoid(logits), F.one_hot(target, num_classes=2).to(torch.float32))ema_loss = 0.9 * ema_loss + 0.1 * bce_loss  # 指数平均lossbce_loss.backward()nn.utils.clip_grad_norm_(model.parameters(), 0.1)  # 梯度的正则进行截断,保证训练稳定optimizer.step()  # 更新参数true_labels.extend(target.tolist())predicted_labels.extend(torch.argmax(logits, dim=-1).tolist())if step % log_step_interval == 0:logging.warning(f"epoch_index: {epoch_index}, batch_index: {batch_index}, ema_loss: {ema_loss}")if step % save_step_interval == 0:os.makedirs(save_path, exist_ok=True)save_file = os.path.join(save_path, f"step_{step}.pt")torch.save({"epoch": epoch_index,"step": step,"model_state_dict": model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': bce_loss}, save_file)logging.warning(f"checkpoint has been saved in {save_file}")if step % save_step_interval == 0:os.makedirs(save_path, exist_ok=True)save_file = os.path.join(save_path, f"step_{step}.pt")torch.save({"epoch": epoch_index,"step": step,"model_state_dict": model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': bce_loss,'accuracy': accuracy,'precision': precision,'recall': recall,'f1': f1}, save_file)logging.warning(f"checkpoint has been saved in {save_file}")if step % eval_step_interval == 0:logging.warning("start to do evaluation...")model.eval()ema_eval_loss = 0total_acc_account = 0total_account = 0true_labels = []predicted_labels = []for eval_batch_index, (eval_target, eval_token_index) in enumerate(eval_data_loader):total_account += eval_target.shape[0]eval_logits = model(eval_token_index)total_acc_account += (torch.argmax(eval_logits, dim=-1) == eval_target).sum().item()eval_bce_loss = F.binary_cross_entropy(torch.sigmoid(eval_logits),F.one_hot(eval_target, num_classes=2).to(torch.float32))ema_eval_loss = 0.9 * ema_eval_loss + 0.1 * eval_bce_losstrue_labels.extend(eval_target.tolist())predicted_labels.extend(torch.argmax(eval_logits, dim=-1).tolist())accuracy = accuracy_score(true_labels, predicted_labels)precision = precision_score(true_labels, predicted_labels)recall = recall_score(true_labels, predicted_labels)f1 = f1_score(true_labels, predicted_labels)logging.warning(f"ema_eval_loss: {ema_eval_loss}, eval_acc: {total_acc_account / total_account}")logging.warning(f"Precision: {precision}, Recall: {recall}, F1: {f1}, Accuracy: {accuracy}")model.train()model = GCNN()
# model = TextClassificationModel()
print("模型总参数:", sum(p.numel() for p in model.parameters()))
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)train_data_iter = IMDB(root="data", split="train")  # Dataset类型的对象
train_data_loader = torch.utils.data.DataLoader(to_map_style_dataset(train_data_iter), batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)eval_data_iter = IMDB(root="data", split="test")  # Dataset类型的对象
# collate校对
eval_data_loader = utils.data.DataLoader(to_map_style_dataset(eval_data_iter), batch_size=8, collate_fn=collate_fn)# resume = "./data/step_500.pt"
resume = ""train(train_data_loader, eval_data_loader, model, optimizer, num_epoch=10, log_step_interval=20, save_step_interval = 500, eval_step_interval = 300, save_path = "./log_imdb_text_classification2", resume = resume)

🥦代码分析

🥦导库

首先导入需要的库

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from torch import utils
import torchtext
from tqdm import tqdm
from torchtext.datasets import IMDB
  • torch (PyTorch):
    PyTorch 是一个用于机器学习和深度学习的开源深度学习框架。它提供了张量计算、自动微分、神经网络层和优化器等功能,使用户能够构建和训练深度学习模型。

  • torch.nn:
    torch.nn 模块包含了PyTorch中用于构建神经网络模型的类和函数。它包括各种神经网络层、损失函数和优化器等。

  • torch.nn.functional:
    torch.nn.functional 模块提供了一组函数,用于构建神经网络的非参数化操作,如激活函数、池化和卷积等。这些函数通常与torch.nn一起使用。

  • sklearn.metrics (scikit-learn):
    scikit-learn是一个用于机器学习的Python库,其中包含了一系列用于评估模型性能的度量工具。导入的precision_score、recall_score、f1_score 和 accuracy_score 用于计算分类模型的精确度、召回率、F1分数和准确性。

  • torch.utils:
    torch.utils 包含了一些实用工具和数据加载相关的函数。在这段代码中,它用于构建数据加载器。

  • torchtext:
    torchtext 是一个PyTorch的自然语言处理库,用于文本数据的处理和加载。它提供了用于文本数据预处理和构建数据集的功能。

  • tqdm:
    tqdm 是一个Python库,用于创建进度条,可用于监视循环迭代的进度。在代码中,它用于显示训练和评估的进度。

  • torchtext.datasets.IMDB:
    torchtext.datasets.IMDB 是TorchText库中的一个数据集,包含了IMDb电影评论的数据。这些评论用于情感分析任务,其中评论被标记为积极或消极。

🥦设置日志

logging.basicConfig(level=logging.WARN, stream=sys.stdout, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)

在代码中设置日志的作用是记录程序的运行状态、调试信息和重要事件,以便在开发和生产环境中更轻松地诊断问题和了解程序的行为。设置日志有以下作用:

  • 问题诊断:当程序出现错误或异常时,日志记录可以提供有关错误发生的位置、原因和上下文的信息。这有助于开发人员快速定位和修复问题。

  • 性能分析:通过记录程序的运行时间和关键操作的时间戳,日志可以用于性能分析,帮助开发人员识别潜在的性能瓶颈。

  • 跟踪进度:在长时间运行的任务中,例如训练深度学习模型,日志记录可以帮助跟踪任务的进度,以便了解训练状态、完成的步骤和剩余时间。

  • 监控和警报:日志可以与监控系统集成,以便在发生关键事件或异常情况时触发警报。这对于及时响应问题非常重要。

  • 审计和合规:在某些应用中,日志记录是合规性的一部分,用于追踪系统的操作和用户的活动。日志可以用于审计和调查。

在上述代码中,设置日志的目的是跟踪训练进度、记录训练损失以及保存检查点。它允许开发人员监视模型训练的进展并在需要时查看详细信息,例如损失值和评估指标。此外,日志还可以用于调试和查看模型性能。

🥦模型定义

代码定义了两个模型:

GCNN:用于文本分类的门控卷积神经网络。
TextClassificationModel:使用嵌入和线性层的简单文本分类模型。

🥦GCNN

class GCNN(nn.Module):def __init__(self, vocab_size=VOCAB_SIZE, embedding_dim=64, num_class=2):super(GCNN, self).__init__()self.embedding_table = nn.Embedding(vocab_size, embedding_dim)nn.init.xavier_uniform_(self.embedding_table.weight)# 都是1维卷积self.conv_A_1 = nn.Conv1d(embedding_dim, 64, 15, stride=7)self.conv_B_1 = nn.Conv1d(embedding_dim, 64, 15, stride=7)self.conv_A_2 = nn.Conv1d(64, 64, 15, stride=7)self.conv_B_2 = nn.Conv1d(64, 64, 15, stride=7)self.output_linear1 = nn.Linear(64, 128)self.output_linear2 = nn.Linear(128, num_class)def forward(self, word_index):"""定义GCN网络的算子操作流程,基于句子单词ID输入得到分类logits输出"""# 1. 通过word_index得到word_embedding# word_index shape: [bs, max_seq_len]word_embedding = self.embedding_table(word_index)  # [bs, max_seq_len, embedding_dim]# 2. 编写第一层1D门卷积模块,通道数在第2维word_embedding = word_embedding.transpose(1, 2)  # [bs, embedding_dim, max_seq_len]A = self.conv_A_1(word_embedding)B = self.conv_B_1(word_embedding)H = A * torch.sigmoid(B)  # [bs, 64, max_seq_len]A = self.conv_A_2(H)B = self.conv_B_2(H)H = A * torch.sigmoid(B)  # [bs, 64, max_seq_len]# 3. 池化并经过全连接层pool_output = torch.mean(H, dim=-1)  # 平均池化,得到[bs, 4096]linear1_output = self.output_linear1(pool_output)# 最后一层需要设置为隐含层数目logits = self.output_linear2(linear1_output)  # [bs, 2]return logits

🥦TextClassificationModel

class TextClassificationModel(nn.Module):"""简单版embedding.DNN模型"""def __init__(self, vocab_size=VOCAB_SIZE, embed_dim=64, num_class=2):super(TextClassificationModel, self).__init__()self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)self.fc = nn.Linear(embed_dim, num_class)def forward(self, token_index):# 词袋embedded = self.embedding(token_index)  # shape: [bs, embedding_dim]return self.fc(embedded)

🥦准备IMDb数据集

这行代码使用TorchText的IMDB数据集对象,导入IMDb数据集的训练集部分。

# 数据集导入
train_data_iter = IMDB(root="./data", split="train")

这行代码创建了一个用于将文本分词为单词的分词器。

# 数据预处理
tokenizer = get_tokenizer("basic_english")

这里,build_vocab_from_iterator 函数根据文本数据创建了一个词汇表,只包括出现频率大于等于20次的单词。特殊标记用于处理未知单词。然后,set_default_index将特殊标记的索引设置为0。

# 构建词汇表
vocab = build_vocab_from_iterator(yeild_tokens(train_data_iter, tokenizer), min_freq=20, specials=["<unk>"])
vocab.set_default_index(0)

这是一个自定义的校对函数,用于处理DataLoader返回的批次数据,将文本转换为可以输入模型的张量形式。

def collate_fn(batch):"""对DataLoader所生成的mini-batch进行后处理"""target = []token_index = []max_length = 0  # 最大的token长度for i, (label, comment) in enumerate(batch):tokens = tokenizer(comment)token_index.append(vocab(tokens))  # 字符列表转换为索引列表# 确定最大的句子长度if len(tokens) > max_length:max_length = len(tokens)if label == "pos":target.append(0)else:target.append(1)token_index = [index + [0] * (max_length - len(index)) for index in token_index]# one-hot接收长整形的数据,所以要转换为int64return (torch.tensor(target).to(torch.int64), torch.tensor(token_index).to(torch.int32))

这行代码将IMDb训练数据集加载到DataLoader对象中,以便进行模型训练。collate_fn函数用于处理数据的批处理。

train_data_loader = torch.utils.data.DataLoader(to_map_style_dataset(train_data_iter), batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)

上述代码块执行了IMDb数据集的准备工作,包括导入数据、分词、构建词汇表和设置数据加载器。这些步骤是为了使数据集可用于训练文本分类模型。

🥦整理函数

这个 collate_fn 函数用于对 DataLoader 批次中的数据进行处理,确保每个批次中的文本序列具有相同的长度,并将标签转换为适用于模型输入的张量形式。它的工作包括以下几个方面:

提取标签和评论文本。
使用分词器将评论文本分词为单词。
确定批次中最长评论的长度。
根据最长评论的长度,将所有评论的单词索引序列填充到相同的长度。
将标签转换为适当的张量形式(这里是将标签转换为长整数型)。
返回处理后的批次数据,其中包括标签和填充后的单词索引序列。

这个整理函数确保了模型在训练期间能够处理不同长度的文本序列,并将它们转换为模型可接受的张量输入。

🥦训练函数

def train(train_data_loader, eval_data_loader, model, optimizer, num_epoch, log_step_interval, save_step_interval,  eval_step_interval, save_path, resume=""):"""此处data_loader是map-style dataset"""start_epoch = 0start_step = 0if resume != "":# 加载之前训练过的模型的参数文件logging.warning(f"loading from resume")checkpoint = torch.load(resume)model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])start_epoch = checkpoint['epoch']start_step = checkpoint['step']for epoch_index in tqdm(range(start_epoch, num_epoch), desc="epoch"):ema_loss = 0total_acc_account = 0total_account = 0true_labels = []predicted_labels = []num_batches = len(train_data_loader)for batch_index, (target, token_index) in enumerate(train_data_loader):optimizer.zero_grad()step = num_batches * (epoch_index) + batch_index + 1logits = model(token_index)# one-hot需要转换float32才可以训练bce_loss = F.binary_cross_entropy(torch.sigmoid(logits), F.one_hot(target, num_classes=2).to(torch.float32))ema_loss = 0.9 * ema_loss + 0.1 * bce_loss  # 指数平均lossbce_loss.backward()nn.utils.clip_grad_norm_(model.parameters(), 0.1)  # 梯度的正则进行截断,保证训练稳定optimizer.step()  # 更新参数true_labels.extend(target.tolist())predicted_labels.extend(torch.argmax(logits, dim=-1).tolist())if step % log_step_interval == 0:logging.warning(f"epoch_index: {epoch_index}, batch_index: {batch_index}, ema_loss: {ema_loss}")if step % save_step_interval == 0:os.makedirs(save_path, exist_ok=True)save_file = os.path.join(save_path, f"step_{step}.pt")torch.save({"epoch": epoch_index,"step": step,"model_state_dict": model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': bce_loss}, save_file)logging.warning(f"checkpoint has been saved in {save_file}")if step % save_step_interval == 0:os.makedirs(save_path, exist_ok=True)save_file = os.path.join(save_path, f"step_{step}.pt")torch.save({"epoch": epoch_index,"step": step,"model_state_dict": model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': bce_loss,'accuracy': accuracy,'precision': precision,'recall': recall,'f1': f1}, save_file)logging.warning(f"checkpoint has been saved in {save_file}")if step % eval_step_interval == 0:logging.warning("start to do evaluation...")model.eval()ema_eval_loss = 0total_acc_account = 0total_account = 0true_labels = []predicted_labels = []for eval_batch_index, (eval_target, eval_token_index) in enumerate(eval_data_loader):total_account += eval_target.shape[0]eval_logits = model(eval_token_index)total_acc_account += (torch.argmax(eval_logits, dim=-1) == eval_target).sum().item()eval_bce_loss = F.binary_cross_entropy(torch.sigmoid(eval_logits),F.one_hot(eval_target, num_classes=2).to(torch.float32))ema_eval_loss = 0.9 * ema_eval_loss + 0.1 * eval_bce_losstrue_labels.extend(eval_target.tolist())predicted_labels.extend(torch.argmax(eval_logits, dim=-1).tolist())accuracy = accuracy_score(true_labels, predicted_labels)precision = precision_score(true_labels, predicted_labels)recall = recall_score(true_labels, predicted_labels)f1 = f1_score(true_labels, predicted_labels)logging.warning(f"ema_eval_loss: {ema_eval_loss}, eval_acc: {total_acc_account / total_account}")logging.warning(f"Precision: {precision}, Recall: {recall}, F1: {f1}, Accuracy: {accuracy}")model.train()

这段代码定义了一个名为 train 的函数,用于执行训练过程。下面是该函数的详细说明:

train 函数接受以下参数:train_data_loader: 训练数据的 DataLoader,用于迭代训练数据。eval_data_loader: 用于评估的 DataLoader,用于评估模型性能。model: 要训练的神经网络模型。optimizer: 用于更新模型参数的优化器。num_epoch: 训练的总周期数。log_step_interval: 记录日志的间隔步数。save_step_interval: 保存模型检查点的间隔步数。eval_step_interval: 执行评估的间隔步数。save_path: 保存模型检查点的目录。resume: 可选的,用于恢复训练的检查点文件路径。训练函数的主要工作如下:它首先检查是否有恢复训练的检查点文件。如果有,它会加载之前训练的模型参数和优化器状态,以便继续训练。然后,它开始进行一系列的训练周期(epochs),每个周期内包含多个训练步(batches)。在每个训练步中,它执行以下操作:零化梯度,以准备更新模型参数。计算模型的预测输出(logits)。计算二进制交叉熵损失(binary cross-entropy loss)。使用反向传播(backpropagation)计算梯度并更新模型参数。记录损失、真实标签和预测标签。如果步数达到了 log_step_interval,则记录损失。如果步数达到了 save_step_interval,则保存模型检查点。如果步数达到了 eval_step_interval,则执行评估:将模型切换到评估模式(model.eval())。对评估数据集中的每个批次执行以下操作:计算模型的预测输出。计算二进制交叉熵损失。计算准确性、精确度、召回率和F1分数。记录评估损失和评估指标。将模型切换回训练模式(model.train())。最后,训练函数返回经过训练的模型。

这个训练函数执行了完整的训练过程,包括了模型的前向传播、损失计算、梯度更新、日志记录、模型检查点的保存和评估。通过调用这个函数,你可以训练模型并监视其性能。

🥦模型初始化和优化器

model = GCNN()
# model = TextClassificationModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

🥦加载用于训练和评估的数据

在提供的代码中,加载用于训练和评估的数据的部分如下:

train_data_iter = IMDB(root="data", split="train")

这一行代码使用 TorchText 的 IMDB 数据集对象,导入 IMDB 数据集的训练集部分。这部分数据将用于模型的训练。

eval_data_iter = IMDB(root="data", split="test")

这一行代码使用 TorchText 的 IMDB 数据集对象,导入 IMDB 数据集的测试集部分。这部分数据将用于评估模型的性能。


之后,这些数据集通过以下代码转化为 DataLoader 对象,以便用于模型训练和评估:

# 训练数据 DataLoader
train_data_loader = torch.utils.data.DataLoader(to_map_style_dataset(train_data_iter), batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)
# 评估数据 DataLoader
eval_data_loader = utils.data.DataLoader(to_map_style_dataset(eval_data_iter), batch_size=8, collate_fn=collate_fn)

这些 DataLoader 对象将数据加载到内存中,以便训练和评估使用。collate_fn 函数用于处理数据的批次,确保它们具有适当的格式,以便输入到模型中。

这些部分负责加载和准备用于训练和评估的数据,是机器学习模型训练和评估的重要准备步骤。训练数据用于训练模型,而评估数据用于评估模型的性能。

🥦恢复训练

start_epoch = 0
start_step = 0
if resume != "":# 加载之前训练过的模型的参数文件logging.warning(f"loading from resume")checkpoint = torch.load(resume)model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])start_epoch = checkpoint['epoch']start_step = checkpoint['step']

上述代码段位于训练函数中的开头部分,主要用于检查是否有已经训练过的模型的检查点文件,以便继续训练。具体解释如下:

如果 resume 变量不为空(即存在要恢复的检查点文件路径),则执行以下操作:
通过 torch.load 加载之前训练过的模型的检查点文件。
使用 load_state_dict 方法将已保存的模型参数加载到当前的模型中,以便继续训练。
同样,使用 load_state_dict 方法将已保存的优化器状态加载到当前的优化器中,以确保继续从之前的状态开始训练。
获取之前训练的轮数和步数,以便从恢复的状态继续训练。

这部分代码的目的是允许从之前保存的模型检查点继续训练,而不是从头开始。这对于长时间运行的训练任务非常有用,可以在中途中断训练并在之后恢复,而不会丢失之前的训练进度。

🥦调用训练

train(train_data_loader, eval_data_loader, model, optimizer, num_epoch=10, log_step_interval=20, save_step_interval=500, eval_step_interval=300, save_path="./log_imdb_text_classification2", resume=resume)

🥦保存文件的读取

import torch# 指定已存在的 .pt 文件路径
file_path = "./log_imdb_text_classification/step_3500.pt"  # 替换为实际的文件路径# 使用 torch.load() 加载文件
checkpoint = torch.load(file_path)# 查看准确率、精确率、召回率和F1分数
accuracy = checkpoint["accuracy"]
precision = checkpoint["precision"]
recall = checkpoint["recall"]
f1 = checkpoint["f1"]print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)

在这里插入图片描述

🥦扩展 LSTM、GRU

本文原作者使用的是卷积神经网络,但是卷积神经网络的优化模型GCNN,但是这个模型对于图更好,由此我接下来引入两个循环神经网络LSTM和GRU

class LSTMModel(nn.Module):def __init__(self, vocab_size=VOCAB_SIZE, embedding_dim=64, hidden_dim=64, num_class=2):super(LSTMModel, self).__init__()self.embedding_table = nn.Embedding(vocab_size, embedding_dim)self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=1, batch_first=True)self.output_linear = nn.Linear(hidden_dim, num_class)def forward(self, word_index):word_embedding = self.embedding_table(word_index)lstm_out, _ = self.lstm(word_embedding)lstm_out = lstm_out[:, -1, :]  # 取最后一个时间步的输出logits = self.output_linear(lstm_out)return logitsclass GRUModel(nn.Module):def __init__(self, vocab_size=VOCAB_SIZE, embedding_dim=64, hidden_dim=64, num_class=2):super(GRUModel, self).__init__()self.embedding_table = nn.Embedding(vocab_size, embedding_dim)self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers=1, batch_first=True)self.output_linear = nn.Linear(hidden_dim, num_class)def forward(self, word_index):word_embedding = self.embedding_table(word_index)gru_out, _ = self.gru(word_embedding)gru_out = gru_out[:, -1, :]  # 取最后一个时间步的输出logits = self.output_linear(gru_out)return logits
# 创建LSTM模型
lstm_model = LSTMModel()
print("模型总参数:", sum(p.numel() for p in lstm_model.parameters()))
lstm_optimizer = torch.optim.Adam(lstm_model.parameters(), lr=0.001)# 创建GRU模型
# gru_model = GRUModel()
# print("模型总参数:", sum(p.numel() for p in gru_model.parameters()))
# gru_optimizer = torch.optim.Adam(gru_model.parameters(), lr=0.001)
# 训练LSTM模型
train(train_data_loader, eval_data_loader, lstm_model, lstm_optimizer, num_epoch=10, log_step_interval=20, save_step_interval=500, eval_step_interval=300, save_path="./log_imdb_lstm", resume="")# 训练GRU模型
# train(train_data_loader, eval_data_loader, gru_model, gru_optimizer, num_epoch=10, log_step_interval=20, save_step_interval=500, eval_step_interval=300, save_path="./log_imdb_gru", resume="")

感兴趣的小伙伴可以试试,对比一下

🥦总结

本文代码来自网络仅供学习,原文地址

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/170605.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Pytorch公共数据集、tensorboard、DataLoader使用

本文将主要介绍torchvision.datasets的使用&#xff0c;并以CIFAR-10为例进行介绍&#xff0c;对可视化工具tensorboard进行介绍&#xff0c;包括安装&#xff0c;使用&#xff0c;可视化过程等&#xff0c;最后介绍DataLoader的使用。希望对你有帮助 Pytorch公共数据集 torc…

vscode安装可以打开docx文件的插件

文章目录 vscode安装可以打开docx文件的插件去插件商城搜索并安装安装后打开一个word vscode安装可以打开docx文件的插件 去插件商城搜索并安装 安装后 打开一个word

脏牛提权 liunx

使用方法 Liunx 普通用户 内核版本 在版本里 我直接脏牛提权 有脚本查看内核版本 上传c脚本 编译 直接执行 获取高权限 提权 Liunx https://github.com/InteliSecureLabs/Linux Exploit Suggester 运行这个脚本 上传到客户端 https://github…

Unity Hub报错:No valid Unity Editor license found. Please activate your license.

最近 遇到一个问题&#xff0c;打开高版本时Hub抛出异常&#xff1a;No valid Unity Editor license found. Please activate your license. 首先你必须排除是否登录Unity Hub&#xff0c;并且激活许可证。 方法一&#xff1a;禁用网络&#xff08;这个可能无效&#xff09; …

如何通过卖虚拟资料月入10万?看这几个卖资料案例

我微信好友里&#xff0c;有近4000个是做创业博主的同行。 你可能会好奇&#xff0c;其中60%的人都通过卖虚拟资料起家&#xff0c;这到底说明了什么呢&#xff1f; 嗯&#xff0c;事实上&#xff0c;这就意味着这些人选择了网络赚钱的首选项目&#xff0c;那就是销售各种资料…

python selenium如何带cookie访问网站

python selenium如何带cookie访问网站 要使用Python的Selenium库带有cookie访问网站&#xff0c;你可以按照以下步骤进行操作&#xff1a; 一、流程介绍 安装Selenium库&#xff08;如果尚未安装&#xff09;&#xff1a; pip install selenium导入Selenium库并启动一个浏览…

导致爬虫无法使用的原因有哪些?

随着互联网的普及和发展&#xff0c;爬虫技术也越来越多地被应用到各个领域。然而&#xff0c;在实际使用中&#xff0c;爬虫可能会遇到各种问题导致无法正常工作。本文将探讨导致爬虫无法使用的原因&#xff0c;并给出相应的解决方法。 一、目标网站反爬虫机制 许多网站为了…

11 个最值得推荐的 Windows 数据恢复软件

您可能已经尝试过许多免费的恢复程序&#xff0c;但它们都不起作用&#xff0c;对吧&#xff1f;这就是您正在寻找最好的数据恢复软件的原因。 个人去过那里。根据个人的经验&#xff0c;大多数免费软件并不能解决这个问题。有时&#xff0c;当个人在 PC 上运行恢复程序时&…

【API篇】七、Flink窗口

文章目录 1、窗口2、分类3、窗口API概览4、窗口分配器 在批处理统计中&#xff0c;可以等待一批数据都到齐后&#xff0c;统一处理。但是在无界流的实时处理统计中&#xff0c;是来一条就得处理一条&#xff0c;那么如何统计最近一段时间内的数据呢&#xff1f; ⇒ 窗口的概念&…

选择合适的项目管理系统来支持专业产品研发团队

专业产品研发团队的公司离不开其严谨的管理和高效的研发流程&#xff0c;为了进一步提升研发效率和管理水平&#xff0c;产研团队需要一个全流程的项目管理系统来支持其研发团队的协同合作。 一、系统需求 IT行业的研发工作涵盖了从立项、项目变更到项目的进程计划等多个环节。…

B-tree(PostgreSQL 14 Internals翻译版)

概览 B树(作为B树访问方法实现)是一种数据结构&#xff0c;它使您能够通过从树的根向下查找树的叶节点中所需的元素。为了明确地标识搜索路径&#xff0c;必须对所有树元素进行排序。B树是为有序数据类型设计的&#xff0c;这些数据类型的值可以进行比较和排序。 下面的机场代…

OpenCV视频车流量识别详解与实践

视频车流量识别基本思想是使用背景消去算法将运动物体从图片中提取出来&#xff0c;消除噪声识别运动物体轮廓&#xff0c;最后&#xff0c;在固定区域统计筛选出来符合条件的轮廓。 基于统计背景模型的视频运动目标检测技术&#xff1a; 背景获取&#xff1a;需要在场景存在…

基于PHP的线上购物商城,MySQL数据库,PHPstudy,原生PHP,前台用户+后台管理,完美运行,有一万五千字论文。

目录 演示视频 基本介绍 论文截图 功能结构 系统截图 演示视频 基本介绍 基于PHP的线上购物商城&#xff0c;MySQL数据库&#xff0c;PHPstudy&#xff0c;原生PHP&#xff0c;前台用户后台管理&#xff0c;完美运行&#xff0c;有一万五千字论文。 现如今,购物网站是商业…

【七】SpringBoot为什么可以打成 jar包启动

SpringBoot为什么可以打成 jar包启动 简介&#xff1a;庆幸的是夜跑的习惯一直都在坚持&#xff0c;正如现在坚持写博客一样。最开始刚接触springboot的时候就觉得很神奇&#xff0c;当时也去研究了一番&#xff0c;今晚夜跑又想起来了这茬事&#xff0c;于是想着应该可以记录一…

Python单元测试

import unittest #必须要导入单元测试的包class Student(object):def __init__(self, name, score):self.name nameself.score scoredef get_grade(self):if self.score > 100:#返回错误不能用return&#xff0c;应该用raise raise ValueError("成绩不能大于100"…

MySQL进阶(日志)——MySQL的日志 bin log (归档日志) 事务日志redo log(重做日志) undo log(回滚日志)

前言 MySQL最为最流行的开源数据库&#xff0c;其重要性不言而喻&#xff0c;也是大多数程序员接触的第一款数据库&#xff0c;深入认识和理解MySQL也比较重要。 本篇博客阐述MySQL的日志&#xff0c;介绍重要的bin log (归档日志) 、 事务日志redo log(重做日志) 、 undo lo…

【iOS逆向与安全】某音App直播间自动发666 和 懒人自动看视频

1.目标 由于看直播的时候主播叫我发 666&#xff0c;支持他&#xff0c;我肯定支持他呀&#xff0c;就一直发&#xff0c;可是后来发现太浪费时间了&#xff0c;能不能做一个直播间自动发 666 呢&#xff1f;于是就花了几分钟做了一个。 2.操作环境 越狱iPhone一台 frida m…

Mybatis应用场景之动态传参、两字段查询、用户存在性的判断

目录 一、动态传参 1、场景描述 2、实现过程 3、代码测试 二、两字段查询 1、场景描述 2、实现过程 3、代码测试 4、注意点 三、用户存在性的判断 1、场景描述 2、实现过程 3、代码测试 一、动态传参 1、场景描述 在进行数据库查询的时候&#xff0c;需要动态传入…

Linux友人帐之日志与备份

一、日志 1.1概述 日志文件是重要的系统信息文件&#xff0c;其中记录了许多重要的系统事件&#xff0c;包括用户的登录信息、系统的启动信息、系统的安全信息、邮件相关信息、各种服务相关信息等。日志对于安全来说也很重要&#xff0c;它记录了系统每天发生的各种事情&#…

openGauss学习笔记-108 openGauss 数据库管理-管理用户及权限-用户

文章目录 openGauss学习笔记-108 openGauss 数据库管理-管理用户及权限-用户108.1 创建、修改和删除用户108.2 私有用户108.3 永久用户108.4 用户认证优先规则 openGauss学习笔记-108 openGauss 数据库管理-管理用户及权限-用户 使用CREATE USER和ALTER USER可以创建和管理数据…