本文为个人学习笔记整理,仅供交流参考,非专业教学资料,内容请自行甄别。


前言

  本篇介绍Spring AI alibaba中,将对话上下文记忆持久化到数据库的实现。ChatMemory是Spring AI提供的接口,定义了聊天对话历史记录的存储的规范。
在这里插入图片描述
  它的默认实现是InMemoryChatMemory,即将对话上下文保存到内存中,伴随着服务器的重启,记录会丢失。
在这里插入图片描述
  在ChatMemory中定义的三个方法:

  • add:将一组消息添加到指定对话的记忆中。conversationId参数是对话的唯一标识,用于区分不同用户或不同会话的上下文。List<Message>代表了要添加的消息列表。Message有三种不同的类型,分别是UserMessage(用户角色)AssistantMessage(助手角色)SystemMessage(系统角色) 实际场景中,当用户发送新消息或 AI 生成响应后,会通过此方法将这些消息存入记忆,确保后续对话能参考历史内容。
  • get:从指定对话的记忆中获取最近的lastN条消息。因为在调用chatClient时,通常需要指定保留的上下文记忆条目。每次调用大模型都是需要消耗token的,如果保留的记忆条数过多,成本也会增加。实际实现中,可能会对消息数量做限制(避免上下文过长),lastN参数就是为了灵活控制上下文长度。
  • clear:清空指定对话的所有记忆消息。

一、对话上下文持久化数据库

  如果需要将对话上下文持久化数据库,需要自己写一个类,实现ChatMemory接口。在进行存储之前,首先需要考虑到表结构的设计,首先需要记录contextId ,作为会话的标识,一次对话中的contextId都是相同的。还需要记录message,即消息的正文。这里采用了blob的格式,因为存入数据库的消息正文需要序列化。

create table if not exists `ai-agent`.context_memory_record
(
    id         bigint                             not null comment '主键Id'
        primary key,
    contextId  varchar(255)                       null comment '上下文Id',
    message    blob                               null comment '消息',
    createTime datetime default CURRENT_TIMESTAMP not null comment '创建时间'
)
    comment '上下文历史消息记录表';

  而序列化的选择,采用kryo。不同的Message的实现,其属性也是不同的:
在这里插入图片描述
  如果采用JSON进行反序列化,那么每次还需要获取到消息的类型,并且进行分支判断处理,无法统一进行处理。所以采用kryo的方案,Kryo 是带类型的二进制格式,序列化时会自动将对象的类信息(如类名、类型标识)写入二进制数据中,反序列化时,Kryo 可以根据二进制数据中包含的类型信息,直接还原出原始对象(可能是 A 或 B),无需显式指定目标类型。
  需要引入依赖(还需要引入mybatis-plus 和 mysql的依赖):

        <dependency>
            <groupId>com.esotericsoftware</groupId>
            <artifactId>kryo</artifactId>
            <version>5.6.2</version>
        </dependency>

  在自定义的类中,进行初始化操作:
在这里插入图片描述

  编写序列化和反序列化的方法:
在这里插入图片描述

2.1、add

  重写add方法:
在这里插入图片描述
  整体思路是首先根据conversationId从数据库中查询结果,并且反序列化,如果是第一次操作,那么结果为空,最终会执行插入到数据库的操作,后续则是查询出前一次的上下文,然后进行追加,执行根据conversationId更新数据库的操作。
在这里插入图片描述
在这里插入图片描述

2.2、get

  重写get方法:
在这里插入图片描述
  这里参数中的lastN,是在调用chatClient时指定的:
在这里插入图片描述

2.3、clear

在这里插入图片描述


  完整代码:

package org.ragdollcat.secondaiagent.chatmemory;

import cn.hutool.core.util.ObjUtil;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.objenesis.strategy.StdInstantiatorStrategy;
import org.ragdollcat.secondaiagent.model.ContextMemoryRecord;
import org.ragdollcat.secondaiagent.service.ContextMemoryRecordService;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.stereotype.Component;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.util.ArrayList;
import java.util.List;

/**
 * 基于数据库的对话记忆持久化
 */
@Slf4j
@Component
public class DbBasedChatMemory implements ChatMemory {

    @Resource
    private ContextMemoryRecordService contextMemoryRecordService;


    private static final Kryo kryo = new Kryo();

    static {
        kryo.setRegistrationRequired(false);
        // 设置实例化策略
        kryo.setInstantiatorStrategy(new StdInstantiatorStrategy());
    }


    @Override
    public void add(String conversationId, List<Message> messages) {
        log.info("conversationId:{},messages:{}", conversationId, messages);
        //根据conversationId从数据库中查询结果,并且反序列化
        List<Message> conversationMessages = getOrCreateConversation(conversationId);
        //将本次结果进行追加
        conversationMessages.addAll(messages);
        //重新存入数据库
        saveConversation(conversationId, conversationMessages);
    }



    @Override
    public List<Message> get(String conversationId, int lastN) {
        List<Message> allMessages = getOrCreateConversation(conversationId);
        return allMessages.stream()
                .skip(Math.max(0, allMessages.size() - lastN))
                .toList();
    }

    @Override
    public void clear(String conversationId) {
        ContextMemoryRecord contextMemoryRecord = getConversationDB(conversationId);
        if (ObjUtil.isNotEmpty(contextMemoryRecord)){
            contextMemoryRecordService.removeById(contextMemoryRecord.getId());
        }
    }


    private List<Message> getOrCreateConversation(String conversationId) {
        ContextMemoryRecord memoryRecord = getConversationDB(conversationId);
        List<Message> messages = new ArrayList<>();
        if (ObjUtil.isNotEmpty(memoryRecord) && memoryRecord.getMessage() != null) {
            // 从数据库记录中获取字节数组
            byte[] messageBytes = memoryRecord.getMessage();
            // 使用Kryo反序列化为List<Message>
            messages = deserializeMessages(messageBytes);
        }
        return messages;
    }

    /**
     * 根据Id查询上下文对象
     *
     * @param conversationId
     * @return
     */
    private ContextMemoryRecord getConversationDB(String conversationId) {
        return contextMemoryRecordService.getOne(new QueryWrapper<>(ContextMemoryRecord.class).eq("contextId", conversationId));
    }




    private void saveConversation(String conversationId, List<Message> conversationMessages) {
        ContextMemoryRecord conversationDB = getConversationDB(conversationId);
        //新增
        if (ObjUtil.isEmpty(conversationDB)){
            ContextMemoryRecord memoryRecord = new ContextMemoryRecord();
            memoryRecord.setContextId(conversationId);
            //将消息重新序列化
            memoryRecord.setMessage(serializeMessages(conversationMessages));
            contextMemoryRecordService.save(memoryRecord);
        }//更新
        else {
            ContextMemoryRecord memoryRecord = new ContextMemoryRecord();
            memoryRecord.setContextId(conversationId);
            //将消息重新序列化
            memoryRecord.setMessage(serializeMessages(conversationMessages));
            //根据contextId字段更新
            contextMemoryRecordService.update(memoryRecord,new QueryWrapper<>(ContextMemoryRecord.class).eq("contextId",conversationId));
        }
    }

    /**
     * 反序列化字节数组为List<Message>
     */
    private List<Message> deserializeMessages(byte[] data) {
        try (Input input = new Input(new ByteArrayInputStream(data))) {
            // 反序列化为List<Message>
            return kryo.readObject(input, ArrayList.class);
        } catch (Exception e) {
            // 处理反序列化异常,例如日志记录和返回空列表
            log.error("Failed to deserialize messages", e);
            return new ArrayList<>();
        }
    }

    // 对应的序列化方法(当你需要保存消息到数据库时使用)
    private byte[] serializeMessages(List<Message> messages) {
        if (messages == null || messages.isEmpty()) {
            return null;
        }

        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
             Output output = new Output(baos)) {
            kryo.writeObject(output, messages);
            output.flush();
            return baos.toByteArray();
        } catch (Exception e) {
            log.error("Failed to serialize messages", e);
            return null;
        }
    }

}

  用户第一次提问,conversationId下的记录为空,就执行序列化然后新增的操作。
在这里插入图片描述
  助手进行回答,查询到第一次相同conversationId的记录,进行追加,然后根据conversationId更新数据库。
在这里插入图片描述
在这里插入图片描述


Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐