huoyan-enterprise/backend/services/vector_service.py

1923 lines
75 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
向量化处理服务
"""
import os
import io
import tempfile
import base64
import time
import asyncio
from typing import List, Tuple, Optional, Dict
from pathlib import Path
from dataclasses import dataclass
from concurrent.futures import ThreadPoolExecutor, as_completed
from core.config import settings
from langchain_community.document_loaders import (
PyPDFLoader,
WebBaseLoader,
UnstructuredWordDocumentLoader,
UnstructuredExcelLoader,
UnstructuredPowerPointLoader,
TextLoader,
CSVLoader,
JSONLoader,
UnstructuredHTMLLoader,
UnstructuredMarkdownLoader,
PythonLoader,
UnstructuredXMLLoader,
)
# 尝试导入阿里云 OCR SDK
try:
from alibabacloud_ocr_api20210707.client import Client as OcrClient
from alibabacloud_tea_openapi import models as open_api_models
from alibabacloud_ocr_api20210707 import models as ocr_models
from alibabacloud_darabonba_stream.client import Client as StreamClient
from alibabacloud_tea_util import models as util_models
ALIYUN_OCR_AVAILABLE = True
except ImportError:
ALIYUN_OCR_AVAILABLE = False
OcrClient = None
StreamClient = None
# 尝试导入 python-docx 用于提取 DOCX 中的图片
try:
from docx import Document as DocxDocument
from docx.oxml.text.paragraph import CT_P
from docx.oxml.table import CT_Tbl
from docx.table import _Cell, Table
from docx.text.paragraph import Paragraph
PYTHON_DOCX_AVAILABLE = True
except ImportError:
PYTHON_DOCX_AVAILABLE = False
DocxDocument = None
# 尝试导入 Pillow 用于图片处理
try:
from PIL import Image
PILLOW_AVAILABLE = True
except ImportError:
PILLOW_AVAILABLE = False
Image = None
# 尝试导入 PyMuPDF 用于 PDF 处理
try:
import fitz # PyMuPDF
PYMUPDF_AVAILABLE = True
except ImportError:
PYMUPDF_AVAILABLE = False
fitz = None
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_ollama import OllamaEmbeddings
from langchain_openai import OpenAIEmbeddings
from langchain_chroma import Chroma
import bs4
from logger.logging import get_logger
from core.config import settings
from services.kb_text_limits import validate_kb_text_length, validate_chat_file_text_length
logger = get_logger(__name__)
@dataclass
class ProcessResult:
"""文档处理结果"""
success: bool
chunks: List[Tuple[int, str, dict, str]]
chunk_count: int
error_message: Optional[str] = None
extracted_image_paths: Optional[List[str]] = None # DOCX 中提取的图片路径(供视觉模型使用)
class VectorService:
"""向量化处理服务类"""
# 文件类型与加载器的映射表
LOADER_MAPPING: Dict[str, tuple] = {
# 文档类型DOCX 使用增强处理,提取图片并 OCR
".pdf": ("PyPDFLoader", None),
".docx": ("docx_with_images", "UnstructuredWordDocumentLoader"),
".ppt": ("UnstructuredPowerPointLoader", None),
".pptx": ("UnstructuredPowerPointLoader", None),
# 图片文件(使用阿里云 OCR
".png": ("image_ocr", None),
".jpg": ("image_ocr", None),
".jpeg": ("image_ocr", None),
".bmp": ("image_ocr", None),
# 其他文档类型
".txt": ("TextLoader", None),
".md": ("UnstructuredMarkdownLoader", "TextLoader"),
".xlsx": ("UnstructuredExcelLoader", None),
".xls": ("UnstructuredExcelLoader", None),
".csv": ("CSVLoader", None),
".json": ("JSONLoader", None),
".html": ("UnstructuredHTMLLoader", None),
".htm": ("UnstructuredHTMLLoader", None),
".xml": ("UnstructuredXMLLoader", None),
".py": ("PythonLoader", None),
}
def __init__(self):
"""初始化向量服务"""
# 初始化嵌入模型
# self.embedding = OllamaEmbeddings(model="nomic-embed-text")
# DashScope 兼容网关只接受字符串 input默认 check_embedding_ctx_length=True
# 会用 tiktoken 转成 token id 列表再请求,导致 400contents is neither str nor list of str
print(settings.dashscope_api_key, settings.dashscope_api_base)
self.embedding = OpenAIEmbeddings(
model="text-embedding-v4",
api_key=os.getenv("ZL_DASHSCOPE_API_KEY"), # 如果您没有配置环境变量请在此处用您的API Key进行替换
base_url=os.getenv("ZL_DASHSCOPE_API_BASE"),
check_embedding_ctx_length=False,
)
# 文本分割器配置(参考 server增大 chunk_size 保留更多上下文)
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=4096, # 从 1000 增加到 4096保留更多完整内容
chunk_overlap=200, # 保持适度重叠,确保语义连续性
add_start_index=True
)
# 向量库存储路径
self.vector_store_path = settings.chroma_persist_directory or "./chroma_db"
# 初始化阿里云 OCR图片、扫描 PDF、DOCX 内嵌图均依赖云端识别)
self.ocr_engine = None
if ALIYUN_OCR_AVAILABLE and settings.ocr_access_key_id and settings.ocr_access_key_secret:
try:
config = open_api_models.Config(
access_key_id=settings.ocr_access_key_id,
access_key_secret=settings.ocr_access_key_secret,
endpoint=settings.ocr_endpoint
)
self.ocr_engine = OcrClient(config)
logger.info("✅ 阿里云 OCR 已启用,将使用云端 OCR 服务识别图片文字")
except Exception as e:
logger.warning(f"⚠️ 阿里云 OCR 初始化失败: {e}")
elif not ALIYUN_OCR_AVAILABLE:
logger.warning("⚠️ 阿里云 OCR SDK 未安装,图片与扫描件 OCR 不可用")
else:
logger.info(" 未配置阿里云 OCR需要 OCR_ACCESS_KEY_ID 和 OCR_ACCESS_KEY_SECRET图片 OCR 将不可用")
if not self.ocr_engine:
logger.warning("⚠️ OCR 服务不可用,图片与扫描件内容将无法通过 OCR 提取。请配置阿里云 OCR")
def _ocr_image(self, image_path: str) -> str:
"""
使用阿里云 OCR 识别单张图片中的文字。
Args:
image_path: 图片文件路径
Returns:
识别到的文字内容
"""
if not self.ocr_engine:
return ""
try:
return self._ocr_image_aliyun(image_path)
except Exception as e:
logger.warning(f"图片 OCR 处理失败: {e}")
return ""
def _ocr_image_aliyun(self, image_path: str) -> str:
"""
使用阿里云 OCR 识别图片中的文字
Args:
image_path: 图片文件路径
Returns:
识别到的文字内容
"""
try:
logger.debug(f"🔍 [阿里云OCR] 开始识别图片: {os.path.basename(image_path)}")
# 读取图片文件为字节流
with open(image_path, 'rb') as f:
image_bytes = f.read()
image_size_kb = len(image_bytes) / 1024
logger.debug(f"📊 [阿里云OCR] 图片大小: {image_size_kb:.2f}KB")
# 使用 StreamClient 读取字节流(阿里云 SDK 要求的格式)
body_stream = StreamClient.read_from_bytes(image_bytes)
logger.debug(f"📦 [阿里云OCR] 字节流已创建")
# 构建请求
request = ocr_models.RecognizeGeneralRequest(body=body_stream)
# 运行时选项
runtime = util_models.RuntimeOptions()
logger.debug(f"☁️ [阿里云OCR] 调用 API: recognize_general_with_options")
# 调用阿里云 OCR API使用 with_options 版本)
response = self.ocr_engine.recognize_general_with_options(request, runtime)
logger.debug(f"✅ [阿里云OCR] API 调用成功")
# 解析结果
if not response or not response.body or not response.body.data:
logger.debug(f"⚠️ [阿里云OCR] 未返回数据: {os.path.basename(image_path)}")
return ""
# 解析 JSON 数据
import json
logger.debug(f"📝 [阿里云OCR] 开始解析返回数据")
data = json.loads(response.body.data)
# 提取 content 字段
if not data or 'content' not in data:
logger.debug(f"⚠️ [阿里云OCR] 未识别到文字: {os.path.basename(image_path)}")
return ""
ocr_content = data['content']
logger.debug(f"📋 [阿里云OCR] content 类型: {type(ocr_content).__name__}")
# content 可能是字符串或字典列表
if isinstance(ocr_content, str):
full_text = ocr_content
logger.debug(f"📄 [阿里云OCR] 直接提取文本: {len(full_text)} 字符")
elif isinstance(ocr_content, list):
# 如果是列表,提取每个元素的 text 字段
texts = []
for idx, item in enumerate(ocr_content):
if isinstance(item, dict) and 'text' in item:
text = item['text']
if text and text.strip():
texts.append(text)
logger.debug(f" 📌 [阿里云OCR] 行 {idx + 1}: {text[:50]}{'...' if len(text) > 50 else ''}")
elif isinstance(item, str):
texts.append(item)
full_text = "\n".join(texts)
logger.debug(f"📄 [阿里云OCR] 合并 {len(texts)} 行文本: {len(full_text)} 字符")
else:
full_text = str(ocr_content)
logger.debug(f"📄 [阿里云OCR] 转换为字符串: {len(full_text)} 字符")
if not full_text or not full_text.strip():
logger.debug(f"⚠️ [阿里云OCR] 识别结果为空: {os.path.basename(image_path)}")
return ""
logger.info(f"✅ [阿里云OCR] 识别成功: {os.path.basename(image_path)}, 识别到 {len(full_text)} 字符")
return full_text
except Exception as e:
logger.warning(f"阿里云 OCR 处理失败: {e}")
import traceback
logger.debug(traceback.format_exc())
return ""
def _process_image_ocr(self, file_path: str) -> List:
"""
使用阿里云 OCR 处理图片并提取文字。
Args:
file_path: 图片文件路径
Returns:
Document 列表
"""
if not self.ocr_engine:
raise Exception("OCR 服务不可用,无法处理图片。请配置阿里云 OCROCR_ACCESS_KEY_ID 与 OCR_ACCESS_KEY_SECRET")
try:
full_text = self._ocr_image(file_path)
if not full_text:
logger.warning(f"图片 OCR 未识别到文字: {file_path}")
return []
logger.info(f"图片 OCR 成功(阿里云 OCR{len(full_text)} 字符")
# 创建 Document 对象(模拟 LangChain Document 格式)
from langchain_core.documents import Document
doc = Document(
page_content=full_text,
metadata={
"source": file_path,
"file_type": "image",
"has_ocr": True,
"ocr_provider": "aliyun"
}
)
return [doc]
except Exception as e:
logger.error(f"图片 OCR 处理失败: {e}")
raise
def _extract_images_from_docx(self, docx_path: str) -> List[str]:
"""
从 DOCX 文件中提取所有图片并转换为标准格式PNG/JPG
Args:
docx_path: DOCX 文件路径
Returns:
临时图片文件路径列表
"""
if not PYTHON_DOCX_AVAILABLE or not PILLOW_AVAILABLE:
logger.warning("⚠️ python-docx 或 Pillow 不可用,无法提取 DOCX 中的图片")
return []
logger.info(f"🖼️ [DOCX图片提取] 开始从 DOCX 提取图片: {os.path.basename(docx_path)}")
image_paths = []
try:
doc = DocxDocument(docx_path)
total_rels = len(doc.part.rels.values())
logger.debug(f"📊 [DOCX图片提取] 文档关系总数: {total_rels}")
# 遍历文档中的所有关系relationships
for idx, rel in enumerate(doc.part.rels.values()):
# 检查是否是图片关系
if "image" in rel.target_ref:
try:
# 获取图片数据
image_data = rel.target_part.blob
# 使用 Pillow 打开图片并转换为标准格式
try:
# 从二进制数据创建图片对象
image = Image.open(io.BytesIO(image_data))
image_size_kb = len(image_data) / 1024
logger.debug(f" 📸 [DOCX图片提取] 图片 {idx + 1}: 原始模式={image.mode}, 大小={image.size}, {image_size_kb:.2f}KB")
# 转换为 RGB 模式(确保兼容性)
if image.mode in ('RGBA', 'LA', 'P'):
logger.debug(f" 🔄 [DOCX图片提取] 图片 {idx + 1}: 转换模式 {image.mode} -> RGB")
# 如果有透明通道,创建白色背景
background = Image.new('RGB', image.size, (255, 255, 255))
if image.mode == 'P':
image = image.convert('RGBA')
background.paste(image, mask=image.split()[-1] if image.mode in ('RGBA', 'LA') else None)
image = background
elif image.mode != 'RGB':
logger.debug(f" 🔄 [DOCX图片提取] 图片 {idx + 1}: 转换模式 {image.mode} -> RGB")
image = image.convert('RGB')
# 保存为 JPG 格式(阿里云 OCR 支持,且文件更小)
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
image.save(tmp_file.name, 'JPEG', quality=95)
tmp_file.close()
image_paths.append(tmp_file.name)
logger.info(f"✅ [DOCX图片提取] 图片 {idx + 1} 已提取并转换: {os.path.basename(tmp_file.name)}")
except Exception as e:
logger.warning(f"图片 {idx + 1} 格式转换失败: {e},尝试直接保存")
# 如果转换失败,尝试直接保存原始数据
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
tmp_file.write(image_data)
tmp_file.close()
image_paths.append(tmp_file.name)
except Exception as e:
logger.warning(f"提取图片 {idx + 1} 失败: {e}")
continue
logger.info(f"从 DOCX 提取了 {len(image_paths)} 张图片(已转换为 JPG 格式)")
return image_paths
except Exception as e:
logger.error(f"提取 DOCX 图片失败: {e}")
return []
def _ocr_images_concurrent(self, image_paths: List[str], max_workers: int = 4) -> List[Tuple[int, str]]:
"""
并发处理多张图片的 OCR 识别
Args:
image_paths: 图片路径列表
max_workers: 最大并发线程数(默认 4
Returns:
List[Tuple[int, str]]: [(图片索引, OCR文本), ...]
"""
if not image_paths:
return []
logger.info(f"🚀 [并发OCR] 开始并发识别 {len(image_paths)} 张图片(并发数: {max_workers}, 引擎: 阿里云OCR")
results = []
start_time = time.time()
def ocr_single_image(idx: int, image_path: str) -> Tuple[int, str]:
"""处理单张图片"""
try:
logger.debug(f" 🔄 [并发OCR] 线程开始处理图片 {idx + 1}")
ocr_text = self._ocr_image(image_path)
if ocr_text:
logger.info(f"✅ [并发OCR] 图片 {idx + 1} 识别成功: {len(ocr_text)} 字符")
return (idx, ocr_text)
else:
logger.info(f"⚠️ [并发OCR] 图片 {idx + 1} 未识别到文字")
return (idx, "")
except Exception as e:
logger.warning(f"❌ [并发OCR] 图片 {idx + 1} 识别失败: {e}")
return (idx, "")
# 使用线程池并发处理
logger.debug(f"📦 [并发OCR] 创建线程池max_workers={max_workers}")
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交所有任务
future_to_idx = {
executor.submit(ocr_single_image, idx, img_path): idx
for idx, img_path in enumerate(image_paths)
}
# 收集结果
completed_count = 0
for future in as_completed(future_to_idx):
try:
result = future.result()
results.append(result)
completed_count += 1
logger.debug(f" ✓ [并发OCR] 已完成 {completed_count}/{len(image_paths)} 张图片")
except Exception as e:
idx = future_to_idx[future]
logger.error(f"❌ [并发OCR] 图片 {idx + 1} OCR 任务执行失败: {e}")
# 按索引排序,保持原始顺序
results.sort(key=lambda x: x[0])
# 统计结果
elapsed_time = time.time() - start_time
success_count = sum(1 for _, text in results if text)
total_chars = sum(len(text) for _, text in results)
logger.info(f"🎉 [并发OCR] 完成!总计 {len(image_paths)} 张图片,成功 {success_count} 张,识别 {total_chars} 字符,耗时 {elapsed_time:.2f}")
logger.info(f"📊 [并发OCR] 平均速度: {elapsed_time/len(image_paths):.2f}秒/张")
return results
def _process_docx_with_images(self, file_path: str) -> Tuple[List, List[str]]:
"""
处理 DOCX 文件,提取文字和图片中的文字(使用多线程并发 OCR
Args:
file_path: DOCX 文件路径
Returns:
Tuple[List, List[str]]: (Document 列表, 提取的图片路径列表)
"""
from langchain_core.documents import Document
try:
# 1. 使用标准加载器提取文字内容
loader = UnstructuredWordDocumentLoader(file_path)
text_docs = loader.load()
text_content = ""
if text_docs:
text_content = "\n\n".join([doc.page_content for doc in text_docs])
logger.info(f"DOCX 文字内容提取完成,共 {len(text_content)} 字符")
# 2. 提取并识别图片中的文字(使用并发处理)
image_texts = []
extracted_image_paths = [] # 保存图片路径,稍后用于视觉模型分析
if self.ocr_engine and PYTHON_DOCX_AVAILABLE:
image_paths = self._extract_images_from_docx(file_path)
extracted_image_paths = image_paths.copy() # 保存副本供后续使用
if image_paths:
# 使用多线程并发处理(最多并发 4 张图片)
logger.info(f"开始并发 OCR 识别 {len(image_paths)} 张图片(并发数: 4")
ocr_results = self._ocr_images_concurrent(image_paths, max_workers=4)
# 整理结果
for idx, ocr_text in ocr_results:
if ocr_text:
image_texts.append(f"\n\n[图片 {idx + 1} 内容]\n{ocr_text}")
logger.info(f"OCR 识别完成,共识别到 {len(image_texts)} 张图片的文字")
# 3. 合并文字内容和图片内容
full_content = text_content
if image_texts:
full_content += "\n\n" + "\n\n".join(image_texts)
logger.info(f"DOCX 总内容(文字+图片):{len(full_content)} 字符")
# 4. 创建 Document 对象metadata 中不包含列表类型,避免 ChromaDB 错误)
doc = Document(
page_content=full_content,
metadata={
"source": file_path,
"file_type": "docx",
"has_images": len(image_texts) > 0,
"image_count": len(image_texts),
"has_ocr": len(image_texts) > 0
# 注意:不在这里保存 extracted_image_paths因为 ChromaDB 不支持列表类型
}
)
return [doc], extracted_image_paths
except Exception as e:
logger.error(f"处理 DOCX 文件失败: {e}")
# 降级到标准处理方式
loader = UnstructuredWordDocumentLoader(file_path)
return loader.load(), [] # 返回元组格式
def _extract_images_from_pdf(self, pdf_path: str) -> List[str]:
"""
从 PDF 文件中提取所有页面为图片
Args:
pdf_path: PDF 文件路径
Returns:
临时图片文件路径列表
"""
if not PYMUPDF_AVAILABLE or not PILLOW_AVAILABLE:
logger.warning("⚠️ PyMuPDF 或 Pillow 不可用,无法提取 PDF 页面为图片")
return []
logger.info(f"📄 [PDF页面提取] 开始从 PDF 提取页面为图片: {os.path.basename(pdf_path)}")
image_paths = []
try:
# 打开 PDF 文件
pdf_document = fitz.open(pdf_path)
total_pages = len(pdf_document)
logger.info(f"📊 [PDF页面提取] PDF 总页数: {total_pages}")
# 遍历每一页
for page_num in range(len(pdf_document)):
try:
logger.debug(f" 🔄 [PDF页面提取] 处理第 {page_num + 1}/{total_pages}")
page = pdf_document[page_num]
# 将页面转换为图片(提高分辨率以提升 OCR 效果)
# zoom=2 表示 2 倍分辨率DPI 约 144
mat = fitz.Matrix(2, 2)
pix = page.get_pixmap(matrix=mat)
# 保存为临时图片文件
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
pix.save(tmp_file.name)
tmp_file.close()
file_size_kb = os.path.getsize(tmp_file.name) / 1024
image_paths.append(tmp_file.name)
logger.info(f"✅ [PDF页面提取] 第 {page_num + 1} 页已转换: {os.path.basename(tmp_file.name)} ({file_size_kb:.2f}KB)")
except Exception as e:
logger.warning(f"PDF 页面 {page_num + 1} 转换失败: {e}")
continue
pdf_document.close()
logger.info(f"从 PDF 提取了 {len(image_paths)} 页图片(已转换为 JPG 格式)")
return image_paths
except Exception as e:
logger.error(f"提取 PDF 页面失败: {e}")
return []
def _is_image_pdf(self, docs: List) -> bool:
"""
判断 PDF 是否为图片型(扫描版)
Args:
docs: PyPDFLoader 加载的文档列表
Returns:
bool: 如果文本内容很少或为空,则认为是图片型 PDF
"""
if not docs:
return True
# 计算总文本长度
total_text = "".join([doc.page_content for doc in docs])
total_chars = len(total_text.strip())
# 如果文本内容少于 100 个字符,认为是图片型 PDF
# (排除空格和换行后仍然很少)
if total_chars < 100:
logger.info(f"检测到图片型 PDF文本内容少于 100 字符:{total_chars} 字符)")
return True
return False
def _process_pdf_with_ocr(self, file_path: str) -> List:
"""
处理图片型 PDF扫描版使用 OCR 识别内容
Args:
file_path: PDF 文件路径
Returns:
Document 列表
"""
from langchain_core.documents import Document
try:
if not self.ocr_engine:
raise Exception("OCR 服务不可用,无法处理图片型 PDF")
# 1. 提取 PDF 每一页为图片
image_paths = self._extract_images_from_pdf(file_path)
if not image_paths:
logger.warning("未能从 PDF 提取任何页面")
return []
# 2. 使用多线程并发 OCR 识别
logger.info(f"开始并发 OCR 识别 {len(image_paths)} 页 PDF并发数: 4")
ocr_results = self._ocr_images_concurrent(image_paths, max_workers=4)
# 3. 整理每页内容
page_texts = []
for idx, ocr_text in ocr_results:
if ocr_text:
page_texts.append(f"[第 {idx + 1} 页]\n{ocr_text}")
logger.info(f"{idx + 1} 页 OCR 识别到 {len(ocr_text)} 字符")
else:
logger.warning(f"{idx + 1} 页 OCR 未识别到文字")
# 4. 清理临时图片文件
for img_path in image_paths:
try:
if os.path.exists(img_path):
os.remove(img_path)
except:
pass
if not page_texts:
logger.warning("PDF OCR 未识别到任何文字内容")
return []
# 5. 合并所有页面内容
full_content = "\n\n".join(page_texts)
logger.info(f"PDF OCR 完成,总计 {len(page_texts)} 页,共 {len(full_content)} 字符")
# 6. 创建 Document 对象
doc = Document(
page_content=full_content,
metadata={
"source": file_path,
"file_type": "pdf",
"is_image_pdf": True,
"page_count": len(page_texts),
"has_ocr": True,
"ocr_provider": "aliyun"
}
)
return [doc]
except Exception as e:
logger.error(f"处理图片型 PDF 失败: {e}")
import traceback
logger.error(traceback.format_exc())
return []
def _get_loader_for_file(self, file_path: str, file_type: str = None):
"""
根据文件类型获取合适的文档加载器
Args:
file_path: 文件路径
file_type: 文件类型(可选,如果不提供则从文件路径推断)
Returns:
加载器实例或 None或者返回 "image_ocr" 字符串表示需要 OCR 处理
"""
# 确定文件扩展名
if file_type:
ext = f".{file_type.lower()}"
else:
ext = Path(file_path).suffix.lower()
# 获取加载器配置
loader_config = self.LOADER_MAPPING.get(ext)
if not loader_config:
logger.warning(f"不支持的文件类型: {ext}")
return None
primary_loader, fallback_loader = loader_config
# 特殊处理图片 OCR
if primary_loader == "image_ocr":
if not self.ocr_engine:
raise Exception("阿里云 OCR 不可用,无法处理图片文件。请配置 OCR_ACCESS_KEY_ID 与 OCR_ACCESS_KEY_SECRET")
logger.info("使用阿里云 OCR 处理图片文件")
return "image_ocr" # 返回特殊标记
# 特殊处理 DOCX提取图片并 OCR
if primary_loader == "docx_with_images":
if self.ocr_engine and PYTHON_DOCX_AVAILABLE:
logger.info(f"使用增强模式处理 DOCX提取图片并 OCR")
return "docx_with_images" # 返回特殊标记
else:
logger.info(f"OCR 或 python-docx 不可用,使用标准 DOCX 加载器")
# 降级到标准加载器
# 使用备选加载器
if fallback_loader:
loader_class = globals().get(fallback_loader)
if loader_class:
logger.info(f"使用 {fallback_loader} 加载文件")
# 特殊处理 TextLoader需要编码参数
if fallback_loader == "TextLoader":
return self._load_text_with_encoding(file_path, loader_class)
return loader_class(file_path)
# 如果没有备选,使用主加载器
if primary_loader:
loader_class = globals().get(primary_loader)
if loader_class:
logger.info(f"使用 {primary_loader} 加载文件")
# 特殊处理 TextLoader
if primary_loader == "TextLoader":
return self._load_text_with_encoding(file_path, loader_class)
return loader_class(file_path)
return None
def _load_text_with_encoding(self, file_path: str, loader_class):
"""
尝试多种编码加载文本文件
Args:
file_path: 文件路径
loader_class: TextLoader 类
Returns:
加载器实例或 None
"""
encodings = ["utf-8", "gbk", "gb2312", "latin-1"]
for encoding in encodings:
try:
loader = loader_class(file_path, encoding=encoding)
# 尝试加载以验证编码是否正确
docs = loader.load()
logger.info(f"成功使用编码 {encoding} 加载文本文件")
return loader
except (UnicodeDecodeError, Exception) as e:
continue
logger.warning(f"无法使用任何编码加载文本文件: {file_path}")
return None
def get_vector_store(self, collection_name: str) -> Chroma:
"""
获取向量库实例
Args:
collection_name: 集合名称(使用知识库 ID
Returns:
Chroma: 向量库实例
"""
persist_directory = os.path.join(self.vector_store_path, collection_name)
vector_store = Chroma(
host=settings.chroma_host,
port=settings.chroma_port,
collection_name=collection_name,
embedding_function=self.embedding,
# persist_directory=persist_directory
)
return vector_store
async def process_document(
self,
file_path: str,
knowledge_base_id: int,
file_type: str = "pdf",
file_id: Optional[int] = None,
source_url: Optional[str] = None
) -> ProcessResult:
"""
处理文档文件:加载、分割、向量化(支持多种文档格式,包括图片 OCR
支持的文件类型:
- PDF、DOCX、PPT/PPTX支持 OCR 提取图片文字)
- 图片PNG、JPG、JPEG、BMP需要配置阿里云 OCR
- 其他TXT、MD、Excel、CSV、JSON、HTML、XML、Python 等
Args:
file_path: 文件路径
knowledge_base_id: 知识库 ID
file_type: 文件类型(如 pdf、docx、xlsx、png 等)
Returns:
ProcessResult: 处理结果对象
"""
try:
logger.info(f"开始处理文件: {file_path}, 类型: {file_type}")
# 1. 获取合适的加载器
loader = self._get_loader_for_file(file_path, file_type)
if not loader:
error_msg = f"不支持的文件类型: {file_type}"
logger.warning(error_msg)
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
# 2. 加载文档(特殊处理图片 OCR 和 DOCX放到线程池执行
if loader == "image_ocr":
logger.info("🔄 在线程池中执行图片 OCR...")
docs = await asyncio.to_thread(self._process_image_ocr, file_path)
elif loader == "docx_with_images":
logger.info("🔄 在线程池中处理 DOCX 文件(提取图片并 OCR...")
docs, _ = await asyncio.to_thread(self._process_docx_with_images, file_path) # 忽略图片路径(知识库暂不使用视觉模型)
else:
# 文档加载也可能是 CPU 密集型的,放到线程池中
logger.info(f"🔄 在线程池中加载文档(类型: {file_type}...")
docs = await asyncio.to_thread(loader.load)
# 特殊处理:检测 PDF 是否为图片型(扫描版)
if file_type.lower() == "pdf" and self._is_image_pdf(docs):
logger.info("检测到图片型 PDF扫描版切换到 OCR 模式")
if self.ocr_engine:
logger.info("🔄 在线程池中执行 PDF OCR...")
docs = await asyncio.to_thread(self._process_pdf_with_ocr, file_path)
if not docs:
error_msg = "图片型 PDF OCR 识别失败"
logger.warning(error_msg)
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
else:
error_msg = "检测到图片型 PDF扫描版但 OCR 服务不可用"
logger.warning(error_msg)
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
logger.info(f"文档加载完成,共 {len(docs)} 个文档片段")
if not docs:
error_msg = "未能从文件加载到任何内容"
logger.warning(error_msg)
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
total_chars = sum(len(d.page_content or "") for d in docs)
try:
validate_kb_text_length(total_chars)
except ValueError as e:
error_msg = str(e)
logger.warning(error_msg)
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
# 2. 分割文本
all_splits = self.text_splitter.split_documents(docs)
logger.info(f"文本分割完成,共 {len(all_splits)} 个块")
# 检查是否有内容
if not all_splits:
error_msg = "文档分割后没有内容,可能是空白文档"
logger.warning(error_msg)
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
try:
validate_kb_text_length(total_chars, chunk_count=len(all_splits))
except ValueError as e:
error_msg = str(e)
logger.warning(error_msg)
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
# 3. 向量化并存储
collection_name = f"kb_{knowledge_base_id}"
vector_store = self.get_vector_store(collection_name)
# 🔑 关键:在向量化前,将 file_id、chunk_index 和 source_url 添加到 metadata
if file_id is not None or source_url is not None:
for idx, doc in enumerate(all_splits):
if not doc.metadata:
doc.metadata = {}
if file_id is not None:
doc.metadata['file_id'] = file_id
doc.metadata['chunk_index'] = idx
if source_url is not None:
doc.metadata['source'] = source_url # 🔑 替换为 OSS URL
if file_id is not None:
logger.info(f"✅ 已为 {len(all_splits)} 个chunks设置 file_id={file_id}")
if source_url is not None:
logger.info(f"✅ 已为 {len(all_splits)} 个chunks设置 source={source_url}")
# 添加文档到向量库
vector_ids = vector_store.add_documents(documents=all_splits)
logger.info(f"向量化完成,共 {len(vector_ids)} 个向量")
# 4. 准备返回数据
chunks = []
for idx, (doc, vector_id) in enumerate(zip(all_splits, vector_ids)):
chunks.append((
idx, # chunk_index
doc.page_content, # content
doc.metadata, # metadata
vector_id # vector_id
))
return ProcessResult(success=True, chunks=chunks, chunk_count=len(chunks))
except Exception as e:
error_msg = f"处理文件失败 ({file_type}): {str(e)}"
logger.error(error_msg)
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
async def process_pdf(
self,
file_path: str,
knowledge_base_id: int
) -> ProcessResult:
"""
处理 PDF 文件:加载、分割、向量化(兼容旧接口)
Args:
file_path: PDF 文件路径
knowledge_base_id: 知识库 ID
Returns:
ProcessResult: 处理结果对象
"""
return await self.process_document(file_path, knowledge_base_id, "pdf")
async def process_url(
self,
url: str,
knowledge_base_id: int
) -> ProcessResult:
"""
处理 URL加载网页内容、分割、向量化
Args:
url: 网页 URL
knowledge_base_id: 知识库 ID
Returns:
ProcessResult: 处理结果对象
"""
try:
logger.info(f"开始处理 URL: {url}")
# 1. 加载网页内容(放到线程池执行,避免阻塞事件循环)
# 使用 bs4 过滤,只保留主要内容(可以根据需要调整)
bs4_strainer = bs4.SoupStrainer()
loader = WebBaseLoader(
web_paths=(url,),
bs_kwargs={"parse_only": bs4_strainer}
)
logger.info("🔄 在线程池中加载网页内容...")
docs = await asyncio.to_thread(loader.load)
logger.info(f"网页加载完成,共 {len(docs)} 个文档")
if not docs:
error_msg = "未能从 URL 加载到任何内容"
logger.warning(error_msg)
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
total_chars = sum(len(d.page_content or "") for d in docs)
try:
validate_kb_text_length(total_chars)
except ValueError as e:
error_msg = str(e)
logger.warning(error_msg)
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
# 2. 分割文本
all_splits = self.text_splitter.split_documents(docs)
logger.info(f"文本分割完成,共 {len(all_splits)} 个块")
# 检查是否有内容
if not all_splits:
error_msg = "网页分割后没有内容,可能是空白页面或无法提取文本"
logger.warning(error_msg)
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
try:
validate_kb_text_length(total_chars, chunk_count=len(all_splits))
except ValueError as e:
error_msg = str(e)
logger.warning(error_msg)
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
# 3. 向量化并存储
collection_name = f"kb_{knowledge_base_id}"
vector_store = self.get_vector_store(collection_name)
# 添加文档到向量库
vector_ids = vector_store.add_documents(documents=all_splits)
logger.info(f"向量化完成,共 {len(vector_ids)} 个向量")
# 4. 准备返回数据
chunks = []
for idx, (doc, vector_id) in enumerate(zip(all_splits, vector_ids)):
chunks.append((
idx, # chunk_index
doc.page_content, # content
doc.metadata, # metadata
vector_id # vector_id
))
return ProcessResult(success=True, chunks=chunks, chunk_count=len(chunks))
except Exception as e:
error_msg = f"处理 URL 失败: {str(e)}"
logger.error(error_msg)
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
async def process_chat_thread_file(
self,
file_path: str,
thread_id: str,
file_type: str = "pdf",
file_id: Optional[int] = None,
source_url: Optional[str] = None
) -> ProcessResult:
"""
处理聊天对话文件:加载、分割、向量化(支持多种格式,包括 URL 和图片 OCR
支持的文件类型:
- PDF、DOCX、PPT/PPTX支持 OCR 提取图片文字)
- 图片PNG、JPG、JPEG、BMP需要配置阿里云 OCR
- 其他TXT、MD、Excel、CSV、JSON、HTML、XML 等
- URL网页链接
Args:
file_path: 文件路径或 URL
thread_id: 会话线程 ID
file_type: 文件类型pdf、docx、xlsx、png、url 等)
Returns:
ProcessResult: 处理结果对象
"""
try:
logger.info(f"开始处理聊天文件: {file_path}, thread_id: {thread_id}, 类型: {file_type}")
docs = []
extracted_image_paths = [] # 用于保存 DOCX 中提取的图片路径
# 特殊处理 URL放到线程池执行
if file_type == "url":
bs4_strainer = bs4.SoupStrainer()
loader = WebBaseLoader(
web_paths=(file_path,),
bs_kwargs={"parse_only": bs4_strainer}
)
logger.info("🔄 在线程池中加载网页内容...")
docs = await asyncio.to_thread(loader.load)
logger.info(f"网页加载完成,共 {len(docs)} 个文档")
else:
# 使用统一的加载器选择逻辑
loader = self._get_loader_for_file(file_path, file_type)
print("loader", loader)
if not loader:
error_msg = f"不支持的文件类型: {file_type}"
logger.warning(error_msg)
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
# 特殊处理图片 OCR 和 DOCX放到线程池执行避免阻塞事件循环
if loader == "image_ocr":
logger.info("🔄 在线程池中执行图片 OCR...")
docs = await asyncio.to_thread(self._process_image_ocr, file_path)
elif loader == "docx_with_images":
logger.info("🔄 在线程池中处理 DOCX 文件(提取图片并 OCR...")
docs, extracted_image_paths = await asyncio.to_thread(self._process_docx_with_images, file_path)
else:
# 文档加载也可能是 CPU 密集型的,放到线程池中
logger.info(f"🔄 在线程池中加载文档(类型: {file_type}...")
docs = await asyncio.to_thread(loader.load)
# 特殊处理:检测 PDF 是否为图片型(扫描版)
if file_type.lower() == "pdf" and self._is_image_pdf(docs):
logger.info("检测到图片型 PDF扫描版切换到 OCR 模式")
if self.ocr_engine:
logger.info("🔄 在线程池中执行 PDF OCR...")
docs = await asyncio.to_thread(self._process_pdf_with_ocr, file_path)
if not docs:
error_msg = "图片型 PDF OCR 识别失败"
logger.warning(error_msg)
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
else:
error_msg = "检测到图片型 PDF扫描版但 OCR 服务不可用"
logger.warning(error_msg)
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
logger.info(f"文档加载完成,共 {len(docs)} 个文档片段")
if not docs:
error_msg = "未能加载到任何内容"
logger.warning(error_msg)
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
total_chars = sum(len(d.page_content or "") for d in docs)
try:
validate_chat_file_text_length(total_chars)
except ValueError as e:
error_msg = str(e)
logger.warning(error_msg)
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
# 分割文本
all_splits = self.text_splitter.split_documents(docs)
logger.info(f"文本分割完成,共 {len(all_splits)} 个块")
# 检查是否有内容
if not all_splits:
error_msg = "文档分割后没有内容,可能是空白文档或无法提取文本"
logger.warning(error_msg)
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
try:
validate_chat_file_text_length(total_chars, chunk_count=len(all_splits))
except ValueError as e:
error_msg = str(e)
logger.warning(error_msg)
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
# 向量化并存储(使用 thread_id 作为集合名)
collection_name = f"thread_{thread_id}"
vector_store = self.get_vector_store(collection_name)
# 🔑 关键:在向量化前,将 file_id、chunk_index 和 source_url 添加到 metadata
if file_id is not None or source_url is not None:
for idx, doc in enumerate(all_splits):
if not doc.metadata:
doc.metadata = {}
if file_id is not None:
doc.metadata['file_id'] = file_id
doc.metadata['chunk_index'] = idx
if source_url is not None:
doc.metadata['source'] = source_url # 🔑 替换为 OSS URL
if file_id is not None:
logger.info(f"✅ 已为 {len(all_splits)} 个chunks设置 file_id={file_id}")
if source_url is not None:
logger.info(f"✅ 已为 {len(all_splits)} 个chunks设置 source={source_url}")
# 添加文档到向量库
vector_ids = vector_store.add_documents(documents=all_splits)
logger.info(f"向量化完成,共 {len(vector_ids)} 个向量")
# 准备返回数据
chunks = []
for idx, (doc, vector_id) in enumerate(zip(all_splits, vector_ids)):
chunks.append((
idx, # chunk_index
doc.page_content, # content
doc.metadata, # metadata
vector_id # vector_id
))
return ProcessResult(
success=True,
chunks=chunks,
chunk_count=len(chunks),
extracted_image_paths=extracted_image_paths if extracted_image_paths else None
)
except Exception as e:
error_msg = f"处理聊天文件失败: {str(e)}"
logger.error(error_msg)
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
def search_similar_in_thread(
self,
thread_id: str,
query: str,
k: int = 5,
file_id: Optional[int] = None,
score_threshold: float = 0.0
) -> List[dict]:
"""
在聊天对话中搜索相似文档(增强版:支持过滤和阈值)
Args:
thread_id: 会话线程 ID
query: 查询文本
k: 返回结果数量
file_id: 可选,仅搜索指定文件的内容
score_threshold: 相似度阈值0-1越小越相似
Returns:
List[dict]: 相似文档列表,按相关性排序
"""
try:
collection_name = f"thread_{thread_id}"
vector_store = self.get_vector_store(collection_name)
# 构建过滤条件(如果指定了 file_id
filter_dict = None
if file_id is not None:
filter_dict = {"file_id": file_id}
# 相似度搜索增加k值以便过滤后仍有足够结果
search_k = k * 3 if file_id else k * 2
results = vector_store.similarity_search_with_score(
query,
k=search_k,
filter=filter_dict
)
# 格式化结果并应用阈值过滤
formatted_results = []
for doc, score in results:
# ChromaDB 使用距离(越小越相似),过滤掉不相关的结果
if score <= score_threshold or score_threshold == 0.0:
formatted_results.append({
"content": doc.page_content,
"metadata": doc.metadata,
"score": float(score),
"file_summary": doc.metadata.get("file_summary", "") # 包含 summary
})
# 达到所需数量即可停止
if len(formatted_results) >= k:
break
logger.info(f"向量检索完成: 查询='{query[:50]}...', 结果数={len(formatted_results)}, file_id={file_id}, 阈值={score_threshold}")
return formatted_results
except Exception as e:
logger.error(f"搜索相似文档失败: {e}")
return []
def get_all_file_chunks(
self,
thread_id: str,
file_id: int
) -> List[dict]:
"""
获取指定文件的所有chunks完整内容- 参考 server 实现
Args:
thread_id: 会话线程 ID
file_id: 文件 ID
Returns:
List[dict]: 所有 chunk 列表(按 chunk_index 排序)
"""
try:
collection_name = f"thread_{thread_id}"
logger.info(f"🔍 开始获取文件chunks: collection={collection_name}, file_id={file_id}")
vector_store = self.get_vector_store(collection_name)
# 获取该文件的所有 chunks使用 filter
# 注意ChromaDB 的 get 方法需要使用 where 参数
all_docs = vector_store.get(
where={"file_id": file_id},
include=["documents", "metadatas"]
)
logger.info(f"📦 ChromaDB返回结果: documents数量={len(all_docs.get('documents', []))}, metadatas数量={len(all_docs.get('metadatas', []))}")
# 格式化结果并按 chunk_index 排序
chunks = []
if all_docs and 'documents' in all_docs and all_docs['documents']:
logger.info(f"✅ 检测到 {len(all_docs['documents'])} 个文档")
for idx, (doc_content, metadata) in enumerate(zip(all_docs['documents'], all_docs['metadatas'])):
chunk_index = metadata.get("chunk_index", idx)
has_summary = "file_summary" in metadata
logger.info(f" - Chunk {idx}: chunk_index={chunk_index}, 内容长度={len(doc_content)}, 有摘要={has_summary}, metadata_keys={list(metadata.keys())}")
chunks.append({
"content": doc_content,
"metadata": metadata,
"chunk_index": chunk_index,
"file_summary": metadata.get("file_summary", "")
})
# 按 chunk_index 排序
chunks.sort(key=lambda x: x['chunk_index'])
logger.info(f"✅ 排序完成,共 {len(chunks)} 个chunkschunk_index范围: {chunks[0]['chunk_index']} ~ {chunks[-1]['chunk_index']}")
else:
logger.warning(f"⚠️ ChromaDB未返回任何文档: all_docs keys={list(all_docs.keys()) if all_docs else 'None'}")
logger.info(f"📊 最终返回: file_id={file_id}, 总chunks数量={len(chunks)}")
return chunks
except Exception as e:
logger.error(f"❌ 获取文件所有chunks失败: {e}")
import traceback
logger.error(traceback.format_exc())
return []
def update_file_summary_in_vectors(
self,
thread_id: str,
file_id: int,
summary: str
) -> bool:
"""
更新 ChromaDB 中指定文件所有向量的 summary metadata
Args:
thread_id: 会话线程 ID
file_id: 文件 ID
summary: 文件摘要
Returns:
bool: 是否更新成功
"""
try:
collection_name = f"thread_{thread_id}"
vector_store = self.get_vector_store(collection_name)
# 获取该文件的所有向量
all_docs = vector_store.get(
where={"file_id": file_id},
include=["metadatas"]
)
if not all_docs or 'ids' not in all_docs or not all_docs['ids']:
logger.warning(f"未找到 file_id={file_id} 的向量")
return False
# 更新每个向量的 metadata添加 file_summary
vector_ids = all_docs['ids']
updated_metadatas = []
for metadata in all_docs['metadatas']:
updated_metadata = metadata.copy()
updated_metadata['file_summary'] = summary
updated_metadatas.append(updated_metadata)
# ChromaDB 更新 metadata
vector_store._collection.update(
ids=vector_ids,
metadatas=updated_metadatas
)
logger.info(f"✅ 已更新 ChromaDB 中 {len(vector_ids)} 个向量的 summary (file_id={file_id})")
return True
except Exception as e:
logger.error(f"更新 ChromaDB summary 失败: {e}")
import traceback
logger.error(traceback.format_exc())
return False
def update_kb_file_summary_in_vectors(
self,
knowledge_base_id: int,
file_id: int,
summary: str
) -> bool:
"""
更新 ChromaDB 中指定知识库文件所有向量的 summary metadata
Args:
knowledge_base_id: 知识库 ID
file_id: 文件 ID
summary: 文件摘要
Returns:
bool: 是否更新成功
"""
try:
collection_name = f"kb_{knowledge_base_id}"
vector_store = self.get_vector_store(collection_name)
# 获取该文件的所有向量
all_docs = vector_store.get(
where={"file_id": file_id},
include=["metadatas"]
)
if not all_docs or 'ids' not in all_docs or not all_docs['ids']:
logger.warning(f"未找到 file_id={file_id} 的向量 (kb_id={knowledge_base_id})")
return False
# 更新每个向量的 metadata添加 file_summary
vector_ids = all_docs['ids']
updated_metadatas = []
for metadata in all_docs['metadatas']:
updated_metadata = metadata.copy()
updated_metadata['file_summary'] = summary
updated_metadatas.append(updated_metadata)
# ChromaDB 更新 metadata
vector_store._collection.update(
ids=vector_ids,
metadatas=updated_metadatas
)
logger.info(f"✅ 已更新 ChromaDB 中 {len(vector_ids)} 个向量的 summary (kb_id={knowledge_base_id}, file_id={file_id})")
return True
except Exception as e:
logger.error(f"更新 ChromaDB summary 失败 (kb_id={knowledge_base_id}, file_id={file_id}): {e}")
import traceback
logger.error(traceback.format_exc())
return False
def delete_thread_vectors(
self,
thread_id: str,
vector_ids: List[str]
) -> bool:
"""
删除聊天对话中的向量
Args:
thread_id: 会话线程 ID
vector_ids: 向量 ID 列表
Returns:
bool: 是否删除成功
"""
try:
if not vector_ids:
return True
collection_name = f"thread_{thread_id}"
vector_store = self.get_vector_store(collection_name)
# 删除指定的向量
vector_store.delete(ids=vector_ids)
logger.info(f"从会话 {thread_id} 删除 {len(vector_ids)} 个向量")
return True
except Exception as e:
logger.error(f"删除向量失败: {e}")
return False
def delete_thread_collection(
self,
thread_id: str
) -> bool:
"""
删除聊天对话的整个向量集合
Args:
thread_id: 会话线程 ID
Returns:
bool: 是否删除成功
"""
try:
collection_name = f"thread_{thread_id}"
persist_directory = os.path.join(self.vector_store_path, collection_name)
# 删除向量库目录
import shutil
if os.path.exists(persist_directory):
shutil.rmtree(persist_directory)
logger.info(f"删除向量库: {collection_name}")
return True
return False
except Exception as e:
logger.error(f"删除向量库失败: {e}")
return False
def search_similar(
self,
knowledge_base_id: int,
query: str,
k: int = 5
) -> List[dict]:
"""
在知识库中搜索相似文档
Args:
knowledge_base_id: 知识库 ID
query: 查询文本
k: 返回结果数量
Returns:
List[dict]: 相似文档列表
"""
try:
collection_name = f"kb_{knowledge_base_id}"
vector_store = self.get_vector_store(collection_name)
# 相似度搜索
results = vector_store.similarity_search_with_score(query, k=k)
# 格式化结果
formatted_results = []
for doc, score in results:
formatted_results.append({
"content": doc.page_content,
"metadata": doc.metadata,
"score": float(score)
})
return formatted_results
except Exception as e:
logger.error(f"搜索相似文档失败: {e}")
return []
def delete_vectors_by_ids(
self,
knowledge_base_id: int,
vector_ids: List[str]
) -> bool:
"""
根据向量 ID 列表删除向量库中的向量
Args:
knowledge_base_id: 知识库 ID
vector_ids: 向量 ID 列表
Returns:
bool: 是否删除成功
"""
try:
if not vector_ids:
return True
collection_name = f"kb_{knowledge_base_id}"
vector_store = self.get_vector_store(collection_name)
# 删除指定的向量
vector_store.delete(ids=vector_ids)
logger.info(f"从知识库 {knowledge_base_id} 删除 {len(vector_ids)} 个向量")
return True
except Exception as e:
logger.error(f"删除向量失败: {e}")
return False
def delete_collection(self, knowledge_base_id: int) -> bool:
"""
删除知识库的整个向量集合
Args:
knowledge_base_id: 知识库 ID
Returns:
bool: 是否删除成功
"""
try:
collection_name = f"kb_{knowledge_base_id}"
persist_directory = os.path.join(self.vector_store_path, collection_name)
# 删除向量库目录
import shutil
if os.path.exists(persist_directory):
shutil.rmtree(persist_directory)
logger.info(f"删除向量库: {collection_name}")
return True
return False
except Exception as e:
logger.error(f"删除向量库失败: {e}")
return False
# ----- 知识图谱 RAGChroma 集合名 knowledge_graph_{graphs.id};兼容旧名 novel_kg_* -----
def delete_knowledge_graph_collection(self, knowledge_graph_pk: int) -> bool:
"""删除知识图谱对应的 Chroma 持久化目录(含旧集合 novel_kg_*)。"""
import shutil
ok = True
for name in (f"knowledge_graph_{knowledge_graph_pk}", f"novel_kg_{knowledge_graph_pk}"):
try:
persist_directory = os.path.join(self.vector_store_path, name)
if os.path.exists(persist_directory):
shutil.rmtree(persist_directory)
logger.info(f"删除知识图谱向量库: {name}")
except Exception as e:
logger.error(f"删除知识图谱向量库失败 {name}: {e}")
ok = False
return ok
def index_knowledge_graph_text(self, knowledge_graph_pk: int, text: str) -> int:
"""
将资料全文分块写入 Chroma。
Returns:
写入的分块数量
"""
from langchain_core.documents import Document
if not text or not text.strip():
return 0
self.delete_knowledge_graph_collection(knowledge_graph_pk)
doc = Document(
page_content=text.strip(),
metadata={"source": f"knowledge_graph_{knowledge_graph_pk}", "kind": "knowledge_graph"},
)
splits = self.text_splitter.split_documents([doc])
if not splits:
return 0
for idx, d in enumerate(splits):
if not d.metadata:
d.metadata = {}
d.metadata["chunk_index"] = idx
collection_name = f"knowledge_graph_{knowledge_graph_pk}"
vector_store = self.get_vector_store(collection_name)
vector_store.add_documents(splits)
logger.info(f"知识图谱向量化完成 knowledge_graph_pk={knowledge_graph_pk} chunks={len(splits)}")
return len(splits)
def search_similar_knowledge_graph(
self,
knowledge_graph_pk: int,
query: str,
k: int = 5,
) -> List[dict]:
"""在知识图谱向量集合中检索(优先新集合名,兼容 novel_kg_*)。"""
for name in (f"knowledge_graph_{knowledge_graph_pk}", f"novel_kg_{knowledge_graph_pk}"):
persist_directory = os.path.join(self.vector_store_path, name)
if not os.path.exists(persist_directory):
continue
try:
vector_store = self.get_vector_store(name)
results = vector_store.similarity_search_with_score(query, k=k)
formatted: List[dict] = []
for doc, score in results:
formatted.append({
"content": doc.page_content,
"metadata": doc.metadata,
"score": float(score),
})
return formatted
except Exception as e:
logger.error(f"知识图谱向量检索失败 ({name}): {e}")
return []
# ==================== 增强的 RAG 功能 ====================
def add_summary_chunk(
self,
collection_name: str,
file_id: int,
file_name: str,
summary_text: str,
metadata: Optional[dict] = None
) -> Optional[str]:
"""
添加文件摘要 chunk
Args:
collection_name: 集合名称
file_id: 文件 ID
file_name: 文件名
summary_text: 摘要文本
metadata: 额外的元数据
Returns:
Optional[str]: 摘要 chunk 的 ID
"""
try:
vector_store = self.get_vector_store(collection_name)
# 构建摘要 metadata
summary_metadata = {
"file_id": file_id,
"file_name": file_name,
"chunk_type": "summary", # 标记为摘要类型
"chunk_index": -1, # 摘要没有索引
}
if metadata:
summary_metadata.update(metadata)
# 添加到向量库
ids = vector_store.add_texts(
texts=[summary_text],
metadatas=[summary_metadata]
)
logger.info(f"添加摘要 chunk 成功: {file_name}")
return ids[0] if ids else None
except Exception as e:
logger.error(f"添加摘要 chunk 失败: {e}")
return None
def search_by_chunk_type(
self,
collection_name: str,
query: str,
chunk_type: str = "text",
top_k: int = 5,
filter_metadata: Optional[dict] = None
) -> List[Tuple[str, dict, float]]:
"""
基于 chunk 类型检索
Args:
collection_name: 集合名称
query: 查询文本
chunk_type: chunk 类型 (summary, text)
top_k: 返回结果数量
filter_metadata: 额外的过滤条件
Returns:
List[Tuple[str, dict, float]]: (内容, 元数据, 分数) 的列表
"""
try:
vector_store = self.get_vector_store(collection_name)
# 构建过滤条件
where_filter = {"chunk_type": chunk_type}
if filter_metadata:
where_filter.update(filter_metadata)
# 执行检索
results = vector_store.similarity_search_with_score(
query=query,
k=top_k,
filter=where_filter
)
return [
(doc.page_content, doc.metadata, score)
for doc, score in results
]
except Exception as e:
logger.error(f"基于类型检索失败: {e}")
return []
def get_file_all_chunks(
self,
collection_name: str,
file_id: int,
chunk_type: str = "text"
) -> List[Tuple[str, dict]]:
"""
获取文件的所有 chunks用于全文检索
Args:
collection_name: 集合名称
file_id: 文件 ID
chunk_type: chunk 类型
Returns:
List[Tuple[str, dict]]: (内容, 元数据) 的列表
"""
try:
vector_store = self.get_vector_store(collection_name)
# 使用 where 过滤获取所有匹配的文档
results = vector_store.get(
where={
"file_id": file_id,
"chunk_type": chunk_type
}
)
if not results or "documents" not in results:
return []
documents = results["documents"]
metadatas = results["metadatas"]
# 按 chunk_index 排序
chunks = list(zip(documents, metadatas))
chunks.sort(key=lambda x: x[1].get("chunk_index", 0))
return chunks
except Exception as e:
logger.error(f"获取文件所有 chunks 失败: {e}")
return []
def hybrid_search(
self,
collection_name: str,
query: str,
top_k: int = 5,
filter_metadata: Optional[dict] = None,
dense_weight: float = 0.7,
sparse_weight: float = 0.3
) -> List[Tuple[str, dict, float]]:
"""
混合检索dense + sparse
注意:这是一个简化版本,真正的混合检索需要支持 BM25 等稀疏检索算法。
ChromaDB 目前主要支持 dense 向量检索。
Args:
collection_name: 集合名称
query: 查询文本
top_k: 返回结果数量
filter_metadata: 过滤条件
dense_weight: dense 检索权重
sparse_weight: sparse 检索权重(当前未实现)
Returns:
List[Tuple[str, dict, float]]: (内容, 元数据, 分数) 的列表
"""
try:
# TODO: 实现真正的混合检索
# 当前仅使用 dense 检索
logger.warning("混合检索功能当前仅支持 dense 向量检索,需要额外集成 BM25")
vector_store = self.get_vector_store(collection_name)
# 执行 dense 检索
results = vector_store.similarity_search_with_score(
query=query,
k=top_k,
filter=filter_metadata
)
return [
(doc.page_content, doc.metadata, score)
for doc, score in results
]
except Exception as e:
logger.error(f"混合检索失败: {e}")
return []
def search_with_rerank(
self,
collection_name: str,
query: str,
top_k: int = 5,
rerank_top_k: int = 3,
filter_metadata: Optional[dict] = None
) -> List[Tuple[str, dict, float]]:
"""
带重排序的检索
Args:
collection_name: 集合名称
query: 查询文本
top_k: 初次检索数量
rerank_top_k: 重排序后返回数量
filter_metadata: 过滤条件
Returns:
List[Tuple[str, dict, float]]: (内容, 元数据, 分数) 的列表
"""
try:
# 先执行初次检索,获取较多结果
results = self.search_similar(
collection_name=collection_name,
query=query,
top_k=top_k,
filter_metadata=filter_metadata
)
# TODO: 使用 rerank 模型对结果进行重排序
# 当前仅返回 top_k 结果
return results[:rerank_top_k]
except Exception as e:
logger.error(f"重排序检索失败: {e}")
return []
def get_summary_by_file_ids(
self,
collection_name: str,
file_ids: List[int]
) -> Dict[int, str]:
"""
批量获取文件摘要
Args:
collection_name: 集合名称
file_ids: 文件 ID 列表
Returns:
Dict[int, str]: 文件 ID 到摘要的映射
"""
try:
vector_store = self.get_vector_store(collection_name)
summaries = {}
for file_id in file_ids:
results = vector_store.get(
where={
"file_id": file_id,
"chunk_type": "summary"
}
)
if results and "documents" in results and results["documents"]:
summaries[file_id] = results["documents"][0]
return summaries
except Exception as e:
logger.error(f"批量获取摘要失败: {e}")
return {}
# 全局向量服务实例
_vector_service = None
def get_vector_service() -> VectorService:
"""获取全局向量服务实例"""
global _vector_service
if _vector_service is None:
_vector_service = VectorService()
return _vector_service