BERT的中文问答系统23

为了确保日志目录问题不会影响模型训练,我们可以在代码中增加更健壮的日志目录处理逻辑。具体来说,我们可以确保日志目录在代码执行的早期阶段就被创建,并且在遇到任何问题时记录详细的错误信息,但不中断整个程序的执行。

以下是修改后的代码,增加了对日志目录的更健壮处理:

python

import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, scrolledtext, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime
from torch.cuda.mp import GradScaler, autocast
import torch.multiprocessing as mp
import psutil
import torch.distributed as dist# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')def ensure_directory_exists(directory):try:os.makedirs(directory, exist_ok=True)logging.info(f"确保目录存在: {directory}")except OSError as e:logging.error(f"创建目录 {directory} 失败: {e}")def setup_logging():log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d/%H-%M-%S/羲和.txt'))log_dir = os.path.dirname(log_file)ensure_directory_exists(log_dir)logging.basicConfig(level=logging.INFO,format='%(asctime)s - %(levelname)s - %(message)s',handlers=[logging.FileHandler(log_file),logging.StreamHandler()])setup_logging()# 数据集类
class XihuaDataset(Dataset):def __init__(self, file_path, tokenizer, max_length=128):self.tokenizer = tokenizerself.max_length = max_lengthself.data = self.load_data(file_path)def load_data(self, file_path):data = []if file_path.endswith('.jsonl'):with jsonlines.open(file_path) as reader:for i, item in enumerate(reader):try:data.append(item)except jsonlines.jsonlines.InvalidLineError as e:logging.warning(f"跳过无效行 {i + 1}: {e}")elif file_path.endswith('.json'):with open(file_path, 'r') as f:try:data = json.load(f)except json.JSONDecodeError as e:logging.warning(f"跳过无效文件 {file_path}: {e}")return datadef __len__(self):return len(self.data)def __getitem__(self, idx):item = self.data[idx]question = item['question']human_answer = item['human_answers'][0]chatgpt_answer = item['chatgpt_answers'][0]try:inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)except Exception as e:logging.warning(f"跳过无效项 {idx}: {e}")return self.__getitem__((idx + 1) % len(self.data))return {'input_ids': inputs['input_ids'].squeeze(),'attention_mask': inputs['attention_mask'].squeeze(),'human_input_ids': human_inputs['input_ids'].squeeze(),'human_attention_mask': human_inputs['attention_mask'].squeeze(),'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),'human_answer': human_answer,'chatgpt_answer': chatgpt_answer}# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128, distributed=False, num_workers=4):dataset = XihuaDataset(file_path, tokenizer, max_length)if distributed:sampler = DistributedSampler(dataset)return DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers)else:return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)# 模型定义
class XihuaModel(torch.nn.Module):def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):super(XihuaModel, self).__init__()self.bert = BertModel.from_pretrained(pretrained_model_name)self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)def forward(self, input_ids, attention_mask):outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)pooled_output = outputs.pooler_outputlogits = self.classifier(pooled_output)return logits# 训练函数
def train(model, data_loader, optimizer, criterion, device, scaler=None, gradient_accumulation_steps=1):model.train()total_loss = 0.0optimizer.zero_grad()for step, batch in enumerate(data_loader):try:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)human_input_ids = batch['human_input_ids'].to(device)human_attention_mask = batch['human_attention_mask'].to(device)chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)with autocast():  # 使用自动混合精度human_logits = model(human_input_ids, human_attention_mask)chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)human_labels = torch.ones(human_logits.size(0), 1).to(device)chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)if gradient_accumulation_steps > 1:loss = loss / gradient_accumulation_stepsscaler.scale(loss).backward()if (step + 1) % gradient_accumulation_steps == 0:scaler.step(optimizer)scaler.update()optimizer.zero_grad()total_loss += loss.item()except Exception as e:logging.warning(f"跳过无效批次: {e}")return total_loss / len(data_loader)# 主训练函数
def main_train(rank, world_size, retrain=False, multi_gpu=False):if multi_gpu:dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)torch.cuda.set_device(rank)device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu')logging.info(f'Using device: {device}')tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(device)if multi_gpu:model = DDP(model, device_ids=[rank])if retrain:model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=device, weights_only=True))model.to(device)model.train()model.gradient_checkpointing_enable()optimizer = optim.Adam(model.parameters(), lr=1e-5)criterion = torch.nn.BCEWithLogitsLoss()scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)scaler = torch.amp.GradScaler('cuda')max_memory = torch.cuda.get_device_properties(device).total_memory * 0.9 if torch.cuda.is_available() else float('inf')batch_size = get_max_batch_size(model, device, max_memory)logging.info(f'Using batch size: {batch_size}')train_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'), tokenizer, batch_size=batch_size, max_length=128, distributed=multi_gpu, num_workers=4)num_epochs = 3gradient_accumulation_steps = 2  # 梯度累积步骤best_loss = float('inf')best_model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')ensure_directory_exists(os.path.dirname(best_model_path))  # 确保模型目录存在writer = SummaryWriter(log_dir=os.path.join(PROJECT_ROOT, 'logs/tensorboard'))for epoch in range(num_epochs):train_loss = train(model, train_data_loader, optimizer, criterion, device, scaler, gradient_accumulation_steps)logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.8f}')writer.add_scalar('Training Loss', train_loss, epoch)scheduler.step(train_loss)if rank == 0:if train_loss < best_loss:best_loss = train_losstorch.save(model.state_dict(), best_model_path)logging.info(f"模型在 Epoch {epoch+1} 更新,Loss: {train_loss:.8f}")if rank == 0:logging.info("模型训练完成并保存")if multi_gpu:dist.destroy_process_group()# 动态调整批大小
def get_max_batch_size(model, device, max_memory=1024 * 1024 * 1024):  # 默认最大显存为1GBbatch_size = 1while True:try:input_ids = torch.randint(0, 100, (batch_size, 128)).to(device)attention_mask = torch.ones(batch_size, 128).to(device)with torch.no_grad():model(input_ids, attention_mask)batch_size *= 2except RuntimeError:return batch_size // 2# 启动多GPU训练
def launch_training(retrain=False, multi_gpu=False):if multi_gpu and torch.cuda.device_count() > 1:world_size = torch.cuda.device_count()mp.spawn(main_train, args=(world_size, retrain, multi_gpu), nprocs=world_size, join=True)else:main_train(0, 1, retrain, multi_gpu)# GUI界面
class XihuaChatbotGUI:def __init__(self, root):self.root = rootself.root.title("羲和聊天机器人")self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)self.load_model()self.model.eval()# 加载训练数据集以便在获取答案时使用self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))self.create_widgets()def create_widgets(self):self.question_label = tk.Label(self.root, text="问题:")self.question_label.pack()self.question_entry = tk.Entry(self.root, width=50)self.question_entry.pack()self.answer_button = tk.Button(self.root, text="获取回答", command=self.get_answer)self.answer_button.pack()self.answer_label = tk.Label(self.root, text="回答:")self.answer_label.pack()self.answer_text = scrolledtext.ScrolledText(self.root, height=10, width=50)self.answer_text.pack()self.train_button = tk.Button(self.root, text="训练模型", command=self.train_model)self.train_button.pack()self.retrain_button = tk.Button(self.root, text="重新训练模型", command=lambda: self.train_model(retrain=True))self.retrain_button.pack()self.multi_gpu_var = tk.BooleanVar()self.multi_gpu_checkbox = tk.Checkbutton(self.root, text="使用多GPU", variable=self.multi_gpu_var)self.multi_gpu_checkbox.pack()self.log_text = scrolledtext.ScrolledText(self.root, height=10, width=50)self.log_text.pack()self.progress_bar = ttk.Progressbar(self.root, orient='horizontal', length=300, mode='determinate')self.progress_bar.pack()def get_answer(self):question = self.question_entry.get()if not question:messagebox.showwarning("输入错误", "请输入问题")returninputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)with torch.no_grad():input_ids = inputs['input_ids'].to(self.device)attention_mask = inputs['attention_mask'].to(self.device)logits = self.model(input_ids, attention_mask)if logits.item() > 0:answer_type = "羲和回答"else:answer_type = "零回答"specific_answer = self.get_specific_answer(question, answer_type)self.answer_text.delete(1.0, tk.END)self.answer_text.insert(tk.END, f"{answer_type}\n{specific_answer}")def get_specific_answer(self, question, answer_type):# 使用模糊匹配查找最相似的问题best_match = Nonebest_ratio = 0.0for item in self.data:ratio = SequenceMatcher(None, question, item['question']).ratio()if ratio > best_ratio:best_ratio = ratiobest_match = itemif best_match:if answer_type == "羲和回答":return best_match['human_answers'][0]else:return best_match['chatgpt_answers'][0]return "这个我也不清楚,你问问零吧"def load_data(self, file_path):data = []if file_path.endswith('.jsonl'):with jsonlines.open(file_path) as reader:for i, item in enumerate(reader):try:data.append(item)except jsonlines.jsonlines.InvalidLineError as e:logging.warning(f"跳过无效行 {i + 1}: {e}")elif file_path.endswith('.json'):with open(file_path, 'r') as f:try:data = json.load(f)except json.JSONDecodeError as e:logging.warning(f"跳过无效文件 {file_path}: {e}")return datadef load_model(self):model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')ensure_directory_exists(os.path.dirname(model_path))  # 确保模型目录存在if os.path.exists(model_path):self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=True))logging.info("加载现有模型")else:logging.info("没有找到现有模型,将使用预训练模型")def train_model(self, retrain=False):file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])if not file_path:messagebox.showwarning("文件选择错误", "请选择一个有效的数据文件")returntry:dataset = XihuaDataset(file_path, self.tokenizer)data_loader = get_data_loader(file_path, self.tokenizer, batch_size=8, max_length=128, distributed=self.multi_gpu_var.get(), num_workers=4)# 加载已训练的模型权重if retrain:self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device, weights_only=True))self.model.to(self.device)self.model.train()optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)criterion = torch.nn.BCEWithLogitsLoss()scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)scaler = torch.amp.GradScaler('cuda')# 启用梯度检查点self.model.bert.gradient_checkpointing_enable()max_memory = torch.cuda.get_device_properties(self.device).total_memory * 0.9 if torch.cuda.is_available() else float('inf')batch_size = get_max_batch_size(self.model, self.device, max_memory)logging.info(f'Using batch size: {batch_size}')data_loader = get_data_loader(file_path, self.tokenizer, batch_size=batch_size, max_length=128, distributed=self.multi_gpu_var.get(), num_workers=4)writer = SummaryWriter(log_dir=os.path.join(PROJECT_ROOT, 'logs/tensorboard'))best_loss = float('inf')best_model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')ensure_directory_exists(os.path.dirname(best_model_path))  # 确保模型目录存在for epoch in range(3):self.progress_bar['value'] = (epoch + 1) / 3 * 100self.root.update_idletasks()train_loss = train(self.model, data_loader, optimizer, criterion, self.device, scaler, 2)logging.info(f'Epoch [{epoch+1}/3], Loss: {train_loss:.4f}')self.log_text.insert(tk.END, f'Epoch [{epoch+1}/3], Loss: {train_loss:.4f}\n')self.log_text.yview(tk.END)writer.add_scalar('Training Loss', train_loss, epoch)scheduler.step(train_loss)if train_loss < best_loss:best_loss = train_losstorch.save(self.model.state_dict(), best_model_path)logging.info(f"模型在 Epoch {epoch+1} 更新,Loss: {train_loss:.4f}")logging.info("模型训练完成并保存")messagebox.showinfo("训练完成", "模型训练完成并保存")except Exception as e:logging.error(f"模型训练失败: {e}")messagebox.showerror("训练失败", f"模型训练失败: {e}")# 主函数
if __name__ == "__main__":# 启动GUIroot = tk.Tk()app = XihuaChatbotGUI(root)root.mainloop()

关键修改点
确保目录存在:在 ensure_directory_exists 函数中,确保日志目录和模型目录在代码执行的早期阶段被创建。
日志配置:在 setup_logging 函数中,调用 ensure_directory_exists 确保日志目录存在。
模型加载:在 load_model 和 train_model 函数中,调用 ensure_directory_exists 确保模型目录存在。
通过这些修改,即使日志目录或模型目录存在问题,也不会中断整个程序的执行。程序会记录详细的错误信息,并继续运行其他部分。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/458235.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

STM32第15章 RCC-使用HSE/HSI配置时钟

时间:2024.10.21-10.23 参考资料: 《零死角玩转STM32》“RCC-使用HSE/HIS配置时钟”章节 TIPS: 从前面的历程中我们知道,程序在启动的时候会执行汇编文件,汇编文件里会调用System_Init(固件库编程的函数),它里面会把时钟初始化成72M,因此前面我们在用固件库写程序的…

MSR寄存器独有的还是共享的

英特尔白皮书Volume 4: Model-Specific Registers 这一章列出了不同英特尔处理器系列的 MSR&#xff08;模型特定寄存器&#xff09;。所有列出的 MSR 都可以使用 RDMSR 和 WRMSR 指令进行读取和写入。MSR 的作用域定义了访问相同 MSR 的处理器集合&#xff0c;具体如下&#x…

栈和队列(上)-栈

1. 栈的概念 引入: 我们平时拿羽毛球,是从盒子顶部的羽毛球开始拿的,而顶部的元素是我们最后放进去的. 栈: 一种特殊的线性表&#xff0c;其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端称为栈顶&#xff0c;另一端称为栈底。栈中的数据元素遵守后…

温泉押金原路退回系统, 押金+手牌+电子押金单——未来之窗行业应用跨平台架构

一、温泉手牌收押金必要性 1. 防止手牌丢失&#xff1a;手牌是顾客在温泉内存储个人物品和进出更衣室的重要凭证。收押金可以让顾客更加重视手牌&#xff0c;降低丢失的概率。比如说&#xff0c;有的顾客可能会因为粗心大意随手放置手牌&#xff0c;如果没有押金的约束&…

STM32之外部中断(实验对射式传感器计次实验)

外部中断配置 #include "stm32f10x.h" // Device headeruint16_t CountSensor_Count;void CountSensor_Init(void) {//RCC--> GPIO--> AFIO--> EXTI--> NVIC五步RCC_APB2PeriphClockCmd(RCC_APB2Periph_GPIOB, ENABLE); //开启GPIOB时…

图---java---黑马

图 概念 图是由顶点(vertex)和边(edge)组成的数据结构&#xff0c;例如 该图有四个顶点&#xff1a;A&#xff0c;B&#xff0c;C&#xff0c;D以及四条有向边&#xff0c;有向图中&#xff0c;边是单向的。 有向 vs 无向 如果是无向图&#xff0c;那么边是双向的&#x…

aarch64-opencv341交叉编译,并在arm上部署helloopencv

背景 当需要在jetson xavier nx或者rk 3562等平台上开发关于视觉检测的工程时&#xff0c;由于arm板子资源不足或者不能联网等原因&#xff0c;通常在虚拟机上利用交叉编译器编译得到可执行程序&#xff0c;然后部署到arm板上。 aarch64-opencv341交叉编译 ubuntu虚拟机中先…

【Linux】环境下升级redis

一、摘要 最近漏洞扫描服务器发现&#xff0c;Redis 缓冲区溢出漏洞(CVE-2024-31449)&#xff0c;解决办法redis更新到6.2.16、7.2.6或7.4.1及以上版本。 二、漏洞描述 漏洞描述&#xff1a;经过身份验证的用户可能会使用特制的 Lua 脚本来触发位库中的堆栈缓冲区溢出&#…

Kaggle比赛复盘

Kaggle - LLM Prompt Recovery 解决方案报告 比赛背景/目标 大型语言模型&#xff08;Large Language Models&#xff0c;LLMs&#xff09;通常被用于改写或对文本进行风格修改。本次Kaggle竞赛的目标是根据给定的改写文本&#xff0c;还原用于将原始文本转换为改写文本的LLM…

MetaArena推出《Final Glory》:引领Web3游戏技术新风向

随着区块链技术的日益成熟&#xff0c;Web3游戏成为了游戏产业探索的新方向&#xff0c;将去中心化经济与虚拟世界结合在一起&#xff0c;形成了一个全新的生态体系。然而&#xff0c;尽管Web3游戏展示了令人兴奋的可能性&#xff0c;但其背后的技术障碍依旧严峻&#xff0c;特…

Android Activity SingleTop启动模式使用场景

通知栏 当用户点击通知栏中的通知时,可以使用单顶启动模式来打开对应的活动,并确保只有一个实例存在。 简单集成极光推送 创建应用 获取appkey参数 切换到极光工作台 极光sdk集成 Project 根目录的主 gradle 配置 Module 的 gradle 配置 Jpush依赖配置 配置推送必须…

华为原生鸿蒙操作系统:我国移动操作系统的新篇章

华为原生鸿蒙操作系统&#xff1a;我国移动操作系统的新篇章 引言 在移动操作系统领域&#xff0c;苹果iOS和安卓系统一直占据主导地位。然而&#xff0c;随着华为原生鸿蒙操作系统的正式发布&#xff0c;这一格局正在发生深刻变化。作为继苹果iOS和安卓系统后的全球第三大移动…

Python酷库之旅-第三方库Pandas(170)

目录 一、用法精讲 781、pandas.arrays.IntervalArray.contains方法 781-1、语法 781-2、参数 781-3、功能 781-4、返回值 781-5、说明 781-6、用法 781-6-1、数据准备 781-6-2、代码示例 781-6-3、结果输出 782、pandas.arrays.IntervalArray.overlaps方法 782-1…

shodan3,vnc空密码批量连接,ip历史记录查找

shodan语法&#xff0c;count&#xff0c;honeyscore count 今天带大家继续学习shodan&#xff0c;今天会带大家学一学这个count命令&#xff0c;再学学其他小命令好其实关键命令也没那么多&#xff0c;就是很方便记忆一下就学会了这样子。 shodan count "/x03/x00/x00…

Docker下载途径

Docker不是Linux自带的&#xff0c;需要我们自己安装 官网&#xff1a;https://www.docker.com/ 安装步骤&#xff1a;https://docs.docker.com/engine/install/centos/ Docker Hub官网(镜像仓库)&#xff1a;https://hub.docker.com/ 在线安装docker 先卸载旧的docker s…

JMeter实战之——模拟登录

本篇介绍使用JMeter 如何对需要登录的站点进行压力测试。 基本Session验证的机制 使用session进行请求验证的机制是一种常见的Web应用认证方式。 该认证方式的主要内容如下&#xff1a; 一、登录过程 用户输入&#xff1a;用户在登录页面输入用户名和密码。发送请求&#x…

JDBC: Java数据库连接的桥梁

什么是JDBC&#xff1f; Java数据库连接&#xff08;Java Database Connectivity&#xff0c;简称JDBC&#xff09;是Java提供的一种API&#xff0c;允许Java应用程序与各种数据库进行交互。JDBC提供了一组标准的接口&#xff0c;开发者可以利用这些接口执行SQL语句、处理结果集…

XQT_UI 组件|02| 按钮 XPushButton

XPushButton 使用文档 简介 XPushButton 是一个自定义的按钮类&#xff0c;基于 Qt 框架构建&#xff0c;提供了丰富的样式和功能选项。它允许开发者轻松创建具有不同外观和行为的按钮&#xff0c;以满足用户界面的需求。 特性 颜色设置&#xff1a;支持多种颜色选择。样式设…

Python之Excel自动化处理(三)

一、Excel数据拆分-xlrd 1.1、代码 import xlrd from xlutils.copy import copydef get_data():wb xlrd.open_workbook(./base_data/data01.xlsx)sh wb.sheet_by_index(0){a: [{},{},{}],b:[{},{},{}],c:[{},{},{}],}all_data {}for r in range(sh.nrows):d {type:sh.cell…