223 lines
8.2 KiB
Python
223 lines
8.2 KiB
Python
"""
|
||
Checkpoint 工具函数模块
|
||
|
||
提供用于从 checkpoint 中重建完整消息历史的工具函数。
|
||
主要用于解决 SummarizationMiddleware 总结消息后导致原始消息丢失的问题。
|
||
"""
|
||
from collections import OrderedDict
|
||
from typing import List, Optional
|
||
|
||
from langchain_core.messages import BaseMessage
|
||
from langgraph.checkpoint.base import CheckpointTuple
|
||
|
||
|
||
def get_message_id(message: BaseMessage) -> str:
|
||
"""
|
||
获取消息的唯一标识符
|
||
|
||
Args:
|
||
message: 消息对象
|
||
|
||
Returns:
|
||
消息的唯一标识符字符串
|
||
"""
|
||
# 优先使用消息的 id 属性(如果存在)
|
||
if hasattr(message, 'id') and message.id:
|
||
return str(message.id)
|
||
|
||
# 如果没有 id,尝试使用其他唯一标识符
|
||
# 一些消息类型可能有 name 或其他唯一字段
|
||
if hasattr(message, 'name') and message.name:
|
||
return f"{message.name}_{id(message)}"
|
||
|
||
# 最后使用内容和类型生成一个标识符
|
||
content = str(getattr(message, 'content', '') or '')
|
||
msg_type = getattr(message, 'type', '') or ''
|
||
# 使用对象的内存地址作为额外的唯一性保证
|
||
return f"{msg_type}_{id(message)}"
|
||
|
||
|
||
def rebuild_full_message_history(checkpoints: List[CheckpointTuple]) -> List[BaseMessage]:
|
||
"""
|
||
通过遍历所有历史 checkpoint 重建完整的消息历史
|
||
|
||
这个方法可以恢复被 SummarizationMiddleware 总结前的原始消息。
|
||
|
||
原理:
|
||
- 每个 checkpoint 都保存了当时的状态
|
||
- SummarizationMiddleware 会在消息过长时总结历史消息,替换原始消息
|
||
- 但之前的 checkpoint 中仍然保存着总结前的原始消息
|
||
- 通过按时间顺序遍历所有 checkpoint,可以提取每个 checkpoint 中的消息
|
||
- 对于重复的消息,保留更完整的版本(通常是原始消息)
|
||
|
||
策略:
|
||
1. 按时间顺序(从旧到新)遍历所有 checkpoint
|
||
2. 对于每个 checkpoint 中的消息:
|
||
- 如果消息 ID 不存在,则添加
|
||
- 如果消息 ID 已存在,比较内容长度,保留更完整的版本
|
||
|
||
Args:
|
||
checkpoints: checkpoint 列表,通常是从新到旧排列的
|
||
|
||
Returns:
|
||
完整的消息历史列表(按时间顺序)
|
||
|
||
Example:
|
||
```python
|
||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||
from utils.checkpoint_helper import rebuild_full_message_history
|
||
|
||
checkpointer = await get_checkpointer()
|
||
checkpoints = [
|
||
checkpoint async for checkpoint in checkpointer.alist(
|
||
{"configurable": {"thread_id": thread_id}}
|
||
)
|
||
]
|
||
|
||
# 重建完整消息历史
|
||
full_messages = rebuild_full_message_history(checkpoints)
|
||
```
|
||
"""
|
||
# 使用 OrderedDict 来存储消息,key 是消息 ID,value 是消息对象
|
||
# 这样可以自动去重,同时保留顺序
|
||
message_dict = OrderedDict()
|
||
|
||
# 按时间顺序遍历所有 checkpoint(从旧到新)
|
||
# checkpoints 通常是从新到旧排列的,所以需要反转
|
||
for checkpoint_tuple in reversed(checkpoints):
|
||
checkpoint = checkpoint_tuple.checkpoint
|
||
if "channel_values" not in checkpoint:
|
||
continue
|
||
|
||
channel_values = checkpoint["channel_values"]
|
||
if "messages" not in channel_values:
|
||
continue
|
||
|
||
messages = channel_values["messages"]
|
||
|
||
# 遍历当前 checkpoint 中的所有消息
|
||
for message in messages:
|
||
msg_id = get_message_id(message)
|
||
|
||
# 如果消息不存在,直接添加
|
||
if msg_id not in message_dict:
|
||
message_dict[msg_id] = message
|
||
else:
|
||
# 如果消息已存在,检查是否需要更新
|
||
existing_msg = message_dict[msg_id]
|
||
existing_content = str(getattr(existing_msg, 'content', '') or '')
|
||
new_content = str(getattr(message, 'content', '') or '')
|
||
|
||
# 策略:如果新消息的内容更长,说明可能是更完整的版本
|
||
# 但也要考虑 SummarizationMiddleware 可能会生成总结消息
|
||
# 如果新消息明显更短,可能是总结后的消息,保留原始消息
|
||
if len(new_content) > len(existing_content) * 1.2:
|
||
# 新消息明显更长,更新
|
||
message_dict[msg_id] = message
|
||
elif len(existing_content) > len(new_content) * 1.2:
|
||
# 原始消息明显更长,保留原始消息(不更新)
|
||
pass
|
||
else:
|
||
# 长度相近,保留第一个(通常是更早的版本,即原始消息)
|
||
pass
|
||
|
||
# 返回消息列表(按时间顺序)
|
||
return list(message_dict.values())
|
||
|
||
|
||
def extract_new_messages_from_checkpoint(
|
||
current_checkpoint: dict,
|
||
parent_checkpoint: Optional[dict] = None
|
||
) -> List[BaseMessage]:
|
||
"""
|
||
从当前 checkpoint 中提取新增的消息(与父 checkpoint 比较)
|
||
|
||
这个方法通过比较当前 checkpoint 和父 checkpoint 的差异,
|
||
提取出新增的消息。这对于理解消息的增量变化很有用。
|
||
|
||
Args:
|
||
current_checkpoint: 当前 checkpoint 字典
|
||
parent_checkpoint: 父 checkpoint 字典(可选)
|
||
|
||
Returns:
|
||
新增的消息列表
|
||
"""
|
||
new_messages = []
|
||
|
||
if "channel_values" not in current_checkpoint:
|
||
return new_messages
|
||
|
||
channel_values = current_checkpoint["channel_values"]
|
||
if "messages" not in channel_values:
|
||
return new_messages
|
||
|
||
current_messages = channel_values["messages"]
|
||
|
||
if parent_checkpoint is None:
|
||
# 如果没有父 checkpoint,返回所有消息
|
||
return current_messages
|
||
|
||
# 获取父 checkpoint 的消息
|
||
parent_messages = []
|
||
if "channel_values" in parent_checkpoint and "messages" in parent_checkpoint["channel_values"]:
|
||
parent_messages = parent_checkpoint["channel_values"]["messages"]
|
||
|
||
# 获取父 checkpoint 的消息 ID 集合
|
||
parent_message_ids = {get_message_id(msg) for msg in parent_messages}
|
||
|
||
# 找出新增的消息
|
||
for msg in current_messages:
|
||
msg_id = get_message_id(msg)
|
||
if msg_id not in parent_message_ids:
|
||
new_messages.append(msg)
|
||
|
||
return new_messages
|
||
|
||
|
||
def rebuild_message_history_by_diff(checkpoints: List[CheckpointTuple]) -> List[BaseMessage]:
|
||
"""
|
||
通过比较相邻 checkpoint 的差异来重建完整的消息历史
|
||
|
||
这个方法通过比较每个 checkpoint 与其父 checkpoint 的差异,
|
||
提取新增的消息,从而重建完整的消息历史。
|
||
这样可以避免 SummarizationMiddleware 总结导致的消息丢失问题。
|
||
|
||
Args:
|
||
checkpoints: checkpoint 列表,通常是从新到旧排列的
|
||
|
||
Returns:
|
||
完整的消息历史列表(按时间顺序)
|
||
"""
|
||
all_messages = []
|
||
|
||
# 创建一个 checkpoint_id 到 checkpoint 的映射
|
||
checkpoint_map = {}
|
||
for checkpoint_tuple in checkpoints:
|
||
checkpoint_id = checkpoint_tuple.config["configurable"]["checkpoint_id"]
|
||
checkpoint_map[checkpoint_id] = checkpoint_tuple
|
||
|
||
# 按时间顺序遍历所有 checkpoint(从旧到新)
|
||
for checkpoint_tuple in reversed(checkpoints):
|
||
checkpoint_id = checkpoint_tuple.config["configurable"]["checkpoint_id"]
|
||
parent_config = checkpoint_tuple.parent_config
|
||
parent_checkpoint_id = (
|
||
parent_config["configurable"]["checkpoint_id"]
|
||
if parent_config else None
|
||
)
|
||
|
||
checkpoint = checkpoint_tuple.checkpoint
|
||
|
||
# 获取父 checkpoint
|
||
parent_checkpoint = None
|
||
if parent_checkpoint_id and parent_checkpoint_id in checkpoint_map:
|
||
parent_checkpoint = checkpoint_map[parent_checkpoint_id].checkpoint
|
||
|
||
# 提取新增的消息
|
||
new_messages = extract_new_messages_from_checkpoint(checkpoint, parent_checkpoint)
|
||
|
||
# 将新增的消息添加到列表中
|
||
all_messages.extend(new_messages)
|
||
|
||
return all_messages
|
||
|