最近在学习研发QA系统,本人单纯想记录一下。项目源码和思路主要参考知乎专栏:

PyTorch搭建聊天机器人(一)词表与数据加载器 - 知乎

PyTorch搭建聊天机器人(二)定义seq2seq网络前向逻辑 - 知乎

PyTorch搭建聊天机器人(三)训练与评估 - 知乎

 

        知乎大佬的思路还是很清晰的。词表和数据加载器使用的数据集本人改用json格式的sogou和web的数据集,这个数据还需要自己处理一下,有些问题没有答案,但是有问答的相关信息(这个文本太长了),为了方便训练,而筛除没有答案的问题,然后做好标签。

 

 json数据获取

#读取json文件内容
sogou_data = json.load(open("qa_datasets/SogouQA.json", 'r', encoding='utf-8'))
web_data = json.load(open("qa_datasets/WebQA.json", 'r', encoding='utf-8')) 
question_list = []
answer_list = []
#获取json字段的相应内容
for i in range(len(sogou_data)):
    if sogou_data[i]['passages'][0]['answer'] != "":
        question_str = ""
        answer_str = ""
        for j in range(len(sogou_data[i]['question'])):
            question_str += sogou_data[i]['question'][j] + " "
        question_list.append(question_str)
        for j in range(len(sogou_data[i]['passages'][0]['answer'])):
            answer_str += sogou_data[i]['passages'][0]['answer'][j] + " "
        answer_list.append(answer_str)
for i in range(len(web_data)):
    if web_data[i]['passages'][0]['answer'] != "":
        question_str = ""
        answer_str = ""
        for j in range(len(web_data[i]['question'])):
            question_str += web_data[i]['question'][j] + " "
        question_list.append(question_str)
        for j in range(len(web_data[i]['passages'][0]['answer'])):
            answer_str += web_data[i]['passages'][0]['answer'][j] + " "
        answer_list.append(answer_str)
for i in range(len(question_list)):
    self.addSentence(question_list[i].strip())
    self.addSentence(answer_list[i].strip())
    pairs.append([question_list[i], answer_list[i]])

词表和数据加载器和参考知乎大佬的源代码!

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐