""" Checkpoint 工具函数模块 提供用于从 checkpoint 中重建完整消息历史的工具函数。 主要用于解决 SummarizationMiddleware 总结消息后导致原始消息丢失的问题。 """ from collections import OrderedDict from typing import List, Optional from langchain_core.messages import BaseMessage from langgraph.checkpoint.base import CheckpointTuple def get_message_id(message: BaseMessage) -> str: """ 获取消息的唯一标识符 Args: message: 消息对象 Returns: 消息的唯一标识符字符串 """ # 优先使用消息的 id 属性(如果存在) if hasattr(message, 'id') and message.id: return str(message.id) # 如果没有 id,尝试使用其他唯一标识符 # 一些消息类型可能有 name 或其他唯一字段 if hasattr(message, 'name') and message.name: return f"{message.name}_{id(message)}" # 最后使用内容和类型生成一个标识符 content = str(getattr(message, 'content', '') or '') msg_type = getattr(message, 'type', '') or '' # 使用对象的内存地址作为额外的唯一性保证 return f"{msg_type}_{id(message)}" def rebuild_full_message_history(checkpoints: List[CheckpointTuple]) -> List[BaseMessage]: """ 通过遍历所有历史 checkpoint 重建完整的消息历史 这个方法可以恢复被 SummarizationMiddleware 总结前的原始消息。 原理: - 每个 checkpoint 都保存了当时的状态 - SummarizationMiddleware 会在消息过长时总结历史消息,替换原始消息 - 但之前的 checkpoint 中仍然保存着总结前的原始消息 - 通过按时间顺序遍历所有 checkpoint,可以提取每个 checkpoint 中的消息 - 对于重复的消息,保留更完整的版本(通常是原始消息) 策略: 1. 按时间顺序(从旧到新)遍历所有 checkpoint 2. 对于每个 checkpoint 中的消息: - 如果消息 ID 不存在,则添加 - 如果消息 ID 已存在,比较内容长度,保留更完整的版本 Args: checkpoints: checkpoint 列表,通常是从新到旧排列的 Returns: 完整的消息历史列表(按时间顺序) Example: ```python from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from utils.checkpoint_helper import rebuild_full_message_history checkpointer = await get_checkpointer() checkpoints = [ checkpoint async for checkpoint in checkpointer.alist( {"configurable": {"thread_id": thread_id}} ) ] # 重建完整消息历史 full_messages = rebuild_full_message_history(checkpoints) ``` """ # 使用 OrderedDict 来存储消息,key 是消息 ID,value 是消息对象 # 这样可以自动去重,同时保留顺序 message_dict = OrderedDict() # 按时间顺序遍历所有 checkpoint(从旧到新) # checkpoints 通常是从新到旧排列的,所以需要反转 for checkpoint_tuple in reversed(checkpoints): checkpoint = checkpoint_tuple.checkpoint if "channel_values" not in checkpoint: continue channel_values = checkpoint["channel_values"] if "messages" not in channel_values: continue messages = channel_values["messages"] # 遍历当前 checkpoint 中的所有消息 for message in messages: msg_id = get_message_id(message) # 如果消息不存在,直接添加 if msg_id not in message_dict: message_dict[msg_id] = message else: # 如果消息已存在,检查是否需要更新 existing_msg = message_dict[msg_id] existing_content = str(getattr(existing_msg, 'content', '') or '') new_content = str(getattr(message, 'content', '') or '') # 策略:如果新消息的内容更长,说明可能是更完整的版本 # 但也要考虑 SummarizationMiddleware 可能会生成总结消息 # 如果新消息明显更短,可能是总结后的消息,保留原始消息 if len(new_content) > len(existing_content) * 1.2: # 新消息明显更长,更新 message_dict[msg_id] = message elif len(existing_content) > len(new_content) * 1.2: # 原始消息明显更长,保留原始消息(不更新) pass else: # 长度相近,保留第一个(通常是更早的版本,即原始消息) pass # 返回消息列表(按时间顺序) return list(message_dict.values()) def extract_new_messages_from_checkpoint( current_checkpoint: dict, parent_checkpoint: Optional[dict] = None ) -> List[BaseMessage]: """ 从当前 checkpoint 中提取新增的消息(与父 checkpoint 比较) 这个方法通过比较当前 checkpoint 和父 checkpoint 的差异, 提取出新增的消息。这对于理解消息的增量变化很有用。 Args: current_checkpoint: 当前 checkpoint 字典 parent_checkpoint: 父 checkpoint 字典(可选) Returns: 新增的消息列表 """ new_messages = [] if "channel_values" not in current_checkpoint: return new_messages channel_values = current_checkpoint["channel_values"] if "messages" not in channel_values: return new_messages current_messages = channel_values["messages"] if parent_checkpoint is None: # 如果没有父 checkpoint,返回所有消息 return current_messages # 获取父 checkpoint 的消息 parent_messages = [] if "channel_values" in parent_checkpoint and "messages" in parent_checkpoint["channel_values"]: parent_messages = parent_checkpoint["channel_values"]["messages"] # 获取父 checkpoint 的消息 ID 集合 parent_message_ids = {get_message_id(msg) for msg in parent_messages} # 找出新增的消息 for msg in current_messages: msg_id = get_message_id(msg) if msg_id not in parent_message_ids: new_messages.append(msg) return new_messages def rebuild_message_history_by_diff(checkpoints: List[CheckpointTuple]) -> List[BaseMessage]: """ 通过比较相邻 checkpoint 的差异来重建完整的消息历史 这个方法通过比较每个 checkpoint 与其父 checkpoint 的差异, 提取新增的消息,从而重建完整的消息历史。 这样可以避免 SummarizationMiddleware 总结导致的消息丢失问题。 Args: checkpoints: checkpoint 列表,通常是从新到旧排列的 Returns: 完整的消息历史列表(按时间顺序) """ all_messages = [] # 创建一个 checkpoint_id 到 checkpoint 的映射 checkpoint_map = {} for checkpoint_tuple in checkpoints: checkpoint_id = checkpoint_tuple.config["configurable"]["checkpoint_id"] checkpoint_map[checkpoint_id] = checkpoint_tuple # 按时间顺序遍历所有 checkpoint(从旧到新) for checkpoint_tuple in reversed(checkpoints): checkpoint_id = checkpoint_tuple.config["configurable"]["checkpoint_id"] parent_config = checkpoint_tuple.parent_config parent_checkpoint_id = ( parent_config["configurable"]["checkpoint_id"] if parent_config else None ) checkpoint = checkpoint_tuple.checkpoint # 获取父 checkpoint parent_checkpoint = None if parent_checkpoint_id and parent_checkpoint_id in checkpoint_map: parent_checkpoint = checkpoint_map[parent_checkpoint_id].checkpoint # 提取新增的消息 new_messages = extract_new_messages_from_checkpoint(checkpoint, parent_checkpoint) # 将新增的消息添加到列表中 all_messages.extend(new_messages) return all_messages