RAG开源项目Qanything源码阅读2-离线文件处理

原文:前沿重器[46] RAG开源项目Qanything源码阅读2-离线文件处理
项目:https://github.com/netease-youdao/QAnything
上一篇:RAG开源项目Qanything源码阅读1-概述+服务


本文结构:

  • 文件上传
  • 文件读取和切片
  • 索引构造

提前说明,这里忽略了大量的业务代码,聚焦在文件处理和相关算法本身,如新建用户、知识库、文件删除,会有选择的忽略,有需要的可以参考文中的思路,在代码里找到对应的位置。

1,文件上传

文件上传是指将文件从前端传到后端的流程,这个流程的工作在docs\API.md有提到。首先是接口字段:

参数名参数值是否必填参数类型描述说明
files文件二进制File需要上传的文件,可多选,目前仅支持[md,txt,pdf,jpg,png,jpeg,docx,xlsx,pptx,eml,csv]
user_idzzpString用户 id
kb_idKBb1dd58e8485443ce81166d24f6febda7String知识库 id
modesoftString上传模式,soft:知识库内存在同名文件时当前文件不再上传,strong:文件名重复的文件强制上传,默认值为 soft

作者给出了两种模式:分别是同步和异步。

1.1 客户端

客户端只需要请求服务即可,这里穿插一下同步异步请求,以及文件上传的细节,这个直接参考源码就好了,首先是同步的请求源码:

1.1.1 上传文件同步请求示例
import os
import requestsurl = "http://{your_host}:8777/api/local_doc_qa/upload_files"
folder_path = "./docx_data"  # 文件所在文件夹,注意是文件夹!!
data = {"user_id": "zzp","kb_id": "KB6dae785cdd5d47a997e890521acbe1c9","mode": "soft"
}files = []
for root, dirs, file_names in os.walk(folder_path):for file_name in file_names:if file_name.endswith(".md"):  # 这里只上传后缀是md的文件,请按需修改,支持类型:file_path = os.path.join(root, file_name)files.append(("files", open(file_path, "rb")))response = requests.post(url, files=files, data=data)
print(response.text)
  • 发请求用的是通用的requests包。
  • 因为是本地测试,所以使用的就是比较直接的本地文件,直接open就行,文件字段存的是open变量,注意打开方式是rb。

至于异步,则会复杂一些。

1.1.2 上传文件异步请求示例
import argparse
import os
import sys
import json
import aiohttp
import asyncio
import time
import random
import stringfiles = []
for root, dirs, file_names in os.walk("./docx_data"):  # 文件夹for file_name in file_names:if file_name.endswith(".docx"):  # 只上传docx文件file_path = os.path.join(root, file_name)files.append(file_path)
print(len(files))
response_times = []async def send_request(round_, files):print(len(files))url = 'http://{your_host}:8777/api/local_doc_qa/upload_files'data = aiohttp.FormData()data.add_field('user_id', 'zzp')data.add_field('kb_id', 'KBf1dafefdb08742f89530acb7e9ed66dd')data.add_field('mode', 'soft')total_size = 0for file_path in files:file_size = os.path.getsize(file_path)total_size += file_sizedata.add_field('files', open(file_path, 'rb'))print('size:', total_size / (1024 * 1024))try:start_time = time.time()async with aiohttp.ClientSession() as session:async with session.post(url, data=data) as response:end_time = time.time()response_times.append(end_time - start_time)print(f"round_:{round_}, 响应状态码: {response.status}, 响应时间: {end_time - start_time}秒")except Exception as e:print(f"请求发送失败: {e}")async def main():start_time = time.time()num = int(sys.argv[1])  // 一次上传数量,http协议限制一次请求data不能大于100M,请自行控制数量round_ = 0r_files = files[:num]tasks = []task = asyncio.create_task(send_request(round_, r_files))tasks.append(task)await asyncio.gather(*tasks)print(f"请求完成")end_time = time.time()total_requests = len(response_times)total_time = end_time - start_timeqps = total_requests / total_timeprint(f"total_time:{total_time}")if __name__ == '__main__':asyncio.run(main())

请求用的是aiohttp,而且使用的是python的协程,即asyncio一套的python技术,具体细节可以参考这篇博客:https://blog.csdn.net/LiamHong_/article/details/134458790。协程在高密度的http请求下,能有效提升CPU的使用率,提升综合性能,毕竟在请求等待过程,可以做很多别的事,就避免CPU空跑了。

1.3 上传文件响应示例

{"code": 200, //状态码"msg": "success,后台正在飞速上传文件,请耐心等待", //提示信息"data": [{"file_id": "1b6c0781fb9245b2973504cb031cc2f3", //文件id"file_name": "网易有道智云平台产品介绍2023.6.ppt", //文件名"status": "gray", //文件状态(red:入库失败-切分失败,green,成功入库,yellow:入库失败-milvus失败,gray:正在入库)"bytes": 17925, //文件大小(字节数)"timestamp": "202401251056" // 上传时间},{"file_id": "aeaec708c7a34952b7de484fb3374f5d","file_name": "有道知识库问答产品介绍.pptx","status": "gray","bytes": 12928, //文件大小(字节数)"timestamp": "202401251056" // 上传时间}] //文件列表
}

1.2,服务端

服务端则比较复杂了,文件上传后要经过大量的校验,并且需要返回最终的处理结果。

文件上传的接口是/api/local_doc_qa/upload_files,我们可以在handlers.py里面找到,排除掉一些校验代码,handlers里面的核心代码是这段(upload_files函数下):

for file, file_name in zip(files, file_names):if file_name in exist_file_names:continuefile_id, msg = local_doc_qa.milvus_summary.add_file(user_id, kb_id, file_name, timestamp)debug_logger.info(f"{file_name}, {file_id}, {msg}")local_file = LocalFile(user_id, kb_id, file, file_id, file_name, local_doc_qa.embeddings)local_files.append(local_file)local_doc_qa.milvus_summary.update_file_size(file_id, len(local_file.file_content))data.append({"file_id": file_id, "file_name": file_name, "status": "gray", "bytes": len(local_file.file_content),"timestamp": timestamp})
asyncio.create_task(local_doc_qa.insert_files_to_milvus(user_id, kb_id, local_files))

关键的函数:

  • local_doc_qa.milvus_summary.add_file:向指定知识库下面增加文件,这是一个mysql操作,要在mysql数据库内记录在案。

  • local_doc_qa.insert_files_to_milvus:将文档加入到milvus中,当然这里也包含了文件切片、推理向量、存入数据库等一系列操作。

回到服务,这里最终还是会收集各种处理的信息,最终以json形式形式返回,这里包括状态码、返回信息以及必要的数据信息(例如文件id、上传后的文件名、更新时间等)

    if exist_file_names:msg = f'warning,当前的mode是soft,无法上传同名文件{exist_file_names},如果想强制上传同名文件,请设置mode:strong'else:msg = "success,后台正在飞速上传文件,请耐心等待"return sanic_json({"code": 200, "msg": msg, "data": data})

1.3,文件处理核心流程

继续往里面看 qanything_kernel\core\local_doc_qa.py

async def insert_files_to_milvus(self, user_id, kb_id, local_files: List[LocalFile]):debug_logger.info(f'insert_files_to_milvus: {kb_id}')milvus_kv = self.match_milvus_kb(user_id, [kb_id])assert milvus_kv is not Nonesuccess_list = []failed_list = []for local_file in local_files:start = time.time()try:local_file.split_file_to_docs(self.get_ocr_result)content_length = sum([len(doc.page_content) for doc in local_file.docs])except Exception as e:error_info = f'split error: {traceback.format_exc()}'debug_logger.error(error_info)self.milvus_summary.update_file_status(local_file.file_id, status='red')failed_list.append(local_file)continueend = time.time()self.milvus_summary.update_content_length(local_file.file_id, content_length)debug_logger.info(f'split time: {end - start} {len(local_file.docs)}')start = time.time()try:local_file.create_embedding()except Exception as e:error_info = f'embedding error: {traceback.format_exc()}'debug_logger.error(error_info)self.milvus_summary.update_file_status(local_file.file_id, status='red')failed_list.append(local_file)continueend = time.time()debug_logger.info(f'embedding time: {end - start} {len(local_file.embs)}')self.milvus_summary.update_chunk_size(local_file.file_id, len(local_file.docs))ret = await milvus_kv.insert_files(local_file.file_id, local_file.file_name, local_file.file_path,local_file.docs, local_file.embs)insert_time = time.time()debug_logger.info(f'insert time: {insert_time - end}')if ret:self.milvus_summary.update_file_status(local_file.file_id, status='green')success_list.append(local_file)else:self.milvus_summary.update_file_status(local_file.file_id, status='yellow')failed_list.append(local_file)debug_logger.info(f"insert_to_milvus: success num: {len(success_list)}, failed num: {len(failed_list)}")

除开各种校验和数据的同步更新,主要经历的是这几个流程:

  • local_file.split_file_to_docs:文件的切片,这里还涉及不同类型的文件处理,例如md、图片等。

  • local_file.create_embedding:看名字就知道了,向量化。

  • milvus_kv.insert_files:存入milvus。

这就是文件上传后核心要经历的4个流程,即文件读取、文件切片、向量化和入库,接下来逐个展开讲。

2,文件读取和切片

文件读取和切片在代码里有不少是混合的。在代码里,我们能看到,他们目前支持的是这几种格式:md,txt,pdf,jpg,png,jpeg,docx,xlsx,pptx,eml,csv,另外还有一个基于url的网页,大概就是这几块的内容,代码里对这几个类型都提供了处理代码。

2.1,load_and_split

在开始之前,必须了解一下文件读取的这基类BaseLoader,这里对加载切分都有详细的预定义。这里向大家关注的点只有一个,就是load_and_split,我只把有关的部分放出来,这是一个支持在自定义好加载组件和切片组建后,一条龙使用的函数,注意这个BaseLoader是在langchain_core里的,不是在Qanything项目里的。

class BaseLoader(ABC):def load_and_split(self, text_splitter: Optional[TextSplitter] = None) -> List[Document]:"""Load Documents and split into chunks. Chunks are returned as Documents.Do not override this method. It should be considered to be deprecated!Args:text_splitter: TextSplitter instance to use for splitting documents.Defaults to RecursiveCharacterTextSplitter.Returns:List of Documents."""if text_splitter is None:try:from langchain_text_splitters import RecursiveCharacterTextSplitterexcept ImportError as e:raise ImportError("Unable to import from langchain_text_splitters. Please specify ""text_splitter or install langchain_text_splitters with ""`pip install -U langchain-text-splitters`.") from e_text_splitter: TextSplitter = RecursiveCharacterTextSplitter()else:_text_splitter = text_splitterdocs = self.load()return _text_splitter.split_documents(docs)

有这个基类后,只需要继承这个积累就能写自己的加载器了,至于文档切分器(TextSplitter),则可以在load_and_split使用的时候传进去,例如这样:

loader = MyRecursiveUrlLoader(url=self.url)
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
docs = loader.load_and_split(text_splitter=textsplitter)

MyRecursiveUrlLoader是URL加载器(具体后面会讲),初始化以后,再定义一个中文的切分器ChineseTextSplitter(具体后面也会讲),然后直接用loader.load_and_split(text_splitter=textsplitter)即可把加载、切片都给搞定了。

2.2,文件读取

在这个基类下,根据不同需要,会有各种不一样的加载器,用于应对多种不同的格式,自定义的加载器直接从BaseLoader继承即可。

  • MyRecursiveUrlLoaderURL加载器,即网络链接下的内容加载,内部直接用了langchainWebBaseLoader,网页解析则使用的是BeautifulSoup,算是爬虫技术里的老朋友了,BeautifulSoup主要用于解析代码里暗藏的url,方便进一步查询。

  • UnstructuredFileLoader,直接从langchain里面加载的,from langchain.document_loaders import UnstructuredFileLoader。这个也就只用在了markdown里面(.md)。

  • TextLoader,也是直接从langchain里面加载的from langchain.document_loaders import UnstructuredFileLoader, TextLoader 。这个也就只用在了txt里面(.txt)。

  • UnstructuredPaddlePDFLoader,这个是专门用在pdf文件里的,作者自己写的类,继承自前面提到的UnstructuredFileLoader,但不局限在此,主要重写的是_get_elements函数,内部写了一个函数pdf_ocr_txt,首先用fitz读取pdf每页的图片,然后用ocr_engine来解析(请求orc接口,本项目里用的是一个triton部署的paddleocr服务),最后用unstructured下的一个函数partition_text来完成切片pip install unstructured),当然后续还会有针对中文的综合切片,后面会说。

  • UnstructuredPaddleImageLoader,用来解析图片的工具,对应jpg、png、jpeg后缀文件。同样继承自UnstructuredFileLoader,和PDF不同的是加载部分,图片加载使用的是cv2,加载后和PDF的处理一样,都是走一遍ocr_enginepartition_text

  • UnstructuredWordDocumentLoader用于处理docx文件,来自langchain。

  • xlsx使用的是pandas,值得注意的是engine使用的是openpyxl,另外文件读取后,作者会把内容转为csv,然后用CSVLoader来处理。

  • CSVLoader顾名思义处理的是csv文件,这里用的是csv.DictReader来读取的。

  • UnstructuredPowerPointLoader用于读取PPT,从langchain里面加载的,from langchain.document_loaders import UnstructuredPowerPointLoader

  • UnstructuredEmailLoader用于读取邮件格式的文件.eml,也是从langchain中加载的,from langchain.document_loaders import UnstructuredEmailLoader

至此,所有支持的文件加载都在这里了,这些文件加载都挺有借鉴意义的,后续在做自己的RAG系统的过程中,也可以考虑直接使用。

2.2.1 qanything_kernel\utils\loader\csv_loader.py
import csv
from io import TextIOWrapper
from typing import Any, Dict, List, Optional, Sequencefrom langchain_core.documents import Documentfrom langchain_community.document_loaders.base import BaseLoader
from langchain_community.document_loaders.helpers import detect_file_encodingsclass CSVLoader(BaseLoader):"""Load a `CSV` file into a list of Documents.Each document represents one row of the CSV file. Every row is converted into akey/value pair and outputted to a new line in the document's page_content.The source for each document loaded from csv is set to the value of the`file_path` argument for all documents by default.You can override this by setting the `source_column` argument to thename of a column in the CSV file.The source of each document will then be set to the value of the columnwith the name specified in `source_column`.Output Example:.. code-block:: txtcolumn1: value1column2: value2column3: value3"""def __init__(self,file_path: str,source_column: Optional[str] = None,metadata_columns: Sequence[str] = (),csv_args: Optional[Dict] = None,encoding: Optional[str] = None,autodetect_encoding: bool = False,):"""Args:file_path: The path to the CSV file.source_column: The name of the column in the CSV file to use as the source.Optional. Defaults to None.metadata_columns: A sequence of column names to use as metadata. Optional.csv_args: A dictionary of arguments to pass to the csv.DictReader.Optional. Defaults to None.encoding: The encoding of the CSV file. Optional. Defaults to None.autodetect_encoding: Whether to try to autodetect the file encoding."""self.file_path = file_pathself.source_column = source_columnself.metadata_columns = metadata_columnsself.encoding = encodingself.csv_args = csv_args or {}self.autodetect_encoding = autodetect_encodingdef load(self) -> List[Document]:"""Load data into document objects."""docs = []try:with open(self.file_path, newline="", encoding=self.encoding) as csvfile:docs = self.__read_file(csvfile)except UnicodeDecodeError as e:if self.autodetect_encoding:detected_encodings = detect_file_encodings(self.file_path)for encoding in detected_encodings:try:with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile:docs = self.__read_file(csvfile)breakexcept UnicodeDecodeError:continueelse:raise RuntimeError(f"Error loading {self.file_path}") from eexcept Exception as e:raise RuntimeError(f"Error loading {self.file_path}") from ereturn docsdef __read_file(self, csvfile: TextIOWrapper) -> List[Document]:docs = []csv_reader = csv.DictReader(csvfile, **self.csv_args)  # type: ignore# 初始化一个字典,用于存储每一列最后一次的非空值last_non_empty_values = {}for i, row in enumerate(csv_reader):try:source = (row[self.source_column]if self.source_column is not Noneelse self.file_path)except KeyError:raise ValueError(f"Source column '{self.source_column}' not found in CSV file.")line_contents = []for k, v in row.items():if k in self.metadata_columns:continueline_contents.append(f"{k.strip()}: {v.strip() if v else last_non_empty_values.get(k, v)}")if v:last_non_empty_values[k] = vcontent = '------------------------\n'# content += " & ".join(#     f"{k.strip()}: {v.strip() if v is not None else v}"#     for k, v in row.items()#     if k not in self.metadata_columns# )content += ' & '.join(line_contents)content += '\n------------------------'metadata = {"source": source, "row": i}for col in self.metadata_columns:try:metadata[col] = row[col]except KeyError:raise ValueError(f"Metadata column '{col}' not found in CSV file.")doc = Document(page_content=content, metadata=metadata)docs.append(doc)return docs
2.2.2 qanything_kernel\utils\loader\image_loader.py
"""Loader that loads image files."""
from typing import List, Callablefrom langchain.document_loaders.unstructured import UnstructuredFileLoader
import os
from typing import Union, Any
import cv2
import base64class UnstructuredPaddleImageLoader(UnstructuredFileLoader):"""Loader that uses unstructured to load image files, such as PNGs and JPGs."""def __init__(self,file_path: Union[str, List[str]],ocr_engine: Callable,mode: str = "single",**unstructured_kwargs: Any,):"""Initialize with file path."""self.ocr_engine = ocr_enginesuper().__init__(file_path=file_path, mode=mode, **unstructured_kwargs)def _get_elements(self) -> List:def image_ocr_txt(filepath, dir_path="tmp_files"):full_dir_path = os.path.join(os.path.dirname(filepath), dir_path)if not os.path.exists(full_dir_path):os.makedirs(full_dir_path)filename = os.path.split(filepath)[-1]img_np = cv2.imread(filepath)h, w, c = img_np.shapeimg_data = {"img64": base64.b64encode(img_np).decode("utf-8"), "height": h, "width": w, "channels": c}result = self.ocr_engine(img_data)result = [line for line in result if line]ocr_result = [i[1][0] for line in result for i in line]txt_file_path = os.path.join(full_dir_path, "%s.txt" % (filename))with open(txt_file_path, 'w', encoding='utf-8') as fout:fout.write("\n".join(ocr_result))return txt_file_pathtxt_file_path = image_ocr_txt(self.file_path)from unstructured.partition.text import partition_textreturn partition_text(filename=txt_file_path, **self.unstructured_kwargs)
2.2.3 qanything_kernel\utils\loader\my_recursive_url_loader.py
from typing import Iterator, List, Optional, Set
from urllib.parse import urljoin, urldefragimport requestsfrom langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoaderclass MyRecursiveUrlLoader(BaseLoader):"""Loads all child links from a given url."""def __init__(self,url: str,exclude_dirs: Optional[str] = None,max_depth: int = -1) -> None:"""Initialize with URL to crawl and any subdirectories to exclude.Args:url: The URL to crawl.exclude_dirs: A list of subdirectories to exclude."""self.url = urlself.exclude_dirs = exclude_dirsself.max_depth = max_depthdef get_child_links_recursive(self, url: str, depth: int, visited: Optional[Set[str]] = None) -> Iterator[Document]:"""Recursively get all child links starting with the path of the input URL.Args:url: The URL to crawl.visited: A set of visited URLs."""from langchain.document_loaders import WebBaseLoadertry:from bs4 import BeautifulSoupexcept ImportError:raise ImportError("The BeautifulSoup package is required for the RecursiveUrlLoader.")# Exclude the root and parent from a listvisited = set() if visited is None else visitedif self.max_depth > 0 and depth <= self.max_depth:return None# Exclude the links that start with any of the excluded directoriesif self.exclude_dirs and any(url.startswith(exclude_dir) for exclude_dir in self.exclude_dirs):return visitedyield from WebBaseLoader(web_path=url).load()visited.add(url)# Get all links that are relative to the root of the websiteresponse = requests.get(url, timeout=60)soup = BeautifulSoup(response.text, "html.parser")all_links = [urljoin(url, link.get("href")) for link in soup.find_all("a")]# Filter children url of current urlchild_links = [link for link in set(all_links) if link.startswith(url)]# Remove framents to avoid repititionsdefraged_child_links = [urldefrag(link).url for link in child_links]# Store the visited links and recursively visit the childrenfor link in set(defraged_child_links):# Check all unvisited linksif link not in visited:visited.add(link)yield from WebBaseLoader(link).load()# If the link is a directory (w/ children) then visit itif link.endswith("/"):yield from self.get_child_links_recursive(link, depth+1, visited)return visiteddef lazy_load(self) -> Iterator[Document]:"""Lazy load web pages."""return self.get_child_links_recursive(self.url, depth=0)def load(self) -> List[Document]:"""Load web pages."""return list(self.lazy_load())
2.2.4 qanything_kernel\utils\loader\pdf_loader.py
"""Loader that loads image files."""
from typing import List, Callablefrom langchain.document_loaders.unstructured import UnstructuredFileLoader
from unstructured.partition.text import partition_text
import os
import fitz
from tqdm import tqdm
from typing import Union, Any
import numpy as np
import base64class UnstructuredPaddlePDFLoader(UnstructuredFileLoader):"""Loader that uses unstructured to load image files, such as PNGs and JPGs."""def __init__(self,file_path: Union[str, List[str]],ocr_engine: Callable,mode: str = "single",**unstructured_kwargs: Any,):"""Initialize with file path."""self.ocr_engine = ocr_enginesuper().__init__(file_path=file_path, mode=mode, **unstructured_kwargs)def _get_elements(self) -> List:def pdf_ocr_txt(filepath, dir_path="tmp_files"):full_dir_path = os.path.join(os.path.dirname(filepath), dir_path)if not os.path.exists(full_dir_path):os.makedirs(full_dir_path)doc = fitz.open(filepath)txt_file_path = os.path.join(full_dir_path, "{}.txt".format(os.path.split(filepath)[-1]))img_name = os.path.join(full_dir_path, 'tmp.png')with open(txt_file_path, 'w', encoding='utf-8') as fout:for i in tqdm(range(doc.page_count)):page = doc.load_page(i)pix = page.get_pixmap()img = np.frombuffer(pix.samples, dtype=np.uint8).reshape((pix.h, pix.w, pix.n))img_data = {"img64": base64.b64encode(img).decode("utf-8"), "height": pix.h, "width": pix.w,"channels": pix.n}result = self.ocr_engine(img_data)result = [line for line in result if line]ocr_result = [i[1][0] for line in result for i in line]fout.write("\n".join(ocr_result))if os.path.exists(img_name):os.remove(img_name)return txt_file_pathtxt_file_path = pdf_ocr_txt(self.file_path)return partition_text(filename=txt_file_path, **self.unstructured_kwargs)
2.2.5 qanything_kernel\core\local_file.py
from qanything_kernel.utils.general_utils import *
from typing import List, Union, Callable
from qanything_kernel.configs.model_config import UPLOAD_ROOT_PATH, SENTENCE_SIZE, ZH_TITLE_ENHANCE
from langchain.docstore.document import Document
from qanything_kernel.utils.loader.my_recursive_url_loader import MyRecursiveUrlLoader
from langchain.document_loaders import UnstructuredFileLoader, TextLoader
from langchain.document_loaders import UnstructuredWordDocumentLoader
from langchain.document_loaders import UnstructuredExcelLoader
from langchain.document_loaders import UnstructuredEmailLoader
from langchain.document_loaders import UnstructuredPowerPointLoader
from qanything_kernel.utils.loader.csv_loader import CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from qanything_kernel.utils.custom_log import debug_logger, qa_logger
from qanything_kernel.utils.splitter import ChineseTextSplitter
from qanything_kernel.utils.loader import UnstructuredPaddleImageLoader, UnstructuredPaddlePDFLoader
from qanything_kernel.utils.splitter import zh_title_enhance
from sanic.request import File
import pandas as pd
import ostext_splitter = RecursiveCharacterTextSplitter(separators=["\n", ".", "。", "!", "!", "?", "?", ";", ";", "……", "…", "、", ",", ",", " "],chunk_size=400,length_function=num_tokens,
)class LocalFile:def __init__(self, user_id, kb_id, file: Union[File, str], file_id, file_name, embedding, is_url=False, in_milvus=False):self.user_id = user_idself.kb_id = kb_idself.file_id = file_idself.docs: List[Document] = []self.embs = []self.emb_infer = embeddingself.url = Noneself.in_milvus = in_milvusself.file_name = file_nameif is_url:self.url = fileself.file_path = "URL"self.file_content = b''else:if isinstance(file, str):self.file_path = filewith open(file, 'rb') as f:self.file_content = f.read()else:upload_path = os.path.join(UPLOAD_ROOT_PATH, user_id)file_dir = os.path.join(upload_path, self.file_id)os.makedirs(file_dir, exist_ok=True)self.file_path = os.path.join(file_dir, self.file_name)self.file_content = file.bodywith open(self.file_path, "wb+") as f:f.write(self.file_content)debug_logger.info(f'success init localfile {self.file_name}')def split_file_to_docs(self, ocr_engine: Callable, sentence_size=SENTENCE_SIZE,using_zh_title_enhance=ZH_TITLE_ENHANCE):if self.url:debug_logger.info("load url: {}".format(self.url))loader = MyRecursiveUrlLoader(url=self.url)textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)docs = loader.load_and_split(text_splitter=textsplitter)elif self.file_path.lower().endswith(".md"):loader = UnstructuredFileLoader(self.file_path, mode="elements")docs = loader.load()elif self.file_path.lower().endswith(".txt"):loader = TextLoader(self.file_path, autodetect_encoding=True)texts_splitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)docs = loader.load_and_split(texts_splitter)elif self.file_path.lower().endswith(".pdf"):loader = UnstructuredPaddlePDFLoader(self.file_path, ocr_engine)texts_splitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size)docs = loader.load_and_split(texts_splitter)elif self.file_path.lower().endswith(".jpg") or self.file_path.lower().endswith(".png") or self.file_path.lower().endswith(".jpeg"):loader = UnstructuredPaddleImageLoader(self.file_path, ocr_engine, mode="elements")texts_splitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)docs = loader.load_and_split(text_splitter=texts_splitter)elif self.file_path.lower().endswith(".docx"):loader = UnstructuredWordDocumentLoader(self.file_path, mode="elements")texts_splitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)docs = loader.load_and_split(texts_splitter)elif self.file_path.lower().endswith(".xlsx"):# loader = UnstructuredExcelLoader(self.file_path, mode="elements")docs = []xlsx = pd.read_excel(self.file_path, engine='openpyxl', sheet_name=None)for sheet in xlsx.keys():df = xlsx[sheet]df.dropna(how='all', inplace=True)csv_file_path = self.file_path[:-5] + '_' + sheet + '.csv'df.to_csv(csv_file_path, index=False)loader = CSVLoader(csv_file_path, csv_args={"delimiter": ",", "quotechar": '"'})docs += loader.load()elif self.file_path.lower().endswith(".pptx"):loader = UnstructuredPowerPointLoader(self.file_path, mode="elements")docs = loader.load()elif self.file_path.lower().endswith(".eml"):loader = UnstructuredEmailLoader(self.file_path, mode="elements")docs = loader.load()elif self.file_path.lower().endswith(".csv"):loader = CSVLoader(self.file_path, csv_args={"delimiter": ",", "quotechar": '"'})docs = loader.load()else:raise TypeError("文件类型不支持,目前仅支持:[md,txt,pdf,jpg,png,jpeg,docx,xlsx,pptx,eml,csv]")if using_zh_title_enhance:debug_logger.info("using_zh_title_enhance %s", using_zh_title_enhance)docs = zh_title_enhance(docs)# 重构docs,如果doc的文本长度大于800tokens,则利用text_splitter将其拆分成多个doc# text_splitter: RecursiveCharacterTextSplitterdebug_logger.info(f"before 2nd split doc lens: {len(docs)}")docs = text_splitter.split_documents(docs)debug_logger.info(f"after 2nd split doc lens: {len(docs)}")# 这里给每个docs片段的metadata里注入file_idfor doc in docs:doc.metadata["file_id"] = self.file_iddoc.metadata["file_name"] = self.url if self.url else os.path.split(self.file_path)[-1]write_check_file(self.file_path, docs)if docs:debug_logger.info('langchain analysis content head: %s', docs[0].page_content[:100])else:debug_logger.info('langchain analysis docs is empty!')self.docs = docsdef create_embedding(self):self.embs = self.emb_infer._get_len_safe_embeddings([doc.page_content for doc in self.docs])

2.3,文件切片

文件切片作者也是写成了通用的工具,方便调用,而且这个相比各种文件格式,这里的泛用性会更高,毕竟都解析成文本了,这个比较通用ChineseTextSplitter,继承自langchain的from langchain.text_splitter import CharacterTextSplitter,重写后,更符合中文的使用习惯。直接来看源码吧:qanything_kernel\utils\splitter\chinese_text_splitter.py

from langchain.text_splitter import CharacterTextSplitter
import re
from typing import List
from qanything_kernel.configs.model_config import SENTENCE_SIZEclass ChineseTextSplitter(CharacterTextSplitter):def __init__(self, pdf: bool = False, sentence_size: int = SENTENCE_SIZE, **kwargs):super().__init__(**kwargs)self.pdf = pdfself.sentence_size = sentence_sizedef split_text1(self, text: str) -> List[str]:if self.pdf:text = re.sub(r"\n{3,}", "\n", text)text = re.sub('\s', ' ', text)text = text.replace("\n\n", "")sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))')  # del :;sent_list = []for ele in sent_sep_pattern.split(text):if sent_sep_pattern.match(ele) and sent_list:sent_list[-1] += eleelif ele:sent_list.append(ele)return sent_listdef split_text(self, text: str) -> List[str]:   ##此处需要进一步优化逻辑if self.pdf:text = re.sub(r"\n{3,}", r"\n", text)text = re.sub('\s', " ", text)text = re.sub("\n\n", "", text)text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text)  # 单字符断句符text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text)  # 英文省略号text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text)  # 中文省略号text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text)# 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号text = text.rstrip()  # 段尾如果有多余的\n就去掉它# 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。ls = [i for i in text.split("\n") if i]for ele in ls:if len(ele) > self.sentence_size:ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', ele)ele1_ls = ele1.split("\n")for ele_ele1 in ele1_ls:if len(ele_ele1) > self.sentence_size:ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1)ele2_ls = ele_ele2.split("\n")for ele_ele2 in ele2_ls:if len(ele_ele2) > self.sentence_size:ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2)ele2_id = ele2_ls.index(ele_ele2)ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[ele2_id + 1:]ele_id = ele1_ls.index(ele_ele1)ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:]id = ls.index(ele)ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:]return ls

实际使用的应该是split_text,不带1那个,这里涉及了很多逻辑和替换,主要都是为了做句子片段的划分,这里的正则大家也可以多多了解和尝试。

在此基础上,都会再过第二次切分,这次切分旨在对长度太长(800tokens+)的进行进一步切分,此处使用的是langchain的RecursiveCharacterTextSplitter

from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(separators=["\n", ".", "。", "!", "!", "?", "?", ";", ";", "……", "…", "、", ",", ",", " "],chunk_size=400,length_function=num_tokens,
)

后面,为了确保信息的存储的可查性检索这段话后,能找到对应的文章),还把文件id和文件名都给记录到doc内(说白了就是正排)。

# 这里给每个docs片段的metadata里注入file_id
for doc in docs:doc.metadata["file_id"] = self.file_iddoc.metadata["file_name"] = self.url if self.url else os.path.split(self.file_path)[-1]

2.4,索引构造

在对文本进行好切片后,就可以开始跑模型准备向数据库灌数据了。此处把他叫做索引构造,主要包括数据转化灌库两个操作。

核心的代码同样是在qanything_kernel\core\local_doc_qa.py的 insert_files_to_milvus这个函数下,这里面create_embedding (下面代码24行)就是构造向量的过程,在前面的章节(RAG开源项目Qanything源码阅读1-概述+服务)有提及

    async def insert_files_to_milvus(self, user_id, kb_id, local_files: List[LocalFile]):debug_logger.info(f'insert_files_to_milvus: {kb_id}')milvus_kv = self.match_milvus_kb(user_id, [kb_id])assert milvus_kv is not Nonesuccess_list = []failed_list = []for local_file in local_files:start = time.time()try:local_file.split_file_to_docs(self.get_ocr_result)content_length = sum([len(doc.page_content) for doc in local_file.docs])except Exception as e:error_info = f'split error: {traceback.format_exc()}'debug_logger.error(error_info)self.milvus_summary.update_file_status(local_file.file_id, status='red')failed_list.append(local_file)continueend = time.time()self.milvus_summary.update_content_length(local_file.file_id, content_length)debug_logger.info(f'split time: {end - start} {len(local_file.docs)}')start = time.time()try:local_file.create_embedding()except Exception as e:error_info = f'embedding error: {traceback.format_exc()}'debug_logger.error(error_info)self.milvus_summary.update_file_status(local_file.file_id, status='red')failed_list.append(local_file)continueend = time.time()debug_logger.info(f'embedding time: {end - start} {len(local_file.embs)}')self.milvus_summary.update_chunk_size(local_file.file_id, len(local_file.docs))ret = await milvus_kv.insert_files(local_file.file_id, local_file.file_name, local_file.file_path,local_file.docs, local_file.embs)insert_time = time.time()debug_logger.info(f'insert time: {insert_time - end}')if ret:self.milvus_summary.update_file_status(local_file.file_id, status='green')success_list.append(local_file)else:self.milvus_summary.update_file_status(local_file.file_id, status='yellow')failed_list.append(local_file)debug_logger.info(f"insert_to_milvus: success num: {len(success_list)}, failed num: {len(failed_list)}")

向量化的模型是单独用triton部署的,所以此处是直接请求模型服务获取的(RAG开源项目Qanything源码阅读1-概述+服务4.2 的12行)。

CUDA_VISIBLE_DEVICES=$gpu_id1 nohup /opt/tritonserver/bin/tritonserver --model-store=/model_repos/QAEnsemble_embed_rerank --http-port=9000 --grpc-port=9001 --metrics-port=9002 --log-verbose=1 > /workspace/qanything_local/logs/debug_logs/embed_rerank_tritonserver.log 2>&1 &

而请求方面,先放一个调用的关键入口:qanything_kernel\core\local_file.py

def create_embedding(self):self.embs = self.emb_infer._get_len_safe_embeddings([doc.page_content for doc in self.docs])

这里实际的调用挺深的,首先对于local,有YouDaoLocalEmbeddings,这里是包装向量模型的,里面更多是考虑并发的concurrent代码,向量是内部的embedding_client(一个EmbeddingClient实例)负责的(当然EmbeddingClient下还有concurrent的代码),这个应该才是算法比较关心的部分吧,直接把EmbeddingClient的核心代码放出来:

import os
import math
import numpy as np
import timefrom typing import Optionalimport onnxruntime as ort
from tritonclient import utils as client_utils
from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput
from transformers import AutoTokenizerWEIGHT2NPDTYPE = {"fp32": np.float32,"fp16": np.float16,
}class EmbeddingClient:DEFAULT_MAX_RESP_WAIT_S = 120embed_version = "local_v0.0.1_20230525_6d4019f1559aef84abc2ab8257e1ad4c"def __init__(self,server_url: str,model_name: str,model_version: str,tokenizer_path: str,resp_wait_s: Optional[float] = None,):self._server_url = server_urlself._model_name = model_nameself._model_version = model_versionself._response_wait_t = self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_sself._tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)def get_embedding(self, sentences, max_length=512):# Setting up clientinputs_data = self._tokenizer(sentences, padding=True, truncation=True, max_length=max_length, return_tensors='np')inputs_data = {k: v for k, v in inputs_data.items()}client = InferenceServerClient(url=self._server_url)model_config = client.get_model_config(self._model_name, self._model_version)model_metadata = client.get_model_metadata(self._model_name, self._model_version)inputs_info = {tm.name: tm for tm in model_metadata.inputs}outputs_info = {tm.name: tm for tm in model_metadata.outputs}output_names = list(outputs_info)outputs_req = [InferRequestedOutput(name_) for name_ in outputs_info]infer_inputs = []for name_ in inputs_info:data = inputs_data[name_]infer_input = InferInput(name_, data.shape, inputs_info[name_].datatype)target_np_dtype = client_utils.triton_to_np_dtype(inputs_info[name_].datatype)data = data.astype(target_np_dtype)infer_input.set_data_from_numpy(data)infer_inputs.append(infer_input)results = client.infer(model_name=self._model_name,model_version=self._model_version,inputs=infer_inputs,outputs=outputs_req,client_timeout=120,)y_pred = {name_: results.as_numpy(name_) for name_ in output_names}embeddings = y_pred["output"][:,0]norm_arr = np.linalg.norm(embeddings, axis=1, keepdims=True)embeddings_normalized = embeddings / norm_arrreturn embeddings_normalized.tolist()def getModelVersion(self):return self.embed_version
  • 首先可以看到,tokenizer依旧是本服务做的。

  • 服务的请求主要是client负责,triton是一个grpc接口,输入和输出的数据结构参考InferInputInferRequestedOutput

  • 细节,对模型的输出结果,结果作者还做了额外的处理,主要是做了一个归一化,用np.linalg.norm求了二范数(默认),然后想了都除以了这个二范数。

  • 有留意到,对模型的版本,作者有保留,方便进行模型迭代的版本可控性。

完成后,就可以开始灌库了,qanything_kernel\core\local_file.py 的 milvus_kv.insert_files。milvus自己是有开源的库的,即pymilvus,作者自己写了一个完整的类MilvusClient,至于pymilvus具体教程大家可以看:https://zhuanlan.zhihu.com/p/676124465

from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility, \Partition
from concurrent.futures import ThreadPoolExecutor, as_completed
import asyncio
from functools import partial
import time
import copy
from datetime import datetime
from qanything_kernel.configs.model_config import MILVUS_HOST_LOCAL, MILVUS_HOST_ONLINE, MILVUS_PORT, MILVUS_USER, MILVUS_PASSWORD, MILVUS_DB_NAME, CHUNK_SIZE, VECTOR_SEARCH_TOP_K
from qanything_kernel.utils.custom_log import debug_logger
from langchain.docstore.document import Document
import math
from itertools import groupby
from typing import List# 混合检索
from .es_client import ElasticsearchClient
from qanything_kernel.configs.model_config import HYBRID_SEARCH#ret = await milvus_kv.insert_files(local_file.file_id, local_file.file_name, local_file.file_path,
#                                               local_file.docs, local_file.embs)async def insert_files(self, file_id, file_name, file_path, docs, embs, batch_size=1000):debug_logger.info(f'now inser_file {file_name}')now = datetime.now()timestamp = now.strftime("%Y%m%d%H%M")loop = asyncio.get_running_loop()contents = [doc.page_content for doc in docs]num_docs = len(docs)for batch_start in range(0, num_docs, batch_size):batch_end = min(batch_start + batch_size, num_docs)data = [[] for _ in range(len(self.sess.schema))]for idx in range(batch_start, batch_end):cont = contents[idx]emb = embs[idx]chunk_id = f'{file_id}_{idx}'data[0].append(chunk_id)data[1].append(file_id)data[2].append(file_name)data[3].append(file_path)data[4].append(timestamp)data[5].append(cont)data[6].append(emb)# 执行插入操作try:debug_logger.info('Inserting into Milvus...')mr = await loop.run_in_executor(self.executor, partial(self.partitions[0].insert, data=data))debug_logger.info(f'{file_name} {mr}')except Exception as e:debug_logger.error(f'Milvus insert file_id:{file_id}, file_name:{file_name} failed: {e}')return False# 混合检索if self.hybrid_search:debug_logger.info(f'now inser_file for es: {file_name}')for batch_start in range(0, num_docs, batch_size):batch_end = min(batch_start + batch_size, num_docs)data_es = []for idx in range(batch_start, batch_end):data_es_item = {'file_id': file_id,'content': contents[idx],'metadata': {'file_name': file_name,'file_path': file_path,'chunk_id': f'{file_id}_{idx}','timestamp': timestamp,}}data_es.append(data_es_item)try:debug_logger.info('Inserting into es ...')mr = await self.client.insert(data=data_es, refresh=batch_end==num_docs)debug_logger.info(f'{file_name} {mr}')except Exception as e:debug_logger.error(f'ES insert file_id: {file_id}\nfile_name: {file_name}\nfailed: {e}')return Falsereturn True
  • milvus使用的是pymilvus工具来读写,其中self.partitions[0].insert就是用存储数据的,此处可以注意到data内有很多不同的字段。

  • 执行代码使用的是loop.run_in_executor,有留意到,在MilvusClient内有一个self.executor,这个的定义在这个类的__init__内,self.executor = ThreadPoolExecutor(max_workers=10),这里新建了一个线程池,新技能get。

  • 下方是ES的数据灌入。个人感觉,这个ES数据处理写在这个位置并不是很合适,应该单独出来处理,毕竟混合代码不太好看到。

MilvusClient 类
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility, \Partition
from concurrent.futures import ThreadPoolExecutor, as_completed
import asyncio
from functools import partial
import time
import copy
from datetime import datetime
from qanything_kernel.configs.model_config import MILVUS_HOST_LOCAL, MILVUS_HOST_ONLINE, MILVUS_PORT, MILVUS_USER, MILVUS_PASSWORD, MILVUS_DB_NAME, CHUNK_SIZE, VECTOR_SEARCH_TOP_K
from qanything_kernel.utils.custom_log import debug_logger
from langchain.docstore.document import Document
import math
from itertools import groupby
from typing import List# 混合检索
from .es_client import ElasticsearchClient
from qanything_kernel.configs.model_config import HYBRID_SEARCHclass MilvusFailed(Exception):"""异常基类"""passclass MilvusClient:def __init__(self, mode, user_id, kb_ids, *, threshold=1.1, client_timeout=3):self.user_id = user_idself.kb_ids = kb_idsif mode == 'local':self.host = MILVUS_HOST_LOCALelse:self.host = MILVUS_HOST_ONLINEself.port = MILVUS_PORTself.user = MILVUS_USERself.password = MILVUS_PASSWORDself.db_name = MILVUS_DB_NAMEself.client_timeout = client_timeoutself.threshold = thresholdself.sess: Collection = Noneself.partitions: List[Partition] = []self.executor = ThreadPoolExecutor(max_workers=10)self.top_k = VECTOR_SEARCH_TOP_Kself.search_params = {"metric_type": "L2", "params": {"nprobe": 256}}if mode == 'local':self.create_params = {"metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 2048}}else:self.create_params = {"metric_type": "L2", "index_type": "GPU_IVF_FLAT", "params": {"nlist": 2048}}self.last_init_ts = time.time() - 100  # 减去100保证最初的init不会被拒绝self.init()# 混合检索self.hybrid_search = HYBRID_SEARCHif self.hybrid_search:self.index_name = [f"{user_id}++{kb_id}" for kb_id in kb_ids]self.client = ElasticsearchClient(index_name=self.index_name)@propertydef fields(self):fields = [FieldSchema(name='chunk_id', dtype=DataType.VARCHAR, max_length=64, is_primary=True),FieldSchema(name='file_id', dtype=DataType.VARCHAR, max_length=64),FieldSchema(name='file_name', dtype=DataType.VARCHAR, max_length=640),FieldSchema(name='file_path', dtype=DataType.VARCHAR, max_length=640),FieldSchema(name='timestamp', dtype=DataType.VARCHAR, max_length=64),FieldSchema(name='content', dtype=DataType.VARCHAR, max_length=4000),FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=768)]return fieldsdef parse_batch_result(self, batch_result):new_result = []for batch_idx, result in enumerate(batch_result):new_cands = []result.sort(key=lambda x: x.score)valid_results = [cand for cand in result if cand.score <= self.threshold]if len(valid_results) == 0:  # 如果没有合适的结果,就取topkvalid_results = result[:self.top_k]for cand_i, cand in enumerate(valid_results):doc = Document(page_content=cand.entity.get('content'),metadata={"score": cand.score, "file_id": cand.entity.get('file_id'),"file_name": cand.entity.get('file_name'),"chunk_id": cand.entity.get('chunk_id')})new_cands.append(doc)# csv和xlsx文件不做expand_cand_docsneed_expand, not_need_expand = [], []for doc in new_cands:if doc.metadata['file_name'].lower().split('.')[-1] in ['csv', 'xlsx']:doc.metadata["kernel"] = doc.page_contentnot_need_expand.append(doc)else:need_expand.append(doc)expand_res = self.expand_cand_docs(need_expand)new_cands = not_need_expand + expand_resnew_result.append(new_cands)return new_result# 混合检索def parse_es_batch_result(self, es_records, milvus_records):milvus_records_seen = set()for result in milvus_records:result.sort(key=lambda x: x.score)flag = Truefor cand in result:if cand.score <= self.threshold:milvus_records_seen.add(cand.entity.get('chunk_id'))flag = Falseif flag:for cand in result[:self.top_k]:milvus_records_seen.add(cand.entity.get('chunk_id'))new_cands = []for es_record in es_records:if es_record['id'] not in milvus_records_seen:doc = Document(page_content=es_record['content'],metadata={"score": es_record['score'], "file_id": es_record['file_id'],"file_name": es_record['metadata']['file_name'],"chunk_id": es_record['metadata']['chunk_id']})new_cands.append(doc)# csv和xlsx文件不做expand_cand_docsneed_expand, not_need_expand = [], []for doc in new_cands:if doc.metadata['file_name'].lower().split('.')[-1] in ['csv', 'xlsx']:doc.metadata["kernel"] = doc.page_contentnot_need_expand.append(doc)else:need_expand.append(doc)expand_res = self.expand_cand_docs(need_expand)new_result = not_need_expand + expand_resreturn new_result@propertydef output_fields(self):return ['chunk_id', 'file_id', 'file_name', 'file_path', 'timestamp', 'content']def init(self):try:connections.connect(host=self.host, port=self.port, user=self.user,password=self.password, db_name=self.db_name)  # timeout=3 [cannot set]if utility.has_collection(self.user_id):self.sess = Collection(self.user_id)debug_logger.info(f'collection {self.user_id} exists')else:schema = CollectionSchema(self.fields)debug_logger.info(f'create collection {self.user_id} {schema}')self.sess = Collection(self.user_id, schema)self.sess.create_index(field_name="embedding", index_params=self.create_params)for kb_id in self.kb_ids:if not self.sess.has_partition(kb_id):self.sess.create_partition(kb_id)self.partitions = [Partition(self.sess, kb_id) for kb_id in self.kb_ids]debug_logger.info('partitions: %s', self.kb_ids)self.sess.load()except Exception as e:debug_logger.error(e)def __search_emb_sync(self, embs, expr='', top_k=None, client_timeout=None, queries=None):if not top_k:top_k = self.top_kmilvus_records = self.sess.search(data=embs, partition_names=self.kb_ids, anns_field="embedding",param=self.search_params, limit=top_k,output_fields=self.output_fields, expr=expr, timeout=client_timeout)milvus_records_proc = self.parse_batch_result(milvus_records)# debug_logger.info(milvus_records)# 混合检索if self.hybrid_search:es_records = self.client.search(queries)es_records_proc = self.parse_es_batch_result(es_records, milvus_records)milvus_records_proc.extend(es_records_proc)return milvus_records_procdef search_emb_async(self, embs, expr='', top_k=None, client_timeout=None, queries=None):if not top_k:top_k = self.top_k# 将search_emb_sync函数放入线程池中运行future = self.executor.submit(self.__search_emb_sync, embs, expr, top_k, client_timeout, queries)return future.result()def query_expr_async(self, expr, output_fields=None, client_timeout=None):if client_timeout is None:client_timeout = self.client_timeoutif not output_fields:output_fields = self.output_fieldsfuture = self.executor.submit(partial(self.sess.query, partition_names=self.kb_ids, output_fields=output_fields, expr=expr,timeout=client_timeout))return future.result()async def insert_files(self, file_id, file_name, file_path, docs, embs, batch_size=1000):debug_logger.info(f'now inser_file {file_name}')now = datetime.now()timestamp = now.strftime("%Y%m%d%H%M")loop = asyncio.get_running_loop()contents = [doc.page_content for doc in docs]num_docs = len(docs)for batch_start in range(0, num_docs, batch_size):batch_end = min(batch_start + batch_size, num_docs)data = [[] for _ in range(len(self.sess.schema))]for idx in range(batch_start, batch_end):cont = contents[idx]emb = embs[idx]chunk_id = f'{file_id}_{idx}'data[0].append(chunk_id)data[1].append(file_id)data[2].append(file_name)data[3].append(file_path)data[4].append(timestamp)data[5].append(cont)data[6].append(emb)# 执行插入操作try:debug_logger.info('Inserting into Milvus...')mr = await loop.run_in_executor(self.executor, partial(self.partitions[0].insert, data=data))debug_logger.info(f'{file_name} {mr}')except Exception as e:debug_logger.error(f'Milvus insert file_id:{file_id}, file_name:{file_name} failed: {e}')return False# 混合检索if self.hybrid_search:debug_logger.info(f'now inser_file for es: {file_name}')for batch_start in range(0, num_docs, batch_size):batch_end = min(batch_start + batch_size, num_docs)data_es = []for idx in range(batch_start, batch_end):data_es_item = {'file_id': file_id,'content': contents[idx],'metadata': {'file_name': file_name,'file_path': file_path,'chunk_id': f'{file_id}_{idx}','timestamp': timestamp,}}data_es.append(data_es_item)try:debug_logger.info('Inserting into es ...')mr = await self.client.insert(data=data_es, refresh=batch_end==num_docs)debug_logger.info(f'{file_name} {mr}')except Exception as e:debug_logger.error(f'ES insert file_id: {file_id}\nfile_name: {file_name}\nfailed: {e}')return Falsereturn Truedef delete_collection(self):self.sess.release()utility.drop_collection(self.user_id)# 混合检索if self.hybrid_search:index_name_delete = []for index_name in self.client.indices.get_alias().keys():if index_name.startswith(f"{self.user_id}++"):index_name_delete.append(index_name)self.client.delete_index(index_name_delete)def delete_partition(self, partition_name):part = Partition(self.sess, partition_name)part.release()self.sess.drop_partition(partition_name)# 混合检索if self.hybrid_search:index_name_delete = []if isinstance(partition_name, str):index_name_delete.append(f"{self.user_id}++{partition_name}")elif isinstance(partition_name, list) and isinstance(partition_name[0], str):for kb_id in partition_name:index_name_delete.append(f"{self.user_id}++{kb_id}")else:debug_logger.info(f"##ES## - kb_ids not valid: {partition_name}")self.client.delete_index(index_name_delete)debug_logger.info(f"##ES## - success delete kb_ids: {partition_name}")def delete_files(self, files_id):self.sess.delete(expr=f"file_id in {files_id}")debug_logger.info('milvus delete files_id: %s', files_id)# 混合检索if self.hybrid_search:es_records = self.client.search(files_id, field='file_id')delete_index_ids = {}for record in es_records:if record['index'] not in delete_index_ids:delete_index_ids[record['index']] = []delete_index_ids[record['index']].append(record['id'])for index, ids in delete_index_ids.items():self.client.delete_chunks(index_name=index, ids=ids)debug_logger.info(f"##ES## - success delete files_id: {files_id}")def get_files(self, files_id):res = self.query_expr_async(expr=f"file_id in {files_id}", output_fields=["file_id"])valid_ids = [result['file_id'] for result in res]return valid_idsdef seperate_list(self, ls: List[int]) -> List[List[int]]:lists = []ls1 = [ls[0]]for i in range(1, len(ls)):if ls[i - 1] + 1 == ls[i]:ls1.append(ls[i])else:lists.append(ls1)ls1 = [ls[i]]lists.append(ls1)return listsdef process_group(self, group):new_cands = []# 对每个分组按照chunk_id进行排序group.sort(key=lambda x: int(x.metadata['chunk_id'].split('_')[-1]))id_set = set()file_id = group[0].metadata['file_id']file_name = group[0].metadata['file_name']group_scores_map = {}# 先找出该文件所有需要搜索的chunk_idcand_chunks_set = set()  # 使用集合而不是列表for cand_doc in group:current_chunk_id = int(cand_doc.metadata['chunk_id'].split('_')[-1])group_scores_map[current_chunk_id] = cand_doc.metadata['score']# 使用 set comprehension 一次性生成区间内所有可能的 chunk_idchunk_ids = {file_id + '_' + str(i) for i in range(current_chunk_id - 200, current_chunk_id + 200)}# 更新 cand_chunks_set 集合cand_chunks_set.update(chunk_ids)cand_chunks = list(cand_chunks_set)group_relative_chunks = self.query_expr_async(expr=f"file_id == \"{file_id}\" and chunk_id in {cand_chunks}",output_fields=["chunk_id", "content"])group_chunk_map = {int(item['chunk_id'].split('_')[-1]): item['content'] for item in group_relative_chunks}group_file_chunk_num = list(group_chunk_map.keys())for cand_doc in group:current_chunk_id = int(cand_doc.metadata['chunk_id'].split('_')[-1])doc = copy.deepcopy(cand_doc)id_set.add(current_chunk_id)docs_len = len(doc.page_content)for k in range(1, 200):break_flag = Falsefor expand_index in [current_chunk_id + k, current_chunk_id - k]:if expand_index in group_file_chunk_num:merge_content = group_chunk_map[expand_index]if docs_len + len(merge_content) > CHUNK_SIZE:break_flag = Truebreakelse:docs_len += len(merge_content)id_set.add(expand_index)if break_flag:breakid_list = sorted(list(id_set))id_lists = self.seperate_list(id_list)for id_seq in id_lists:try:for id in id_seq:if id == id_seq[0]:doc = Document(page_content=group_chunk_map[id],metadata={"score": 0, "file_id": file_id,"file_name": file_name})else:doc.page_content += " " + group_chunk_map[id]doc_score = min([group_scores_map[id] for id in id_seq if id in group_scores_map])doc.metadata["score"] = float(format(1 - doc_score / math.sqrt(2), '.4f'))doc.metadata["kernel"] = '|'.join([group_chunk_map[id] for id in id_seq if id in group_scores_map])new_cands.append(doc)except Exception as e:debug_logger.error(f"process_group error: {e}. maybe chunks in ES not exists in Milvus. Please delete the file and upload again.")return new_candsdef expand_cand_docs(self, cand_docs):cand_docs = sorted(cand_docs, key=lambda x: x.metadata['file_id'])# 按照file_id进行分组m_grouped = [list(group) for key, group in groupby(cand_docs, key=lambda x: x.metadata['file_id'])]debug_logger.info('milvus group number: %s', len(m_grouped))with ThreadPoolExecutor(max_workers=10) as executor:futures = []for group in m_grouped:if not group:continuefuture = executor.submit(self.process_group, group)futures.append(future)new_cands = []for future in as_completed(futures):result = future.result()if result is not None:new_cands.extend(result)return new_cands

补充:
Sanic 是什么?怎么使用?一文带你快速上手 Sanic
aiohttp 官方文档:Welcome to AIOHTTP — aiohttp 3.8.6 documentation
Python asyncio 文档:asyncio — Asynchronous I/O — Python 3.12.0 documentation
掌握异步网络编程利器:Python aiohttp使用教程及代码示例
正排索引 vs 倒排索引 - 搜索引擎具体原理
ES高频面试问题:一张图带你读懂 Elasticsearch 中“正排索引(正向索引)”和“倒排索引(反向索引)”区别
MySQL中的倒排索引与正排索引:区别与用途

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/366766.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【C++进阶学习】第五弹——二叉搜索树——二叉树进阶及set和map的铺垫

二叉树1&#xff1a;深入理解数据结构第一弹——二叉树&#xff08;1&#xff09;——堆-CSDN博客 二叉树2&#xff1a;深入理解数据结构第三弹——二叉树&#xff08;3&#xff09;——二叉树的基本结构与操作-CSDN博客 二叉树3&#xff1a;深入理解数据结构第三弹——二叉树…

ubuntu22.04速装中文输入法

附送ubuntu安装chrome wget https://dl.google.com/linux/direct/google-chrome-stable_current_amd64.deb sudo dpkg -i google-chrome-stable_current_amd64.deb

Flask新手入门(一)

前言 Flask是一个用Python编写的轻量级Web应用框架。它最初由Armin Ronacher作为Werkzeug的一个子项目在2010年开发出来。Werkzeug是一个综合工具包&#xff0c;提供了各种用于Web应用开发的工具和函数。自发布以来&#xff0c;Flask因其简洁和灵活性而迅速受到开发者的欢迎。…

Chapter9 更复杂的光照——Shader入门精要学习笔记

Chapter9 更复杂的光照 一、Unity的渲染路径1.渲染路径的概念2.渲染路径的类型①前向渲染路径a. 前向渲染路径的原理b. Unity中的前向渲染c. 两种Pass ②延迟渲染路径a. 延迟渲染路径的原理b. Unity中的延迟渲染c. 两种Pass ③顶点照明渲染路径 二、Unity的光源类型1.光源类型①…

如何找BMS算法、BMS软件的实习

之前一直忙&#xff0c;好久没有更新了&#xff0c;今天就来写一篇文章来介绍如何找BMS方向的实习&#xff0c;以及需要具备哪些条件&#xff0c;我的实习经历都是在读研阶段找的&#xff0c;读研期间两段的实习经历再加上最高影响因子9.4分的论文&#xff0c;我的秋招可以说是…

[22] Opencv_CUDA应用之 使用背景相减法进行对象跟踪

Opencv_CUDA应用之 使用背景相减法进行对象跟踪 背景相减法是在一系列视频帧中将前景对象从背景中分离出来的过程&#xff0c;它广泛应用于对象检测和跟踪应用中去除背景 背景相减法分四步进行&#xff1a;图像预处理 -> 背景建模 -> 检测前景 -> 数据验证 预处理去除…

《昇思25天学习打卡营第9天|onereal》

继续学习昨天的 基于MindNLPMusicGen生成自己的个性化音乐 生成音乐 MusicGen支持两种生成模式&#xff1a;贪心&#xff08;greedy&#xff09;和采样&#xff08;sampling&#xff09;。在实际执行过程中&#xff0c;采样模式得到的结果要显著优于贪心模式。因此我们默认启…

DP:子序列问题

文章目录 什么是子序列子序列的特点举例说明常见问题 关于子序列问题的几个例题1.最长递增子序列2.摆动序列3.最长递增子序列的个数4.最长数对链5.最长定差子序列 总结 什么是子序列 在计算机科学和数学中&#xff0c;子序列&#xff08;Subsequence&#xff09;是指从一个序列…

【JavaEE精炼宝库】多线程进阶(2)synchronized原理、JUC类——深度理解多线程编程

一、synchronized 原理 1.1 基本特点&#xff1a; 结合上面的锁策略&#xff0c;我们就可以总结出&#xff0c;synchronized 具有以下特性(只考虑 JDK 1.8)&#xff1a; 开始时是乐观锁&#xff0c;如果锁冲突频繁&#xff0c;就转换为悲观锁。 开始是轻量级锁实现&#xff…

维护Nginx千字经验总结

Hello , 我是恒 。 维护putty和nginx两个项目好久了&#xff0c;用面向底层的思路去接触 在nginx社区的收获不少&#xff0c;在这里谈谈我的感悟 Nginx的夺冠不是偶然 高速:一方面&#xff0c;在正常情况下&#xff0c;单次请求会得到更快的响应&#xff1b;另一方面&#xff0…

Linux:网络基础1

文章目录 前言1. 协议1.1 为什么要有协议&#xff1f;1.2 什么是协议&#xff1f; 2. 网络2.1 网络通信的问题2.2 网络的解决方案——网络的层状结构2.3 网络和系统的关系2.4 网络传输基本流程2.5 简单理解IP地址2.6 跨网络传输 总结 前言 在早期的计算机发展中&#xff0c;一开…

免费翻译API及使用指南——百度、腾讯

目录 一、百度翻译API 二、腾讯翻译API 一、百度翻译API 百度翻译API接口免费翻译额度&#xff1a;标准版&#xff08;5万字符免费/每月&#xff09;、高级版&#xff08;100万字符免费/每月-需个人认证&#xff0c;基本都能通过&#xff09;、尊享版&#xff08;200万字符免…

Linux驱动开发实战宝典:设备模型、模块编程、I2C/SPI/USB外设精讲

摘要: 本文将带你走进 Linux 驱动开发的世界,从设备驱动模型、内核模块开发基础开始,逐步深入 I2C、SPI、USB 等常用外设的驱动编写,结合实际案例,助你掌握 Linux 驱动开发技能。 关键词: Linux 驱动,设备驱动模型,内核模块,I2C,SPI,USB 一、Linux 设备驱动模型 Li…

cesium 聚合

cesium 聚合(下面附有源码) 示例代码 <html lang="en"><head><!-- Use correct character set. -->

python系列30:各种爬虫技术总结

1. 使用requests获取网页内容 以巴鲁夫产品为例&#xff0c;可以用get请求获取内容&#xff1a; https://www.balluff.com.cn/zh-cn/products/BES02YF 对应的网页为&#xff1a; 使用简单方法进行解析即可 import requests r BES02YF res requests.get("https://www.…

JSON JOLT常用示例整理

JSON JOLT常用示例整理 1、什么是jolt Jolt是用Java编写的JSON到JSON转换库&#xff0c;其中指示如何转换的"specification"本身就是一个JSON文档。以下文档中&#xff0c;我统一以 Spec 代替如何转换的"specification"json文档。以LHS(left hand side)代…

云计算基础技术

网络类技术 网络的作用 网络是设备间、虚拟机之间通信的桥梁。因此&#xff0c;在ICT基础设施中&#xff0c;网络是必不可少的。 传统网络的基本概念 广播和单播&#xff1a;两个设备通信就好像是人们之间的对话一样。如果一个人对另外一个人说话&#xff0c;那么用网络技术的…

从零开始搭建spring boot多模块项目

一、搭建父级模块 1、打开idea,选择file–new–project 2、选择Spring Initializr,选择相关java版本,点击“Next” 3、填写父级模块信息 选择/填写group、artifact、type、language、packaging(后面需要修改)、java version(后面需要修改成和第2步中版本一致)。点击“…

容器内存

一、容器内存概述 容器本质上还是一个进程&#xff0c;是一个被隔离和限制的进程。因此容器内存和进程内存在表现形式上其实是一样的&#xff0c;这块主要涉及三部分内容&#xff1a;RSS&#xff0c;page cache和swap这三部分&#xff0c;容器基于memory Cgroup对内存进行限制…

HSRP热备份路由协议(VRRP虚拟路由冗余协议)配置以及实现负载均衡

1、相关原理 在网络中&#xff0c;如果一台作为默认网关的三层交换机或者路由器损坏&#xff0c;所有使用该网关为下一跳的主机通信必然中断&#xff0c;即使配置多个默认网关&#xff0c;在不重启终端的情况下&#xff0c;也不能彻底换到新网关。Cisco提出了HSRP热备份路由协…