huoyan-enterprise/backend/utils/checkpoint_helper.py

223 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Checkpoint 工具函数模块
提供用于从 checkpoint 中重建完整消息历史的工具函数。
主要用于解决 SummarizationMiddleware 总结消息后导致原始消息丢失的问题。
"""
from collections import OrderedDict
from typing import List, Optional
from langchain_core.messages import BaseMessage
from langgraph.checkpoint.base import CheckpointTuple
def get_message_id(message: BaseMessage) -> str:
"""
获取消息的唯一标识符
Args:
message: 消息对象
Returns:
消息的唯一标识符字符串
"""
# 优先使用消息的 id 属性(如果存在)
if hasattr(message, 'id') and message.id:
return str(message.id)
# 如果没有 id尝试使用其他唯一标识符
# 一些消息类型可能有 name 或其他唯一字段
if hasattr(message, 'name') and message.name:
return f"{message.name}_{id(message)}"
# 最后使用内容和类型生成一个标识符
content = str(getattr(message, 'content', '') or '')
msg_type = getattr(message, 'type', '') or ''
# 使用对象的内存地址作为额外的唯一性保证
return f"{msg_type}_{id(message)}"
def rebuild_full_message_history(checkpoints: List[CheckpointTuple]) -> List[BaseMessage]:
"""
通过遍历所有历史 checkpoint 重建完整的消息历史
这个方法可以恢复被 SummarizationMiddleware 总结前的原始消息。
原理:
- 每个 checkpoint 都保存了当时的状态
- SummarizationMiddleware 会在消息过长时总结历史消息,替换原始消息
- 但之前的 checkpoint 中仍然保存着总结前的原始消息
- 通过按时间顺序遍历所有 checkpoint可以提取每个 checkpoint 中的消息
- 对于重复的消息,保留更完整的版本(通常是原始消息)
策略:
1. 按时间顺序(从旧到新)遍历所有 checkpoint
2. 对于每个 checkpoint 中的消息:
- 如果消息 ID 不存在,则添加
- 如果消息 ID 已存在,比较内容长度,保留更完整的版本
Args:
checkpoints: checkpoint 列表,通常是从新到旧排列的
Returns:
完整的消息历史列表(按时间顺序)
Example:
```python
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from utils.checkpoint_helper import rebuild_full_message_history
checkpointer = await get_checkpointer()
checkpoints = [
checkpoint async for checkpoint in checkpointer.alist(
{"configurable": {"thread_id": thread_id}}
)
]
# 重建完整消息历史
full_messages = rebuild_full_message_history(checkpoints)
```
"""
# 使用 OrderedDict 来存储消息key 是消息 IDvalue 是消息对象
# 这样可以自动去重,同时保留顺序
message_dict = OrderedDict()
# 按时间顺序遍历所有 checkpoint从旧到新
# checkpoints 通常是从新到旧排列的,所以需要反转
for checkpoint_tuple in reversed(checkpoints):
checkpoint = checkpoint_tuple.checkpoint
if "channel_values" not in checkpoint:
continue
channel_values = checkpoint["channel_values"]
if "messages" not in channel_values:
continue
messages = channel_values["messages"]
# 遍历当前 checkpoint 中的所有消息
for message in messages:
msg_id = get_message_id(message)
# 如果消息不存在,直接添加
if msg_id not in message_dict:
message_dict[msg_id] = message
else:
# 如果消息已存在,检查是否需要更新
existing_msg = message_dict[msg_id]
existing_content = str(getattr(existing_msg, 'content', '') or '')
new_content = str(getattr(message, 'content', '') or '')
# 策略:如果新消息的内容更长,说明可能是更完整的版本
# 但也要考虑 SummarizationMiddleware 可能会生成总结消息
# 如果新消息明显更短,可能是总结后的消息,保留原始消息
if len(new_content) > len(existing_content) * 1.2:
# 新消息明显更长,更新
message_dict[msg_id] = message
elif len(existing_content) > len(new_content) * 1.2:
# 原始消息明显更长,保留原始消息(不更新)
pass
else:
# 长度相近,保留第一个(通常是更早的版本,即原始消息)
pass
# 返回消息列表(按时间顺序)
return list(message_dict.values())
def extract_new_messages_from_checkpoint(
current_checkpoint: dict,
parent_checkpoint: Optional[dict] = None
) -> List[BaseMessage]:
"""
从当前 checkpoint 中提取新增的消息(与父 checkpoint 比较)
这个方法通过比较当前 checkpoint 和父 checkpoint 的差异,
提取出新增的消息。这对于理解消息的增量变化很有用。
Args:
current_checkpoint: 当前 checkpoint 字典
parent_checkpoint: 父 checkpoint 字典(可选)
Returns:
新增的消息列表
"""
new_messages = []
if "channel_values" not in current_checkpoint:
return new_messages
channel_values = current_checkpoint["channel_values"]
if "messages" not in channel_values:
return new_messages
current_messages = channel_values["messages"]
if parent_checkpoint is None:
# 如果没有父 checkpoint返回所有消息
return current_messages
# 获取父 checkpoint 的消息
parent_messages = []
if "channel_values" in parent_checkpoint and "messages" in parent_checkpoint["channel_values"]:
parent_messages = parent_checkpoint["channel_values"]["messages"]
# 获取父 checkpoint 的消息 ID 集合
parent_message_ids = {get_message_id(msg) for msg in parent_messages}
# 找出新增的消息
for msg in current_messages:
msg_id = get_message_id(msg)
if msg_id not in parent_message_ids:
new_messages.append(msg)
return new_messages
def rebuild_message_history_by_diff(checkpoints: List[CheckpointTuple]) -> List[BaseMessage]:
"""
通过比较相邻 checkpoint 的差异来重建完整的消息历史
这个方法通过比较每个 checkpoint 与其父 checkpoint 的差异,
提取新增的消息,从而重建完整的消息历史。
这样可以避免 SummarizationMiddleware 总结导致的消息丢失问题。
Args:
checkpoints: checkpoint 列表,通常是从新到旧排列的
Returns:
完整的消息历史列表(按时间顺序)
"""
all_messages = []
# 创建一个 checkpoint_id 到 checkpoint 的映射
checkpoint_map = {}
for checkpoint_tuple in checkpoints:
checkpoint_id = checkpoint_tuple.config["configurable"]["checkpoint_id"]
checkpoint_map[checkpoint_id] = checkpoint_tuple
# 按时间顺序遍历所有 checkpoint从旧到新
for checkpoint_tuple in reversed(checkpoints):
checkpoint_id = checkpoint_tuple.config["configurable"]["checkpoint_id"]
parent_config = checkpoint_tuple.parent_config
parent_checkpoint_id = (
parent_config["configurable"]["checkpoint_id"]
if parent_config else None
)
checkpoint = checkpoint_tuple.checkpoint
# 获取父 checkpoint
parent_checkpoint = None
if parent_checkpoint_id and parent_checkpoint_id in checkpoint_map:
parent_checkpoint = checkpoint_map[parent_checkpoint_id].checkpoint
# 提取新增的消息
new_messages = extract_new_messages_from_checkpoint(checkpoint, parent_checkpoint)
# 将新增的消息添加到列表中
all_messages.extend(new_messages)
return all_messages