BERT的中文问答系统39

实现当用户在GUI中输入问题(例如“刘邦”)且输出的答案被标记为不正确时,自动从百度百科中搜索相关内容并显示在GUI中的功能,我们需要对现有的代码进行一些修改。以下是完整的代码,包括对XihuaChatbotGUI类的修改以及新增的功能:

import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime
import requests
from bs4 import BeautifulSoup# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)def setup_logging():log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))logging.basicConfig(level=logging.INFO,format='%(asctime)s - %(levelname)s - %(message)s',handlers=[logging.FileHandler(log_file),logging.StreamHandler()])setup_logging()# 数据集类
class XihuaDataset(Dataset):def __init__(self, file_path, tokenizer, max_length=128):self.tokenizer = tokenizerself.max_length = max_lengthself.data = self.load_data(file_path)def load_data(self, file_path):data = []if file_path.endswith('.jsonl'):with jsonlines.open(file_path) as reader:for i, item in enumerate(reader):try:data.append(item)except jsonlines.jsonlines.InvalidLineError as e:logging.warning(f"跳过无效行 {i + 1}: {e}")elif file_path.endswith('.json'):with open(file_path, 'r') as f:try:data = json.load(f)except json.JSONDecodeError as e:logging.warning(f"跳过无效文件 {file_path}: {e}")return datadef __len__(self):return len(self.data)def __getitem__(self, idx):item = self.data[idx]question = item['question']human_answer = item['human_answers'][0]chatgpt_answer = item['chatgpt_answers'][0]try:inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)except Exception as e:logging.warning(f"跳过无效项 {idx}: {e}")return self.__getitem__((idx + 1) % len(self.data))return {'input_ids': inputs['input_ids'].squeeze(),'attention_mask': inputs['attention_mask'].squeeze(),'human_input_ids': human_inputs['input_ids'].squeeze(),'human_attention_mask': human_inputs['attention_mask'].squeeze(),'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),'human_answer': human_answer,'chatgpt_answer': chatgpt_answer}# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):dataset = XihuaDataset(file_path, tokenizer, max_length)return DataLoader(dataset, batch_size=batch_size, shuffle=True)# 模型定义
class XihuaModel(torch.nn.Module):def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):super(XihuaModel, self).__init__()self.bert = BertModel.from_pretrained(pretrained_model_name)self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)def forward(self, input_ids, attention_mask):outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)pooled_output = outputs.pooler_outputlogits = self.classifier(pooled_output)return logits# 训练函数
def train(model, data_loader, optimizer, criterion, device, progress_var=None):model.train()total_loss = 0.0num_batches = len(data_loader)for batch_idx, batch in enumerate(data_loader):try:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)human_input_ids = batch['human_input_ids'].to(device)human_attention_mask = batch['human_attention_mask'].to(device)chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)optimizer.zero_grad()human_logits = model(human_input_ids, human_attention_mask)chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)human_labels = torch.ones(human_logits.size(0), 1).to(device)chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)loss.backward()optimizer.step()total_loss += loss.item()if progress_var:progress_var.set((batch_idx + 1) / num_batches * 100)except Exception as e:logging.warning(f"跳过无效批次: {e}")return total_loss / len(data_loader)# 主训练函数
def main_train(retrain=False):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')logging.info(f'使用设备: {device}')tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(device)if retrain:model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')if os.path.exists(model_path):model.load_state_dict(torch.load(model_path, map_location=device))logging.info("加载现有模型")else:logging.info("没有找到现有模型,将使用预训练模型")optimizer = optim.Adam(model.parameters(), lr=1e-5)criterion = torch.nn.BCEWithLogitsLoss()train_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'), tokenizer, batch_size=8, max_length=128)num_epochs = 30for epoch in range(num_epochs):train_loss = train(model, train_data_loader, optimizer, criterion, device)logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.10f}')torch.save(model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))logging.info("模型训练完成并保存")# 网络搜索函数
def search_baidu(query):url = f"https://www.baidu.com/s?wd={query}"headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}response = requests.get(url, headers=headers)soup = BeautifulSoup(response.text, 'html.parser')results = soup.find_all('div', class_='c-abstract')if results:return results[0].get_text().strip()return "没有找到相关信息"# 百度百科搜索函数
def search_baidu_baike(query):url = f"https://baike.baidu.com/item/{query}"headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}response = requests.get(url, headers=headers)soup = BeautifulSoup(response.text, 'html.parser')meta_description = soup.find('meta', attrs={'name': 'description'})if meta_description:return meta_description['content']return "没有找到相关信息"# GUI界面
class XihuaChatbotGUI:def __init__(self, root):self.root = rootself.root.title("羲和聊天机器人")self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)self.load_model()self.model.eval()# 加载训练数据集以便在获取答案时使用self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))# 历史记录self.history = []self.create_widgets()def create_widgets(self):# 设置样式style = ttk.Style()style.theme_use('clam')# 顶部框架top_frame = ttk.Frame(self.root)top_frame.pack(pady=10)self.question_label = ttk.Label(top_frame, text="问题:", font=("Arial", 12))self.question_label.grid(row=0, column=0, padx=10)self.question_entry = ttk.Entry(top_frame, width=50, font=("Arial", 12))self.question_entry.grid(row=0, column=1, padx=10)self.answer_button = ttk.Button(top_frame, text="获取回答", command=self.get_answer, style='TButton')self.answer_button.grid(row=0, column=2, padx=10)# 中部框架middle_frame = ttk.Frame(self.root)middle_frame.pack(pady=10)self.chat_text = tk.Text(middle_frame, height=20, width=100, font=("Arial", 12), wrap='word')self.chat_text.grid(row=0, column=0, padx=10, pady=10)self.chat_text.tag_configure("user", justify='right', foreground='blue')self.chat_text.tag_configure("xihua", justify='left', foreground='green')# 底部框架bottom_frame = ttk.Frame(self.root)bottom_frame.pack(pady=10)self.correct_button = ttk.Button(bottom_frame, text="准确", command=self.mark_correct, style='TButton')self.correct_button.grid(row=0, column=0, padx=10)self.incorrect_button = ttk.Button(bottom_frame, text="不准确", command=self.mark_incorrect, style='TButton')self.incorrect_button.grid(row=0, column=1, padx=10)self.train_button = ttk.Button(bottom_frame, text="训练模型", command=self.train_model, style='TButton')self.train_button.grid(row=0, column=2, padx=10)self.retrain_button = ttk.Button(bottom_frame, text="重新训练模型", command=lambda: self.train_model(retrain=True), style='TButton')self.retrain_button.grid(row=0, column=3, padx=10)self.progress_var = tk.DoubleVar()self.progress_bar = ttk.Progressbar(bottom_frame, variable=self.progress_var, maximum=100, length=200, mode='determinate')self.progress_bar.grid(row=1, column=0, columnspan=4, pady=10)self.log_text = tk.Text(bottom_frame, height=10, width=70, font=("Arial", 12))self.log_text.grid(row=2, column=0, columnspan=4, pady=10)self.evaluate_button = ttk.Button(bottom_frame, text="评估模型", command=self.evaluate_model, style='TButton')self.evaluate_button.grid(row=3, column=0, padx=10, pady=10)self.history_button = ttk.Button(bottom_frame, text="查看历史记录", command=self.view_history, style='TButton')self.history_button.grid(row=3, column=1, padx=10, pady=10)self.save_history_button = ttk.Button(bottom_frame, text="保存历史记录", command=self.save_history, style='TButton')self.save_history_button.grid(row=3, column=2, padx=10, pady=10)def get_answer(self):question = self.question_entry.get()if not question:messagebox.showwarning("输入错误", "请输入问题")returninputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)with torch.no_grad():input_ids = inputs['input_ids'].to(self.device)attention_mask = inputs['attention_mask'].to(self.device)logits = self.model(input_ids, attention_mask)if logits.item() > 0:answer_type = "羲和回答"else:answer_type = "零回答"specific_answer = self.get_specific_answer(question, answer_type)self.chat_text.insert(tk.END, f"用户: {question}\n", "user")self.chat_text.insert(tk.END, f"羲和: {specific_answer}\n", "xihua")# 添加到历史记录self.history.append({'question': question,'answer_type': answer_type,'specific_answer': specific_answer,'accuracy': None  # 初始状态为未评价})def get_specific_answer(self, question, answer_type):# 使用模糊匹配查找最相似的问题best_match = Nonebest_ratio = 0.0for item in self.data:ratio = SequenceMatcher(None, question, item['question']).ratio()if ratio > best_ratio:best_ratio = ratiobest_match = itemif best_match:if answer_type == "羲和回答":return best_match['human_answers'][0]else:return best_match['chatgpt_answers'][0]return "这个我也不清楚,你问问零吧"def load_data(self, file_path):data = []if file_path.endswith('.jsonl'):with jsonlines.open(file_path) as reader:for i, item in enumerate(reader):try:data.append(item)except jsonlines.jsonlines.InvalidLineError as e:logging.warning(f"跳过无效行 {i + 1}: {e}")elif file_path.endswith('.json'):with open(file_path, 'r') as f:try:data = json.load(f)except json.JSONDecodeError as e:logging.warning(f"跳过无效文件 {file_path}: {e}")return datadef load_model(self):model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')if os.path.exists(model_path):self.model.load_state_dict(torch.load(model_path, map_location=self.device))logging.info("加载现有模型")else:logging.info("没有找到现有模型,将使用预训练模型")def train_model(self, retrain=False):file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])if not file_path:messagebox.showwarning("文件选择错误", "请选择一个有效的数据文件")returntry:dataset = XihuaDataset(file_path, self.tokenizer)data_loader = DataLoader(dataset, batch_size=8, shuffle=True)# 加载已训练的模型权重if retrain:self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device))self.model.to(self.device)self.model.train()optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)criterion = torch.nn.BCEWithLogitsLoss()num_epochs = 30for epoch in range(num_epochs):train_loss = train(self.model, data_loader, optimizer, criterion, self.device, self.progress_var)logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.10f}')self.log_text.insert(tk.END, f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.10f}\n')self.log_text.see(tk.END)torch.save(self.model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))logging.info("模型训练完成并保存")self.log_text.insert(tk.END, "模型训练完成并保存\n")self.log_text.see(tk.END)messagebox.showinfo("训练完成", "模型训练完成并保存")except Exception as e:logging.error(f"模型训练失败: {e}")self.log_text.insert(tk.END, f"模型训练失败: {e}\n")self.log_text.see(tk.END)messagebox.showerror("训练失败", f"模型训练失败: {e}")def evaluate_model(self):# 这里可以添加模型评估的逻辑messagebox.showinfo("评估结果", "模型评估功能暂未实现")def mark_correct(self):if self.history:self.history[-1]['accuracy'] = Truemessagebox.showinfo("评价成功", "您认为这次回答是准确的")def mark_incorrect(self):if self.history:self.history[-1]['accuracy'] = Falsequestion = self.history[-1]['question']baike_answer = self.search_baidu_baike(question)self.chat_text.insert(tk.END, f"百度百科结果: {baike_answer}\n", "xihua")messagebox.showinfo("评价成功", "您认为这次回答是不准确的")def search_baidu_baike(self, query):return search_baidu_baike(query)def view_history(self):history_window = tk.Toplevel(self.root)history_window.title("历史记录")history_text = tk.Text(history_window, height=20, width=80, font=("Arial", 12))history_text.pack(padx=10, pady=10)for entry in self.history:history_text.insert(tk.END, f"问题: {entry['question']}\n")history_text.insert(tk.END, f"回答类型: {entry['answer_type']}\n")history_text.insert(tk.END, f"具体回答: {entry['specific_answer']}\n")if entry['accuracy'] is None:history_text.insert(tk.END, "评价: 未评价\n")elif entry['accuracy']:history_text.insert(tk.END, "评价: 准确\n")else:history_text.insert(tk.END, "评价: 不准确\n")history_text.insert(tk.END, "-" * 50 + "\n")def save_history(self):file_path = filedialog.asksaveasfilename(defaultextension=".json", filetypes=[("JSON files", "*.json")])if not file_path:returnwith open(file_path, 'w') as f:json.dump(self.history, f, ensure_ascii=False, indent=4)messagebox.showinfo("保存成功", "历史记录已保存到文件")# 主函数
if __name__ == "__main__":# 启动GUIroot = tk.Tk()app = XihuaChatbotGUI(root)root.mainloop()

主要修改点:
增加百度百科搜索函数:search_baidu_baike函数用于从百度百科中搜索问题的相关信息。
修改mark_incorrect方法:当用户标记回答为不正确时,调用search_baidu_baike函数获取百度百科的结果,并将其显示在GUI的Text组件中。
文件结构:
main.py:主程序文件,包含所有代码。
logs/:日志文件存储目录。
models/:模型权重文件存储目录。
data/:训练数据文件存储目录。
运行步骤:
确保安装了所有依赖库,如torch, transformers, requests, beautifulsoup4等。
将训练数据文件放在data/目录下。
运行main.py启动GUI。
这样,当用户在GUI中输入问题并标记回答为不正确时,程序会自动从百度百科中搜索相关信息并显示在GUI中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/477568.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【SKFramework框架】一、框架介绍

推荐阅读 CSDN主页GitHub开源地址Unity3D插件分享QQ群:398291828小红书小破站 大家好,我是佛系工程师☆恬静的小魔龙☆,不定时更新Unity开发技巧,觉得有用记得一键三连哦。 一、前言 【Unity3D框架】SKFramework框架完全教程《全…

Python 版本的 2024详细代码

2048游戏的Python实现 概述: 2048是一款流行的单人益智游戏,玩家通过滑动数字瓷砖来合并相同的数字,目标是合成2048这个数字。本文将介绍如何使用Python和Pygame库实现2048游戏的基本功能,包括游戏逻辑、界面绘制和用户交互。 主…

排序算法(选择排序、直接插入排序、冒泡排序、二路归并排序)(C语言版)

对数组进行排序,主要演示选择排序、直接排序、冒泡排序、二路归并排序算法,附上代码演示 一、编写好各类排序方法的函数 (1) s_sort(int e[],int n):选择排序。 (2)si_sort(int e[],int n):直接插人排序。…

Unity图形学之Surface Shader结构

1.没有Pass:因为Render Path已经封装好了 Shader "Custom/Test" {Properties{_Color ("Color", Color) (1,1,1,1)_MainTex ("Albedo (RGB)", 2D) "white" {}_Glossiness ("Smoothness", Range(0,1)) 0.5_Meta…

数据集-目标检测系列- 牵牛花 检测数据集 morning_glory >> DataBall

数据集-目标检测系列- 牵牛花 检测数据集 morning DataBall 助力快速掌握数据集的信息和使用方式,会员享有 百种数据集,持续增加中。 贵在坚持! 数据样例项目地址: * 相关项目 1)数据集可视化项目:git…

摄影:相机控色

摄影:相机控色 白平衡(White Balance)白平衡的作用: 白平衡的使用环境色温下相机色温下总结 白平衡偏移与包围白平衡包围 影调 白平衡(White Balance) 人眼看到的白色:会自动适应环境光线。 相…

Odoo :免费且开源的农牧行业ERP管理系统

文 / 开源智造Odoo亚太金牌服务 引言 提供农牧企业数字化、智能化、无人化产品服务及全产业链高度协同的一体化解决方案,提升企业智慧种养、成本领先、产业互联的核心竞争力。 行业典型痛点 一、成本管理粗放,效率低、管控弱 产品研发过程缺少体系化…

oracle查看锁阻塞-谁阻塞了谁

一 模拟锁阻塞 #阻塞1 一个会话正在往一个大表写入大量数据的时候,另一个会话加字段: #会话1 #会话2 会话2被阻塞了。 #阻塞2 模拟一个会话update一条记录,没提交。 另一个会话也update这一条记录: 会话2被阻塞了。 二 简单查…

HDR视频技术之三:色度学与颜色空间

HDR 技术的第二个理论基础是色度学。从前面的内容中可以了解到,光学以及人类视觉感知模型为人类提供了解释与分析人类感知亮度的理论基础,但是 HDR 技术不仅仅关注于提升图像与视频的亮度范围,同时也关注于提供更加丰富的色彩。因此&#xff…

神经网络中常用的激活函数(公式 + 函数图像)

激活函数是人工神经网络中的一个关键组件,负责引入非线性,从而使神经网络能够学习和表示复杂的非线性关系。没有激活函数,神经网络中的所有计算都是线性变换,而线性模型的表达能力有限,无法处理复杂的任务。 激活函数…

Redis——Raft算法

Raft使用较为广泛的强一致性、去中心化、高可用的分布式协议,即使在网络、节点故障等情况下,多个节点依然能达到一致性。 其中redis、etcd等都用到了这种算法 在Redis集群中,采取的主从复制结构,当主节点宕机后,哨兵会…

C 语言复习总结记录二

C 语言复习总结记录二 一 控制语句 1、语句的分类 表达式语句函数调用语句复合语句控制语句空语句 控制语句 控制程序的执行流程,实现程序的各种结构方式 C 语言支持三种结构 :顺序结构、选择结构、循环结构,由特定的语句定义符组成C语言…

网络无人值守批量装机-cobbler

网络无人值守批量装机-cobbler 一、cobbler简介 ​ 上一节中的pxe+kickstart已经可以解决网络批量装机的问题了,但是环境配置过于复杂,而且仅针对某一个版本的操作系统进批量安装则无法满足目前复杂环境的部署需求。 ​ 本小节所讲的cobbler则是基于pxe+kickstart技术的二…

基于深度学习CNN算法的花卉分类识别系统01--带数据集-pyqt5UI界面-全套源码

文章目录 基于深度学习算法的花卉分类识别系统一、项目摘要二、项目运行效果三、项目文件介绍四、项目环境配置1、项目环境库2、环境配置视频教程 五、项目系统架构六、项目构建流程1、数据集2、算法网络Mobilenet3、网络模型训练4、训练好的模型预测5、UI界面设计-pyqt56、项目…

HarmonyOs鸿蒙开发实战(20)=>一文学会基础使用组件导航Navigation

敲黑板,以下是重点技巧。文章末尾有实战项目效果截图及代码截图可参考 1.概要 Navigation是路由导航的根视图容器Navigation组件主要包含​导航页(NavBar)和子页(NavDestination),导航页不存在页面栈中&am…

tcpdump抓包 wireShark

TCPdump抓包工具介绍 TCPdump,全称dump the traffic on anetwork,是一个运行在linux平台可以根据使用者需求对网络上传输的数据包进行捕获的抓包工具。 tcpdump可以支持的功能: 1、在Linux平台将网络中传输的数据包全部捕获过来进行分析 2、支持网络层…

已阻止加载“http://localhost:8086/xxx.js”的模块,它使用了不允许的 MIME 类型 (“text/plain”)。

记录今天解决的一个小bug 在终端启动8080端口号监听后,打开网址http://localhost:8080,发现不能正确加载页面,打开检查-控制台,出现如下警告:已阻止加载“http://localhost:8086/xxx.js”的模块,它使用了不…

使用 helm 部署 gitlab

一、下载 Gitlab chart 进入 artifacthub 官网 选择你想要的版本(我选择的chart版本是 8.4.0 , gitlab 版本是17.4.0 ) 进入到控制台,添加helm仓库 如果你想不改任何配置,你可以执行安装命令,等待安装即可helm instal…

若依-一个请求中返回多个表的信息

背景 主表是点位表关联子表 需要知道对应 合作商的信息关联子表 需要直到对应 区域的信息关联子表 需要直到对应 设备数量 实现的方案 关联实体相关的标签

C++注释

目录 1. 什么是注释 2. 语法 2.1 单行注释 2.2 多行注释 2.3 示例 3. 注释源代码的方法 3.1 使用多行注释 3.2 使用预处理器指令 #if #endif 3.3 使用条件判断语句 if (false) 4. 不能用宏定义,组成注释 5 // \ 会将源代码中的下一行也被当作注释中的一部…