Spring AI alibaba对话上下文持久化数据库
本文介绍了如何实现Spring AI Alibaba中对话上下文记忆的数据库持久化。主要内容包括: 分析了ChatMemory接口及其默认内存实现InMemoryChatMemory的局限性 提出了数据库持久化方案,设计了包含contextId、message等字段的数据表结构 采用Kryo进行消息对象的序列化/反序列化处理 实现了ChatMemory接口的三个核心方法: add():将消息追加到
本文为个人学习笔记整理,仅供交流参考,非专业教学资料,内容请自行甄别。
前言
本篇介绍Spring AI alibaba中,将对话上下文记忆持久化到数据库的实现。ChatMemory是Spring AI提供的接口,定义了聊天对话历史记录的存储的规范。
它的默认实现是InMemoryChatMemory,即将对话上下文保存到内存中,伴随着服务器的重启,记录会丢失。
在ChatMemory中定义的三个方法:
- add:将一组消息添加到指定对话的记忆中。
conversationId参数是对话的唯一标识,用于区分不同用户或不同会话的上下文。List<Message>代表了要添加的消息列表。Message有三种不同的类型,分别是UserMessage(用户角色),AssistantMessage(助手角色)和SystemMessage(系统角色)实际场景中,当用户发送新消息或 AI 生成响应后,会通过此方法将这些消息存入记忆,确保后续对话能参考历史内容。 - get:从指定对话的记忆中获取最近的lastN条消息。因为在调用chatClient时,通常需要指定保留的上下文记忆条目。每次调用大模型都是需要消耗token的,如果保留的记忆条数过多,成本也会增加。实际实现中,可能会对消息数量做限制(避免上下文过长),lastN参数就是为了灵活控制上下文长度。
- clear:清空指定对话的所有记忆消息。
一、对话上下文持久化数据库
如果需要将对话上下文持久化数据库,需要自己写一个类,实现ChatMemory接口。在进行存储之前,首先需要考虑到表结构的设计,首先需要记录contextId ,作为会话的标识,一次对话中的contextId都是相同的。还需要记录message,即消息的正文。这里采用了blob的格式,因为存入数据库的消息正文需要序列化。
create table if not exists `ai-agent`.context_memory_record
(
id bigint not null comment '主键Id'
primary key,
contextId varchar(255) null comment '上下文Id',
message blob null comment '消息',
createTime datetime default CURRENT_TIMESTAMP not null comment '创建时间'
)
comment '上下文历史消息记录表';
而序列化的选择,采用kryo。不同的Message的实现,其属性也是不同的:
如果采用JSON进行反序列化,那么每次还需要获取到消息的类型,并且进行分支判断处理,无法统一进行处理。所以采用kryo的方案,Kryo 是带类型的二进制格式,序列化时会自动将对象的类信息(如类名、类型标识)写入二进制数据中,反序列化时,Kryo 可以根据二进制数据中包含的类型信息,直接还原出原始对象(可能是 A 或 B),无需显式指定目标类型。
需要引入依赖(还需要引入mybatis-plus 和 mysql的依赖):
<dependency>
<groupId>com.esotericsoftware</groupId>
<artifactId>kryo</artifactId>
<version>5.6.2</version>
</dependency>
在自定义的类中,进行初始化操作:
编写序列化和反序列化的方法:
2.1、add
重写add方法:
整体思路是首先根据conversationId从数据库中查询结果,并且反序列化,如果是第一次操作,那么结果为空,最终会执行插入到数据库的操作,后续则是查询出前一次的上下文,然后进行追加,执行根据conversationId更新数据库的操作。

2.2、get
重写get方法:
这里参数中的lastN,是在调用chatClient时指定的:
2.3、clear

完整代码:
package org.ragdollcat.secondaiagent.chatmemory;
import cn.hutool.core.util.ObjUtil;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.objenesis.strategy.StdInstantiatorStrategy;
import org.ragdollcat.secondaiagent.model.ContextMemoryRecord;
import org.ragdollcat.secondaiagent.service.ContextMemoryRecordService;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.stereotype.Component;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.util.ArrayList;
import java.util.List;
/**
* 基于数据库的对话记忆持久化
*/
@Slf4j
@Component
public class DbBasedChatMemory implements ChatMemory {
@Resource
private ContextMemoryRecordService contextMemoryRecordService;
private static final Kryo kryo = new Kryo();
static {
kryo.setRegistrationRequired(false);
// 设置实例化策略
kryo.setInstantiatorStrategy(new StdInstantiatorStrategy());
}
@Override
public void add(String conversationId, List<Message> messages) {
log.info("conversationId:{},messages:{}", conversationId, messages);
//根据conversationId从数据库中查询结果,并且反序列化
List<Message> conversationMessages = getOrCreateConversation(conversationId);
//将本次结果进行追加
conversationMessages.addAll(messages);
//重新存入数据库
saveConversation(conversationId, conversationMessages);
}
@Override
public List<Message> get(String conversationId, int lastN) {
List<Message> allMessages = getOrCreateConversation(conversationId);
return allMessages.stream()
.skip(Math.max(0, allMessages.size() - lastN))
.toList();
}
@Override
public void clear(String conversationId) {
ContextMemoryRecord contextMemoryRecord = getConversationDB(conversationId);
if (ObjUtil.isNotEmpty(contextMemoryRecord)){
contextMemoryRecordService.removeById(contextMemoryRecord.getId());
}
}
private List<Message> getOrCreateConversation(String conversationId) {
ContextMemoryRecord memoryRecord = getConversationDB(conversationId);
List<Message> messages = new ArrayList<>();
if (ObjUtil.isNotEmpty(memoryRecord) && memoryRecord.getMessage() != null) {
// 从数据库记录中获取字节数组
byte[] messageBytes = memoryRecord.getMessage();
// 使用Kryo反序列化为List<Message>
messages = deserializeMessages(messageBytes);
}
return messages;
}
/**
* 根据Id查询上下文对象
*
* @param conversationId
* @return
*/
private ContextMemoryRecord getConversationDB(String conversationId) {
return contextMemoryRecordService.getOne(new QueryWrapper<>(ContextMemoryRecord.class).eq("contextId", conversationId));
}
private void saveConversation(String conversationId, List<Message> conversationMessages) {
ContextMemoryRecord conversationDB = getConversationDB(conversationId);
//新增
if (ObjUtil.isEmpty(conversationDB)){
ContextMemoryRecord memoryRecord = new ContextMemoryRecord();
memoryRecord.setContextId(conversationId);
//将消息重新序列化
memoryRecord.setMessage(serializeMessages(conversationMessages));
contextMemoryRecordService.save(memoryRecord);
}//更新
else {
ContextMemoryRecord memoryRecord = new ContextMemoryRecord();
memoryRecord.setContextId(conversationId);
//将消息重新序列化
memoryRecord.setMessage(serializeMessages(conversationMessages));
//根据contextId字段更新
contextMemoryRecordService.update(memoryRecord,new QueryWrapper<>(ContextMemoryRecord.class).eq("contextId",conversationId));
}
}
/**
* 反序列化字节数组为List<Message>
*/
private List<Message> deserializeMessages(byte[] data) {
try (Input input = new Input(new ByteArrayInputStream(data))) {
// 反序列化为List<Message>
return kryo.readObject(input, ArrayList.class);
} catch (Exception e) {
// 处理反序列化异常,例如日志记录和返回空列表
log.error("Failed to deserialize messages", e);
return new ArrayList<>();
}
}
// 对应的序列化方法(当你需要保存消息到数据库时使用)
private byte[] serializeMessages(List<Message> messages) {
if (messages == null || messages.isEmpty()) {
return null;
}
try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
Output output = new Output(baos)) {
kryo.writeObject(output, messages);
output.flush();
return baos.toByteArray();
} catch (Exception e) {
log.error("Failed to serialize messages", e);
return null;
}
}
}
用户第一次提问,conversationId下的记录为空,就执行序列化然后新增的操作。
助手进行回答,查询到第一次相同conversationId的记录,进行追加,然后根据conversationId更新数据库。

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。
更多推荐
所有评论(0)