项目目录结构
code
XihuaChatbot/
├── data/
│ └── train_data.jsonl
├── logs/
│ └── (自动创建的日志文件)
├── models/
│ └── xihua_model.pth
│ └── bert-base-chinese/
├── icons/
│ └── icon.ico
├── src/
│ ├── init.py
│ ├── dataset.py
│ ├── model.py
│ ├── gui.py
│ └── main.py
└── requirements.txt
目录说明
XihuaChatbot/: 项目根目录。
data/: 存放训练数据文件,例如 train_data.jsonl。
logs/: 存放日志文件,自动创建。
models/: 存放模型文件,例如 xihua_model.pth 和预训练模型文件夹 bert-base-chinese。
icons/: 存放图标文件,例如 icon.ico。
src/: 存放源代码文件。
init.py: 使 src 成为一个 Python 包。
dataset.py: 定义数据集类 XihuaDataset。
model.py: 定义模型类 XihuaModel。
gui.py: 定义图形用户界面类 XihuaChatbotGUI。
main.py: 主入口文件,启动应用程序。
requirements.txt: 列出项目依赖的库。
文件内容
src/dataset.py
python

import os
import json
import jsonlines
from difflib import SequenceMatcher

class XihuaDataset:
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        if self.validate_item(item):
                            data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {i + 1}: {e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = [item for item in json.load(f) if self.validate_item(item)]
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {file_path}: {e}")
        return data

    def validate_item(self, item):
        required_keys = ['question', 'human_answers', 'chatgpt_answers']
        if all(key in item for key in required_keys):
            return True
        logging.warning(f"跳过无效项: 缺少必要键 {required_keys}")
        return False

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item['question']
        human_answer = item['human_answers'][0]
        chatgpt_answer = item['chatgpt_answers'][0]

        try:
            inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        except Exception as e:
            logging.warning(f"跳过无效项 {idx}: {e}")
            return self.__getitem__((idx + 1) % len(self.data))

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'human_input_ids': human_inputs['input_ids'].squeeze(),
            'human_attention_mask': human_inputs['attention_mask'].squeeze(),
            'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
            'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
            'human_answer': human_answer,
            'chatgpt_answer': chatgpt_answer
        }

src/model.py

python
import torch
from transformers import BertModel

class XihuaModel(torch.nn.Module):
    def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):
        super(XihuaModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

src/gui.py
python

import os
import json
import jsonlines
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import ttkbootstrap as ttk
from ttkbootstrap.constants import *
import logging
from difflib import SequenceMatcher
from datetime import datetime
import speech_recognition as sr

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

class XihuaChatbotGUI:
    def __init__(self, root):
        self.root = root
        self.root.title("羲和聊天机器人")
        self.root.geometry("600x400")
        self.root.iconbitmap(os.path.join(PROJECT_ROOT, 'icons/icon.ico'))  # 设置图标

        self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)
        self.load_model()
        self.model.eval()

        # 加载训练数据集以便在获取答案时使用
        self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))

        self.create_widgets()

    def create_widgets(self):
        # 创建框架
        main_frame = ttk.Frame(self.root, padding=20)
        main_frame.pack(fill=tk.BOTH, expand=True)

        # 问题输入区域
        question_frame = ttk.LabelFrame(main_frame, text="问题输入", padding=10)
        question_frame.grid(row=0, column=0, padx=10, pady=10, sticky=tk.W+tk.E)

        self.question_label = ttk.Label(question_frame, text="问题:", font=("Arial", 14))
        self.question_label.grid(row=0, column=0, pady=10, sticky=tk.W)

        self.question_entry = ttk.Entry(question_frame, width=50, font=("Arial", 12))
        self.question_entry.grid(row=0, column=1, pady=10, sticky=tk.W)

        self.answer_button = ttk.Button(question_frame, text="获取回答", command=self.get_answer, style='success.TButton')
        self.answer_button.grid(row=1, column=0, pady=10, columnspan=2, sticky=tk.W+tk.E)

        self.voice_button = ttk.Button(question_frame, text="语音输入", command=self.recognize_speech, style='primary.TButton')
        self.voice_button.grid(row=1, column=1, pady=10, sticky=tk.W+tk.E)

        # 回答显示区域
        answer_frame = ttk.LabelFrame(main_frame, text="回答显示", padding=10)
        answer_frame.grid(row=1, column=0, padx=10, pady=10, sticky=tk.W+tk.E)

        self.answer_label = ttk.Label(answer_frame, text="回答:", font=("Arial", 14))
        self.answer_label.grid(row=0, column=0, pady=10, sticky=tk.W)

        self.answer_text = tk.Text(answer_frame, height=10, width=50, font=("Arial", 12))
        self.answer_text.grid(row=1, column=0, pady=5, sticky=tk.W+tk.E)

        # 控制按钮区域
        control_frame = ttk.Frame(main_frame, padding=10)
        control_frame.grid(row=2, column=0, padx=10, pady=10, sticky=tk.W+tk.E)

        self.clear_button = ttk.Button(control_frame, text="清除历史记录", command=self.clear_history, style='danger.TButton')
        self.clear_button.grid(row=0, column=0, pady=5, padx=5, sticky=tk.W+tk.E)

        self.save_button = ttk.Button(control_frame, text="保存历史记录", command=self.save_history, style='info.TButton')
        self.save_button.grid(row=0, column=1, pady=5, padx=5, sticky=tk.W+tk.E)

    def get_answer(self):
        question = self.question_entry.get()
        if not question:
            messagebox.showwarning("输入错误", "请输入问题")
            return

        self.answer_text.insert(tk.END, "正在获取回答...\n")
        self.answer_text.update_idletasks()

        inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
        with torch.no_grad():
            input_ids = inputs['input_ids'].to(self.device)
            attention_mask = inputs['attention_mask'].to(self.device)
            logits = self.model(input_ids, attention_mask)
        
        if logits.item() > 0:
            answer_type = "羲和回答"
        else:
            answer_type = "零回答"

        specific_answer = self.get_specific_answer(question, answer_type)

        self.answer_text.insert(tk.END, f"问题: {question}\n{answer_type}\n{specific_answer}\n\n")

    def get_specific_answer(self, question, answer_type):
        # 使用模糊匹配查找最相似的问题
        best_match = None
        best_ratio = 0.0
        for item in self.data:
            ratio = SequenceMatcher(None, question, item['question']).ratio()
            if ratio > best_ratio:
                best_ratio = ratio
                best_match = item

        if best_match:
            if answer_type == "羲和回答":
                return best_match['human_answers'][0]
            else:
                return best_match['chatgpt_answers'][0]
        return "这个我也不清楚,你问问零吧"

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        if self.validate_item(item):
                            data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {i + 1}: {e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = [item for item in json.load(f) if self.validate_item(item)]
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {file_path}: {e}")
        return data

    def validate_item(self, item):
        required_keys = ['question', 'human_answers', 'chatgpt_answers']
        if all(key in item for key in required_keys):
            return True
        logging.warning(f"跳过无效项: 缺少必要键 {required_keys}")
        return False

    def load_model(self):
        model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=True))
            logging.info("加载现有模型")
        else:
            logging.info("没有找到现有模型,将使用预训练模型")

    def clear_history(self):
        self.answer_text.delete(1.0, tk.END)

    def save_history(self):
        history = self.answer_text.get(1.0, tk.END)
        if history.strip():
            timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
            with open(f'history_{timestamp}.txt', 'w') as f:
                f.write(history)
            messagebox.showinfo("保存成功", "历史记录已保存")
        else:
            messagebox.showwarning("保存失败", "没有历史记录可保存")

    def recognize_speech(self):
        recognizer = sr.Recognizer()
        with sr.Microphone() as source:
            recognizer.adjust_for_ambient_noise(source)
            audio = recognizer.listen(source)
        try:
            question = recognizer.recognize_google(audio, language='zh-CN')
            self.question_entry.delete(0, tk.END)
            self.question_entry.insert(0, question)
        except sr.UnknownValueError:
            messagebox.showwarning("识别失败", "无法识别语音")
        except sr.RequestError:
            messagebox.showwarning("网络错误", "无法连接到语音识别服务")

src/main.py
python

import os
import ttkbootstrap as ttk

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

from src.gui import XihuaChatbotGUI

if __name__ == "__main__":
    # 启动GUI
    root = ttk.Window(themename='litera')
    app = XihuaChatbotGUI(root)
    root.mainloop()

requirements.txt

code
torch
transformers
ttkbootstrap
logging
difflib
datetime
speechrecognition
jsonlines

项目运行
确保所有依赖库已安装:

sh

pip install -r requirements.txt

运行主程序:

sh

python src/main.py

通过以上步骤,您的项目将具有清晰的目录结构,便于管理和扩展。希望这能帮助到您!

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐