聊天机器人羲和的代码04。1
models/: 存放模型文件,例如 xihua_model.pth 和预训练模型文件夹 bert-base-chinese。通过以上步骤,您的项目将具有清晰的目录结构,便于管理和扩展。data/: 存放训练数据文件,例如 train_data.jsonl。icons/: 存放图标文件,例如 icon.ico。model.py: 定义模型类 XihuaModel。requirements.txt:
项目目录结构
code
XihuaChatbot/
├── data/
│ └── train_data.jsonl
├── logs/
│ └── (自动创建的日志文件)
├── models/
│ └── xihua_model.pth
│ └── bert-base-chinese/
├── icons/
│ └── icon.ico
├── src/
│ ├── init.py
│ ├── dataset.py
│ ├── model.py
│ ├── gui.py
│ └── main.py
└── requirements.txt
目录说明
XihuaChatbot/: 项目根目录。
data/: 存放训练数据文件,例如 train_data.jsonl。
logs/: 存放日志文件,自动创建。
models/: 存放模型文件,例如 xihua_model.pth 和预训练模型文件夹 bert-base-chinese。
icons/: 存放图标文件,例如 icon.ico。
src/: 存放源代码文件。
init.py: 使 src 成为一个 Python 包。
dataset.py: 定义数据集类 XihuaDataset。
model.py: 定义模型类 XihuaModel。
gui.py: 定义图形用户界面类 XihuaChatbotGUI。
main.py: 主入口文件,启动应用程序。
requirements.txt: 列出项目依赖的库。
文件内容
src/dataset.py
python
import os
import json
import jsonlines
from difflib import SequenceMatcher
class XihuaDataset:
def __init__(self, file_path, tokenizer, max_length=128):
self.tokenizer = tokenizer
self.max_length = max_length
self.data = self.load_data(file_path)
def load_data(self, file_path):
data = []
if file_path.endswith('.jsonl'):
with jsonlines.open(file_path) as reader:
for i, item in enumerate(reader):
try:
if self.validate_item(item):
data.append(item)
except jsonlines.jsonlines.InvalidLineError as e:
logging.warning(f"跳过无效行 {i + 1}: {e}")
elif file_path.endswith('.json'):
with open(file_path, 'r') as f:
try:
data = [item for item in json.load(f) if self.validate_item(item)]
except json.JSONDecodeError as e:
logging.warning(f"跳过无效文件 {file_path}: {e}")
return data
def validate_item(self, item):
required_keys = ['question', 'human_answers', 'chatgpt_answers']
if all(key in item for key in required_keys):
return True
logging.warning(f"跳过无效项: 缺少必要键 {required_keys}")
return False
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
question = item['question']
human_answer = item['human_answers'][0]
chatgpt_answer = item['chatgpt_answers'][0]
try:
inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
except Exception as e:
logging.warning(f"跳过无效项 {idx}: {e}")
return self.__getitem__((idx + 1) % len(self.data))
return {
'input_ids': inputs['input_ids'].squeeze(),
'attention_mask': inputs['attention_mask'].squeeze(),
'human_input_ids': human_inputs['input_ids'].squeeze(),
'human_attention_mask': human_inputs['attention_mask'].squeeze(),
'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
'human_answer': human_answer,
'chatgpt_answer': chatgpt_answer
}
src/model.py
python
import torch
from transformers import BertModel
class XihuaModel(torch.nn.Module):
def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):
super(XihuaModel, self).__init__()
self.bert = BertModel.from_pretrained(pretrained_model_name)
self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
logits = self.classifier(pooled_output)
return logits
src/gui.py
python
import os
import json
import jsonlines
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import ttkbootstrap as ttk
from ttkbootstrap.constants import *
import logging
from difflib import SequenceMatcher
from datetime import datetime
import speech_recognition as sr
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class XihuaChatbotGUI:
def __init__(self, root):
self.root = root
self.root.title("羲和聊天机器人")
self.root.geometry("600x400")
self.root.iconbitmap(os.path.join(PROJECT_ROOT, 'icons/icon.ico')) # 设置图标
self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)
self.load_model()
self.model.eval()
# 加载训练数据集以便在获取答案时使用
self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))
self.create_widgets()
def create_widgets(self):
# 创建框架
main_frame = ttk.Frame(self.root, padding=20)
main_frame.pack(fill=tk.BOTH, expand=True)
# 问题输入区域
question_frame = ttk.LabelFrame(main_frame, text="问题输入", padding=10)
question_frame.grid(row=0, column=0, padx=10, pady=10, sticky=tk.W+tk.E)
self.question_label = ttk.Label(question_frame, text="问题:", font=("Arial", 14))
self.question_label.grid(row=0, column=0, pady=10, sticky=tk.W)
self.question_entry = ttk.Entry(question_frame, width=50, font=("Arial", 12))
self.question_entry.grid(row=0, column=1, pady=10, sticky=tk.W)
self.answer_button = ttk.Button(question_frame, text="获取回答", command=self.get_answer, style='success.TButton')
self.answer_button.grid(row=1, column=0, pady=10, columnspan=2, sticky=tk.W+tk.E)
self.voice_button = ttk.Button(question_frame, text="语音输入", command=self.recognize_speech, style='primary.TButton')
self.voice_button.grid(row=1, column=1, pady=10, sticky=tk.W+tk.E)
# 回答显示区域
answer_frame = ttk.LabelFrame(main_frame, text="回答显示", padding=10)
answer_frame.grid(row=1, column=0, padx=10, pady=10, sticky=tk.W+tk.E)
self.answer_label = ttk.Label(answer_frame, text="回答:", font=("Arial", 14))
self.answer_label.grid(row=0, column=0, pady=10, sticky=tk.W)
self.answer_text = tk.Text(answer_frame, height=10, width=50, font=("Arial", 12))
self.answer_text.grid(row=1, column=0, pady=5, sticky=tk.W+tk.E)
# 控制按钮区域
control_frame = ttk.Frame(main_frame, padding=10)
control_frame.grid(row=2, column=0, padx=10, pady=10, sticky=tk.W+tk.E)
self.clear_button = ttk.Button(control_frame, text="清除历史记录", command=self.clear_history, style='danger.TButton')
self.clear_button.grid(row=0, column=0, pady=5, padx=5, sticky=tk.W+tk.E)
self.save_button = ttk.Button(control_frame, text="保存历史记录", command=self.save_history, style='info.TButton')
self.save_button.grid(row=0, column=1, pady=5, padx=5, sticky=tk.W+tk.E)
def get_answer(self):
question = self.question_entry.get()
if not question:
messagebox.showwarning("输入错误", "请输入问题")
return
self.answer_text.insert(tk.END, "正在获取回答...\n")
self.answer_text.update_idletasks()
inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
with torch.no_grad():
input_ids = inputs['input_ids'].to(self.device)
attention_mask = inputs['attention_mask'].to(self.device)
logits = self.model(input_ids, attention_mask)
if logits.item() > 0:
answer_type = "羲和回答"
else:
answer_type = "零回答"
specific_answer = self.get_specific_answer(question, answer_type)
self.answer_text.insert(tk.END, f"问题: {question}\n{answer_type}\n{specific_answer}\n\n")
def get_specific_answer(self, question, answer_type):
# 使用模糊匹配查找最相似的问题
best_match = None
best_ratio = 0.0
for item in self.data:
ratio = SequenceMatcher(None, question, item['question']).ratio()
if ratio > best_ratio:
best_ratio = ratio
best_match = item
if best_match:
if answer_type == "羲和回答":
return best_match['human_answers'][0]
else:
return best_match['chatgpt_answers'][0]
return "这个我也不清楚,你问问零吧"
def load_data(self, file_path):
data = []
if file_path.endswith('.jsonl'):
with jsonlines.open(file_path) as reader:
for i, item in enumerate(reader):
try:
if self.validate_item(item):
data.append(item)
except jsonlines.jsonlines.InvalidLineError as e:
logging.warning(f"跳过无效行 {i + 1}: {e}")
elif file_path.endswith('.json'):
with open(file_path, 'r') as f:
try:
data = [item for item in json.load(f) if self.validate_item(item)]
except json.JSONDecodeError as e:
logging.warning(f"跳过无效文件 {file_path}: {e}")
return data
def validate_item(self, item):
required_keys = ['question', 'human_answers', 'chatgpt_answers']
if all(key in item for key in required_keys):
return True
logging.warning(f"跳过无效项: 缺少必要键 {required_keys}")
return False
def load_model(self):
model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
if os.path.exists(model_path):
self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=True))
logging.info("加载现有模型")
else:
logging.info("没有找到现有模型,将使用预训练模型")
def clear_history(self):
self.answer_text.delete(1.0, tk.END)
def save_history(self):
history = self.answer_text.get(1.0, tk.END)
if history.strip():
timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
with open(f'history_{timestamp}.txt', 'w') as f:
f.write(history)
messagebox.showinfo("保存成功", "历史记录已保存")
else:
messagebox.showwarning("保存失败", "没有历史记录可保存")
def recognize_speech(self):
recognizer = sr.Recognizer()
with sr.Microphone() as source:
recognizer.adjust_for_ambient_noise(source)
audio = recognizer.listen(source)
try:
question = recognizer.recognize_google(audio, language='zh-CN')
self.question_entry.delete(0, tk.END)
self.question_entry.insert(0, question)
except sr.UnknownValueError:
messagebox.showwarning("识别失败", "无法识别语音")
except sr.RequestError:
messagebox.showwarning("网络错误", "无法连接到语音识别服务")
src/main.py
python
import os
import ttkbootstrap as ttk
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
from src.gui import XihuaChatbotGUI
if __name__ == "__main__":
# 启动GUI
root = ttk.Window(themename='litera')
app = XihuaChatbotGUI(root)
root.mainloop()
requirements.txt
code
torch
transformers
ttkbootstrap
logging
difflib
datetime
speechrecognition
jsonlines
项目运行
确保所有依赖库已安装:
sh
pip install -r requirements.txt
运行主程序:
sh
python src/main.py
通过以上步骤,您的项目将具有清晰的目录结构,便于管理和扩展。希望这能帮助到您!
DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。
更多推荐

所有评论(0)